Fixed the actual root cause instead of just masking it in tests: 1. Root Problem: - C++ side's ZmqDistanceComputer creates ZMQ connections but doesn't clean them - Python 3.9/3.13 are more sensitive to cleanup timing during shutdown 2. Core Fixes in SearcherBase and LeannSearcher: - Added cleanup() method to BaseSearcher that cleans ZMQ and embedding server - LeannSearcher.cleanup() now also handles ZMQ context cleanup - Both HNSW and DiskANN searchers now properly delete C++ index objects 3. Backend-Specific Cleanup: - HNSWSearcher.cleanup(): Deletes self.index to trigger C++ destructors - DiskannSearcher.cleanup(): Deletes self._index and resets state - Both force garbage collection after deletion 4. Test Infrastructure: - Added auto_cleanup_searcher fixture for explicit resource management - Global cleanup now more aggressive with ZMQ context destruction This is the proper fix - cleaning up resources at the source, not just working around the issue in tests. The hanging was caused by C++ side ZMQ connections not being properly terminated when is_recompute=True.
484 lines
19 KiB
Python
484 lines
19 KiB
Python
import contextlib
|
|
import logging
|
|
import os
|
|
import struct
|
|
import sys
|
|
from pathlib import Path
|
|
from typing import Any, Literal, Optional
|
|
|
|
import numpy as np
|
|
import psutil
|
|
from leann.interface import (
|
|
LeannBackendBuilderInterface,
|
|
LeannBackendFactoryInterface,
|
|
LeannBackendSearcherInterface,
|
|
)
|
|
from leann.registry import register_backend
|
|
from leann.searcher_base import BaseSearcher
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@contextlib.contextmanager
|
|
def suppress_cpp_output_if_needed():
|
|
"""Suppress C++ stdout/stderr based on LEANN_LOG_LEVEL"""
|
|
log_level = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper()
|
|
|
|
# Only suppress if log level is WARNING or higher (ERROR, CRITICAL)
|
|
should_suppress = log_level in ["WARNING", "ERROR", "CRITICAL"]
|
|
|
|
if not should_suppress:
|
|
# Don't suppress, just yield
|
|
yield
|
|
return
|
|
|
|
# Save original file descriptors
|
|
stdout_fd = sys.stdout.fileno()
|
|
stderr_fd = sys.stderr.fileno()
|
|
|
|
# Save original stdout/stderr
|
|
stdout_dup = os.dup(stdout_fd)
|
|
stderr_dup = os.dup(stderr_fd)
|
|
|
|
try:
|
|
# Redirect to /dev/null
|
|
devnull = os.open(os.devnull, os.O_WRONLY)
|
|
os.dup2(devnull, stdout_fd)
|
|
os.dup2(devnull, stderr_fd)
|
|
os.close(devnull)
|
|
|
|
yield
|
|
|
|
finally:
|
|
# Restore original file descriptors
|
|
os.dup2(stdout_dup, stdout_fd)
|
|
os.dup2(stderr_dup, stderr_fd)
|
|
os.close(stdout_dup)
|
|
os.close(stderr_dup)
|
|
|
|
|
|
def _get_diskann_metrics():
|
|
from . import _diskannpy as diskannpy # type: ignore
|
|
|
|
return {
|
|
"mips": diskannpy.Metric.INNER_PRODUCT,
|
|
"l2": diskannpy.Metric.L2,
|
|
"cosine": diskannpy.Metric.COSINE,
|
|
}
|
|
|
|
|
|
@contextlib.contextmanager
|
|
def chdir(path):
|
|
original_dir = os.getcwd()
|
|
os.chdir(path)
|
|
try:
|
|
yield
|
|
finally:
|
|
os.chdir(original_dir)
|
|
|
|
|
|
def _write_vectors_to_bin(data: np.ndarray, file_path: Path):
|
|
num_vectors, dim = data.shape
|
|
with open(file_path, "wb") as f:
|
|
f.write(struct.pack("I", num_vectors))
|
|
f.write(struct.pack("I", dim))
|
|
f.write(data.tobytes())
|
|
|
|
|
|
def _calculate_smart_memory_config(data: np.ndarray) -> tuple[float, float]:
|
|
"""
|
|
Calculate smart memory configuration for DiskANN based on data size and system specs.
|
|
|
|
Args:
|
|
data: The embedding data array
|
|
|
|
Returns:
|
|
tuple: (search_memory_maximum, build_memory_maximum) in GB
|
|
"""
|
|
num_vectors, dim = data.shape
|
|
|
|
# Calculate embedding storage size
|
|
embedding_size_bytes = num_vectors * dim * 4 # float32 = 4 bytes
|
|
embedding_size_gb = embedding_size_bytes / (1024**3)
|
|
|
|
# search_memory_maximum: 1/10 of embedding size for optimal PQ compression
|
|
# This controls Product Quantization size - smaller means more compression
|
|
search_memory_gb = max(0.1, embedding_size_gb / 10) # At least 100MB
|
|
|
|
# build_memory_maximum: Based on available system RAM for sharding control
|
|
# This controls how much memory DiskANN uses during index construction
|
|
available_memory_gb = psutil.virtual_memory().available / (1024**3)
|
|
total_memory_gb = psutil.virtual_memory().total / (1024**3)
|
|
|
|
# Use 50% of available memory, but at least 2GB and at most 75% of total
|
|
build_memory_gb = max(2.0, min(available_memory_gb * 0.5, total_memory_gb * 0.75))
|
|
|
|
logger.info(
|
|
f"Smart memory config - Data: {embedding_size_gb:.2f}GB, "
|
|
f"Search mem: {search_memory_gb:.2f}GB (PQ control), "
|
|
f"Build mem: {build_memory_gb:.2f}GB (sharding control)"
|
|
)
|
|
|
|
return search_memory_gb, build_memory_gb
|
|
|
|
|
|
@register_backend("diskann")
|
|
class DiskannBackend(LeannBackendFactoryInterface):
|
|
@staticmethod
|
|
def builder(**kwargs) -> LeannBackendBuilderInterface:
|
|
return DiskannBuilder(**kwargs)
|
|
|
|
@staticmethod
|
|
def searcher(index_path: str, **kwargs) -> LeannBackendSearcherInterface:
|
|
return DiskannSearcher(index_path, **kwargs)
|
|
|
|
|
|
class DiskannBuilder(LeannBackendBuilderInterface):
|
|
def __init__(self, **kwargs):
|
|
self.build_params = kwargs
|
|
|
|
def _safe_cleanup_after_partition(self, index_dir: Path, index_prefix: str):
|
|
"""
|
|
Safely cleanup files after partition.
|
|
In partition mode, C++ doesn't read _disk.index content,
|
|
so we can delete it if all derived files exist.
|
|
"""
|
|
disk_index_file = index_dir / f"{index_prefix}_disk.index"
|
|
beam_search_file = index_dir / f"{index_prefix}_disk_beam_search.index"
|
|
|
|
# Required files that C++ partition mode needs
|
|
# Note: C++ generates these with _disk.index suffix
|
|
disk_suffix = "_disk.index"
|
|
required_files = [
|
|
f"{index_prefix}{disk_suffix}_medoids.bin", # Critical: assert fails if missing
|
|
# Note: _centroids.bin is not created in single-shot build - C++ handles this automatically
|
|
f"{index_prefix}_pq_pivots.bin", # PQ table
|
|
f"{index_prefix}_pq_compressed.bin", # PQ compressed vectors
|
|
]
|
|
|
|
# Check if all required files exist
|
|
missing_files = []
|
|
for filename in required_files:
|
|
file_path = index_dir / filename
|
|
if not file_path.exists():
|
|
missing_files.append(filename)
|
|
|
|
if missing_files:
|
|
logger.warning(
|
|
f"Cannot safely delete _disk.index - missing required files: {missing_files}"
|
|
)
|
|
logger.info("Keeping all original files for safety")
|
|
return
|
|
|
|
# Calculate space savings
|
|
space_saved = 0
|
|
files_to_delete = []
|
|
|
|
if disk_index_file.exists():
|
|
space_saved += disk_index_file.stat().st_size
|
|
files_to_delete.append(disk_index_file)
|
|
|
|
if beam_search_file.exists():
|
|
space_saved += beam_search_file.stat().st_size
|
|
files_to_delete.append(beam_search_file)
|
|
|
|
# Safe to delete!
|
|
for file_to_delete in files_to_delete:
|
|
try:
|
|
os.remove(file_to_delete)
|
|
logger.info(f"✅ Safely deleted: {file_to_delete.name}")
|
|
except Exception as e:
|
|
logger.warning(f"Failed to delete {file_to_delete.name}: {e}")
|
|
|
|
if space_saved > 0:
|
|
space_saved_mb = space_saved / (1024 * 1024)
|
|
logger.info(f"💾 Space saved: {space_saved_mb:.1f} MB")
|
|
|
|
# Show what files are kept
|
|
logger.info("📁 Kept essential files for partition mode:")
|
|
for filename in required_files:
|
|
file_path = index_dir / filename
|
|
if file_path.exists():
|
|
size_mb = file_path.stat().st_size / (1024 * 1024)
|
|
logger.info(f" - {filename} ({size_mb:.1f} MB)")
|
|
|
|
def build(self, data: np.ndarray, ids: list[str], index_path: str, **kwargs):
|
|
path = Path(index_path)
|
|
index_dir = path.parent
|
|
index_prefix = path.stem
|
|
index_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
if data.dtype != np.float32:
|
|
logger.warning(f"Converting data to float32, shape: {data.shape}")
|
|
data = data.astype(np.float32)
|
|
|
|
data_filename = f"{index_prefix}_data.bin"
|
|
_write_vectors_to_bin(data, index_dir / data_filename)
|
|
|
|
build_kwargs = {**self.build_params, **kwargs}
|
|
|
|
# Extract is_recompute from nested backend_kwargs if needed
|
|
is_recompute = build_kwargs.get("is_recompute", False)
|
|
if not is_recompute and "backend_kwargs" in build_kwargs:
|
|
is_recompute = build_kwargs["backend_kwargs"].get("is_recompute", False)
|
|
|
|
# Flatten all backend_kwargs parameters to top level for compatibility
|
|
if "backend_kwargs" in build_kwargs:
|
|
nested_params = build_kwargs.pop("backend_kwargs")
|
|
build_kwargs.update(nested_params)
|
|
|
|
metric_enum = _get_diskann_metrics().get(
|
|
build_kwargs.get("distance_metric", "mips").lower()
|
|
)
|
|
if metric_enum is None:
|
|
raise ValueError(
|
|
f"Unsupported distance_metric '{build_kwargs.get('distance_metric', 'unknown')}'."
|
|
)
|
|
|
|
# Calculate smart memory configuration if not explicitly provided
|
|
if (
|
|
"search_memory_maximum" not in build_kwargs
|
|
or "build_memory_maximum" not in build_kwargs
|
|
):
|
|
smart_search_mem, smart_build_mem = _calculate_smart_memory_config(data)
|
|
else:
|
|
smart_search_mem = build_kwargs.get("search_memory_maximum", 4.0)
|
|
smart_build_mem = build_kwargs.get("build_memory_maximum", 8.0)
|
|
|
|
try:
|
|
from . import _diskannpy as diskannpy # type: ignore
|
|
|
|
with chdir(index_dir):
|
|
diskannpy.build_disk_float_index(
|
|
metric_enum,
|
|
data_filename,
|
|
index_prefix,
|
|
build_kwargs.get("complexity", 64),
|
|
build_kwargs.get("graph_degree", 32),
|
|
build_kwargs.get("search_memory_maximum", smart_search_mem),
|
|
build_kwargs.get("build_memory_maximum", smart_build_mem),
|
|
build_kwargs.get("num_threads", 8),
|
|
build_kwargs.get("pq_disk_bytes", 0),
|
|
"",
|
|
)
|
|
|
|
# Auto-partition if is_recompute is enabled
|
|
if build_kwargs.get("is_recompute", False):
|
|
logger.info("is_recompute=True, starting automatic graph partitioning...")
|
|
from .graph_partition import partition_graph
|
|
|
|
# Partition the index using absolute paths
|
|
# Convert to absolute paths to avoid issues with working directory changes
|
|
absolute_index_dir = Path(index_dir).resolve()
|
|
absolute_index_prefix_path = str(absolute_index_dir / index_prefix)
|
|
disk_graph_path, partition_bin_path = partition_graph(
|
|
index_prefix_path=absolute_index_prefix_path,
|
|
output_dir=str(absolute_index_dir),
|
|
partition_prefix=index_prefix,
|
|
)
|
|
|
|
# Safe cleanup: In partition mode, C++ doesn't read _disk.index content
|
|
# but still needs the derived files (_medoids.bin, _centroids.bin, etc.)
|
|
self._safe_cleanup_after_partition(index_dir, index_prefix)
|
|
|
|
logger.info("✅ Graph partitioning completed successfully!")
|
|
logger.info(f" - Disk graph: {disk_graph_path}")
|
|
logger.info(f" - Partition file: {partition_bin_path}")
|
|
|
|
finally:
|
|
temp_data_file = index_dir / data_filename
|
|
if temp_data_file.exists():
|
|
os.remove(temp_data_file)
|
|
logger.debug(f"Cleaned up temporary data file: {temp_data_file}")
|
|
|
|
|
|
class DiskannSearcher(BaseSearcher):
|
|
def __init__(self, index_path: str, **kwargs):
|
|
super().__init__(
|
|
index_path,
|
|
backend_module_name="leann_backend_diskann.diskann_embedding_server",
|
|
**kwargs,
|
|
)
|
|
|
|
# Initialize DiskANN index with suppressed C++ output based on log level
|
|
with suppress_cpp_output_if_needed():
|
|
from . import _diskannpy as diskannpy # type: ignore
|
|
|
|
distance_metric = kwargs.get("distance_metric", "mips").lower()
|
|
metric_enum = _get_diskann_metrics().get(distance_metric)
|
|
if metric_enum is None:
|
|
raise ValueError(f"Unsupported distance_metric '{distance_metric}'.")
|
|
|
|
self.num_threads = kwargs.get("num_threads", 8)
|
|
|
|
# For DiskANN, we need to reinitialize the index when zmq_port changes
|
|
# Store the initialization parameters for later use
|
|
# Note: C++ load method expects the BASE path (without _disk.index suffix)
|
|
# C++ internally constructs: index_prefix + "_disk.index"
|
|
index_name = self.index_path.stem # "simple_test.leann" -> "simple_test"
|
|
diskann_index_prefix = str(self.index_dir / index_name) # /path/to/simple_test
|
|
full_index_prefix = diskann_index_prefix # /path/to/simple_test (base path)
|
|
|
|
# Auto-detect partition files and set partition_prefix
|
|
partition_graph_file = self.index_dir / f"{index_name}_disk_graph.index"
|
|
partition_bin_file = self.index_dir / f"{index_name}_partition.bin"
|
|
|
|
partition_prefix = ""
|
|
if partition_graph_file.exists() and partition_bin_file.exists():
|
|
# C++ expects full path prefix, not just filename
|
|
partition_prefix = str(self.index_dir / index_name) # /path/to/simple_test
|
|
logger.info(
|
|
f"✅ Detected partition files, using partition_prefix='{partition_prefix}'"
|
|
)
|
|
else:
|
|
logger.debug("No partition files detected, using standard index files")
|
|
|
|
self._init_params = {
|
|
"metric_enum": metric_enum,
|
|
"full_index_prefix": full_index_prefix,
|
|
"num_threads": self.num_threads,
|
|
"num_nodes_to_cache": kwargs.get("num_nodes_to_cache", 0),
|
|
"cache_mechanism": 1,
|
|
"pq_prefix": "",
|
|
"partition_prefix": partition_prefix,
|
|
}
|
|
|
|
# Log partition configuration for debugging
|
|
if partition_prefix:
|
|
logger.info(
|
|
f"✅ Detected partition files, using partition_prefix='{partition_prefix}'"
|
|
)
|
|
self._diskannpy = diskannpy
|
|
self._current_zmq_port = None
|
|
self._index = None
|
|
logger.debug("DiskANN searcher initialized (index will be loaded on first search)")
|
|
|
|
def _ensure_index_loaded(self, zmq_port: int):
|
|
"""Ensure the index is loaded with the correct zmq_port."""
|
|
if self._index is None or self._current_zmq_port != zmq_port:
|
|
# Need to (re)load the index with the correct zmq_port
|
|
with suppress_cpp_output_if_needed():
|
|
if self._index is not None:
|
|
logger.debug(f"Reloading DiskANN index with new zmq_port: {zmq_port}")
|
|
else:
|
|
logger.debug(f"Loading DiskANN index with zmq_port: {zmq_port}")
|
|
|
|
self._index = self._diskannpy.StaticDiskFloatIndex(
|
|
self._init_params["metric_enum"],
|
|
self._init_params["full_index_prefix"],
|
|
self._init_params["num_threads"],
|
|
self._init_params["num_nodes_to_cache"],
|
|
self._init_params["cache_mechanism"],
|
|
zmq_port,
|
|
self._init_params["pq_prefix"],
|
|
self._init_params["partition_prefix"],
|
|
)
|
|
self._current_zmq_port = zmq_port
|
|
|
|
def search(
|
|
self,
|
|
query: np.ndarray,
|
|
top_k: int,
|
|
complexity: int = 64,
|
|
beam_width: int = 1,
|
|
prune_ratio: float = 0.0,
|
|
recompute_embeddings: bool = False,
|
|
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
|
zmq_port: Optional[int] = None,
|
|
batch_recompute: bool = False,
|
|
dedup_node_dis: bool = False,
|
|
**kwargs,
|
|
) -> dict[str, Any]:
|
|
"""
|
|
Search for nearest neighbors using DiskANN index.
|
|
|
|
Args:
|
|
query: Query vectors (B, D) where B is batch size, D is dimension
|
|
top_k: Number of nearest neighbors to return
|
|
complexity: Search complexity/candidate list size, higher = more accurate but slower
|
|
beam_width: Number of parallel IO requests per iteration
|
|
prune_ratio: Ratio of neighbors to prune via approximate distance (0.0-1.0)
|
|
recompute_embeddings: Whether to fetch fresh embeddings from server
|
|
pruning_strategy: PQ candidate selection strategy:
|
|
- "global": Use global pruning strategy (default)
|
|
- "local": Use local pruning strategy
|
|
- "proportional": Not supported in DiskANN, falls back to global
|
|
zmq_port: ZMQ port for embedding server communication. Must be provided if recompute_embeddings is True.
|
|
batch_recompute: Whether to batch neighbor recomputation (DiskANN-specific)
|
|
dedup_node_dis: Whether to cache and reuse distance computations (DiskANN-specific)
|
|
**kwargs: Additional DiskANN-specific parameters (for legacy compatibility)
|
|
|
|
Returns:
|
|
Dict with 'labels' (list of lists) and 'distances' (ndarray)
|
|
"""
|
|
# Handle zmq_port compatibility: Ensure index is loaded with correct port
|
|
if recompute_embeddings:
|
|
if zmq_port is None:
|
|
raise ValueError("zmq_port must be provided if recompute_embeddings is True")
|
|
self._ensure_index_loaded(zmq_port)
|
|
else:
|
|
# If not recomputing, we still need an index, use a default port
|
|
if self._index is None:
|
|
self._ensure_index_loaded(6666) # Default port when not recomputing
|
|
|
|
# DiskANN doesn't support "proportional" strategy
|
|
if pruning_strategy == "proportional":
|
|
raise NotImplementedError(
|
|
"DiskANN backend does not support 'proportional' pruning strategy. Use 'global' or 'local' instead."
|
|
)
|
|
|
|
if query.dtype != np.float32:
|
|
query = query.astype(np.float32)
|
|
|
|
# Map pruning_strategy to DiskANN's global_pruning parameter
|
|
if pruning_strategy == "local":
|
|
use_global_pruning = False
|
|
else: # "global"
|
|
use_global_pruning = True
|
|
|
|
# Perform search with suppressed C++ output based on log level
|
|
use_deferred_fetch = kwargs.get("USE_DEFERRED_FETCH", True)
|
|
recompute_neighors = False
|
|
with suppress_cpp_output_if_needed():
|
|
labels, distances = self._index.batch_search(
|
|
query,
|
|
query.shape[0],
|
|
top_k,
|
|
complexity,
|
|
beam_width,
|
|
self.num_threads,
|
|
use_deferred_fetch,
|
|
kwargs.get("skip_search_reorder", False),
|
|
recompute_neighors,
|
|
dedup_node_dis,
|
|
prune_ratio,
|
|
batch_recompute,
|
|
use_global_pruning,
|
|
)
|
|
|
|
string_labels = [[str(int_label) for int_label in batch_labels] for batch_labels in labels]
|
|
|
|
return {"labels": string_labels, "distances": distances}
|
|
|
|
def cleanup(self):
|
|
"""Cleanup DiskANN-specific resources including C++ index."""
|
|
# Call parent cleanup first
|
|
super().cleanup()
|
|
|
|
# Delete the C++ index to trigger destructors
|
|
try:
|
|
if hasattr(self, "_index") and self._index is not None:
|
|
del self._index
|
|
self._index = None
|
|
self._current_zmq_port = None
|
|
except Exception:
|
|
pass
|
|
|
|
# Force garbage collection to ensure C++ objects are destroyed
|
|
try:
|
|
import gc
|
|
|
|
gc.collect()
|
|
except Exception:
|
|
pass
|