feat: make diskann runnable

This commit is contained in:
Andy Lee
2025-07-22 14:26:03 -07:00
parent 71e5f1774c
commit 8513471573
9 changed files with 394 additions and 760 deletions

View File

@@ -142,12 +142,12 @@ class HNSWSearcher(BaseSearcher):
self,
query: np.ndarray,
top_k: int,
zmq_port: Optional[int] = None,
complexity: int = 64,
beam_width: int = 1,
prune_ratio: float = 0.0,
recompute_embeddings: bool = True,
pruning_strategy: Literal["global", "local", "proportional"] = "global",
expected_zmq_port: Optional[int] = None,
batch_size: int = 0,
**kwargs,
) -> Dict[str, Any]:
@@ -165,7 +165,7 @@ class HNSWSearcher(BaseSearcher):
- "global": Use global PQ queue size for selection (default)
- "local": Local pruning, sort and select best candidates
- "proportional": Base selection on new neighbor count ratio
expected_zmq_port: ZMQ port for embedding server
zmq_port: ZMQ port for embedding server communication. Must be provided if recompute_embeddings is True.
batch_size: Neighbor processing batch size, 0=disabled (HNSW-specific)
**kwargs: Additional HNSW-specific parameters (for legacy compatibility)
@@ -177,6 +177,11 @@ class HNSWSearcher(BaseSearcher):
if not recompute_embeddings:
if self.is_pruned:
raise RuntimeError("Recompute is required for pruned index.")
if recompute_embeddings:
if zmq_port is None:
raise ValueError(
"zmq_port must be provided if recompute_embeddings is True"
)
if query.dtype != np.float32:
query = query.astype(np.float32)
@@ -184,7 +189,10 @@ class HNSWSearcher(BaseSearcher):
faiss.normalize_L2(query)
params = faiss.SearchParametersHNSW()
params.zmq_port = expected_zmq_port
if zmq_port is not None:
params.zmq_port = (
zmq_port # C++ code won't use this if recompute_embeddings is False
)
params.efSearch = complexity
params.beam_size = beam_width