feat: different search_args and docstrings

This commit is contained in:
Andy Lee
2025-07-16 15:25:58 -07:00
parent c3fb949693
commit 7b9406a3ea
5 changed files with 231 additions and 65 deletions

View File

@@ -3,7 +3,7 @@ import os
import json import json
import struct import struct
from pathlib import Path from pathlib import Path
from typing import Dict, Any, List from typing import Dict, Any, List, Literal
import contextlib import contextlib
import pickle import pickle
@@ -108,24 +108,69 @@ class DiskannSearcher(BaseSearcher):
kwargs.get("num_nodes_to_cache", 0), 1, self.zmq_port, "", "" kwargs.get("num_nodes_to_cache", 0), 1, self.zmq_port, "", ""
) )
def search(self, query: np.ndarray, top_k: int, **kwargs) -> Dict[str, Any]: def search(self, query: np.ndarray, top_k: int,
recompute = kwargs.get("recompute_beighbor_embeddings", False) complexity: int = 64,
if recompute: beam_width: int = 1,
prune_ratio: float = 0.0,
recompute_embeddings: bool = False,
pruning_strategy: Literal["global", "local", "proportional"] = "global",
zmq_port: int = 5557,
batch_recompute: bool = False,
dedup_node_dis: bool = False,
**kwargs) -> Dict[str, Any]:
"""
Search for nearest neighbors using DiskANN index.
Args:
query: Query vectors (B, D) where B is batch size, D is dimension
top_k: Number of nearest neighbors to return
complexity: Search complexity/candidate list size, higher = more accurate but slower
beam_width: Number of parallel IO requests per iteration
prune_ratio: Ratio of neighbors to prune via approximate distance (0.0-1.0)
recompute_embeddings: Whether to fetch fresh embeddings from server
pruning_strategy: PQ candidate selection strategy:
- "global": Use global pruning strategy (default)
- "local": Use local pruning strategy
- "proportional": Not supported in DiskANN, falls back to global
zmq_port: ZMQ port for embedding server
batch_recompute: Whether to batch neighbor recomputation (DiskANN-specific)
dedup_node_dis: Whether to cache and reuse distance computations (DiskANN-specific)
**kwargs: Additional DiskANN-specific parameters (for legacy compatibility)
Returns:
Dict with 'labels' (list of lists) and 'distances' (ndarray)
"""
# DiskANN doesn't support "proportional" strategy
if pruning_strategy == "proportional":
raise NotImplementedError("DiskANN backend does not support 'proportional' pruning strategy. Use 'global' or 'local' instead.")
# Use recompute_embeddings parameter
use_recompute = recompute_embeddings
if use_recompute:
meta_file_path = self.index_dir / f"{self.index_path.name}.meta.json" meta_file_path = self.index_dir / f"{self.index_path.name}.meta.json"
if not meta_file_path.exists(): if not meta_file_path.exists():
raise RuntimeError(f"FATAL: Recompute mode enabled but metadata file not found: {meta_file_path}") raise RuntimeError(f"FATAL: Recompute enabled but metadata file not found: {meta_file_path}")
zmq_port = kwargs.get("zmq_port", self.zmq_port)
self._ensure_server_running(str(meta_file_path), port=zmq_port, **kwargs) self._ensure_server_running(str(meta_file_path), port=zmq_port, **kwargs)
if query.dtype != np.float32: if query.dtype != np.float32:
query = query.astype(np.float32) query = query.astype(np.float32)
# Map pruning_strategy to DiskANN's global_pruning parameter
if pruning_strategy == "local":
use_global_pruning = False
else: # "global"
use_global_pruning = True
labels, distances = self._index.batch_search( labels, distances = self._index.batch_search(
query, query.shape[0], top_k, query, query.shape[0], top_k,
kwargs.get("complexity", 256), kwargs.get("beam_width", 4), self.num_threads, complexity, beam_width, self.num_threads,
kwargs.get("USE_DEFERRED_FETCH", False), kwargs.get("skip_search_reorder", False), kwargs.get("USE_DEFERRED_FETCH", False),
recompute, kwargs.get("dedup_node_dis", False), kwargs.get("prune_ratio", 0.0), kwargs.get("skip_search_reorder", False),
kwargs.get("batch_recompute", False), kwargs.get("global_pruning", False) use_recompute,
dedup_node_dis,
prune_ratio,
batch_recompute,
use_global_pruning
) )
string_labels = [[self.label_map.get(int_label, f"unknown_{int_label}") for int_label in batch_labels] for batch_labels in labels] string_labels = [[self.label_map.get(int_label, f"unknown_{int_label}") for int_label in batch_labels] for batch_labels in labels]

View File

@@ -2,7 +2,7 @@ import numpy as np
import os import os
import json import json
from pathlib import Path from pathlib import Path
from typing import Dict, Any, List from typing import Dict, Any, List, Literal
import pickle import pickle
import shutil import shutil
@@ -13,17 +13,20 @@ from leann.registry import register_backend
from leann.interface import ( from leann.interface import (
LeannBackendFactoryInterface, LeannBackendFactoryInterface,
LeannBackendBuilderInterface, LeannBackendBuilderInterface,
LeannBackendSearcherInterface LeannBackendSearcherInterface,
) )
def get_metric_map(): def get_metric_map():
from . import faiss from . import faiss
return { return {
"mips": faiss.METRIC_INNER_PRODUCT, "mips": faiss.METRIC_INNER_PRODUCT,
"l2": faiss.METRIC_L2, "l2": faiss.METRIC_L2,
"cosine": faiss.METRIC_INNER_PRODUCT, "cosine": faiss.METRIC_INNER_PRODUCT,
} }
@register_backend("hnsw") @register_backend("hnsw")
class HNSWBackend(LeannBackendFactoryInterface): class HNSWBackend(LeannBackendFactoryInterface):
@staticmethod @staticmethod
@@ -34,6 +37,7 @@ class HNSWBackend(LeannBackendFactoryInterface):
def searcher(index_path: str, **kwargs) -> LeannBackendSearcherInterface: def searcher(index_path: str, **kwargs) -> LeannBackendSearcherInterface:
return HNSWSearcher(index_path, **kwargs) return HNSWSearcher(index_path, **kwargs)
class HNSWBuilder(LeannBackendBuilderInterface): class HNSWBuilder(LeannBackendBuilderInterface):
def __init__(self, **kwargs): def __init__(self, **kwargs):
self.build_params = kwargs.copy() self.build_params = kwargs.copy()
@@ -46,6 +50,7 @@ class HNSWBuilder(LeannBackendBuilderInterface):
def build(self, data: np.ndarray, ids: List[str], index_path: str, **kwargs): def build(self, data: np.ndarray, ids: List[str], index_path: str, **kwargs):
from . import faiss from . import faiss
path = Path(index_path) path = Path(index_path)
index_dir = path.parent index_dir = path.parent
index_prefix = path.stem index_prefix = path.stem
@@ -56,7 +61,7 @@ class HNSWBuilder(LeannBackendBuilderInterface):
label_map = {i: str_id for i, str_id in enumerate(ids)} label_map = {i: str_id for i, str_id in enumerate(ids)}
label_map_file = index_dir / "leann.labels.map" label_map_file = index_dir / "leann.labels.map"
with open(label_map_file, 'wb') as f: with open(label_map_file, "wb") as f:
pickle.dump(label_map, f) pickle.dump(label_map, f)
metric_enum = get_metric_map().get(self.distance_metric.lower()) metric_enum = get_metric_map().get(self.distance_metric.lower())
@@ -85,9 +90,7 @@ class HNSWBuilder(LeannBackendBuilderInterface):
csr_temp_file = index_file.with_suffix(".csr.tmp") csr_temp_file = index_file.with_suffix(".csr.tmp")
success = convert_hnsw_graph_to_csr( success = convert_hnsw_graph_to_csr(
str(index_file), str(index_file), str(csr_temp_file), prune_embeddings=self.is_recompute
str(csr_temp_file),
prune_embeddings=self.is_recompute
) )
if success: if success:
@@ -95,16 +98,25 @@ class HNSWBuilder(LeannBackendBuilderInterface):
index_file_old = index_file.with_suffix(".old") index_file_old = index_file.with_suffix(".old")
shutil.move(str(index_file), str(index_file_old)) shutil.move(str(index_file), str(index_file_old))
shutil.move(str(csr_temp_file), str(index_file)) shutil.move(str(csr_temp_file), str(index_file))
print(f"INFO: Replaced original index with {mode_str} version at '{index_file}'") print(
f"INFO: Replaced original index with {mode_str} version at '{index_file}'"
)
else: else:
# Clean up and fail fast # Clean up and fail fast
if csr_temp_file.exists(): if csr_temp_file.exists():
os.remove(csr_temp_file) os.remove(csr_temp_file)
raise RuntimeError("CSR conversion failed - cannot proceed with compact format") raise RuntimeError(
"CSR conversion failed - cannot proceed with compact format"
)
class HNSWSearcher(BaseSearcher): class HNSWSearcher(BaseSearcher):
def __init__(self, index_path: str, **kwargs): def __init__(self, index_path: str, **kwargs):
super().__init__(index_path, backend_module_name="leann_backend_hnsw.hnsw_embedding_server", **kwargs) super().__init__(
index_path,
backend_module_name="leann_backend_hnsw.hnsw_embedding_server",
**kwargs,
)
from . import faiss from . import faiss
self.distance_metric = self.meta.get("distance_metric", "mips").lower() self.distance_metric = self.meta.get("distance_metric", "mips").lower()
@@ -113,8 +125,8 @@ class HNSWSearcher(BaseSearcher):
raise ValueError(f"Unsupported distance_metric '{self.distance_metric}'.") raise ValueError(f"Unsupported distance_metric '{self.distance_metric}'.")
self.is_compact, self.is_pruned = ( self.is_compact, self.is_pruned = (
self.meta.get('is_compact', True), self.meta.get("is_compact", True),
self.meta.get('is_pruned', True) self.meta.get("is_pruned", True),
) )
index_file = self.index_dir / f"{self.index_path.stem}.index" index_file = self.index_dir / f"{self.index_path.stem}.index"
@@ -130,14 +142,50 @@ class HNSWSearcher(BaseSearcher):
self._index = faiss.read_index(str(index_file), faiss.IO_FLAG_MMAP, hnsw_config) self._index = faiss.read_index(str(index_file), faiss.IO_FLAG_MMAP, hnsw_config)
def search(self, query: np.ndarray, top_k: int, **kwargs) -> Dict[str, Any]: def search(
from . import faiss self,
query: np.ndarray,
top_k: int,
complexity: int = 64,
beam_width: int = 1,
prune_ratio: float = 0.0,
recompute_embeddings: bool = False,
pruning_strategy: Literal["global", "local", "proportional"] = "global",
zmq_port: int = 5557,
batch_size: int = 0,
**kwargs,
) -> Dict[str, Any]:
"""
Search for nearest neighbors using HNSW index.
if self.is_pruned: Args:
query: Query vectors (B, D) where B is batch size, D is dimension
top_k: Number of nearest neighbors to return
complexity: Search complexity/efSearch, higher = more accurate but slower
beam_width: Number of parallel search paths/beam_size
prune_ratio: Ratio of neighbors to prune via PQ (0.0-1.0)
recompute_embeddings: Whether to fetch fresh embeddings from server
pruning_strategy: PQ candidate selection strategy:
- "global": Use global PQ queue size for selection (default)
- "local": Local pruning, sort and select best candidates
- "proportional": Base selection on new neighbor count ratio
zmq_port: ZMQ port for embedding server
batch_size: Neighbor processing batch size, 0=disabled (HNSW-specific)
**kwargs: Additional HNSW-specific parameters (for legacy compatibility)
Returns:
Dict with 'labels' (list of lists) and 'distances' (ndarray)
"""
from . import faiss # type: ignore
# Use recompute_embeddings parameter
use_recompute = recompute_embeddings or self.is_pruned
if use_recompute:
meta_file_path = self.index_dir / f"{self.index_path.name}.meta.json" meta_file_path = self.index_dir / f"{self.index_path.name}.meta.json"
if not meta_file_path.exists(): if not meta_file_path.exists():
raise RuntimeError(f"FATAL: Index is pruned but metadata file not found: {meta_file_path}") raise RuntimeError(
zmq_port = kwargs.get("zmq_port", 5557) f"FATAL: Recompute enabled but metadata file not found: {meta_file_path}"
)
self._ensure_server_running(str(meta_file_path), port=zmq_port, **kwargs) self._ensure_server_running(str(meta_file_path), port=zmq_port, **kwargs)
if query.dtype != np.float32: if query.dtype != np.float32:
@@ -146,16 +194,48 @@ class HNSWSearcher(BaseSearcher):
faiss.normalize_L2(query) faiss.normalize_L2(query)
params = faiss.SearchParametersHNSW() params = faiss.SearchParametersHNSW()
params.zmq_port = kwargs.get("zmq_port", 5557) params.zmq_port = zmq_port
params.efSearch = kwargs.get("complexity", 32) params.efSearch = complexity
params.beam_size = kwargs.get("beam_width", 1) params.beam_size = beam_width
batch_size = query.shape[0] # PQ pruning: direct mapping to HNSW's pq_pruning_ratio
distances = np.empty((batch_size, top_k), dtype=np.float32) params.pq_pruning_ratio = prune_ratio
labels = np.empty((batch_size, top_k), dtype=np.int64)
self._index.search(query.shape[0], faiss.swig_ptr(query), top_k, faiss.swig_ptr(distances), faiss.swig_ptr(labels), params) # Map pruning_strategy to HNSW parameters
if pruning_strategy == "local":
params.local_prune = True
params.send_neigh_times_ratio = 0.0
elif pruning_strategy == "proportional":
params.local_prune = False
params.send_neigh_times_ratio = (
1.0 # Any value > 1e-6 triggers proportional mode
)
else: # "global"
params.local_prune = False
params.send_neigh_times_ratio = 0.0
string_labels = [[self.label_map.get(int_label, f"unknown_{int_label}") for int_label in batch_labels] for batch_labels in labels] # HNSW-specific batch processing parameter
params.batch_size = batch_size
return {"labels": string_labels, "distances": distances} batch_size_query = query.shape[0]
distances = np.empty((batch_size_query, top_k), dtype=np.float32)
labels = np.empty((batch_size_query, top_k), dtype=np.int64)
self._index.search(
query.shape[0],
faiss.swig_ptr(query),
top_k,
faiss.swig_ptr(distances),
faiss.swig_ptr(labels),
params,
)
string_labels = [
[
self.label_map.get(int_label, f"unknown_{int_label}")
for int_label in batch_labels
]
for batch_labels in labels
]
return {"labels": string_labels, "distances": distances}

