perf: make embedder loading faster by 6x, and embed queries through the server
This commit is contained in:
@@ -18,10 +18,19 @@ import json
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, Optional, Union
|
||||
import sys
|
||||
import logging
|
||||
|
||||
RED = "\033[91m"
|
||||
RESET = "\033[0m"
|
||||
|
||||
# Set up logging based on environment variable
|
||||
LOG_LEVEL = os.getenv('LEANN_LOG_LEVEL', 'INFO').upper()
|
||||
logging.basicConfig(
|
||||
level=getattr(logging, LOG_LEVEL, logging.INFO),
|
||||
format='%(asctime)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def is_similarity_metric():
|
||||
"""
|
||||
@@ -36,6 +45,23 @@ import torch
|
||||
from torch import Tensor
|
||||
import torch.nn.functional as F
|
||||
|
||||
# Timing utilities
|
||||
@contextmanager
|
||||
def timer(name: str, sync_cuda: bool = True):
|
||||
"""Context manager for timing operations with optional CUDA sync"""
|
||||
start_time = time.time()
|
||||
if sync_cuda and torch.cuda.is_available():
|
||||
torch.cuda.synchronize()
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
if sync_cuda and torch.cuda.is_available():
|
||||
torch.cuda.synchronize()
|
||||
elif sync_cuda and torch.backends.mps.is_available():
|
||||
torch.mps.synchronize()
|
||||
elapsed = time.time() - start_time
|
||||
logger.info(f"⏱️ {name}: {elapsed:.4f}s")
|
||||
|
||||
|
||||
def e5_average_pool(last_hidden_states: Tensor, attention_mask: Tensor) -> Tensor:
|
||||
last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
|
||||
@@ -120,13 +146,13 @@ def load_passages_from_metadata(meta_file: str) -> SimplePassageLoader:
|
||||
if passage_data and passage_data.get("text"):
|
||||
return {"text": passage_data["text"]}
|
||||
else:
|
||||
print(f"DEBUG: Empty text for ID {int_id} -> {string_id}")
|
||||
logger.debug(f"Empty text for ID {int_id} -> {string_id}")
|
||||
return {"text": ""}
|
||||
else:
|
||||
print(f"DEBUG: ID {int_id} not found in label_map")
|
||||
logger.debug(f"ID {int_id} not found in label_map")
|
||||
return {"text": ""}
|
||||
except Exception as e:
|
||||
print(f"DEBUG: Exception getting passage {passage_id}: {e}")
|
||||
logger.debug(f"Exception getting passage {passage_id}: {e}")
|
||||
return {"text": ""}
|
||||
|
||||
def __len__(self) -> int:
|
||||
@@ -184,8 +210,21 @@ def create_hnsw_embedding_server(
|
||||
tokenizer = None # MLX handles tokenization separately
|
||||
else: # sentence-transformers
|
||||
print(f"Loading tokenizer for {model_name}...")
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
|
||||
print(f"Tokenizer loaded successfully!")
|
||||
# Optimized tokenizer loading: try local first, then fallback
|
||||
try:
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_name,
|
||||
use_fast=True, # Use fast tokenizer (better runtime perf)
|
||||
local_files_only=True # Avoid network delays
|
||||
)
|
||||
print(f"Tokenizer loaded successfully! (local + fast)")
|
||||
except Exception as e:
|
||||
print(f"Local tokenizer failed ({e}), trying network download...")
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_name,
|
||||
use_fast=True # Use fast tokenizer
|
||||
)
|
||||
print(f"Tokenizer loaded successfully! (network)")
|
||||
|
||||
# Device setup
|
||||
mps_available = hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
|
||||
@@ -217,9 +256,47 @@ def create_hnsw_embedding_server(
|
||||
print("OpenAI API mode - no local model loading required")
|
||||
model = None
|
||||
else:
|
||||
# Use standard transformers for sentence-transformers models
|
||||
model = AutoModel.from_pretrained(model_name).to(device).eval()
|
||||
print(f"Model {model_name} loaded successfully!")
|
||||
# Use optimized transformers loading for sentence-transformers models
|
||||
print(f"Loading model with optimizations...")
|
||||
try:
|
||||
# Ultra-fast loading: preload config + fast_init
|
||||
from transformers import AutoConfig
|
||||
config = AutoConfig.from_pretrained(model_name, local_files_only=True)
|
||||
model = AutoModel.from_pretrained(
|
||||
model_name,
|
||||
config=config,
|
||||
torch_dtype=torch.float16, # Half precision for speed
|
||||
low_cpu_mem_usage=True, # Reduce memory peaks
|
||||
local_files_only=True, # Avoid network delays
|
||||
_fast_init=True # Skip weight init checks
|
||||
).to(device).eval()
|
||||
print(f"Model {model_name} loaded successfully! (ultra-fast)")
|
||||
except Exception as e:
|
||||
print(f"Ultra-fast loading failed ({e}), trying optimized...")
|
||||
try:
|
||||
# Fallback: regular optimized loading
|
||||
model = AutoModel.from_pretrained(
|
||||
model_name,
|
||||
torch_dtype=torch.float16,
|
||||
low_cpu_mem_usage=True,
|
||||
local_files_only=True
|
||||
).to(device).eval()
|
||||
print(f"Model {model_name} loaded successfully! (optimized)")
|
||||
except Exception as e2:
|
||||
print(f"Optimized loading failed ({e2}), trying network...")
|
||||
try:
|
||||
# Fallback: optimized network loading
|
||||
model = AutoModel.from_pretrained(
|
||||
model_name,
|
||||
torch_dtype=torch.float16,
|
||||
low_cpu_mem_usage=True
|
||||
).to(device).eval()
|
||||
print(f"Model {model_name} loaded successfully! (network + optimized)")
|
||||
except Exception as e3:
|
||||
print(f"All optimized methods failed ({e3}), using standard...")
|
||||
# Final fallback: standard loading
|
||||
model = AutoModel.from_pretrained(model_name).to(device).eval()
|
||||
print(f"Model {model_name} loaded successfully! (standard)")
|
||||
|
||||
# Check port availability
|
||||
import socket
|
||||
@@ -370,8 +447,9 @@ def create_hnsw_embedding_server(
|
||||
if embedding_mode == "mlx":
|
||||
return _process_batch_mlx(texts_batch, ids_batch, missing_ids)
|
||||
elif embedding_mode == "openai":
|
||||
from leann.api import compute_embeddings_openai
|
||||
return compute_embeddings_openai(texts_batch, model_name)
|
||||
with timer("OpenAI API call", sync_cuda=False):
|
||||
from leann.api import compute_embeddings_openai
|
||||
return compute_embeddings_openai(texts_batch, model_name)
|
||||
|
||||
_is_e5_model = "e5" in model_name.lower()
|
||||
_is_bge_model = "bge" in model_name.lower()
|
||||
@@ -417,44 +495,46 @@ def create_hnsw_embedding_server(
|
||||
enc = {k: v.to(device) for k, v in encoded_batch.items()}
|
||||
|
||||
with torch.no_grad():
|
||||
with embed_timer.timing():
|
||||
out = model(enc["input_ids"], enc["attention_mask"])
|
||||
with timer("Model forward pass"):
|
||||
with embed_timer.timing():
|
||||
out = model(enc["input_ids"], enc["attention_mask"])
|
||||
|
||||
with pool_timer.timing():
|
||||
if _is_bge_model:
|
||||
pooled_embeddings = out.last_hidden_state[:, 0]
|
||||
elif not hasattr(out, "last_hidden_state"):
|
||||
if isinstance(out, torch.Tensor) and len(out.shape) == 2:
|
||||
pooled_embeddings = out
|
||||
with timer("Pooling"):
|
||||
with pool_timer.timing():
|
||||
if _is_bge_model:
|
||||
pooled_embeddings = out.last_hidden_state[:, 0]
|
||||
elif not hasattr(out, "last_hidden_state"):
|
||||
if isinstance(out, torch.Tensor) and len(out.shape) == 2:
|
||||
pooled_embeddings = out
|
||||
else:
|
||||
print(
|
||||
f"{RED}ERROR: Cannot determine how to pool. Output shape: {out.shape if isinstance(out, torch.Tensor) else 'N/A'}{RESET}"
|
||||
)
|
||||
hidden_dim = getattr(
|
||||
model.config, "hidden_size", 384 if _is_e5_model else 768
|
||||
)
|
||||
pooled_embeddings = torch.zeros(
|
||||
(batch_size, hidden_dim),
|
||||
device=device,
|
||||
dtype=enc["input_ids"].dtype
|
||||
if hasattr(enc["input_ids"], "dtype")
|
||||
else torch.float32,
|
||||
)
|
||||
elif _is_e5_model:
|
||||
pooled_embeddings = e5_average_pool(
|
||||
out.last_hidden_state, enc["attention_mask"]
|
||||
)
|
||||
else:
|
||||
print(
|
||||
f"{RED}ERROR: Cannot determine how to pool. Output shape: {out.shape if isinstance(out, torch.Tensor) else 'N/A'}{RESET}"
|
||||
hidden_states = out.last_hidden_state
|
||||
mask_expanded = (
|
||||
enc["attention_mask"]
|
||||
.unsqueeze(-1)
|
||||
.expand(hidden_states.size())
|
||||
.float()
|
||||
)
|
||||
hidden_dim = getattr(
|
||||
model.config, "hidden_size", 384 if _is_e5_model else 768
|
||||
)
|
||||
pooled_embeddings = torch.zeros(
|
||||
(batch_size, hidden_dim),
|
||||
device=device,
|
||||
dtype=enc["input_ids"].dtype
|
||||
if hasattr(enc["input_ids"], "dtype")
|
||||
else torch.float32,
|
||||
)
|
||||
elif _is_e5_model:
|
||||
pooled_embeddings = e5_average_pool(
|
||||
out.last_hidden_state, enc["attention_mask"]
|
||||
)
|
||||
else:
|
||||
hidden_states = out.last_hidden_state
|
||||
mask_expanded = (
|
||||
enc["attention_mask"]
|
||||
.unsqueeze(-1)
|
||||
.expand(hidden_states.size())
|
||||
.float()
|
||||
)
|
||||
sum_embeddings = torch.sum(hidden_states * mask_expanded, 1)
|
||||
sum_mask = torch.clamp(mask_expanded.sum(1), min=1e-9)
|
||||
pooled_embeddings = sum_embeddings / sum_mask
|
||||
sum_embeddings = torch.sum(hidden_states * mask_expanded, 1)
|
||||
sum_mask = torch.clamp(mask_expanded.sum(1), min=1e-9)
|
||||
pooled_embeddings = sum_embeddings / sum_mask
|
||||
|
||||
final_embeddings = pooled_embeddings
|
||||
if _is_e5_model or _is_bge_model:
|
||||
@@ -536,7 +616,7 @@ def create_hnsw_embedding_server(
|
||||
|
||||
def zmq_server_thread():
|
||||
"""ZMQ server thread"""
|
||||
nonlocal passages, model, tokenizer, model_name
|
||||
nonlocal passages, model, tokenizer, model_name, embedding_mode
|
||||
context = zmq.Context()
|
||||
socket = context.socket(zmq.REP)
|
||||
socket.bind(f"tcp://*:{zmq_port}")
|
||||
@@ -556,13 +636,13 @@ def create_hnsw_embedding_server(
|
||||
try:
|
||||
request_payload = msgpack.unpackb(message_bytes)
|
||||
if isinstance(request_payload, list):
|
||||
print(f"DEBUG: request_payload length: {len(request_payload)}")
|
||||
logger.debug(f"request_payload length: {len(request_payload)}")
|
||||
for i, item in enumerate(request_payload):
|
||||
print(
|
||||
f"DEBUG: request_payload[{i}]: {type(item)} - {item if len(str(item)) < 100 else str(item)[:100] + '...'}"
|
||||
)
|
||||
|
||||
# Handle control messages for meta path and model management
|
||||
# Handle control messages for meta path and model management FIRST
|
||||
if isinstance(request_payload, list) and len(request_payload) >= 1:
|
||||
if request_payload[0] == "__QUERY_META_PATH__":
|
||||
# Return the current meta path being used by the server
|
||||
@@ -617,19 +697,61 @@ def create_hnsw_embedding_server(
|
||||
)
|
||||
|
||||
# Clean up old model to free memory
|
||||
print("INFO: Releasing old model from memory...")
|
||||
logger.info("Releasing old model from memory...")
|
||||
old_model = model
|
||||
old_tokenizer = tokenizer
|
||||
|
||||
# Load new tokenizer first
|
||||
# Load new tokenizer first (optimized)
|
||||
print(f"Loading new tokenizer for {new_model_name}...")
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
new_model_name, use_fast=True
|
||||
)
|
||||
try:
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
new_model_name,
|
||||
use_fast=True,
|
||||
local_files_only=True
|
||||
)
|
||||
print(f"New tokenizer loaded! (local + fast)")
|
||||
except:
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
new_model_name,
|
||||
use_fast=True
|
||||
)
|
||||
print(f"New tokenizer loaded! (network + fast)")
|
||||
|
||||
# Load new model
|
||||
# Load new model (optimized)
|
||||
print(f"Loading new model {new_model_name}...")
|
||||
model = AutoModel.from_pretrained(new_model_name)
|
||||
try:
|
||||
# Ultra-fast model switching
|
||||
from transformers import AutoConfig
|
||||
config = AutoConfig.from_pretrained(new_model_name, local_files_only=True)
|
||||
model = AutoModel.from_pretrained(
|
||||
new_model_name,
|
||||
config=config,
|
||||
torch_dtype=torch.float16,
|
||||
low_cpu_mem_usage=True,
|
||||
local_files_only=True,
|
||||
_fast_init=True
|
||||
)
|
||||
print(f"New model loaded! (ultra-fast)")
|
||||
except:
|
||||
try:
|
||||
model = AutoModel.from_pretrained(
|
||||
new_model_name,
|
||||
torch_dtype=torch.float16,
|
||||
low_cpu_mem_usage=True,
|
||||
local_files_only=True
|
||||
)
|
||||
print(f"New model loaded! (optimized)")
|
||||
except:
|
||||
try:
|
||||
model = AutoModel.from_pretrained(
|
||||
new_model_name,
|
||||
torch_dtype=torch.float16,
|
||||
low_cpu_mem_usage=True
|
||||
)
|
||||
print(f"New model loaded! (network + optimized)")
|
||||
except:
|
||||
model = AutoModel.from_pretrained(new_model_name)
|
||||
print(f"New model loaded! (standard)")
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
||||
@@ -640,19 +762,27 @@ def create_hnsw_embedding_server(
|
||||
# Clear GPU cache if available
|
||||
if device.type == "cuda":
|
||||
torch.cuda.empty_cache()
|
||||
print("INFO: Cleared CUDA cache")
|
||||
logger.info("Cleared CUDA cache")
|
||||
elif device.type == "mps":
|
||||
torch.mps.empty_cache()
|
||||
print("INFO: Cleared MPS cache")
|
||||
logger.info("Cleared MPS cache")
|
||||
|
||||
# Update model name
|
||||
model_name = new_model_name
|
||||
|
||||
# Re-detect embedding mode based on new model name
|
||||
if model_name.startswith("text-embedding-"):
|
||||
embedding_mode = "openai"
|
||||
logger.info(f"Auto-detected embedding mode: openai for {model_name}")
|
||||
else:
|
||||
embedding_mode = "sentence-transformers"
|
||||
logger.info(f"Auto-detected embedding mode: sentence-transformers for {model_name}")
|
||||
|
||||
# Force garbage collection
|
||||
import gc
|
||||
|
||||
gc.collect()
|
||||
print("INFO: Memory cleanup completed")
|
||||
logger.info("Memory cleanup completed")
|
||||
|
||||
response = ["SUCCESS"]
|
||||
print(
|
||||
@@ -664,6 +794,32 @@ def create_hnsw_embedding_server(
|
||||
socket.send(msgpack.packb(response))
|
||||
continue
|
||||
|
||||
# Handle direct text embedding request (for OpenAI and sentence-transformers)
|
||||
if isinstance(request_payload, list) and len(request_payload) > 0:
|
||||
# Check if this is a direct text request (list of strings) and NOT a control message
|
||||
if (all(isinstance(item, str) for item in request_payload) and
|
||||
not request_payload[0].startswith("__")):
|
||||
logger.info(f"Processing direct text embedding request for {len(request_payload)} texts in {embedding_mode} mode")
|
||||
|
||||
try:
|
||||
if embedding_mode == "openai":
|
||||
from leann.api import compute_embeddings_openai
|
||||
embeddings = compute_embeddings_openai(request_payload, model_name)
|
||||
else:
|
||||
# sentence-transformers mode - compute directly
|
||||
with timer(f"Direct text embedding ({len(request_payload)} texts)"):
|
||||
embeddings = process_batch(request_payload, [], [])
|
||||
|
||||
response = embeddings.tolist()
|
||||
socket.send(msgpack.packb(response))
|
||||
e2e_end = time.time()
|
||||
logger.info(f"⏱️ Text embedding E2E time: {e2e_end - e2e_start:.6f}s")
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.error(f"ERROR: Failed to compute {embedding_mode} embeddings: {e}")
|
||||
socket.send(msgpack.packb([]))
|
||||
continue
|
||||
|
||||
# Handle distance calculation requests
|
||||
if (
|
||||
isinstance(request_payload, list)
|
||||
@@ -674,7 +830,7 @@ def create_hnsw_embedding_server(
|
||||
node_ids = request_payload[0]
|
||||
query_vector = np.array(request_payload[1], dtype=np.float32)
|
||||
|
||||
print("DEBUG: Distance calculation request received")
|
||||
logger.debug("Distance calculation request received")
|
||||
print(f" Node IDs: {node_ids}")
|
||||
print(f" Query vector dim: {len(query_vector)}")
|
||||
print(f" Passages loaded: {len(passages)}")
|
||||
@@ -684,7 +840,7 @@ def create_hnsw_embedding_server(
|
||||
missing_ids = []
|
||||
with lookup_timer.timing():
|
||||
for nid in node_ids:
|
||||
print(f"DEBUG: Looking up passage ID {nid}")
|
||||
logger.debug(f"Looking up passage ID {nid}")
|
||||
try:
|
||||
txtinfo = passages[nid]
|
||||
if txtinfo is None:
|
||||
@@ -804,29 +960,11 @@ def create_hnsw_embedding_server(
|
||||
elif device.type == "mps":
|
||||
torch.mps.synchronize()
|
||||
e2e_end = time.time()
|
||||
print(
|
||||
f"Distance calculation E2E time: {e2e_end - e2e_start:.6f} seconds"
|
||||
logger.info(
|
||||
f"⏱️ Distance calculation E2E time: {e2e_end - e2e_start:.6f}s"
|
||||
)
|
||||
continue
|
||||
|
||||
# Handle direct text embedding request (for OpenAI mode)
|
||||
if embedding_mode == "openai" and isinstance(request_payload, list) and len(request_payload) > 0:
|
||||
# Check if this is a direct text request (list of strings)
|
||||
if all(isinstance(item, str) for item in request_payload):
|
||||
print(f"Processing direct text embedding request for {len(request_payload)} texts")
|
||||
|
||||
try:
|
||||
from leann.api import compute_embeddings_openai
|
||||
embeddings = compute_embeddings_openai(request_payload, model_name)
|
||||
response = embeddings.tolist()
|
||||
socket.send(msgpack.packb(response))
|
||||
e2e_end = time.time()
|
||||
print(f"Text embedding E2E time: {e2e_end - e2e_start:.6f} seconds")
|
||||
continue
|
||||
except Exception as e:
|
||||
print(f"ERROR: Failed to compute OpenAI embeddings: {e}")
|
||||
socket.send(msgpack.packb([]))
|
||||
continue
|
||||
|
||||
# Standard embedding request (passage ID lookup)
|
||||
if (
|
||||
@@ -945,10 +1083,10 @@ def create_hnsw_embedding_server(
|
||||
elif device.type == "mps":
|
||||
torch.mps.synchronize()
|
||||
e2e_end = time.time()
|
||||
print(f"ZMQ E2E time: {e2e_end - e2e_start:.6f} seconds")
|
||||
logger.info(f"⏱️ ZMQ E2E time: {e2e_end - e2e_start:.6f}s")
|
||||
|
||||
except zmq.Again:
|
||||
print("ZMQ socket timeout, continuing to listen")
|
||||
logger.debug("ZMQ socket timeout, continuing to listen")
|
||||
continue
|
||||
except Exception as e:
|
||||
print(f"Error in ZMQ server loop: {e}")
|
||||
|
||||
Reference in New Issue
Block a user