fix: cache the loaded model
This commit is contained in:
@@ -70,9 +70,7 @@ async def main(args):
|
|||||||
# )
|
# )
|
||||||
|
|
||||||
print(f"You: {query}")
|
print(f"You: {query}")
|
||||||
chat_response = chat.ask(
|
chat_response = chat.ask(query, top_k=20, recompute_embeddings=True, complexity=32)
|
||||||
query, top_k=20, recompute_beighbor_embeddings=True, complexity=32
|
|
||||||
)
|
|
||||||
print(f"Leann: {chat_response}")
|
print(f"Leann: {chat_response}")
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -4,7 +4,6 @@ import struct
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, Any, List, Literal
|
from typing import Dict, Any, List, Literal
|
||||||
import contextlib
|
import contextlib
|
||||||
import pickle
|
|
||||||
|
|
||||||
from leann.searcher_base import BaseSearcher
|
from leann.searcher_base import BaseSearcher
|
||||||
from leann.registry import register_backend
|
from leann.registry import register_backend
|
||||||
@@ -70,7 +69,6 @@ class DiskannBuilder(LeannBackendBuilderInterface):
|
|||||||
data_filename = f"{index_prefix}_data.bin"
|
data_filename = f"{index_prefix}_data.bin"
|
||||||
_write_vectors_to_bin(data, index_dir / data_filename)
|
_write_vectors_to_bin(data, index_dir / data_filename)
|
||||||
|
|
||||||
|
|
||||||
build_kwargs = {**self.build_params, **kwargs}
|
build_kwargs = {**self.build_params, **kwargs}
|
||||||
metric_enum = _get_diskann_metrics().get(
|
metric_enum = _get_diskann_metrics().get(
|
||||||
build_kwargs.get("distance_metric", "mips").lower()
|
build_kwargs.get("distance_metric", "mips").lower()
|
||||||
@@ -207,8 +205,7 @@ class DiskannSearcher(BaseSearcher):
|
|||||||
)
|
)
|
||||||
|
|
||||||
string_labels = [
|
string_labels = [
|
||||||
[str(int_label) for int_label in batch_labels]
|
[str(int_label) for int_label in batch_labels] for batch_labels in labels
|
||||||
for batch_labels in labels
|
|
||||||
]
|
]
|
||||||
|
|
||||||
return {"labels": string_labels, "distances": distances}
|
return {"labels": string_labels, "distances": distances}
|
||||||
|
|||||||
@@ -1,10 +1,9 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, Any, List, Literal
|
from typing import Dict, Any, List, Literal, Optional
|
||||||
import pickle
|
|
||||||
import shutil
|
import shutil
|
||||||
import time
|
import logging
|
||||||
|
|
||||||
from leann.searcher_base import BaseSearcher
|
from leann.searcher_base import BaseSearcher
|
||||||
from .convert_to_csr import convert_hnsw_graph_to_csr
|
from .convert_to_csr import convert_hnsw_graph_to_csr
|
||||||
@@ -16,6 +15,8 @@ from leann.interface import (
|
|||||||
LeannBackendSearcherInterface,
|
LeannBackendSearcherInterface,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def get_metric_map():
|
def get_metric_map():
|
||||||
from . import faiss # type: ignore
|
from . import faiss # type: ignore
|
||||||
@@ -57,9 +58,9 @@ class HNSWBuilder(LeannBackendBuilderInterface):
|
|||||||
index_dir.mkdir(parents=True, exist_ok=True)
|
index_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
if data.dtype != np.float32:
|
if data.dtype != np.float32:
|
||||||
|
logger.warning(f"Converting data to float32, shape: {data.shape}")
|
||||||
data = data.astype(np.float32)
|
data = data.astype(np.float32)
|
||||||
|
|
||||||
|
|
||||||
metric_enum = get_metric_map().get(self.distance_metric.lower())
|
metric_enum = get_metric_map().get(self.distance_metric.lower())
|
||||||
if metric_enum is None:
|
if metric_enum is None:
|
||||||
raise ValueError(f"Unsupported distance_metric '{self.distance_metric}'.")
|
raise ValueError(f"Unsupported distance_metric '{self.distance_metric}'.")
|
||||||
@@ -81,7 +82,7 @@ class HNSWBuilder(LeannBackendBuilderInterface):
|
|||||||
def _convert_to_csr(self, index_file: Path):
|
def _convert_to_csr(self, index_file: Path):
|
||||||
"""Convert built index to CSR format"""
|
"""Convert built index to CSR format"""
|
||||||
mode_str = "CSR-pruned" if self.is_recompute else "CSR-standard"
|
mode_str = "CSR-pruned" if self.is_recompute else "CSR-standard"
|
||||||
print(f"INFO: Converting HNSW index to {mode_str} format...")
|
logger.info(f"INFO: Converting HNSW index to {mode_str} format...")
|
||||||
|
|
||||||
csr_temp_file = index_file.with_suffix(".csr.tmp")
|
csr_temp_file = index_file.with_suffix(".csr.tmp")
|
||||||
|
|
||||||
@@ -90,11 +91,11 @@ class HNSWBuilder(LeannBackendBuilderInterface):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if success:
|
if success:
|
||||||
print("✅ CSR conversion successful.")
|
logger.info("✅ CSR conversion successful.")
|
||||||
index_file_old = index_file.with_suffix(".old")
|
index_file_old = index_file.with_suffix(".old")
|
||||||
shutil.move(str(index_file), str(index_file_old))
|
shutil.move(str(index_file), str(index_file_old))
|
||||||
shutil.move(str(csr_temp_file), str(index_file))
|
shutil.move(str(csr_temp_file), str(index_file))
|
||||||
print(
|
logger.info(
|
||||||
f"INFO: Replaced original index with {mode_str} version at '{index_file}'"
|
f"INFO: Replaced original index with {mode_str} version at '{index_file}'"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@@ -131,14 +132,12 @@ class HNSWSearcher(BaseSearcher):
|
|||||||
|
|
||||||
hnsw_config = faiss.HNSWIndexConfig()
|
hnsw_config = faiss.HNSWIndexConfig()
|
||||||
hnsw_config.is_compact = self.is_compact
|
hnsw_config.is_compact = self.is_compact
|
||||||
hnsw_config.is_recompute = self.is_pruned or kwargs.get("is_recompute", False)
|
hnsw_config.is_recompute = (
|
||||||
|
self.is_pruned
|
||||||
if self.is_pruned and not hnsw_config.is_recompute:
|
) # In C++ code, it's called is_recompute, but it's only for loading IIUC.
|
||||||
raise RuntimeError("Index is pruned but recompute is disabled.")
|
|
||||||
|
|
||||||
self._index = faiss.read_index(str(index_file), faiss.IO_FLAG_MMAP, hnsw_config)
|
self._index = faiss.read_index(str(index_file), faiss.IO_FLAG_MMAP, hnsw_config)
|
||||||
|
|
||||||
|
|
||||||
def search(
|
def search(
|
||||||
self,
|
self,
|
||||||
query: np.ndarray,
|
query: np.ndarray,
|
||||||
@@ -146,9 +145,9 @@ class HNSWSearcher(BaseSearcher):
|
|||||||
complexity: int = 64,
|
complexity: int = 64,
|
||||||
beam_width: int = 1,
|
beam_width: int = 1,
|
||||||
prune_ratio: float = 0.0,
|
prune_ratio: float = 0.0,
|
||||||
recompute_embeddings: bool = False,
|
recompute_embeddings: bool = True,
|
||||||
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
||||||
zmq_port: int = 5557,
|
expected_zmq_port: Optional[int] = None,
|
||||||
batch_size: int = 0,
|
batch_size: int = 0,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
@@ -166,7 +165,7 @@ class HNSWSearcher(BaseSearcher):
|
|||||||
- "global": Use global PQ queue size for selection (default)
|
- "global": Use global PQ queue size for selection (default)
|
||||||
- "local": Local pruning, sort and select best candidates
|
- "local": Local pruning, sort and select best candidates
|
||||||
- "proportional": Base selection on new neighbor count ratio
|
- "proportional": Base selection on new neighbor count ratio
|
||||||
zmq_port: ZMQ port for embedding server
|
expected_zmq_port: ZMQ port for embedding server
|
||||||
batch_size: Neighbor processing batch size, 0=disabled (HNSW-specific)
|
batch_size: Neighbor processing batch size, 0=disabled (HNSW-specific)
|
||||||
**kwargs: Additional HNSW-specific parameters (for legacy compatibility)
|
**kwargs: Additional HNSW-specific parameters (for legacy compatibility)
|
||||||
|
|
||||||
@@ -175,15 +174,9 @@ class HNSWSearcher(BaseSearcher):
|
|||||||
"""
|
"""
|
||||||
from . import faiss # type: ignore
|
from . import faiss # type: ignore
|
||||||
|
|
||||||
# Use recompute_embeddings parameter
|
if not recompute_embeddings:
|
||||||
use_recompute = recompute_embeddings or self.is_pruned
|
if self.is_pruned:
|
||||||
if use_recompute:
|
raise RuntimeError("Recompute is required for pruned index.")
|
||||||
meta_file_path = self.index_dir / f"{self.index_path.name}.meta.json"
|
|
||||||
if not meta_file_path.exists():
|
|
||||||
raise RuntimeError(
|
|
||||||
f"FATAL: Recompute enabled but metadata file not found: {meta_file_path}"
|
|
||||||
)
|
|
||||||
self._ensure_server_running(str(meta_file_path), port=zmq_port, **kwargs)
|
|
||||||
|
|
||||||
if query.dtype != np.float32:
|
if query.dtype != np.float32:
|
||||||
query = query.astype(np.float32)
|
query = query.astype(np.float32)
|
||||||
@@ -191,7 +184,7 @@ class HNSWSearcher(BaseSearcher):
|
|||||||
faiss.normalize_L2(query)
|
faiss.normalize_L2(query)
|
||||||
|
|
||||||
params = faiss.SearchParametersHNSW()
|
params = faiss.SearchParametersHNSW()
|
||||||
params.zmq_port = zmq_port
|
params.zmq_port = expected_zmq_port
|
||||||
params.efSearch = complexity
|
params.efSearch = complexity
|
||||||
params.beam_size = beam_width
|
params.beam_size = beam_width
|
||||||
|
|
||||||
@@ -228,8 +221,7 @@ class HNSWSearcher(BaseSearcher):
|
|||||||
)
|
)
|
||||||
|
|
||||||
string_labels = [
|
string_labels = [
|
||||||
[str(int_label) for int_label in batch_labels]
|
[str(int_label) for int_label in batch_labels] for batch_labels in labels
|
||||||
for batch_labels in labels
|
|
||||||
]
|
]
|
||||||
|
|
||||||
return {"labels": string_labels, "distances": distances}
|
return {"labels": string_labels, "distances": distances}
|
||||||
|
|||||||
@@ -29,7 +29,6 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
def create_hnsw_embedding_server(
|
def create_hnsw_embedding_server(
|
||||||
passages_file: Optional[str] = None,
|
passages_file: Optional[str] = None,
|
||||||
passages_data: Optional[Dict[str, str]] = None,
|
|
||||||
zmq_port: int = 5555,
|
zmq_port: int = 5555,
|
||||||
model_name: str = "sentence-transformers/all-mpnet-base-v2",
|
model_name: str = "sentence-transformers/all-mpnet-base-v2",
|
||||||
distance_metric: str = "mips",
|
distance_metric: str = "mips",
|
||||||
@@ -39,12 +38,6 @@ def create_hnsw_embedding_server(
|
|||||||
Create and start a ZMQ-based embedding server for HNSW backend.
|
Create and start a ZMQ-based embedding server for HNSW backend.
|
||||||
Simplified version using unified embedding computation module.
|
Simplified version using unified embedding computation module.
|
||||||
"""
|
"""
|
||||||
# Auto-detect mode based on model name if not explicitly set
|
|
||||||
if embedding_mode == "sentence-transformers" and model_name.startswith(
|
|
||||||
"text-embedding-"
|
|
||||||
):
|
|
||||||
embedding_mode = "openai"
|
|
||||||
|
|
||||||
print(f"Starting HNSW server on port {zmq_port} with model {model_name}")
|
print(f"Starting HNSW server on port {zmq_port} with model {model_name}")
|
||||||
print(f"Using embedding mode: {embedding_mode}")
|
print(f"Using embedding mode: {embedding_mode}")
|
||||||
|
|
||||||
@@ -64,6 +57,7 @@ def create_hnsw_embedding_server(
|
|||||||
finally:
|
finally:
|
||||||
sys.path.pop(0)
|
sys.path.pop(0)
|
||||||
|
|
||||||
|
|
||||||
# Check port availability
|
# Check port availability
|
||||||
import socket
|
import socket
|
||||||
|
|
||||||
@@ -84,7 +78,9 @@ def create_hnsw_embedding_server(
|
|||||||
meta = json.load(f)
|
meta = json.load(f)
|
||||||
|
|
||||||
passages = PassageManager(meta["passage_sources"])
|
passages = PassageManager(meta["passage_sources"])
|
||||||
print(f"Loaded PassageManager with {len(passages.global_offset_map)} passages from metadata")
|
print(
|
||||||
|
f"Loaded PassageManager with {len(passages.global_offset_map)} passages from metadata"
|
||||||
|
)
|
||||||
|
|
||||||
def zmq_server_thread():
|
def zmq_server_thread():
|
||||||
"""ZMQ server thread"""
|
"""ZMQ server thread"""
|
||||||
@@ -112,7 +108,7 @@ def create_hnsw_embedding_server(
|
|||||||
f"Processing direct text embedding request for {len(request_payload)} texts in {embedding_mode} mode"
|
f"Processing direct text embedding request for {len(request_payload)} texts in {embedding_mode} mode"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Use unified embedding computation
|
# Use unified embedding computation (now with model caching)
|
||||||
embeddings = compute_embeddings(
|
embeddings = compute_embeddings(
|
||||||
request_payload, model_name, mode=embedding_mode
|
request_payload, model_name, mode=embedding_mode
|
||||||
)
|
)
|
||||||
@@ -148,15 +144,15 @@ def create_hnsw_embedding_server(
|
|||||||
texts.append(txt)
|
texts.append(txt)
|
||||||
except KeyError:
|
except KeyError:
|
||||||
print(f"ERROR: Passage ID {nid} not found")
|
print(f"ERROR: Passage ID {nid} not found")
|
||||||
raise RuntimeError(f"FATAL: Passage with ID {nid} not found")
|
raise RuntimeError(
|
||||||
|
f"FATAL: Passage with ID {nid} not found"
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"ERROR: Exception looking up passage ID {nid}: {e}")
|
print(f"ERROR: Exception looking up passage ID {nid}: {e}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
# Process embeddings
|
# Process embeddings
|
||||||
embeddings = compute_embeddings(
|
embeddings = compute_embeddings(texts, model_name, mode=embedding_mode)
|
||||||
texts, model_name, mode=embedding_mode
|
|
||||||
)
|
|
||||||
print(
|
print(
|
||||||
f"INFO: Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
|
f"INFO: Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
|
||||||
)
|
)
|
||||||
@@ -204,7 +200,9 @@ def create_hnsw_embedding_server(
|
|||||||
passage_data = passages.get_passage(str(nid))
|
passage_data = passages.get_passage(str(nid))
|
||||||
txt = passage_data["text"]
|
txt = passage_data["text"]
|
||||||
if not txt:
|
if not txt:
|
||||||
raise RuntimeError(f"FATAL: Empty text for passage ID {nid}")
|
raise RuntimeError(
|
||||||
|
f"FATAL: Empty text for passage ID {nid}"
|
||||||
|
)
|
||||||
texts.append(txt)
|
texts.append(txt)
|
||||||
except KeyError:
|
except KeyError:
|
||||||
raise RuntimeError(f"FATAL: Passage with ID {nid} not found")
|
raise RuntimeError(f"FATAL: Passage with ID {nid} not found")
|
||||||
|
|||||||
Submodule packages/leann-backend-hnsw/third_party/msgpack-c updated: 9b801f087a...a0b2ec09da
@@ -5,7 +5,9 @@ with the correct, original embedding logic from the user's reference code.
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
import pickle
|
import pickle
|
||||||
|
from leann.interface import LeannBackendSearcherInterface
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import time
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Dict, Any, Optional, Literal
|
from typing import List, Dict, Any, Optional, Literal
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
@@ -126,6 +128,7 @@ class PassageManager:
|
|||||||
def get_passage(self, passage_id: str) -> Dict[str, Any]:
|
def get_passage(self, passage_id: str) -> Dict[str, Any]:
|
||||||
if passage_id in self.global_offset_map:
|
if passage_id in self.global_offset_map:
|
||||||
passage_file, offset = self.global_offset_map[passage_id]
|
passage_file, offset = self.global_offset_map[passage_id]
|
||||||
|
# Lazy file opening - only open when needed
|
||||||
with open(passage_file, "r", encoding="utf-8") as f:
|
with open(passage_file, "r", encoding="utf-8") as f:
|
||||||
f.seek(offset)
|
f.seek(offset)
|
||||||
return json.loads(f.readline())
|
return json.loads(f.readline())
|
||||||
@@ -373,10 +376,12 @@ class LeannBuilder:
|
|||||||
|
|
||||||
class LeannSearcher:
|
class LeannSearcher:
|
||||||
def __init__(self, index_path: str, enable_warmup: bool = False, **backend_kwargs):
|
def __init__(self, index_path: str, enable_warmup: bool = False, **backend_kwargs):
|
||||||
meta_path_str = f"{index_path}.meta.json"
|
self.meta_path_str = f"{index_path}.meta.json"
|
||||||
if not Path(meta_path_str).exists():
|
if not Path(self.meta_path_str).exists():
|
||||||
raise FileNotFoundError(f"Leann metadata file not found at {meta_path_str}")
|
raise FileNotFoundError(
|
||||||
with open(meta_path_str, "r", encoding="utf-8") as f:
|
f"Leann metadata file not found at {self.meta_path_str}"
|
||||||
|
)
|
||||||
|
with open(self.meta_path_str, "r", encoding="utf-8") as f:
|
||||||
self.meta_data = json.load(f)
|
self.meta_data = json.load(f)
|
||||||
backend_name = self.meta_data["backend_name"]
|
backend_name = self.meta_data["backend_name"]
|
||||||
self.embedding_model = self.meta_data["embedding_model"]
|
self.embedding_model = self.meta_data["embedding_model"]
|
||||||
@@ -390,7 +395,9 @@ class LeannSearcher:
|
|||||||
raise ValueError(f"Backend '{backend_name}' not found.")
|
raise ValueError(f"Backend '{backend_name}' not found.")
|
||||||
final_kwargs = {**self.meta_data.get("backend_kwargs", {}), **backend_kwargs}
|
final_kwargs = {**self.meta_data.get("backend_kwargs", {}), **backend_kwargs}
|
||||||
final_kwargs["enable_warmup"] = enable_warmup
|
final_kwargs["enable_warmup"] = enable_warmup
|
||||||
self.backend_impl = backend_factory.searcher(index_path, **final_kwargs)
|
self.backend_impl: LeannBackendSearcherInterface = backend_factory.searcher(
|
||||||
|
index_path, **final_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
def search(
|
def search(
|
||||||
self,
|
self,
|
||||||
@@ -399,9 +406,9 @@ class LeannSearcher:
|
|||||||
complexity: int = 64,
|
complexity: int = 64,
|
||||||
beam_width: int = 1,
|
beam_width: int = 1,
|
||||||
prune_ratio: float = 0.0,
|
prune_ratio: float = 0.0,
|
||||||
recompute_embeddings: bool = False,
|
recompute_embeddings: bool = True,
|
||||||
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
||||||
expected_zmq_port: Optional[int] = None,
|
expected_zmq_port: int = 5557,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> List[SearchResult]:
|
) -> List[SearchResult]:
|
||||||
print("🔍 DEBUG LeannSearcher.search() called:")
|
print("🔍 DEBUG LeannSearcher.search() called:")
|
||||||
@@ -409,16 +416,21 @@ class LeannSearcher:
|
|||||||
print(f" Top_k: {top_k}")
|
print(f" Top_k: {top_k}")
|
||||||
print(f" Additional kwargs: {kwargs}")
|
print(f" Additional kwargs: {kwargs}")
|
||||||
|
|
||||||
# Use backend's compute_query_embedding method
|
|
||||||
# This will automatically use embedding server if available and needed
|
|
||||||
import time
|
|
||||||
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
|
zmq_port = None
|
||||||
|
if recompute_embeddings:
|
||||||
|
zmq_port = self.backend_impl._ensure_server_running(
|
||||||
|
self.meta_path_str,
|
||||||
|
port=expected_zmq_port,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
del expected_zmq_port
|
||||||
|
|
||||||
query_embedding = self.backend_impl.compute_query_embedding(
|
query_embedding = self.backend_impl.compute_query_embedding(
|
||||||
query,
|
query,
|
||||||
expected_zmq_port,
|
|
||||||
use_server_if_available=recompute_embeddings,
|
use_server_if_available=recompute_embeddings,
|
||||||
|
zmq_port=zmq_port,
|
||||||
)
|
)
|
||||||
print(f" Generated embedding shape: {query_embedding.shape}")
|
print(f" Generated embedding shape: {query_embedding.shape}")
|
||||||
embedding_time = time.time() - start_time
|
embedding_time = time.time() - start_time
|
||||||
@@ -433,7 +445,7 @@ class LeannSearcher:
|
|||||||
prune_ratio=prune_ratio,
|
prune_ratio=prune_ratio,
|
||||||
recompute_embeddings=recompute_embeddings,
|
recompute_embeddings=recompute_embeddings,
|
||||||
pruning_strategy=pruning_strategy,
|
pruning_strategy=pruning_strategy,
|
||||||
expected_zmq_port=expected_zmq_port,
|
expected_zmq_port=zmq_port,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
search_time = time.time() - start_time
|
search_time = time.time() - start_time
|
||||||
@@ -488,10 +500,10 @@ class LeannChat:
|
|||||||
complexity: int = 64,
|
complexity: int = 64,
|
||||||
beam_width: int = 1,
|
beam_width: int = 1,
|
||||||
prune_ratio: float = 0.0,
|
prune_ratio: float = 0.0,
|
||||||
recompute_embeddings: bool = False,
|
recompute_embeddings: bool = True,
|
||||||
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
||||||
expected_zmq_port: Optional[int] = None,
|
|
||||||
llm_kwargs: Optional[Dict[str, Any]] = None,
|
llm_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
|
expected_zmq_port: int = 5557,
|
||||||
**search_kwargs,
|
**search_kwargs,
|
||||||
):
|
):
|
||||||
if llm_kwargs is None:
|
if llm_kwargs is None:
|
||||||
|
|||||||
@@ -6,11 +6,14 @@ Preserves all optimization parameters to ensure performance
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from typing import List
|
from typing import List, Dict, Any, Optional
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Global model cache to avoid repeated loading
|
||||||
|
_model_cache: Dict[str, Any] = {}
|
||||||
|
|
||||||
|
|
||||||
def compute_embeddings(
|
def compute_embeddings(
|
||||||
texts: List[str], model_name: str, mode: str = "sentence-transformers",is_build: bool = False
|
texts: List[str], model_name: str, mode: str = "sentence-transformers",is_build: bool = False
|
||||||
@@ -45,25 +48,12 @@ def compute_embeddings_sentence_transformers(
|
|||||||
is_build: bool = False,
|
is_build: bool = False,
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
"""
|
"""
|
||||||
Compute embeddings using SentenceTransformer
|
Compute embeddings using SentenceTransformer with model caching
|
||||||
Preserves all optimization parameters to ensure consistency with original embedding_server
|
|
||||||
|
|
||||||
Args:
|
|
||||||
texts: List of texts to compute embeddings for
|
|
||||||
model_name: SentenceTransformer model name
|
|
||||||
use_fp16: Whether to use FP16 precision
|
|
||||||
device: Device selection ('auto', 'cuda', 'mps', 'cpu')
|
|
||||||
batch_size: Batch size for processing
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Normalized embeddings array, shape: (len(texts), embedding_dim)
|
|
||||||
"""
|
"""
|
||||||
print(
|
print(
|
||||||
f"INFO: Computing embeddings for {len(texts)} texts using SentenceTransformer, model: '{model_name}'"
|
f"INFO: Computing embeddings for {len(texts)} texts using SentenceTransformer, model: '{model_name}'"
|
||||||
)
|
)
|
||||||
|
|
||||||
from sentence_transformers import SentenceTransformer
|
|
||||||
|
|
||||||
# Auto-detect device
|
# Auto-detect device
|
||||||
if device == "auto":
|
if device == "auto":
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
@@ -73,62 +63,72 @@ def compute_embeddings_sentence_transformers(
|
|||||||
else:
|
else:
|
||||||
device = "cpu"
|
device = "cpu"
|
||||||
|
|
||||||
print(f"INFO: Using device: {device}")
|
# Create cache key
|
||||||
|
cache_key = f"sentence_transformers_{model_name}_{device}_{use_fp16}"
|
||||||
|
|
||||||
# Prepare model and tokenizer optimization parameters (consistent with original embedding_server)
|
# Check if model is already cached
|
||||||
model_kwargs = {
|
if cache_key in _model_cache:
|
||||||
"torch_dtype": torch.float16 if use_fp16 else torch.float32,
|
print(f"INFO: Using cached model: {model_name}")
|
||||||
"low_cpu_mem_usage": True,
|
model = _model_cache[cache_key]
|
||||||
"_fast_init": True, # Skip weight initialization checks for faster loading
|
else:
|
||||||
}
|
print(f"INFO: Loading and caching SentenceTransformer model: {model_name}")
|
||||||
|
from sentence_transformers import SentenceTransformer
|
||||||
|
|
||||||
tokenizer_kwargs = {
|
print(f"INFO: Using device: {device}")
|
||||||
"use_fast": True, # Use fast tokenizer for better runtime performance
|
|
||||||
}
|
|
||||||
|
|
||||||
# Load SentenceTransformer (try local first, then network)
|
# Prepare model and tokenizer optimization parameters
|
||||||
print(f"INFO: Loading SentenceTransformer model: {model_name}")
|
model_kwargs = {
|
||||||
|
"torch_dtype": torch.float16 if use_fp16 else torch.float32,
|
||||||
|
"low_cpu_mem_usage": True,
|
||||||
|
"_fast_init": True,
|
||||||
|
}
|
||||||
|
|
||||||
try:
|
tokenizer_kwargs = {
|
||||||
# Try local loading (avoid network delays)
|
"use_fast": True,
|
||||||
model_kwargs["local_files_only"] = True
|
}
|
||||||
tokenizer_kwargs["local_files_only"] = True
|
|
||||||
|
|
||||||
model = SentenceTransformer(
|
|
||||||
model_name,
|
|
||||||
device=device,
|
|
||||||
model_kwargs=model_kwargs,
|
|
||||||
tokenizer_kwargs=tokenizer_kwargs,
|
|
||||||
local_files_only=True,
|
|
||||||
)
|
|
||||||
print("✅ Model loaded successfully! (local + optimized)")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Local loading failed ({e}), trying network download...")
|
|
||||||
# Fallback to network loading
|
|
||||||
model_kwargs["local_files_only"] = False
|
|
||||||
tokenizer_kwargs["local_files_only"] = False
|
|
||||||
|
|
||||||
model = SentenceTransformer(
|
|
||||||
model_name,
|
|
||||||
device=device,
|
|
||||||
model_kwargs=model_kwargs,
|
|
||||||
tokenizer_kwargs=tokenizer_kwargs,
|
|
||||||
local_files_only=False,
|
|
||||||
)
|
|
||||||
print("✅ Model loaded successfully! (network + optimized)")
|
|
||||||
|
|
||||||
# Apply additional optimizations (if supported)
|
|
||||||
if use_fp16 and device in ["cuda", "mps"]:
|
|
||||||
try:
|
try:
|
||||||
model = model.half()
|
# Try local loading first
|
||||||
model = torch.compile(model)
|
model_kwargs["local_files_only"] = True
|
||||||
print(f"✅ Using FP16 precision and compile optimization: {model_name}")
|
tokenizer_kwargs["local_files_only"] = True
|
||||||
except Exception as e:
|
|
||||||
print(
|
|
||||||
f"FP16 or compile optimization failed, continuing with default settings: {e}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Compute embeddings (using SentenceTransformer's optimized implementation)
|
model = SentenceTransformer(
|
||||||
|
model_name,
|
||||||
|
device=device,
|
||||||
|
model_kwargs=model_kwargs,
|
||||||
|
tokenizer_kwargs=tokenizer_kwargs,
|
||||||
|
local_files_only=True,
|
||||||
|
)
|
||||||
|
print("✅ Model loaded successfully! (local + optimized)")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Local loading failed ({e}), trying network download...")
|
||||||
|
# Fallback to network loading
|
||||||
|
model_kwargs["local_files_only"] = False
|
||||||
|
tokenizer_kwargs["local_files_only"] = False
|
||||||
|
|
||||||
|
model = SentenceTransformer(
|
||||||
|
model_name,
|
||||||
|
device=device,
|
||||||
|
model_kwargs=model_kwargs,
|
||||||
|
tokenizer_kwargs=tokenizer_kwargs,
|
||||||
|
local_files_only=False,
|
||||||
|
)
|
||||||
|
print("✅ Model loaded successfully! (network + optimized)")
|
||||||
|
|
||||||
|
# Apply additional optimizations (if supported)
|
||||||
|
if use_fp16 and device in ["cuda", "mps"]:
|
||||||
|
try:
|
||||||
|
model = model.half()
|
||||||
|
model = torch.compile(model)
|
||||||
|
print(f"✅ Using FP16 precision and compile optimization: {model_name}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"FP16 or compile optimization failed: {e}")
|
||||||
|
|
||||||
|
# Cache the model
|
||||||
|
_model_cache[cache_key] = model
|
||||||
|
print(f"✅ Model cached: {cache_key}")
|
||||||
|
|
||||||
|
# Compute embeddings
|
||||||
print("INFO: Starting embedding computation...")
|
print("INFO: Starting embedding computation...")
|
||||||
|
|
||||||
embeddings = model.encode(
|
embeddings = model.encode(
|
||||||
@@ -136,7 +136,7 @@ def compute_embeddings_sentence_transformers(
|
|||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
show_progress_bar=is_build, # Don't show progress bar in server environment
|
show_progress_bar=is_build, # Don't show progress bar in server environment
|
||||||
convert_to_numpy=True,
|
convert_to_numpy=True,
|
||||||
normalize_embeddings=False, # Keep consistent with original API behavior
|
normalize_embeddings=False,
|
||||||
device=device,
|
device=device,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -166,7 +166,14 @@ def compute_embeddings_openai(texts: List[str], model_name: str) -> np.ndarray:
|
|||||||
if not api_key:
|
if not api_key:
|
||||||
raise RuntimeError("OPENAI_API_KEY environment variable not set")
|
raise RuntimeError("OPENAI_API_KEY environment variable not set")
|
||||||
|
|
||||||
client = openai.OpenAI(api_key=api_key)
|
# Cache OpenAI client
|
||||||
|
cache_key = "openai_client"
|
||||||
|
if cache_key in _model_cache:
|
||||||
|
client = _model_cache[cache_key]
|
||||||
|
else:
|
||||||
|
client = openai.OpenAI(api_key=api_key)
|
||||||
|
_model_cache[cache_key] = client
|
||||||
|
print("✅ OpenAI client cached")
|
||||||
|
|
||||||
print(
|
print(
|
||||||
f"INFO: Computing embeddings for {len(texts)} texts using OpenAI API, model: '{model_name}'"
|
f"INFO: Computing embeddings for {len(texts)} texts using OpenAI API, model: '{model_name}'"
|
||||||
@@ -214,7 +221,6 @@ def compute_embeddings_mlx(
|
|||||||
try:
|
try:
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
from mlx_lm.utils import load
|
from mlx_lm.utils import load
|
||||||
from tqdm import tqdm
|
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"MLX or related libraries not available. Install with: uv pip install mlx mlx-lm"
|
"MLX or related libraries not available. Install with: uv pip install mlx mlx-lm"
|
||||||
@@ -224,8 +230,16 @@ def compute_embeddings_mlx(
|
|||||||
f"INFO: Computing embeddings for {len(chunks)} chunks using MLX model '{model_name}' with batch_size={batch_size}..."
|
f"INFO: Computing embeddings for {len(chunks)} chunks using MLX model '{model_name}' with batch_size={batch_size}..."
|
||||||
)
|
)
|
||||||
|
|
||||||
# Load model and tokenizer
|
# Cache MLX model and tokenizer
|
||||||
model, tokenizer = load(model_name)
|
cache_key = f"mlx_{model_name}"
|
||||||
|
if cache_key in _model_cache:
|
||||||
|
print(f"INFO: Using cached MLX model: {model_name}")
|
||||||
|
model, tokenizer = _model_cache[cache_key]
|
||||||
|
else:
|
||||||
|
print(f"INFO: Loading and caching MLX model: {model_name}")
|
||||||
|
model, tokenizer = load(model_name)
|
||||||
|
_model_cache[cache_key] = (model, tokenizer)
|
||||||
|
print(f"✅ MLX model cached: {cache_key}")
|
||||||
|
|
||||||
# Process chunks in batches with progress bar
|
# Process chunks in batches with progress bar
|
||||||
all_embeddings = []
|
all_embeddings = []
|
||||||
|
|||||||
@@ -34,6 +34,13 @@ class LeannBackendSearcherInterface(ABC):
|
|||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def _ensure_server_running(
|
||||||
|
self, passages_source_file: str, port: Optional[int], **kwargs
|
||||||
|
) -> int:
|
||||||
|
"""Ensure server is running"""
|
||||||
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def search(
|
def search(
|
||||||
self,
|
self,
|
||||||
@@ -44,7 +51,7 @@ class LeannBackendSearcherInterface(ABC):
|
|||||||
prune_ratio: float = 0.0,
|
prune_ratio: float = 0.0,
|
||||||
recompute_embeddings: bool = False,
|
recompute_embeddings: bool = False,
|
||||||
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
||||||
expected_zmq_port: Optional[int] = None,
|
zmq_port: Optional[int] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""Search for nearest neighbors
|
"""Search for nearest neighbors
|
||||||
@@ -57,7 +64,7 @@ class LeannBackendSearcherInterface(ABC):
|
|||||||
prune_ratio: Ratio of neighbors to prune via approximate distance (0.0-1.0)
|
prune_ratio: Ratio of neighbors to prune via approximate distance (0.0-1.0)
|
||||||
recompute_embeddings: Whether to fetch fresh embeddings from server vs use stored PQ codes
|
recompute_embeddings: Whether to fetch fresh embeddings from server vs use stored PQ codes
|
||||||
pruning_strategy: PQ candidate selection strategy - "global" (default), "local", or "proportional"
|
pruning_strategy: PQ candidate selection strategy - "global" (default), "local", or "proportional"
|
||||||
expected_zmq_port: ZMQ port for embedding server communication
|
zmq_port: ZMQ port for embedding server communication
|
||||||
**kwargs: Backend-specific parameters
|
**kwargs: Backend-specific parameters
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@@ -69,14 +76,14 @@ class LeannBackendSearcherInterface(ABC):
|
|||||||
def compute_query_embedding(
|
def compute_query_embedding(
|
||||||
self,
|
self,
|
||||||
query: str,
|
query: str,
|
||||||
expected_zmq_port: Optional[int] = None,
|
|
||||||
use_server_if_available: bool = True,
|
use_server_if_available: bool = True,
|
||||||
|
zmq_port: Optional[int] = None,
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
"""Compute embedding for a query string
|
"""Compute embedding for a query string
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
query: The query string to embed
|
query: The query string to embed
|
||||||
expected_zmq_port: ZMQ port for embedding server
|
zmq_port: ZMQ port for embedding server
|
||||||
use_server_if_available: Whether to try using embedding server first
|
use_server_if_available: Whether to try using embedding server first
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
import json
|
import json
|
||||||
import pickle
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, Any, Literal, Optional
|
from typing import Dict, Any, Literal, Optional
|
||||||
@@ -88,15 +87,15 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
|
|||||||
def compute_query_embedding(
|
def compute_query_embedding(
|
||||||
self,
|
self,
|
||||||
query: str,
|
query: str,
|
||||||
expected_zmq_port: int = 5557,
|
|
||||||
use_server_if_available: bool = True,
|
use_server_if_available: bool = True,
|
||||||
|
zmq_port: int = 5557,
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
"""
|
"""
|
||||||
Compute embedding for a query string.
|
Compute embedding for a query string.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
query: The query string to embed
|
query: The query string to embed
|
||||||
expected_zmq_port: ZMQ port for embedding server
|
zmq_port: ZMQ port for embedding server
|
||||||
use_server_if_available: Whether to try using embedding server first
|
use_server_if_available: Whether to try using embedding server first
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@@ -110,7 +109,7 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
|
|||||||
self.index_dir / f"{self.index_path.name}.meta.json"
|
self.index_dir / f"{self.index_path.name}.meta.json"
|
||||||
)
|
)
|
||||||
zmq_port = self._ensure_server_running(
|
zmq_port = self._ensure_server_running(
|
||||||
str(passages_source_file), expected_zmq_port
|
str(passages_source_file), zmq_port
|
||||||
)
|
)
|
||||||
|
|
||||||
return self._compute_embedding_via_server([query], zmq_port)[
|
return self._compute_embedding_via_server([query], zmq_port)[
|
||||||
@@ -168,7 +167,7 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
|
|||||||
prune_ratio: float = 0.0,
|
prune_ratio: float = 0.0,
|
||||||
recompute_embeddings: bool = False,
|
recompute_embeddings: bool = False,
|
||||||
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
||||||
expected_zmq_port: Optional[int] = None,
|
zmq_port: Optional[int] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
@@ -182,7 +181,7 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
|
|||||||
prune_ratio: Ratio of neighbors to prune via approximate distance (0.0-1.0)
|
prune_ratio: Ratio of neighbors to prune via approximate distance (0.0-1.0)
|
||||||
recompute_embeddings: Whether to fetch fresh embeddings from server vs use stored PQ codes
|
recompute_embeddings: Whether to fetch fresh embeddings from server vs use stored PQ codes
|
||||||
pruning_strategy: PQ candidate selection strategy - "global" (default), "local", or "proportional"
|
pruning_strategy: PQ candidate selection strategy - "global" (default), "local", or "proportional"
|
||||||
expected_zmq_port: ZMQ port for embedding server communication
|
zmq_port: ZMQ port for embedding server communication
|
||||||
**kwargs: Backend-specific parameters (e.g., batch_size, dedup_node_dis, etc.)
|
**kwargs: Backend-specific parameters (e.g., batch_size, dedup_node_dis, etc.)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
|||||||
Reference in New Issue
Block a user