perf: reuse embedding server for query embed

This commit is contained in:
Andy Lee
2025-07-16 16:12:15 -07:00
parent 2a1a152073
commit f77c4e38cb
4 changed files with 169 additions and 38 deletions

View File

@@ -7,7 +7,7 @@ import json
import pickle
import numpy as np
from pathlib import Path
from typing import List, Dict, Any, Optional
from typing import List, Dict, Any, Optional, Literal
from dataclasses import dataclass, field
import uuid
import torch
@@ -250,22 +250,41 @@ class LeannSearcher:
final_kwargs["enable_warmup"] = enable_warmup
self.backend_impl = backend_factory.searcher(index_path, **final_kwargs)
def search(self, query: str, top_k: int = 5, **search_kwargs) -> List[SearchResult]:
def search(
self,
query: str,
top_k: int = 5,
complexity: int = 64,
beam_width: int = 1,
prune_ratio: float = 0.0,
recompute_embeddings: bool = False,
pruning_strategy: Literal["global", "local", "proportional"] = "global",
zmq_port: int = 5557,
**kwargs,
) -> List[SearchResult]:
print("🔍 DEBUG LeannSearcher.search() called:")
print(f" Query: '{query}'")
print(f" Top_k: {top_k}")
print(f" Search kwargs: {search_kwargs}")
print(f" Additional kwargs: {kwargs}")
query_embedding = compute_embeddings(
[query], self.embedding_model, self.use_mlx
)
# Use backend's compute_query_embedding method
# This will automatically use embedding server if available and needed
query_embedding = self.backend_impl.compute_query_embedding(query, zmq_port)
print(f" Generated embedding shape: {query_embedding.shape}")
print(f"🔍 DEBUG Query embedding first 10 values: {query_embedding[0][:10]}")
print(f"🔍 DEBUG Query embedding norm: {np.linalg.norm(query_embedding[0])}")
# Add use_mlx to search kwargs
search_kwargs["use_mlx"] = self.use_mlx
results = self.backend_impl.search(query_embedding, top_k, **search_kwargs)
results = self.backend_impl.search(
query_embedding,
top_k,
complexity=complexity,
beam_width=beam_width,
prune_ratio=prune_ratio,
recompute_embeddings=recompute_embeddings,
pruning_strategy=pruning_strategy,
zmq_port=zmq_port,
**kwargs,
)
print(
f" Backend returned: labels={len(results.get('labels', [[]])[0])} results"
)
@@ -309,8 +328,33 @@ class LeannChat:
self.searcher = LeannSearcher(index_path, enable_warmup=enable_warmup, **kwargs)
self.llm = get_llm(llm_config)
def ask(self, question: str, top_k=5, **kwargs):
results = self.searcher.search(question, top_k=top_k, **kwargs)
def ask(
self,
question: str,
top_k: int = 5,
complexity: int = 64,
beam_width: int = 1,
prune_ratio: float = 0.0,
recompute_embeddings: bool = False,
pruning_strategy: Literal["global", "local", "proportional"] = "global",
zmq_port: int = 5557,
llm_kwargs: Optional[Dict[str, Any]] = None,
**search_kwargs,
):
if llm_kwargs is None:
llm_kwargs = {}
results = self.searcher.search(
question,
top_k=top_k,
complexity=complexity,
beam_width=beam_width,
prune_ratio=prune_ratio,
recompute_embeddings=recompute_embeddings,
pruning_strategy=pruning_strategy,
zmq_port=zmq_port,
**search_kwargs,
)
context = "\n\n".join([r.text for r in results])
prompt = (
"Here is some retrieved context that might help answer your question:\n\n"
@@ -318,7 +362,7 @@ class LeannChat:
f"Question: {question}\n\n"
"Please provide the best answer you can based on this context and your knowledge."
)
return self.llm.ask(prompt, **kwargs.get("llm_kwargs", {}))
return self.llm.ask(prompt, **llm_kwargs)
def start_interactive(self):
print("\nLeann Chat started (type 'quit' to exit)")