fix: use server to emb query only when recompute
This commit is contained in:
@@ -19,7 +19,7 @@ def compute_embeddings(
|
||||
model_name: str,
|
||||
mode: str = "sentence-transformers",
|
||||
use_server: bool = True,
|
||||
port: int = 5557,
|
||||
port: Optional[int] = None,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Computes embeddings using different backends.
|
||||
@@ -38,6 +38,8 @@ def compute_embeddings(
|
||||
"""
|
||||
if use_server:
|
||||
# Use embedding server (for search/query)
|
||||
if port is None:
|
||||
raise ValueError("port is required when use_server is True")
|
||||
return compute_embeddings_via_server(chunks, model_name, port=port)
|
||||
else:
|
||||
# Use direct computation (for build_index)
|
||||
@@ -105,21 +107,19 @@ class PassageManager:
|
||||
self.global_offset_map = {} # Combined map for fast lookup
|
||||
|
||||
for source in passage_sources:
|
||||
if source["type"] == "jsonl":
|
||||
passage_file = source["path"]
|
||||
index_file = source["index_path"]
|
||||
if not Path(index_file).exists():
|
||||
raise FileNotFoundError(
|
||||
f"Passage index file not found: {index_file}"
|
||||
)
|
||||
with open(index_file, "rb") as f:
|
||||
offset_map = pickle.load(f)
|
||||
self.offset_maps[passage_file] = offset_map
|
||||
self.passage_files[passage_file] = passage_file
|
||||
assert source["type"] == "jsonl", "only jsonl is supported"
|
||||
passage_file = source["path"]
|
||||
index_file = source["index_path"]
|
||||
if not Path(index_file).exists():
|
||||
raise FileNotFoundError(f"Passage index file not found: {index_file}")
|
||||
with open(index_file, "rb") as f:
|
||||
offset_map = pickle.load(f)
|
||||
self.offset_maps[passage_file] = offset_map
|
||||
self.passage_files[passage_file] = passage_file
|
||||
|
||||
# Build global map for O(1) lookup
|
||||
for passage_id, offset in offset_map.items():
|
||||
self.global_offset_map[passage_id] = (passage_file, offset)
|
||||
# Build global map for O(1) lookup
|
||||
for passage_id, offset in offset_map.items():
|
||||
self.global_offset_map[passage_id] = (passage_file, offset)
|
||||
|
||||
def get_passage(self, passage_id: str) -> Dict[str, Any]:
|
||||
if passage_id in self.global_offset_map:
|
||||
@@ -209,7 +209,6 @@ class LeannBuilder:
|
||||
self.embedding_model,
|
||||
self.embedding_mode,
|
||||
use_server=False,
|
||||
port=5557,
|
||||
)
|
||||
string_ids = [chunk["id"] for chunk in self.chunks]
|
||||
current_backend_kwargs = {**self.backend_kwargs, "dimensions": self.dimensions}
|
||||
@@ -382,9 +381,6 @@ class LeannSearcher:
|
||||
self.embedding_mode = self.meta_data.get(
|
||||
"embedding_mode", "sentence-transformers"
|
||||
)
|
||||
# Backward compatibility with use_mlx
|
||||
if self.meta_data.get("use_mlx", False):
|
||||
self.embedding_mode = "mlx"
|
||||
self.passage_manager = PassageManager(self.meta_data.get("passage_sources", []))
|
||||
backend_factory = BACKEND_REGISTRY.get(backend_name)
|
||||
if backend_factory is None:
|
||||
@@ -402,7 +398,7 @@ class LeannSearcher:
|
||||
prune_ratio: float = 0.0,
|
||||
recompute_embeddings: bool = False,
|
||||
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
||||
zmq_port: int = 5557,
|
||||
expected_zmq_port: Optional[int] = None,
|
||||
**kwargs,
|
||||
) -> List[SearchResult]:
|
||||
print("🔍 DEBUG LeannSearcher.search() called:")
|
||||
@@ -416,7 +412,11 @@ class LeannSearcher:
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
query_embedding = self.backend_impl.compute_query_embedding(query, zmq_port)
|
||||
query_embedding = self.backend_impl.compute_query_embedding(
|
||||
query,
|
||||
expected_zmq_port,
|
||||
use_server_if_available=recompute_embeddings,
|
||||
)
|
||||
print(f" Generated embedding shape: {query_embedding.shape}")
|
||||
embedding_time = time.time() - start_time
|
||||
print(f" Embedding time: {embedding_time} seconds")
|
||||
@@ -430,7 +430,7 @@ class LeannSearcher:
|
||||
prune_ratio=prune_ratio,
|
||||
recompute_embeddings=recompute_embeddings,
|
||||
pruning_strategy=pruning_strategy,
|
||||
zmq_port=zmq_port,
|
||||
expected_zmq_port=expected_zmq_port,
|
||||
**kwargs,
|
||||
)
|
||||
search_time = time.time() - start_time
|
||||
@@ -487,7 +487,7 @@ class LeannChat:
|
||||
prune_ratio: float = 0.0,
|
||||
recompute_embeddings: bool = False,
|
||||
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
||||
zmq_port: int = 5557,
|
||||
expected_zmq_port: Optional[int] = None,
|
||||
llm_kwargs: Optional[Dict[str, Any]] = None,
|
||||
**search_kwargs,
|
||||
):
|
||||
@@ -502,7 +502,7 @@ class LeannChat:
|
||||
prune_ratio=prune_ratio,
|
||||
recompute_embeddings=recompute_embeddings,
|
||||
pruning_strategy=pruning_strategy,
|
||||
zmq_port=zmq_port,
|
||||
expected_zmq_port=expected_zmq_port,
|
||||
**search_kwargs,
|
||||
)
|
||||
context = "\n\n".join([r.text for r in results])
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from abc import ABC, abstractmethod
|
||||
import numpy as np
|
||||
from typing import Dict, Any, List, Literal
|
||||
from typing import Dict, Any, List, Literal, Optional
|
||||
|
||||
|
||||
class LeannBackendBuilderInterface(ABC):
|
||||
@@ -44,7 +44,7 @@ class LeannBackendSearcherInterface(ABC):
|
||||
prune_ratio: float = 0.0,
|
||||
recompute_embeddings: bool = False,
|
||||
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
||||
zmq_port: int = 5557,
|
||||
expected_zmq_port: Optional[int] = None,
|
||||
**kwargs,
|
||||
) -> Dict[str, Any]:
|
||||
"""Search for nearest neighbors
|
||||
@@ -57,7 +57,7 @@ class LeannBackendSearcherInterface(ABC):
|
||||
prune_ratio: Ratio of neighbors to prune via approximate distance (0.0-1.0)
|
||||
recompute_embeddings: Whether to fetch fresh embeddings from server vs use stored PQ codes
|
||||
pruning_strategy: PQ candidate selection strategy - "global" (default), "local", or "proportional"
|
||||
zmq_port: ZMQ port for embedding server communication
|
||||
expected_zmq_port: ZMQ port for embedding server communication
|
||||
**kwargs: Backend-specific parameters
|
||||
|
||||
Returns:
|
||||
@@ -67,13 +67,16 @@ class LeannBackendSearcherInterface(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def compute_query_embedding(
|
||||
self, query: str, zmq_port: int = 5557, use_server_if_available: bool = True
|
||||
self,
|
||||
query: str,
|
||||
expected_zmq_port: Optional[int] = None,
|
||||
use_server_if_available: bool = True,
|
||||
) -> np.ndarray:
|
||||
"""Compute embedding for a query string
|
||||
|
||||
Args:
|
||||
query: The query string to embed
|
||||
zmq_port: ZMQ port for embedding server
|
||||
expected_zmq_port: ZMQ port for embedding server
|
||||
use_server_if_available: Whether to try using embedding server first
|
||||
|
||||
Returns:
|
||||
|
||||
@@ -2,7 +2,7 @@ import json
|
||||
import pickle
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, Literal
|
||||
from typing import Dict, Any, Literal, Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -86,14 +86,17 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
|
||||
return actual_port
|
||||
|
||||
def compute_query_embedding(
|
||||
self, query: str, zmq_port: int = 5557, use_server_if_available: bool = True
|
||||
self,
|
||||
query: str,
|
||||
expected_zmq_port: int = 5557,
|
||||
use_server_if_available: bool = True,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Compute embedding for a query string.
|
||||
|
||||
Args:
|
||||
query: The query string to embed
|
||||
zmq_port: ZMQ port for embedding server
|
||||
expected_zmq_port: ZMQ port for embedding server
|
||||
use_server_if_available: Whether to try using embedding server first
|
||||
|
||||
Returns:
|
||||
@@ -107,7 +110,7 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
|
||||
self.index_dir / f"{self.index_path.name}.meta.json"
|
||||
)
|
||||
zmq_port = self._ensure_server_running(
|
||||
str(passages_source_file), zmq_port
|
||||
str(passages_source_file), expected_zmq_port
|
||||
)
|
||||
|
||||
return self._compute_embedding_via_server([query], zmq_port)[
|
||||
@@ -118,7 +121,7 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
|
||||
print("⏭️ Falling back to direct model loading...")
|
||||
|
||||
# Fallback to direct computation
|
||||
from .api import compute_embeddings
|
||||
from .embedding_compute import compute_embeddings
|
||||
|
||||
embedding_mode = self.meta.get("embedding_mode", "sentence-transformers")
|
||||
return compute_embeddings([query], self.embedding_model, embedding_mode)
|
||||
@@ -165,7 +168,7 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
|
||||
prune_ratio: float = 0.0,
|
||||
recompute_embeddings: bool = False,
|
||||
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
||||
zmq_port: int = 5557,
|
||||
expected_zmq_port: Optional[int] = None,
|
||||
**kwargs,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
@@ -179,7 +182,7 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
|
||||
prune_ratio: Ratio of neighbors to prune via approximate distance (0.0-1.0)
|
||||
recompute_embeddings: Whether to fetch fresh embeddings from server vs use stored PQ codes
|
||||
pruning_strategy: PQ candidate selection strategy - "global" (default), "local", or "proportional"
|
||||
zmq_port: ZMQ port for embedding server communication
|
||||
expected_zmq_port: ZMQ port for embedding server communication
|
||||
**kwargs: Backend-specific parameters (e.g., batch_size, dedup_node_dis, etc.)
|
||||
|
||||
Returns:
|
||||
|
||||
Reference in New Issue
Block a user