From c5d8138349e954fcf1f3a796eec7acc5502dd324 Mon Sep 17 00:00:00 2001 From: yichuan-w Date: Mon, 18 Aug 2025 17:04:40 -0700 Subject: [PATCH] feat(hnsw): add batch_size to LeannSearcher.search and LeannChat.ask; forward only for HNSW backend --- packages/leann-core/src/leann/api.py | 27 ++++++++++++++++++++------- 1 file changed, 20 insertions(+), 7 deletions(-) diff --git a/packages/leann-core/src/leann/api.py b/packages/leann-core/src/leann/api.py index 71e6b69..41cda80 100644 --- a/packages/leann-core/src/leann/api.py +++ b/packages/leann-core/src/leann/api.py @@ -557,6 +557,8 @@ class LeannSearcher: self.passage_manager = PassageManager( self.meta_data.get("passage_sources", []), metadata_file_path=self.meta_path_str ) + # Preserve backend name for conditional parameter forwarding + self.backend_name = backend_name backend_factory = BACKEND_REGISTRY.get(backend_name) if backend_factory is None: raise ValueError(f"Backend '{backend_name}' not found.") @@ -576,6 +578,7 @@ class LeannSearcher: recompute_embeddings: bool = True, pruning_strategy: Literal["global", "local", "proportional"] = "global", expected_zmq_port: int = 5557, + batch_size: int = 0, **kwargs, ) -> list[SearchResult]: logger.info("🔍 LeannSearcher.search() called:") @@ -618,16 +621,24 @@ class LeannSearcher: logger.info(f" Embedding time: {embedding_time} seconds") start_time = time.time() + backend_search_kwargs: dict[str, Any] = { + "complexity": complexity, + "beam_width": beam_width, + "prune_ratio": prune_ratio, + "recompute_embeddings": recompute_embeddings, + "pruning_strategy": pruning_strategy, + "zmq_port": zmq_port, + } + # Only HNSW supports batching; forward conditionally + if self.backend_name == "hnsw": + backend_search_kwargs["batch_size"] = batch_size + # Merge any extra kwargs last + backend_search_kwargs.update(kwargs) + results = self.backend_impl.search( query_embedding, top_k, - complexity=complexity, - beam_width=beam_width, - prune_ratio=prune_ratio, - recompute_embeddings=recompute_embeddings, - pruning_strategy=pruning_strategy, - zmq_port=zmq_port, - **kwargs, + **backend_search_kwargs, ) search_time = time.time() - start_time logger.info(f" Search time in search() LEANN searcher: {search_time} seconds") @@ -731,6 +742,7 @@ class LeannChat: pruning_strategy: Literal["global", "local", "proportional"] = "global", llm_kwargs: Optional[dict[str, Any]] = None, expected_zmq_port: int = 5557, + batch_size: int = 0, **search_kwargs, ): if llm_kwargs is None: @@ -745,6 +757,7 @@ class LeannChat: recompute_embeddings=recompute_embeddings, pruning_strategy=pruning_strategy, expected_zmq_port=expected_zmq_port, + batch_size=batch_size, **search_kwargs, ) search_time = time.time() - search_time