fix: recompute args in searcher
This commit is contained in:
@@ -13,7 +13,7 @@ import os
|
|||||||
import time
|
import time
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from leann.api import LeannBuilder, LeannSearcher
|
from leann.api import LeannBuilder, LeannSearcher, SearchResult
|
||||||
|
|
||||||
os.environ["LEANN_LOG_LEVEL"] = "DEBUG"
|
os.environ["LEANN_LOG_LEVEL"] = "DEBUG"
|
||||||
|
|
||||||
@@ -87,7 +87,7 @@ def test_search_performance():
|
|||||||
print("\n Test 1: Default complexity (64) `1 ")
|
print("\n Test 1: Default complexity (64) `1 ")
|
||||||
print(f" Query: '{test_query}'")
|
print(f" Query: '{test_query}'")
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
results = searcher.search(test_query, top_k=10, complexity=64)
|
results: list[SearchResult] = searcher.search(test_query, top_k=10, complexity=64)
|
||||||
search_time = time.time() - start_time
|
search_time = time.time() - start_time
|
||||||
print(f" ✓ Search completed in {search_time:.2f} seconds")
|
print(f" ✓ Search completed in {search_time:.2f} seconds")
|
||||||
print(f" Results: {len(results)} items")
|
print(f" Results: {len(results)} items")
|
||||||
|
|||||||
@@ -864,7 +864,13 @@ class LeannBuilder:
|
|||||||
|
|
||||||
|
|
||||||
class LeannSearcher:
|
class LeannSearcher:
|
||||||
def __init__(self, index_path: str, enable_warmup: bool = False, **backend_kwargs):
|
def __init__(
|
||||||
|
self,
|
||||||
|
index_path: str,
|
||||||
|
enable_warmup: bool = True,
|
||||||
|
recompute_embeddings: bool = True,
|
||||||
|
**backend_kwargs,
|
||||||
|
):
|
||||||
# Fix path resolution for Colab and other environments
|
# Fix path resolution for Colab and other environments
|
||||||
if not Path(index_path).is_absolute():
|
if not Path(index_path).is_absolute():
|
||||||
index_path = str(Path(index_path).resolve())
|
index_path = str(Path(index_path).resolve())
|
||||||
@@ -895,14 +901,32 @@ class LeannSearcher:
|
|||||||
backend_factory = BACKEND_REGISTRY.get(backend_name)
|
backend_factory = BACKEND_REGISTRY.get(backend_name)
|
||||||
if backend_factory is None:
|
if backend_factory is None:
|
||||||
raise ValueError(f"Backend '{backend_name}' not found.")
|
raise ValueError(f"Backend '{backend_name}' not found.")
|
||||||
|
|
||||||
|
# Global recompute flag for this searcher (explicit knob, default True)
|
||||||
|
self.recompute_embeddings: bool = bool(recompute_embeddings)
|
||||||
|
|
||||||
|
# Warmup flag: keep using the existing enable_warmup parameter,
|
||||||
|
# but default it to True so cold-start happens earlier.
|
||||||
|
self._warmup: bool = bool(enable_warmup)
|
||||||
|
|
||||||
final_kwargs = {**self.meta_data.get("backend_kwargs", {}), **backend_kwargs}
|
final_kwargs = {**self.meta_data.get("backend_kwargs", {}), **backend_kwargs}
|
||||||
final_kwargs["enable_warmup"] = enable_warmup
|
final_kwargs["enable_warmup"] = self._warmup
|
||||||
if self.embedding_options:
|
if self.embedding_options:
|
||||||
final_kwargs.setdefault("embedding_options", self.embedding_options)
|
final_kwargs.setdefault("embedding_options", self.embedding_options)
|
||||||
self.backend_impl: LeannBackendSearcherInterface = backend_factory.searcher(
|
self.backend_impl: LeannBackendSearcherInterface = backend_factory.searcher(
|
||||||
index_path, **final_kwargs
|
index_path, **final_kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Optional one-shot warmup at construction time to hide cold-start latency.
|
||||||
|
if self._warmup:
|
||||||
|
try:
|
||||||
|
_ = self.backend_impl.compute_query_embedding(
|
||||||
|
"__LEANN_WARMUP__",
|
||||||
|
use_server_if_available=self.recompute_embeddings,
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning(f"Warmup embedding failed (ignored): {exc}")
|
||||||
|
|
||||||
def search(
|
def search(
|
||||||
self,
|
self,
|
||||||
query: str,
|
query: str,
|
||||||
@@ -910,7 +934,7 @@ class LeannSearcher:
|
|||||||
complexity: int = 64,
|
complexity: int = 64,
|
||||||
beam_width: int = 1,
|
beam_width: int = 1,
|
||||||
prune_ratio: float = 0.0,
|
prune_ratio: float = 0.0,
|
||||||
recompute_embeddings: bool = True,
|
recompute_embeddings: Optional[bool] = None,
|
||||||
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
||||||
expected_zmq_port: int = 5557,
|
expected_zmq_port: int = 5557,
|
||||||
metadata_filters: Optional[dict[str, dict[str, Union[str, int, float, bool, list]]]] = None,
|
metadata_filters: Optional[dict[str, dict[str, Union[str, int, float, bool, list]]]] = None,
|
||||||
@@ -927,7 +951,8 @@ class LeannSearcher:
|
|||||||
complexity: Search complexity/candidate list size, higher = more accurate but slower
|
complexity: Search complexity/candidate list size, higher = more accurate but slower
|
||||||
beam_width: Number of parallel search paths/IO requests per iteration
|
beam_width: Number of parallel search paths/IO requests per iteration
|
||||||
prune_ratio: Ratio of neighbors to prune via approximate distance (0.0-1.0)
|
prune_ratio: Ratio of neighbors to prune via approximate distance (0.0-1.0)
|
||||||
recompute_embeddings: Whether to fetch fresh embeddings from server vs use stored codes
|
recompute_embeddings: (Deprecated) Per-call override for recompute mode.
|
||||||
|
Configure this at LeannSearcher(..., recompute_embeddings=...) instead.
|
||||||
pruning_strategy: Candidate selection strategy - "global" (default), "local", or "proportional"
|
pruning_strategy: Candidate selection strategy - "global" (default), "local", or "proportional"
|
||||||
expected_zmq_port: ZMQ port for embedding server communication
|
expected_zmq_port: ZMQ port for embedding server communication
|
||||||
metadata_filters: Optional filters to apply to search results based on metadata.
|
metadata_filters: Optional filters to apply to search results based on metadata.
|
||||||
@@ -966,8 +991,19 @@ class LeannSearcher:
|
|||||||
|
|
||||||
zmq_port = None
|
zmq_port = None
|
||||||
|
|
||||||
|
# Resolve effective recompute flag for this search.
|
||||||
|
if recompute_embeddings is not None:
|
||||||
|
logger.warning(
|
||||||
|
"LeannSearcher.search(..., recompute_embeddings=...) is deprecated and "
|
||||||
|
"will be removed in a future version. Configure recompute at "
|
||||||
|
"LeannSearcher(..., recompute_embeddings=...) instead."
|
||||||
|
)
|
||||||
|
effective_recompute = bool(recompute_embeddings)
|
||||||
|
else:
|
||||||
|
effective_recompute = self.recompute_embeddings
|
||||||
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
if recompute_embeddings:
|
if effective_recompute:
|
||||||
zmq_port = self.backend_impl._ensure_server_running(
|
zmq_port = self.backend_impl._ensure_server_running(
|
||||||
self.meta_path_str,
|
self.meta_path_str,
|
||||||
port=expected_zmq_port,
|
port=expected_zmq_port,
|
||||||
@@ -981,7 +1017,7 @@ class LeannSearcher:
|
|||||||
|
|
||||||
query_embedding = self.backend_impl.compute_query_embedding(
|
query_embedding = self.backend_impl.compute_query_embedding(
|
||||||
query,
|
query,
|
||||||
use_server_if_available=recompute_embeddings,
|
use_server_if_available=effective_recompute,
|
||||||
zmq_port=zmq_port,
|
zmq_port=zmq_port,
|
||||||
)
|
)
|
||||||
logger.info(f" Generated embedding shape: {query_embedding.shape}")
|
logger.info(f" Generated embedding shape: {query_embedding.shape}")
|
||||||
@@ -993,7 +1029,7 @@ class LeannSearcher:
|
|||||||
"complexity": complexity,
|
"complexity": complexity,
|
||||||
"beam_width": beam_width,
|
"beam_width": beam_width,
|
||||||
"prune_ratio": prune_ratio,
|
"prune_ratio": prune_ratio,
|
||||||
"recompute_embeddings": recompute_embeddings,
|
"recompute_embeddings": effective_recompute,
|
||||||
"pruning_strategy": pruning_strategy,
|
"pruning_strategy": pruning_strategy,
|
||||||
"zmq_port": zmq_port,
|
"zmq_port": zmq_port,
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user