fix: use server to emb query only when recompute

This commit is contained in:
Andy Lee
2025-07-21 20:40:21 -07:00
parent 1b6272ce0e
commit 2f224f5793
3 changed files with 42 additions and 36 deletions

View File

@@ -19,7 +19,7 @@ def compute_embeddings(
model_name: str, model_name: str,
mode: str = "sentence-transformers", mode: str = "sentence-transformers",
use_server: bool = True, use_server: bool = True,
port: int = 5557, port: Optional[int] = None,
) -> np.ndarray: ) -> np.ndarray:
""" """
Computes embeddings using different backends. Computes embeddings using different backends.
@@ -38,6 +38,8 @@ def compute_embeddings(
""" """
if use_server: if use_server:
# Use embedding server (for search/query) # Use embedding server (for search/query)
if port is None:
raise ValueError("port is required when use_server is True")
return compute_embeddings_via_server(chunks, model_name, port=port) return compute_embeddings_via_server(chunks, model_name, port=port)
else: else:
# Use direct computation (for build_index) # Use direct computation (for build_index)
@@ -105,21 +107,19 @@ class PassageManager:
self.global_offset_map = {} # Combined map for fast lookup self.global_offset_map = {} # Combined map for fast lookup
for source in passage_sources: for source in passage_sources:
if source["type"] == "jsonl": assert source["type"] == "jsonl", "only jsonl is supported"
passage_file = source["path"] passage_file = source["path"]
index_file = source["index_path"] index_file = source["index_path"]
if not Path(index_file).exists(): if not Path(index_file).exists():
raise FileNotFoundError( raise FileNotFoundError(f"Passage index file not found: {index_file}")
f"Passage index file not found: {index_file}" with open(index_file, "rb") as f:
) offset_map = pickle.load(f)
with open(index_file, "rb") as f: self.offset_maps[passage_file] = offset_map
offset_map = pickle.load(f) self.passage_files[passage_file] = passage_file
self.offset_maps[passage_file] = offset_map
self.passage_files[passage_file] = passage_file
# Build global map for O(1) lookup # Build global map for O(1) lookup
for passage_id, offset in offset_map.items(): for passage_id, offset in offset_map.items():
self.global_offset_map[passage_id] = (passage_file, offset) self.global_offset_map[passage_id] = (passage_file, offset)
def get_passage(self, passage_id: str) -> Dict[str, Any]: def get_passage(self, passage_id: str) -> Dict[str, Any]:
if passage_id in self.global_offset_map: if passage_id in self.global_offset_map:
@@ -209,7 +209,6 @@ class LeannBuilder:
self.embedding_model, self.embedding_model,
self.embedding_mode, self.embedding_mode,
use_server=False, use_server=False,
port=5557,
) )
string_ids = [chunk["id"] for chunk in self.chunks] string_ids = [chunk["id"] for chunk in self.chunks]
current_backend_kwargs = {**self.backend_kwargs, "dimensions": self.dimensions} current_backend_kwargs = {**self.backend_kwargs, "dimensions": self.dimensions}
@@ -382,9 +381,6 @@ class LeannSearcher:
self.embedding_mode = self.meta_data.get( self.embedding_mode = self.meta_data.get(
"embedding_mode", "sentence-transformers" "embedding_mode", "sentence-transformers"
) )
# Backward compatibility with use_mlx
if self.meta_data.get("use_mlx", False):
self.embedding_mode = "mlx"
self.passage_manager = PassageManager(self.meta_data.get("passage_sources", [])) self.passage_manager = PassageManager(self.meta_data.get("passage_sources", []))
backend_factory = BACKEND_REGISTRY.get(backend_name) backend_factory = BACKEND_REGISTRY.get(backend_name)
if backend_factory is None: if backend_factory is None:
@@ -402,7 +398,7 @@ class LeannSearcher:
prune_ratio: float = 0.0, prune_ratio: float = 0.0,
recompute_embeddings: bool = False, recompute_embeddings: bool = False,
pruning_strategy: Literal["global", "local", "proportional"] = "global", pruning_strategy: Literal["global", "local", "proportional"] = "global",
zmq_port: int = 5557, expected_zmq_port: Optional[int] = None,
**kwargs, **kwargs,
) -> List[SearchResult]: ) -> List[SearchResult]:
print("🔍 DEBUG LeannSearcher.search() called:") print("🔍 DEBUG LeannSearcher.search() called:")
@@ -416,7 +412,11 @@ class LeannSearcher:
start_time = time.time() start_time = time.time()
query_embedding = self.backend_impl.compute_query_embedding(query, zmq_port) query_embedding = self.backend_impl.compute_query_embedding(
query,
expected_zmq_port,
use_server_if_available=recompute_embeddings,
)
print(f" Generated embedding shape: {query_embedding.shape}") print(f" Generated embedding shape: {query_embedding.shape}")
embedding_time = time.time() - start_time embedding_time = time.time() - start_time
print(f" Embedding time: {embedding_time} seconds") print(f" Embedding time: {embedding_time} seconds")
@@ -430,7 +430,7 @@ class LeannSearcher:
prune_ratio=prune_ratio, prune_ratio=prune_ratio,
recompute_embeddings=recompute_embeddings, recompute_embeddings=recompute_embeddings,
pruning_strategy=pruning_strategy, pruning_strategy=pruning_strategy,
zmq_port=zmq_port, expected_zmq_port=expected_zmq_port,
**kwargs, **kwargs,
) )
search_time = time.time() - start_time search_time = time.time() - start_time
@@ -487,7 +487,7 @@ class LeannChat:
prune_ratio: float = 0.0, prune_ratio: float = 0.0,
recompute_embeddings: bool = False, recompute_embeddings: bool = False,
pruning_strategy: Literal["global", "local", "proportional"] = "global", pruning_strategy: Literal["global", "local", "proportional"] = "global",
zmq_port: int = 5557, expected_zmq_port: Optional[int] = None,
llm_kwargs: Optional[Dict[str, Any]] = None, llm_kwargs: Optional[Dict[str, Any]] = None,
**search_kwargs, **search_kwargs,
): ):
@@ -502,7 +502,7 @@ class LeannChat:
prune_ratio=prune_ratio, prune_ratio=prune_ratio,
recompute_embeddings=recompute_embeddings, recompute_embeddings=recompute_embeddings,
pruning_strategy=pruning_strategy, pruning_strategy=pruning_strategy,
zmq_port=zmq_port, expected_zmq_port=expected_zmq_port,
**search_kwargs, **search_kwargs,
) )
context = "\n\n".join([r.text for r in results]) context = "\n\n".join([r.text for r in results])

View File

@@ -1,6 +1,6 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
import numpy as np import numpy as np
from typing import Dict, Any, List, Literal from typing import Dict, Any, List, Literal, Optional
class LeannBackendBuilderInterface(ABC): class LeannBackendBuilderInterface(ABC):
@@ -44,7 +44,7 @@ class LeannBackendSearcherInterface(ABC):
prune_ratio: float = 0.0, prune_ratio: float = 0.0,
recompute_embeddings: bool = False, recompute_embeddings: bool = False,
pruning_strategy: Literal["global", "local", "proportional"] = "global", pruning_strategy: Literal["global", "local", "proportional"] = "global",
zmq_port: int = 5557, expected_zmq_port: Optional[int] = None,
**kwargs, **kwargs,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
"""Search for nearest neighbors """Search for nearest neighbors
@@ -57,7 +57,7 @@ class LeannBackendSearcherInterface(ABC):
prune_ratio: Ratio of neighbors to prune via approximate distance (0.0-1.0) prune_ratio: Ratio of neighbors to prune via approximate distance (0.0-1.0)
recompute_embeddings: Whether to fetch fresh embeddings from server vs use stored PQ codes recompute_embeddings: Whether to fetch fresh embeddings from server vs use stored PQ codes
pruning_strategy: PQ candidate selection strategy - "global" (default), "local", or "proportional" pruning_strategy: PQ candidate selection strategy - "global" (default), "local", or "proportional"
zmq_port: ZMQ port for embedding server communication expected_zmq_port: ZMQ port for embedding server communication
**kwargs: Backend-specific parameters **kwargs: Backend-specific parameters
Returns: Returns:
@@ -67,13 +67,16 @@ class LeannBackendSearcherInterface(ABC):
@abstractmethod @abstractmethod
def compute_query_embedding( def compute_query_embedding(
self, query: str, zmq_port: int = 5557, use_server_if_available: bool = True self,
query: str,
expected_zmq_port: Optional[int] = None,
use_server_if_available: bool = True,
) -> np.ndarray: ) -> np.ndarray:
"""Compute embedding for a query string """Compute embedding for a query string
Args: Args:
query: The query string to embed query: The query string to embed
zmq_port: ZMQ port for embedding server expected_zmq_port: ZMQ port for embedding server
use_server_if_available: Whether to try using embedding server first use_server_if_available: Whether to try using embedding server first
Returns: Returns:

