import contextlib import logging import os import struct import sys from pathlib import Path from typing import Any, Dict, List, Literal, Optional, Tuple import numpy as np import psutil from leann.interface import ( LeannBackendBuilderInterface, LeannBackendFactoryInterface, LeannBackendSearcherInterface, ) from leann.registry import register_backend from leann.searcher_base import BaseSearcher logger = logging.getLogger(__name__) @contextlib.contextmanager def suppress_cpp_output_if_needed(): """Suppress C++ stdout/stderr based on LEANN_LOG_LEVEL""" log_level = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper() # Only suppress if log level is WARNING or higher (ERROR, CRITICAL) should_suppress = log_level in ["WARNING", "ERROR", "CRITICAL"] if not should_suppress: # Don't suppress, just yield yield return # Save original file descriptors stdout_fd = sys.stdout.fileno() stderr_fd = sys.stderr.fileno() # Save original stdout/stderr stdout_dup = os.dup(stdout_fd) stderr_dup = os.dup(stderr_fd) try: # Redirect to /dev/null devnull = os.open(os.devnull, os.O_WRONLY) os.dup2(devnull, stdout_fd) os.dup2(devnull, stderr_fd) os.close(devnull) yield finally: # Restore original file descriptors os.dup2(stdout_dup, stdout_fd) os.dup2(stderr_dup, stderr_fd) os.close(stdout_dup) os.close(stderr_dup) def _get_diskann_metrics(): from . import _diskannpy as diskannpy # type: ignore return { "mips": diskannpy.Metric.INNER_PRODUCT, "l2": diskannpy.Metric.L2, "cosine": diskannpy.Metric.COSINE, } @contextlib.contextmanager def chdir(path): original_dir = os.getcwd() os.chdir(path) try: yield finally: os.chdir(original_dir) def _write_vectors_to_bin(data: np.ndarray, file_path: Path): num_vectors, dim = data.shape with open(file_path, "wb") as f: f.write(struct.pack("I", num_vectors)) f.write(struct.pack("I", dim)) f.write(data.tobytes()) def _calculate_smart_memory_config(data: np.ndarray) -> Tuple[float, float]: """ Calculate smart memory configuration for DiskANN based on data size and system specs. Args: data: The embedding data array Returns: tuple: (search_memory_maximum, build_memory_maximum) in GB """ num_vectors, dim = data.shape # Calculate embedding storage size embedding_size_bytes = num_vectors * dim * 4 # float32 = 4 bytes embedding_size_gb = embedding_size_bytes / (1024**3) # search_memory_maximum: 1/10 of embedding size for optimal PQ compression # This controls Product Quantization size - smaller means more compression search_memory_gb = max(0.1, embedding_size_gb / 10) # At least 100MB # build_memory_maximum: Based on available system RAM for sharding control # This controls how much memory DiskANN uses during index construction available_memory_gb = psutil.virtual_memory().available / (1024**3) total_memory_gb = psutil.virtual_memory().total / (1024**3) # Use 50% of available memory, but at least 2GB and at most 75% of total build_memory_gb = max(2.0, min(available_memory_gb * 0.5, total_memory_gb * 0.75)) logger.info( f"Smart memory config - Data: {embedding_size_gb:.2f}GB, " f"Search mem: {search_memory_gb:.2f}GB (PQ control), " f"Build mem: {build_memory_gb:.2f}GB (sharding control)" ) return search_memory_gb, build_memory_gb @register_backend("diskann") class DiskannBackend(LeannBackendFactoryInterface): @staticmethod def builder(**kwargs) -> LeannBackendBuilderInterface: return DiskannBuilder(**kwargs) @staticmethod def searcher(index_path: str, **kwargs) -> LeannBackendSearcherInterface: return DiskannSearcher(index_path, **kwargs) class DiskannBuilder(LeannBackendBuilderInterface): def __init__(self, **kwargs): self.build_params = kwargs def _safe_cleanup_after_partition(self, index_dir: Path, index_prefix: str): """ Safely cleanup files after partition. In partition mode, C++ doesn't read _disk.index content, so we can delete it if all derived files exist. """ disk_index_file = index_dir / f"{index_prefix}_disk.index" beam_search_file = index_dir / f"{index_prefix}_disk_beam_search.index" # Required files that C++ partition mode needs # Note: C++ generates these with _disk.index suffix disk_suffix = "_disk.index" required_files = [ f"{index_prefix}{disk_suffix}_medoids.bin", # Critical: assert fails if missing # Note: _centroids.bin is not created in single-shot build - C++ handles this automatically f"{index_prefix}_pq_pivots.bin", # PQ table f"{index_prefix}_pq_compressed.bin", # PQ compressed vectors ] # Check if all required files exist missing_files = [] for filename in required_files: file_path = index_dir / filename if not file_path.exists(): missing_files.append(filename) if missing_files: logger.warning( f"Cannot safely delete _disk.index - missing required files: {missing_files}" ) logger.info("Keeping all original files for safety") return # Calculate space savings space_saved = 0 files_to_delete = [] if disk_index_file.exists(): space_saved += disk_index_file.stat().st_size files_to_delete.append(disk_index_file) if beam_search_file.exists(): space_saved += beam_search_file.stat().st_size files_to_delete.append(beam_search_file) # Safe to delete! for file_to_delete in files_to_delete: try: os.remove(file_to_delete) logger.info(f"✅ Safely deleted: {file_to_delete.name}") except Exception as e: logger.warning(f"Failed to delete {file_to_delete.name}: {e}") if space_saved > 0: space_saved_mb = space_saved / (1024 * 1024) logger.info(f"💾 Space saved: {space_saved_mb:.1f} MB") # Show what files are kept logger.info("📁 Kept essential files for partition mode:") for filename in required_files: file_path = index_dir / filename if file_path.exists(): size_mb = file_path.stat().st_size / (1024 * 1024) logger.info(f" - {filename} ({size_mb:.1f} MB)") def build(self, data: np.ndarray, ids: List[str], index_path: str, **kwargs): path = Path(index_path) index_dir = path.parent index_prefix = path.stem index_dir.mkdir(parents=True, exist_ok=True) if data.dtype != np.float32: logger.warning(f"Converting data to float32, shape: {data.shape}") data = data.astype(np.float32) data_filename = f"{index_prefix}_data.bin" _write_vectors_to_bin(data, index_dir / data_filename) build_kwargs = {**self.build_params, **kwargs} # Extract is_recompute from nested backend_kwargs if needed is_recompute = build_kwargs.get("is_recompute", False) if not is_recompute and "backend_kwargs" in build_kwargs: is_recompute = build_kwargs["backend_kwargs"].get("is_recompute", False) # Flatten all backend_kwargs parameters to top level for compatibility if "backend_kwargs" in build_kwargs: nested_params = build_kwargs.pop("backend_kwargs") build_kwargs.update(nested_params) metric_enum = _get_diskann_metrics().get( build_kwargs.get("distance_metric", "mips").lower() ) if metric_enum is None: raise ValueError( f"Unsupported distance_metric '{build_kwargs.get('distance_metric', 'unknown')}'." ) # Calculate smart memory configuration if not explicitly provided if ( "search_memory_maximum" not in build_kwargs or "build_memory_maximum" not in build_kwargs ): smart_search_mem, smart_build_mem = _calculate_smart_memory_config(data) else: smart_search_mem = build_kwargs.get("search_memory_maximum", 4.0) smart_build_mem = build_kwargs.get("build_memory_maximum", 8.0) try: from . import _diskannpy as diskannpy # type: ignore with chdir(index_dir): diskannpy.build_disk_float_index( metric_enum, data_filename, index_prefix, build_kwargs.get("complexity", 64), build_kwargs.get("graph_degree", 32), build_kwargs.get("search_memory_maximum", smart_search_mem), build_kwargs.get("build_memory_maximum", smart_build_mem), build_kwargs.get("num_threads", 8), build_kwargs.get("pq_disk_bytes", 0), "", ) # Auto-partition if is_recompute is enabled if build_kwargs.get("is_recompute", False): logger.info("is_recompute=True, starting automatic graph partitioning...") from .graph_partition import partition_graph # Partition the index using absolute paths # Convert to absolute paths to avoid issues with working directory changes absolute_index_dir = Path(index_dir).resolve() absolute_index_prefix_path = str(absolute_index_dir / index_prefix) disk_graph_path, partition_bin_path = partition_graph( index_prefix_path=absolute_index_prefix_path, output_dir=str(absolute_index_dir), partition_prefix=index_prefix, ) # Safe cleanup: In partition mode, C++ doesn't read _disk.index content # but still needs the derived files (_medoids.bin, _centroids.bin, etc.) self._safe_cleanup_after_partition(index_dir, index_prefix) logger.info("✅ Graph partitioning completed successfully!") logger.info(f" - Disk graph: {disk_graph_path}") logger.info(f" - Partition file: {partition_bin_path}") finally: temp_data_file = index_dir / data_filename if temp_data_file.exists(): os.remove(temp_data_file) logger.debug(f"Cleaned up temporary data file: {temp_data_file}") class DiskannSearcher(BaseSearcher): def __init__(self, index_path: str, **kwargs): super().__init__( index_path, backend_module_name="leann_backend_diskann.diskann_embedding_server", **kwargs, ) # Initialize DiskANN index with suppressed C++ output based on log level with suppress_cpp_output_if_needed(): from . import _diskannpy as diskannpy # type: ignore distance_metric = kwargs.get("distance_metric", "mips").lower() metric_enum = _get_diskann_metrics().get(distance_metric) if metric_enum is None: raise ValueError(f"Unsupported distance_metric '{distance_metric}'.") self.num_threads = kwargs.get("num_threads", 8) # For DiskANN, we need to reinitialize the index when zmq_port changes # Store the initialization parameters for later use # Note: C++ load method expects the BASE path (without _disk.index suffix) # C++ internally constructs: index_prefix + "_disk.index" index_name = self.index_path.stem # "simple_test.leann" -> "simple_test" diskann_index_prefix = str(self.index_dir / index_name) # /path/to/simple_test full_index_prefix = diskann_index_prefix # /path/to/simple_test (base path) # Auto-detect partition files and set partition_prefix partition_graph_file = self.index_dir / f"{index_name}_disk_graph.index" partition_bin_file = self.index_dir / f"{index_name}_partition.bin" partition_prefix = "" if partition_graph_file.exists() and partition_bin_file.exists(): # C++ expects full path prefix, not just filename partition_prefix = str(self.index_dir / index_name) # /path/to/simple_test logger.info( f"✅ Detected partition files, using partition_prefix='{partition_prefix}'" ) else: logger.debug("No partition files detected, using standard index files") self._init_params = { "metric_enum": metric_enum, "full_index_prefix": full_index_prefix, "num_threads": self.num_threads, "num_nodes_to_cache": kwargs.get("num_nodes_to_cache", 0), "cache_mechanism": 1, "pq_prefix": "", "partition_prefix": partition_prefix, } # Log partition configuration for debugging if partition_prefix: logger.info( f"✅ Detected partition files, using partition_prefix='{partition_prefix}'" ) self._diskannpy = diskannpy self._current_zmq_port = None self._index = None logger.debug("DiskANN searcher initialized (index will be loaded on first search)") def _ensure_index_loaded(self, zmq_port: int): """Ensure the index is loaded with the correct zmq_port.""" if self._index is None or self._current_zmq_port != zmq_port: # Need to (re)load the index with the correct zmq_port with suppress_cpp_output_if_needed(): if self._index is not None: logger.debug(f"Reloading DiskANN index with new zmq_port: {zmq_port}") else: logger.debug(f"Loading DiskANN index with zmq_port: {zmq_port}") self._index = self._diskannpy.StaticDiskFloatIndex( self._init_params["metric_enum"], self._init_params["full_index_prefix"], self._init_params["num_threads"], self._init_params["num_nodes_to_cache"], self._init_params["cache_mechanism"], zmq_port, self._init_params["pq_prefix"], self._init_params["partition_prefix"], ) self._current_zmq_port = zmq_port def search( self, query: np.ndarray, top_k: int, complexity: int = 64, beam_width: int = 1, prune_ratio: float = 0.0, recompute_embeddings: bool = False, pruning_strategy: Literal["global", "local", "proportional"] = "global", zmq_port: Optional[int] = None, batch_recompute: bool = False, dedup_node_dis: bool = False, **kwargs, ) -> Dict[str, Any]: """ Search for nearest neighbors using DiskANN index. Args: query: Query vectors (B, D) where B is batch size, D is dimension top_k: Number of nearest neighbors to return complexity: Search complexity/candidate list size, higher = more accurate but slower beam_width: Number of parallel IO requests per iteration prune_ratio: Ratio of neighbors to prune via approximate distance (0.0-1.0) recompute_embeddings: Whether to fetch fresh embeddings from server pruning_strategy: PQ candidate selection strategy: - "global": Use global pruning strategy (default) - "local": Use local pruning strategy - "proportional": Not supported in DiskANN, falls back to global zmq_port: ZMQ port for embedding server communication. Must be provided if recompute_embeddings is True. batch_recompute: Whether to batch neighbor recomputation (DiskANN-specific) dedup_node_dis: Whether to cache and reuse distance computations (DiskANN-specific) **kwargs: Additional DiskANN-specific parameters (for legacy compatibility) Returns: Dict with 'labels' (list of lists) and 'distances' (ndarray) """ # Handle zmq_port compatibility: Ensure index is loaded with correct port if recompute_embeddings: if zmq_port is None: raise ValueError("zmq_port must be provided if recompute_embeddings is True") self._ensure_index_loaded(zmq_port) else: # If not recomputing, we still need an index, use a default port if self._index is None: self._ensure_index_loaded(6666) # Default port when not recomputing # DiskANN doesn't support "proportional" strategy if pruning_strategy == "proportional": raise NotImplementedError( "DiskANN backend does not support 'proportional' pruning strategy. Use 'global' or 'local' instead." ) if query.dtype != np.float32: query = query.astype(np.float32) # Map pruning_strategy to DiskANN's global_pruning parameter if pruning_strategy == "local": use_global_pruning = False else: # "global" use_global_pruning = True # Perform search with suppressed C++ output based on log level use_deferred_fetch = kwargs.get("USE_DEFERRED_FETCH", True) recompute_neighors = False with suppress_cpp_output_if_needed(): labels, distances = self._index.batch_search( query, query.shape[0], top_k, complexity, beam_width, self.num_threads, use_deferred_fetch, kwargs.get("skip_search_reorder", False), recompute_neighors, dedup_node_dis, prune_ratio, batch_recompute, use_global_pruning, ) string_labels = [[str(int_label) for int_label in batch_labels] for batch_labels in labels] return {"labels": string_labels, "distances": distances} def cleanup(self): """Cleanup DiskANN-specific resources including C++ index.""" # Call parent cleanup first super().cleanup() # Delete the C++ index to trigger destructors try: if hasattr(self, "_index") and self._index is not None: del self._index self._index = None self._current_zmq_port = None except Exception: pass # Force garbage collection to ensure C++ objects are destroyed try: import gc gc.collect() except Exception: pass