fix: cache the loaded model

This commit is contained in:
Andy Lee
2025-07-21 21:20:53 -07:00
parent 727724990e
commit b3970793cf
9 changed files with 163 additions and 146 deletions

View File

@@ -70,9 +70,7 @@ async def main(args):
# ) # )
print(f"You: {query}") print(f"You: {query}")
chat_response = chat.ask( chat_response = chat.ask(query, top_k=20, recompute_embeddings=True, complexity=32)
query, top_k=20, recompute_beighbor_embeddings=True, complexity=32
)
print(f"Leann: {chat_response}") print(f"Leann: {chat_response}")

View File

@@ -4,7 +4,6 @@ import struct
from pathlib import Path from pathlib import Path
from typing import Dict, Any, List, Literal from typing import Dict, Any, List, Literal
import contextlib import contextlib
import pickle
from leann.searcher_base import BaseSearcher from leann.searcher_base import BaseSearcher
from leann.registry import register_backend from leann.registry import register_backend
@@ -70,7 +69,6 @@ class DiskannBuilder(LeannBackendBuilderInterface):
data_filename = f"{index_prefix}_data.bin" data_filename = f"{index_prefix}_data.bin"
_write_vectors_to_bin(data, index_dir / data_filename) _write_vectors_to_bin(data, index_dir / data_filename)
build_kwargs = {**self.build_params, **kwargs} build_kwargs = {**self.build_params, **kwargs}
metric_enum = _get_diskann_metrics().get( metric_enum = _get_diskann_metrics().get(
build_kwargs.get("distance_metric", "mips").lower() build_kwargs.get("distance_metric", "mips").lower()
@@ -207,8 +205,7 @@ class DiskannSearcher(BaseSearcher):
) )
string_labels = [ string_labels = [
[str(int_label) for int_label in batch_labels] [str(int_label) for int_label in batch_labels] for batch_labels in labels
for batch_labels in labels
] ]
return {"labels": string_labels, "distances": distances} return {"labels": string_labels, "distances": distances}

View File

@@ -1,10 +1,9 @@
import numpy as np import numpy as np
import os import os
from pathlib import Path from pathlib import Path
from typing import Dict, Any, List, Literal from typing import Dict, Any, List, Literal, Optional
import pickle
import shutil import shutil
import time import logging
from leann.searcher_base import BaseSearcher from leann.searcher_base import BaseSearcher
from .convert_to_csr import convert_hnsw_graph_to_csr from .convert_to_csr import convert_hnsw_graph_to_csr
@@ -16,6 +15,8 @@ from leann.interface import (
LeannBackendSearcherInterface, LeannBackendSearcherInterface,
) )
logger = logging.getLogger(__name__)
def get_metric_map(): def get_metric_map():
from . import faiss # type: ignore from . import faiss # type: ignore
@@ -57,9 +58,9 @@ class HNSWBuilder(LeannBackendBuilderInterface):
index_dir.mkdir(parents=True, exist_ok=True) index_dir.mkdir(parents=True, exist_ok=True)
if data.dtype != np.float32: if data.dtype != np.float32:
logger.warning(f"Converting data to float32, shape: {data.shape}")
data = data.astype(np.float32) data = data.astype(np.float32)
metric_enum = get_metric_map().get(self.distance_metric.lower()) metric_enum = get_metric_map().get(self.distance_metric.lower())
if metric_enum is None: if metric_enum is None:
raise ValueError(f"Unsupported distance_metric '{self.distance_metric}'.") raise ValueError(f"Unsupported distance_metric '{self.distance_metric}'.")
@@ -81,7 +82,7 @@ class HNSWBuilder(LeannBackendBuilderInterface):
def _convert_to_csr(self, index_file: Path): def _convert_to_csr(self, index_file: Path):
"""Convert built index to CSR format""" """Convert built index to CSR format"""
mode_str = "CSR-pruned" if self.is_recompute else "CSR-standard" mode_str = "CSR-pruned" if self.is_recompute else "CSR-standard"
print(f"INFO: Converting HNSW index to {mode_str} format...") logger.info(f"INFO: Converting HNSW index to {mode_str} format...")
csr_temp_file = index_file.with_suffix(".csr.tmp") csr_temp_file = index_file.with_suffix(".csr.tmp")
@@ -90,11 +91,11 @@ class HNSWBuilder(LeannBackendBuilderInterface):
) )
if success: if success:
print("✅ CSR conversion successful.") logger.info("✅ CSR conversion successful.")
index_file_old = index_file.with_suffix(".old") index_file_old = index_file.with_suffix(".old")
shutil.move(str(index_file), str(index_file_old)) shutil.move(str(index_file), str(index_file_old))
shutil.move(str(csr_temp_file), str(index_file)) shutil.move(str(csr_temp_file), str(index_file))
print( logger.info(
f"INFO: Replaced original index with {mode_str} version at '{index_file}'" f"INFO: Replaced original index with {mode_str} version at '{index_file}'"
) )
else: else:
@@ -131,13 +132,11 @@ class HNSWSearcher(BaseSearcher):
hnsw_config = faiss.HNSWIndexConfig() hnsw_config = faiss.HNSWIndexConfig()
hnsw_config.is_compact = self.is_compact hnsw_config.is_compact = self.is_compact
hnsw_config.is_recompute = self.is_pruned or kwargs.get("is_recompute", False) hnsw_config.is_recompute = (
self.is_pruned
if self.is_pruned and not hnsw_config.is_recompute: ) # In C++ code, it's called is_recompute, but it's only for loading IIUC.
raise RuntimeError("Index is pruned but recompute is disabled.")
self._index = faiss.read_index(str(index_file), faiss.IO_FLAG_MMAP, hnsw_config) self._index = faiss.read_index(str(index_file), faiss.IO_FLAG_MMAP, hnsw_config)
def search( def search(
self, self,
@@ -146,9 +145,9 @@ class HNSWSearcher(BaseSearcher):
complexity: int = 64, complexity: int = 64,
beam_width: int = 1, beam_width: int = 1,
prune_ratio: float = 0.0, prune_ratio: float = 0.0,
recompute_embeddings: bool = False, recompute_embeddings: bool = True,
pruning_strategy: Literal["global", "local", "proportional"] = "global", pruning_strategy: Literal["global", "local", "proportional"] = "global",
zmq_port: int = 5557, expected_zmq_port: Optional[int] = None,
batch_size: int = 0, batch_size: int = 0,
**kwargs, **kwargs,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
@@ -166,7 +165,7 @@ class HNSWSearcher(BaseSearcher):
- "global": Use global PQ queue size for selection (default) - "global": Use global PQ queue size for selection (default)
- "local": Local pruning, sort and select best candidates - "local": Local pruning, sort and select best candidates
- "proportional": Base selection on new neighbor count ratio - "proportional": Base selection on new neighbor count ratio
zmq_port: ZMQ port for embedding server expected_zmq_port: ZMQ port for embedding server
batch_size: Neighbor processing batch size, 0=disabled (HNSW-specific) batch_size: Neighbor processing batch size, 0=disabled (HNSW-specific)
**kwargs: Additional HNSW-specific parameters (for legacy compatibility) **kwargs: Additional HNSW-specific parameters (for legacy compatibility)
@@ -175,15 +174,9 @@ class HNSWSearcher(BaseSearcher):
""" """
from . import faiss # type: ignore from . import faiss # type: ignore
# Use recompute_embeddings parameter if not recompute_embeddings:
use_recompute = recompute_embeddings or self.is_pruned if self.is_pruned:
if use_recompute: raise RuntimeError("Recompute is required for pruned index.")
meta_file_path = self.index_dir / f"{self.index_path.name}.meta.json"
if not meta_file_path.exists():
raise RuntimeError(
f"FATAL: Recompute enabled but metadata file not found: {meta_file_path}"
)
self._ensure_server_running(str(meta_file_path), port=zmq_port, **kwargs)
if query.dtype != np.float32: if query.dtype != np.float32:
query = query.astype(np.float32) query = query.astype(np.float32)
@@ -191,7 +184,7 @@ class HNSWSearcher(BaseSearcher):
faiss.normalize_L2(query) faiss.normalize_L2(query)
params = faiss.SearchParametersHNSW() params = faiss.SearchParametersHNSW()
params.zmq_port = zmq_port params.zmq_port = expected_zmq_port
params.efSearch = complexity params.efSearch = complexity
params.beam_size = beam_width params.beam_size = beam_width
@@ -228,8 +221,7 @@ class HNSWSearcher(BaseSearcher):
) )
string_labels = [ string_labels = [
[str(int_label) for int_label in batch_labels] [str(int_label) for int_label in batch_labels] for batch_labels in labels
for batch_labels in labels
] ]
return {"labels": string_labels, "distances": distances} return {"labels": string_labels, "distances": distances}

View File

@@ -29,7 +29,6 @@ logger = logging.getLogger(__name__)
def create_hnsw_embedding_server( def create_hnsw_embedding_server(
passages_file: Optional[str] = None, passages_file: Optional[str] = None,
passages_data: Optional[Dict[str, str]] = None,
zmq_port: int = 5555, zmq_port: int = 5555,
model_name: str = "sentence-transformers/all-mpnet-base-v2", model_name: str = "sentence-transformers/all-mpnet-base-v2",
distance_metric: str = "mips", distance_metric: str = "mips",
@@ -39,12 +38,6 @@ def create_hnsw_embedding_server(
Create and start a ZMQ-based embedding server for HNSW backend. Create and start a ZMQ-based embedding server for HNSW backend.
Simplified version using unified embedding computation module. Simplified version using unified embedding computation module.
""" """
# Auto-detect mode based on model name if not explicitly set
if embedding_mode == "sentence-transformers" and model_name.startswith(
"text-embedding-"
):
embedding_mode = "openai"
print(f"Starting HNSW server on port {zmq_port} with model {model_name}") print(f"Starting HNSW server on port {zmq_port} with model {model_name}")
print(f"Using embedding mode: {embedding_mode}") print(f"Using embedding mode: {embedding_mode}")
@@ -64,6 +57,7 @@ def create_hnsw_embedding_server(
finally: finally:
sys.path.pop(0) sys.path.pop(0)
# Check port availability # Check port availability
import socket import socket
@@ -78,13 +72,15 @@ def create_hnsw_embedding_server(
# Only support metadata file, fail fast for everything else # Only support metadata file, fail fast for everything else
if not passages_file or not passages_file.endswith(".meta.json"): if not passages_file or not passages_file.endswith(".meta.json"):
raise ValueError("Only metadata files (.meta.json) are supported") raise ValueError("Only metadata files (.meta.json) are supported")
# Load metadata to get passage sources # Load metadata to get passage sources
with open(passages_file, "r") as f: with open(passages_file, "r") as f:
meta = json.load(f) meta = json.load(f)
passages = PassageManager(meta["passage_sources"]) passages = PassageManager(meta["passage_sources"])
print(f"Loaded PassageManager with {len(passages.global_offset_map)} passages from metadata") print(
f"Loaded PassageManager with {len(passages.global_offset_map)} passages from metadata"
)
def zmq_server_thread(): def zmq_server_thread():
"""ZMQ server thread""" """ZMQ server thread"""
@@ -112,7 +108,7 @@ def create_hnsw_embedding_server(
f"Processing direct text embedding request for {len(request_payload)} texts in {embedding_mode} mode" f"Processing direct text embedding request for {len(request_payload)} texts in {embedding_mode} mode"
) )
# Use unified embedding computation # Use unified embedding computation (now with model caching)
embeddings = compute_embeddings( embeddings = compute_embeddings(
request_payload, model_name, mode=embedding_mode request_payload, model_name, mode=embedding_mode
) )
@@ -148,15 +144,15 @@ def create_hnsw_embedding_server(
texts.append(txt) texts.append(txt)
except KeyError: except KeyError:
print(f"ERROR: Passage ID {nid} not found") print(f"ERROR: Passage ID {nid} not found")
raise RuntimeError(f"FATAL: Passage with ID {nid} not found") raise RuntimeError(
f"FATAL: Passage with ID {nid} not found"
)
except Exception as e: except Exception as e:
print(f"ERROR: Exception looking up passage ID {nid}: {e}") print(f"ERROR: Exception looking up passage ID {nid}: {e}")
raise raise
# Process embeddings # Process embeddings
embeddings = compute_embeddings( embeddings = compute_embeddings(texts, model_name, mode=embedding_mode)
texts, model_name, mode=embedding_mode
)
print( print(
f"INFO: Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}" f"INFO: Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
) )
@@ -204,7 +200,9 @@ def create_hnsw_embedding_server(
passage_data = passages.get_passage(str(nid)) passage_data = passages.get_passage(str(nid))
txt = passage_data["text"] txt = passage_data["text"]
if not txt: if not txt:
raise RuntimeError(f"FATAL: Empty text for passage ID {nid}") raise RuntimeError(
f"FATAL: Empty text for passage ID {nid}"
)
texts.append(txt) texts.append(txt)
except KeyError: except KeyError:
raise RuntimeError(f"FATAL: Passage with ID {nid} not found") raise RuntimeError(f"FATAL: Passage with ID {nid} not found")

View File

@@ -5,7 +5,9 @@ with the correct, original embedding logic from the user's reference code.
import json import json
import pickle import pickle
from leann.interface import LeannBackendSearcherInterface
import numpy as np import numpy as np
import time
from pathlib import Path from pathlib import Path
from typing import List, Dict, Any, Optional, Literal from typing import List, Dict, Any, Optional, Literal
from dataclasses import dataclass, field from dataclasses import dataclass, field
@@ -126,6 +128,7 @@ class PassageManager:
def get_passage(self, passage_id: str) -> Dict[str, Any]: def get_passage(self, passage_id: str) -> Dict[str, Any]:
if passage_id in self.global_offset_map: if passage_id in self.global_offset_map:
passage_file, offset = self.global_offset_map[passage_id] passage_file, offset = self.global_offset_map[passage_id]
# Lazy file opening - only open when needed
with open(passage_file, "r", encoding="utf-8") as f: with open(passage_file, "r", encoding="utf-8") as f:
f.seek(offset) f.seek(offset)
return json.loads(f.readline()) return json.loads(f.readline())
@@ -373,10 +376,12 @@ class LeannBuilder:
class LeannSearcher: class LeannSearcher:
def __init__(self, index_path: str, enable_warmup: bool = False, **backend_kwargs): def __init__(self, index_path: str, enable_warmup: bool = False, **backend_kwargs):
meta_path_str = f"{index_path}.meta.json" self.meta_path_str = f"{index_path}.meta.json"
if not Path(meta_path_str).exists(): if not Path(self.meta_path_str).exists():
raise FileNotFoundError(f"Leann metadata file not found at {meta_path_str}") raise FileNotFoundError(
with open(meta_path_str, "r", encoding="utf-8") as f: f"Leann metadata file not found at {self.meta_path_str}"
)
with open(self.meta_path_str, "r", encoding="utf-8") as f:
self.meta_data = json.load(f) self.meta_data = json.load(f)
backend_name = self.meta_data["backend_name"] backend_name = self.meta_data["backend_name"]
self.embedding_model = self.meta_data["embedding_model"] self.embedding_model = self.meta_data["embedding_model"]
@@ -390,7 +395,9 @@ class LeannSearcher:
raise ValueError(f"Backend '{backend_name}' not found.") raise ValueError(f"Backend '{backend_name}' not found.")
final_kwargs = {**self.meta_data.get("backend_kwargs", {}), **backend_kwargs} final_kwargs = {**self.meta_data.get("backend_kwargs", {}), **backend_kwargs}
final_kwargs["enable_warmup"] = enable_warmup final_kwargs["enable_warmup"] = enable_warmup
self.backend_impl = backend_factory.searcher(index_path, **final_kwargs) self.backend_impl: LeannBackendSearcherInterface = backend_factory.searcher(
index_path, **final_kwargs
)
def search( def search(
self, self,
@@ -399,9 +406,9 @@ class LeannSearcher:
complexity: int = 64, complexity: int = 64,
beam_width: int = 1, beam_width: int = 1,
prune_ratio: float = 0.0, prune_ratio: float = 0.0,
recompute_embeddings: bool = False, recompute_embeddings: bool = True,
pruning_strategy: Literal["global", "local", "proportional"] = "global", pruning_strategy: Literal["global", "local", "proportional"] = "global",
expected_zmq_port: Optional[int] = None, expected_zmq_port: int = 5557,
**kwargs, **kwargs,
) -> List[SearchResult]: ) -> List[SearchResult]:
print("🔍 DEBUG LeannSearcher.search() called:") print("🔍 DEBUG LeannSearcher.search() called:")
@@ -409,16 +416,21 @@ class LeannSearcher:
print(f" Top_k: {top_k}") print(f" Top_k: {top_k}")
print(f" Additional kwargs: {kwargs}") print(f" Additional kwargs: {kwargs}")
# Use backend's compute_query_embedding method
# This will automatically use embedding server if available and needed
import time
start_time = time.time() start_time = time.time()
zmq_port = None
if recompute_embeddings:
zmq_port = self.backend_impl._ensure_server_running(
self.meta_path_str,
port=expected_zmq_port,
**kwargs,
)
del expected_zmq_port
query_embedding = self.backend_impl.compute_query_embedding( query_embedding = self.backend_impl.compute_query_embedding(
query, query,
expected_zmq_port,
use_server_if_available=recompute_embeddings, use_server_if_available=recompute_embeddings,
zmq_port=zmq_port,
) )
print(f" Generated embedding shape: {query_embedding.shape}") print(f" Generated embedding shape: {query_embedding.shape}")
embedding_time = time.time() - start_time embedding_time = time.time() - start_time
@@ -433,7 +445,7 @@ class LeannSearcher:
prune_ratio=prune_ratio, prune_ratio=prune_ratio,
recompute_embeddings=recompute_embeddings, recompute_embeddings=recompute_embeddings,
pruning_strategy=pruning_strategy, pruning_strategy=pruning_strategy,
expected_zmq_port=expected_zmq_port, expected_zmq_port=zmq_port,
**kwargs, **kwargs,
) )
search_time = time.time() - start_time search_time = time.time() - start_time
@@ -488,10 +500,10 @@ class LeannChat:
complexity: int = 64, complexity: int = 64,
beam_width: int = 1, beam_width: int = 1,
prune_ratio: float = 0.0, prune_ratio: float = 0.0,
recompute_embeddings: bool = False, recompute_embeddings: bool = True,
pruning_strategy: Literal["global", "local", "proportional"] = "global", pruning_strategy: Literal["global", "local", "proportional"] = "global",
expected_zmq_port: Optional[int] = None,
llm_kwargs: Optional[Dict[str, Any]] = None, llm_kwargs: Optional[Dict[str, Any]] = None,
expected_zmq_port: int = 5557,
**search_kwargs, **search_kwargs,
): ):
if llm_kwargs is None: if llm_kwargs is None:

View File

@@ -6,11 +6,14 @@ Preserves all optimization parameters to ensure performance
import numpy as np import numpy as np
import torch import torch
from typing import List from typing import List, Dict, Any, Optional
import logging import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Global model cache to avoid repeated loading
_model_cache: Dict[str, Any] = {}
def compute_embeddings( def compute_embeddings(
texts: List[str], model_name: str, mode: str = "sentence-transformers",is_build: bool = False texts: List[str], model_name: str, mode: str = "sentence-transformers",is_build: bool = False
@@ -45,25 +48,12 @@ def compute_embeddings_sentence_transformers(
is_build: bool = False, is_build: bool = False,
) -> np.ndarray: ) -> np.ndarray:
""" """
Compute embeddings using SentenceTransformer Compute embeddings using SentenceTransformer with model caching
Preserves all optimization parameters to ensure consistency with original embedding_server
Args:
texts: List of texts to compute embeddings for
model_name: SentenceTransformer model name
use_fp16: Whether to use FP16 precision
device: Device selection ('auto', 'cuda', 'mps', 'cpu')
batch_size: Batch size for processing
Returns:
Normalized embeddings array, shape: (len(texts), embedding_dim)
""" """
print( print(
f"INFO: Computing embeddings for {len(texts)} texts using SentenceTransformer, model: '{model_name}'" f"INFO: Computing embeddings for {len(texts)} texts using SentenceTransformer, model: '{model_name}'"
) )
from sentence_transformers import SentenceTransformer
# Auto-detect device # Auto-detect device
if device == "auto": if device == "auto":
if torch.cuda.is_available(): if torch.cuda.is_available():
@@ -73,62 +63,72 @@ def compute_embeddings_sentence_transformers(
else: else:
device = "cpu" device = "cpu"
print(f"INFO: Using device: {device}") # Create cache key
cache_key = f"sentence_transformers_{model_name}_{device}_{use_fp16}"
# Check if model is already cached
if cache_key in _model_cache:
print(f"INFO: Using cached model: {model_name}")
model = _model_cache[cache_key]
else:
print(f"INFO: Loading and caching SentenceTransformer model: {model_name}")
from sentence_transformers import SentenceTransformer
# Prepare model and tokenizer optimization parameters (consistent with original embedding_server) print(f"INFO: Using device: {device}")
model_kwargs = {
"torch_dtype": torch.float16 if use_fp16 else torch.float32,
"low_cpu_mem_usage": True,
"_fast_init": True, # Skip weight initialization checks for faster loading
}
tokenizer_kwargs = { # Prepare model and tokenizer optimization parameters
"use_fast": True, # Use fast tokenizer for better runtime performance model_kwargs = {
} "torch_dtype": torch.float16 if use_fp16 else torch.float32,
"low_cpu_mem_usage": True,
"_fast_init": True,
}
# Load SentenceTransformer (try local first, then network) tokenizer_kwargs = {
print(f"INFO: Loading SentenceTransformer model: {model_name}") "use_fast": True,
}
try:
# Try local loading (avoid network delays)
model_kwargs["local_files_only"] = True
tokenizer_kwargs["local_files_only"] = True
model = SentenceTransformer(
model_name,
device=device,
model_kwargs=model_kwargs,
tokenizer_kwargs=tokenizer_kwargs,
local_files_only=True,
)
print("✅ Model loaded successfully! (local + optimized)")
except Exception as e:
print(f"Local loading failed ({e}), trying network download...")
# Fallback to network loading
model_kwargs["local_files_only"] = False
tokenizer_kwargs["local_files_only"] = False
model = SentenceTransformer(
model_name,
device=device,
model_kwargs=model_kwargs,
tokenizer_kwargs=tokenizer_kwargs,
local_files_only=False,
)
print("✅ Model loaded successfully! (network + optimized)")
# Apply additional optimizations (if supported)
if use_fp16 and device in ["cuda", "mps"]:
try: try:
model = model.half() # Try local loading first
model = torch.compile(model) model_kwargs["local_files_only"] = True
print(f"✅ Using FP16 precision and compile optimization: {model_name}") tokenizer_kwargs["local_files_only"] = True
except Exception as e:
print(
f"FP16 or compile optimization failed, continuing with default settings: {e}"
)
# Compute embeddings (using SentenceTransformer's optimized implementation) model = SentenceTransformer(
model_name,
device=device,
model_kwargs=model_kwargs,
tokenizer_kwargs=tokenizer_kwargs,
local_files_only=True,
)
print("✅ Model loaded successfully! (local + optimized)")
except Exception as e:
print(f"Local loading failed ({e}), trying network download...")
# Fallback to network loading
model_kwargs["local_files_only"] = False
tokenizer_kwargs["local_files_only"] = False
model = SentenceTransformer(
model_name,
device=device,
model_kwargs=model_kwargs,
tokenizer_kwargs=tokenizer_kwargs,
local_files_only=False,
)
print("✅ Model loaded successfully! (network + optimized)")
# Apply additional optimizations (if supported)
if use_fp16 and device in ["cuda", "mps"]:
try:
model = model.half()
model = torch.compile(model)
print(f"✅ Using FP16 precision and compile optimization: {model_name}")
except Exception as e:
print(f"FP16 or compile optimization failed: {e}")
# Cache the model
_model_cache[cache_key] = model
print(f"✅ Model cached: {cache_key}")
# Compute embeddings
print("INFO: Starting embedding computation...") print("INFO: Starting embedding computation...")
embeddings = model.encode( embeddings = model.encode(
@@ -136,7 +136,7 @@ def compute_embeddings_sentence_transformers(
batch_size=batch_size, batch_size=batch_size,
show_progress_bar=is_build, # Don't show progress bar in server environment show_progress_bar=is_build, # Don't show progress bar in server environment
convert_to_numpy=True, convert_to_numpy=True,
normalize_embeddings=False, # Keep consistent with original API behavior normalize_embeddings=False,
device=device, device=device,
) )
@@ -166,7 +166,14 @@ def compute_embeddings_openai(texts: List[str], model_name: str) -> np.ndarray:
if not api_key: if not api_key:
raise RuntimeError("OPENAI_API_KEY environment variable not set") raise RuntimeError("OPENAI_API_KEY environment variable not set")
client = openai.OpenAI(api_key=api_key) # Cache OpenAI client
cache_key = "openai_client"
if cache_key in _model_cache:
client = _model_cache[cache_key]
else:
client = openai.OpenAI(api_key=api_key)
_model_cache[cache_key] = client
print("✅ OpenAI client cached")
print( print(
f"INFO: Computing embeddings for {len(texts)} texts using OpenAI API, model: '{model_name}'" f"INFO: Computing embeddings for {len(texts)} texts using OpenAI API, model: '{model_name}'"
@@ -214,7 +221,6 @@ def compute_embeddings_mlx(
try: try:
import mlx.core as mx import mlx.core as mx
from mlx_lm.utils import load from mlx_lm.utils import load
from tqdm import tqdm
except ImportError as e: except ImportError as e:
raise RuntimeError( raise RuntimeError(
"MLX or related libraries not available. Install with: uv pip install mlx mlx-lm" "MLX or related libraries not available. Install with: uv pip install mlx mlx-lm"
@@ -224,8 +230,16 @@ def compute_embeddings_mlx(
f"INFO: Computing embeddings for {len(chunks)} chunks using MLX model '{model_name}' with batch_size={batch_size}..." f"INFO: Computing embeddings for {len(chunks)} chunks using MLX model '{model_name}' with batch_size={batch_size}..."
) )
# Load model and tokenizer # Cache MLX model and tokenizer
model, tokenizer = load(model_name) cache_key = f"mlx_{model_name}"
if cache_key in _model_cache:
print(f"INFO: Using cached MLX model: {model_name}")
model, tokenizer = _model_cache[cache_key]
else:
print(f"INFO: Loading and caching MLX model: {model_name}")
model, tokenizer = load(model_name)
_model_cache[cache_key] = (model, tokenizer)
print(f"✅ MLX model cached: {cache_key}")
# Process chunks in batches with progress bar # Process chunks in batches with progress bar
all_embeddings = [] all_embeddings = []

View File

@@ -34,6 +34,13 @@ class LeannBackendSearcherInterface(ABC):
""" """
pass pass
@abstractmethod
def _ensure_server_running(
self, passages_source_file: str, port: Optional[int], **kwargs
) -> int:
"""Ensure server is running"""
pass
@abstractmethod @abstractmethod
def search( def search(
self, self,
@@ -44,7 +51,7 @@ class LeannBackendSearcherInterface(ABC):
prune_ratio: float = 0.0, prune_ratio: float = 0.0,
recompute_embeddings: bool = False, recompute_embeddings: bool = False,
pruning_strategy: Literal["global", "local", "proportional"] = "global", pruning_strategy: Literal["global", "local", "proportional"] = "global",
expected_zmq_port: Optional[int] = None, zmq_port: Optional[int] = None,
**kwargs, **kwargs,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
"""Search for nearest neighbors """Search for nearest neighbors
@@ -57,7 +64,7 @@ class LeannBackendSearcherInterface(ABC):
prune_ratio: Ratio of neighbors to prune via approximate distance (0.0-1.0) prune_ratio: Ratio of neighbors to prune via approximate distance (0.0-1.0)
recompute_embeddings: Whether to fetch fresh embeddings from server vs use stored PQ codes recompute_embeddings: Whether to fetch fresh embeddings from server vs use stored PQ codes
pruning_strategy: PQ candidate selection strategy - "global" (default), "local", or "proportional" pruning_strategy: PQ candidate selection strategy - "global" (default), "local", or "proportional"
expected_zmq_port: ZMQ port for embedding server communication zmq_port: ZMQ port for embedding server communication
**kwargs: Backend-specific parameters **kwargs: Backend-specific parameters
Returns: Returns:
@@ -69,14 +76,14 @@ class LeannBackendSearcherInterface(ABC):
def compute_query_embedding( def compute_query_embedding(
self, self,
query: str, query: str,
expected_zmq_port: Optional[int] = None,
use_server_if_available: bool = True, use_server_if_available: bool = True,
zmq_port: Optional[int] = None,
) -> np.ndarray: ) -> np.ndarray:
"""Compute embedding for a query string """Compute embedding for a query string
Args: Args:
query: The query string to embed query: The query string to embed
expected_zmq_port: ZMQ port for embedding server zmq_port: ZMQ port for embedding server
use_server_if_available: Whether to try using embedding server first use_server_if_available: Whether to try using embedding server first
Returns: Returns:

View File

@@ -1,5 +1,4 @@
import json import json
import pickle
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from pathlib import Path from pathlib import Path
from typing import Dict, Any, Literal, Optional from typing import Dict, Any, Literal, Optional
@@ -88,15 +87,15 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
def compute_query_embedding( def compute_query_embedding(
self, self,
query: str, query: str,
expected_zmq_port: int = 5557,
use_server_if_available: bool = True, use_server_if_available: bool = True,
zmq_port: int = 5557,
) -> np.ndarray: ) -> np.ndarray:
""" """
Compute embedding for a query string. Compute embedding for a query string.
Args: Args:
query: The query string to embed query: The query string to embed
expected_zmq_port: ZMQ port for embedding server zmq_port: ZMQ port for embedding server
use_server_if_available: Whether to try using embedding server first use_server_if_available: Whether to try using embedding server first
Returns: Returns:
@@ -110,7 +109,7 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
self.index_dir / f"{self.index_path.name}.meta.json" self.index_dir / f"{self.index_path.name}.meta.json"
) )
zmq_port = self._ensure_server_running( zmq_port = self._ensure_server_running(
str(passages_source_file), expected_zmq_port str(passages_source_file), zmq_port
) )
return self._compute_embedding_via_server([query], zmq_port)[ return self._compute_embedding_via_server([query], zmq_port)[
@@ -168,7 +167,7 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
prune_ratio: float = 0.0, prune_ratio: float = 0.0,
recompute_embeddings: bool = False, recompute_embeddings: bool = False,
pruning_strategy: Literal["global", "local", "proportional"] = "global", pruning_strategy: Literal["global", "local", "proportional"] = "global",
expected_zmq_port: Optional[int] = None, zmq_port: Optional[int] = None,
**kwargs, **kwargs,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
""" """
@@ -182,7 +181,7 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
prune_ratio: Ratio of neighbors to prune via approximate distance (0.0-1.0) prune_ratio: Ratio of neighbors to prune via approximate distance (0.0-1.0)
recompute_embeddings: Whether to fetch fresh embeddings from server vs use stored PQ codes recompute_embeddings: Whether to fetch fresh embeddings from server vs use stored PQ codes
pruning_strategy: PQ candidate selection strategy - "global" (default), "local", or "proportional" pruning_strategy: PQ candidate selection strategy - "global" (default), "local", or "proportional"
expected_zmq_port: ZMQ port for embedding server communication zmq_port: ZMQ port for embedding server communication
**kwargs: Backend-specific parameters (e.g., batch_size, dedup_node_dis, etc.) **kwargs: Backend-specific parameters (e.g., batch_size, dedup_node_dis, etc.)
Returns: Returns: