perf: reuse embedding server for query embed
This commit is contained in:
@@ -89,6 +89,72 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
|
||||
if not server_started:
|
||||
raise RuntimeError(f"Failed to start embedding server on port {port}")
|
||||
|
||||
def compute_query_embedding(
|
||||
self, query: str, zmq_port: int = 5557, use_server_if_available: bool = True
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Compute embedding for a query string.
|
||||
|
||||
Args:
|
||||
query: The query string to embed
|
||||
zmq_port: ZMQ port for embedding server
|
||||
use_server_if_available: Whether to try using embedding server first
|
||||
|
||||
Returns:
|
||||
Query embedding as numpy array
|
||||
"""
|
||||
# Try to use embedding server if available and requested
|
||||
if (
|
||||
use_server_if_available
|
||||
and self.embedding_server_manager
|
||||
and self.embedding_server_manager.server_process
|
||||
):
|
||||
try:
|
||||
return self._compute_embedding_via_server([query], zmq_port)[
|
||||
0:1
|
||||
] # Return (1, D) shape
|
||||
except Exception as e:
|
||||
print(f"⚠️ Embedding server failed: {e}")
|
||||
print("⏭️ Falling back to direct model loading...")
|
||||
|
||||
# Fallback to direct computation
|
||||
from .api import compute_embeddings
|
||||
|
||||
use_mlx = self.meta.get("use_mlx", False)
|
||||
return compute_embeddings([query], self.embedding_model, use_mlx)
|
||||
|
||||
def _compute_embedding_via_server(self, chunks: list, zmq_port: int) -> np.ndarray:
|
||||
"""Compute embeddings using the ZMQ embedding server."""
|
||||
import zmq
|
||||
import msgpack
|
||||
|
||||
try:
|
||||
context = zmq.Context()
|
||||
socket = context.socket(zmq.REQ)
|
||||
socket.setsockopt(zmq.RCVTIMEO, 30000) # 30 second timeout
|
||||
socket.connect(f"tcp://localhost:{zmq_port}")
|
||||
|
||||
# Send embedding request
|
||||
request = chunks
|
||||
request_bytes = msgpack.packb(request)
|
||||
socket.send(request_bytes)
|
||||
|
||||
# Wait for response
|
||||
response_bytes = socket.recv()
|
||||
response = msgpack.unpackb(response_bytes)
|
||||
|
||||
socket.close()
|
||||
context.term()
|
||||
|
||||
# Convert response to numpy array
|
||||
if isinstance(response, list) and len(response) > 0:
|
||||
return np.array(response, dtype=np.float32)
|
||||
else:
|
||||
raise RuntimeError("Invalid response from embedding server")
|
||||
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to compute embeddings via server: {e}")
|
||||
|
||||
@abstractmethod
|
||||
def search(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user