Remove the strict=False parameter from zip() call in api.py as it was introduced in Python 3.10 and causes "TypeError: zip() takes no keyword arguments" in Python 3.9. The strict parameter controls whether zip() raises an exception when the iterables have different lengths. Since we're not relying on this behavior and the code works correctly without it, removing it maintains the same functionality while ensuring Python 3.9 compatibility. 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com>
659 lines
25 KiB
Python
659 lines
25 KiB
Python
"""
|
|
This file contains the core API for the LEANN project, now definitively updated
|
|
with the correct, original embedding logic from the user's reference code.
|
|
"""
|
|
|
|
import json
|
|
import logging
|
|
import pickle
|
|
import time
|
|
import warnings
|
|
from dataclasses import dataclass, field
|
|
from pathlib import Path
|
|
from typing import Any, Literal, Optional
|
|
|
|
import numpy as np
|
|
|
|
from leann.interface import LeannBackendSearcherInterface
|
|
|
|
from .chat import get_llm
|
|
from .interface import LeannBackendFactoryInterface
|
|
from .registry import BACKEND_REGISTRY
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def get_registered_backends() -> list[str]:
|
|
"""Get list of registered backend names."""
|
|
return list(BACKEND_REGISTRY.keys())
|
|
|
|
|
|
def compute_embeddings(
|
|
chunks: list[str],
|
|
model_name: str,
|
|
mode: str = "sentence-transformers",
|
|
use_server: bool = True,
|
|
port: Optional[int] = None,
|
|
is_build=False,
|
|
) -> np.ndarray:
|
|
"""
|
|
Computes embeddings using different backends.
|
|
|
|
Args:
|
|
chunks: List of text chunks to embed
|
|
model_name: Name of the embedding model
|
|
mode: Embedding backend mode. Options:
|
|
- "sentence-transformers": Use sentence-transformers library (default)
|
|
- "mlx": Use MLX backend for Apple Silicon
|
|
- "openai": Use OpenAI embedding API
|
|
use_server: Whether to use embedding server (True for search, False for build)
|
|
|
|
Returns:
|
|
numpy array of embeddings
|
|
"""
|
|
if use_server:
|
|
# Use embedding server (for search/query)
|
|
if port is None:
|
|
raise ValueError("port is required when use_server is True")
|
|
return compute_embeddings_via_server(chunks, model_name, port=port)
|
|
else:
|
|
# Use direct computation (for build_index)
|
|
from .embedding_compute import (
|
|
compute_embeddings as compute_embeddings_direct,
|
|
)
|
|
|
|
return compute_embeddings_direct(
|
|
chunks,
|
|
model_name,
|
|
mode=mode,
|
|
is_build=is_build,
|
|
)
|
|
|
|
|
|
def compute_embeddings_via_server(chunks: list[str], model_name: str, port: int) -> np.ndarray:
|
|
"""Computes embeddings using sentence-transformers.
|
|
|
|
Args:
|
|
chunks: List of text chunks to embed
|
|
model_name: Name of the sentence transformer model
|
|
"""
|
|
logger.info(
|
|
f"Computing embeddings for {len(chunks)} chunks using SentenceTransformer model '{model_name}' (via embedding server)..."
|
|
)
|
|
import msgpack
|
|
import numpy as np
|
|
import zmq
|
|
|
|
# Connect to embedding server
|
|
context = zmq.Context()
|
|
socket = context.socket(zmq.REQ)
|
|
socket.connect(f"tcp://localhost:{port}")
|
|
|
|
# Send chunks to server for embedding computation
|
|
request = chunks
|
|
socket.send(msgpack.packb(request))
|
|
|
|
# Receive embeddings from server
|
|
response = socket.recv()
|
|
embeddings_list = msgpack.unpackb(response)
|
|
|
|
# Convert back to numpy array
|
|
embeddings = np.array(embeddings_list, dtype=np.float32)
|
|
|
|
socket.close()
|
|
context.term()
|
|
|
|
return embeddings
|
|
|
|
|
|
@dataclass
|
|
class SearchResult:
|
|
id: str
|
|
score: float
|
|
text: str
|
|
metadata: dict[str, Any] = field(default_factory=dict)
|
|
|
|
|
|
class PassageManager:
|
|
def __init__(self, passage_sources: list[dict[str, Any]]):
|
|
self.offset_maps = {}
|
|
self.passage_files = {}
|
|
self.global_offset_map = {} # Combined map for fast lookup
|
|
|
|
for source in passage_sources:
|
|
assert source["type"] == "jsonl", "only jsonl is supported"
|
|
passage_file = source["path"]
|
|
index_file = source["index_path"] # .idx file
|
|
|
|
# Fix path resolution for Colab and other environments
|
|
if not Path(index_file).is_absolute():
|
|
# If relative path, try to resolve it properly
|
|
index_file = str(Path(index_file).resolve())
|
|
|
|
if not Path(index_file).exists():
|
|
raise FileNotFoundError(f"Passage index file not found: {index_file}")
|
|
|
|
with open(index_file, "rb") as f:
|
|
offset_map = pickle.load(f)
|
|
self.offset_maps[passage_file] = offset_map
|
|
self.passage_files[passage_file] = passage_file
|
|
|
|
# Build global map for O(1) lookup
|
|
for passage_id, offset in offset_map.items():
|
|
self.global_offset_map[passage_id] = (passage_file, offset)
|
|
|
|
def get_passage(self, passage_id: str) -> dict[str, Any]:
|
|
if passage_id in self.global_offset_map:
|
|
passage_file, offset = self.global_offset_map[passage_id]
|
|
# Lazy file opening - only open when needed
|
|
with open(passage_file, encoding="utf-8") as f:
|
|
f.seek(offset)
|
|
return json.loads(f.readline())
|
|
raise KeyError(f"Passage ID not found: {passage_id}")
|
|
|
|
|
|
class LeannBuilder:
|
|
def __init__(
|
|
self,
|
|
backend_name: str,
|
|
embedding_model: str = "facebook/contriever",
|
|
dimensions: Optional[int] = None,
|
|
embedding_mode: str = "sentence-transformers",
|
|
**backend_kwargs,
|
|
):
|
|
self.backend_name = backend_name
|
|
backend_factory: Optional[LeannBackendFactoryInterface] = BACKEND_REGISTRY.get(backend_name)
|
|
if backend_factory is None:
|
|
raise ValueError(f"Backend '{backend_name}' not found or not registered.")
|
|
self.backend_factory = backend_factory
|
|
self.embedding_model = embedding_model
|
|
self.dimensions = dimensions
|
|
self.embedding_mode = embedding_mode
|
|
|
|
# Check if we need to use cosine distance for normalized embeddings
|
|
normalized_embeddings_models = {
|
|
# OpenAI models
|
|
("openai", "text-embedding-ada-002"),
|
|
("openai", "text-embedding-3-small"),
|
|
("openai", "text-embedding-3-large"),
|
|
# Voyage AI models
|
|
("voyage", "voyage-2"),
|
|
("voyage", "voyage-3"),
|
|
("voyage", "voyage-large-2"),
|
|
("voyage", "voyage-multilingual-2"),
|
|
("voyage", "voyage-code-2"),
|
|
# Cohere models
|
|
("cohere", "embed-english-v3.0"),
|
|
("cohere", "embed-multilingual-v3.0"),
|
|
("cohere", "embed-english-light-v3.0"),
|
|
("cohere", "embed-multilingual-light-v3.0"),
|
|
}
|
|
|
|
# Also check for patterns in model names
|
|
is_normalized = False
|
|
current_model_lower = embedding_model.lower()
|
|
current_mode_lower = embedding_mode.lower()
|
|
|
|
# Check exact matches
|
|
for mode, model in normalized_embeddings_models:
|
|
if (current_mode_lower == mode and current_model_lower == model) or (
|
|
mode in current_mode_lower and model in current_model_lower
|
|
):
|
|
is_normalized = True
|
|
break
|
|
|
|
# Check patterns
|
|
if not is_normalized:
|
|
# OpenAI patterns
|
|
if "openai" in current_mode_lower or "openai" in current_model_lower:
|
|
if any(
|
|
pattern in current_model_lower
|
|
for pattern in ["text-embedding", "ada", "3-small", "3-large"]
|
|
):
|
|
is_normalized = True
|
|
# Voyage patterns
|
|
elif "voyage" in current_mode_lower or "voyage" in current_model_lower:
|
|
is_normalized = True
|
|
# Cohere patterns
|
|
elif "cohere" in current_mode_lower or "cohere" in current_model_lower:
|
|
if "embed" in current_model_lower:
|
|
is_normalized = True
|
|
|
|
# Handle distance metric
|
|
if is_normalized and "distance_metric" not in backend_kwargs:
|
|
backend_kwargs["distance_metric"] = "cosine"
|
|
warnings.warn(
|
|
f"Detected normalized embeddings model '{embedding_model}' with mode '{embedding_mode}'. "
|
|
f"Automatically setting distance_metric='cosine' for optimal performance. "
|
|
f"Normalized embeddings (L2 norm = 1) should use cosine similarity instead of MIPS.",
|
|
UserWarning,
|
|
stacklevel=2,
|
|
)
|
|
elif is_normalized and backend_kwargs.get("distance_metric", "").lower() != "cosine":
|
|
current_metric = backend_kwargs.get("distance_metric", "mips")
|
|
warnings.warn(
|
|
f"Warning: Using '{current_metric}' distance metric with normalized embeddings model "
|
|
f"'{embedding_model}' may lead to suboptimal search results. "
|
|
f"Consider using 'cosine' distance metric for better performance.",
|
|
UserWarning,
|
|
stacklevel=2,
|
|
)
|
|
|
|
self.backend_kwargs = backend_kwargs
|
|
self.chunks: list[dict[str, Any]] = []
|
|
|
|
def add_text(self, text: str, metadata: Optional[dict[str, Any]] = None):
|
|
if metadata is None:
|
|
metadata = {}
|
|
passage_id = metadata.get("id", str(len(self.chunks)))
|
|
chunk_data = {"id": passage_id, "text": text, "metadata": metadata}
|
|
self.chunks.append(chunk_data)
|
|
|
|
def build_index(self, index_path: str):
|
|
if not self.chunks:
|
|
raise ValueError("No chunks added.")
|
|
if self.dimensions is None:
|
|
self.dimensions = len(
|
|
compute_embeddings(
|
|
["dummy"],
|
|
self.embedding_model,
|
|
self.embedding_mode,
|
|
use_server=False,
|
|
)[0]
|
|
)
|
|
path = Path(index_path)
|
|
index_dir = path.parent
|
|
index_name = path.name
|
|
index_dir.mkdir(parents=True, exist_ok=True)
|
|
passages_file = index_dir / f"{index_name}.passages.jsonl"
|
|
offset_file = index_dir / f"{index_name}.passages.idx"
|
|
offset_map = {}
|
|
with open(passages_file, "w", encoding="utf-8") as f:
|
|
try:
|
|
from tqdm import tqdm
|
|
|
|
chunk_iterator = tqdm(self.chunks, desc="Writing passages", unit="chunk")
|
|
except ImportError:
|
|
chunk_iterator = self.chunks
|
|
|
|
for chunk in chunk_iterator:
|
|
offset = f.tell()
|
|
json.dump(
|
|
{
|
|
"id": chunk["id"],
|
|
"text": chunk["text"],
|
|
"metadata": chunk["metadata"],
|
|
},
|
|
f,
|
|
ensure_ascii=False,
|
|
)
|
|
f.write("\n")
|
|
offset_map[chunk["id"]] = offset
|
|
with open(offset_file, "wb") as f:
|
|
pickle.dump(offset_map, f)
|
|
texts_to_embed = [c["text"] for c in self.chunks]
|
|
embeddings = compute_embeddings(
|
|
texts_to_embed,
|
|
self.embedding_model,
|
|
self.embedding_mode,
|
|
use_server=False,
|
|
is_build=True,
|
|
)
|
|
string_ids = [chunk["id"] for chunk in self.chunks]
|
|
current_backend_kwargs = {**self.backend_kwargs, "dimensions": self.dimensions}
|
|
builder_instance = self.backend_factory.builder(**current_backend_kwargs)
|
|
builder_instance.build(embeddings, string_ids, index_path, **current_backend_kwargs)
|
|
leann_meta_path = index_dir / f"{index_name}.meta.json"
|
|
meta_data = {
|
|
"version": "1.0",
|
|
"backend_name": self.backend_name,
|
|
"embedding_model": self.embedding_model,
|
|
"dimensions": self.dimensions,
|
|
"backend_kwargs": self.backend_kwargs,
|
|
"embedding_mode": self.embedding_mode,
|
|
"passage_sources": [
|
|
{
|
|
"type": "jsonl",
|
|
"path": str(passages_file),
|
|
"index_path": str(offset_file),
|
|
}
|
|
],
|
|
}
|
|
|
|
# Add storage status flags for HNSW backend
|
|
if self.backend_name == "hnsw":
|
|
is_compact = self.backend_kwargs.get("is_compact", True)
|
|
is_recompute = self.backend_kwargs.get("is_recompute", True)
|
|
meta_data["is_compact"] = is_compact
|
|
meta_data["is_pruned"] = (
|
|
is_compact and is_recompute
|
|
) # Pruned only if compact and recompute
|
|
with open(leann_meta_path, "w", encoding="utf-8") as f:
|
|
json.dump(meta_data, f, indent=2)
|
|
|
|
def build_index_from_embeddings(self, index_path: str, embeddings_file: str):
|
|
"""
|
|
Build an index from pre-computed embeddings stored in a pickle file.
|
|
|
|
Args:
|
|
index_path: Path where the index will be saved
|
|
embeddings_file: Path to pickle file containing (ids, embeddings) tuple
|
|
"""
|
|
# Load pre-computed embeddings
|
|
with open(embeddings_file, "rb") as f:
|
|
data = pickle.load(f)
|
|
|
|
if not isinstance(data, tuple) or len(data) != 2:
|
|
raise ValueError(
|
|
f"Invalid embeddings file format. Expected tuple with 2 elements, got {type(data)}"
|
|
)
|
|
|
|
ids, embeddings = data
|
|
|
|
if not isinstance(embeddings, np.ndarray):
|
|
raise ValueError(f"Expected embeddings to be numpy array, got {type(embeddings)}")
|
|
|
|
if len(ids) != embeddings.shape[0]:
|
|
raise ValueError(
|
|
f"Mismatch between number of IDs ({len(ids)}) and embeddings ({embeddings.shape[0]})"
|
|
)
|
|
|
|
# Validate/set dimensions
|
|
embedding_dim = embeddings.shape[1]
|
|
if self.dimensions is None:
|
|
self.dimensions = embedding_dim
|
|
elif self.dimensions != embedding_dim:
|
|
raise ValueError(f"Dimension mismatch: expected {self.dimensions}, got {embedding_dim}")
|
|
|
|
logger.info(
|
|
f"Building index from precomputed embeddings: {len(ids)} items, {embedding_dim} dimensions"
|
|
)
|
|
|
|
# Ensure we have text data for each embedding
|
|
if len(self.chunks) != len(ids):
|
|
# If no text chunks provided, create placeholder text entries
|
|
if not self.chunks:
|
|
logger.info("No text chunks provided, creating placeholder entries...")
|
|
for id_val in ids:
|
|
self.add_text(
|
|
f"Document {id_val}",
|
|
metadata={"id": str(id_val), "from_embeddings": True},
|
|
)
|
|
else:
|
|
raise ValueError(
|
|
f"Number of text chunks ({len(self.chunks)}) doesn't match number of embeddings ({len(ids)})"
|
|
)
|
|
|
|
# Build file structure
|
|
path = Path(index_path)
|
|
index_dir = path.parent
|
|
index_name = path.name
|
|
index_dir.mkdir(parents=True, exist_ok=True)
|
|
passages_file = index_dir / f"{index_name}.passages.jsonl"
|
|
offset_file = index_dir / f"{index_name}.passages.idx"
|
|
|
|
# Write passages and create offset map
|
|
offset_map = {}
|
|
with open(passages_file, "w", encoding="utf-8") as f:
|
|
for chunk in self.chunks:
|
|
offset = f.tell()
|
|
json.dump(
|
|
{
|
|
"id": chunk["id"],
|
|
"text": chunk["text"],
|
|
"metadata": chunk["metadata"],
|
|
},
|
|
f,
|
|
ensure_ascii=False,
|
|
)
|
|
f.write("\n")
|
|
offset_map[chunk["id"]] = offset
|
|
|
|
with open(offset_file, "wb") as f:
|
|
pickle.dump(offset_map, f)
|
|
|
|
# Build the vector index using precomputed embeddings
|
|
string_ids = [str(id_val) for id_val in ids]
|
|
current_backend_kwargs = {**self.backend_kwargs, "dimensions": self.dimensions}
|
|
builder_instance = self.backend_factory.builder(**current_backend_kwargs)
|
|
builder_instance.build(embeddings, string_ids, index_path)
|
|
|
|
# Create metadata file
|
|
leann_meta_path = index_dir / f"{index_name}.meta.json"
|
|
meta_data = {
|
|
"version": "1.0",
|
|
"backend_name": self.backend_name,
|
|
"embedding_model": self.embedding_model,
|
|
"dimensions": self.dimensions,
|
|
"backend_kwargs": self.backend_kwargs,
|
|
"embedding_mode": self.embedding_mode,
|
|
"passage_sources": [
|
|
{
|
|
"type": "jsonl",
|
|
"path": str(passages_file),
|
|
"index_path": str(offset_file),
|
|
}
|
|
],
|
|
"built_from_precomputed_embeddings": True,
|
|
"embeddings_source": str(embeddings_file),
|
|
}
|
|
|
|
# Add storage status flags for HNSW backend
|
|
if self.backend_name == "hnsw":
|
|
is_compact = self.backend_kwargs.get("is_compact", True)
|
|
is_recompute = self.backend_kwargs.get("is_recompute", True)
|
|
meta_data["is_compact"] = is_compact
|
|
meta_data["is_pruned"] = is_compact and is_recompute
|
|
|
|
with open(leann_meta_path, "w", encoding="utf-8") as f:
|
|
json.dump(meta_data, f, indent=2)
|
|
|
|
logger.info(f"Index built successfully from precomputed embeddings: {index_path}")
|
|
|
|
|
|
class LeannSearcher:
|
|
def __init__(self, index_path: str, enable_warmup: bool = False, **backend_kwargs):
|
|
# Fix path resolution for Colab and other environments
|
|
if not Path(index_path).is_absolute():
|
|
index_path = str(Path(index_path).resolve())
|
|
|
|
self.meta_path_str = f"{index_path}.meta.json"
|
|
if not Path(self.meta_path_str).exists():
|
|
parent_dir = Path(index_path).parent
|
|
print(
|
|
f"Leann metadata file not found at {self.meta_path_str}, and you may need to rm -rf {parent_dir}"
|
|
)
|
|
# highlight in red the filenotfound error
|
|
raise FileNotFoundError(
|
|
f"Leann metadata file not found at {self.meta_path_str}, \033[91m you may need to rm -rf {parent_dir}\033[0m"
|
|
)
|
|
with open(self.meta_path_str, encoding="utf-8") as f:
|
|
self.meta_data = json.load(f)
|
|
backend_name = self.meta_data["backend_name"]
|
|
self.embedding_model = self.meta_data["embedding_model"]
|
|
# Support both old and new format
|
|
self.embedding_mode = self.meta_data.get("embedding_mode", "sentence-transformers")
|
|
self.passage_manager = PassageManager(self.meta_data.get("passage_sources", []))
|
|
backend_factory = BACKEND_REGISTRY.get(backend_name)
|
|
if backend_factory is None:
|
|
raise ValueError(f"Backend '{backend_name}' not found.")
|
|
final_kwargs = {**self.meta_data.get("backend_kwargs", {}), **backend_kwargs}
|
|
final_kwargs["enable_warmup"] = enable_warmup
|
|
self.backend_impl: LeannBackendSearcherInterface = backend_factory.searcher(
|
|
index_path, **final_kwargs
|
|
)
|
|
|
|
def search(
|
|
self,
|
|
query: str,
|
|
top_k: int = 5,
|
|
complexity: int = 64,
|
|
beam_width: int = 1,
|
|
prune_ratio: float = 0.0,
|
|
recompute_embeddings: bool = True,
|
|
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
|
expected_zmq_port: int = 5557,
|
|
**kwargs,
|
|
) -> list[SearchResult]:
|
|
logger.info("🔍 LeannSearcher.search() called:")
|
|
logger.info(f" Query: '{query}'")
|
|
logger.info(f" Top_k: {top_k}")
|
|
logger.info(f" Additional kwargs: {kwargs}")
|
|
|
|
# Smart top_k detection and adjustment
|
|
total_docs = len(self.passage_manager.global_offset_map)
|
|
original_top_k = top_k
|
|
if top_k > total_docs:
|
|
top_k = total_docs
|
|
logger.warning(
|
|
f" ⚠️ Requested top_k ({original_top_k}) exceeds total documents ({total_docs})"
|
|
)
|
|
logger.warning(f" ✅ Auto-adjusted top_k to {top_k} to match available documents")
|
|
|
|
zmq_port = None
|
|
|
|
start_time = time.time()
|
|
if recompute_embeddings:
|
|
zmq_port = self.backend_impl._ensure_server_running(
|
|
self.meta_path_str,
|
|
port=expected_zmq_port,
|
|
**kwargs,
|
|
)
|
|
del expected_zmq_port
|
|
zmq_time = time.time() - start_time
|
|
logger.info(f" Launching server time: {zmq_time} seconds")
|
|
|
|
start_time = time.time()
|
|
|
|
query_embedding = self.backend_impl.compute_query_embedding(
|
|
query,
|
|
use_server_if_available=recompute_embeddings,
|
|
zmq_port=zmq_port,
|
|
)
|
|
# logger.info(f" Generated embedding shape: {query_embedding.shape}")
|
|
time.time() - start_time
|
|
# logger.info(f" Embedding time: {embedding_time} seconds")
|
|
|
|
start_time = time.time()
|
|
results = self.backend_impl.search(
|
|
query_embedding,
|
|
top_k,
|
|
complexity=complexity,
|
|
beam_width=beam_width,
|
|
prune_ratio=prune_ratio,
|
|
recompute_embeddings=recompute_embeddings,
|
|
pruning_strategy=pruning_strategy,
|
|
zmq_port=zmq_port,
|
|
**kwargs,
|
|
)
|
|
time.time() - start_time
|
|
# logger.info(f" Search time: {search_time} seconds")
|
|
logger.info(f" Backend returned: labels={len(results.get('labels', [[]])[0])} results")
|
|
|
|
enriched_results = []
|
|
if "labels" in results and "distances" in results:
|
|
logger.info(f" Processing {len(results['labels'][0])} passage IDs:")
|
|
for i, (string_id, dist) in enumerate(
|
|
zip(results["labels"][0], results["distances"][0])
|
|
):
|
|
try:
|
|
passage_data = self.passage_manager.get_passage(string_id)
|
|
enriched_results.append(
|
|
SearchResult(
|
|
id=string_id,
|
|
score=dist,
|
|
text=passage_data["text"],
|
|
metadata=passage_data.get("metadata", {}),
|
|
)
|
|
)
|
|
|
|
# Color codes for better logging
|
|
GREEN = "\033[92m"
|
|
BLUE = "\033[94m"
|
|
YELLOW = "\033[93m"
|
|
RESET = "\033[0m"
|
|
|
|
# Truncate text for display (first 100 chars)
|
|
display_text = passage_data["text"]
|
|
logger.info(
|
|
f" {GREEN}✓{RESET} {BLUE}[{i + 1:2d}]{RESET} {YELLOW}ID:{RESET} '{string_id}' {YELLOW}Score:{RESET} {dist:.4f} {YELLOW}Text:{RESET} {display_text}"
|
|
)
|
|
except KeyError:
|
|
RED = "\033[91m"
|
|
logger.error(
|
|
f" {RED}✗{RESET} [{i + 1:2d}] ID: '{string_id}' -> {RED}ERROR: Passage not found!{RESET}"
|
|
)
|
|
|
|
logger.info(f" {GREEN}✓ Final enriched results: {len(enriched_results)} passages{RESET}")
|
|
return enriched_results
|
|
|
|
|
|
class LeannChat:
|
|
def __init__(
|
|
self,
|
|
index_path: str,
|
|
llm_config: Optional[dict[str, Any]] = None,
|
|
enable_warmup: bool = False,
|
|
**kwargs,
|
|
):
|
|
self.searcher = LeannSearcher(index_path, enable_warmup=enable_warmup, **kwargs)
|
|
self.llm = get_llm(llm_config)
|
|
|
|
def ask(
|
|
self,
|
|
question: str,
|
|
top_k: int = 5,
|
|
complexity: int = 64,
|
|
beam_width: int = 1,
|
|
prune_ratio: float = 0.0,
|
|
recompute_embeddings: bool = True,
|
|
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
|
llm_kwargs: Optional[dict[str, Any]] = None,
|
|
expected_zmq_port: int = 5557,
|
|
**search_kwargs,
|
|
):
|
|
if llm_kwargs is None:
|
|
llm_kwargs = {}
|
|
search_time = time.time()
|
|
results = self.searcher.search(
|
|
question,
|
|
top_k=top_k,
|
|
complexity=complexity,
|
|
beam_width=beam_width,
|
|
prune_ratio=prune_ratio,
|
|
recompute_embeddings=recompute_embeddings,
|
|
pruning_strategy=pruning_strategy,
|
|
expected_zmq_port=expected_zmq_port,
|
|
**search_kwargs,
|
|
)
|
|
search_time = time.time() - search_time
|
|
# logger.info(f" Search time: {search_time} seconds")
|
|
context = "\n\n".join([r.text for r in results])
|
|
prompt = (
|
|
"Here is some retrieved context that might help answer your question:\n\n"
|
|
f"{context}\n\n"
|
|
f"Question: {question}\n\n"
|
|
"Please provide the best answer you can based on this context and your knowledge."
|
|
)
|
|
|
|
ask_time = time.time()
|
|
ans = self.llm.ask(prompt, **llm_kwargs)
|
|
ask_time = time.time() - ask_time
|
|
logger.info(f" Ask time: {ask_time} seconds")
|
|
return ans
|
|
|
|
def start_interactive(self):
|
|
print("\nLeann Chat started (type 'quit' to exit)")
|
|
while True:
|
|
try:
|
|
user_input = input("You: ").strip()
|
|
if user_input.lower() in ["quit", "exit"]:
|
|
break
|
|
if not user_input:
|
|
continue
|
|
response = self.ask(user_input)
|
|
print(f"Leann: {response}")
|
|
except (KeyboardInterrupt, EOFError):
|
|
print("\nGoodbye!")
|
|
break
|