fix: cache the loaded model
This commit is contained in:
@@ -5,7 +5,9 @@ with the correct, original embedding logic from the user's reference code.
|
||||
|
||||
import json
|
||||
import pickle
|
||||
from leann.interface import LeannBackendSearcherInterface
|
||||
import numpy as np
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Any, Optional, Literal
|
||||
from dataclasses import dataclass, field
|
||||
@@ -126,6 +128,7 @@ class PassageManager:
|
||||
def get_passage(self, passage_id: str) -> Dict[str, Any]:
|
||||
if passage_id in self.global_offset_map:
|
||||
passage_file, offset = self.global_offset_map[passage_id]
|
||||
# Lazy file opening - only open when needed
|
||||
with open(passage_file, "r", encoding="utf-8") as f:
|
||||
f.seek(offset)
|
||||
return json.loads(f.readline())
|
||||
@@ -373,10 +376,12 @@ class LeannBuilder:
|
||||
|
||||
class LeannSearcher:
|
||||
def __init__(self, index_path: str, enable_warmup: bool = False, **backend_kwargs):
|
||||
meta_path_str = f"{index_path}.meta.json"
|
||||
if not Path(meta_path_str).exists():
|
||||
raise FileNotFoundError(f"Leann metadata file not found at {meta_path_str}")
|
||||
with open(meta_path_str, "r", encoding="utf-8") as f:
|
||||
self.meta_path_str = f"{index_path}.meta.json"
|
||||
if not Path(self.meta_path_str).exists():
|
||||
raise FileNotFoundError(
|
||||
f"Leann metadata file not found at {self.meta_path_str}"
|
||||
)
|
||||
with open(self.meta_path_str, "r", encoding="utf-8") as f:
|
||||
self.meta_data = json.load(f)
|
||||
backend_name = self.meta_data["backend_name"]
|
||||
self.embedding_model = self.meta_data["embedding_model"]
|
||||
@@ -390,7 +395,9 @@ class LeannSearcher:
|
||||
raise ValueError(f"Backend '{backend_name}' not found.")
|
||||
final_kwargs = {**self.meta_data.get("backend_kwargs", {}), **backend_kwargs}
|
||||
final_kwargs["enable_warmup"] = enable_warmup
|
||||
self.backend_impl = backend_factory.searcher(index_path, **final_kwargs)
|
||||
self.backend_impl: LeannBackendSearcherInterface = backend_factory.searcher(
|
||||
index_path, **final_kwargs
|
||||
)
|
||||
|
||||
def search(
|
||||
self,
|
||||
@@ -399,9 +406,9 @@ class LeannSearcher:
|
||||
complexity: int = 64,
|
||||
beam_width: int = 1,
|
||||
prune_ratio: float = 0.0,
|
||||
recompute_embeddings: bool = False,
|
||||
recompute_embeddings: bool = True,
|
||||
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
||||
expected_zmq_port: Optional[int] = None,
|
||||
expected_zmq_port: int = 5557,
|
||||
**kwargs,
|
||||
) -> List[SearchResult]:
|
||||
print("🔍 DEBUG LeannSearcher.search() called:")
|
||||
@@ -409,16 +416,21 @@ class LeannSearcher:
|
||||
print(f" Top_k: {top_k}")
|
||||
print(f" Additional kwargs: {kwargs}")
|
||||
|
||||
# Use backend's compute_query_embedding method
|
||||
# This will automatically use embedding server if available and needed
|
||||
import time
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
zmq_port = None
|
||||
if recompute_embeddings:
|
||||
zmq_port = self.backend_impl._ensure_server_running(
|
||||
self.meta_path_str,
|
||||
port=expected_zmq_port,
|
||||
**kwargs,
|
||||
)
|
||||
del expected_zmq_port
|
||||
|
||||
query_embedding = self.backend_impl.compute_query_embedding(
|
||||
query,
|
||||
expected_zmq_port,
|
||||
use_server_if_available=recompute_embeddings,
|
||||
zmq_port=zmq_port,
|
||||
)
|
||||
print(f" Generated embedding shape: {query_embedding.shape}")
|
||||
embedding_time = time.time() - start_time
|
||||
@@ -433,7 +445,7 @@ class LeannSearcher:
|
||||
prune_ratio=prune_ratio,
|
||||
recompute_embeddings=recompute_embeddings,
|
||||
pruning_strategy=pruning_strategy,
|
||||
expected_zmq_port=expected_zmq_port,
|
||||
expected_zmq_port=zmq_port,
|
||||
**kwargs,
|
||||
)
|
||||
search_time = time.time() - start_time
|
||||
@@ -488,10 +500,10 @@ class LeannChat:
|
||||
complexity: int = 64,
|
||||
beam_width: int = 1,
|
||||
prune_ratio: float = 0.0,
|
||||
recompute_embeddings: bool = False,
|
||||
recompute_embeddings: bool = True,
|
||||
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
||||
expected_zmq_port: Optional[int] = None,
|
||||
llm_kwargs: Optional[Dict[str, Any]] = None,
|
||||
expected_zmq_port: int = 5557,
|
||||
**search_kwargs,
|
||||
):
|
||||
if llm_kwargs is None:
|
||||
|
||||
@@ -6,11 +6,14 @@ Preserves all optimization parameters to ensure performance
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from typing import List
|
||||
from typing import List, Dict, Any, Optional
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Global model cache to avoid repeated loading
|
||||
_model_cache: Dict[str, Any] = {}
|
||||
|
||||
|
||||
def compute_embeddings(
|
||||
texts: List[str], model_name: str, mode: str = "sentence-transformers",is_build: bool = False
|
||||
@@ -45,25 +48,12 @@ def compute_embeddings_sentence_transformers(
|
||||
is_build: bool = False,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Compute embeddings using SentenceTransformer
|
||||
Preserves all optimization parameters to ensure consistency with original embedding_server
|
||||
|
||||
Args:
|
||||
texts: List of texts to compute embeddings for
|
||||
model_name: SentenceTransformer model name
|
||||
use_fp16: Whether to use FP16 precision
|
||||
device: Device selection ('auto', 'cuda', 'mps', 'cpu')
|
||||
batch_size: Batch size for processing
|
||||
|
||||
Returns:
|
||||
Normalized embeddings array, shape: (len(texts), embedding_dim)
|
||||
Compute embeddings using SentenceTransformer with model caching
|
||||
"""
|
||||
print(
|
||||
f"INFO: Computing embeddings for {len(texts)} texts using SentenceTransformer, model: '{model_name}'"
|
||||
)
|
||||
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
# Auto-detect device
|
||||
if device == "auto":
|
||||
if torch.cuda.is_available():
|
||||
@@ -73,62 +63,72 @@ def compute_embeddings_sentence_transformers(
|
||||
else:
|
||||
device = "cpu"
|
||||
|
||||
print(f"INFO: Using device: {device}")
|
||||
# Create cache key
|
||||
cache_key = f"sentence_transformers_{model_name}_{device}_{use_fp16}"
|
||||
|
||||
# Check if model is already cached
|
||||
if cache_key in _model_cache:
|
||||
print(f"INFO: Using cached model: {model_name}")
|
||||
model = _model_cache[cache_key]
|
||||
else:
|
||||
print(f"INFO: Loading and caching SentenceTransformer model: {model_name}")
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
# Prepare model and tokenizer optimization parameters (consistent with original embedding_server)
|
||||
model_kwargs = {
|
||||
"torch_dtype": torch.float16 if use_fp16 else torch.float32,
|
||||
"low_cpu_mem_usage": True,
|
||||
"_fast_init": True, # Skip weight initialization checks for faster loading
|
||||
}
|
||||
print(f"INFO: Using device: {device}")
|
||||
|
||||
tokenizer_kwargs = {
|
||||
"use_fast": True, # Use fast tokenizer for better runtime performance
|
||||
}
|
||||
# Prepare model and tokenizer optimization parameters
|
||||
model_kwargs = {
|
||||
"torch_dtype": torch.float16 if use_fp16 else torch.float32,
|
||||
"low_cpu_mem_usage": True,
|
||||
"_fast_init": True,
|
||||
}
|
||||
|
||||
# Load SentenceTransformer (try local first, then network)
|
||||
print(f"INFO: Loading SentenceTransformer model: {model_name}")
|
||||
tokenizer_kwargs = {
|
||||
"use_fast": True,
|
||||
}
|
||||
|
||||
try:
|
||||
# Try local loading (avoid network delays)
|
||||
model_kwargs["local_files_only"] = True
|
||||
tokenizer_kwargs["local_files_only"] = True
|
||||
|
||||
model = SentenceTransformer(
|
||||
model_name,
|
||||
device=device,
|
||||
model_kwargs=model_kwargs,
|
||||
tokenizer_kwargs=tokenizer_kwargs,
|
||||
local_files_only=True,
|
||||
)
|
||||
print("✅ Model loaded successfully! (local + optimized)")
|
||||
except Exception as e:
|
||||
print(f"Local loading failed ({e}), trying network download...")
|
||||
# Fallback to network loading
|
||||
model_kwargs["local_files_only"] = False
|
||||
tokenizer_kwargs["local_files_only"] = False
|
||||
|
||||
model = SentenceTransformer(
|
||||
model_name,
|
||||
device=device,
|
||||
model_kwargs=model_kwargs,
|
||||
tokenizer_kwargs=tokenizer_kwargs,
|
||||
local_files_only=False,
|
||||
)
|
||||
print("✅ Model loaded successfully! (network + optimized)")
|
||||
|
||||
# Apply additional optimizations (if supported)
|
||||
if use_fp16 and device in ["cuda", "mps"]:
|
||||
try:
|
||||
model = model.half()
|
||||
model = torch.compile(model)
|
||||
print(f"✅ Using FP16 precision and compile optimization: {model_name}")
|
||||
except Exception as e:
|
||||
print(
|
||||
f"FP16 or compile optimization failed, continuing with default settings: {e}"
|
||||
)
|
||||
# Try local loading first
|
||||
model_kwargs["local_files_only"] = True
|
||||
tokenizer_kwargs["local_files_only"] = True
|
||||
|
||||
# Compute embeddings (using SentenceTransformer's optimized implementation)
|
||||
model = SentenceTransformer(
|
||||
model_name,
|
||||
device=device,
|
||||
model_kwargs=model_kwargs,
|
||||
tokenizer_kwargs=tokenizer_kwargs,
|
||||
local_files_only=True,
|
||||
)
|
||||
print("✅ Model loaded successfully! (local + optimized)")
|
||||
except Exception as e:
|
||||
print(f"Local loading failed ({e}), trying network download...")
|
||||
# Fallback to network loading
|
||||
model_kwargs["local_files_only"] = False
|
||||
tokenizer_kwargs["local_files_only"] = False
|
||||
|
||||
model = SentenceTransformer(
|
||||
model_name,
|
||||
device=device,
|
||||
model_kwargs=model_kwargs,
|
||||
tokenizer_kwargs=tokenizer_kwargs,
|
||||
local_files_only=False,
|
||||
)
|
||||
print("✅ Model loaded successfully! (network + optimized)")
|
||||
|
||||
# Apply additional optimizations (if supported)
|
||||
if use_fp16 and device in ["cuda", "mps"]:
|
||||
try:
|
||||
model = model.half()
|
||||
model = torch.compile(model)
|
||||
print(f"✅ Using FP16 precision and compile optimization: {model_name}")
|
||||
except Exception as e:
|
||||
print(f"FP16 or compile optimization failed: {e}")
|
||||
|
||||
# Cache the model
|
||||
_model_cache[cache_key] = model
|
||||
print(f"✅ Model cached: {cache_key}")
|
||||
|
||||
# Compute embeddings
|
||||
print("INFO: Starting embedding computation...")
|
||||
|
||||
embeddings = model.encode(
|
||||
@@ -136,7 +136,7 @@ def compute_embeddings_sentence_transformers(
|
||||
batch_size=batch_size,
|
||||
show_progress_bar=is_build, # Don't show progress bar in server environment
|
||||
convert_to_numpy=True,
|
||||
normalize_embeddings=False, # Keep consistent with original API behavior
|
||||
normalize_embeddings=False,
|
||||
device=device,
|
||||
)
|
||||
|
||||
@@ -166,7 +166,14 @@ def compute_embeddings_openai(texts: List[str], model_name: str) -> np.ndarray:
|
||||
if not api_key:
|
||||
raise RuntimeError("OPENAI_API_KEY environment variable not set")
|
||||
|
||||
client = openai.OpenAI(api_key=api_key)
|
||||
# Cache OpenAI client
|
||||
cache_key = "openai_client"
|
||||
if cache_key in _model_cache:
|
||||
client = _model_cache[cache_key]
|
||||
else:
|
||||
client = openai.OpenAI(api_key=api_key)
|
||||
_model_cache[cache_key] = client
|
||||
print("✅ OpenAI client cached")
|
||||
|
||||
print(
|
||||
f"INFO: Computing embeddings for {len(texts)} texts using OpenAI API, model: '{model_name}'"
|
||||
@@ -214,7 +221,6 @@ def compute_embeddings_mlx(
|
||||
try:
|
||||
import mlx.core as mx
|
||||
from mlx_lm.utils import load
|
||||
from tqdm import tqdm
|
||||
except ImportError as e:
|
||||
raise RuntimeError(
|
||||
"MLX or related libraries not available. Install with: uv pip install mlx mlx-lm"
|
||||
@@ -224,8 +230,16 @@ def compute_embeddings_mlx(
|
||||
f"INFO: Computing embeddings for {len(chunks)} chunks using MLX model '{model_name}' with batch_size={batch_size}..."
|
||||
)
|
||||
|
||||
# Load model and tokenizer
|
||||
model, tokenizer = load(model_name)
|
||||
# Cache MLX model and tokenizer
|
||||
cache_key = f"mlx_{model_name}"
|
||||
if cache_key in _model_cache:
|
||||
print(f"INFO: Using cached MLX model: {model_name}")
|
||||
model, tokenizer = _model_cache[cache_key]
|
||||
else:
|
||||
print(f"INFO: Loading and caching MLX model: {model_name}")
|
||||
model, tokenizer = load(model_name)
|
||||
_model_cache[cache_key] = (model, tokenizer)
|
||||
print(f"✅ MLX model cached: {cache_key}")
|
||||
|
||||
# Process chunks in batches with progress bar
|
||||
all_embeddings = []
|
||||
|
||||
@@ -34,6 +34,13 @@ class LeannBackendSearcherInterface(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _ensure_server_running(
|
||||
self, passages_source_file: str, port: Optional[int], **kwargs
|
||||
) -> int:
|
||||
"""Ensure server is running"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def search(
|
||||
self,
|
||||
@@ -44,7 +51,7 @@ class LeannBackendSearcherInterface(ABC):
|
||||
prune_ratio: float = 0.0,
|
||||
recompute_embeddings: bool = False,
|
||||
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
||||
expected_zmq_port: Optional[int] = None,
|
||||
zmq_port: Optional[int] = None,
|
||||
**kwargs,
|
||||
) -> Dict[str, Any]:
|
||||
"""Search for nearest neighbors
|
||||
@@ -57,7 +64,7 @@ class LeannBackendSearcherInterface(ABC):
|
||||
prune_ratio: Ratio of neighbors to prune via approximate distance (0.0-1.0)
|
||||
recompute_embeddings: Whether to fetch fresh embeddings from server vs use stored PQ codes
|
||||
pruning_strategy: PQ candidate selection strategy - "global" (default), "local", or "proportional"
|
||||
expected_zmq_port: ZMQ port for embedding server communication
|
||||
zmq_port: ZMQ port for embedding server communication
|
||||
**kwargs: Backend-specific parameters
|
||||
|
||||
Returns:
|
||||
@@ -69,14 +76,14 @@ class LeannBackendSearcherInterface(ABC):
|
||||
def compute_query_embedding(
|
||||
self,
|
||||
query: str,
|
||||
expected_zmq_port: Optional[int] = None,
|
||||
use_server_if_available: bool = True,
|
||||
zmq_port: Optional[int] = None,
|
||||
) -> np.ndarray:
|
||||
"""Compute embedding for a query string
|
||||
|
||||
Args:
|
||||
query: The query string to embed
|
||||
expected_zmq_port: ZMQ port for embedding server
|
||||
zmq_port: ZMQ port for embedding server
|
||||
use_server_if_available: Whether to try using embedding server first
|
||||
|
||||
Returns:
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import json
|
||||
import pickle
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, Literal, Optional
|
||||
@@ -88,15 +87,15 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
|
||||
def compute_query_embedding(
|
||||
self,
|
||||
query: str,
|
||||
expected_zmq_port: int = 5557,
|
||||
use_server_if_available: bool = True,
|
||||
zmq_port: int = 5557,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Compute embedding for a query string.
|
||||
|
||||
Args:
|
||||
query: The query string to embed
|
||||
expected_zmq_port: ZMQ port for embedding server
|
||||
zmq_port: ZMQ port for embedding server
|
||||
use_server_if_available: Whether to try using embedding server first
|
||||
|
||||
Returns:
|
||||
@@ -110,7 +109,7 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
|
||||
self.index_dir / f"{self.index_path.name}.meta.json"
|
||||
)
|
||||
zmq_port = self._ensure_server_running(
|
||||
str(passages_source_file), expected_zmq_port
|
||||
str(passages_source_file), zmq_port
|
||||
)
|
||||
|
||||
return self._compute_embedding_via_server([query], zmq_port)[
|
||||
@@ -168,7 +167,7 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
|
||||
prune_ratio: float = 0.0,
|
||||
recompute_embeddings: bool = False,
|
||||
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
||||
expected_zmq_port: Optional[int] = None,
|
||||
zmq_port: Optional[int] = None,
|
||||
**kwargs,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
@@ -182,7 +181,7 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
|
||||
prune_ratio: Ratio of neighbors to prune via approximate distance (0.0-1.0)
|
||||
recompute_embeddings: Whether to fetch fresh embeddings from server vs use stored PQ codes
|
||||
pruning_strategy: PQ candidate selection strategy - "global" (default), "local", or "proportional"
|
||||
expected_zmq_port: ZMQ port for embedding server communication
|
||||
zmq_port: ZMQ port for embedding server communication
|
||||
**kwargs: Backend-specific parameters (e.g., batch_size, dedup_node_dis, etc.)
|
||||
|
||||
Returns:
|
||||
|
||||
Reference in New Issue
Block a user