feat: make diskann runnable

This commit is contained in:
Andy Lee
2025-07-22 14:26:03 -07:00
parent 71e5f1774c
commit 8513471573
9 changed files with 394 additions and 760 deletions

View File

@@ -1,10 +1,13 @@
import numpy as np
import os
import struct
import sys
from pathlib import Path
from typing import Dict, Any, List, Literal
from typing import Dict, Any, List, Literal, Optional
import contextlib
import logging
from leann.searcher_base import BaseSearcher
from leann.registry import register_backend
from leann.interface import (
@@ -13,6 +16,46 @@ from leann.interface import (
LeannBackendSearcherInterface,
)
logger = logging.getLogger(__name__)
@contextlib.contextmanager
def suppress_cpp_output_if_needed():
"""Suppress C++ stdout/stderr based on LEANN_LOG_LEVEL"""
log_level = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper()
# Only suppress if log level is WARNING or higher (ERROR, CRITICAL)
should_suppress = log_level in ["WARNING", "ERROR", "CRITICAL"]
if not should_suppress:
# Don't suppress, just yield
yield
return
# Save original file descriptors
stdout_fd = sys.stdout.fileno()
stderr_fd = sys.stderr.fileno()
# Save original stdout/stderr
stdout_dup = os.dup(stdout_fd)
stderr_dup = os.dup(stderr_fd)
try:
# Redirect to /dev/null
devnull = os.open(os.devnull, os.O_WRONLY)
os.dup2(devnull, stdout_fd)
os.dup2(devnull, stderr_fd)
os.close(devnull)
yield
finally:
# Restore original file descriptors
os.dup2(stdout_dup, stdout_fd)
os.dup2(stderr_dup, stderr_fd)
os.close(stdout_dup)
os.close(stderr_dup)
def _get_diskann_metrics():
from . import _diskannpy as diskannpy # type: ignore
@@ -64,6 +107,7 @@ class DiskannBuilder(LeannBackendBuilderInterface):
index_dir.mkdir(parents=True, exist_ok=True)
if data.dtype != np.float32:
logger.warning(f"Converting data to float32, shape: {data.shape}")
data = data.astype(np.float32)
data_filename = f"{index_prefix}_data.bin"
@@ -74,7 +118,9 @@ class DiskannBuilder(LeannBackendBuilderInterface):
build_kwargs.get("distance_metric", "mips").lower()
)
if metric_enum is None:
raise ValueError("Unsupported distance_metric.")
raise ValueError(
f"Unsupported distance_metric '{build_kwargs.get('distance_metric', 'unknown')}'."
)
try:
from . import _diskannpy as diskannpy # type: ignore
@@ -96,36 +142,40 @@ class DiskannBuilder(LeannBackendBuilderInterface):
temp_data_file = index_dir / data_filename
if temp_data_file.exists():
os.remove(temp_data_file)
logger.debug(f"Cleaned up temporary data file: {temp_data_file}")
class DiskannSearcher(BaseSearcher):
def __init__(self, index_path: str, **kwargs):
super().__init__(
index_path,
backend_module_name="leann_backend_diskann.embedding_server",
backend_module_name="leann_backend_diskann.diskann_embedding_server",
**kwargs,
)
from . import _diskannpy as diskannpy # type: ignore
distance_metric = kwargs.get("distance_metric", "mips").lower()
metric_enum = _get_diskann_metrics().get(distance_metric)
if metric_enum is None:
raise ValueError(f"Unsupported distance_metric '{distance_metric}'.")
# Initialize DiskANN index with suppressed C++ output based on log level
with suppress_cpp_output_if_needed():
from . import _diskannpy as diskannpy # type: ignore
self.num_threads = kwargs.get("num_threads", 8)
self.zmq_port = kwargs.get("zmq_port", 6666)
distance_metric = kwargs.get("distance_metric", "mips").lower()
metric_enum = _get_diskann_metrics().get(distance_metric)
if metric_enum is None:
raise ValueError(f"Unsupported distance_metric '{distance_metric}'.")
full_index_prefix = str(self.index_dir / self.index_path.stem)
self._index = diskannpy.StaticDiskFloatIndex(
metric_enum,
full_index_prefix,
self.num_threads,
kwargs.get("num_nodes_to_cache", 0),
1,
self.zmq_port,
"",
"",
)
self.num_threads = kwargs.get("num_threads", 8)
fake_zmq_port = 6666
full_index_prefix = str(self.index_dir / self.index_path.stem)
self._index = diskannpy.StaticDiskFloatIndex(
metric_enum,
full_index_prefix,
self.num_threads,
kwargs.get("num_nodes_to_cache", 0),
1,
fake_zmq_port, # Initial port, can be updated at runtime
"",
"",
)
def search(
self,
@@ -136,7 +186,7 @@ class DiskannSearcher(BaseSearcher):
prune_ratio: float = 0.0,
recompute_embeddings: bool = False,
pruning_strategy: Literal["global", "local", "proportional"] = "global",
zmq_port: int = 5557,
zmq_port: Optional[int] = None,
batch_recompute: bool = False,
dedup_node_dis: bool = False,
**kwargs,
@@ -155,7 +205,7 @@ class DiskannSearcher(BaseSearcher):
- "global": Use global pruning strategy (default)
- "local": Use local pruning strategy
- "proportional": Not supported in DiskANN, falls back to global
zmq_port: ZMQ port for embedding server
zmq_port: ZMQ port for embedding server communication. Must be provided if recompute_embeddings is True.
batch_recompute: Whether to batch neighbor recomputation (DiskANN-specific)
dedup_node_dis: Whether to cache and reuse distance computations (DiskANN-specific)
**kwargs: Additional DiskANN-specific parameters (for legacy compatibility)
@@ -163,22 +213,25 @@ class DiskannSearcher(BaseSearcher):
Returns:
Dict with 'labels' (list of lists) and 'distances' (ndarray)
"""
# Handle zmq_port compatibility: DiskANN can now update port at runtime
if recompute_embeddings:
if zmq_port is None:
raise ValueError(
"zmq_port must be provided if recompute_embeddings is True"
)
current_port = self._index.get_zmq_port()
if zmq_port != current_port:
logger.debug(
f"Updating DiskANN zmq_port from {current_port} to {zmq_port}"
)
self._index.set_zmq_port(zmq_port)
# DiskANN doesn't support "proportional" strategy
if pruning_strategy == "proportional":
raise NotImplementedError(
"DiskANN backend does not support 'proportional' pruning strategy. Use 'global' or 'local' instead."
)
# Use recompute_embeddings parameter
use_recompute = recompute_embeddings
if use_recompute:
meta_file_path = self.index_dir / f"{self.index_path.name}.meta.json"
if not meta_file_path.exists():
raise RuntimeError(
f"FATAL: Recompute enabled but metadata file not found: {meta_file_path}"
)
self._ensure_server_running(str(meta_file_path), port=zmq_port, **kwargs)
if query.dtype != np.float32:
query = query.astype(np.float32)
@@ -188,21 +241,23 @@ class DiskannSearcher(BaseSearcher):
else: # "global"
use_global_pruning = True
labels, distances = self._index.batch_search(
query,
query.shape[0],
top_k,
complexity,
beam_width,
self.num_threads,
kwargs.get("USE_DEFERRED_FETCH", False),
kwargs.get("skip_search_reorder", False),
use_recompute,
dedup_node_dis,
prune_ratio,
batch_recompute,
use_global_pruning,
)
# Perform search with suppressed C++ output based on log level
with suppress_cpp_output_if_needed():
labels, distances = self._index.batch_search(
query,
query.shape[0],
top_k,
complexity,
beam_width,
self.num_threads,
kwargs.get("USE_DEFERRED_FETCH", False),
kwargs.get("skip_search_reorder", False),
recompute_embeddings,
dedup_node_dis,
prune_ratio,
batch_recompute,
use_global_pruning,
)
string_labels = [
[str(int_label) for int_label in batch_labels] for batch_labels in labels