perf: reuse embedding server for query embed

This commit is contained in:
Andy Lee
2025-07-16 16:12:15 -07:00
parent 2a1a152073
commit f77c4e38cb
4 changed files with 169 additions and 38 deletions

View File

@@ -1,7 +1,3 @@
import faulthandler
faulthandler.enable()
import argparse
from llama_index.core import SimpleDirectoryReader, Settings
from llama_index.core.node_parser import SentenceSplitter
@@ -62,7 +58,7 @@ async def main(args):
print(f"\n[PHASE 2] Starting Leann chat session...")
llm_config = {"type": "hf", "model": "Qwen/Qwen3-8B"}
llm_config = {"type": "hf", "model": "Qwen/Qwen3-4B"}
chat = LeannChat(index_path=INDEX_PATH, llm_config=llm_config)

View File

@@ -7,7 +7,7 @@ import json
import pickle
import numpy as np
from pathlib import Path
from typing import List, Dict, Any, Optional
from typing import List, Dict, Any, Optional, Literal
from dataclasses import dataclass, field
import uuid
import torch
@@ -250,22 +250,41 @@ class LeannSearcher:
final_kwargs["enable_warmup"] = enable_warmup
self.backend_impl = backend_factory.searcher(index_path, **final_kwargs)
def search(self, query: str, top_k: int = 5, **search_kwargs) -> List[SearchResult]:
def search(
self,
query: str,
top_k: int = 5,
complexity: int = 64,
beam_width: int = 1,
prune_ratio: float = 0.0,
recompute_embeddings: bool = False,
pruning_strategy: Literal["global", "local", "proportional"] = "global",
zmq_port: int = 5557,
**kwargs,
) -> List[SearchResult]:
print("🔍 DEBUG LeannSearcher.search() called:")
print(f" Query: '{query}'")
print(f" Top_k: {top_k}")
print(f" Search kwargs: {search_kwargs}")
print(f" Additional kwargs: {kwargs}")
query_embedding = compute_embeddings(
[query], self.embedding_model, self.use_mlx
)
# Use backend's compute_query_embedding method
# This will automatically use embedding server if available and needed
query_embedding = self.backend_impl.compute_query_embedding(query, zmq_port)
print(f" Generated embedding shape: {query_embedding.shape}")
print(f"🔍 DEBUG Query embedding first 10 values: {query_embedding[0][:10]}")
print(f"🔍 DEBUG Query embedding norm: {np.linalg.norm(query_embedding[0])}")
# Add use_mlx to search kwargs
search_kwargs["use_mlx"] = self.use_mlx
results = self.backend_impl.search(query_embedding, top_k, **search_kwargs)
results = self.backend_impl.search(
query_embedding,
top_k,
complexity=complexity,
beam_width=beam_width,
prune_ratio=prune_ratio,
recompute_embeddings=recompute_embeddings,
pruning_strategy=pruning_strategy,
zmq_port=zmq_port,
**kwargs,
)
print(
f" Backend returned: labels={len(results.get('labels', [[]])[0])} results"
)
@@ -309,8 +328,33 @@ class LeannChat:
self.searcher = LeannSearcher(index_path, enable_warmup=enable_warmup, **kwargs)
self.llm = get_llm(llm_config)
def ask(self, question: str, top_k=5, **kwargs):
results = self.searcher.search(question, top_k=top_k, **kwargs)
def ask(
self,
question: str,
top_k: int = 5,
complexity: int = 64,
beam_width: int = 1,
prune_ratio: float = 0.0,
recompute_embeddings: bool = False,
pruning_strategy: Literal["global", "local", "proportional"] = "global",
zmq_port: int = 5557,
llm_kwargs: Optional[Dict[str, Any]] = None,
**search_kwargs,
):
if llm_kwargs is None:
llm_kwargs = {}
results = self.searcher.search(
question,
top_k=top_k,
complexity=complexity,
beam_width=beam_width,
prune_ratio=prune_ratio,
recompute_embeddings=recompute_embeddings,
pruning_strategy=pruning_strategy,
zmq_port=zmq_port,
**search_kwargs,
)
context = "\n\n".join([r.text for r in results])
prompt = (
"Here is some retrieved context that might help answer your question:\n\n"
@@ -318,7 +362,7 @@ class LeannChat:
f"Question: {question}\n\n"
"Please provide the best answer you can based on this context and your knowledge."
)
return self.llm.ask(prompt, **kwargs.get("llm_kwargs", {}))
return self.llm.ask(prompt, **llm_kwargs)
def start_interactive(self):
print("\nLeann Chat started (type 'quit' to exit)")

View File

@@ -2,13 +2,16 @@ from abc import ABC, abstractmethod
import numpy as np
from typing import Dict, Any, List, Literal
class LeannBackendBuilderInterface(ABC):
"""Backend interface for building indexes"""
@abstractmethod
def build(self, data: np.ndarray, ids: List[str], index_path: str, **kwargs) -> None:
@abstractmethod
def build(
self, data: np.ndarray, ids: List[str], index_path: str, **kwargs
) -> None:
"""Build index
Args:
data: Vector data (N, D)
ids: List of string IDs for each vector
@@ -17,30 +20,35 @@ class LeannBackendBuilderInterface(ABC):
"""
pass
class LeannBackendSearcherInterface(ABC):
"""Backend interface for searching"""
@abstractmethod
def __init__(self, index_path: str, **kwargs):
"""Initialize searcher
Args:
index_path: Path to index file
**kwargs: Backend-specific loading parameters
"""
pass
@abstractmethod
def search(self, query: np.ndarray, top_k: int,
complexity: int = 64,
beam_width: int = 1,
prune_ratio: float = 0.0,
recompute_embeddings: bool = False,
pruning_strategy: Literal["global", "local", "proportional"] = "global",
zmq_port: int = 5557,
**kwargs) -> Dict[str, Any]:
def search(
self,
query: np.ndarray,
top_k: int,
complexity: int = 64,
beam_width: int = 1,
prune_ratio: float = 0.0,
recompute_embeddings: bool = False,
pruning_strategy: Literal["global", "local", "proportional"] = "global",
zmq_port: int = 5557,
**kwargs,
) -> Dict[str, Any]:
"""Search for nearest neighbors
Args:
query: Query vectors (B, D) where B is batch size, D is dimension
top_k: Number of nearest neighbors to return
@@ -51,23 +59,40 @@ class LeannBackendSearcherInterface(ABC):
pruning_strategy: PQ candidate selection strategy - "global" (default), "local", or "proportional"
zmq_port: ZMQ port for embedding server communication
**kwargs: Backend-specific parameters
Returns:
{"labels": [...], "distances": [...]}
"""
pass
@abstractmethod
def compute_query_embedding(
self, query: str, zmq_port: int = 5557, use_server_if_available: bool = True
) -> np.ndarray:
"""Compute embedding for a query string
Args:
query: The query string to embed
zmq_port: ZMQ port for embedding server
use_server_if_available: Whether to try using embedding server first
Returns:
Query embedding as numpy array with shape (1, D)
"""
pass
class LeannBackendFactoryInterface(ABC):
"""Backend factory interface"""
@staticmethod
@abstractmethod
def builder(**kwargs) -> LeannBackendBuilderInterface:
"""Create Builder instance"""
pass
@staticmethod
@abstractmethod
@abstractmethod
def searcher(index_path: str, **kwargs) -> LeannBackendSearcherInterface:
"""Create Searcher instance"""
pass
pass

View File

@@ -89,6 +89,72 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
if not server_started:
raise RuntimeError(f"Failed to start embedding server on port {port}")
def compute_query_embedding(
self, query: str, zmq_port: int = 5557, use_server_if_available: bool = True
) -> np.ndarray:
"""
Compute embedding for a query string.
Args:
query: The query string to embed
zmq_port: ZMQ port for embedding server
use_server_if_available: Whether to try using embedding server first
Returns:
Query embedding as numpy array
"""
# Try to use embedding server if available and requested
if (
use_server_if_available
and self.embedding_server_manager
and self.embedding_server_manager.server_process
):
try:
return self._compute_embedding_via_server([query], zmq_port)[
0:1
] # Return (1, D) shape
except Exception as e:
print(f"⚠️ Embedding server failed: {e}")
print("⏭️ Falling back to direct model loading...")
# Fallback to direct computation
from .api import compute_embeddings
use_mlx = self.meta.get("use_mlx", False)
return compute_embeddings([query], self.embedding_model, use_mlx)
def _compute_embedding_via_server(self, chunks: list, zmq_port: int) -> np.ndarray:
"""Compute embeddings using the ZMQ embedding server."""
import zmq
import msgpack
try:
context = zmq.Context()
socket = context.socket(zmq.REQ)
socket.setsockopt(zmq.RCVTIMEO, 30000) # 30 second timeout
socket.connect(f"tcp://localhost:{zmq_port}")
# Send embedding request
request = chunks
request_bytes = msgpack.packb(request)
socket.send(request_bytes)
# Wait for response
response_bytes = socket.recv()
response = msgpack.unpackb(response_bytes)
socket.close()
context.term()
# Convert response to numpy array
if isinstance(response, list) and len(response) > 0:
return np.array(response, dtype=np.float32)
else:
raise RuntimeError("Invalid response from embedding server")
except Exception as e:
raise RuntimeError(f"Failed to compute embeddings via server: {e}")
@abstractmethod
def search(
self,