View File

@@ -175,7 +175,7 @@ class EmbeddingServerManager:
self.backend_module_name = backend_module_name self.backend_module_name = backend_module_name
self.server_process: Optional[subprocess.Popen] = None self.server_process: Optional[subprocess.Popen] = None
self.server_port: Optional[int] = None self.server_port: Optional[int] = None
atexit.register(self.stop_server) # atexit.register(self.stop_server)
def start_server(self, port: int, model_name: str, **kwargs) -> bool: def start_server(self, port: int, model_name: str, **kwargs) -> bool:
""" """

View File

@@ -1,42 +1,55 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
import numpy as np import numpy as np
from typing import Dict, Any from typing import Dict, Any, Literal
class LeannBackendBuilderInterface(ABC): class LeannBackendBuilderInterface(ABC):
"""用于构建索引的后端接口""" """Backend interface for building indexes"""
@abstractmethod @abstractmethod
def build(self, data: np.ndarray, index_path: str, **kwargs) -> None: def build(self, data: np.ndarray, index_path: str, **kwargs) -> None:
"""构建索引 """Build index
Args: Args:
data: 向量数据 (N, D) data: Vector data (N, D)
index_path: 索引保存路径 index_path: Path to save index
**kwargs: 后端特定的构建参数 **kwargs: Backend-specific build parameters
""" """
pass pass
class LeannBackendSearcherInterface(ABC): class LeannBackendSearcherInterface(ABC):
"""用于搜索的后端接口""" """Backend interface for searching"""
@abstractmethod @abstractmethod
def __init__(self, index_path: str, **kwargs): def __init__(self, index_path: str, **kwargs):
"""初始化搜索器 """Initialize searcher
Args: Args:
index_path: 索引文件路径 index_path: Path to index file
**kwargs: 后端特定的加载参数 **kwargs: Backend-specific loading parameters
""" """
pass pass
@abstractmethod @abstractmethod
def search(self, query: np.ndarray, top_k: int, **kwargs) -> Dict[str, Any]: def search(self, query: np.ndarray, top_k: int,
"""搜索最近邻 complexity: int = 64,
beam_width: int = 1,
prune_ratio: float = 0.0,
recompute_embeddings: bool = False,
pruning_strategy: Literal["global", "local", "proportional"] = "global",
zmq_port: int = 5557,
**kwargs) -> Dict[str, Any]:
"""Search for nearest neighbors
Args: Args:
query: 查询向量 (1, D) 或 (B, D) query: Query vectors (B, D) where B is batch size, D is dimension
top_k: 返回的最近邻数量 top_k: Number of nearest neighbors to return
**kwargs: 搜索参数 complexity: Search complexity/candidate list size, higher = more accurate but slower
beam_width: Number of parallel search paths/IO requests per iteration
prune_ratio: Ratio of neighbors to prune via approximate distance (0.0-1.0)
recompute_embeddings: Whether to fetch fresh embeddings from server vs use stored PQ codes
pruning_strategy: PQ candidate selection strategy - "global", "local", or "proportional"
zmq_port: ZMQ port for embedding server communication
**kwargs: Backend-specific parameters
Returns: Returns:
{"labels": [...], "distances": [...]} {"labels": [...], "distances": [...]}
@@ -44,16 +57,16 @@ class LeannBackendSearcherInterface(ABC):
pass pass
class LeannBackendFactoryInterface(ABC): class LeannBackendFactoryInterface(ABC):
"""后端工厂接口""" """Backend factory interface"""
@staticmethod @staticmethod
@abstractmethod @abstractmethod
def builder(**kwargs) -> LeannBackendBuilderInterface: def builder(**kwargs) -> LeannBackendBuilderInterface:
"""创建 Builder 实例""" """Create Builder instance"""
pass pass
@staticmethod @staticmethod
@abstractmethod @abstractmethod
def searcher(index_path: str, **kwargs) -> LeannBackendSearcherInterface: def searcher(index_path: str, **kwargs) -> LeannBackendSearcherInterface:
"""创建 Searcher 实例""" """Create Searcher instance"""
pass pass

View File

@@ -1,9 +1,8 @@
import json import json
import pickle import pickle
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from pathlib import Path from pathlib import Path
from typing import Dict, Any, List from typing import Dict, Any, Literal
import numpy as np import numpy as np
@@ -40,7 +39,9 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
self.embedding_model = self.meta.get("embedding_model") self.embedding_model = self.meta.get("embedding_model")
if not self.embedding_model: if not self.embedding_model:
print("WARNING: embedding_model not found in meta.json. Recompute will fail.") print(
"WARNING: embedding_model not found in meta.json. Recompute will fail."
)
self.label_map = self._load_label_map() self.label_map = self._load_label_map()
@@ -54,7 +55,7 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
meta_path = self.index_dir / f"{self.index_path.name}.meta.json" meta_path = self.index_dir / f"{self.index_path.name}.meta.json"
if not meta_path.exists(): if not meta_path.exists():
raise FileNotFoundError(f"Leann metadata file not found at {meta_path}") raise FileNotFoundError(f"Leann metadata file not found at {meta_path}")
with open(meta_path, 'r', encoding='utf-8') as f: with open(meta_path, "r", encoding="utf-8") as f:
return json.load(f) return json.load(f)
def _load_label_map(self) -> Dict[int, str]: def _load_label_map(self) -> Dict[int, str]:
@@ -62,16 +63,20 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
label_map_file = self.index_dir / "leann.labels.map" label_map_file = self.index_dir / "leann.labels.map"
if not label_map_file.exists(): if not label_map_file.exists():
raise FileNotFoundError(f"Label map file not found: {label_map_file}") raise FileNotFoundError(f"Label map file not found: {label_map_file}")
with open(label_map_file, 'rb') as f: with open(label_map_file, "rb") as f:
return pickle.load(f) return pickle.load(f)
def _ensure_server_running(self, passages_source_file: str, port: int, **kwargs) -> None: def _ensure_server_running(
self, passages_source_file: str, port: int, **kwargs
) -> None:
""" """
Ensures the embedding server is running if recompute is needed. Ensures the embedding server is running if recompute is needed.
This is a helper for subclasses. This is a helper for subclasses.
""" """
if not self.embedding_model: if not self.embedding_model:
raise ValueError("Cannot use recompute mode without 'embedding_model' in meta.json.") raise ValueError(
"Cannot use recompute mode without 'embedding_model' in meta.json."
)
server_started = self.embedding_server_manager.start_server( server_started = self.embedding_server_manager.start_server(
port=port, port=port,
@@ -85,15 +90,38 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
raise RuntimeError(f"Failed to start embedding server on port {port}") raise RuntimeError(f"Failed to start embedding server on port {port}")
@abstractmethod @abstractmethod
def search(self, query: np.ndarray, top_k: int, **kwargs) -> Dict[str, Any]: def search(
self,
query: np.ndarray,
top_k: int,
complexity: int = 64,
beam_width: int = 1,
prune_ratio: float = 0.0,
recompute_embeddings: bool = False,
pruning_strategy: Literal["global", "local", "proportional"] = "global",
zmq_port: int = 5557,
**kwargs,
) -> Dict[str, Any]:
""" """
Search for the top_k nearest neighbors of the query vector. Search for the top_k nearest neighbors of the query vector.
Must be implemented by subclasses.
Args:
query: Query vectors (B, D) where B is batch size, D is dimension
top_k: Number of nearest neighbors to return
complexity: Search complexity/candidate list size, higher = more accurate but slower
beam_width: Number of parallel search paths/IO requests per iteration
prune_ratio: Ratio of neighbors to prune via approximate distance (0.0-1.0)
recompute_embeddings: Whether to fetch fresh embeddings from server vs use stored PQ codes
pruning_strategy: PQ candidate selection strategy - "global" (default), "local", or "proportional"
zmq_port: ZMQ port for embedding server communication
**kwargs: Backend-specific parameters (e.g., batch_size, dedup_node_dis, etc.)
Returns:
Dict with 'labels' (list of lists) and 'distances' (ndarray)
""" """
pass pass
def __del__(self): def __del__(self):
"""Ensures the embedding server is stopped when the searcher is destroyed.""" """Ensures the embedding server is stopped when the searcher is destroyed."""
if hasattr(self, 'embedding_server_manager'): if hasattr(self, "embedding_server_manager"):
self.embedding_server_manager.stop_server() self.embedding_server_manager.stop_server()