fix: cache the loaded model

This commit is contained in:
Andy Lee
2025-07-21 21:20:53 -07:00
parent 727724990e
commit b3970793cf
9 changed files with 163 additions and 146 deletions

View File

@@ -6,11 +6,14 @@ Preserves all optimization parameters to ensure performance
import numpy as np
import torch
from typing import List
from typing import List, Dict, Any, Optional
import logging
logger = logging.getLogger(__name__)
# Global model cache to avoid repeated loading
_model_cache: Dict[str, Any] = {}
def compute_embeddings(
texts: List[str], model_name: str, mode: str = "sentence-transformers",is_build: bool = False
@@ -45,25 +48,12 @@ def compute_embeddings_sentence_transformers(
is_build: bool = False,
) -> np.ndarray:
"""
Compute embeddings using SentenceTransformer
Preserves all optimization parameters to ensure consistency with original embedding_server
Args:
texts: List of texts to compute embeddings for
model_name: SentenceTransformer model name
use_fp16: Whether to use FP16 precision
device: Device selection ('auto', 'cuda', 'mps', 'cpu')
batch_size: Batch size for processing
Returns:
Normalized embeddings array, shape: (len(texts), embedding_dim)
Compute embeddings using SentenceTransformer with model caching
"""
print(
f"INFO: Computing embeddings for {len(texts)} texts using SentenceTransformer, model: '{model_name}'"
)
from sentence_transformers import SentenceTransformer
# Auto-detect device
if device == "auto":
if torch.cuda.is_available():
@@ -73,62 +63,72 @@ def compute_embeddings_sentence_transformers(
else:
device = "cpu"
print(f"INFO: Using device: {device}")
# Create cache key
cache_key = f"sentence_transformers_{model_name}_{device}_{use_fp16}"
# Check if model is already cached
if cache_key in _model_cache:
print(f"INFO: Using cached model: {model_name}")
model = _model_cache[cache_key]
else:
print(f"INFO: Loading and caching SentenceTransformer model: {model_name}")
from sentence_transformers import SentenceTransformer
# Prepare model and tokenizer optimization parameters (consistent with original embedding_server)
model_kwargs = {
"torch_dtype": torch.float16 if use_fp16 else torch.float32,
"low_cpu_mem_usage": True,
"_fast_init": True, # Skip weight initialization checks for faster loading
}
print(f"INFO: Using device: {device}")
tokenizer_kwargs = {
"use_fast": True, # Use fast tokenizer for better runtime performance
}
# Prepare model and tokenizer optimization parameters
model_kwargs = {
"torch_dtype": torch.float16 if use_fp16 else torch.float32,
"low_cpu_mem_usage": True,
"_fast_init": True,
}
# Load SentenceTransformer (try local first, then network)
print(f"INFO: Loading SentenceTransformer model: {model_name}")
tokenizer_kwargs = {
"use_fast": True,
}
try:
# Try local loading (avoid network delays)
model_kwargs["local_files_only"] = True
tokenizer_kwargs["local_files_only"] = True
model = SentenceTransformer(
model_name,
device=device,
model_kwargs=model_kwargs,
tokenizer_kwargs=tokenizer_kwargs,
local_files_only=True,
)
print("✅ Model loaded successfully! (local + optimized)")
except Exception as e:
print(f"Local loading failed ({e}), trying network download...")
# Fallback to network loading
model_kwargs["local_files_only"] = False
tokenizer_kwargs["local_files_only"] = False
model = SentenceTransformer(
model_name,
device=device,
model_kwargs=model_kwargs,
tokenizer_kwargs=tokenizer_kwargs,
local_files_only=False,
)
print("✅ Model loaded successfully! (network + optimized)")
# Apply additional optimizations (if supported)
if use_fp16 and device in ["cuda", "mps"]:
try:
model = model.half()
model = torch.compile(model)
print(f"✅ Using FP16 precision and compile optimization: {model_name}")
except Exception as e:
print(
f"FP16 or compile optimization failed, continuing with default settings: {e}"
)
# Try local loading first
model_kwargs["local_files_only"] = True
tokenizer_kwargs["local_files_only"] = True
# Compute embeddings (using SentenceTransformer's optimized implementation)
model = SentenceTransformer(
model_name,
device=device,
model_kwargs=model_kwargs,
tokenizer_kwargs=tokenizer_kwargs,
local_files_only=True,
)
print("✅ Model loaded successfully! (local + optimized)")
except Exception as e:
print(f"Local loading failed ({e}), trying network download...")
# Fallback to network loading
model_kwargs["local_files_only"] = False
tokenizer_kwargs["local_files_only"] = False
model = SentenceTransformer(
model_name,
device=device,
model_kwargs=model_kwargs,
tokenizer_kwargs=tokenizer_kwargs,
local_files_only=False,
)
print("✅ Model loaded successfully! (network + optimized)")
# Apply additional optimizations (if supported)
if use_fp16 and device in ["cuda", "mps"]:
try:
model = model.half()
model = torch.compile(model)
print(f"✅ Using FP16 precision and compile optimization: {model_name}")
except Exception as e:
print(f"FP16 or compile optimization failed: {e}")
# Cache the model
_model_cache[cache_key] = model
print(f"✅ Model cached: {cache_key}")
# Compute embeddings
print("INFO: Starting embedding computation...")
embeddings = model.encode(
@@ -136,7 +136,7 @@ def compute_embeddings_sentence_transformers(
batch_size=batch_size,
show_progress_bar=is_build, # Don't show progress bar in server environment
convert_to_numpy=True,
normalize_embeddings=False, # Keep consistent with original API behavior
normalize_embeddings=False,
device=device,
)
@@ -166,7 +166,14 @@ def compute_embeddings_openai(texts: List[str], model_name: str) -> np.ndarray:
if not api_key:
raise RuntimeError("OPENAI_API_KEY environment variable not set")
client = openai.OpenAI(api_key=api_key)
# Cache OpenAI client
cache_key = "openai_client"
if cache_key in _model_cache:
client = _model_cache[cache_key]
else:
client = openai.OpenAI(api_key=api_key)
_model_cache[cache_key] = client
print("✅ OpenAI client cached")
print(
f"INFO: Computing embeddings for {len(texts)} texts using OpenAI API, model: '{model_name}'"
@@ -214,7 +221,6 @@ def compute_embeddings_mlx(
try:
import mlx.core as mx
from mlx_lm.utils import load
from tqdm import tqdm
except ImportError as e:
raise RuntimeError(
"MLX or related libraries not available. Install with: uv pip install mlx mlx-lm"
@@ -224,8 +230,16 @@ def compute_embeddings_mlx(
f"INFO: Computing embeddings for {len(chunks)} chunks using MLX model '{model_name}' with batch_size={batch_size}..."
)
# Load model and tokenizer
model, tokenizer = load(model_name)
# Cache MLX model and tokenizer
cache_key = f"mlx_{model_name}"
if cache_key in _model_cache:
print(f"INFO: Using cached MLX model: {model_name}")
model, tokenizer = _model_cache[cache_key]
else:
print(f"INFO: Loading and caching MLX model: {model_name}")
model, tokenizer = load(model_name)
_model_cache[cache_key] = (model, tokenizer)
print(f"✅ MLX model cached: {cache_key}")
# Process chunks in batches with progress bar
all_embeddings = []