fix: recompute args in searcher
This commit is contained in:
@@ -13,7 +13,7 @@ import os
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
from leann.api import LeannBuilder, LeannSearcher
|
||||
from leann.api import LeannBuilder, LeannSearcher, SearchResult
|
||||
|
||||
os.environ["LEANN_LOG_LEVEL"] = "DEBUG"
|
||||
|
||||
@@ -87,7 +87,7 @@ def test_search_performance():
|
||||
print("\n Test 1: Default complexity (64) `1 ")
|
||||
print(f" Query: '{test_query}'")
|
||||
start_time = time.time()
|
||||
results = searcher.search(test_query, top_k=10, complexity=64)
|
||||
results: list[SearchResult] = searcher.search(test_query, top_k=10, complexity=64)
|
||||
search_time = time.time() - start_time
|
||||
print(f" ✓ Search completed in {search_time:.2f} seconds")
|
||||
print(f" Results: {len(results)} items")
|
||||
|
||||
@@ -864,7 +864,13 @@ class LeannBuilder:
|
||||
|
||||
|
||||
class LeannSearcher:
|
||||
def __init__(self, index_path: str, enable_warmup: bool = False, **backend_kwargs):
|
||||
def __init__(
|
||||
self,
|
||||
index_path: str,
|
||||
enable_warmup: bool = True,
|
||||
recompute_embeddings: bool = True,
|
||||
**backend_kwargs,
|
||||
):
|
||||
# Fix path resolution for Colab and other environments
|
||||
if not Path(index_path).is_absolute():
|
||||
index_path = str(Path(index_path).resolve())
|
||||
@@ -895,14 +901,32 @@ class LeannSearcher:
|
||||
backend_factory = BACKEND_REGISTRY.get(backend_name)
|
||||
if backend_factory is None:
|
||||
raise ValueError(f"Backend '{backend_name}' not found.")
|
||||
|
||||
# Global recompute flag for this searcher (explicit knob, default True)
|
||||
self.recompute_embeddings: bool = bool(recompute_embeddings)
|
||||
|
||||
# Warmup flag: keep using the existing enable_warmup parameter,
|
||||
# but default it to True so cold-start happens earlier.
|
||||
self._warmup: bool = bool(enable_warmup)
|
||||
|
||||
final_kwargs = {**self.meta_data.get("backend_kwargs", {}), **backend_kwargs}
|
||||
final_kwargs["enable_warmup"] = enable_warmup
|
||||
final_kwargs["enable_warmup"] = self._warmup
|
||||
if self.embedding_options:
|
||||
final_kwargs.setdefault("embedding_options", self.embedding_options)
|
||||
self.backend_impl: LeannBackendSearcherInterface = backend_factory.searcher(
|
||||
index_path, **final_kwargs
|
||||
)
|
||||
|
||||
# Optional one-shot warmup at construction time to hide cold-start latency.
|
||||
if self._warmup:
|
||||
try:
|
||||
_ = self.backend_impl.compute_query_embedding(
|
||||
"__LEANN_WARMUP__",
|
||||
use_server_if_available=self.recompute_embeddings,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning(f"Warmup embedding failed (ignored): {exc}")
|
||||
|
||||
def search(
|
||||
self,
|
||||
query: str,
|
||||
@@ -910,7 +934,7 @@ class LeannSearcher:
|
||||
complexity: int = 64,
|
||||
beam_width: int = 1,
|
||||
prune_ratio: float = 0.0,
|
||||
recompute_embeddings: bool = True,
|
||||
recompute_embeddings: Optional[bool] = None,
|
||||
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
||||
expected_zmq_port: int = 5557,
|
||||
metadata_filters: Optional[dict[str, dict[str, Union[str, int, float, bool, list]]]] = None,
|
||||
@@ -927,7 +951,8 @@ class LeannSearcher:
|
||||
complexity: Search complexity/candidate list size, higher = more accurate but slower
|
||||
beam_width: Number of parallel search paths/IO requests per iteration
|
||||
prune_ratio: Ratio of neighbors to prune via approximate distance (0.0-1.0)
|
||||
recompute_embeddings: Whether to fetch fresh embeddings from server vs use stored codes
|
||||
recompute_embeddings: (Deprecated) Per-call override for recompute mode.
|
||||
Configure this at LeannSearcher(..., recompute_embeddings=...) instead.
|
||||
pruning_strategy: Candidate selection strategy - "global" (default), "local", or "proportional"
|
||||
expected_zmq_port: ZMQ port for embedding server communication
|
||||
metadata_filters: Optional filters to apply to search results based on metadata.
|
||||
@@ -966,8 +991,19 @@ class LeannSearcher:
|
||||
|
||||
zmq_port = None
|
||||
|
||||
# Resolve effective recompute flag for this search.
|
||||
if recompute_embeddings is not None:
|
||||
logger.warning(
|
||||
"LeannSearcher.search(..., recompute_embeddings=...) is deprecated and "
|
||||
"will be removed in a future version. Configure recompute at "
|
||||
"LeannSearcher(..., recompute_embeddings=...) instead."
|
||||
)
|
||||
effective_recompute = bool(recompute_embeddings)
|
||||
else:
|
||||
effective_recompute = self.recompute_embeddings
|
||||
|
||||
start_time = time.time()
|
||||
if recompute_embeddings:
|
||||
if effective_recompute:
|
||||
zmq_port = self.backend_impl._ensure_server_running(
|
||||
self.meta_path_str,
|
||||
port=expected_zmq_port,
|
||||
@@ -981,7 +1017,7 @@ class LeannSearcher:
|
||||
|
||||
query_embedding = self.backend_impl.compute_query_embedding(
|
||||
query,
|
||||
use_server_if_available=recompute_embeddings,
|
||||
use_server_if_available=effective_recompute,
|
||||
zmq_port=zmq_port,
|
||||
)
|
||||
logger.info(f" Generated embedding shape: {query_embedding.shape}")
|
||||
@@ -993,7 +1029,7 @@ class LeannSearcher:
|
||||
"complexity": complexity,
|
||||
"beam_width": beam_width,
|
||||
"prune_ratio": prune_ratio,
|
||||
"recompute_embeddings": recompute_embeddings,
|
||||
"recompute_embeddings": effective_recompute,
|
||||
"pruning_strategy": pruning_strategy,
|
||||
"zmq_port": zmq_port,
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user