feat: different search_args and docstrings
This commit is contained in:
@@ -3,7 +3,7 @@ import os
|
||||
import json
|
||||
import struct
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, List
|
||||
from typing import Dict, Any, List, Literal
|
||||
import contextlib
|
||||
import pickle
|
||||
|
||||
@@ -108,24 +108,69 @@ class DiskannSearcher(BaseSearcher):
|
||||
kwargs.get("num_nodes_to_cache", 0), 1, self.zmq_port, "", ""
|
||||
)
|
||||
|
||||
def search(self, query: np.ndarray, top_k: int, **kwargs) -> Dict[str, Any]:
|
||||
recompute = kwargs.get("recompute_beighbor_embeddings", False)
|
||||
if recompute:
|
||||
def search(self, query: np.ndarray, top_k: int,
|
||||
complexity: int = 64,
|
||||
beam_width: int = 1,
|
||||
prune_ratio: float = 0.0,
|
||||
recompute_embeddings: bool = False,
|
||||
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
||||
zmq_port: int = 5557,
|
||||
batch_recompute: bool = False,
|
||||
dedup_node_dis: bool = False,
|
||||
**kwargs) -> Dict[str, Any]:
|
||||
"""
|
||||
Search for nearest neighbors using DiskANN index.
|
||||
|
||||
Args:
|
||||
query: Query vectors (B, D) where B is batch size, D is dimension
|
||||
top_k: Number of nearest neighbors to return
|
||||
complexity: Search complexity/candidate list size, higher = more accurate but slower
|
||||
beam_width: Number of parallel IO requests per iteration
|
||||
prune_ratio: Ratio of neighbors to prune via approximate distance (0.0-1.0)
|
||||
recompute_embeddings: Whether to fetch fresh embeddings from server
|
||||
pruning_strategy: PQ candidate selection strategy:
|
||||
- "global": Use global pruning strategy (default)
|
||||
- "local": Use local pruning strategy
|
||||
- "proportional": Not supported in DiskANN, falls back to global
|
||||
zmq_port: ZMQ port for embedding server
|
||||
batch_recompute: Whether to batch neighbor recomputation (DiskANN-specific)
|
||||
dedup_node_dis: Whether to cache and reuse distance computations (DiskANN-specific)
|
||||
**kwargs: Additional DiskANN-specific parameters (for legacy compatibility)
|
||||
|
||||
Returns:
|
||||
Dict with 'labels' (list of lists) and 'distances' (ndarray)
|
||||
"""
|
||||
# DiskANN doesn't support "proportional" strategy
|
||||
if pruning_strategy == "proportional":
|
||||
raise NotImplementedError("DiskANN backend does not support 'proportional' pruning strategy. Use 'global' or 'local' instead.")
|
||||
|
||||
# Use recompute_embeddings parameter
|
||||
use_recompute = recompute_embeddings
|
||||
if use_recompute:
|
||||
meta_file_path = self.index_dir / f"{self.index_path.name}.meta.json"
|
||||
if not meta_file_path.exists():
|
||||
raise RuntimeError(f"FATAL: Recompute mode enabled but metadata file not found: {meta_file_path}")
|
||||
zmq_port = kwargs.get("zmq_port", self.zmq_port)
|
||||
raise RuntimeError(f"FATAL: Recompute enabled but metadata file not found: {meta_file_path}")
|
||||
self._ensure_server_running(str(meta_file_path), port=zmq_port, **kwargs)
|
||||
|
||||
if query.dtype != np.float32:
|
||||
query = query.astype(np.float32)
|
||||
|
||||
# Map pruning_strategy to DiskANN's global_pruning parameter
|
||||
if pruning_strategy == "local":
|
||||
use_global_pruning = False
|
||||
else: # "global"
|
||||
use_global_pruning = True
|
||||
|
||||
labels, distances = self._index.batch_search(
|
||||
query, query.shape[0], top_k,
|
||||
kwargs.get("complexity", 256), kwargs.get("beam_width", 4), self.num_threads,
|
||||
kwargs.get("USE_DEFERRED_FETCH", False), kwargs.get("skip_search_reorder", False),
|
||||
recompute, kwargs.get("dedup_node_dis", False), kwargs.get("prune_ratio", 0.0),
|
||||
kwargs.get("batch_recompute", False), kwargs.get("global_pruning", False)
|
||||
complexity, beam_width, self.num_threads,
|
||||
kwargs.get("USE_DEFERRED_FETCH", False),
|
||||
kwargs.get("skip_search_reorder", False),
|
||||
use_recompute,
|
||||
dedup_node_dis,
|
||||
prune_ratio,
|
||||
batch_recompute,
|
||||
use_global_pruning
|
||||
)
|
||||
|
||||
string_labels = [[self.label_map.get(int_label, f"unknown_{int_label}") for int_label in batch_labels] for batch_labels in labels]
|
||||
|
||||
Reference in New Issue
Block a user