perf: reuse embedding server for query embed

This commit is contained in:
Andy Lee
2025-07-16 16:12:15 -07:00
parent 2a1a152073
commit f77c4e38cb
4 changed files with 169 additions and 38 deletions

View File

@@ -1,7 +1,3 @@
import faulthandler
faulthandler.enable()
import argparse
from llama_index.core import SimpleDirectoryReader, Settings
from llama_index.core.node_parser import SentenceSplitter
@@ -62,7 +58,7 @@ async def main(args):
print(f"\n[PHASE 2] Starting Leann chat session...")
llm_config = {"type": "hf", "model": "Qwen/Qwen3-8B"}
llm_config = {"type": "hf", "model": "Qwen/Qwen3-4B"}
chat = LeannChat(index_path=INDEX_PATH, llm_config=llm_config)

View File

@@ -7,7 +7,7 @@ import json
import pickle
import numpy as np
from pathlib import Path
from typing import List, Dict, Any, Optional
from typing import List, Dict, Any, Optional, Literal
from dataclasses import dataclass, field
import uuid
import torch
@@ -250,22 +250,41 @@ class LeannSearcher:
final_kwargs["enable_warmup"] = enable_warmup
self.backend_impl = backend_factory.searcher(index_path, **final_kwargs)
def search(self, query: str, top_k: int = 5, **search_kwargs) -> List[SearchResult]:
def search(
self,
query: str,
top_k: int = 5,
complexity: int = 64,
beam_width: int = 1,
prune_ratio: float = 0.0,
recompute_embeddings: bool = False,
pruning_strategy: Literal["global", "local", "proportional"] = "global",
zmq_port: int = 5557,
**kwargs,
) -> List[SearchResult]:
print("🔍 DEBUG LeannSearcher.search() called:")
print(f" Query: '{query}'")
print(f" Top_k: {top_k}")
print(f" Search kwargs: {search_kwargs}")
print(f" Additional kwargs: {kwargs}")
query_embedding = compute_embeddings(
[query], self.embedding_model, self.use_mlx
)
# Use backend's compute_query_embedding method
# This will automatically use embedding server if available and needed
query_embedding = self.backend_impl.compute_query_embedding(query, zmq_port)
print(f" Generated embedding shape: {query_embedding.shape}")
print(f"🔍 DEBUG Query embedding first 10 values: {query_embedding[0][:10]}")
print(f"🔍 DEBUG Query embedding norm: {np.linalg.norm(query_embedding[0])}")
# Add use_mlx to search kwargs
search_kwargs["use_mlx"] = self.use_mlx
results = self.backend_impl.search(query_embedding, top_k, **search_kwargs)
results = self.backend_impl.search(
query_embedding,
top_k,
complexity=complexity,
beam_width=beam_width,
prune_ratio=prune_ratio,
recompute_embeddings=recompute_embeddings,
pruning_strategy=pruning_strategy,
zmq_port=zmq_port,
**kwargs,
)
print(
f" Backend returned: labels={len(results.get('labels', [[]])[0])} results"
)
@@ -309,8 +328,33 @@ class LeannChat:
self.searcher = LeannSearcher(index_path, enable_warmup=enable_warmup, **kwargs)
self.llm = get_llm(llm_config)
def ask(self, question: str, top_k=5, **kwargs):
results = self.searcher.search(question, top_k=top_k, **kwargs)
def ask(
self,
question: str,
top_k: int = 5,
complexity: int = 64,
beam_width: int = 1,
prune_ratio: float = 0.0,
recompute_embeddings: bool = False,
pruning_strategy: Literal["global", "local", "proportional"] = "global",
zmq_port: int = 5557,
llm_kwargs: Optional[Dict[str, Any]] = None,
**search_kwargs,
):
if llm_kwargs is None:
llm_kwargs = {}
results = self.searcher.search(
question,
top_k=top_k,
complexity=complexity,
beam_width=beam_width,
prune_ratio=prune_ratio,
recompute_embeddings=recompute_embeddings,
pruning_strategy=pruning_strategy,
zmq_port=zmq_port,
**search_kwargs,
)
context = "\n\n".join([r.text for r in results])
prompt = (
"Here is some retrieved context that might help answer your question:\n\n"
@@ -318,7 +362,7 @@ class LeannChat:
f"Question: {question}\n\n"
"Please provide the best answer you can based on this context and your knowledge."
)
return self.llm.ask(prompt, **kwargs.get("llm_kwargs", {}))
return self.llm.ask(prompt, **llm_kwargs)
def start_interactive(self):
print("\nLeann Chat started (type 'quit' to exit)")

View File

@@ -2,11 +2,14 @@ from abc import ABC, abstractmethod
import numpy as np
from typing import Dict, Any, List, Literal
class LeannBackendBuilderInterface(ABC):
"""Backend interface for building indexes"""
@abstractmethod
def build(self, data: np.ndarray, ids: List[str], index_path: str, **kwargs) -> None:
def build(
self, data: np.ndarray, ids: List[str], index_path: str, **kwargs
) -> None:
"""Build index
Args:
@@ -17,6 +20,7 @@ class LeannBackendBuilderInterface(ABC):
"""
pass
class LeannBackendSearcherInterface(ABC):
"""Backend interface for searching"""
@@ -31,14 +35,18 @@ class LeannBackendSearcherInterface(ABC):
pass
@abstractmethod
def search(self, query: np.ndarray, top_k: int,
complexity: int = 64,
beam_width: int = 1,
prune_ratio: float = 0.0,
recompute_embeddings: bool = False,
pruning_strategy: Literal["global", "local", "proportional"] = "global",
zmq_port: int = 5557,
**kwargs) -> Dict[str, Any]:
def search(
self,
query: np.ndarray,
top_k: int,
complexity: int = 64,
beam_width: int = 1,
prune_ratio: float = 0.0,
recompute_embeddings: bool = False,
pruning_strategy: Literal["global", "local", "proportional"] = "global",
zmq_port: int = 5557,
**kwargs,
) -> Dict[str, Any]:
"""Search for nearest neighbors
Args:
@@ -57,6 +65,23 @@ class LeannBackendSearcherInterface(ABC):
"""
pass
@abstractmethod
def compute_query_embedding(
self, query: str, zmq_port: int = 5557, use_server_if_available: bool = True
) -> np.ndarray:
"""Compute embedding for a query string
Args:
query: The query string to embed
zmq_port: ZMQ port for embedding server
use_server_if_available: Whether to try using embedding server first
Returns:
Query embedding as numpy array with shape (1, D)
"""
pass
class LeannBackendFactoryInterface(ABC):
"""Backend factory interface"""

View File

@@ -89,6 +89,72 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
if not server_started:
raise RuntimeError(f"Failed to start embedding server on port {port}")
def compute_query_embedding(
self, query: str, zmq_port: int = 5557, use_server_if_available: bool = True
) -> np.ndarray:
"""
Compute embedding for a query string.
Args:
query: The query string to embed
zmq_port: ZMQ port for embedding server
use_server_if_available: Whether to try using embedding server first
Returns:
Query embedding as numpy array
"""
# Try to use embedding server if available and requested
if (
use_server_if_available
and self.embedding_server_manager
and self.embedding_server_manager.server_process
):
try:
return self._compute_embedding_via_server([query], zmq_port)[
0:1
] # Return (1, D) shape
except Exception as e:
print(f"⚠️ Embedding server failed: {e}")
print("⏭️ Falling back to direct model loading...")
# Fallback to direct computation
from .api import compute_embeddings
use_mlx = self.meta.get("use_mlx", False)
return compute_embeddings([query], self.embedding_model, use_mlx)
def _compute_embedding_via_server(self, chunks: list, zmq_port: int) -> np.ndarray:
"""Compute embeddings using the ZMQ embedding server."""
import zmq
import msgpack
try:
context = zmq.Context()
socket = context.socket(zmq.REQ)
socket.setsockopt(zmq.RCVTIMEO, 30000) # 30 second timeout
socket.connect(f"tcp://localhost:{zmq_port}")
# Send embedding request
request = chunks
request_bytes = msgpack.packb(request)
socket.send(request_bytes)
# Wait for response
response_bytes = socket.recv()
response = msgpack.unpackb(response_bytes)
socket.close()
context.term()
# Convert response to numpy array
if isinstance(response, list) and len(response) > 0:
return np.array(response, dtype=np.float32)
else:
raise RuntimeError("Invalid response from embedding server")
except Exception as e:
raise RuntimeError(f"Failed to compute embeddings via server: {e}")
@abstractmethod
def search(
self,