fix: use server to emb query only when recompute

This commit is contained in:
Andy Lee
2025-07-21 20:40:21 -07:00
parent 1b6272ce0e
commit 2f224f5793
3 changed files with 42 additions and 36 deletions

View File

@@ -19,7 +19,7 @@ def compute_embeddings(
model_name: str,
mode: str = "sentence-transformers",
use_server: bool = True,
port: int = 5557,
port: Optional[int] = None,
) -> np.ndarray:
"""
Computes embeddings using different backends.
@@ -38,6 +38,8 @@ def compute_embeddings(
"""
if use_server:
# Use embedding server (for search/query)
if port is None:
raise ValueError("port is required when use_server is True")
return compute_embeddings_via_server(chunks, model_name, port=port)
else:
# Use direct computation (for build_index)
@@ -105,21 +107,19 @@ class PassageManager:
self.global_offset_map = {} # Combined map for fast lookup
for source in passage_sources:
if source["type"] == "jsonl":
passage_file = source["path"]
index_file = source["index_path"]
if not Path(index_file).exists():
raise FileNotFoundError(
f"Passage index file not found: {index_file}"
)
with open(index_file, "rb") as f:
offset_map = pickle.load(f)
self.offset_maps[passage_file] = offset_map
self.passage_files[passage_file] = passage_file
assert source["type"] == "jsonl", "only jsonl is supported"
passage_file = source["path"]
index_file = source["index_path"]
if not Path(index_file).exists():
raise FileNotFoundError(f"Passage index file not found: {index_file}")
with open(index_file, "rb") as f:
offset_map = pickle.load(f)
self.offset_maps[passage_file] = offset_map
self.passage_files[passage_file] = passage_file
# Build global map for O(1) lookup
for passage_id, offset in offset_map.items():
self.global_offset_map[passage_id] = (passage_file, offset)
# Build global map for O(1) lookup
for passage_id, offset in offset_map.items():
self.global_offset_map[passage_id] = (passage_file, offset)
def get_passage(self, passage_id: str) -> Dict[str, Any]:
if passage_id in self.global_offset_map:
@@ -209,7 +209,6 @@ class LeannBuilder:
self.embedding_model,
self.embedding_mode,
use_server=False,
port=5557,
)
string_ids = [chunk["id"] for chunk in self.chunks]
current_backend_kwargs = {**self.backend_kwargs, "dimensions": self.dimensions}
@@ -382,9 +381,6 @@ class LeannSearcher:
self.embedding_mode = self.meta_data.get(
"embedding_mode", "sentence-transformers"
)
# Backward compatibility with use_mlx
if self.meta_data.get("use_mlx", False):
self.embedding_mode = "mlx"
self.passage_manager = PassageManager(self.meta_data.get("passage_sources", []))
backend_factory = BACKEND_REGISTRY.get(backend_name)
if backend_factory is None:
@@ -402,7 +398,7 @@ class LeannSearcher:
prune_ratio: float = 0.0,
recompute_embeddings: bool = False,
pruning_strategy: Literal["global", "local", "proportional"] = "global",
zmq_port: int = 5557,
expected_zmq_port: Optional[int] = None,
**kwargs,
) -> List[SearchResult]:
print("🔍 DEBUG LeannSearcher.search() called:")
@@ -416,7 +412,11 @@ class LeannSearcher:
start_time = time.time()
query_embedding = self.backend_impl.compute_query_embedding(query, zmq_port)
query_embedding = self.backend_impl.compute_query_embedding(
query,
expected_zmq_port,
use_server_if_available=recompute_embeddings,
)
print(f" Generated embedding shape: {query_embedding.shape}")
embedding_time = time.time() - start_time
print(f" Embedding time: {embedding_time} seconds")
@@ -430,7 +430,7 @@ class LeannSearcher:
prune_ratio=prune_ratio,
recompute_embeddings=recompute_embeddings,
pruning_strategy=pruning_strategy,
zmq_port=zmq_port,
expected_zmq_port=expected_zmq_port,
**kwargs,
)
search_time = time.time() - start_time
@@ -487,7 +487,7 @@ class LeannChat:
prune_ratio: float = 0.0,
recompute_embeddings: bool = False,
pruning_strategy: Literal["global", "local", "proportional"] = "global",
zmq_port: int = 5557,
expected_zmq_port: Optional[int] = None,
llm_kwargs: Optional[Dict[str, Any]] = None,
**search_kwargs,
):
@@ -502,7 +502,7 @@ class LeannChat:
prune_ratio=prune_ratio,
recompute_embeddings=recompute_embeddings,
pruning_strategy=pruning_strategy,
zmq_port=zmq_port,
expected_zmq_port=expected_zmq_port,
**search_kwargs,
)
context = "\n\n".join([r.text for r in results])

View File

@@ -1,6 +1,6 @@
from abc import ABC, abstractmethod
import numpy as np
from typing import Dict, Any, List, Literal
from typing import Dict, Any, List, Literal, Optional
class LeannBackendBuilderInterface(ABC):
@@ -44,7 +44,7 @@ class LeannBackendSearcherInterface(ABC):
prune_ratio: float = 0.0,
recompute_embeddings: bool = False,
pruning_strategy: Literal["global", "local", "proportional"] = "global",
zmq_port: int = 5557,
expected_zmq_port: Optional[int] = None,
**kwargs,
) -> Dict[str, Any]:
"""Search for nearest neighbors
@@ -57,7 +57,7 @@ class LeannBackendSearcherInterface(ABC):
prune_ratio: Ratio of neighbors to prune via approximate distance (0.0-1.0)
recompute_embeddings: Whether to fetch fresh embeddings from server vs use stored PQ codes
pruning_strategy: PQ candidate selection strategy - "global" (default), "local", or "proportional"
zmq_port: ZMQ port for embedding server communication
expected_zmq_port: ZMQ port for embedding server communication
**kwargs: Backend-specific parameters
Returns:
@@ -67,13 +67,16 @@ class LeannBackendSearcherInterface(ABC):
@abstractmethod
def compute_query_embedding(
self, query: str, zmq_port: int = 5557, use_server_if_available: bool = True
self,
query: str,
expected_zmq_port: Optional[int] = None,
use_server_if_available: bool = True,
) -> np.ndarray:
"""Compute embedding for a query string
Args:
query: The query string to embed
zmq_port: ZMQ port for embedding server
expected_zmq_port: ZMQ port for embedding server
use_server_if_available: Whether to try using embedding server first
Returns:

View File

@@ -2,7 +2,7 @@ import json
import pickle
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Dict, Any, Literal
from typing import Dict, Any, Literal, Optional
import numpy as np
@@ -86,14 +86,17 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
return actual_port
def compute_query_embedding(
self, query: str, zmq_port: int = 5557, use_server_if_available: bool = True
self,
query: str,
expected_zmq_port: int = 5557,
use_server_if_available: bool = True,
) -> np.ndarray:
"""
Compute embedding for a query string.
Args:
query: The query string to embed
zmq_port: ZMQ port for embedding server
expected_zmq_port: ZMQ port for embedding server
use_server_if_available: Whether to try using embedding server first
Returns:
@@ -107,7 +110,7 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
self.index_dir / f"{self.index_path.name}.meta.json"
)
zmq_port = self._ensure_server_running(
str(passages_source_file), zmq_port
str(passages_source_file), expected_zmq_port
)
return self._compute_embedding_via_server([query], zmq_port)[
@@ -118,7 +121,7 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
print("⏭️ Falling back to direct model loading...")
# Fallback to direct computation
from .api import compute_embeddings
from .embedding_compute import compute_embeddings
embedding_mode = self.meta.get("embedding_mode", "sentence-transformers")
return compute_embeddings([query], self.embedding_model, embedding_mode)
@@ -165,7 +168,7 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
prune_ratio: float = 0.0,
recompute_embeddings: bool = False,
pruning_strategy: Literal["global", "local", "proportional"] = "global",
zmq_port: int = 5557,
expected_zmq_port: Optional[int] = None,
**kwargs,
) -> Dict[str, Any]:
"""
@@ -179,7 +182,7 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
prune_ratio: Ratio of neighbors to prune via approximate distance (0.0-1.0)
recompute_embeddings: Whether to fetch fresh embeddings from server vs use stored PQ codes
pruning_strategy: PQ candidate selection strategy - "global" (default), "local", or "proportional"
zmq_port: ZMQ port for embedding server communication
expected_zmq_port: ZMQ port for embedding server communication
**kwargs: Backend-specific parameters (e.g., batch_size, dedup_node_dis, etc.)
Returns: