fix: cache the loaded model

This commit is contained in:
Andy Lee
2025-07-21 21:20:53 -07:00
parent 727724990e
commit b3970793cf
9 changed files with 163 additions and 146 deletions

View File

@@ -70,9 +70,7 @@ async def main(args):
# )
print(f"You: {query}")
chat_response = chat.ask(
query, top_k=20, recompute_beighbor_embeddings=True, complexity=32
)
chat_response = chat.ask(query, top_k=20, recompute_embeddings=True, complexity=32)
print(f"Leann: {chat_response}")

View File

@@ -4,7 +4,6 @@ import struct
from pathlib import Path
from typing import Dict, Any, List, Literal
import contextlib
import pickle
from leann.searcher_base import BaseSearcher
from leann.registry import register_backend
@@ -70,7 +69,6 @@ class DiskannBuilder(LeannBackendBuilderInterface):
data_filename = f"{index_prefix}_data.bin"
_write_vectors_to_bin(data, index_dir / data_filename)
build_kwargs = {**self.build_params, **kwargs}
metric_enum = _get_diskann_metrics().get(
build_kwargs.get("distance_metric", "mips").lower()
@@ -207,8 +205,7 @@ class DiskannSearcher(BaseSearcher):
)
string_labels = [
[str(int_label) for int_label in batch_labels]
for batch_labels in labels
[str(int_label) for int_label in batch_labels] for batch_labels in labels
]
return {"labels": string_labels, "distances": distances}

View File

@@ -1,10 +1,9 @@
import numpy as np
import os
from pathlib import Path
from typing import Dict, Any, List, Literal
import pickle
from typing import Dict, Any, List, Literal, Optional
import shutil
import time
import logging
from leann.searcher_base import BaseSearcher
from .convert_to_csr import convert_hnsw_graph_to_csr
@@ -16,6 +15,8 @@ from leann.interface import (
LeannBackendSearcherInterface,
)
logger = logging.getLogger(__name__)
def get_metric_map():
from . import faiss # type: ignore
@@ -57,9 +58,9 @@ class HNSWBuilder(LeannBackendBuilderInterface):
index_dir.mkdir(parents=True, exist_ok=True)
if data.dtype != np.float32:
logger.warning(f"Converting data to float32, shape: {data.shape}")
data = data.astype(np.float32)
metric_enum = get_metric_map().get(self.distance_metric.lower())
if metric_enum is None:
raise ValueError(f"Unsupported distance_metric '{self.distance_metric}'.")
@@ -81,7 +82,7 @@ class HNSWBuilder(LeannBackendBuilderInterface):
def _convert_to_csr(self, index_file: Path):
"""Convert built index to CSR format"""
mode_str = "CSR-pruned" if self.is_recompute else "CSR-standard"
print(f"INFO: Converting HNSW index to {mode_str} format...")
logger.info(f"INFO: Converting HNSW index to {mode_str} format...")
csr_temp_file = index_file.with_suffix(".csr.tmp")
@@ -90,11 +91,11 @@ class HNSWBuilder(LeannBackendBuilderInterface):
)
if success:
print("✅ CSR conversion successful.")
logger.info("✅ CSR conversion successful.")
index_file_old = index_file.with_suffix(".old")
shutil.move(str(index_file), str(index_file_old))
shutil.move(str(csr_temp_file), str(index_file))
print(
logger.info(
f"INFO: Replaced original index with {mode_str} version at '{index_file}'"
)
else:
@@ -131,13 +132,11 @@ class HNSWSearcher(BaseSearcher):
hnsw_config = faiss.HNSWIndexConfig()
hnsw_config.is_compact = self.is_compact
hnsw_config.is_recompute = self.is_pruned or kwargs.get("is_recompute", False)
if self.is_pruned and not hnsw_config.is_recompute:
raise RuntimeError("Index is pruned but recompute is disabled.")
hnsw_config.is_recompute = (
self.is_pruned
) # In C++ code, it's called is_recompute, but it's only for loading IIUC.
self._index = faiss.read_index(str(index_file), faiss.IO_FLAG_MMAP, hnsw_config)
def search(
self,
@@ -146,9 +145,9 @@ class HNSWSearcher(BaseSearcher):
complexity: int = 64,
beam_width: int = 1,
prune_ratio: float = 0.0,
recompute_embeddings: bool = False,
recompute_embeddings: bool = True,
pruning_strategy: Literal["global", "local", "proportional"] = "global",
zmq_port: int = 5557,
expected_zmq_port: Optional[int] = None,
batch_size: int = 0,
**kwargs,
) -> Dict[str, Any]:
@@ -166,7 +165,7 @@ class HNSWSearcher(BaseSearcher):
- "global": Use global PQ queue size for selection (default)
- "local": Local pruning, sort and select best candidates
- "proportional": Base selection on new neighbor count ratio
zmq_port: ZMQ port for embedding server
expected_zmq_port: ZMQ port for embedding server
batch_size: Neighbor processing batch size, 0=disabled (HNSW-specific)
**kwargs: Additional HNSW-specific parameters (for legacy compatibility)
@@ -175,15 +174,9 @@ class HNSWSearcher(BaseSearcher):
"""
from . import faiss # type: ignore
# Use recompute_embeddings parameter
use_recompute = recompute_embeddings or self.is_pruned
if use_recompute:
meta_file_path = self.index_dir / f"{self.index_path.name}.meta.json"
if not meta_file_path.exists():
raise RuntimeError(
f"FATAL: Recompute enabled but metadata file not found: {meta_file_path}"
)
self._ensure_server_running(str(meta_file_path), port=zmq_port, **kwargs)
if not recompute_embeddings:
if self.is_pruned:
raise RuntimeError("Recompute is required for pruned index.")
if query.dtype != np.float32:
query = query.astype(np.float32)
@@ -191,7 +184,7 @@ class HNSWSearcher(BaseSearcher):
faiss.normalize_L2(query)
params = faiss.SearchParametersHNSW()
params.zmq_port = zmq_port
params.zmq_port = expected_zmq_port
params.efSearch = complexity
params.beam_size = beam_width
@@ -228,8 +221,7 @@ class HNSWSearcher(BaseSearcher):
)
string_labels = [
[str(int_label) for int_label in batch_labels]
for batch_labels in labels
[str(int_label) for int_label in batch_labels] for batch_labels in labels
]
return {"labels": string_labels, "distances": distances}

View File

@@ -29,7 +29,6 @@ logger = logging.getLogger(__name__)
def create_hnsw_embedding_server(
passages_file: Optional[str] = None,
passages_data: Optional[Dict[str, str]] = None,
zmq_port: int = 5555,
model_name: str = "sentence-transformers/all-mpnet-base-v2",
distance_metric: str = "mips",
@@ -39,12 +38,6 @@ def create_hnsw_embedding_server(
Create and start a ZMQ-based embedding server for HNSW backend.
Simplified version using unified embedding computation module.
"""
# Auto-detect mode based on model name if not explicitly set
if embedding_mode == "sentence-transformers" and model_name.startswith(
"text-embedding-"
):
embedding_mode = "openai"
print(f"Starting HNSW server on port {zmq_port} with model {model_name}")
print(f"Using embedding mode: {embedding_mode}")
@@ -64,6 +57,7 @@ def create_hnsw_embedding_server(
finally:
sys.path.pop(0)
# Check port availability
import socket
@@ -78,13 +72,15 @@ def create_hnsw_embedding_server(
# Only support metadata file, fail fast for everything else
if not passages_file or not passages_file.endswith(".meta.json"):
raise ValueError("Only metadata files (.meta.json) are supported")
# Load metadata to get passage sources
with open(passages_file, "r") as f:
meta = json.load(f)
passages = PassageManager(meta["passage_sources"])
print(f"Loaded PassageManager with {len(passages.global_offset_map)} passages from metadata")
print(
f"Loaded PassageManager with {len(passages.global_offset_map)} passages from metadata"
)
def zmq_server_thread():
"""ZMQ server thread"""
@@ -112,7 +108,7 @@ def create_hnsw_embedding_server(
f"Processing direct text embedding request for {len(request_payload)} texts in {embedding_mode} mode"
)
# Use unified embedding computation
# Use unified embedding computation (now with model caching)
embeddings = compute_embeddings(
request_payload, model_name, mode=embedding_mode
)
@@ -148,15 +144,15 @@ def create_hnsw_embedding_server(
texts.append(txt)
except KeyError:
print(f"ERROR: Passage ID {nid} not found")
raise RuntimeError(f"FATAL: Passage with ID {nid} not found")
raise RuntimeError(
f"FATAL: Passage with ID {nid} not found"
)
except Exception as e:
print(f"ERROR: Exception looking up passage ID {nid}: {e}")
raise
# Process embeddings
embeddings = compute_embeddings(
texts, model_name, mode=embedding_mode
)
embeddings = compute_embeddings(texts, model_name, mode=embedding_mode)
print(
f"INFO: Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
)
@@ -204,7 +200,9 @@ def create_hnsw_embedding_server(
passage_data = passages.get_passage(str(nid))
txt = passage_data["text"]
if not txt:
raise RuntimeError(f"FATAL: Empty text for passage ID {nid}")
raise RuntimeError(
f"FATAL: Empty text for passage ID {nid}"
)
texts.append(txt)
except KeyError:
raise RuntimeError(f"FATAL: Passage with ID {nid} not found")

View File

@@ -5,7 +5,9 @@ with the correct, original embedding logic from the user's reference code.
import json
import pickle
from leann.interface import LeannBackendSearcherInterface
import numpy as np
import time
from pathlib import Path
from typing import List, Dict, Any, Optional, Literal
from dataclasses import dataclass, field
@@ -126,6 +128,7 @@ class PassageManager:
def get_passage(self, passage_id: str) -> Dict[str, Any]:
if passage_id in self.global_offset_map:
passage_file, offset = self.global_offset_map[passage_id]
# Lazy file opening - only open when needed
with open(passage_file, "r", encoding="utf-8") as f:
f.seek(offset)
return json.loads(f.readline())
@@ -373,10 +376,12 @@ class LeannBuilder:
class LeannSearcher:
def __init__(self, index_path: str, enable_warmup: bool = False, **backend_kwargs):
meta_path_str = f"{index_path}.meta.json"
if not Path(meta_path_str).exists():
raise FileNotFoundError(f"Leann metadata file not found at {meta_path_str}")
with open(meta_path_str, "r", encoding="utf-8") as f:
self.meta_path_str = f"{index_path}.meta.json"
if not Path(self.meta_path_str).exists():
raise FileNotFoundError(
f"Leann metadata file not found at {self.meta_path_str}"
)
with open(self.meta_path_str, "r", encoding="utf-8") as f:
self.meta_data = json.load(f)
backend_name = self.meta_data["backend_name"]
self.embedding_model = self.meta_data["embedding_model"]
@@ -390,7 +395,9 @@ class LeannSearcher:
raise ValueError(f"Backend '{backend_name}' not found.")
final_kwargs = {**self.meta_data.get("backend_kwargs", {}), **backend_kwargs}
final_kwargs["enable_warmup"] = enable_warmup
self.backend_impl = backend_factory.searcher(index_path, **final_kwargs)
self.backend_impl: LeannBackendSearcherInterface = backend_factory.searcher(
index_path, **final_kwargs
)
def search(
self,
@@ -399,9 +406,9 @@ class LeannSearcher:
complexity: int = 64,
beam_width: int = 1,
prune_ratio: float = 0.0,
recompute_embeddings: bool = False,
recompute_embeddings: bool = True,
pruning_strategy: Literal["global", "local", "proportional"] = "global",
expected_zmq_port: Optional[int] = None,
expected_zmq_port: int = 5557,
**kwargs,
) -> List[SearchResult]:
print("🔍 DEBUG LeannSearcher.search() called:")
@@ -409,16 +416,21 @@ class LeannSearcher:
print(f" Top_k: {top_k}")
print(f" Additional kwargs: {kwargs}")
# Use backend's compute_query_embedding method
# This will automatically use embedding server if available and needed
import time
start_time = time.time()
zmq_port = None
if recompute_embeddings:
zmq_port = self.backend_impl._ensure_server_running(
self.meta_path_str,
port=expected_zmq_port,
**kwargs,
)
del expected_zmq_port
query_embedding = self.backend_impl.compute_query_embedding(
query,
expected_zmq_port,
use_server_if_available=recompute_embeddings,
zmq_port=zmq_port,
)
print(f" Generated embedding shape: {query_embedding.shape}")
embedding_time = time.time() - start_time
@@ -433,7 +445,7 @@ class LeannSearcher:
prune_ratio=prune_ratio,
recompute_embeddings=recompute_embeddings,
pruning_strategy=pruning_strategy,
expected_zmq_port=expected_zmq_port,
expected_zmq_port=zmq_port,
**kwargs,
)
search_time = time.time() - start_time
@@ -488,10 +500,10 @@ class LeannChat:
complexity: int = 64,
beam_width: int = 1,
prune_ratio: float = 0.0,
recompute_embeddings: bool = False,
recompute_embeddings: bool = True,
pruning_strategy: Literal["global", "local", "proportional"] = "global",
expected_zmq_port: Optional[int] = None,
llm_kwargs: Optional[Dict[str, Any]] = None,
expected_zmq_port: int = 5557,
**search_kwargs,
):
if llm_kwargs is None:

View File

@@ -6,11 +6,14 @@ Preserves all optimization parameters to ensure performance
import numpy as np
import torch
from typing import List
from typing import List, Dict, Any, Optional
import logging
logger = logging.getLogger(__name__)
# Global model cache to avoid repeated loading
_model_cache: Dict[str, Any] = {}
def compute_embeddings(
texts: List[str], model_name: str, mode: str = "sentence-transformers",is_build: bool = False
@@ -45,25 +48,12 @@ def compute_embeddings_sentence_transformers(
is_build: bool = False,
) -> np.ndarray:
"""
Compute embeddings using SentenceTransformer
Preserves all optimization parameters to ensure consistency with original embedding_server
Args:
texts: List of texts to compute embeddings for
model_name: SentenceTransformer model name
use_fp16: Whether to use FP16 precision
device: Device selection ('auto', 'cuda', 'mps', 'cpu')
batch_size: Batch size for processing
Returns:
Normalized embeddings array, shape: (len(texts), embedding_dim)
Compute embeddings using SentenceTransformer with model caching
"""
print(
f"INFO: Computing embeddings for {len(texts)} texts using SentenceTransformer, model: '{model_name}'"
)
from sentence_transformers import SentenceTransformer
# Auto-detect device
if device == "auto":
if torch.cuda.is_available():
@@ -73,62 +63,72 @@ def compute_embeddings_sentence_transformers(
else:
device = "cpu"
print(f"INFO: Using device: {device}")
# Create cache key
cache_key = f"sentence_transformers_{model_name}_{device}_{use_fp16}"
# Check if model is already cached
if cache_key in _model_cache:
print(f"INFO: Using cached model: {model_name}")
model = _model_cache[cache_key]
else:
print(f"INFO: Loading and caching SentenceTransformer model: {model_name}")
from sentence_transformers import SentenceTransformer
# Prepare model and tokenizer optimization parameters (consistent with original embedding_server)
model_kwargs = {
"torch_dtype": torch.float16 if use_fp16 else torch.float32,
"low_cpu_mem_usage": True,
"_fast_init": True, # Skip weight initialization checks for faster loading
}
print(f"INFO: Using device: {device}")
tokenizer_kwargs = {
"use_fast": True, # Use fast tokenizer for better runtime performance
}
# Prepare model and tokenizer optimization parameters
model_kwargs = {
"torch_dtype": torch.float16 if use_fp16 else torch.float32,
"low_cpu_mem_usage": True,
"_fast_init": True,
}
# Load SentenceTransformer (try local first, then network)
print(f"INFO: Loading SentenceTransformer model: {model_name}")
tokenizer_kwargs = {
"use_fast": True,
}
try:
# Try local loading (avoid network delays)
model_kwargs["local_files_only"] = True
tokenizer_kwargs["local_files_only"] = True
model = SentenceTransformer(
model_name,
device=device,
model_kwargs=model_kwargs,
tokenizer_kwargs=tokenizer_kwargs,
local_files_only=True,
)
print("✅ Model loaded successfully! (local + optimized)")
except Exception as e:
print(f"Local loading failed ({e}), trying network download...")
# Fallback to network loading
model_kwargs["local_files_only"] = False
tokenizer_kwargs["local_files_only"] = False
model = SentenceTransformer(
model_name,
device=device,
model_kwargs=model_kwargs,
tokenizer_kwargs=tokenizer_kwargs,
local_files_only=False,
)
print("✅ Model loaded successfully! (network + optimized)")
# Apply additional optimizations (if supported)
if use_fp16 and device in ["cuda", "mps"]:
try:
model = model.half()
model = torch.compile(model)
print(f"✅ Using FP16 precision and compile optimization: {model_name}")
except Exception as e:
print(
f"FP16 or compile optimization failed, continuing with default settings: {e}"
)
# Try local loading first
model_kwargs["local_files_only"] = True
tokenizer_kwargs["local_files_only"] = True
# Compute embeddings (using SentenceTransformer's optimized implementation)
model = SentenceTransformer(
model_name,
device=device,
model_kwargs=model_kwargs,
tokenizer_kwargs=tokenizer_kwargs,
local_files_only=True,
)
print("✅ Model loaded successfully! (local + optimized)")
except Exception as e:
print(f"Local loading failed ({e}), trying network download...")
# Fallback to network loading
model_kwargs["local_files_only"] = False
tokenizer_kwargs["local_files_only"] = False
model = SentenceTransformer(
model_name,
device=device,
model_kwargs=model_kwargs,
tokenizer_kwargs=tokenizer_kwargs,
local_files_only=False,
)
print("✅ Model loaded successfully! (network + optimized)")
# Apply additional optimizations (if supported)
if use_fp16 and device in ["cuda", "mps"]:
try:
model = model.half()
model = torch.compile(model)
print(f"✅ Using FP16 precision and compile optimization: {model_name}")
except Exception as e:
print(f"FP16 or compile optimization failed: {e}")
# Cache the model
_model_cache[cache_key] = model
print(f"✅ Model cached: {cache_key}")
# Compute embeddings
print("INFO: Starting embedding computation...")
embeddings = model.encode(
@@ -136,7 +136,7 @@ def compute_embeddings_sentence_transformers(
batch_size=batch_size,
show_progress_bar=is_build, # Don't show progress bar in server environment
convert_to_numpy=True,
normalize_embeddings=False, # Keep consistent with original API behavior
normalize_embeddings=False,
device=device,
)
@@ -166,7 +166,14 @@ def compute_embeddings_openai(texts: List[str], model_name: str) -> np.ndarray:
if not api_key:
raise RuntimeError("OPENAI_API_KEY environment variable not set")
client = openai.OpenAI(api_key=api_key)
# Cache OpenAI client
cache_key = "openai_client"
if cache_key in _model_cache:
client = _model_cache[cache_key]
else:
client = openai.OpenAI(api_key=api_key)
_model_cache[cache_key] = client
print("✅ OpenAI client cached")
print(
f"INFO: Computing embeddings for {len(texts)} texts using OpenAI API, model: '{model_name}'"
@@ -214,7 +221,6 @@ def compute_embeddings_mlx(
try:
import mlx.core as mx
from mlx_lm.utils import load
from tqdm import tqdm
except ImportError as e:
raise RuntimeError(
"MLX or related libraries not available. Install with: uv pip install mlx mlx-lm"
@@ -224,8 +230,16 @@ def compute_embeddings_mlx(
f"INFO: Computing embeddings for {len(chunks)} chunks using MLX model '{model_name}' with batch_size={batch_size}..."
)
# Load model and tokenizer
model, tokenizer = load(model_name)
# Cache MLX model and tokenizer
cache_key = f"mlx_{model_name}"
if cache_key in _model_cache:
print(f"INFO: Using cached MLX model: {model_name}")
model, tokenizer = _model_cache[cache_key]
else:
print(f"INFO: Loading and caching MLX model: {model_name}")
model, tokenizer = load(model_name)
_model_cache[cache_key] = (model, tokenizer)
print(f"✅ MLX model cached: {cache_key}")
# Process chunks in batches with progress bar
all_embeddings = []

View File

@@ -34,6 +34,13 @@ class LeannBackendSearcherInterface(ABC):
"""
pass
@abstractmethod
def _ensure_server_running(
self, passages_source_file: str, port: Optional[int], **kwargs
) -> int:
"""Ensure server is running"""
pass
@abstractmethod
def search(
self,
@@ -44,7 +51,7 @@ class LeannBackendSearcherInterface(ABC):
prune_ratio: float = 0.0,
recompute_embeddings: bool = False,
pruning_strategy: Literal["global", "local", "proportional"] = "global",
expected_zmq_port: Optional[int] = None,
zmq_port: Optional[int] = None,
**kwargs,
) -> Dict[str, Any]:
"""Search for nearest neighbors
@@ -57,7 +64,7 @@ class LeannBackendSearcherInterface(ABC):
prune_ratio: Ratio of neighbors to prune via approximate distance (0.0-1.0)
recompute_embeddings: Whether to fetch fresh embeddings from server vs use stored PQ codes
pruning_strategy: PQ candidate selection strategy - "global" (default), "local", or "proportional"
expected_zmq_port: ZMQ port for embedding server communication
zmq_port: ZMQ port for embedding server communication
**kwargs: Backend-specific parameters
Returns:
@@ -69,14 +76,14 @@ class LeannBackendSearcherInterface(ABC):
def compute_query_embedding(
self,
query: str,
expected_zmq_port: Optional[int] = None,
use_server_if_available: bool = True,
zmq_port: Optional[int] = None,
) -> np.ndarray:
"""Compute embedding for a query string
Args:
query: The query string to embed
expected_zmq_port: ZMQ port for embedding server
zmq_port: ZMQ port for embedding server
use_server_if_available: Whether to try using embedding server first
Returns:

View File

@@ -1,5 +1,4 @@
import json
import pickle
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Dict, Any, Literal, Optional
@@ -88,15 +87,15 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
def compute_query_embedding(
self,
query: str,
expected_zmq_port: int = 5557,
use_server_if_available: bool = True,
zmq_port: int = 5557,
) -> np.ndarray:
"""
Compute embedding for a query string.
Args:
query: The query string to embed
expected_zmq_port: ZMQ port for embedding server
zmq_port: ZMQ port for embedding server
use_server_if_available: Whether to try using embedding server first
Returns:
@@ -110,7 +109,7 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
self.index_dir / f"{self.index_path.name}.meta.json"
)
zmq_port = self._ensure_server_running(
str(passages_source_file), expected_zmq_port
str(passages_source_file), zmq_port
)
return self._compute_embedding_via_server([query], zmq_port)[
@@ -168,7 +167,7 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
prune_ratio: float = 0.0,
recompute_embeddings: bool = False,
pruning_strategy: Literal["global", "local", "proportional"] = "global",
expected_zmq_port: Optional[int] = None,
zmq_port: Optional[int] = None,
**kwargs,
) -> Dict[str, Any]:
"""
@@ -182,7 +181,7 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
prune_ratio: Ratio of neighbors to prune via approximate distance (0.0-1.0)
recompute_embeddings: Whether to fetch fresh embeddings from server vs use stored PQ codes
pruning_strategy: PQ candidate selection strategy - "global" (default), "local", or "proportional"
expected_zmq_port: ZMQ port for embedding server communication
zmq_port: ZMQ port for embedding server communication
**kwargs: Backend-specific parameters (e.g., batch_size, dedup_node_dis, etc.)
Returns: