diff --git a/examples/main_cli_example.py b/examples/main_cli_example.py index baf4fcc..d52d7ce 100644 --- a/examples/main_cli_example.py +++ b/examples/main_cli_example.py @@ -70,9 +70,7 @@ async def main(args): # ) print(f"You: {query}") - chat_response = chat.ask( - query, top_k=20, recompute_beighbor_embeddings=True, complexity=32 - ) + chat_response = chat.ask(query, top_k=20, recompute_embeddings=True, complexity=32) print(f"Leann: {chat_response}") diff --git a/packages/leann-backend-diskann/leann_backend_diskann/diskann_backend.py b/packages/leann-backend-diskann/leann_backend_diskann/diskann_backend.py index bbd042d..a28a744 100644 --- a/packages/leann-backend-diskann/leann_backend_diskann/diskann_backend.py +++ b/packages/leann-backend-diskann/leann_backend_diskann/diskann_backend.py @@ -4,7 +4,6 @@ import struct from pathlib import Path from typing import Dict, Any, List, Literal import contextlib -import pickle from leann.searcher_base import BaseSearcher from leann.registry import register_backend @@ -70,7 +69,6 @@ class DiskannBuilder(LeannBackendBuilderInterface): data_filename = f"{index_prefix}_data.bin" _write_vectors_to_bin(data, index_dir / data_filename) - build_kwargs = {**self.build_params, **kwargs} metric_enum = _get_diskann_metrics().get( build_kwargs.get("distance_metric", "mips").lower() @@ -207,8 +205,7 @@ class DiskannSearcher(BaseSearcher): ) string_labels = [ - [str(int_label) for int_label in batch_labels] - for batch_labels in labels + [str(int_label) for int_label in batch_labels] for batch_labels in labels ] return {"labels": string_labels, "distances": distances} diff --git a/packages/leann-backend-hnsw/leann_backend_hnsw/hnsw_backend.py b/packages/leann-backend-hnsw/leann_backend_hnsw/hnsw_backend.py index b7061e1..f1f8da0 100644 --- a/packages/leann-backend-hnsw/leann_backend_hnsw/hnsw_backend.py +++ b/packages/leann-backend-hnsw/leann_backend_hnsw/hnsw_backend.py @@ -1,10 +1,9 @@ import numpy as np import os from pathlib import Path -from typing import Dict, Any, List, Literal -import pickle +from typing import Dict, Any, List, Literal, Optional import shutil -import time +import logging from leann.searcher_base import BaseSearcher from .convert_to_csr import convert_hnsw_graph_to_csr @@ -16,6 +15,8 @@ from leann.interface import ( LeannBackendSearcherInterface, ) +logger = logging.getLogger(__name__) + def get_metric_map(): from . import faiss # type: ignore @@ -57,9 +58,9 @@ class HNSWBuilder(LeannBackendBuilderInterface): index_dir.mkdir(parents=True, exist_ok=True) if data.dtype != np.float32: + logger.warning(f"Converting data to float32, shape: {data.shape}") data = data.astype(np.float32) - metric_enum = get_metric_map().get(self.distance_metric.lower()) if metric_enum is None: raise ValueError(f"Unsupported distance_metric '{self.distance_metric}'.") @@ -81,7 +82,7 @@ class HNSWBuilder(LeannBackendBuilderInterface): def _convert_to_csr(self, index_file: Path): """Convert built index to CSR format""" mode_str = "CSR-pruned" if self.is_recompute else "CSR-standard" - print(f"INFO: Converting HNSW index to {mode_str} format...") + logger.info(f"INFO: Converting HNSW index to {mode_str} format...") csr_temp_file = index_file.with_suffix(".csr.tmp") @@ -90,11 +91,11 @@ class HNSWBuilder(LeannBackendBuilderInterface): ) if success: - print("✅ CSR conversion successful.") + logger.info("✅ CSR conversion successful.") index_file_old = index_file.with_suffix(".old") shutil.move(str(index_file), str(index_file_old)) shutil.move(str(csr_temp_file), str(index_file)) - print( + logger.info( f"INFO: Replaced original index with {mode_str} version at '{index_file}'" ) else: @@ -131,13 +132,11 @@ class HNSWSearcher(BaseSearcher): hnsw_config = faiss.HNSWIndexConfig() hnsw_config.is_compact = self.is_compact - hnsw_config.is_recompute = self.is_pruned or kwargs.get("is_recompute", False) - - if self.is_pruned and not hnsw_config.is_recompute: - raise RuntimeError("Index is pruned but recompute is disabled.") + hnsw_config.is_recompute = ( + self.is_pruned + ) # In C++ code, it's called is_recompute, but it's only for loading IIUC. self._index = faiss.read_index(str(index_file), faiss.IO_FLAG_MMAP, hnsw_config) - def search( self, @@ -146,9 +145,9 @@ class HNSWSearcher(BaseSearcher): complexity: int = 64, beam_width: int = 1, prune_ratio: float = 0.0, - recompute_embeddings: bool = False, + recompute_embeddings: bool = True, pruning_strategy: Literal["global", "local", "proportional"] = "global", - zmq_port: int = 5557, + expected_zmq_port: Optional[int] = None, batch_size: int = 0, **kwargs, ) -> Dict[str, Any]: @@ -166,7 +165,7 @@ class HNSWSearcher(BaseSearcher): - "global": Use global PQ queue size for selection (default) - "local": Local pruning, sort and select best candidates - "proportional": Base selection on new neighbor count ratio - zmq_port: ZMQ port for embedding server + expected_zmq_port: ZMQ port for embedding server batch_size: Neighbor processing batch size, 0=disabled (HNSW-specific) **kwargs: Additional HNSW-specific parameters (for legacy compatibility) @@ -175,15 +174,9 @@ class HNSWSearcher(BaseSearcher): """ from . import faiss # type: ignore - # Use recompute_embeddings parameter - use_recompute = recompute_embeddings or self.is_pruned - if use_recompute: - meta_file_path = self.index_dir / f"{self.index_path.name}.meta.json" - if not meta_file_path.exists(): - raise RuntimeError( - f"FATAL: Recompute enabled but metadata file not found: {meta_file_path}" - ) - self._ensure_server_running(str(meta_file_path), port=zmq_port, **kwargs) + if not recompute_embeddings: + if self.is_pruned: + raise RuntimeError("Recompute is required for pruned index.") if query.dtype != np.float32: query = query.astype(np.float32) @@ -191,7 +184,7 @@ class HNSWSearcher(BaseSearcher): faiss.normalize_L2(query) params = faiss.SearchParametersHNSW() - params.zmq_port = zmq_port + params.zmq_port = expected_zmq_port params.efSearch = complexity params.beam_size = beam_width @@ -228,8 +221,7 @@ class HNSWSearcher(BaseSearcher): ) string_labels = [ - [str(int_label) for int_label in batch_labels] - for batch_labels in labels + [str(int_label) for int_label in batch_labels] for batch_labels in labels ] return {"labels": string_labels, "distances": distances} diff --git a/packages/leann-backend-hnsw/leann_backend_hnsw/hnsw_embedding_server.py b/packages/leann-backend-hnsw/leann_backend_hnsw/hnsw_embedding_server.py index 48f8e1b..3d3c9fa 100644 --- a/packages/leann-backend-hnsw/leann_backend_hnsw/hnsw_embedding_server.py +++ b/packages/leann-backend-hnsw/leann_backend_hnsw/hnsw_embedding_server.py @@ -29,7 +29,6 @@ logger = logging.getLogger(__name__) def create_hnsw_embedding_server( passages_file: Optional[str] = None, - passages_data: Optional[Dict[str, str]] = None, zmq_port: int = 5555, model_name: str = "sentence-transformers/all-mpnet-base-v2", distance_metric: str = "mips", @@ -39,12 +38,6 @@ def create_hnsw_embedding_server( Create and start a ZMQ-based embedding server for HNSW backend. Simplified version using unified embedding computation module. """ - # Auto-detect mode based on model name if not explicitly set - if embedding_mode == "sentence-transformers" and model_name.startswith( - "text-embedding-" - ): - embedding_mode = "openai" - print(f"Starting HNSW server on port {zmq_port} with model {model_name}") print(f"Using embedding mode: {embedding_mode}") @@ -64,6 +57,7 @@ def create_hnsw_embedding_server( finally: sys.path.pop(0) + # Check port availability import socket @@ -78,13 +72,15 @@ def create_hnsw_embedding_server( # Only support metadata file, fail fast for everything else if not passages_file or not passages_file.endswith(".meta.json"): raise ValueError("Only metadata files (.meta.json) are supported") - + # Load metadata to get passage sources with open(passages_file, "r") as f: meta = json.load(f) - + passages = PassageManager(meta["passage_sources"]) - print(f"Loaded PassageManager with {len(passages.global_offset_map)} passages from metadata") + print( + f"Loaded PassageManager with {len(passages.global_offset_map)} passages from metadata" + ) def zmq_server_thread(): """ZMQ server thread""" @@ -112,7 +108,7 @@ def create_hnsw_embedding_server( f"Processing direct text embedding request for {len(request_payload)} texts in {embedding_mode} mode" ) - # Use unified embedding computation + # Use unified embedding computation (now with model caching) embeddings = compute_embeddings( request_payload, model_name, mode=embedding_mode ) @@ -148,15 +144,15 @@ def create_hnsw_embedding_server( texts.append(txt) except KeyError: print(f"ERROR: Passage ID {nid} not found") - raise RuntimeError(f"FATAL: Passage with ID {nid} not found") + raise RuntimeError( + f"FATAL: Passage with ID {nid} not found" + ) except Exception as e: print(f"ERROR: Exception looking up passage ID {nid}: {e}") raise # Process embeddings - embeddings = compute_embeddings( - texts, model_name, mode=embedding_mode - ) + embeddings = compute_embeddings(texts, model_name, mode=embedding_mode) print( f"INFO: Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}" ) @@ -204,7 +200,9 @@ def create_hnsw_embedding_server( passage_data = passages.get_passage(str(nid)) txt = passage_data["text"] if not txt: - raise RuntimeError(f"FATAL: Empty text for passage ID {nid}") + raise RuntimeError( + f"FATAL: Empty text for passage ID {nid}" + ) texts.append(txt) except KeyError: raise RuntimeError(f"FATAL: Passage with ID {nid} not found") diff --git a/packages/leann-backend-hnsw/third_party/msgpack-c b/packages/leann-backend-hnsw/third_party/msgpack-c index 9b801f0..a0b2ec0 160000 --- a/packages/leann-backend-hnsw/third_party/msgpack-c +++ b/packages/leann-backend-hnsw/third_party/msgpack-c @@ -1 +1 @@ -Subproject commit 9b801f087ab7434f2ab1ab3c0f48a966c19d3b70 +Subproject commit a0b2ec09da4bd823e40fa591221713951d4ec995 diff --git a/packages/leann-core/src/leann/api.py b/packages/leann-core/src/leann/api.py index 75a6d2b..ff41912 100644 --- a/packages/leann-core/src/leann/api.py +++ b/packages/leann-core/src/leann/api.py @@ -5,7 +5,9 @@ with the correct, original embedding logic from the user's reference code. import json import pickle +from leann.interface import LeannBackendSearcherInterface import numpy as np +import time from pathlib import Path from typing import List, Dict, Any, Optional, Literal from dataclasses import dataclass, field @@ -126,6 +128,7 @@ class PassageManager: def get_passage(self, passage_id: str) -> Dict[str, Any]: if passage_id in self.global_offset_map: passage_file, offset = self.global_offset_map[passage_id] + # Lazy file opening - only open when needed with open(passage_file, "r", encoding="utf-8") as f: f.seek(offset) return json.loads(f.readline()) @@ -373,10 +376,12 @@ class LeannBuilder: class LeannSearcher: def __init__(self, index_path: str, enable_warmup: bool = False, **backend_kwargs): - meta_path_str = f"{index_path}.meta.json" - if not Path(meta_path_str).exists(): - raise FileNotFoundError(f"Leann metadata file not found at {meta_path_str}") - with open(meta_path_str, "r", encoding="utf-8") as f: + self.meta_path_str = f"{index_path}.meta.json" + if not Path(self.meta_path_str).exists(): + raise FileNotFoundError( + f"Leann metadata file not found at {self.meta_path_str}" + ) + with open(self.meta_path_str, "r", encoding="utf-8") as f: self.meta_data = json.load(f) backend_name = self.meta_data["backend_name"] self.embedding_model = self.meta_data["embedding_model"] @@ -390,7 +395,9 @@ class LeannSearcher: raise ValueError(f"Backend '{backend_name}' not found.") final_kwargs = {**self.meta_data.get("backend_kwargs", {}), **backend_kwargs} final_kwargs["enable_warmup"] = enable_warmup - self.backend_impl = backend_factory.searcher(index_path, **final_kwargs) + self.backend_impl: LeannBackendSearcherInterface = backend_factory.searcher( + index_path, **final_kwargs + ) def search( self, @@ -399,9 +406,9 @@ class LeannSearcher: complexity: int = 64, beam_width: int = 1, prune_ratio: float = 0.0, - recompute_embeddings: bool = False, + recompute_embeddings: bool = True, pruning_strategy: Literal["global", "local", "proportional"] = "global", - expected_zmq_port: Optional[int] = None, + expected_zmq_port: int = 5557, **kwargs, ) -> List[SearchResult]: print("🔍 DEBUG LeannSearcher.search() called:") @@ -409,16 +416,21 @@ class LeannSearcher: print(f" Top_k: {top_k}") print(f" Additional kwargs: {kwargs}") - # Use backend's compute_query_embedding method - # This will automatically use embedding server if available and needed - import time - start_time = time.time() + zmq_port = None + if recompute_embeddings: + zmq_port = self.backend_impl._ensure_server_running( + self.meta_path_str, + port=expected_zmq_port, + **kwargs, + ) + del expected_zmq_port + query_embedding = self.backend_impl.compute_query_embedding( query, - expected_zmq_port, use_server_if_available=recompute_embeddings, + zmq_port=zmq_port, ) print(f" Generated embedding shape: {query_embedding.shape}") embedding_time = time.time() - start_time @@ -433,7 +445,7 @@ class LeannSearcher: prune_ratio=prune_ratio, recompute_embeddings=recompute_embeddings, pruning_strategy=pruning_strategy, - expected_zmq_port=expected_zmq_port, + expected_zmq_port=zmq_port, **kwargs, ) search_time = time.time() - start_time @@ -488,10 +500,10 @@ class LeannChat: complexity: int = 64, beam_width: int = 1, prune_ratio: float = 0.0, - recompute_embeddings: bool = False, + recompute_embeddings: bool = True, pruning_strategy: Literal["global", "local", "proportional"] = "global", - expected_zmq_port: Optional[int] = None, llm_kwargs: Optional[Dict[str, Any]] = None, + expected_zmq_port: int = 5557, **search_kwargs, ): if llm_kwargs is None: diff --git a/packages/leann-core/src/leann/embedding_compute.py b/packages/leann-core/src/leann/embedding_compute.py index 25e505c..3b30798 100644 --- a/packages/leann-core/src/leann/embedding_compute.py +++ b/packages/leann-core/src/leann/embedding_compute.py @@ -6,11 +6,14 @@ Preserves all optimization parameters to ensure performance import numpy as np import torch -from typing import List +from typing import List, Dict, Any, Optional import logging logger = logging.getLogger(__name__) +# Global model cache to avoid repeated loading +_model_cache: Dict[str, Any] = {} + def compute_embeddings( texts: List[str], model_name: str, mode: str = "sentence-transformers",is_build: bool = False @@ -45,25 +48,12 @@ def compute_embeddings_sentence_transformers( is_build: bool = False, ) -> np.ndarray: """ - Compute embeddings using SentenceTransformer - Preserves all optimization parameters to ensure consistency with original embedding_server - - Args: - texts: List of texts to compute embeddings for - model_name: SentenceTransformer model name - use_fp16: Whether to use FP16 precision - device: Device selection ('auto', 'cuda', 'mps', 'cpu') - batch_size: Batch size for processing - - Returns: - Normalized embeddings array, shape: (len(texts), embedding_dim) + Compute embeddings using SentenceTransformer with model caching """ print( f"INFO: Computing embeddings for {len(texts)} texts using SentenceTransformer, model: '{model_name}'" ) - from sentence_transformers import SentenceTransformer - # Auto-detect device if device == "auto": if torch.cuda.is_available(): @@ -73,62 +63,72 @@ def compute_embeddings_sentence_transformers( else: device = "cpu" - print(f"INFO: Using device: {device}") + # Create cache key + cache_key = f"sentence_transformers_{model_name}_{device}_{use_fp16}" + + # Check if model is already cached + if cache_key in _model_cache: + print(f"INFO: Using cached model: {model_name}") + model = _model_cache[cache_key] + else: + print(f"INFO: Loading and caching SentenceTransformer model: {model_name}") + from sentence_transformers import SentenceTransformer - # Prepare model and tokenizer optimization parameters (consistent with original embedding_server) - model_kwargs = { - "torch_dtype": torch.float16 if use_fp16 else torch.float32, - "low_cpu_mem_usage": True, - "_fast_init": True, # Skip weight initialization checks for faster loading - } + print(f"INFO: Using device: {device}") - tokenizer_kwargs = { - "use_fast": True, # Use fast tokenizer for better runtime performance - } + # Prepare model and tokenizer optimization parameters + model_kwargs = { + "torch_dtype": torch.float16 if use_fp16 else torch.float32, + "low_cpu_mem_usage": True, + "_fast_init": True, + } - # Load SentenceTransformer (try local first, then network) - print(f"INFO: Loading SentenceTransformer model: {model_name}") + tokenizer_kwargs = { + "use_fast": True, + } - try: - # Try local loading (avoid network delays) - model_kwargs["local_files_only"] = True - tokenizer_kwargs["local_files_only"] = True - - model = SentenceTransformer( - model_name, - device=device, - model_kwargs=model_kwargs, - tokenizer_kwargs=tokenizer_kwargs, - local_files_only=True, - ) - print("✅ Model loaded successfully! (local + optimized)") - except Exception as e: - print(f"Local loading failed ({e}), trying network download...") - # Fallback to network loading - model_kwargs["local_files_only"] = False - tokenizer_kwargs["local_files_only"] = False - - model = SentenceTransformer( - model_name, - device=device, - model_kwargs=model_kwargs, - tokenizer_kwargs=tokenizer_kwargs, - local_files_only=False, - ) - print("✅ Model loaded successfully! (network + optimized)") - - # Apply additional optimizations (if supported) - if use_fp16 and device in ["cuda", "mps"]: try: - model = model.half() - model = torch.compile(model) - print(f"✅ Using FP16 precision and compile optimization: {model_name}") - except Exception as e: - print( - f"FP16 or compile optimization failed, continuing with default settings: {e}" - ) + # Try local loading first + model_kwargs["local_files_only"] = True + tokenizer_kwargs["local_files_only"] = True - # Compute embeddings (using SentenceTransformer's optimized implementation) + model = SentenceTransformer( + model_name, + device=device, + model_kwargs=model_kwargs, + tokenizer_kwargs=tokenizer_kwargs, + local_files_only=True, + ) + print("✅ Model loaded successfully! (local + optimized)") + except Exception as e: + print(f"Local loading failed ({e}), trying network download...") + # Fallback to network loading + model_kwargs["local_files_only"] = False + tokenizer_kwargs["local_files_only"] = False + + model = SentenceTransformer( + model_name, + device=device, + model_kwargs=model_kwargs, + tokenizer_kwargs=tokenizer_kwargs, + local_files_only=False, + ) + print("✅ Model loaded successfully! (network + optimized)") + + # Apply additional optimizations (if supported) + if use_fp16 and device in ["cuda", "mps"]: + try: + model = model.half() + model = torch.compile(model) + print(f"✅ Using FP16 precision and compile optimization: {model_name}") + except Exception as e: + print(f"FP16 or compile optimization failed: {e}") + + # Cache the model + _model_cache[cache_key] = model + print(f"✅ Model cached: {cache_key}") + + # Compute embeddings print("INFO: Starting embedding computation...") embeddings = model.encode( @@ -136,7 +136,7 @@ def compute_embeddings_sentence_transformers( batch_size=batch_size, show_progress_bar=is_build, # Don't show progress bar in server environment convert_to_numpy=True, - normalize_embeddings=False, # Keep consistent with original API behavior + normalize_embeddings=False, device=device, ) @@ -166,7 +166,14 @@ def compute_embeddings_openai(texts: List[str], model_name: str) -> np.ndarray: if not api_key: raise RuntimeError("OPENAI_API_KEY environment variable not set") - client = openai.OpenAI(api_key=api_key) + # Cache OpenAI client + cache_key = "openai_client" + if cache_key in _model_cache: + client = _model_cache[cache_key] + else: + client = openai.OpenAI(api_key=api_key) + _model_cache[cache_key] = client + print("✅ OpenAI client cached") print( f"INFO: Computing embeddings for {len(texts)} texts using OpenAI API, model: '{model_name}'" @@ -214,7 +221,6 @@ def compute_embeddings_mlx( try: import mlx.core as mx from mlx_lm.utils import load - from tqdm import tqdm except ImportError as e: raise RuntimeError( "MLX or related libraries not available. Install with: uv pip install mlx mlx-lm" @@ -224,8 +230,16 @@ def compute_embeddings_mlx( f"INFO: Computing embeddings for {len(chunks)} chunks using MLX model '{model_name}' with batch_size={batch_size}..." ) - # Load model and tokenizer - model, tokenizer = load(model_name) + # Cache MLX model and tokenizer + cache_key = f"mlx_{model_name}" + if cache_key in _model_cache: + print(f"INFO: Using cached MLX model: {model_name}") + model, tokenizer = _model_cache[cache_key] + else: + print(f"INFO: Loading and caching MLX model: {model_name}") + model, tokenizer = load(model_name) + _model_cache[cache_key] = (model, tokenizer) + print(f"✅ MLX model cached: {cache_key}") # Process chunks in batches with progress bar all_embeddings = [] diff --git a/packages/leann-core/src/leann/interface.py b/packages/leann-core/src/leann/interface.py index dfa4cc1..338c3dc 100644 --- a/packages/leann-core/src/leann/interface.py +++ b/packages/leann-core/src/leann/interface.py @@ -34,6 +34,13 @@ class LeannBackendSearcherInterface(ABC): """ pass + @abstractmethod + def _ensure_server_running( + self, passages_source_file: str, port: Optional[int], **kwargs + ) -> int: + """Ensure server is running""" + pass + @abstractmethod def search( self, @@ -44,7 +51,7 @@ class LeannBackendSearcherInterface(ABC): prune_ratio: float = 0.0, recompute_embeddings: bool = False, pruning_strategy: Literal["global", "local", "proportional"] = "global", - expected_zmq_port: Optional[int] = None, + zmq_port: Optional[int] = None, **kwargs, ) -> Dict[str, Any]: """Search for nearest neighbors @@ -57,7 +64,7 @@ class LeannBackendSearcherInterface(ABC): prune_ratio: Ratio of neighbors to prune via approximate distance (0.0-1.0) recompute_embeddings: Whether to fetch fresh embeddings from server vs use stored PQ codes pruning_strategy: PQ candidate selection strategy - "global" (default), "local", or "proportional" - expected_zmq_port: ZMQ port for embedding server communication + zmq_port: ZMQ port for embedding server communication **kwargs: Backend-specific parameters Returns: @@ -69,14 +76,14 @@ class LeannBackendSearcherInterface(ABC): def compute_query_embedding( self, query: str, - expected_zmq_port: Optional[int] = None, use_server_if_available: bool = True, + zmq_port: Optional[int] = None, ) -> np.ndarray: """Compute embedding for a query string Args: query: The query string to embed - expected_zmq_port: ZMQ port for embedding server + zmq_port: ZMQ port for embedding server use_server_if_available: Whether to try using embedding server first Returns: diff --git a/packages/leann-core/src/leann/searcher_base.py b/packages/leann-core/src/leann/searcher_base.py index fad5589..73f979a 100644 --- a/packages/leann-core/src/leann/searcher_base.py +++ b/packages/leann-core/src/leann/searcher_base.py @@ -1,5 +1,4 @@ import json -import pickle from abc import ABC, abstractmethod from pathlib import Path from typing import Dict, Any, Literal, Optional @@ -88,15 +87,15 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC): def compute_query_embedding( self, query: str, - expected_zmq_port: int = 5557, use_server_if_available: bool = True, + zmq_port: int = 5557, ) -> np.ndarray: """ Compute embedding for a query string. Args: query: The query string to embed - expected_zmq_port: ZMQ port for embedding server + zmq_port: ZMQ port for embedding server use_server_if_available: Whether to try using embedding server first Returns: @@ -110,7 +109,7 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC): self.index_dir / f"{self.index_path.name}.meta.json" ) zmq_port = self._ensure_server_running( - str(passages_source_file), expected_zmq_port + str(passages_source_file), zmq_port ) return self._compute_embedding_via_server([query], zmq_port)[ @@ -168,7 +167,7 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC): prune_ratio: float = 0.0, recompute_embeddings: bool = False, pruning_strategy: Literal["global", "local", "proportional"] = "global", - expected_zmq_port: Optional[int] = None, + zmq_port: Optional[int] = None, **kwargs, ) -> Dict[str, Any]: """ @@ -182,7 +181,7 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC): prune_ratio: Ratio of neighbors to prune via approximate distance (0.0-1.0) recompute_embeddings: Whether to fetch fresh embeddings from server vs use stored PQ codes pruning_strategy: PQ candidate selection strategy - "global" (default), "local", or "proportional" - expected_zmq_port: ZMQ port for embedding server communication + zmq_port: ZMQ port for embedding server communication **kwargs: Backend-specific parameters (e.g., batch_size, dedup_node_dis, etc.) Returns: