perf: make embedder loading faster by 6x, and embed queries through the server
This commit is contained in:
@@ -16,8 +16,17 @@ import zmq
|
||||
import numpy as np
|
||||
import msgpack
|
||||
from pathlib import Path
|
||||
import logging
|
||||
|
||||
RED = "\033[91m"
|
||||
|
||||
# Set up logging based on environment variable
|
||||
LOG_LEVEL = os.getenv('LEANN_LOG_LEVEL', 'INFO').upper()
|
||||
logging.basicConfig(
|
||||
level=getattr(logging, LOG_LEVEL, logging.INFO),
|
||||
format='%(asctime)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
RESET = "\033[0m"
|
||||
|
||||
# --- New Passage Loader from HNSW backend ---
|
||||
@@ -169,7 +178,7 @@ def create_embedding_server_thread(
|
||||
在当前线程中创建并运行 embedding server
|
||||
这个函数设计为在单独的线程中调用
|
||||
"""
|
||||
print(f"INFO: Initializing embedding server thread on port {zmq_port}")
|
||||
logger.info(f"Initializing embedding server thread on port {zmq_port}")
|
||||
|
||||
try:
|
||||
# 检查端口是否已被占用
|
||||
@@ -189,7 +198,7 @@ def create_embedding_server_thread(
|
||||
if embedding_mode == "mlx":
|
||||
from leann.api import compute_embeddings_mlx
|
||||
import torch
|
||||
print("INFO: Using MLX for embeddings")
|
||||
logger.info("Using MLX for embeddings")
|
||||
# Set device to CPU for compatibility with DeviceTimer class
|
||||
device = torch.device("cpu")
|
||||
cuda_available = False
|
||||
@@ -197,7 +206,7 @@ def create_embedding_server_thread(
|
||||
elif embedding_mode == "openai":
|
||||
from leann.api import compute_embeddings_openai
|
||||
import torch
|
||||
print("INFO: Using OpenAI API for embeddings")
|
||||
logger.info("Using OpenAI API for embeddings")
|
||||
# Set device to CPU for compatibility with DeviceTimer class
|
||||
device = torch.device("cpu")
|
||||
cuda_available = False
|
||||
@@ -213,16 +222,16 @@ def create_embedding_server_thread(
|
||||
|
||||
if cuda_available:
|
||||
device = torch.device("cuda")
|
||||
print("INFO: Using CUDA device")
|
||||
logger.info("Using CUDA device")
|
||||
elif mps_available:
|
||||
device = torch.device("mps")
|
||||
print("INFO: Using MPS device (Apple Silicon)")
|
||||
logger.info("Using MPS device (Apple Silicon)")
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
print("INFO: Using CPU device")
|
||||
logger.info("Using CPU device")
|
||||
|
||||
# 加载模型
|
||||
print(f"INFO: Loading model {model_name}")
|
||||
logger.info(f"Loading model {model_name}")
|
||||
model = AutoModel.from_pretrained(model_name).to(device).eval()
|
||||
|
||||
# 优化模型
|
||||
@@ -230,7 +239,7 @@ def create_embedding_server_thread(
|
||||
try:
|
||||
model = model.half()
|
||||
model = torch.compile(model)
|
||||
print(f"INFO: Using FP16 precision with model: {model_name}")
|
||||
logger.info(f"Using FP16 precision with model: {model_name}")
|
||||
except Exception as e:
|
||||
print(f"WARNING: Model optimization failed: {e}")
|
||||
else:
|
||||
@@ -256,7 +265,7 @@ def create_embedding_server_thread(
|
||||
print("WARNING: No passages file provided or file not found. Using an empty passage loader.")
|
||||
passages = SimplePassageLoader()
|
||||
|
||||
print(f"INFO: Loaded {len(passages)} passages.")
|
||||
logger.info(f"Loaded {len(passages)} passages.")
|
||||
|
||||
def client_warmup(zmq_port):
|
||||
"""Perform client-side warmup for DiskANN server"""
|
||||
@@ -365,7 +374,7 @@ def create_embedding_server_thread(
|
||||
def process_batch_pytorch(texts_batch, ids_batch, missing_ids):
|
||||
"""处理文本批次"""
|
||||
batch_size = len(texts_batch)
|
||||
print(f"INFO: Processing batch of size {batch_size}")
|
||||
logger.info(f"Processing batch of size {batch_size}")
|
||||
|
||||
tokenize_timer = DeviceTimer("tokenization (batch)", device)
|
||||
to_device_timer = DeviceTimer("transfer to device (batch)", device)
|
||||
|
||||
@@ -18,10 +18,19 @@ import json
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, Optional, Union
|
||||
import sys
|
||||
import logging
|
||||
|
||||
RED = "\033[91m"
|
||||
RESET = "\033[0m"
|
||||
|
||||
# Set up logging based on environment variable
|
||||
LOG_LEVEL = os.getenv('LEANN_LOG_LEVEL', 'INFO').upper()
|
||||
logging.basicConfig(
|
||||
level=getattr(logging, LOG_LEVEL, logging.INFO),
|
||||
format='%(asctime)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def is_similarity_metric():
|
||||
"""
|
||||
@@ -36,6 +45,23 @@ import torch
|
||||
from torch import Tensor
|
||||
import torch.nn.functional as F
|
||||
|
||||
# Timing utilities
|
||||
@contextmanager
|
||||
def timer(name: str, sync_cuda: bool = True):
|
||||
"""Context manager for timing operations with optional CUDA sync"""
|
||||
start_time = time.time()
|
||||
if sync_cuda and torch.cuda.is_available():
|
||||
torch.cuda.synchronize()
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
if sync_cuda and torch.cuda.is_available():
|
||||
torch.cuda.synchronize()
|
||||
elif sync_cuda and torch.backends.mps.is_available():
|
||||
torch.mps.synchronize()
|
||||
elapsed = time.time() - start_time
|
||||
logger.info(f"⏱️ {name}: {elapsed:.4f}s")
|
||||
|
||||
|
||||
def e5_average_pool(last_hidden_states: Tensor, attention_mask: Tensor) -> Tensor:
|
||||
last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
|
||||
@@ -120,13 +146,13 @@ def load_passages_from_metadata(meta_file: str) -> SimplePassageLoader:
|
||||
if passage_data and passage_data.get("text"):
|
||||
return {"text": passage_data["text"]}
|
||||
else:
|
||||
print(f"DEBUG: Empty text for ID {int_id} -> {string_id}")
|
||||
logger.debug(f"Empty text for ID {int_id} -> {string_id}")
|
||||
return {"text": ""}
|
||||
else:
|
||||
print(f"DEBUG: ID {int_id} not found in label_map")
|
||||
logger.debug(f"ID {int_id} not found in label_map")
|
||||
return {"text": ""}
|
||||
except Exception as e:
|
||||
print(f"DEBUG: Exception getting passage {passage_id}: {e}")
|
||||
logger.debug(f"Exception getting passage {passage_id}: {e}")
|
||||
return {"text": ""}
|
||||
|
||||
def __len__(self) -> int:
|
||||
@@ -184,8 +210,21 @@ def create_hnsw_embedding_server(
|
||||
tokenizer = None # MLX handles tokenization separately
|
||||
else: # sentence-transformers
|
||||
print(f"Loading tokenizer for {model_name}...")
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
|
||||
print(f"Tokenizer loaded successfully!")
|
||||
# Optimized tokenizer loading: try local first, then fallback
|
||||
try:
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_name,
|
||||
use_fast=True, # Use fast tokenizer (better runtime perf)
|
||||
local_files_only=True # Avoid network delays
|
||||
)
|
||||
print(f"Tokenizer loaded successfully! (local + fast)")
|
||||
except Exception as e:
|
||||
print(f"Local tokenizer failed ({e}), trying network download...")
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_name,
|
||||
use_fast=True # Use fast tokenizer
|
||||
)
|
||||
print(f"Tokenizer loaded successfully! (network)")
|
||||
|
||||
# Device setup
|
||||
mps_available = hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
|
||||
@@ -217,9 +256,47 @@ def create_hnsw_embedding_server(
|
||||
print("OpenAI API mode - no local model loading required")
|
||||
model = None
|
||||
else:
|
||||
# Use standard transformers for sentence-transformers models
|
||||
model = AutoModel.from_pretrained(model_name).to(device).eval()
|
||||
print(f"Model {model_name} loaded successfully!")
|
||||
# Use optimized transformers loading for sentence-transformers models
|
||||
print(f"Loading model with optimizations...")
|
||||
try:
|
||||
# Ultra-fast loading: preload config + fast_init
|
||||
from transformers import AutoConfig
|
||||
config = AutoConfig.from_pretrained(model_name, local_files_only=True)
|
||||
model = AutoModel.from_pretrained(
|
||||
model_name,
|
||||
config=config,
|
||||
torch_dtype=torch.float16, # Half precision for speed
|
||||
low_cpu_mem_usage=True, # Reduce memory peaks
|
||||
local_files_only=True, # Avoid network delays
|
||||
_fast_init=True # Skip weight init checks
|
||||
).to(device).eval()
|
||||
print(f"Model {model_name} loaded successfully! (ultra-fast)")
|
||||
except Exception as e:
|
||||
print(f"Ultra-fast loading failed ({e}), trying optimized...")
|
||||
try:
|
||||
# Fallback: regular optimized loading
|
||||
model = AutoModel.from_pretrained(
|
||||
model_name,
|
||||
torch_dtype=torch.float16,
|
||||
low_cpu_mem_usage=True,
|
||||
local_files_only=True
|
||||
).to(device).eval()
|
||||
print(f"Model {model_name} loaded successfully! (optimized)")
|
||||
except Exception as e2:
|
||||
print(f"Optimized loading failed ({e2}), trying network...")
|
||||
try:
|
||||
# Fallback: optimized network loading
|
||||
model = AutoModel.from_pretrained(
|
||||
model_name,
|
||||
torch_dtype=torch.float16,
|
||||
low_cpu_mem_usage=True
|
||||
).to(device).eval()
|
||||
print(f"Model {model_name} loaded successfully! (network + optimized)")
|
||||
except Exception as e3:
|
||||
print(f"All optimized methods failed ({e3}), using standard...")
|
||||
# Final fallback: standard loading
|
||||
model = AutoModel.from_pretrained(model_name).to(device).eval()
|
||||
print(f"Model {model_name} loaded successfully! (standard)")
|
||||
|
||||
# Check port availability
|
||||
import socket
|
||||
@@ -370,8 +447,9 @@ def create_hnsw_embedding_server(
|
||||
if embedding_mode == "mlx":
|
||||
return _process_batch_mlx(texts_batch, ids_batch, missing_ids)
|
||||
elif embedding_mode == "openai":
|
||||
from leann.api import compute_embeddings_openai
|
||||
return compute_embeddings_openai(texts_batch, model_name)
|
||||
with timer("OpenAI API call", sync_cuda=False):
|
||||
from leann.api import compute_embeddings_openai
|
||||
return compute_embeddings_openai(texts_batch, model_name)
|
||||
|
||||
_is_e5_model = "e5" in model_name.lower()
|
||||
_is_bge_model = "bge" in model_name.lower()
|
||||
@@ -417,44 +495,46 @@ def create_hnsw_embedding_server(
|
||||
enc = {k: v.to(device) for k, v in encoded_batch.items()}
|
||||
|
||||
with torch.no_grad():
|
||||
with embed_timer.timing():
|
||||
out = model(enc["input_ids"], enc["attention_mask"])
|
||||
with timer("Model forward pass"):
|
||||
with embed_timer.timing():
|
||||
out = model(enc["input_ids"], enc["attention_mask"])
|
||||
|
||||
with pool_timer.timing():
|
||||
if _is_bge_model:
|
||||
pooled_embeddings = out.last_hidden_state[:, 0]
|
||||
elif not hasattr(out, "last_hidden_state"):
|
||||
if isinstance(out, torch.Tensor) and len(out.shape) == 2:
|
||||
pooled_embeddings = out
|
||||
with timer("Pooling"):
|
||||
with pool_timer.timing():
|
||||
if _is_bge_model:
|
||||
pooled_embeddings = out.last_hidden_state[:, 0]
|
||||
elif not hasattr(out, "last_hidden_state"):
|
||||
if isinstance(out, torch.Tensor) and len(out.shape) == 2:
|
||||
pooled_embeddings = out
|
||||
else:
|
||||
print(
|
||||
f"{RED}ERROR: Cannot determine how to pool. Output shape: {out.shape if isinstance(out, torch.Tensor) else 'N/A'}{RESET}"
|
||||
)
|
||||
hidden_dim = getattr(
|
||||
model.config, "hidden_size", 384 if _is_e5_model else 768
|
||||
)
|
||||
pooled_embeddings = torch.zeros(
|
||||
(batch_size, hidden_dim),
|
||||
device=device,
|
||||
dtype=enc["input_ids"].dtype
|
||||
if hasattr(enc["input_ids"], "dtype")
|
||||
else torch.float32,
|
||||
)
|
||||
elif _is_e5_model:
|
||||
pooled_embeddings = e5_average_pool(
|
||||
out.last_hidden_state, enc["attention_mask"]
|
||||
)
|
||||
else:
|
||||
print(
|
||||
f"{RED}ERROR: Cannot determine how to pool. Output shape: {out.shape if isinstance(out, torch.Tensor) else 'N/A'}{RESET}"
|
||||
hidden_states = out.last_hidden_state
|
||||
mask_expanded = (
|
||||
enc["attention_mask"]
|
||||
.unsqueeze(-1)
|
||||
.expand(hidden_states.size())
|
||||
.float()
|
||||
)
|
||||
hidden_dim = getattr(
|
||||
model.config, "hidden_size", 384 if _is_e5_model else 768
|
||||
)
|
||||
pooled_embeddings = torch.zeros(
|
||||
(batch_size, hidden_dim),
|
||||
device=device,
|
||||
dtype=enc["input_ids"].dtype
|
||||
if hasattr(enc["input_ids"], "dtype")
|
||||
else torch.float32,
|
||||
)
|
||||
elif _is_e5_model:
|
||||
pooled_embeddings = e5_average_pool(
|
||||
out.last_hidden_state, enc["attention_mask"]
|
||||
)
|
||||
else:
|
||||
hidden_states = out.last_hidden_state
|
||||
mask_expanded = (
|
||||
enc["attention_mask"]
|
||||
.unsqueeze(-1)
|
||||
.expand(hidden_states.size())
|
||||
.float()
|
||||
)
|
||||
sum_embeddings = torch.sum(hidden_states * mask_expanded, 1)
|
||||
sum_mask = torch.clamp(mask_expanded.sum(1), min=1e-9)
|
||||
pooled_embeddings = sum_embeddings / sum_mask
|
||||
sum_embeddings = torch.sum(hidden_states * mask_expanded, 1)
|
||||
sum_mask = torch.clamp(mask_expanded.sum(1), min=1e-9)
|
||||
pooled_embeddings = sum_embeddings / sum_mask
|
||||
|
||||
final_embeddings = pooled_embeddings
|
||||
if _is_e5_model or _is_bge_model:
|
||||
@@ -536,7 +616,7 @@ def create_hnsw_embedding_server(
|
||||
|
||||
def zmq_server_thread():
|
||||
"""ZMQ server thread"""
|
||||
nonlocal passages, model, tokenizer, model_name
|
||||
nonlocal passages, model, tokenizer, model_name, embedding_mode
|
||||
context = zmq.Context()
|
||||
socket = context.socket(zmq.REP)
|
||||
socket.bind(f"tcp://*:{zmq_port}")
|
||||
@@ -556,13 +636,13 @@ def create_hnsw_embedding_server(
|
||||
try:
|
||||
request_payload = msgpack.unpackb(message_bytes)
|
||||
if isinstance(request_payload, list):
|
||||
print(f"DEBUG: request_payload length: {len(request_payload)}")
|
||||
logger.debug(f"request_payload length: {len(request_payload)}")
|
||||
for i, item in enumerate(request_payload):
|
||||
print(
|
||||
f"DEBUG: request_payload[{i}]: {type(item)} - {item if len(str(item)) < 100 else str(item)[:100] + '...'}"
|
||||
)
|
||||
|
||||
# Handle control messages for meta path and model management
|
||||
# Handle control messages for meta path and model management FIRST
|
||||
if isinstance(request_payload, list) and len(request_payload) >= 1:
|
||||
if request_payload[0] == "__QUERY_META_PATH__":
|
||||
# Return the current meta path being used by the server
|
||||
@@ -617,19 +697,61 @@ def create_hnsw_embedding_server(
|
||||
)
|
||||
|
||||
# Clean up old model to free memory
|
||||
print("INFO: Releasing old model from memory...")
|
||||
logger.info("Releasing old model from memory...")
|
||||
old_model = model
|
||||
old_tokenizer = tokenizer
|
||||
|
||||
# Load new tokenizer first
|
||||
# Load new tokenizer first (optimized)
|
||||
print(f"Loading new tokenizer for {new_model_name}...")
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
new_model_name, use_fast=True
|
||||
)
|
||||
try:
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
new_model_name,
|
||||
use_fast=True,
|
||||
local_files_only=True
|
||||
)
|
||||
print(f"New tokenizer loaded! (local + fast)")
|
||||
except:
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
new_model_name,
|
||||
use_fast=True
|
||||
)
|
||||
print(f"New tokenizer loaded! (network + fast)")
|
||||
|
||||
# Load new model
|
||||
# Load new model (optimized)
|
||||
print(f"Loading new model {new_model_name}...")
|
||||
model = AutoModel.from_pretrained(new_model_name)
|
||||
try:
|
||||
# Ultra-fast model switching
|
||||
from transformers import AutoConfig
|
||||
config = AutoConfig.from_pretrained(new_model_name, local_files_only=True)
|
||||
model = AutoModel.from_pretrained(
|
||||
new_model_name,
|
||||
config=config,
|
||||
torch_dtype=torch.float16,
|
||||
low_cpu_mem_usage=True,
|
||||
local_files_only=True,
|
||||
_fast_init=True
|
||||
)
|
||||
print(f"New model loaded! (ultra-fast)")
|
||||
except:
|
||||
try:
|
||||
model = AutoModel.from_pretrained(
|
||||
new_model_name,
|
||||
torch_dtype=torch.float16,
|
||||
low_cpu_mem_usage=True,
|
||||
local_files_only=True
|
||||
)
|
||||
print(f"New model loaded! (optimized)")
|
||||
except:
|
||||
try:
|
||||
model = AutoModel.from_pretrained(
|
||||
new_model_name,
|
||||
torch_dtype=torch.float16,
|
||||
low_cpu_mem_usage=True
|
||||
)
|
||||
print(f"New model loaded! (network + optimized)")
|
||||
except:
|
||||
model = AutoModel.from_pretrained(new_model_name)
|
||||
print(f"New model loaded! (standard)")
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
||||
@@ -640,19 +762,27 @@ def create_hnsw_embedding_server(
|
||||
# Clear GPU cache if available
|
||||
if device.type == "cuda":
|
||||
torch.cuda.empty_cache()
|
||||
print("INFO: Cleared CUDA cache")
|
||||
logger.info("Cleared CUDA cache")
|
||||
elif device.type == "mps":
|
||||
torch.mps.empty_cache()
|
||||
print("INFO: Cleared MPS cache")
|
||||
logger.info("Cleared MPS cache")
|
||||
|
||||
# Update model name
|
||||
model_name = new_model_name
|
||||
|
||||
# Re-detect embedding mode based on new model name
|
||||
if model_name.startswith("text-embedding-"):
|
||||
embedding_mode = "openai"
|
||||
logger.info(f"Auto-detected embedding mode: openai for {model_name}")
|
||||
else:
|
||||
embedding_mode = "sentence-transformers"
|
||||
logger.info(f"Auto-detected embedding mode: sentence-transformers for {model_name}")
|
||||
|
||||
# Force garbage collection
|
||||
import gc
|
||||
|
||||
gc.collect()
|
||||
print("INFO: Memory cleanup completed")
|
||||
logger.info("Memory cleanup completed")
|
||||
|
||||
response = ["SUCCESS"]
|
||||
print(
|
||||
@@ -664,6 +794,32 @@ def create_hnsw_embedding_server(
|
||||
socket.send(msgpack.packb(response))
|
||||
continue
|
||||
|
||||
# Handle direct text embedding request (for OpenAI and sentence-transformers)
|
||||
if isinstance(request_payload, list) and len(request_payload) > 0:
|
||||
# Check if this is a direct text request (list of strings) and NOT a control message
|
||||
if (all(isinstance(item, str) for item in request_payload) and
|
||||
not request_payload[0].startswith("__")):
|
||||
logger.info(f"Processing direct text embedding request for {len(request_payload)} texts in {embedding_mode} mode")
|
||||
|
||||
try:
|
||||
if embedding_mode == "openai":
|
||||
from leann.api import compute_embeddings_openai
|
||||
embeddings = compute_embeddings_openai(request_payload, model_name)
|
||||
else:
|
||||
# sentence-transformers mode - compute directly
|
||||
with timer(f"Direct text embedding ({len(request_payload)} texts)"):
|
||||
embeddings = process_batch(request_payload, [], [])
|
||||
|
||||
response = embeddings.tolist()
|
||||
socket.send(msgpack.packb(response))
|
||||
e2e_end = time.time()
|
||||
logger.info(f"⏱️ Text embedding E2E time: {e2e_end - e2e_start:.6f}s")
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.error(f"ERROR: Failed to compute {embedding_mode} embeddings: {e}")
|
||||
socket.send(msgpack.packb([]))
|
||||
continue
|
||||
|
||||
# Handle distance calculation requests
|
||||
if (
|
||||
isinstance(request_payload, list)
|
||||
@@ -674,7 +830,7 @@ def create_hnsw_embedding_server(
|
||||
node_ids = request_payload[0]
|
||||
query_vector = np.array(request_payload[1], dtype=np.float32)
|
||||
|
||||
print("DEBUG: Distance calculation request received")
|
||||
logger.debug("Distance calculation request received")
|
||||
print(f" Node IDs: {node_ids}")
|
||||
print(f" Query vector dim: {len(query_vector)}")
|
||||
print(f" Passages loaded: {len(passages)}")
|
||||
@@ -684,7 +840,7 @@ def create_hnsw_embedding_server(
|
||||
missing_ids = []
|
||||
with lookup_timer.timing():
|
||||
for nid in node_ids:
|
||||
print(f"DEBUG: Looking up passage ID {nid}")
|
||||
logger.debug(f"Looking up passage ID {nid}")
|
||||
try:
|
||||
txtinfo = passages[nid]
|
||||
if txtinfo is None:
|
||||
@@ -804,29 +960,11 @@ def create_hnsw_embedding_server(
|
||||
elif device.type == "mps":
|
||||
torch.mps.synchronize()
|
||||
e2e_end = time.time()
|
||||
print(
|
||||
f"Distance calculation E2E time: {e2e_end - e2e_start:.6f} seconds"
|
||||
logger.info(
|
||||
f"⏱️ Distance calculation E2E time: {e2e_end - e2e_start:.6f}s"
|
||||
)
|
||||
continue
|
||||
|
||||
# Handle direct text embedding request (for OpenAI mode)
|
||||
if embedding_mode == "openai" and isinstance(request_payload, list) and len(request_payload) > 0:
|
||||
# Check if this is a direct text request (list of strings)
|
||||
if all(isinstance(item, str) for item in request_payload):
|
||||
print(f"Processing direct text embedding request for {len(request_payload)} texts")
|
||||
|
||||
try:
|
||||
from leann.api import compute_embeddings_openai
|
||||
embeddings = compute_embeddings_openai(request_payload, model_name)
|
||||
response = embeddings.tolist()
|
||||
socket.send(msgpack.packb(response))
|
||||
e2e_end = time.time()
|
||||
print(f"Text embedding E2E time: {e2e_end - e2e_start:.6f} seconds")
|
||||
continue
|
||||
except Exception as e:
|
||||
print(f"ERROR: Failed to compute OpenAI embeddings: {e}")
|
||||
socket.send(msgpack.packb([]))
|
||||
continue
|
||||
|
||||
# Standard embedding request (passage ID lookup)
|
||||
if (
|
||||
@@ -945,10 +1083,10 @@ def create_hnsw_embedding_server(
|
||||
elif device.type == "mps":
|
||||
torch.mps.synchronize()
|
||||
e2e_end = time.time()
|
||||
print(f"ZMQ E2E time: {e2e_end - e2e_start:.6f} seconds")
|
||||
logger.info(f"⏱️ ZMQ E2E time: {e2e_end - e2e_start:.6f}s")
|
||||
|
||||
except zmq.Again:
|
||||
print("ZMQ socket timeout, continuing to listen")
|
||||
logger.debug("ZMQ socket timeout, continuing to listen")
|
||||
continue
|
||||
except Exception as e:
|
||||
print(f"Error in ZMQ server loop: {e}")
|
||||
|
||||
@@ -51,7 +51,63 @@ def compute_embeddings(
|
||||
|
||||
|
||||
def compute_embeddings_sentence_transformers(chunks: List[str], model_name: str) -> np.ndarray:
|
||||
"""Computes embeddings using sentence-transformers library."""
|
||||
"""Computes embeddings using sentence-transformers via embedding server."""
|
||||
print(
|
||||
f"INFO: Computing embeddings for {len(chunks)} chunks using SentenceTransformer model '{model_name}' (via embedding server)..."
|
||||
)
|
||||
|
||||
# Use embedding server for sentence-transformers too
|
||||
# This avoids loading the model twice (once in API, once in server)
|
||||
try:
|
||||
# Import ZMQ client functionality and server manager
|
||||
import zmq
|
||||
import msgpack
|
||||
import numpy as np
|
||||
from .embedding_server_manager import EmbeddingServerManager
|
||||
|
||||
# Ensure embedding server is running
|
||||
port = 5557
|
||||
server_manager = EmbeddingServerManager(backend_module_name="leann_backend_hnsw.hnsw_embedding_server")
|
||||
|
||||
server_started = server_manager.start_server(
|
||||
port=port,
|
||||
model_name=model_name,
|
||||
embedding_mode="sentence-transformers",
|
||||
enable_warmup=False,
|
||||
)
|
||||
|
||||
if not server_started:
|
||||
raise RuntimeError(f"Failed to start embedding server on port {port}")
|
||||
|
||||
# Connect to embedding server
|
||||
context = zmq.Context()
|
||||
socket = context.socket(zmq.REQ)
|
||||
socket.connect(f"tcp://localhost:{port}")
|
||||
|
||||
# Send chunks to server for embedding computation
|
||||
request = chunks
|
||||
socket.send(msgpack.packb(request))
|
||||
|
||||
# Receive embeddings from server
|
||||
response = socket.recv()
|
||||
embeddings_list = msgpack.unpackb(response)
|
||||
|
||||
# Convert back to numpy array
|
||||
embeddings = np.array(embeddings_list, dtype=np.float32)
|
||||
|
||||
socket.close()
|
||||
context.term()
|
||||
|
||||
return embeddings
|
||||
|
||||
except Exception as e:
|
||||
# Fallback to direct sentence-transformers if server connection fails
|
||||
print(f"Warning: Failed to connect to embedding server, falling back to direct computation: {e}")
|
||||
return _compute_embeddings_sentence_transformers_direct(chunks, model_name)
|
||||
|
||||
|
||||
def _compute_embeddings_sentence_transformers_direct(chunks: List[str], model_name: str) -> np.ndarray:
|
||||
"""Direct sentence-transformers computation (fallback)."""
|
||||
try:
|
||||
from sentence_transformers import SentenceTransformer
|
||||
except ImportError as e:
|
||||
@@ -64,7 +120,7 @@ def compute_embeddings_sentence_transformers(chunks: List[str], model_name: str)
|
||||
|
||||
model = model.half()
|
||||
print(
|
||||
f"INFO: Computing embeddings for {len(chunks)} chunks using SentenceTransformer model '{model_name}'..."
|
||||
f"INFO: Computing embeddings for {len(chunks)} chunks using SentenceTransformer model '{model_name}' (direct)..."
|
||||
)
|
||||
# use acclerater GPU or MAC GPU
|
||||
|
||||
|
||||
@@ -200,7 +200,27 @@ class EmbeddingServerManager:
|
||||
|
||||
# Check model compatibility
|
||||
model_matches = _check_server_model(self.server_port, model_name)
|
||||
if not model_matches:
|
||||
if model_matches:
|
||||
print(
|
||||
f"✅ Existing server already using correct model: {model_name}"
|
||||
)
|
||||
|
||||
# Still check meta path if provided
|
||||
passages_file = kwargs.get("passages_file")
|
||||
if passages_file and str(passages_file).endswith(
|
||||
".meta.json"
|
||||
):
|
||||
meta_matches = _check_server_meta_path(
|
||||
self.server_port, str(passages_file)
|
||||
)
|
||||
if not meta_matches:
|
||||
print("⚠️ Updating meta path to: {passages_file}")
|
||||
_update_server_meta_path(
|
||||
self.server_port, str(passages_file)
|
||||
)
|
||||
|
||||
return True
|
||||
else:
|
||||
print(
|
||||
f"⚠️ Existing server has different model. Attempting to update to: {model_name}"
|
||||
)
|
||||
@@ -230,11 +250,6 @@ class EmbeddingServerManager:
|
||||
)
|
||||
|
||||
return True
|
||||
else:
|
||||
print(
|
||||
f"✅ Existing server already using correct model: {model_name}"
|
||||
)
|
||||
return True
|
||||
else:
|
||||
# Server process exists but port not responding - restart
|
||||
print("⚠️ Server process exists but not responding. Restarting...")
|
||||
@@ -254,7 +269,11 @@ class EmbeddingServerManager:
|
||||
|
||||
# Check model compatibility first
|
||||
model_matches = _check_server_model(port, model_name)
|
||||
if not model_matches:
|
||||
if model_matches:
|
||||
print(
|
||||
f"✅ Existing server on port {port} is using correct model: {model_name}"
|
||||
)
|
||||
else:
|
||||
print(
|
||||
f"⚠️ Existing server on port {port} has different model. Attempting to update to: {model_name}"
|
||||
)
|
||||
@@ -263,10 +282,6 @@ class EmbeddingServerManager:
|
||||
f"❌ Failed to update server model to {model_name}. Consider using a different port."
|
||||
)
|
||||
print(f"✅ Successfully updated server model to: {model_name}")
|
||||
else:
|
||||
print(
|
||||
f"✅ Existing server on port {port} is using correct model: {model_name}"
|
||||
)
|
||||
|
||||
# Check meta path compatibility if provided
|
||||
if passages_file and str(passages_file).endswith(".meta.json"):
|
||||
|
||||
Reference in New Issue
Block a user