perf: make embedder loading faster by 6x, and embed queries through the server
This commit is contained in:
@@ -51,7 +51,63 @@ def compute_embeddings(
|
||||
|
||||
|
||||
def compute_embeddings_sentence_transformers(chunks: List[str], model_name: str) -> np.ndarray:
|
||||
"""Computes embeddings using sentence-transformers library."""
|
||||
"""Computes embeddings using sentence-transformers via embedding server."""
|
||||
print(
|
||||
f"INFO: Computing embeddings for {len(chunks)} chunks using SentenceTransformer model '{model_name}' (via embedding server)..."
|
||||
)
|
||||
|
||||
# Use embedding server for sentence-transformers too
|
||||
# This avoids loading the model twice (once in API, once in server)
|
||||
try:
|
||||
# Import ZMQ client functionality and server manager
|
||||
import zmq
|
||||
import msgpack
|
||||
import numpy as np
|
||||
from .embedding_server_manager import EmbeddingServerManager
|
||||
|
||||
# Ensure embedding server is running
|
||||
port = 5557
|
||||
server_manager = EmbeddingServerManager(backend_module_name="leann_backend_hnsw.hnsw_embedding_server")
|
||||
|
||||
server_started = server_manager.start_server(
|
||||
port=port,
|
||||
model_name=model_name,
|
||||
embedding_mode="sentence-transformers",
|
||||
enable_warmup=False,
|
||||
)
|
||||
|
||||
if not server_started:
|
||||
raise RuntimeError(f"Failed to start embedding server on port {port}")
|
||||
|
||||
# Connect to embedding server
|
||||
context = zmq.Context()
|
||||
socket = context.socket(zmq.REQ)
|
||||
socket.connect(f"tcp://localhost:{port}")
|
||||
|
||||
# Send chunks to server for embedding computation
|
||||
request = chunks
|
||||
socket.send(msgpack.packb(request))
|
||||
|
||||
# Receive embeddings from server
|
||||
response = socket.recv()
|
||||
embeddings_list = msgpack.unpackb(response)
|
||||
|
||||
# Convert back to numpy array
|
||||
embeddings = np.array(embeddings_list, dtype=np.float32)
|
||||
|
||||
socket.close()
|
||||
context.term()
|
||||
|
||||
return embeddings
|
||||
|
||||
except Exception as e:
|
||||
# Fallback to direct sentence-transformers if server connection fails
|
||||
print(f"Warning: Failed to connect to embedding server, falling back to direct computation: {e}")
|
||||
return _compute_embeddings_sentence_transformers_direct(chunks, model_name)
|
||||
|
||||
|
||||
def _compute_embeddings_sentence_transformers_direct(chunks: List[str], model_name: str) -> np.ndarray:
|
||||
"""Direct sentence-transformers computation (fallback)."""
|
||||
try:
|
||||
from sentence_transformers import SentenceTransformer
|
||||
except ImportError as e:
|
||||
@@ -64,7 +120,7 @@ def compute_embeddings_sentence_transformers(chunks: List[str], model_name: str)
|
||||
|
||||
model = model.half()
|
||||
print(
|
||||
f"INFO: Computing embeddings for {len(chunks)} chunks using SentenceTransformer model '{model_name}'..."
|
||||
f"INFO: Computing embeddings for {len(chunks)} chunks using SentenceTransformer model '{model_name}' (direct)..."
|
||||
)
|
||||
# use acclerater GPU or MAC GPU
|
||||
|
||||
|
||||
Reference in New Issue
Block a user