diff --git a/issue_159.py b/issue_159.py index e0c43b0..62cb990 100644 --- a/issue_159.py +++ b/issue_159.py @@ -13,7 +13,7 @@ import os import time from pathlib import Path -from leann.api import LeannBuilder, LeannSearcher +from leann.api import LeannBuilder, LeannSearcher, SearchResult os.environ["LEANN_LOG_LEVEL"] = "DEBUG" @@ -87,7 +87,7 @@ def test_search_performance(): print("\n Test 1: Default complexity (64) `1 ") print(f" Query: '{test_query}'") start_time = time.time() - results = searcher.search(test_query, top_k=10, complexity=64) + results: list[SearchResult] = searcher.search(test_query, top_k=10, complexity=64) search_time = time.time() - start_time print(f" ✓ Search completed in {search_time:.2f} seconds") print(f" Results: {len(results)} items") diff --git a/packages/leann-core/src/leann/api.py b/packages/leann-core/src/leann/api.py index bbcc8a3..123713c 100644 --- a/packages/leann-core/src/leann/api.py +++ b/packages/leann-core/src/leann/api.py @@ -864,7 +864,13 @@ class LeannBuilder: class LeannSearcher: - def __init__(self, index_path: str, enable_warmup: bool = False, **backend_kwargs): + def __init__( + self, + index_path: str, + enable_warmup: bool = True, + recompute_embeddings: bool = True, + **backend_kwargs, + ): # Fix path resolution for Colab and other environments if not Path(index_path).is_absolute(): index_path = str(Path(index_path).resolve()) @@ -895,14 +901,32 @@ class LeannSearcher: backend_factory = BACKEND_REGISTRY.get(backend_name) if backend_factory is None: raise ValueError(f"Backend '{backend_name}' not found.") + + # Global recompute flag for this searcher (explicit knob, default True) + self.recompute_embeddings: bool = bool(recompute_embeddings) + + # Warmup flag: keep using the existing enable_warmup parameter, + # but default it to True so cold-start happens earlier. + self._warmup: bool = bool(enable_warmup) + final_kwargs = {**self.meta_data.get("backend_kwargs", {}), **backend_kwargs} - final_kwargs["enable_warmup"] = enable_warmup + final_kwargs["enable_warmup"] = self._warmup if self.embedding_options: final_kwargs.setdefault("embedding_options", self.embedding_options) self.backend_impl: LeannBackendSearcherInterface = backend_factory.searcher( index_path, **final_kwargs ) + # Optional one-shot warmup at construction time to hide cold-start latency. + if self._warmup: + try: + _ = self.backend_impl.compute_query_embedding( + "__LEANN_WARMUP__", + use_server_if_available=self.recompute_embeddings, + ) + except Exception as exc: + logger.warning(f"Warmup embedding failed (ignored): {exc}") + def search( self, query: str, @@ -910,7 +934,7 @@ class LeannSearcher: complexity: int = 64, beam_width: int = 1, prune_ratio: float = 0.0, - recompute_embeddings: bool = True, + recompute_embeddings: Optional[bool] = None, pruning_strategy: Literal["global", "local", "proportional"] = "global", expected_zmq_port: int = 5557, metadata_filters: Optional[dict[str, dict[str, Union[str, int, float, bool, list]]]] = None, @@ -927,7 +951,8 @@ class LeannSearcher: complexity: Search complexity/candidate list size, higher = more accurate but slower beam_width: Number of parallel search paths/IO requests per iteration prune_ratio: Ratio of neighbors to prune via approximate distance (0.0-1.0) - recompute_embeddings: Whether to fetch fresh embeddings from server vs use stored codes + recompute_embeddings: (Deprecated) Per-call override for recompute mode. + Configure this at LeannSearcher(..., recompute_embeddings=...) instead. pruning_strategy: Candidate selection strategy - "global" (default), "local", or "proportional" expected_zmq_port: ZMQ port for embedding server communication metadata_filters: Optional filters to apply to search results based on metadata. @@ -966,8 +991,19 @@ class LeannSearcher: zmq_port = None + # Resolve effective recompute flag for this search. + if recompute_embeddings is not None: + logger.warning( + "LeannSearcher.search(..., recompute_embeddings=...) is deprecated and " + "will be removed in a future version. Configure recompute at " + "LeannSearcher(..., recompute_embeddings=...) instead." + ) + effective_recompute = bool(recompute_embeddings) + else: + effective_recompute = self.recompute_embeddings + start_time = time.time() - if recompute_embeddings: + if effective_recompute: zmq_port = self.backend_impl._ensure_server_running( self.meta_path_str, port=expected_zmq_port, @@ -981,7 +1017,7 @@ class LeannSearcher: query_embedding = self.backend_impl.compute_query_embedding( query, - use_server_if_available=recompute_embeddings, + use_server_if_available=effective_recompute, zmq_port=zmq_port, ) logger.info(f" Generated embedding shape: {query_embedding.shape}") @@ -993,7 +1029,7 @@ class LeannSearcher: "complexity": complexity, "beam_width": beam_width, "prune_ratio": prune_ratio, - "recompute_embeddings": recompute_embeddings, + "recompute_embeddings": effective_recompute, "pruning_strategy": pruning_strategy, "zmq_port": zmq_port, }