From 7b9406a3ead3f84fbe42df5555cec0a4293e557b Mon Sep 17 00:00:00 2001 From: Andy Lee Date: Wed, 16 Jul 2025 15:25:58 -0700 Subject: [PATCH] feat: different search_args and docstrings --- .../leann_backend_diskann/diskann_backend.py | 65 +++++++-- .../leann_backend_hnsw/hnsw_backend.py | 130 ++++++++++++++---- .../src/leann/embedding_server_manager.py | 2 +- packages/leann-core/src/leann/interface.py | 49 ++++--- .../leann-core/src/leann/searcher_base.py | 50 +++++-- 5 files changed, 231 insertions(+), 65 deletions(-) diff --git a/packages/leann-backend-diskann/leann_backend_diskann/diskann_backend.py b/packages/leann-backend-diskann/leann_backend_diskann/diskann_backend.py index e31f912..40e5495 100644 --- a/packages/leann-backend-diskann/leann_backend_diskann/diskann_backend.py +++ b/packages/leann-backend-diskann/leann_backend_diskann/diskann_backend.py @@ -3,7 +3,7 @@ import os import json import struct from pathlib import Path -from typing import Dict, Any, List +from typing import Dict, Any, List, Literal import contextlib import pickle @@ -108,24 +108,69 @@ class DiskannSearcher(BaseSearcher): kwargs.get("num_nodes_to_cache", 0), 1, self.zmq_port, "", "" ) - def search(self, query: np.ndarray, top_k: int, **kwargs) -> Dict[str, Any]: - recompute = kwargs.get("recompute_beighbor_embeddings", False) - if recompute: + def search(self, query: np.ndarray, top_k: int, + complexity: int = 64, + beam_width: int = 1, + prune_ratio: float = 0.0, + recompute_embeddings: bool = False, + pruning_strategy: Literal["global", "local", "proportional"] = "global", + zmq_port: int = 5557, + batch_recompute: bool = False, + dedup_node_dis: bool = False, + **kwargs) -> Dict[str, Any]: + """ + Search for nearest neighbors using DiskANN index. + + Args: + query: Query vectors (B, D) where B is batch size, D is dimension + top_k: Number of nearest neighbors to return + complexity: Search complexity/candidate list size, higher = more accurate but slower + beam_width: Number of parallel IO requests per iteration + prune_ratio: Ratio of neighbors to prune via approximate distance (0.0-1.0) + recompute_embeddings: Whether to fetch fresh embeddings from server + pruning_strategy: PQ candidate selection strategy: + - "global": Use global pruning strategy (default) + - "local": Use local pruning strategy + - "proportional": Not supported in DiskANN, falls back to global + zmq_port: ZMQ port for embedding server + batch_recompute: Whether to batch neighbor recomputation (DiskANN-specific) + dedup_node_dis: Whether to cache and reuse distance computations (DiskANN-specific) + **kwargs: Additional DiskANN-specific parameters (for legacy compatibility) + + Returns: + Dict with 'labels' (list of lists) and 'distances' (ndarray) + """ + # DiskANN doesn't support "proportional" strategy + if pruning_strategy == "proportional": + raise NotImplementedError("DiskANN backend does not support 'proportional' pruning strategy. Use 'global' or 'local' instead.") + + # Use recompute_embeddings parameter + use_recompute = recompute_embeddings + if use_recompute: meta_file_path = self.index_dir / f"{self.index_path.name}.meta.json" if not meta_file_path.exists(): - raise RuntimeError(f"FATAL: Recompute mode enabled but metadata file not found: {meta_file_path}") - zmq_port = kwargs.get("zmq_port", self.zmq_port) + raise RuntimeError(f"FATAL: Recompute enabled but metadata file not found: {meta_file_path}") self._ensure_server_running(str(meta_file_path), port=zmq_port, **kwargs) if query.dtype != np.float32: query = query.astype(np.float32) + # Map pruning_strategy to DiskANN's global_pruning parameter + if pruning_strategy == "local": + use_global_pruning = False + else: # "global" + use_global_pruning = True + labels, distances = self._index.batch_search( query, query.shape[0], top_k, - kwargs.get("complexity", 256), kwargs.get("beam_width", 4), self.num_threads, - kwargs.get("USE_DEFERRED_FETCH", False), kwargs.get("skip_search_reorder", False), - recompute, kwargs.get("dedup_node_dis", False), kwargs.get("prune_ratio", 0.0), - kwargs.get("batch_recompute", False), kwargs.get("global_pruning", False) + complexity, beam_width, self.num_threads, + kwargs.get("USE_DEFERRED_FETCH", False), + kwargs.get("skip_search_reorder", False), + use_recompute, + dedup_node_dis, + prune_ratio, + batch_recompute, + use_global_pruning ) string_labels = [[self.label_map.get(int_label, f"unknown_{int_label}") for int_label in batch_labels] for batch_labels in labels] diff --git a/packages/leann-backend-hnsw/leann_backend_hnsw/hnsw_backend.py b/packages/leann-backend-hnsw/leann_backend_hnsw/hnsw_backend.py index 02d91c9..dccfa3f 100644 --- a/packages/leann-backend-hnsw/leann_backend_hnsw/hnsw_backend.py +++ b/packages/leann-backend-hnsw/leann_backend_hnsw/hnsw_backend.py @@ -2,7 +2,7 @@ import numpy as np import os import json from pathlib import Path -from typing import Dict, Any, List +from typing import Dict, Any, List, Literal import pickle import shutil @@ -13,17 +13,20 @@ from leann.registry import register_backend from leann.interface import ( LeannBackendFactoryInterface, LeannBackendBuilderInterface, - LeannBackendSearcherInterface + LeannBackendSearcherInterface, ) + def get_metric_map(): from . import faiss + return { "mips": faiss.METRIC_INNER_PRODUCT, "l2": faiss.METRIC_L2, "cosine": faiss.METRIC_INNER_PRODUCT, } + @register_backend("hnsw") class HNSWBackend(LeannBackendFactoryInterface): @staticmethod @@ -34,6 +37,7 @@ class HNSWBackend(LeannBackendFactoryInterface): def searcher(index_path: str, **kwargs) -> LeannBackendSearcherInterface: return HNSWSearcher(index_path, **kwargs) + class HNSWBuilder(LeannBackendBuilderInterface): def __init__(self, **kwargs): self.build_params = kwargs.copy() @@ -46,6 +50,7 @@ class HNSWBuilder(LeannBackendBuilderInterface): def build(self, data: np.ndarray, ids: List[str], index_path: str, **kwargs): from . import faiss + path = Path(index_path) index_dir = path.parent index_prefix = path.stem @@ -56,7 +61,7 @@ class HNSWBuilder(LeannBackendBuilderInterface): label_map = {i: str_id for i, str_id in enumerate(ids)} label_map_file = index_dir / "leann.labels.map" - with open(label_map_file, 'wb') as f: + with open(label_map_file, "wb") as f: pickle.dump(label_map, f) metric_enum = get_metric_map().get(self.distance_metric.lower()) @@ -85,9 +90,7 @@ class HNSWBuilder(LeannBackendBuilderInterface): csr_temp_file = index_file.with_suffix(".csr.tmp") success = convert_hnsw_graph_to_csr( - str(index_file), - str(csr_temp_file), - prune_embeddings=self.is_recompute + str(index_file), str(csr_temp_file), prune_embeddings=self.is_recompute ) if success: @@ -95,16 +98,25 @@ class HNSWBuilder(LeannBackendBuilderInterface): index_file_old = index_file.with_suffix(".old") shutil.move(str(index_file), str(index_file_old)) shutil.move(str(csr_temp_file), str(index_file)) - print(f"INFO: Replaced original index with {mode_str} version at '{index_file}'") + print( + f"INFO: Replaced original index with {mode_str} version at '{index_file}'" + ) else: # Clean up and fail fast if csr_temp_file.exists(): os.remove(csr_temp_file) - raise RuntimeError("CSR conversion failed - cannot proceed with compact format") + raise RuntimeError( + "CSR conversion failed - cannot proceed with compact format" + ) + class HNSWSearcher(BaseSearcher): def __init__(self, index_path: str, **kwargs): - super().__init__(index_path, backend_module_name="leann_backend_hnsw.hnsw_embedding_server", **kwargs) + super().__init__( + index_path, + backend_module_name="leann_backend_hnsw.hnsw_embedding_server", + **kwargs, + ) from . import faiss self.distance_metric = self.meta.get("distance_metric", "mips").lower() @@ -113,8 +125,8 @@ class HNSWSearcher(BaseSearcher): raise ValueError(f"Unsupported distance_metric '{self.distance_metric}'.") self.is_compact, self.is_pruned = ( - self.meta.get('is_compact', True), - self.meta.get('is_pruned', True) + self.meta.get("is_compact", True), + self.meta.get("is_pruned", True), ) index_file = self.index_dir / f"{self.index_path.stem}.index" @@ -130,14 +142,50 @@ class HNSWSearcher(BaseSearcher): self._index = faiss.read_index(str(index_file), faiss.IO_FLAG_MMAP, hnsw_config) - def search(self, query: np.ndarray, top_k: int, **kwargs) -> Dict[str, Any]: - from . import faiss + def search( + self, + query: np.ndarray, + top_k: int, + complexity: int = 64, + beam_width: int = 1, + prune_ratio: float = 0.0, + recompute_embeddings: bool = False, + pruning_strategy: Literal["global", "local", "proportional"] = "global", + zmq_port: int = 5557, + batch_size: int = 0, + **kwargs, + ) -> Dict[str, Any]: + """ + Search for nearest neighbors using HNSW index. - if self.is_pruned: + Args: + query: Query vectors (B, D) where B is batch size, D is dimension + top_k: Number of nearest neighbors to return + complexity: Search complexity/efSearch, higher = more accurate but slower + beam_width: Number of parallel search paths/beam_size + prune_ratio: Ratio of neighbors to prune via PQ (0.0-1.0) + recompute_embeddings: Whether to fetch fresh embeddings from server + pruning_strategy: PQ candidate selection strategy: + - "global": Use global PQ queue size for selection (default) + - "local": Local pruning, sort and select best candidates + - "proportional": Base selection on new neighbor count ratio + zmq_port: ZMQ port for embedding server + batch_size: Neighbor processing batch size, 0=disabled (HNSW-specific) + **kwargs: Additional HNSW-specific parameters (for legacy compatibility) + + Returns: + Dict with 'labels' (list of lists) and 'distances' (ndarray) + """ + from . import faiss # type: ignore + + # Use recompute_embeddings parameter + use_recompute = recompute_embeddings or self.is_pruned + if use_recompute: meta_file_path = self.index_dir / f"{self.index_path.name}.meta.json" if not meta_file_path.exists(): - raise RuntimeError(f"FATAL: Index is pruned but metadata file not found: {meta_file_path}") - zmq_port = kwargs.get("zmq_port", 5557) + raise RuntimeError( + f"FATAL: Recompute enabled but metadata file not found: {meta_file_path}" + ) self._ensure_server_running(str(meta_file_path), port=zmq_port, **kwargs) if query.dtype != np.float32: @@ -146,16 +194,48 @@ class HNSWSearcher(BaseSearcher): faiss.normalize_L2(query) params = faiss.SearchParametersHNSW() - params.zmq_port = kwargs.get("zmq_port", 5557) - params.efSearch = kwargs.get("complexity", 32) - params.beam_size = kwargs.get("beam_width", 1) + params.zmq_port = zmq_port + params.efSearch = complexity + params.beam_size = beam_width - batch_size = query.shape[0] - distances = np.empty((batch_size, top_k), dtype=np.float32) - labels = np.empty((batch_size, top_k), dtype=np.int64) + # PQ pruning: direct mapping to HNSW's pq_pruning_ratio + params.pq_pruning_ratio = prune_ratio - self._index.search(query.shape[0], faiss.swig_ptr(query), top_k, faiss.swig_ptr(distances), faiss.swig_ptr(labels), params) + # Map pruning_strategy to HNSW parameters + if pruning_strategy == "local": + params.local_prune = True + params.send_neigh_times_ratio = 0.0 + elif pruning_strategy == "proportional": + params.local_prune = False + params.send_neigh_times_ratio = ( + 1.0 # Any value > 1e-6 triggers proportional mode + ) + else: # "global" + params.local_prune = False + params.send_neigh_times_ratio = 0.0 - string_labels = [[self.label_map.get(int_label, f"unknown_{int_label}") for int_label in batch_labels] for batch_labels in labels] + # HNSW-specific batch processing parameter + params.batch_size = batch_size - return {"labels": string_labels, "distances": distances} \ No newline at end of file + batch_size_query = query.shape[0] + distances = np.empty((batch_size_query, top_k), dtype=np.float32) + labels = np.empty((batch_size_query, top_k), dtype=np.int64) + + self._index.search( + query.shape[0], + faiss.swig_ptr(query), + top_k, + faiss.swig_ptr(distances), + faiss.swig_ptr(labels), + params, + ) + + string_labels = [ + [ + self.label_map.get(int_label, f"unknown_{int_label}") + for int_label in batch_labels + ] + for batch_labels in labels + ] + + return {"labels": string_labels, "distances": distances} diff --git a/packages/leann-core/src/leann/embedding_server_manager.py b/packages/leann-core/src/leann/embedding_server_manager.py index 9b0bf53..f409241 100644 --- a/packages/leann-core/src/leann/embedding_server_manager.py +++ b/packages/leann-core/src/leann/embedding_server_manager.py @@ -175,7 +175,7 @@ class EmbeddingServerManager: self.backend_module_name = backend_module_name self.server_process: Optional[subprocess.Popen] = None self.server_port: Optional[int] = None - atexit.register(self.stop_server) + # atexit.register(self.stop_server) def start_server(self, port: int, model_name: str, **kwargs) -> bool: """ diff --git a/packages/leann-core/src/leann/interface.py b/packages/leann-core/src/leann/interface.py index 2786b62..1c36c1a 100644 --- a/packages/leann-core/src/leann/interface.py +++ b/packages/leann-core/src/leann/interface.py @@ -1,42 +1,55 @@ from abc import ABC, abstractmethod import numpy as np -from typing import Dict, Any +from typing import Dict, Any, Literal class LeannBackendBuilderInterface(ABC): - """用于构建索引的后端接口""" + """Backend interface for building indexes""" @abstractmethod def build(self, data: np.ndarray, index_path: str, **kwargs) -> None: - """构建索引 + """Build index Args: - data: 向量数据 (N, D) - index_path: 索引保存路径 - **kwargs: 后端特定的构建参数 + data: Vector data (N, D) + index_path: Path to save index + **kwargs: Backend-specific build parameters """ pass class LeannBackendSearcherInterface(ABC): - """用于搜索的后端接口""" + """Backend interface for searching""" @abstractmethod def __init__(self, index_path: str, **kwargs): - """初始化搜索器 + """Initialize searcher Args: - index_path: 索引文件路径 - **kwargs: 后端特定的加载参数 + index_path: Path to index file + **kwargs: Backend-specific loading parameters """ pass @abstractmethod - def search(self, query: np.ndarray, top_k: int, **kwargs) -> Dict[str, Any]: - """搜索最近邻 + def search(self, query: np.ndarray, top_k: int, + complexity: int = 64, + beam_width: int = 1, + prune_ratio: float = 0.0, + recompute_embeddings: bool = False, + pruning_strategy: Literal["global", "local", "proportional"] = "global", + zmq_port: int = 5557, + **kwargs) -> Dict[str, Any]: + """Search for nearest neighbors Args: - query: 查询向量 (1, D) 或 (B, D) - top_k: 返回的最近邻数量 - **kwargs: 搜索参数 + query: Query vectors (B, D) where B is batch size, D is dimension + top_k: Number of nearest neighbors to return + complexity: Search complexity/candidate list size, higher = more accurate but slower + beam_width: Number of parallel search paths/IO requests per iteration + prune_ratio: Ratio of neighbors to prune via approximate distance (0.0-1.0) + recompute_embeddings: Whether to fetch fresh embeddings from server vs use stored PQ codes + pruning_strategy: PQ candidate selection strategy - "global", "local", or "proportional" + zmq_port: ZMQ port for embedding server communication + **kwargs: Backend-specific parameters Returns: {"labels": [...], "distances": [...]} @@ -44,16 +57,16 @@ class LeannBackendSearcherInterface(ABC): pass class LeannBackendFactoryInterface(ABC): - """后端工厂接口""" + """Backend factory interface""" @staticmethod @abstractmethod def builder(**kwargs) -> LeannBackendBuilderInterface: - """创建 Builder 实例""" + """Create Builder instance""" pass @staticmethod @abstractmethod def searcher(index_path: str, **kwargs) -> LeannBackendSearcherInterface: - """创建 Searcher 实例""" + """Create Searcher instance""" pass \ No newline at end of file diff --git a/packages/leann-core/src/leann/searcher_base.py b/packages/leann-core/src/leann/searcher_base.py index b0f5ad3..82069e3 100644 --- a/packages/leann-core/src/leann/searcher_base.py +++ b/packages/leann-core/src/leann/searcher_base.py @@ -1,9 +1,8 @@ - import json import pickle from abc import ABC, abstractmethod from pathlib import Path -from typing import Dict, Any, List +from typing import Dict, Any, Literal import numpy as np @@ -40,7 +39,9 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC): self.embedding_model = self.meta.get("embedding_model") if not self.embedding_model: - print("WARNING: embedding_model not found in meta.json. Recompute will fail.") + print( + "WARNING: embedding_model not found in meta.json. Recompute will fail." + ) self.label_map = self._load_label_map() @@ -54,7 +55,7 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC): meta_path = self.index_dir / f"{self.index_path.name}.meta.json" if not meta_path.exists(): raise FileNotFoundError(f"Leann metadata file not found at {meta_path}") - with open(meta_path, 'r', encoding='utf-8') as f: + with open(meta_path, "r", encoding="utf-8") as f: return json.load(f) def _load_label_map(self) -> Dict[int, str]: @@ -62,16 +63,20 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC): label_map_file = self.index_dir / "leann.labels.map" if not label_map_file.exists(): raise FileNotFoundError(f"Label map file not found: {label_map_file}") - with open(label_map_file, 'rb') as f: + with open(label_map_file, "rb") as f: return pickle.load(f) - def _ensure_server_running(self, passages_source_file: str, port: int, **kwargs) -> None: + def _ensure_server_running( + self, passages_source_file: str, port: int, **kwargs + ) -> None: """ Ensures the embedding server is running if recompute is needed. This is a helper for subclasses. """ if not self.embedding_model: - raise ValueError("Cannot use recompute mode without 'embedding_model' in meta.json.") + raise ValueError( + "Cannot use recompute mode without 'embedding_model' in meta.json." + ) server_started = self.embedding_server_manager.start_server( port=port, @@ -85,15 +90,38 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC): raise RuntimeError(f"Failed to start embedding server on port {port}") @abstractmethod - def search(self, query: np.ndarray, top_k: int, **kwargs) -> Dict[str, Any]: + def search( + self, + query: np.ndarray, + top_k: int, + complexity: int = 64, + beam_width: int = 1, + prune_ratio: float = 0.0, + recompute_embeddings: bool = False, + pruning_strategy: Literal["global", "local", "proportional"] = "global", + zmq_port: int = 5557, + **kwargs, + ) -> Dict[str, Any]: """ Search for the top_k nearest neighbors of the query vector. - Must be implemented by subclasses. + + Args: + query: Query vectors (B, D) where B is batch size, D is dimension + top_k: Number of nearest neighbors to return + complexity: Search complexity/candidate list size, higher = more accurate but slower + beam_width: Number of parallel search paths/IO requests per iteration + prune_ratio: Ratio of neighbors to prune via approximate distance (0.0-1.0) + recompute_embeddings: Whether to fetch fresh embeddings from server vs use stored PQ codes + pruning_strategy: PQ candidate selection strategy - "global" (default), "local", or "proportional" + zmq_port: ZMQ port for embedding server communication + **kwargs: Backend-specific parameters (e.g., batch_size, dedup_node_dis, etc.) + + Returns: + Dict with 'labels' (list of lists) and 'distances' (ndarray) """ pass def __del__(self): """Ensures the embedding server is stopped when the searcher is destroyed.""" - if hasattr(self, 'embedding_server_manager'): + if hasattr(self, "embedding_server_manager"): self.embedding_server_manager.stop_server() -