View File

@@ -2,7 +2,7 @@ import json
import pickle import pickle
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from pathlib import Path from pathlib import Path
from typing import Dict, Any, Literal from typing import Dict, Any, Literal, Optional
import numpy as np import numpy as np
@@ -86,14 +86,17 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
return actual_port return actual_port
def compute_query_embedding( def compute_query_embedding(
self, query: str, zmq_port: int = 5557, use_server_if_available: bool = True self,
query: str,
expected_zmq_port: int = 5557,
use_server_if_available: bool = True,
) -> np.ndarray: ) -> np.ndarray:
""" """
Compute embedding for a query string. Compute embedding for a query string.
Args: Args:
query: The query string to embed query: The query string to embed
zmq_port: ZMQ port for embedding server expected_zmq_port: ZMQ port for embedding server
use_server_if_available: Whether to try using embedding server first use_server_if_available: Whether to try using embedding server first
Returns: Returns:
@@ -107,7 +110,7 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
self.index_dir / f"{self.index_path.name}.meta.json" self.index_dir / f"{self.index_path.name}.meta.json"
) )
zmq_port = self._ensure_server_running( zmq_port = self._ensure_server_running(
str(passages_source_file), zmq_port str(passages_source_file), expected_zmq_port
) )
return self._compute_embedding_via_server([query], zmq_port)[ return self._compute_embedding_via_server([query], zmq_port)[
@@ -118,7 +121,7 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
print("⏭️ Falling back to direct model loading...") print("⏭️ Falling back to direct model loading...")
# Fallback to direct computation # Fallback to direct computation
from .api import compute_embeddings from .embedding_compute import compute_embeddings
embedding_mode = self.meta.get("embedding_mode", "sentence-transformers") embedding_mode = self.meta.get("embedding_mode", "sentence-transformers")
return compute_embeddings([query], self.embedding_model, embedding_mode) return compute_embeddings([query], self.embedding_model, embedding_mode)
@@ -165,7 +168,7 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
prune_ratio: float = 0.0, prune_ratio: float = 0.0,
recompute_embeddings: bool = False, recompute_embeddings: bool = False,
pruning_strategy: Literal["global", "local", "proportional"] = "global", pruning_strategy: Literal["global", "local", "proportional"] = "global",
zmq_port: int = 5557, expected_zmq_port: Optional[int] = None,
**kwargs, **kwargs,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
""" """
@@ -179,7 +182,7 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
prune_ratio: Ratio of neighbors to prune via approximate distance (0.0-1.0) prune_ratio: Ratio of neighbors to prune via approximate distance (0.0-1.0)
recompute_embeddings: Whether to fetch fresh embeddings from server vs use stored PQ codes recompute_embeddings: Whether to fetch fresh embeddings from server vs use stored PQ codes
pruning_strategy: PQ candidate selection strategy - "global" (default), "local", or "proportional" pruning_strategy: PQ candidate selection strategy - "global" (default), "local", or "proportional"
zmq_port: ZMQ port for embedding server communication expected_zmq_port: ZMQ port for embedding server communication
**kwargs: Backend-specific parameters (e.g., batch_size, dedup_node_dis, etc.) **kwargs: Backend-specific parameters (e.g., batch_size, dedup_node_dis, etc.)
Returns: Returns: