perf: reuse embedding server for query embed

This commit is contained in:
Andy Lee
2025-07-16 16:12:15 -07:00
parent 2a1a152073
commit f77c4e38cb
4 changed files with 169 additions and 38 deletions

View File

@@ -1,7 +1,3 @@
import faulthandler
faulthandler.enable()
import argparse import argparse
from llama_index.core import SimpleDirectoryReader, Settings from llama_index.core import SimpleDirectoryReader, Settings
from llama_index.core.node_parser import SentenceSplitter from llama_index.core.node_parser import SentenceSplitter
@@ -62,7 +58,7 @@ async def main(args):
print(f"\n[PHASE 2] Starting Leann chat session...") print(f"\n[PHASE 2] Starting Leann chat session...")
llm_config = {"type": "hf", "model": "Qwen/Qwen3-8B"} llm_config = {"type": "hf", "model": "Qwen/Qwen3-4B"}
chat = LeannChat(index_path=INDEX_PATH, llm_config=llm_config) chat = LeannChat(index_path=INDEX_PATH, llm_config=llm_config)

View File

@@ -7,7 +7,7 @@ import json
import pickle import pickle
import numpy as np import numpy as np
from pathlib import Path from pathlib import Path
from typing import List, Dict, Any, Optional from typing import List, Dict, Any, Optional, Literal
from dataclasses import dataclass, field from dataclasses import dataclass, field
import uuid import uuid
import torch import torch
@@ -250,22 +250,41 @@ class LeannSearcher:
final_kwargs["enable_warmup"] = enable_warmup final_kwargs["enable_warmup"] = enable_warmup
self.backend_impl = backend_factory.searcher(index_path, **final_kwargs) self.backend_impl = backend_factory.searcher(index_path, **final_kwargs)
def search(self, query: str, top_k: int = 5, **search_kwargs) -> List[SearchResult]: def search(
self,
query: str,
top_k: int = 5,
complexity: int = 64,
beam_width: int = 1,
prune_ratio: float = 0.0,
recompute_embeddings: bool = False,
pruning_strategy: Literal["global", "local", "proportional"] = "global",
zmq_port: int = 5557,
**kwargs,
) -> List[SearchResult]:
print("🔍 DEBUG LeannSearcher.search() called:") print("🔍 DEBUG LeannSearcher.search() called:")
print(f" Query: '{query}'") print(f" Query: '{query}'")
print(f" Top_k: {top_k}") print(f" Top_k: {top_k}")
print(f" Search kwargs: {search_kwargs}") print(f" Additional kwargs: {kwargs}")
query_embedding = compute_embeddings( # Use backend's compute_query_embedding method
[query], self.embedding_model, self.use_mlx # This will automatically use embedding server if available and needed
) query_embedding = self.backend_impl.compute_query_embedding(query, zmq_port)
print(f" Generated embedding shape: {query_embedding.shape}") print(f" Generated embedding shape: {query_embedding.shape}")
print(f"🔍 DEBUG Query embedding first 10 values: {query_embedding[0][:10]}") print(f"🔍 DEBUG Query embedding first 10 values: {query_embedding[0][:10]}")
print(f"🔍 DEBUG Query embedding norm: {np.linalg.norm(query_embedding[0])}") print(f"🔍 DEBUG Query embedding norm: {np.linalg.norm(query_embedding[0])}")
# Add use_mlx to search kwargs results = self.backend_impl.search(
search_kwargs["use_mlx"] = self.use_mlx query_embedding,
results = self.backend_impl.search(query_embedding, top_k, **search_kwargs) top_k,
complexity=complexity,
beam_width=beam_width,
prune_ratio=prune_ratio,
recompute_embeddings=recompute_embeddings,
pruning_strategy=pruning_strategy,
zmq_port=zmq_port,
**kwargs,
)
print( print(
f" Backend returned: labels={len(results.get('labels', [[]])[0])} results" f" Backend returned: labels={len(results.get('labels', [[]])[0])} results"
) )
@@ -309,8 +328,33 @@ class LeannChat:
self.searcher = LeannSearcher(index_path, enable_warmup=enable_warmup, **kwargs) self.searcher = LeannSearcher(index_path, enable_warmup=enable_warmup, **kwargs)
self.llm = get_llm(llm_config) self.llm = get_llm(llm_config)
def ask(self, question: str, top_k=5, **kwargs): def ask(
results = self.searcher.search(question, top_k=top_k, **kwargs) self,
question: str,
top_k: int = 5,
complexity: int = 64,
beam_width: int = 1,
prune_ratio: float = 0.0,
recompute_embeddings: bool = False,
pruning_strategy: Literal["global", "local", "proportional"] = "global",
zmq_port: int = 5557,
llm_kwargs: Optional[Dict[str, Any]] = None,
**search_kwargs,
):
if llm_kwargs is None:
llm_kwargs = {}
results = self.searcher.search(
question,
top_k=top_k,
complexity=complexity,
beam_width=beam_width,
prune_ratio=prune_ratio,
recompute_embeddings=recompute_embeddings,
pruning_strategy=pruning_strategy,
zmq_port=zmq_port,
**search_kwargs,
)
context = "\n\n".join([r.text for r in results]) context = "\n\n".join([r.text for r in results])
prompt = ( prompt = (
"Here is some retrieved context that might help answer your question:\n\n" "Here is some retrieved context that might help answer your question:\n\n"
@@ -318,7 +362,7 @@ class LeannChat:
f"Question: {question}\n\n" f"Question: {question}\n\n"
"Please provide the best answer you can based on this context and your knowledge." "Please provide the best answer you can based on this context and your knowledge."
) )
return self.llm.ask(prompt, **kwargs.get("llm_kwargs", {})) return self.llm.ask(prompt, **llm_kwargs)
def start_interactive(self): def start_interactive(self):
print("\nLeann Chat started (type 'quit' to exit)") print("\nLeann Chat started (type 'quit' to exit)")

View File

@@ -2,11 +2,14 @@ from abc import ABC, abstractmethod
import numpy as np import numpy as np
from typing import Dict, Any, List, Literal from typing import Dict, Any, List, Literal
class LeannBackendBuilderInterface(ABC): class LeannBackendBuilderInterface(ABC):
"""Backend interface for building indexes""" """Backend interface for building indexes"""
@abstractmethod @abstractmethod
def build(self, data: np.ndarray, ids: List[str], index_path: str, **kwargs) -> None: def build(
self, data: np.ndarray, ids: List[str], index_path: str, **kwargs
) -> None:
"""Build index """Build index
Args: Args:
@@ -17,6 +20,7 @@ class LeannBackendBuilderInterface(ABC):
""" """
pass pass
class LeannBackendSearcherInterface(ABC): class LeannBackendSearcherInterface(ABC):
"""Backend interface for searching""" """Backend interface for searching"""
@@ -31,14 +35,18 @@ class LeannBackendSearcherInterface(ABC):
pass pass
@abstractmethod @abstractmethod
def search(self, query: np.ndarray, top_k: int, def search(
complexity: int = 64, self,
beam_width: int = 1, query: np.ndarray,
prune_ratio: float = 0.0, top_k: int,
recompute_embeddings: bool = False, complexity: int = 64,
pruning_strategy: Literal["global", "local", "proportional"] = "global", beam_width: int = 1,
zmq_port: int = 5557, prune_ratio: float = 0.0,
**kwargs) -> Dict[str, Any]: recompute_embeddings: bool = False,
pruning_strategy: Literal["global", "local", "proportional"] = "global",
zmq_port: int = 5557,
**kwargs,
) -> Dict[str, Any]:
"""Search for nearest neighbors """Search for nearest neighbors
Args: Args:
@@ -57,6 +65,23 @@ class LeannBackendSearcherInterface(ABC):
""" """
pass pass
@abstractmethod
def compute_query_embedding(
self, query: str, zmq_port: int = 5557, use_server_if_available: bool = True
) -> np.ndarray:
"""Compute embedding for a query string
Args:
query: The query string to embed
zmq_port: ZMQ port for embedding server
use_server_if_available: Whether to try using embedding server first
Returns:
Query embedding as numpy array with shape (1, D)
"""
pass
class LeannBackendFactoryInterface(ABC): class LeannBackendFactoryInterface(ABC):
"""Backend factory interface""" """Backend factory interface"""

View File

@@ -89,6 +89,72 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
if not server_started: if not server_started:
raise RuntimeError(f"Failed to start embedding server on port {port}") raise RuntimeError(f"Failed to start embedding server on port {port}")
def compute_query_embedding(
self, query: str, zmq_port: int = 5557, use_server_if_available: bool = True
) -> np.ndarray:
"""
Compute embedding for a query string.
Args:
query: The query string to embed
zmq_port: ZMQ port for embedding server
use_server_if_available: Whether to try using embedding server first
Returns:
Query embedding as numpy array
"""
# Try to use embedding server if available and requested
if (
use_server_if_available
and self.embedding_server_manager
and self.embedding_server_manager.server_process
):
try:
return self._compute_embedding_via_server([query], zmq_port)[
0:1
] # Return (1, D) shape
except Exception as e:
print(f"⚠️ Embedding server failed: {e}")
print("⏭️ Falling back to direct model loading...")
# Fallback to direct computation
from .api import compute_embeddings
use_mlx = self.meta.get("use_mlx", False)
return compute_embeddings([query], self.embedding_model, use_mlx)
def _compute_embedding_via_server(self, chunks: list, zmq_port: int) -> np.ndarray:
"""Compute embeddings using the ZMQ embedding server."""
import zmq
import msgpack
try:
context = zmq.Context()
socket = context.socket(zmq.REQ)
socket.setsockopt(zmq.RCVTIMEO, 30000) # 30 second timeout
socket.connect(f"tcp://localhost:{zmq_port}")
# Send embedding request
request = chunks
request_bytes = msgpack.packb(request)
socket.send(request_bytes)
# Wait for response
response_bytes = socket.recv()
response = msgpack.unpackb(response_bytes)
socket.close()
context.term()
# Convert response to numpy array
if isinstance(response, list) and len(response) > 0:
return np.array(response, dtype=np.float32)
else:
raise RuntimeError("Invalid response from embedding server")
except Exception as e:
raise RuntimeError(f"Failed to compute embeddings via server: {e}")
@abstractmethod @abstractmethod
def search( def search(
self, self,