feat(hnsw): add batch_size to LeannSearcher.search and LeannChat.ask; forward only for HNSW backend

This commit is contained in:
yichuan-w
2025-08-18 17:04:40 -07:00
parent d14faaa771
commit c5d8138349

View File

@@ -557,6 +557,8 @@ class LeannSearcher:
self.passage_manager = PassageManager(
self.meta_data.get("passage_sources", []), metadata_file_path=self.meta_path_str
)
# Preserve backend name for conditional parameter forwarding
self.backend_name = backend_name
backend_factory = BACKEND_REGISTRY.get(backend_name)
if backend_factory is None:
raise ValueError(f"Backend '{backend_name}' not found.")
@@ -576,6 +578,7 @@ class LeannSearcher:
recompute_embeddings: bool = True,
pruning_strategy: Literal["global", "local", "proportional"] = "global",
expected_zmq_port: int = 5557,
batch_size: int = 0,
**kwargs,
) -> list[SearchResult]:
logger.info("🔍 LeannSearcher.search() called:")
@@ -618,16 +621,24 @@ class LeannSearcher:
logger.info(f" Embedding time: {embedding_time} seconds")
start_time = time.time()
backend_search_kwargs: dict[str, Any] = {
"complexity": complexity,
"beam_width": beam_width,
"prune_ratio": prune_ratio,
"recompute_embeddings": recompute_embeddings,
"pruning_strategy": pruning_strategy,
"zmq_port": zmq_port,
}
# Only HNSW supports batching; forward conditionally
if self.backend_name == "hnsw":
backend_search_kwargs["batch_size"] = batch_size
# Merge any extra kwargs last
backend_search_kwargs.update(kwargs)
results = self.backend_impl.search(
query_embedding,
top_k,
complexity=complexity,
beam_width=beam_width,
prune_ratio=prune_ratio,
recompute_embeddings=recompute_embeddings,
pruning_strategy=pruning_strategy,
zmq_port=zmq_port,
**kwargs,
**backend_search_kwargs,
)
search_time = time.time() - start_time
logger.info(f" Search time in search() LEANN searcher: {search_time} seconds")
@@ -731,6 +742,7 @@ class LeannChat:
pruning_strategy: Literal["global", "local", "proportional"] = "global",
llm_kwargs: Optional[dict[str, Any]] = None,
expected_zmq_port: int = 5557,
batch_size: int = 0,
**search_kwargs,
):
if llm_kwargs is None:
@@ -745,6 +757,7 @@ class LeannChat:
recompute_embeddings=recompute_embeddings,
pruning_strategy=pruning_strategy,
expected_zmq_port=expected_zmq_port,
batch_size=batch_size,
**search_kwargs,
)
search_time = time.time() - search_time