perf: make embedder loading faster by 6x, and embed queries through the server

This commit is contained in:
Andy Lee
2025-07-17 20:08:06 -07:00
parent 99d439577d
commit 1c5fec5565
4 changed files with 323 additions and 105 deletions

View File

@@ -51,7 +51,63 @@ def compute_embeddings(
def compute_embeddings_sentence_transformers(chunks: List[str], model_name: str) -> np.ndarray:
"""Computes embeddings using sentence-transformers library."""
"""Computes embeddings using sentence-transformers via embedding server."""
print(
f"INFO: Computing embeddings for {len(chunks)} chunks using SentenceTransformer model '{model_name}' (via embedding server)..."
)
# Use embedding server for sentence-transformers too
# This avoids loading the model twice (once in API, once in server)
try:
# Import ZMQ client functionality and server manager
import zmq
import msgpack
import numpy as np
from .embedding_server_manager import EmbeddingServerManager
# Ensure embedding server is running
port = 5557
server_manager = EmbeddingServerManager(backend_module_name="leann_backend_hnsw.hnsw_embedding_server")
server_started = server_manager.start_server(
port=port,
model_name=model_name,
embedding_mode="sentence-transformers",
enable_warmup=False,
)
if not server_started:
raise RuntimeError(f"Failed to start embedding server on port {port}")
# Connect to embedding server
context = zmq.Context()
socket = context.socket(zmq.REQ)
socket.connect(f"tcp://localhost:{port}")
# Send chunks to server for embedding computation
request = chunks
socket.send(msgpack.packb(request))
# Receive embeddings from server
response = socket.recv()
embeddings_list = msgpack.unpackb(response)
# Convert back to numpy array
embeddings = np.array(embeddings_list, dtype=np.float32)
socket.close()
context.term()
return embeddings
except Exception as e:
# Fallback to direct sentence-transformers if server connection fails
print(f"Warning: Failed to connect to embedding server, falling back to direct computation: {e}")
return _compute_embeddings_sentence_transformers_direct(chunks, model_name)
def _compute_embeddings_sentence_transformers_direct(chunks: List[str], model_name: str) -> np.ndarray:
"""Direct sentence-transformers computation (fallback)."""
try:
from sentence_transformers import SentenceTransformer
except ImportError as e:
@@ -64,7 +120,7 @@ def compute_embeddings_sentence_transformers(chunks: List[str], model_name: str)
model = model.half()
print(
f"INFO: Computing embeddings for {len(chunks)} chunks using SentenceTransformer model '{model_name}'..."
f"INFO: Computing embeddings for {len(chunks)} chunks using SentenceTransformer model '{model_name}' (direct)..."
)
# use acclerater GPU or MAC GPU

View File

@@ -200,7 +200,27 @@ class EmbeddingServerManager:
# Check model compatibility
model_matches = _check_server_model(self.server_port, model_name)
if not model_matches:
if model_matches:
print(
f"✅ Existing server already using correct model: {model_name}"
)
# Still check meta path if provided
passages_file = kwargs.get("passages_file")
if passages_file and str(passages_file).endswith(
".meta.json"
):
meta_matches = _check_server_meta_path(
self.server_port, str(passages_file)
)
if not meta_matches:
print("⚠️ Updating meta path to: {passages_file}")
_update_server_meta_path(
self.server_port, str(passages_file)
)
return True
else:
print(
f"⚠️ Existing server has different model. Attempting to update to: {model_name}"
)
@@ -230,11 +250,6 @@ class EmbeddingServerManager:
)
return True
else:
print(
f"✅ Existing server already using correct model: {model_name}"
)
return True
else:
# Server process exists but port not responding - restart
print("⚠️ Server process exists but not responding. Restarting...")
@@ -254,7 +269,11 @@ class EmbeddingServerManager:
# Check model compatibility first
model_matches = _check_server_model(port, model_name)
if not model_matches:
if model_matches:
print(
f"✅ Existing server on port {port} is using correct model: {model_name}"
)
else:
print(
f"⚠️ Existing server on port {port} has different model. Attempting to update to: {model_name}"
)
@@ -263,10 +282,6 @@ class EmbeddingServerManager:
f"❌ Failed to update server model to {model_name}. Consider using a different port."
)
print(f"✅ Successfully updated server model to: {model_name}")
else:
print(
f"✅ Existing server on port {port} is using correct model: {model_name}"
)
# Check meta path compatibility if provided
if passages_file and str(passages_file).endswith(".meta.json"):