perf: make embedder loading faster by 6x, and embed queries through the server

This commit is contained in:
Andy Lee
2025-07-17 20:08:06 -07:00
parent 99d439577d
commit 1c5fec5565
4 changed files with 323 additions and 105 deletions

View File

@@ -16,8 +16,17 @@ import zmq
import numpy as np import numpy as np
import msgpack import msgpack
from pathlib import Path from pathlib import Path
import logging
RED = "\033[91m" RED = "\033[91m"
# Set up logging based on environment variable
LOG_LEVEL = os.getenv('LEANN_LOG_LEVEL', 'INFO').upper()
logging.basicConfig(
level=getattr(logging, LOG_LEVEL, logging.INFO),
format='%(asctime)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
RESET = "\033[0m" RESET = "\033[0m"
# --- New Passage Loader from HNSW backend --- # --- New Passage Loader from HNSW backend ---
@@ -169,7 +178,7 @@ def create_embedding_server_thread(
在当前线程中创建并运行 embedding server 在当前线程中创建并运行 embedding server
这个函数设计为在单独的线程中调用 这个函数设计为在单独的线程中调用
""" """
print(f"INFO: Initializing embedding server thread on port {zmq_port}") logger.info(f"Initializing embedding server thread on port {zmq_port}")
try: try:
# 检查端口是否已被占用 # 检查端口是否已被占用
@@ -189,7 +198,7 @@ def create_embedding_server_thread(
if embedding_mode == "mlx": if embedding_mode == "mlx":
from leann.api import compute_embeddings_mlx from leann.api import compute_embeddings_mlx
import torch import torch
print("INFO: Using MLX for embeddings") logger.info("Using MLX for embeddings")
# Set device to CPU for compatibility with DeviceTimer class # Set device to CPU for compatibility with DeviceTimer class
device = torch.device("cpu") device = torch.device("cpu")
cuda_available = False cuda_available = False
@@ -197,7 +206,7 @@ def create_embedding_server_thread(
elif embedding_mode == "openai": elif embedding_mode == "openai":
from leann.api import compute_embeddings_openai from leann.api import compute_embeddings_openai
import torch import torch
print("INFO: Using OpenAI API for embeddings") logger.info("Using OpenAI API for embeddings")
# Set device to CPU for compatibility with DeviceTimer class # Set device to CPU for compatibility with DeviceTimer class
device = torch.device("cpu") device = torch.device("cpu")
cuda_available = False cuda_available = False
@@ -213,16 +222,16 @@ def create_embedding_server_thread(
if cuda_available: if cuda_available:
device = torch.device("cuda") device = torch.device("cuda")
print("INFO: Using CUDA device") logger.info("Using CUDA device")
elif mps_available: elif mps_available:
device = torch.device("mps") device = torch.device("mps")
print("INFO: Using MPS device (Apple Silicon)") logger.info("Using MPS device (Apple Silicon)")
else: else:
device = torch.device("cpu") device = torch.device("cpu")
print("INFO: Using CPU device") logger.info("Using CPU device")
# 加载模型 # 加载模型
print(f"INFO: Loading model {model_name}") logger.info(f"Loading model {model_name}")
model = AutoModel.from_pretrained(model_name).to(device).eval() model = AutoModel.from_pretrained(model_name).to(device).eval()
# 优化模型 # 优化模型
@@ -230,7 +239,7 @@ def create_embedding_server_thread(
try: try:
model = model.half() model = model.half()
model = torch.compile(model) model = torch.compile(model)
print(f"INFO: Using FP16 precision with model: {model_name}") logger.info(f"Using FP16 precision with model: {model_name}")
except Exception as e: except Exception as e:
print(f"WARNING: Model optimization failed: {e}") print(f"WARNING: Model optimization failed: {e}")
else: else:
@@ -256,7 +265,7 @@ def create_embedding_server_thread(
print("WARNING: No passages file provided or file not found. Using an empty passage loader.") print("WARNING: No passages file provided or file not found. Using an empty passage loader.")
passages = SimplePassageLoader() passages = SimplePassageLoader()
print(f"INFO: Loaded {len(passages)} passages.") logger.info(f"Loaded {len(passages)} passages.")
def client_warmup(zmq_port): def client_warmup(zmq_port):
"""Perform client-side warmup for DiskANN server""" """Perform client-side warmup for DiskANN server"""
@@ -365,7 +374,7 @@ def create_embedding_server_thread(
def process_batch_pytorch(texts_batch, ids_batch, missing_ids): def process_batch_pytorch(texts_batch, ids_batch, missing_ids):
"""处理文本批次""" """处理文本批次"""
batch_size = len(texts_batch) batch_size = len(texts_batch)
print(f"INFO: Processing batch of size {batch_size}") logger.info(f"Processing batch of size {batch_size}")
tokenize_timer = DeviceTimer("tokenization (batch)", device) tokenize_timer = DeviceTimer("tokenization (batch)", device)
to_device_timer = DeviceTimer("transfer to device (batch)", device) to_device_timer = DeviceTimer("transfer to device (batch)", device)

View File

@@ -18,10 +18,19 @@ import json
from pathlib import Path from pathlib import Path
from typing import Dict, Any, Optional, Union from typing import Dict, Any, Optional, Union
import sys import sys
import logging
RED = "\033[91m" RED = "\033[91m"
RESET = "\033[0m" RESET = "\033[0m"
# Set up logging based on environment variable
LOG_LEVEL = os.getenv('LEANN_LOG_LEVEL', 'INFO').upper()
logging.basicConfig(
level=getattr(logging, LOG_LEVEL, logging.INFO),
format='%(asctime)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
def is_similarity_metric(): def is_similarity_metric():
""" """
@@ -36,6 +45,23 @@ import torch
from torch import Tensor from torch import Tensor
import torch.nn.functional as F import torch.nn.functional as F
# Timing utilities
@contextmanager
def timer(name: str, sync_cuda: bool = True):
"""Context manager for timing operations with optional CUDA sync"""
start_time = time.time()
if sync_cuda and torch.cuda.is_available():
torch.cuda.synchronize()
try:
yield
finally:
if sync_cuda and torch.cuda.is_available():
torch.cuda.synchronize()
elif sync_cuda and torch.backends.mps.is_available():
torch.mps.synchronize()
elapsed = time.time() - start_time
logger.info(f"⏱️ {name}: {elapsed:.4f}s")
def e5_average_pool(last_hidden_states: Tensor, attention_mask: Tensor) -> Tensor: def e5_average_pool(last_hidden_states: Tensor, attention_mask: Tensor) -> Tensor:
last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0) last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
@@ -120,13 +146,13 @@ def load_passages_from_metadata(meta_file: str) -> SimplePassageLoader:
if passage_data and passage_data.get("text"): if passage_data and passage_data.get("text"):
return {"text": passage_data["text"]} return {"text": passage_data["text"]}
else: else:
print(f"DEBUG: Empty text for ID {int_id} -> {string_id}") logger.debug(f"Empty text for ID {int_id} -> {string_id}")
return {"text": ""} return {"text": ""}
else: else:
print(f"DEBUG: ID {int_id} not found in label_map") logger.debug(f"ID {int_id} not found in label_map")
return {"text": ""} return {"text": ""}
except Exception as e: except Exception as e:
print(f"DEBUG: Exception getting passage {passage_id}: {e}") logger.debug(f"Exception getting passage {passage_id}: {e}")
return {"text": ""} return {"text": ""}
def __len__(self) -> int: def __len__(self) -> int:
@@ -184,8 +210,21 @@ def create_hnsw_embedding_server(
tokenizer = None # MLX handles tokenization separately tokenizer = None # MLX handles tokenization separately
else: # sentence-transformers else: # sentence-transformers
print(f"Loading tokenizer for {model_name}...") print(f"Loading tokenizer for {model_name}...")
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True) # Optimized tokenizer loading: try local first, then fallback
print(f"Tokenizer loaded successfully!") try:
tokenizer = AutoTokenizer.from_pretrained(
model_name,
use_fast=True, # Use fast tokenizer (better runtime perf)
local_files_only=True # Avoid network delays
)
print(f"Tokenizer loaded successfully! (local + fast)")
except Exception as e:
print(f"Local tokenizer failed ({e}), trying network download...")
tokenizer = AutoTokenizer.from_pretrained(
model_name,
use_fast=True # Use fast tokenizer
)
print(f"Tokenizer loaded successfully! (network)")
# Device setup # Device setup
mps_available = hasattr(torch.backends, "mps") and torch.backends.mps.is_available() mps_available = hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
@@ -217,9 +256,47 @@ def create_hnsw_embedding_server(
print("OpenAI API mode - no local model loading required") print("OpenAI API mode - no local model loading required")
model = None model = None
else: else:
# Use standard transformers for sentence-transformers models # Use optimized transformers loading for sentence-transformers models
model = AutoModel.from_pretrained(model_name).to(device).eval() print(f"Loading model with optimizations...")
print(f"Model {model_name} loaded successfully!") try:
# Ultra-fast loading: preload config + fast_init
from transformers import AutoConfig
config = AutoConfig.from_pretrained(model_name, local_files_only=True)
model = AutoModel.from_pretrained(
model_name,
config=config,
torch_dtype=torch.float16, # Half precision for speed
low_cpu_mem_usage=True, # Reduce memory peaks
local_files_only=True, # Avoid network delays
_fast_init=True # Skip weight init checks
).to(device).eval()
print(f"Model {model_name} loaded successfully! (ultra-fast)")
except Exception as e:
print(f"Ultra-fast loading failed ({e}), trying optimized...")
try:
# Fallback: regular optimized loading
model = AutoModel.from_pretrained(
model_name,
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
local_files_only=True
).to(device).eval()
print(f"Model {model_name} loaded successfully! (optimized)")
except Exception as e2:
print(f"Optimized loading failed ({e2}), trying network...")
try:
# Fallback: optimized network loading
model = AutoModel.from_pretrained(
model_name,
torch_dtype=torch.float16,
low_cpu_mem_usage=True
).to(device).eval()
print(f"Model {model_name} loaded successfully! (network + optimized)")
except Exception as e3:
print(f"All optimized methods failed ({e3}), using standard...")
# Final fallback: standard loading
model = AutoModel.from_pretrained(model_name).to(device).eval()
print(f"Model {model_name} loaded successfully! (standard)")
# Check port availability # Check port availability
import socket import socket
@@ -370,8 +447,9 @@ def create_hnsw_embedding_server(
if embedding_mode == "mlx": if embedding_mode == "mlx":
return _process_batch_mlx(texts_batch, ids_batch, missing_ids) return _process_batch_mlx(texts_batch, ids_batch, missing_ids)
elif embedding_mode == "openai": elif embedding_mode == "openai":
from leann.api import compute_embeddings_openai with timer("OpenAI API call", sync_cuda=False):
return compute_embeddings_openai(texts_batch, model_name) from leann.api import compute_embeddings_openai
return compute_embeddings_openai(texts_batch, model_name)
_is_e5_model = "e5" in model_name.lower() _is_e5_model = "e5" in model_name.lower()
_is_bge_model = "bge" in model_name.lower() _is_bge_model = "bge" in model_name.lower()
@@ -417,44 +495,46 @@ def create_hnsw_embedding_server(
enc = {k: v.to(device) for k, v in encoded_batch.items()} enc = {k: v.to(device) for k, v in encoded_batch.items()}
with torch.no_grad(): with torch.no_grad():
with embed_timer.timing(): with timer("Model forward pass"):
out = model(enc["input_ids"], enc["attention_mask"]) with embed_timer.timing():
out = model(enc["input_ids"], enc["attention_mask"])
with pool_timer.timing(): with timer("Pooling"):
if _is_bge_model: with pool_timer.timing():
pooled_embeddings = out.last_hidden_state[:, 0] if _is_bge_model:
elif not hasattr(out, "last_hidden_state"): pooled_embeddings = out.last_hidden_state[:, 0]
if isinstance(out, torch.Tensor) and len(out.shape) == 2: elif not hasattr(out, "last_hidden_state"):
pooled_embeddings = out if isinstance(out, torch.Tensor) and len(out.shape) == 2:
pooled_embeddings = out
else:
print(
f"{RED}ERROR: Cannot determine how to pool. Output shape: {out.shape if isinstance(out, torch.Tensor) else 'N/A'}{RESET}"
)
hidden_dim = getattr(
model.config, "hidden_size", 384 if _is_e5_model else 768
)
pooled_embeddings = torch.zeros(
(batch_size, hidden_dim),
device=device,
dtype=enc["input_ids"].dtype
if hasattr(enc["input_ids"], "dtype")
else torch.float32,
)
elif _is_e5_model:
pooled_embeddings = e5_average_pool(
out.last_hidden_state, enc["attention_mask"]
)
else: else:
print( hidden_states = out.last_hidden_state
f"{RED}ERROR: Cannot determine how to pool. Output shape: {out.shape if isinstance(out, torch.Tensor) else 'N/A'}{RESET}" mask_expanded = (
enc["attention_mask"]
.unsqueeze(-1)
.expand(hidden_states.size())
.float()
) )
hidden_dim = getattr( sum_embeddings = torch.sum(hidden_states * mask_expanded, 1)
model.config, "hidden_size", 384 if _is_e5_model else 768 sum_mask = torch.clamp(mask_expanded.sum(1), min=1e-9)
) pooled_embeddings = sum_embeddings / sum_mask
pooled_embeddings = torch.zeros(
(batch_size, hidden_dim),
device=device,
dtype=enc["input_ids"].dtype
if hasattr(enc["input_ids"], "dtype")
else torch.float32,
)
elif _is_e5_model:
pooled_embeddings = e5_average_pool(
out.last_hidden_state, enc["attention_mask"]
)
else:
hidden_states = out.last_hidden_state
mask_expanded = (
enc["attention_mask"]
.unsqueeze(-1)
.expand(hidden_states.size())
.float()
)
sum_embeddings = torch.sum(hidden_states * mask_expanded, 1)
sum_mask = torch.clamp(mask_expanded.sum(1), min=1e-9)
pooled_embeddings = sum_embeddings / sum_mask
final_embeddings = pooled_embeddings final_embeddings = pooled_embeddings
if _is_e5_model or _is_bge_model: if _is_e5_model or _is_bge_model:
@@ -536,7 +616,7 @@ def create_hnsw_embedding_server(
def zmq_server_thread(): def zmq_server_thread():
"""ZMQ server thread""" """ZMQ server thread"""
nonlocal passages, model, tokenizer, model_name nonlocal passages, model, tokenizer, model_name, embedding_mode
context = zmq.Context() context = zmq.Context()
socket = context.socket(zmq.REP) socket = context.socket(zmq.REP)
socket.bind(f"tcp://*:{zmq_port}") socket.bind(f"tcp://*:{zmq_port}")
@@ -556,13 +636,13 @@ def create_hnsw_embedding_server(
try: try:
request_payload = msgpack.unpackb(message_bytes) request_payload = msgpack.unpackb(message_bytes)
if isinstance(request_payload, list): if isinstance(request_payload, list):
print(f"DEBUG: request_payload length: {len(request_payload)}") logger.debug(f"request_payload length: {len(request_payload)}")
for i, item in enumerate(request_payload): for i, item in enumerate(request_payload):
print( print(
f"DEBUG: request_payload[{i}]: {type(item)} - {item if len(str(item)) < 100 else str(item)[:100] + '...'}" f"DEBUG: request_payload[{i}]: {type(item)} - {item if len(str(item)) < 100 else str(item)[:100] + '...'}"
) )
# Handle control messages for meta path and model management # Handle control messages for meta path and model management FIRST
if isinstance(request_payload, list) and len(request_payload) >= 1: if isinstance(request_payload, list) and len(request_payload) >= 1:
if request_payload[0] == "__QUERY_META_PATH__": if request_payload[0] == "__QUERY_META_PATH__":
# Return the current meta path being used by the server # Return the current meta path being used by the server
@@ -617,19 +697,61 @@ def create_hnsw_embedding_server(
) )
# Clean up old model to free memory # Clean up old model to free memory
print("INFO: Releasing old model from memory...") logger.info("Releasing old model from memory...")
old_model = model old_model = model
old_tokenizer = tokenizer old_tokenizer = tokenizer
# Load new tokenizer first # Load new tokenizer first (optimized)
print(f"Loading new tokenizer for {new_model_name}...") print(f"Loading new tokenizer for {new_model_name}...")
tokenizer = AutoTokenizer.from_pretrained( try:
new_model_name, use_fast=True tokenizer = AutoTokenizer.from_pretrained(
) new_model_name,
use_fast=True,
local_files_only=True
)
print(f"New tokenizer loaded! (local + fast)")
except:
tokenizer = AutoTokenizer.from_pretrained(
new_model_name,
use_fast=True
)
print(f"New tokenizer loaded! (network + fast)")
# Load new model # Load new model (optimized)
print(f"Loading new model {new_model_name}...") print(f"Loading new model {new_model_name}...")
model = AutoModel.from_pretrained(new_model_name) try:
# Ultra-fast model switching
from transformers import AutoConfig
config = AutoConfig.from_pretrained(new_model_name, local_files_only=True)
model = AutoModel.from_pretrained(
new_model_name,
config=config,
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
local_files_only=True,
_fast_init=True
)
print(f"New model loaded! (ultra-fast)")
except:
try:
model = AutoModel.from_pretrained(
new_model_name,
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
local_files_only=True
)
print(f"New model loaded! (optimized)")
except:
try:
model = AutoModel.from_pretrained(
new_model_name,
torch_dtype=torch.float16,
low_cpu_mem_usage=True
)
print(f"New model loaded! (network + optimized)")
except:
model = AutoModel.from_pretrained(new_model_name)
print(f"New model loaded! (standard)")
model.to(device) model.to(device)
model.eval() model.eval()
@@ -640,19 +762,27 @@ def create_hnsw_embedding_server(
# Clear GPU cache if available # Clear GPU cache if available
if device.type == "cuda": if device.type == "cuda":
torch.cuda.empty_cache() torch.cuda.empty_cache()
print("INFO: Cleared CUDA cache") logger.info("Cleared CUDA cache")
elif device.type == "mps": elif device.type == "mps":
torch.mps.empty_cache() torch.mps.empty_cache()
print("INFO: Cleared MPS cache") logger.info("Cleared MPS cache")
# Update model name # Update model name
model_name = new_model_name model_name = new_model_name
# Re-detect embedding mode based on new model name
if model_name.startswith("text-embedding-"):
embedding_mode = "openai"
logger.info(f"Auto-detected embedding mode: openai for {model_name}")
else:
embedding_mode = "sentence-transformers"
logger.info(f"Auto-detected embedding mode: sentence-transformers for {model_name}")
# Force garbage collection # Force garbage collection
import gc import gc
gc.collect() gc.collect()
print("INFO: Memory cleanup completed") logger.info("Memory cleanup completed")
response = ["SUCCESS"] response = ["SUCCESS"]
print( print(
@@ -664,6 +794,32 @@ def create_hnsw_embedding_server(
socket.send(msgpack.packb(response)) socket.send(msgpack.packb(response))
continue continue
# Handle direct text embedding request (for OpenAI and sentence-transformers)
if isinstance(request_payload, list) and len(request_payload) > 0:
# Check if this is a direct text request (list of strings) and NOT a control message
if (all(isinstance(item, str) for item in request_payload) and
not request_payload[0].startswith("__")):
logger.info(f"Processing direct text embedding request for {len(request_payload)} texts in {embedding_mode} mode")
try:
if embedding_mode == "openai":
from leann.api import compute_embeddings_openai
embeddings = compute_embeddings_openai(request_payload, model_name)
else:
# sentence-transformers mode - compute directly
with timer(f"Direct text embedding ({len(request_payload)} texts)"):
embeddings = process_batch(request_payload, [], [])
response = embeddings.tolist()
socket.send(msgpack.packb(response))
e2e_end = time.time()
logger.info(f"⏱️ Text embedding E2E time: {e2e_end - e2e_start:.6f}s")
continue
except Exception as e:
logger.error(f"ERROR: Failed to compute {embedding_mode} embeddings: {e}")
socket.send(msgpack.packb([]))
continue
# Handle distance calculation requests # Handle distance calculation requests
if ( if (
isinstance(request_payload, list) isinstance(request_payload, list)
@@ -674,7 +830,7 @@ def create_hnsw_embedding_server(
node_ids = request_payload[0] node_ids = request_payload[0]
query_vector = np.array(request_payload[1], dtype=np.float32) query_vector = np.array(request_payload[1], dtype=np.float32)
print("DEBUG: Distance calculation request received") logger.debug("Distance calculation request received")
print(f" Node IDs: {node_ids}") print(f" Node IDs: {node_ids}")
print(f" Query vector dim: {len(query_vector)}") print(f" Query vector dim: {len(query_vector)}")
print(f" Passages loaded: {len(passages)}") print(f" Passages loaded: {len(passages)}")
@@ -684,7 +840,7 @@ def create_hnsw_embedding_server(
missing_ids = [] missing_ids = []
with lookup_timer.timing(): with lookup_timer.timing():
for nid in node_ids: for nid in node_ids:
print(f"DEBUG: Looking up passage ID {nid}") logger.debug(f"Looking up passage ID {nid}")
try: try:
txtinfo = passages[nid] txtinfo = passages[nid]
if txtinfo is None: if txtinfo is None:
@@ -804,29 +960,11 @@ def create_hnsw_embedding_server(
elif device.type == "mps": elif device.type == "mps":
torch.mps.synchronize() torch.mps.synchronize()
e2e_end = time.time() e2e_end = time.time()
print( logger.info(
f"Distance calculation E2E time: {e2e_end - e2e_start:.6f} seconds" f"⏱️ Distance calculation E2E time: {e2e_end - e2e_start:.6f}s"
) )
continue continue
# Handle direct text embedding request (for OpenAI mode)
if embedding_mode == "openai" and isinstance(request_payload, list) and len(request_payload) > 0:
# Check if this is a direct text request (list of strings)
if all(isinstance(item, str) for item in request_payload):
print(f"Processing direct text embedding request for {len(request_payload)} texts")
try:
from leann.api import compute_embeddings_openai
embeddings = compute_embeddings_openai(request_payload, model_name)
response = embeddings.tolist()
socket.send(msgpack.packb(response))
e2e_end = time.time()
print(f"Text embedding E2E time: {e2e_end - e2e_start:.6f} seconds")
continue
except Exception as e:
print(f"ERROR: Failed to compute OpenAI embeddings: {e}")
socket.send(msgpack.packb([]))
continue
# Standard embedding request (passage ID lookup) # Standard embedding request (passage ID lookup)
if ( if (
@@ -945,10 +1083,10 @@ def create_hnsw_embedding_server(
elif device.type == "mps": elif device.type == "mps":
torch.mps.synchronize() torch.mps.synchronize()
e2e_end = time.time() e2e_end = time.time()
print(f"ZMQ E2E time: {e2e_end - e2e_start:.6f} seconds") logger.info(f"⏱️ ZMQ E2E time: {e2e_end - e2e_start:.6f}s")
except zmq.Again: except zmq.Again:
print("ZMQ socket timeout, continuing to listen") logger.debug("ZMQ socket timeout, continuing to listen")
continue continue
except Exception as e: except Exception as e:
print(f"Error in ZMQ server loop: {e}") print(f"Error in ZMQ server loop: {e}")

View File

@@ -51,7 +51,63 @@ def compute_embeddings(
def compute_embeddings_sentence_transformers(chunks: List[str], model_name: str) -> np.ndarray: def compute_embeddings_sentence_transformers(chunks: List[str], model_name: str) -> np.ndarray:
"""Computes embeddings using sentence-transformers library.""" """Computes embeddings using sentence-transformers via embedding server."""
print(
f"INFO: Computing embeddings for {len(chunks)} chunks using SentenceTransformer model '{model_name}' (via embedding server)..."
)
# Use embedding server for sentence-transformers too
# This avoids loading the model twice (once in API, once in server)
try:
# Import ZMQ client functionality and server manager
import zmq
import msgpack
import numpy as np
from .embedding_server_manager import EmbeddingServerManager
# Ensure embedding server is running
port = 5557
server_manager = EmbeddingServerManager(backend_module_name="leann_backend_hnsw.hnsw_embedding_server")
server_started = server_manager.start_server(
port=port,
model_name=model_name,
embedding_mode="sentence-transformers",
enable_warmup=False,
)
if not server_started:
raise RuntimeError(f"Failed to start embedding server on port {port}")
# Connect to embedding server
context = zmq.Context()
socket = context.socket(zmq.REQ)
socket.connect(f"tcp://localhost:{port}")
# Send chunks to server for embedding computation
request = chunks
socket.send(msgpack.packb(request))
# Receive embeddings from server
response = socket.recv()
embeddings_list = msgpack.unpackb(response)
# Convert back to numpy array
embeddings = np.array(embeddings_list, dtype=np.float32)
socket.close()
context.term()
return embeddings
except Exception as e:
# Fallback to direct sentence-transformers if server connection fails
print(f"Warning: Failed to connect to embedding server, falling back to direct computation: {e}")
return _compute_embeddings_sentence_transformers_direct(chunks, model_name)
def _compute_embeddings_sentence_transformers_direct(chunks: List[str], model_name: str) -> np.ndarray:
"""Direct sentence-transformers computation (fallback)."""
try: try:
from sentence_transformers import SentenceTransformer from sentence_transformers import SentenceTransformer
except ImportError as e: except ImportError as e:
@@ -64,7 +120,7 @@ def compute_embeddings_sentence_transformers(chunks: List[str], model_name: str)
model = model.half() model = model.half()
print( print(
f"INFO: Computing embeddings for {len(chunks)} chunks using SentenceTransformer model '{model_name}'..." f"INFO: Computing embeddings for {len(chunks)} chunks using SentenceTransformer model '{model_name}' (direct)..."
) )
# use acclerater GPU or MAC GPU # use acclerater GPU or MAC GPU

View File

@@ -200,7 +200,27 @@ class EmbeddingServerManager:
# Check model compatibility # Check model compatibility
model_matches = _check_server_model(self.server_port, model_name) model_matches = _check_server_model(self.server_port, model_name)
if not model_matches: if model_matches:
print(
f"✅ Existing server already using correct model: {model_name}"
)
# Still check meta path if provided
passages_file = kwargs.get("passages_file")
if passages_file and str(passages_file).endswith(
".meta.json"
):
meta_matches = _check_server_meta_path(
self.server_port, str(passages_file)
)
if not meta_matches:
print("⚠️ Updating meta path to: {passages_file}")
_update_server_meta_path(
self.server_port, str(passages_file)
)
return True
else:
print( print(
f"⚠️ Existing server has different model. Attempting to update to: {model_name}" f"⚠️ Existing server has different model. Attempting to update to: {model_name}"
) )
@@ -230,11 +250,6 @@ class EmbeddingServerManager:
) )
return True return True
else:
print(
f"✅ Existing server already using correct model: {model_name}"
)
return True
else: else:
# Server process exists but port not responding - restart # Server process exists but port not responding - restart
print("⚠️ Server process exists but not responding. Restarting...") print("⚠️ Server process exists but not responding. Restarting...")
@@ -254,7 +269,11 @@ class EmbeddingServerManager:
# Check model compatibility first # Check model compatibility first
model_matches = _check_server_model(port, model_name) model_matches = _check_server_model(port, model_name)
if not model_matches: if model_matches:
print(
f"✅ Existing server on port {port} is using correct model: {model_name}"
)
else:
print( print(
f"⚠️ Existing server on port {port} has different model. Attempting to update to: {model_name}" f"⚠️ Existing server on port {port} has different model. Attempting to update to: {model_name}"
) )
@@ -263,10 +282,6 @@ class EmbeddingServerManager:
f"❌ Failed to update server model to {model_name}. Consider using a different port." f"❌ Failed to update server model to {model_name}. Consider using a different port."
) )
print(f"✅ Successfully updated server model to: {model_name}") print(f"✅ Successfully updated server model to: {model_name}")
else:
print(
f"✅ Existing server on port {port} is using correct model: {model_name}"
)
# Check meta path compatibility if provided # Check meta path compatibility if provided
if passages_file and str(passages_file).endswith(".meta.json"): if passages_file and str(passages_file).endswith(".meta.json"):