fix: use server to emb query only when recompute

This commit is contained in:
Andy Lee
2025-07-21 20:40:21 -07:00
parent 1b6272ce0e
commit 2f224f5793
3 changed files with 42 additions and 36 deletions

View File

@@ -19,7 +19,7 @@ def compute_embeddings(
model_name: str,
mode: str = "sentence-transformers",
use_server: bool = True,
port: int = 5557,
port: Optional[int] = None,
) -> np.ndarray:
"""
Computes embeddings using different backends.
@@ -38,6 +38,8 @@ def compute_embeddings(
"""
if use_server:
# Use embedding server (for search/query)
if port is None:
raise ValueError("port is required when use_server is True")
return compute_embeddings_via_server(chunks, model_name, port=port)
else:
# Use direct computation (for build_index)
@@ -105,21 +107,19 @@ class PassageManager:
self.global_offset_map = {} # Combined map for fast lookup
for source in passage_sources:
if source["type"] == "jsonl":
passage_file = source["path"]
index_file = source["index_path"]
if not Path(index_file).exists():
raise FileNotFoundError(
f"Passage index file not found: {index_file}"
)
with open(index_file, "rb") as f:
offset_map = pickle.load(f)
self.offset_maps[passage_file] = offset_map
self.passage_files[passage_file] = passage_file
assert source["type"] == "jsonl", "only jsonl is supported"
passage_file = source["path"]
index_file = source["index_path"]
if not Path(index_file).exists():
raise FileNotFoundError(f"Passage index file not found: {index_file}")
with open(index_file, "rb") as f:
offset_map = pickle.load(f)
self.offset_maps[passage_file] = offset_map
self.passage_files[passage_file] = passage_file
# Build global map for O(1) lookup
for passage_id, offset in offset_map.items():
self.global_offset_map[passage_id] = (passage_file, offset)
# Build global map for O(1) lookup
for passage_id, offset in offset_map.items():
self.global_offset_map[passage_id] = (passage_file, offset)
def get_passage(self, passage_id: str) -> Dict[str, Any]:
if passage_id in self.global_offset_map:
@@ -209,7 +209,6 @@ class LeannBuilder:
self.embedding_model,
self.embedding_mode,
use_server=False,
port=5557,
)
string_ids = [chunk["id"] for chunk in self.chunks]
current_backend_kwargs = {**self.backend_kwargs, "dimensions": self.dimensions}
@@ -382,9 +381,6 @@ class LeannSearcher:
self.embedding_mode = self.meta_data.get(
"embedding_mode", "sentence-transformers"
)
# Backward compatibility with use_mlx
if self.meta_data.get("use_mlx", False):
self.embedding_mode = "mlx"
self.passage_manager = PassageManager(self.meta_data.get("passage_sources", []))
backend_factory = BACKEND_REGISTRY.get(backend_name)
if backend_factory is None:
@@ -402,7 +398,7 @@ class LeannSearcher:
prune_ratio: float = 0.0,
recompute_embeddings: bool = False,
pruning_strategy: Literal["global", "local", "proportional"] = "global",
zmq_port: int = 5557,
expected_zmq_port: Optional[int] = None,
**kwargs,
) -> List[SearchResult]:
print("🔍 DEBUG LeannSearcher.search() called:")
@@ -416,7 +412,11 @@ class LeannSearcher:
start_time = time.time()
query_embedding = self.backend_impl.compute_query_embedding(query, zmq_port)
query_embedding = self.backend_impl.compute_query_embedding(
query,
expected_zmq_port,
use_server_if_available=recompute_embeddings,
)
print(f" Generated embedding shape: {query_embedding.shape}")
embedding_time = time.time() - start_time
print(f" Embedding time: {embedding_time} seconds")
@@ -430,7 +430,7 @@ class LeannSearcher:
prune_ratio=prune_ratio,
recompute_embeddings=recompute_embeddings,
pruning_strategy=pruning_strategy,
zmq_port=zmq_port,
expected_zmq_port=expected_zmq_port,
**kwargs,
)
search_time = time.time() - start_time
@@ -487,7 +487,7 @@ class LeannChat:
prune_ratio: float = 0.0,
recompute_embeddings: bool = False,
pruning_strategy: Literal["global", "local", "proportional"] = "global",
zmq_port: int = 5557,
expected_zmq_port: Optional[int] = None,
llm_kwargs: Optional[Dict[str, Any]] = None,
**search_kwargs,
):
@@ -502,7 +502,7 @@ class LeannChat:
prune_ratio=prune_ratio,
recompute_embeddings=recompute_embeddings,
pruning_strategy=pruning_strategy,
zmq_port=zmq_port,
expected_zmq_port=expected_zmq_port,
**search_kwargs,
)
context = "\n\n".join([r.text for r in results])