fix: use server to emb query only when recompute
This commit is contained in:
@@ -19,7 +19,7 @@ def compute_embeddings(
|
|||||||
model_name: str,
|
model_name: str,
|
||||||
mode: str = "sentence-transformers",
|
mode: str = "sentence-transformers",
|
||||||
use_server: bool = True,
|
use_server: bool = True,
|
||||||
port: int = 5557,
|
port: Optional[int] = None,
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
"""
|
"""
|
||||||
Computes embeddings using different backends.
|
Computes embeddings using different backends.
|
||||||
@@ -38,6 +38,8 @@ def compute_embeddings(
|
|||||||
"""
|
"""
|
||||||
if use_server:
|
if use_server:
|
||||||
# Use embedding server (for search/query)
|
# Use embedding server (for search/query)
|
||||||
|
if port is None:
|
||||||
|
raise ValueError("port is required when use_server is True")
|
||||||
return compute_embeddings_via_server(chunks, model_name, port=port)
|
return compute_embeddings_via_server(chunks, model_name, port=port)
|
||||||
else:
|
else:
|
||||||
# Use direct computation (for build_index)
|
# Use direct computation (for build_index)
|
||||||
@@ -105,21 +107,19 @@ class PassageManager:
|
|||||||
self.global_offset_map = {} # Combined map for fast lookup
|
self.global_offset_map = {} # Combined map for fast lookup
|
||||||
|
|
||||||
for source in passage_sources:
|
for source in passage_sources:
|
||||||
if source["type"] == "jsonl":
|
assert source["type"] == "jsonl", "only jsonl is supported"
|
||||||
passage_file = source["path"]
|
passage_file = source["path"]
|
||||||
index_file = source["index_path"]
|
index_file = source["index_path"]
|
||||||
if not Path(index_file).exists():
|
if not Path(index_file).exists():
|
||||||
raise FileNotFoundError(
|
raise FileNotFoundError(f"Passage index file not found: {index_file}")
|
||||||
f"Passage index file not found: {index_file}"
|
with open(index_file, "rb") as f:
|
||||||
)
|
offset_map = pickle.load(f)
|
||||||
with open(index_file, "rb") as f:
|
self.offset_maps[passage_file] = offset_map
|
||||||
offset_map = pickle.load(f)
|
self.passage_files[passage_file] = passage_file
|
||||||
self.offset_maps[passage_file] = offset_map
|
|
||||||
self.passage_files[passage_file] = passage_file
|
|
||||||
|
|
||||||
# Build global map for O(1) lookup
|
# Build global map for O(1) lookup
|
||||||
for passage_id, offset in offset_map.items():
|
for passage_id, offset in offset_map.items():
|
||||||
self.global_offset_map[passage_id] = (passage_file, offset)
|
self.global_offset_map[passage_id] = (passage_file, offset)
|
||||||
|
|
||||||
def get_passage(self, passage_id: str) -> Dict[str, Any]:
|
def get_passage(self, passage_id: str) -> Dict[str, Any]:
|
||||||
if passage_id in self.global_offset_map:
|
if passage_id in self.global_offset_map:
|
||||||
@@ -209,7 +209,6 @@ class LeannBuilder:
|
|||||||
self.embedding_model,
|
self.embedding_model,
|
||||||
self.embedding_mode,
|
self.embedding_mode,
|
||||||
use_server=False,
|
use_server=False,
|
||||||
port=5557,
|
|
||||||
)
|
)
|
||||||
string_ids = [chunk["id"] for chunk in self.chunks]
|
string_ids = [chunk["id"] for chunk in self.chunks]
|
||||||
current_backend_kwargs = {**self.backend_kwargs, "dimensions": self.dimensions}
|
current_backend_kwargs = {**self.backend_kwargs, "dimensions": self.dimensions}
|
||||||
@@ -382,9 +381,6 @@ class LeannSearcher:
|
|||||||
self.embedding_mode = self.meta_data.get(
|
self.embedding_mode = self.meta_data.get(
|
||||||
"embedding_mode", "sentence-transformers"
|
"embedding_mode", "sentence-transformers"
|
||||||
)
|
)
|
||||||
# Backward compatibility with use_mlx
|
|
||||||
if self.meta_data.get("use_mlx", False):
|
|
||||||
self.embedding_mode = "mlx"
|
|
||||||
self.passage_manager = PassageManager(self.meta_data.get("passage_sources", []))
|
self.passage_manager = PassageManager(self.meta_data.get("passage_sources", []))
|
||||||
backend_factory = BACKEND_REGISTRY.get(backend_name)
|
backend_factory = BACKEND_REGISTRY.get(backend_name)
|
||||||
if backend_factory is None:
|
if backend_factory is None:
|
||||||
@@ -402,7 +398,7 @@ class LeannSearcher:
|
|||||||
prune_ratio: float = 0.0,
|
prune_ratio: float = 0.0,
|
||||||
recompute_embeddings: bool = False,
|
recompute_embeddings: bool = False,
|
||||||
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
||||||
zmq_port: int = 5557,
|
expected_zmq_port: Optional[int] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> List[SearchResult]:
|
) -> List[SearchResult]:
|
||||||
print("🔍 DEBUG LeannSearcher.search() called:")
|
print("🔍 DEBUG LeannSearcher.search() called:")
|
||||||
@@ -416,7 +412,11 @@ class LeannSearcher:
|
|||||||
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
query_embedding = self.backend_impl.compute_query_embedding(query, zmq_port)
|
query_embedding = self.backend_impl.compute_query_embedding(
|
||||||
|
query,
|
||||||
|
expected_zmq_port,
|
||||||
|
use_server_if_available=recompute_embeddings,
|
||||||
|
)
|
||||||
print(f" Generated embedding shape: {query_embedding.shape}")
|
print(f" Generated embedding shape: {query_embedding.shape}")
|
||||||
embedding_time = time.time() - start_time
|
embedding_time = time.time() - start_time
|
||||||
print(f" Embedding time: {embedding_time} seconds")
|
print(f" Embedding time: {embedding_time} seconds")
|
||||||
@@ -430,7 +430,7 @@ class LeannSearcher:
|
|||||||
prune_ratio=prune_ratio,
|
prune_ratio=prune_ratio,
|
||||||
recompute_embeddings=recompute_embeddings,
|
recompute_embeddings=recompute_embeddings,
|
||||||
pruning_strategy=pruning_strategy,
|
pruning_strategy=pruning_strategy,
|
||||||
zmq_port=zmq_port,
|
expected_zmq_port=expected_zmq_port,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
search_time = time.time() - start_time
|
search_time = time.time() - start_time
|
||||||
@@ -487,7 +487,7 @@ class LeannChat:
|
|||||||
prune_ratio: float = 0.0,
|
prune_ratio: float = 0.0,
|
||||||
recompute_embeddings: bool = False,
|
recompute_embeddings: bool = False,
|
||||||
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
||||||
zmq_port: int = 5557,
|
expected_zmq_port: Optional[int] = None,
|
||||||
llm_kwargs: Optional[Dict[str, Any]] = None,
|
llm_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
**search_kwargs,
|
**search_kwargs,
|
||||||
):
|
):
|
||||||
@@ -502,7 +502,7 @@ class LeannChat:
|
|||||||
prune_ratio=prune_ratio,
|
prune_ratio=prune_ratio,
|
||||||
recompute_embeddings=recompute_embeddings,
|
recompute_embeddings=recompute_embeddings,
|
||||||
pruning_strategy=pruning_strategy,
|
pruning_strategy=pruning_strategy,
|
||||||
zmq_port=zmq_port,
|
expected_zmq_port=expected_zmq_port,
|
||||||
**search_kwargs,
|
**search_kwargs,
|
||||||
)
|
)
|
||||||
context = "\n\n".join([r.text for r in results])
|
context = "\n\n".join([r.text for r in results])
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from typing import Dict, Any, List, Literal
|
from typing import Dict, Any, List, Literal, Optional
|
||||||
|
|
||||||
|
|
||||||
class LeannBackendBuilderInterface(ABC):
|
class LeannBackendBuilderInterface(ABC):
|
||||||
@@ -44,7 +44,7 @@ class LeannBackendSearcherInterface(ABC):
|
|||||||
prune_ratio: float = 0.0,
|
prune_ratio: float = 0.0,
|
||||||
recompute_embeddings: bool = False,
|
recompute_embeddings: bool = False,
|
||||||
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
||||||
zmq_port: int = 5557,
|
expected_zmq_port: Optional[int] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""Search for nearest neighbors
|
"""Search for nearest neighbors
|
||||||
@@ -57,7 +57,7 @@ class LeannBackendSearcherInterface(ABC):
|
|||||||
prune_ratio: Ratio of neighbors to prune via approximate distance (0.0-1.0)
|
prune_ratio: Ratio of neighbors to prune via approximate distance (0.0-1.0)
|
||||||
recompute_embeddings: Whether to fetch fresh embeddings from server vs use stored PQ codes
|
recompute_embeddings: Whether to fetch fresh embeddings from server vs use stored PQ codes
|
||||||
pruning_strategy: PQ candidate selection strategy - "global" (default), "local", or "proportional"
|
pruning_strategy: PQ candidate selection strategy - "global" (default), "local", or "proportional"
|
||||||
zmq_port: ZMQ port for embedding server communication
|
expected_zmq_port: ZMQ port for embedding server communication
|
||||||
**kwargs: Backend-specific parameters
|
**kwargs: Backend-specific parameters
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@@ -67,13 +67,16 @@ class LeannBackendSearcherInterface(ABC):
|
|||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def compute_query_embedding(
|
def compute_query_embedding(
|
||||||
self, query: str, zmq_port: int = 5557, use_server_if_available: bool = True
|
self,
|
||||||
|
query: str,
|
||||||
|
expected_zmq_port: Optional[int] = None,
|
||||||
|
use_server_if_available: bool = True,
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
"""Compute embedding for a query string
|
"""Compute embedding for a query string
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
query: The query string to embed
|
query: The query string to embed
|
||||||
zmq_port: ZMQ port for embedding server
|
expected_zmq_port: ZMQ port for embedding server
|
||||||
use_server_if_available: Whether to try using embedding server first
|
use_server_if_available: Whether to try using embedding server first
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ import json
|
|||||||
import pickle
|
import pickle
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, Any, Literal
|
from typing import Dict, Any, Literal, Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@@ -86,14 +86,17 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
|
|||||||
return actual_port
|
return actual_port
|
||||||
|
|
||||||
def compute_query_embedding(
|
def compute_query_embedding(
|
||||||
self, query: str, zmq_port: int = 5557, use_server_if_available: bool = True
|
self,
|
||||||
|
query: str,
|
||||||
|
expected_zmq_port: int = 5557,
|
||||||
|
use_server_if_available: bool = True,
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
"""
|
"""
|
||||||
Compute embedding for a query string.
|
Compute embedding for a query string.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
query: The query string to embed
|
query: The query string to embed
|
||||||
zmq_port: ZMQ port for embedding server
|
expected_zmq_port: ZMQ port for embedding server
|
||||||
use_server_if_available: Whether to try using embedding server first
|
use_server_if_available: Whether to try using embedding server first
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@@ -107,7 +110,7 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
|
|||||||
self.index_dir / f"{self.index_path.name}.meta.json"
|
self.index_dir / f"{self.index_path.name}.meta.json"
|
||||||
)
|
)
|
||||||
zmq_port = self._ensure_server_running(
|
zmq_port = self._ensure_server_running(
|
||||||
str(passages_source_file), zmq_port
|
str(passages_source_file), expected_zmq_port
|
||||||
)
|
)
|
||||||
|
|
||||||
return self._compute_embedding_via_server([query], zmq_port)[
|
return self._compute_embedding_via_server([query], zmq_port)[
|
||||||
@@ -118,7 +121,7 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
|
|||||||
print("⏭️ Falling back to direct model loading...")
|
print("⏭️ Falling back to direct model loading...")
|
||||||
|
|
||||||
# Fallback to direct computation
|
# Fallback to direct computation
|
||||||
from .api import compute_embeddings
|
from .embedding_compute import compute_embeddings
|
||||||
|
|
||||||
embedding_mode = self.meta.get("embedding_mode", "sentence-transformers")
|
embedding_mode = self.meta.get("embedding_mode", "sentence-transformers")
|
||||||
return compute_embeddings([query], self.embedding_model, embedding_mode)
|
return compute_embeddings([query], self.embedding_model, embedding_mode)
|
||||||
@@ -165,7 +168,7 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
|
|||||||
prune_ratio: float = 0.0,
|
prune_ratio: float = 0.0,
|
||||||
recompute_embeddings: bool = False,
|
recompute_embeddings: bool = False,
|
||||||
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
||||||
zmq_port: int = 5557,
|
expected_zmq_port: Optional[int] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
@@ -179,7 +182,7 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
|
|||||||
prune_ratio: Ratio of neighbors to prune via approximate distance (0.0-1.0)
|
prune_ratio: Ratio of neighbors to prune via approximate distance (0.0-1.0)
|
||||||
recompute_embeddings: Whether to fetch fresh embeddings from server vs use stored PQ codes
|
recompute_embeddings: Whether to fetch fresh embeddings from server vs use stored PQ codes
|
||||||
pruning_strategy: PQ candidate selection strategy - "global" (default), "local", or "proportional"
|
pruning_strategy: PQ candidate selection strategy - "global" (default), "local", or "proportional"
|
||||||
zmq_port: ZMQ port for embedding server communication
|
expected_zmq_port: ZMQ port for embedding server communication
|
||||||
**kwargs: Backend-specific parameters (e.g., batch_size, dedup_node_dis, etc.)
|
**kwargs: Backend-specific parameters (e.g., batch_size, dedup_node_dis, etc.)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
|||||||
Reference in New Issue
Block a user