perf: reuse embedding server for query embed
This commit is contained in:
@@ -1,7 +1,3 @@
|
||||
import faulthandler
|
||||
|
||||
faulthandler.enable()
|
||||
|
||||
import argparse
|
||||
from llama_index.core import SimpleDirectoryReader, Settings
|
||||
from llama_index.core.node_parser import SentenceSplitter
|
||||
@@ -62,7 +58,7 @@ async def main(args):
|
||||
|
||||
print(f"\n[PHASE 2] Starting Leann chat session...")
|
||||
|
||||
llm_config = {"type": "hf", "model": "Qwen/Qwen3-8B"}
|
||||
llm_config = {"type": "hf", "model": "Qwen/Qwen3-4B"}
|
||||
|
||||
chat = LeannChat(index_path=INDEX_PATH, llm_config=llm_config)
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ import json
|
||||
import pickle
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Any, Optional
|
||||
from typing import List, Dict, Any, Optional, Literal
|
||||
from dataclasses import dataclass, field
|
||||
import uuid
|
||||
import torch
|
||||
@@ -250,22 +250,41 @@ class LeannSearcher:
|
||||
final_kwargs["enable_warmup"] = enable_warmup
|
||||
self.backend_impl = backend_factory.searcher(index_path, **final_kwargs)
|
||||
|
||||
def search(self, query: str, top_k: int = 5, **search_kwargs) -> List[SearchResult]:
|
||||
def search(
|
||||
self,
|
||||
query: str,
|
||||
top_k: int = 5,
|
||||
complexity: int = 64,
|
||||
beam_width: int = 1,
|
||||
prune_ratio: float = 0.0,
|
||||
recompute_embeddings: bool = False,
|
||||
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
||||
zmq_port: int = 5557,
|
||||
**kwargs,
|
||||
) -> List[SearchResult]:
|
||||
print("🔍 DEBUG LeannSearcher.search() called:")
|
||||
print(f" Query: '{query}'")
|
||||
print(f" Top_k: {top_k}")
|
||||
print(f" Search kwargs: {search_kwargs}")
|
||||
print(f" Additional kwargs: {kwargs}")
|
||||
|
||||
query_embedding = compute_embeddings(
|
||||
[query], self.embedding_model, self.use_mlx
|
||||
)
|
||||
# Use backend's compute_query_embedding method
|
||||
# This will automatically use embedding server if available and needed
|
||||
query_embedding = self.backend_impl.compute_query_embedding(query, zmq_port)
|
||||
print(f" Generated embedding shape: {query_embedding.shape}")
|
||||
print(f"🔍 DEBUG Query embedding first 10 values: {query_embedding[0][:10]}")
|
||||
print(f"🔍 DEBUG Query embedding norm: {np.linalg.norm(query_embedding[0])}")
|
||||
|
||||
# Add use_mlx to search kwargs
|
||||
search_kwargs["use_mlx"] = self.use_mlx
|
||||
results = self.backend_impl.search(query_embedding, top_k, **search_kwargs)
|
||||
results = self.backend_impl.search(
|
||||
query_embedding,
|
||||
top_k,
|
||||
complexity=complexity,
|
||||
beam_width=beam_width,
|
||||
prune_ratio=prune_ratio,
|
||||
recompute_embeddings=recompute_embeddings,
|
||||
pruning_strategy=pruning_strategy,
|
||||
zmq_port=zmq_port,
|
||||
**kwargs,
|
||||
)
|
||||
print(
|
||||
f" Backend returned: labels={len(results.get('labels', [[]])[0])} results"
|
||||
)
|
||||
@@ -309,8 +328,33 @@ class LeannChat:
|
||||
self.searcher = LeannSearcher(index_path, enable_warmup=enable_warmup, **kwargs)
|
||||
self.llm = get_llm(llm_config)
|
||||
|
||||
def ask(self, question: str, top_k=5, **kwargs):
|
||||
results = self.searcher.search(question, top_k=top_k, **kwargs)
|
||||
def ask(
|
||||
self,
|
||||
question: str,
|
||||
top_k: int = 5,
|
||||
complexity: int = 64,
|
||||
beam_width: int = 1,
|
||||
prune_ratio: float = 0.0,
|
||||
recompute_embeddings: bool = False,
|
||||
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
||||
zmq_port: int = 5557,
|
||||
llm_kwargs: Optional[Dict[str, Any]] = None,
|
||||
**search_kwargs,
|
||||
):
|
||||
if llm_kwargs is None:
|
||||
llm_kwargs = {}
|
||||
|
||||
results = self.searcher.search(
|
||||
question,
|
||||
top_k=top_k,
|
||||
complexity=complexity,
|
||||
beam_width=beam_width,
|
||||
prune_ratio=prune_ratio,
|
||||
recompute_embeddings=recompute_embeddings,
|
||||
pruning_strategy=pruning_strategy,
|
||||
zmq_port=zmq_port,
|
||||
**search_kwargs,
|
||||
)
|
||||
context = "\n\n".join([r.text for r in results])
|
||||
prompt = (
|
||||
"Here is some retrieved context that might help answer your question:\n\n"
|
||||
@@ -318,7 +362,7 @@ class LeannChat:
|
||||
f"Question: {question}\n\n"
|
||||
"Please provide the best answer you can based on this context and your knowledge."
|
||||
)
|
||||
return self.llm.ask(prompt, **kwargs.get("llm_kwargs", {}))
|
||||
return self.llm.ask(prompt, **llm_kwargs)
|
||||
|
||||
def start_interactive(self):
|
||||
print("\nLeann Chat started (type 'quit' to exit)")
|
||||
|
||||
@@ -2,13 +2,16 @@ from abc import ABC, abstractmethod
|
||||
import numpy as np
|
||||
from typing import Dict, Any, List, Literal
|
||||
|
||||
|
||||
class LeannBackendBuilderInterface(ABC):
|
||||
"""Backend interface for building indexes"""
|
||||
|
||||
@abstractmethod
|
||||
def build(self, data: np.ndarray, ids: List[str], index_path: str, **kwargs) -> None:
|
||||
|
||||
@abstractmethod
|
||||
def build(
|
||||
self, data: np.ndarray, ids: List[str], index_path: str, **kwargs
|
||||
) -> None:
|
||||
"""Build index
|
||||
|
||||
|
||||
Args:
|
||||
data: Vector data (N, D)
|
||||
ids: List of string IDs for each vector
|
||||
@@ -17,30 +20,35 @@ class LeannBackendBuilderInterface(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class LeannBackendSearcherInterface(ABC):
|
||||
"""Backend interface for searching"""
|
||||
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self, index_path: str, **kwargs):
|
||||
"""Initialize searcher
|
||||
|
||||
|
||||
Args:
|
||||
index_path: Path to index file
|
||||
**kwargs: Backend-specific loading parameters
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@abstractmethod
|
||||
def search(self, query: np.ndarray, top_k: int,
|
||||
complexity: int = 64,
|
||||
beam_width: int = 1,
|
||||
prune_ratio: float = 0.0,
|
||||
recompute_embeddings: bool = False,
|
||||
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
||||
zmq_port: int = 5557,
|
||||
**kwargs) -> Dict[str, Any]:
|
||||
def search(
|
||||
self,
|
||||
query: np.ndarray,
|
||||
top_k: int,
|
||||
complexity: int = 64,
|
||||
beam_width: int = 1,
|
||||
prune_ratio: float = 0.0,
|
||||
recompute_embeddings: bool = False,
|
||||
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
||||
zmq_port: int = 5557,
|
||||
**kwargs,
|
||||
) -> Dict[str, Any]:
|
||||
"""Search for nearest neighbors
|
||||
|
||||
|
||||
Args:
|
||||
query: Query vectors (B, D) where B is batch size, D is dimension
|
||||
top_k: Number of nearest neighbors to return
|
||||
@@ -51,23 +59,40 @@ class LeannBackendSearcherInterface(ABC):
|
||||
pruning_strategy: PQ candidate selection strategy - "global" (default), "local", or "proportional"
|
||||
zmq_port: ZMQ port for embedding server communication
|
||||
**kwargs: Backend-specific parameters
|
||||
|
||||
|
||||
Returns:
|
||||
{"labels": [...], "distances": [...]}
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def compute_query_embedding(
|
||||
self, query: str, zmq_port: int = 5557, use_server_if_available: bool = True
|
||||
) -> np.ndarray:
|
||||
"""Compute embedding for a query string
|
||||
|
||||
Args:
|
||||
query: The query string to embed
|
||||
zmq_port: ZMQ port for embedding server
|
||||
use_server_if_available: Whether to try using embedding server first
|
||||
|
||||
Returns:
|
||||
Query embedding as numpy array with shape (1, D)
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class LeannBackendFactoryInterface(ABC):
|
||||
"""Backend factory interface"""
|
||||
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def builder(**kwargs) -> LeannBackendBuilderInterface:
|
||||
"""Create Builder instance"""
|
||||
pass
|
||||
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
@abstractmethod
|
||||
def searcher(index_path: str, **kwargs) -> LeannBackendSearcherInterface:
|
||||
"""Create Searcher instance"""
|
||||
pass
|
||||
pass
|
||||
|
||||
@@ -89,6 +89,72 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
|
||||
if not server_started:
|
||||
raise RuntimeError(f"Failed to start embedding server on port {port}")
|
||||
|
||||
def compute_query_embedding(
|
||||
self, query: str, zmq_port: int = 5557, use_server_if_available: bool = True
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Compute embedding for a query string.
|
||||
|
||||
Args:
|
||||
query: The query string to embed
|
||||
zmq_port: ZMQ port for embedding server
|
||||
use_server_if_available: Whether to try using embedding server first
|
||||
|
||||
Returns:
|
||||
Query embedding as numpy array
|
||||
"""
|
||||
# Try to use embedding server if available and requested
|
||||
if (
|
||||
use_server_if_available
|
||||
and self.embedding_server_manager
|
||||
and self.embedding_server_manager.server_process
|
||||
):
|
||||
try:
|
||||
return self._compute_embedding_via_server([query], zmq_port)[
|
||||
0:1
|
||||
] # Return (1, D) shape
|
||||
except Exception as e:
|
||||
print(f"⚠️ Embedding server failed: {e}")
|
||||
print("⏭️ Falling back to direct model loading...")
|
||||
|
||||
# Fallback to direct computation
|
||||
from .api import compute_embeddings
|
||||
|
||||
use_mlx = self.meta.get("use_mlx", False)
|
||||
return compute_embeddings([query], self.embedding_model, use_mlx)
|
||||
|
||||
def _compute_embedding_via_server(self, chunks: list, zmq_port: int) -> np.ndarray:
|
||||
"""Compute embeddings using the ZMQ embedding server."""
|
||||
import zmq
|
||||
import msgpack
|
||||
|
||||
try:
|
||||
context = zmq.Context()
|
||||
socket = context.socket(zmq.REQ)
|
||||
socket.setsockopt(zmq.RCVTIMEO, 30000) # 30 second timeout
|
||||
socket.connect(f"tcp://localhost:{zmq_port}")
|
||||
|
||||
# Send embedding request
|
||||
request = chunks
|
||||
request_bytes = msgpack.packb(request)
|
||||
socket.send(request_bytes)
|
||||
|
||||
# Wait for response
|
||||
response_bytes = socket.recv()
|
||||
response = msgpack.unpackb(response_bytes)
|
||||
|
||||
socket.close()
|
||||
context.term()
|
||||
|
||||
# Convert response to numpy array
|
||||
if isinstance(response, list) and len(response) > 0:
|
||||
return np.array(response, dtype=np.float32)
|
||||
else:
|
||||
raise RuntimeError("Invalid response from embedding server")
|
||||
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to compute embeddings via server: {e}")
|
||||
|
||||
@abstractmethod
|
||||
def search(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user