perf: reuse embedding server for query embed
This commit is contained in:
@@ -7,7 +7,7 @@ import json
|
||||
import pickle
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Any, Optional
|
||||
from typing import List, Dict, Any, Optional, Literal
|
||||
from dataclasses import dataclass, field
|
||||
import uuid
|
||||
import torch
|
||||
@@ -250,22 +250,41 @@ class LeannSearcher:
|
||||
final_kwargs["enable_warmup"] = enable_warmup
|
||||
self.backend_impl = backend_factory.searcher(index_path, **final_kwargs)
|
||||
|
||||
def search(self, query: str, top_k: int = 5, **search_kwargs) -> List[SearchResult]:
|
||||
def search(
|
||||
self,
|
||||
query: str,
|
||||
top_k: int = 5,
|
||||
complexity: int = 64,
|
||||
beam_width: int = 1,
|
||||
prune_ratio: float = 0.0,
|
||||
recompute_embeddings: bool = False,
|
||||
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
||||
zmq_port: int = 5557,
|
||||
**kwargs,
|
||||
) -> List[SearchResult]:
|
||||
print("🔍 DEBUG LeannSearcher.search() called:")
|
||||
print(f" Query: '{query}'")
|
||||
print(f" Top_k: {top_k}")
|
||||
print(f" Search kwargs: {search_kwargs}")
|
||||
print(f" Additional kwargs: {kwargs}")
|
||||
|
||||
query_embedding = compute_embeddings(
|
||||
[query], self.embedding_model, self.use_mlx
|
||||
)
|
||||
# Use backend's compute_query_embedding method
|
||||
# This will automatically use embedding server if available and needed
|
||||
query_embedding = self.backend_impl.compute_query_embedding(query, zmq_port)
|
||||
print(f" Generated embedding shape: {query_embedding.shape}")
|
||||
print(f"🔍 DEBUG Query embedding first 10 values: {query_embedding[0][:10]}")
|
||||
print(f"🔍 DEBUG Query embedding norm: {np.linalg.norm(query_embedding[0])}")
|
||||
|
||||
# Add use_mlx to search kwargs
|
||||
search_kwargs["use_mlx"] = self.use_mlx
|
||||
results = self.backend_impl.search(query_embedding, top_k, **search_kwargs)
|
||||
results = self.backend_impl.search(
|
||||
query_embedding,
|
||||
top_k,
|
||||
complexity=complexity,
|
||||
beam_width=beam_width,
|
||||
prune_ratio=prune_ratio,
|
||||
recompute_embeddings=recompute_embeddings,
|
||||
pruning_strategy=pruning_strategy,
|
||||
zmq_port=zmq_port,
|
||||
**kwargs,
|
||||
)
|
||||
print(
|
||||
f" Backend returned: labels={len(results.get('labels', [[]])[0])} results"
|
||||
)
|
||||
@@ -309,8 +328,33 @@ class LeannChat:
|
||||
self.searcher = LeannSearcher(index_path, enable_warmup=enable_warmup, **kwargs)
|
||||
self.llm = get_llm(llm_config)
|
||||
|
||||
def ask(self, question: str, top_k=5, **kwargs):
|
||||
results = self.searcher.search(question, top_k=top_k, **kwargs)
|
||||
def ask(
|
||||
self,
|
||||
question: str,
|
||||
top_k: int = 5,
|
||||
complexity: int = 64,
|
||||
beam_width: int = 1,
|
||||
prune_ratio: float = 0.0,
|
||||
recompute_embeddings: bool = False,
|
||||
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
||||
zmq_port: int = 5557,
|
||||
llm_kwargs: Optional[Dict[str, Any]] = None,
|
||||
**search_kwargs,
|
||||
):
|
||||
if llm_kwargs is None:
|
||||
llm_kwargs = {}
|
||||
|
||||
results = self.searcher.search(
|
||||
question,
|
||||
top_k=top_k,
|
||||
complexity=complexity,
|
||||
beam_width=beam_width,
|
||||
prune_ratio=prune_ratio,
|
||||
recompute_embeddings=recompute_embeddings,
|
||||
pruning_strategy=pruning_strategy,
|
||||
zmq_port=zmq_port,
|
||||
**search_kwargs,
|
||||
)
|
||||
context = "\n\n".join([r.text for r in results])
|
||||
prompt = (
|
||||
"Here is some retrieved context that might help answer your question:\n\n"
|
||||
@@ -318,7 +362,7 @@ class LeannChat:
|
||||
f"Question: {question}\n\n"
|
||||
"Please provide the best answer you can based on this context and your knowledge."
|
||||
)
|
||||
return self.llm.ask(prompt, **kwargs.get("llm_kwargs", {}))
|
||||
return self.llm.ask(prompt, **llm_kwargs)
|
||||
|
||||
def start_interactive(self):
|
||||
print("\nLeann Chat started (type 'quit' to exit)")
|
||||
|
||||
Reference in New Issue
Block a user