diff --git a/packages/leann-core/src/leann/api.py b/packages/leann-core/src/leann/api.py index d5f3a53..6f4c536 100644 --- a/packages/leann-core/src/leann/api.py +++ b/packages/leann-core/src/leann/api.py @@ -19,7 +19,7 @@ def compute_embeddings( model_name: str, mode: str = "sentence-transformers", use_server: bool = True, - port: int = 5557, + port: Optional[int] = None, ) -> np.ndarray: """ Computes embeddings using different backends. @@ -38,6 +38,8 @@ def compute_embeddings( """ if use_server: # Use embedding server (for search/query) + if port is None: + raise ValueError("port is required when use_server is True") return compute_embeddings_via_server(chunks, model_name, port=port) else: # Use direct computation (for build_index) @@ -105,21 +107,19 @@ class PassageManager: self.global_offset_map = {} # Combined map for fast lookup for source in passage_sources: - if source["type"] == "jsonl": - passage_file = source["path"] - index_file = source["index_path"] - if not Path(index_file).exists(): - raise FileNotFoundError( - f"Passage index file not found: {index_file}" - ) - with open(index_file, "rb") as f: - offset_map = pickle.load(f) - self.offset_maps[passage_file] = offset_map - self.passage_files[passage_file] = passage_file + assert source["type"] == "jsonl", "only jsonl is supported" + passage_file = source["path"] + index_file = source["index_path"] + if not Path(index_file).exists(): + raise FileNotFoundError(f"Passage index file not found: {index_file}") + with open(index_file, "rb") as f: + offset_map = pickle.load(f) + self.offset_maps[passage_file] = offset_map + self.passage_files[passage_file] = passage_file - # Build global map for O(1) lookup - for passage_id, offset in offset_map.items(): - self.global_offset_map[passage_id] = (passage_file, offset) + # Build global map for O(1) lookup + for passage_id, offset in offset_map.items(): + self.global_offset_map[passage_id] = (passage_file, offset) def get_passage(self, passage_id: str) -> Dict[str, Any]: if passage_id in self.global_offset_map: @@ -209,7 +209,6 @@ class LeannBuilder: self.embedding_model, self.embedding_mode, use_server=False, - port=5557, ) string_ids = [chunk["id"] for chunk in self.chunks] current_backend_kwargs = {**self.backend_kwargs, "dimensions": self.dimensions} @@ -382,9 +381,6 @@ class LeannSearcher: self.embedding_mode = self.meta_data.get( "embedding_mode", "sentence-transformers" ) - # Backward compatibility with use_mlx - if self.meta_data.get("use_mlx", False): - self.embedding_mode = "mlx" self.passage_manager = PassageManager(self.meta_data.get("passage_sources", [])) backend_factory = BACKEND_REGISTRY.get(backend_name) if backend_factory is None: @@ -402,7 +398,7 @@ class LeannSearcher: prune_ratio: float = 0.0, recompute_embeddings: bool = False, pruning_strategy: Literal["global", "local", "proportional"] = "global", - zmq_port: int = 5557, + expected_zmq_port: Optional[int] = None, **kwargs, ) -> List[SearchResult]: print("🔍 DEBUG LeannSearcher.search() called:") @@ -416,7 +412,11 @@ class LeannSearcher: start_time = time.time() - query_embedding = self.backend_impl.compute_query_embedding(query, zmq_port) + query_embedding = self.backend_impl.compute_query_embedding( + query, + expected_zmq_port, + use_server_if_available=recompute_embeddings, + ) print(f" Generated embedding shape: {query_embedding.shape}") embedding_time = time.time() - start_time print(f" Embedding time: {embedding_time} seconds") @@ -430,7 +430,7 @@ class LeannSearcher: prune_ratio=prune_ratio, recompute_embeddings=recompute_embeddings, pruning_strategy=pruning_strategy, - zmq_port=zmq_port, + expected_zmq_port=expected_zmq_port, **kwargs, ) search_time = time.time() - start_time @@ -487,7 +487,7 @@ class LeannChat: prune_ratio: float = 0.0, recompute_embeddings: bool = False, pruning_strategy: Literal["global", "local", "proportional"] = "global", - zmq_port: int = 5557, + expected_zmq_port: Optional[int] = None, llm_kwargs: Optional[Dict[str, Any]] = None, **search_kwargs, ): @@ -502,7 +502,7 @@ class LeannChat: prune_ratio=prune_ratio, recompute_embeddings=recompute_embeddings, pruning_strategy=pruning_strategy, - zmq_port=zmq_port, + expected_zmq_port=expected_zmq_port, **search_kwargs, ) context = "\n\n".join([r.text for r in results]) diff --git a/packages/leann-core/src/leann/interface.py b/packages/leann-core/src/leann/interface.py index fc7c71b..dfa4cc1 100644 --- a/packages/leann-core/src/leann/interface.py +++ b/packages/leann-core/src/leann/interface.py @@ -1,6 +1,6 @@ from abc import ABC, abstractmethod import numpy as np -from typing import Dict, Any, List, Literal +from typing import Dict, Any, List, Literal, Optional class LeannBackendBuilderInterface(ABC): @@ -44,7 +44,7 @@ class LeannBackendSearcherInterface(ABC): prune_ratio: float = 0.0, recompute_embeddings: bool = False, pruning_strategy: Literal["global", "local", "proportional"] = "global", - zmq_port: int = 5557, + expected_zmq_port: Optional[int] = None, **kwargs, ) -> Dict[str, Any]: """Search for nearest neighbors @@ -57,7 +57,7 @@ class LeannBackendSearcherInterface(ABC): prune_ratio: Ratio of neighbors to prune via approximate distance (0.0-1.0) recompute_embeddings: Whether to fetch fresh embeddings from server vs use stored PQ codes pruning_strategy: PQ candidate selection strategy - "global" (default), "local", or "proportional" - zmq_port: ZMQ port for embedding server communication + expected_zmq_port: ZMQ port for embedding server communication **kwargs: Backend-specific parameters Returns: @@ -67,13 +67,16 @@ class LeannBackendSearcherInterface(ABC): @abstractmethod def compute_query_embedding( - self, query: str, zmq_port: int = 5557, use_server_if_available: bool = True + self, + query: str, + expected_zmq_port: Optional[int] = None, + use_server_if_available: bool = True, ) -> np.ndarray: """Compute embedding for a query string Args: query: The query string to embed - zmq_port: ZMQ port for embedding server + expected_zmq_port: ZMQ port for embedding server use_server_if_available: Whether to try using embedding server first Returns: diff --git a/packages/leann-core/src/leann/searcher_base.py b/packages/leann-core/src/leann/searcher_base.py index dfa6c2d..fad5589 100644 --- a/packages/leann-core/src/leann/searcher_base.py +++ b/packages/leann-core/src/leann/searcher_base.py @@ -2,7 +2,7 @@ import json import pickle from abc import ABC, abstractmethod from pathlib import Path -from typing import Dict, Any, Literal +from typing import Dict, Any, Literal, Optional import numpy as np @@ -86,14 +86,17 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC): return actual_port def compute_query_embedding( - self, query: str, zmq_port: int = 5557, use_server_if_available: bool = True + self, + query: str, + expected_zmq_port: int = 5557, + use_server_if_available: bool = True, ) -> np.ndarray: """ Compute embedding for a query string. Args: query: The query string to embed - zmq_port: ZMQ port for embedding server + expected_zmq_port: ZMQ port for embedding server use_server_if_available: Whether to try using embedding server first Returns: @@ -107,7 +110,7 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC): self.index_dir / f"{self.index_path.name}.meta.json" ) zmq_port = self._ensure_server_running( - str(passages_source_file), zmq_port + str(passages_source_file), expected_zmq_port ) return self._compute_embedding_via_server([query], zmq_port)[ @@ -118,7 +121,7 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC): print("⏭️ Falling back to direct model loading...") # Fallback to direct computation - from .api import compute_embeddings + from .embedding_compute import compute_embeddings embedding_mode = self.meta.get("embedding_mode", "sentence-transformers") return compute_embeddings([query], self.embedding_model, embedding_mode) @@ -165,7 +168,7 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC): prune_ratio: float = 0.0, recompute_embeddings: bool = False, pruning_strategy: Literal["global", "local", "proportional"] = "global", - zmq_port: int = 5557, + expected_zmq_port: Optional[int] = None, **kwargs, ) -> Dict[str, Any]: """ @@ -179,7 +182,7 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC): prune_ratio: Ratio of neighbors to prune via approximate distance (0.0-1.0) recompute_embeddings: Whether to fetch fresh embeddings from server vs use stored PQ codes pruning_strategy: PQ candidate selection strategy - "global" (default), "local", or "proportional" - zmq_port: ZMQ port for embedding server communication + expected_zmq_port: ZMQ port for embedding server communication **kwargs: Backend-specific parameters (e.g., batch_size, dedup_node_dis, etc.) Returns: