perf: reuse embedding server for query embed
This commit is contained in:
@@ -1,7 +1,3 @@
|
|||||||
import faulthandler
|
|
||||||
|
|
||||||
faulthandler.enable()
|
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
from llama_index.core import SimpleDirectoryReader, Settings
|
from llama_index.core import SimpleDirectoryReader, Settings
|
||||||
from llama_index.core.node_parser import SentenceSplitter
|
from llama_index.core.node_parser import SentenceSplitter
|
||||||
@@ -62,7 +58,7 @@ async def main(args):
|
|||||||
|
|
||||||
print(f"\n[PHASE 2] Starting Leann chat session...")
|
print(f"\n[PHASE 2] Starting Leann chat session...")
|
||||||
|
|
||||||
llm_config = {"type": "hf", "model": "Qwen/Qwen3-8B"}
|
llm_config = {"type": "hf", "model": "Qwen/Qwen3-4B"}
|
||||||
|
|
||||||
chat = LeannChat(index_path=INDEX_PATH, llm_config=llm_config)
|
chat = LeannChat(index_path=INDEX_PATH, llm_config=llm_config)
|
||||||
|
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ import json
|
|||||||
import pickle
|
import pickle
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Dict, Any, Optional
|
from typing import List, Dict, Any, Optional, Literal
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
import uuid
|
import uuid
|
||||||
import torch
|
import torch
|
||||||
@@ -250,22 +250,41 @@ class LeannSearcher:
|
|||||||
final_kwargs["enable_warmup"] = enable_warmup
|
final_kwargs["enable_warmup"] = enable_warmup
|
||||||
self.backend_impl = backend_factory.searcher(index_path, **final_kwargs)
|
self.backend_impl = backend_factory.searcher(index_path, **final_kwargs)
|
||||||
|
|
||||||
def search(self, query: str, top_k: int = 5, **search_kwargs) -> List[SearchResult]:
|
def search(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
top_k: int = 5,
|
||||||
|
complexity: int = 64,
|
||||||
|
beam_width: int = 1,
|
||||||
|
prune_ratio: float = 0.0,
|
||||||
|
recompute_embeddings: bool = False,
|
||||||
|
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
||||||
|
zmq_port: int = 5557,
|
||||||
|
**kwargs,
|
||||||
|
) -> List[SearchResult]:
|
||||||
print("🔍 DEBUG LeannSearcher.search() called:")
|
print("🔍 DEBUG LeannSearcher.search() called:")
|
||||||
print(f" Query: '{query}'")
|
print(f" Query: '{query}'")
|
||||||
print(f" Top_k: {top_k}")
|
print(f" Top_k: {top_k}")
|
||||||
print(f" Search kwargs: {search_kwargs}")
|
print(f" Additional kwargs: {kwargs}")
|
||||||
|
|
||||||
query_embedding = compute_embeddings(
|
# Use backend's compute_query_embedding method
|
||||||
[query], self.embedding_model, self.use_mlx
|
# This will automatically use embedding server if available and needed
|
||||||
)
|
query_embedding = self.backend_impl.compute_query_embedding(query, zmq_port)
|
||||||
print(f" Generated embedding shape: {query_embedding.shape}")
|
print(f" Generated embedding shape: {query_embedding.shape}")
|
||||||
print(f"🔍 DEBUG Query embedding first 10 values: {query_embedding[0][:10]}")
|
print(f"🔍 DEBUG Query embedding first 10 values: {query_embedding[0][:10]}")
|
||||||
print(f"🔍 DEBUG Query embedding norm: {np.linalg.norm(query_embedding[0])}")
|
print(f"🔍 DEBUG Query embedding norm: {np.linalg.norm(query_embedding[0])}")
|
||||||
|
|
||||||
# Add use_mlx to search kwargs
|
results = self.backend_impl.search(
|
||||||
search_kwargs["use_mlx"] = self.use_mlx
|
query_embedding,
|
||||||
results = self.backend_impl.search(query_embedding, top_k, **search_kwargs)
|
top_k,
|
||||||
|
complexity=complexity,
|
||||||
|
beam_width=beam_width,
|
||||||
|
prune_ratio=prune_ratio,
|
||||||
|
recompute_embeddings=recompute_embeddings,
|
||||||
|
pruning_strategy=pruning_strategy,
|
||||||
|
zmq_port=zmq_port,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
print(
|
print(
|
||||||
f" Backend returned: labels={len(results.get('labels', [[]])[0])} results"
|
f" Backend returned: labels={len(results.get('labels', [[]])[0])} results"
|
||||||
)
|
)
|
||||||
@@ -309,8 +328,33 @@ class LeannChat:
|
|||||||
self.searcher = LeannSearcher(index_path, enable_warmup=enable_warmup, **kwargs)
|
self.searcher = LeannSearcher(index_path, enable_warmup=enable_warmup, **kwargs)
|
||||||
self.llm = get_llm(llm_config)
|
self.llm = get_llm(llm_config)
|
||||||
|
|
||||||
def ask(self, question: str, top_k=5, **kwargs):
|
def ask(
|
||||||
results = self.searcher.search(question, top_k=top_k, **kwargs)
|
self,
|
||||||
|
question: str,
|
||||||
|
top_k: int = 5,
|
||||||
|
complexity: int = 64,
|
||||||
|
beam_width: int = 1,
|
||||||
|
prune_ratio: float = 0.0,
|
||||||
|
recompute_embeddings: bool = False,
|
||||||
|
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
||||||
|
zmq_port: int = 5557,
|
||||||
|
llm_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
|
**search_kwargs,
|
||||||
|
):
|
||||||
|
if llm_kwargs is None:
|
||||||
|
llm_kwargs = {}
|
||||||
|
|
||||||
|
results = self.searcher.search(
|
||||||
|
question,
|
||||||
|
top_k=top_k,
|
||||||
|
complexity=complexity,
|
||||||
|
beam_width=beam_width,
|
||||||
|
prune_ratio=prune_ratio,
|
||||||
|
recompute_embeddings=recompute_embeddings,
|
||||||
|
pruning_strategy=pruning_strategy,
|
||||||
|
zmq_port=zmq_port,
|
||||||
|
**search_kwargs,
|
||||||
|
)
|
||||||
context = "\n\n".join([r.text for r in results])
|
context = "\n\n".join([r.text for r in results])
|
||||||
prompt = (
|
prompt = (
|
||||||
"Here is some retrieved context that might help answer your question:\n\n"
|
"Here is some retrieved context that might help answer your question:\n\n"
|
||||||
@@ -318,7 +362,7 @@ class LeannChat:
|
|||||||
f"Question: {question}\n\n"
|
f"Question: {question}\n\n"
|
||||||
"Please provide the best answer you can based on this context and your knowledge."
|
"Please provide the best answer you can based on this context and your knowledge."
|
||||||
)
|
)
|
||||||
return self.llm.ask(prompt, **kwargs.get("llm_kwargs", {}))
|
return self.llm.ask(prompt, **llm_kwargs)
|
||||||
|
|
||||||
def start_interactive(self):
|
def start_interactive(self):
|
||||||
print("\nLeann Chat started (type 'quit' to exit)")
|
print("\nLeann Chat started (type 'quit' to exit)")
|
||||||
|
|||||||
@@ -2,11 +2,14 @@ from abc import ABC, abstractmethod
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from typing import Dict, Any, List, Literal
|
from typing import Dict, Any, List, Literal
|
||||||
|
|
||||||
|
|
||||||
class LeannBackendBuilderInterface(ABC):
|
class LeannBackendBuilderInterface(ABC):
|
||||||
"""Backend interface for building indexes"""
|
"""Backend interface for building indexes"""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def build(self, data: np.ndarray, ids: List[str], index_path: str, **kwargs) -> None:
|
def build(
|
||||||
|
self, data: np.ndarray, ids: List[str], index_path: str, **kwargs
|
||||||
|
) -> None:
|
||||||
"""Build index
|
"""Build index
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -17,6 +20,7 @@ class LeannBackendBuilderInterface(ABC):
|
|||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class LeannBackendSearcherInterface(ABC):
|
class LeannBackendSearcherInterface(ABC):
|
||||||
"""Backend interface for searching"""
|
"""Backend interface for searching"""
|
||||||
|
|
||||||
@@ -31,14 +35,18 @@ class LeannBackendSearcherInterface(ABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def search(self, query: np.ndarray, top_k: int,
|
def search(
|
||||||
complexity: int = 64,
|
self,
|
||||||
beam_width: int = 1,
|
query: np.ndarray,
|
||||||
prune_ratio: float = 0.0,
|
top_k: int,
|
||||||
recompute_embeddings: bool = False,
|
complexity: int = 64,
|
||||||
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
beam_width: int = 1,
|
||||||
zmq_port: int = 5557,
|
prune_ratio: float = 0.0,
|
||||||
**kwargs) -> Dict[str, Any]:
|
recompute_embeddings: bool = False,
|
||||||
|
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
||||||
|
zmq_port: int = 5557,
|
||||||
|
**kwargs,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
"""Search for nearest neighbors
|
"""Search for nearest neighbors
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -57,6 +65,23 @@ class LeannBackendSearcherInterface(ABC):
|
|||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def compute_query_embedding(
|
||||||
|
self, query: str, zmq_port: int = 5557, use_server_if_available: bool = True
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""Compute embedding for a query string
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: The query string to embed
|
||||||
|
zmq_port: ZMQ port for embedding server
|
||||||
|
use_server_if_available: Whether to try using embedding server first
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Query embedding as numpy array with shape (1, D)
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class LeannBackendFactoryInterface(ABC):
|
class LeannBackendFactoryInterface(ABC):
|
||||||
"""Backend factory interface"""
|
"""Backend factory interface"""
|
||||||
|
|
||||||
|
|||||||
@@ -89,6 +89,72 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
|
|||||||
if not server_started:
|
if not server_started:
|
||||||
raise RuntimeError(f"Failed to start embedding server on port {port}")
|
raise RuntimeError(f"Failed to start embedding server on port {port}")
|
||||||
|
|
||||||
|
def compute_query_embedding(
|
||||||
|
self, query: str, zmq_port: int = 5557, use_server_if_available: bool = True
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Compute embedding for a query string.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: The query string to embed
|
||||||
|
zmq_port: ZMQ port for embedding server
|
||||||
|
use_server_if_available: Whether to try using embedding server first
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Query embedding as numpy array
|
||||||
|
"""
|
||||||
|
# Try to use embedding server if available and requested
|
||||||
|
if (
|
||||||
|
use_server_if_available
|
||||||
|
and self.embedding_server_manager
|
||||||
|
and self.embedding_server_manager.server_process
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
return self._compute_embedding_via_server([query], zmq_port)[
|
||||||
|
0:1
|
||||||
|
] # Return (1, D) shape
|
||||||
|
except Exception as e:
|
||||||
|
print(f"⚠️ Embedding server failed: {e}")
|
||||||
|
print("⏭️ Falling back to direct model loading...")
|
||||||
|
|
||||||
|
# Fallback to direct computation
|
||||||
|
from .api import compute_embeddings
|
||||||
|
|
||||||
|
use_mlx = self.meta.get("use_mlx", False)
|
||||||
|
return compute_embeddings([query], self.embedding_model, use_mlx)
|
||||||
|
|
||||||
|
def _compute_embedding_via_server(self, chunks: list, zmq_port: int) -> np.ndarray:
|
||||||
|
"""Compute embeddings using the ZMQ embedding server."""
|
||||||
|
import zmq
|
||||||
|
import msgpack
|
||||||
|
|
||||||
|
try:
|
||||||
|
context = zmq.Context()
|
||||||
|
socket = context.socket(zmq.REQ)
|
||||||
|
socket.setsockopt(zmq.RCVTIMEO, 30000) # 30 second timeout
|
||||||
|
socket.connect(f"tcp://localhost:{zmq_port}")
|
||||||
|
|
||||||
|
# Send embedding request
|
||||||
|
request = chunks
|
||||||
|
request_bytes = msgpack.packb(request)
|
||||||
|
socket.send(request_bytes)
|
||||||
|
|
||||||
|
# Wait for response
|
||||||
|
response_bytes = socket.recv()
|
||||||
|
response = msgpack.unpackb(response_bytes)
|
||||||
|
|
||||||
|
socket.close()
|
||||||
|
context.term()
|
||||||
|
|
||||||
|
# Convert response to numpy array
|
||||||
|
if isinstance(response, list) and len(response) > 0:
|
||||||
|
return np.array(response, dtype=np.float32)
|
||||||
|
else:
|
||||||
|
raise RuntimeError("Invalid response from embedding server")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"Failed to compute embeddings via server: {e}")
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def search(
|
def search(
|
||||||
self,
|
self,
|
||||||
|
|||||||
Reference in New Issue
Block a user