Based on excellent analysis from user, implemented comprehensive fixes: 1. ZMQ Socket Cleanup: - Set LINGER=0 on all ZMQ sockets (client and server) - Use try-finally blocks to ensure socket.close() and context.term() - Prevents blocking on exit when ZMQ contexts have pending operations 2. Global Test Cleanup: - Added tests/conftest.py with session-scoped cleanup fixture - Cleans up leftover ZMQ contexts and child processes after all tests - Lists remaining threads for debugging 3. CI Improvements: - Apply timeout to ALL Python versions on Linux (not just 3.13) - Increased timeout to 180s for better reliability - Added process cleanup (pkill) on timeout 4. Dependencies: - Added psutil>=5.9.0 to test dependencies for process management Root cause: Python 3.9/3.13 are more sensitive to cleanup timing during interpreter shutdown. ZMQ's default LINGER=-1 was blocking exit, and atexit handlers were unreliable for cleanup. This should resolve the 'all tests pass but CI hangs' issue.
703 lines
27 KiB
Python
703 lines
27 KiB
Python
"""
|
|
This file contains the core API for the LEANN project, now definitively updated
|
|
with the correct, original embedding logic from the user's reference code.
|
|
"""
|
|
|
|
import json
|
|
import logging
|
|
import pickle
|
|
import time
|
|
import warnings
|
|
from dataclasses import dataclass, field
|
|
from pathlib import Path
|
|
from typing import Any, Literal, Optional
|
|
|
|
import numpy as np
|
|
|
|
from leann.interface import LeannBackendSearcherInterface
|
|
|
|
from .chat import get_llm
|
|
from .interface import LeannBackendFactoryInterface
|
|
from .registry import BACKEND_REGISTRY
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def get_registered_backends() -> list[str]:
|
|
"""Get list of registered backend names."""
|
|
return list(BACKEND_REGISTRY.keys())
|
|
|
|
|
|
def compute_embeddings(
|
|
chunks: list[str],
|
|
model_name: str,
|
|
mode: str = "sentence-transformers",
|
|
use_server: bool = True,
|
|
port: Optional[int] = None,
|
|
is_build=False,
|
|
) -> np.ndarray:
|
|
"""
|
|
Computes embeddings using different backends.
|
|
|
|
Args:
|
|
chunks: List of text chunks to embed
|
|
model_name: Name of the embedding model
|
|
mode: Embedding backend mode. Options:
|
|
- "sentence-transformers": Use sentence-transformers library (default)
|
|
- "mlx": Use MLX backend for Apple Silicon
|
|
- "openai": Use OpenAI embedding API
|
|
use_server: Whether to use embedding server (True for search, False for build)
|
|
|
|
Returns:
|
|
numpy array of embeddings
|
|
"""
|
|
if use_server:
|
|
# Use embedding server (for search/query)
|
|
if port is None:
|
|
raise ValueError("port is required when use_server is True")
|
|
return compute_embeddings_via_server(chunks, model_name, port=port)
|
|
else:
|
|
# Use direct computation (for build_index)
|
|
from .embedding_compute import (
|
|
compute_embeddings as compute_embeddings_direct,
|
|
)
|
|
|
|
return compute_embeddings_direct(
|
|
chunks,
|
|
model_name,
|
|
mode=mode,
|
|
is_build=is_build,
|
|
)
|
|
|
|
|
|
def compute_embeddings_via_server(chunks: list[str], model_name: str, port: int) -> np.ndarray:
|
|
"""Computes embeddings using sentence-transformers.
|
|
|
|
Args:
|
|
chunks: List of text chunks to embed
|
|
model_name: Name of the sentence transformer model
|
|
"""
|
|
logger.info(
|
|
f"Computing embeddings for {len(chunks)} chunks using SentenceTransformer model '{model_name}' (via embedding server)..."
|
|
)
|
|
import msgpack
|
|
import numpy as np
|
|
import zmq
|
|
|
|
# Connect to embedding server
|
|
context = zmq.Context()
|
|
socket = context.socket(zmq.REQ)
|
|
socket.setsockopt(zmq.LINGER, 0) # Don't block on close
|
|
socket.connect(f"tcp://localhost:{port}")
|
|
|
|
try:
|
|
# Send chunks to server for embedding computation
|
|
request = chunks
|
|
socket.send(msgpack.packb(request))
|
|
|
|
# Receive embeddings from server
|
|
response = socket.recv()
|
|
embeddings_list = msgpack.unpackb(response)
|
|
|
|
# Convert back to numpy array
|
|
embeddings = np.array(embeddings_list, dtype=np.float32)
|
|
finally:
|
|
socket.close(linger=0)
|
|
context.term()
|
|
|
|
return embeddings
|
|
|
|
|
|
@dataclass
|
|
class SearchResult:
|
|
id: str
|
|
score: float
|
|
text: str
|
|
metadata: dict[str, Any] = field(default_factory=dict)
|
|
|
|
|
|
class PassageManager:
|
|
def __init__(
|
|
self, passage_sources: list[dict[str, Any]], metadata_file_path: Optional[str] = None
|
|
):
|
|
self.offset_maps = {}
|
|
self.passage_files = {}
|
|
self.global_offset_map = {} # Combined map for fast lookup
|
|
|
|
for source in passage_sources:
|
|
assert source["type"] == "jsonl", "only jsonl is supported"
|
|
passage_file = source["path"]
|
|
index_file = source["index_path"] # .idx file
|
|
|
|
# Fix path resolution - relative paths should be relative to metadata file directory
|
|
if not Path(index_file).is_absolute():
|
|
if metadata_file_path:
|
|
# Resolve relative to metadata file directory
|
|
metadata_dir = Path(metadata_file_path).parent
|
|
logger.debug(
|
|
f"PassageManager: Resolving relative paths from metadata_dir: {metadata_dir}"
|
|
)
|
|
index_file = str((metadata_dir / index_file).resolve())
|
|
passage_file = str((metadata_dir / passage_file).resolve())
|
|
logger.debug(f"PassageManager: Resolved index_file: {index_file}")
|
|
else:
|
|
# Fallback to current directory resolution (legacy behavior)
|
|
logger.warning(
|
|
"PassageManager: No metadata_file_path provided, using fallback resolution from cwd"
|
|
)
|
|
logger.debug(f"PassageManager: Current working directory: {Path.cwd()}")
|
|
index_file = str(Path(index_file).resolve())
|
|
passage_file = str(Path(passage_file).resolve())
|
|
logger.debug(f"PassageManager: Fallback resolved index_file: {index_file}")
|
|
|
|
if not Path(index_file).exists():
|
|
raise FileNotFoundError(f"Passage index file not found: {index_file}")
|
|
|
|
with open(index_file, "rb") as f:
|
|
offset_map = pickle.load(f)
|
|
self.offset_maps[passage_file] = offset_map
|
|
self.passage_files[passage_file] = passage_file
|
|
|
|
# Build global map for O(1) lookup
|
|
for passage_id, offset in offset_map.items():
|
|
self.global_offset_map[passage_id] = (passage_file, offset)
|
|
|
|
def get_passage(self, passage_id: str) -> dict[str, Any]:
|
|
if passage_id in self.global_offset_map:
|
|
passage_file, offset = self.global_offset_map[passage_id]
|
|
# Lazy file opening - only open when needed
|
|
with open(passage_file, encoding="utf-8") as f:
|
|
f.seek(offset)
|
|
return json.loads(f.readline())
|
|
raise KeyError(f"Passage ID not found: {passage_id}")
|
|
|
|
|
|
class LeannBuilder:
|
|
def __init__(
|
|
self,
|
|
backend_name: str,
|
|
embedding_model: str = "facebook/contriever",
|
|
dimensions: Optional[int] = None,
|
|
embedding_mode: str = "sentence-transformers",
|
|
**backend_kwargs,
|
|
):
|
|
self.backend_name = backend_name
|
|
backend_factory: Optional[LeannBackendFactoryInterface] = BACKEND_REGISTRY.get(backend_name)
|
|
if backend_factory is None:
|
|
raise ValueError(f"Backend '{backend_name}' not found or not registered.")
|
|
self.backend_factory = backend_factory
|
|
self.embedding_model = embedding_model
|
|
self.dimensions = dimensions
|
|
self.embedding_mode = embedding_mode
|
|
|
|
# Check if we need to use cosine distance for normalized embeddings
|
|
normalized_embeddings_models = {
|
|
# OpenAI models
|
|
("openai", "text-embedding-ada-002"),
|
|
("openai", "text-embedding-3-small"),
|
|
("openai", "text-embedding-3-large"),
|
|
# Voyage AI models
|
|
("voyage", "voyage-2"),
|
|
("voyage", "voyage-3"),
|
|
("voyage", "voyage-large-2"),
|
|
("voyage", "voyage-multilingual-2"),
|
|
("voyage", "voyage-code-2"),
|
|
# Cohere models
|
|
("cohere", "embed-english-v3.0"),
|
|
("cohere", "embed-multilingual-v3.0"),
|
|
("cohere", "embed-english-light-v3.0"),
|
|
("cohere", "embed-multilingual-light-v3.0"),
|
|
}
|
|
|
|
# Also check for patterns in model names
|
|
is_normalized = False
|
|
current_model_lower = embedding_model.lower()
|
|
current_mode_lower = embedding_mode.lower()
|
|
|
|
# Check exact matches
|
|
for mode, model in normalized_embeddings_models:
|
|
if (current_mode_lower == mode and current_model_lower == model) or (
|
|
mode in current_mode_lower and model in current_model_lower
|
|
):
|
|
is_normalized = True
|
|
break
|
|
|
|
# Check patterns
|
|
if not is_normalized:
|
|
# OpenAI patterns
|
|
if "openai" in current_mode_lower or "openai" in current_model_lower:
|
|
if any(
|
|
pattern in current_model_lower
|
|
for pattern in ["text-embedding", "ada", "3-small", "3-large"]
|
|
):
|
|
is_normalized = True
|
|
# Voyage patterns
|
|
elif "voyage" in current_mode_lower or "voyage" in current_model_lower:
|
|
is_normalized = True
|
|
# Cohere patterns
|
|
elif "cohere" in current_mode_lower or "cohere" in current_model_lower:
|
|
if "embed" in current_model_lower:
|
|
is_normalized = True
|
|
|
|
# Handle distance metric
|
|
if is_normalized and "distance_metric" not in backend_kwargs:
|
|
backend_kwargs["distance_metric"] = "cosine"
|
|
warnings.warn(
|
|
f"Detected normalized embeddings model '{embedding_model}' with mode '{embedding_mode}'. "
|
|
f"Automatically setting distance_metric='cosine' for optimal performance. "
|
|
f"Normalized embeddings (L2 norm = 1) should use cosine similarity instead of MIPS.",
|
|
UserWarning,
|
|
stacklevel=2,
|
|
)
|
|
elif is_normalized and backend_kwargs.get("distance_metric", "").lower() != "cosine":
|
|
current_metric = backend_kwargs.get("distance_metric", "mips")
|
|
warnings.warn(
|
|
f"Warning: Using '{current_metric}' distance metric with normalized embeddings model "
|
|
f"'{embedding_model}' may lead to suboptimal search results. "
|
|
f"Consider using 'cosine' distance metric for better performance.",
|
|
UserWarning,
|
|
stacklevel=2,
|
|
)
|
|
|
|
self.backend_kwargs = backend_kwargs
|
|
self.chunks: list[dict[str, Any]] = []
|
|
|
|
def add_text(self, text: str, metadata: Optional[dict[str, Any]] = None):
|
|
if metadata is None:
|
|
metadata = {}
|
|
passage_id = metadata.get("id", str(len(self.chunks)))
|
|
chunk_data = {"id": passage_id, "text": text, "metadata": metadata}
|
|
self.chunks.append(chunk_data)
|
|
|
|
def build_index(self, index_path: str):
|
|
if not self.chunks:
|
|
raise ValueError("No chunks added.")
|
|
if self.dimensions is None:
|
|
self.dimensions = len(
|
|
compute_embeddings(
|
|
["dummy"],
|
|
self.embedding_model,
|
|
self.embedding_mode,
|
|
use_server=False,
|
|
)[0]
|
|
)
|
|
path = Path(index_path)
|
|
index_dir = path.parent
|
|
index_name = path.name
|
|
index_dir.mkdir(parents=True, exist_ok=True)
|
|
passages_file = index_dir / f"{index_name}.passages.jsonl"
|
|
offset_file = index_dir / f"{index_name}.passages.idx"
|
|
offset_map = {}
|
|
with open(passages_file, "w", encoding="utf-8") as f:
|
|
try:
|
|
from tqdm import tqdm
|
|
|
|
chunk_iterator = tqdm(self.chunks, desc="Writing passages", unit="chunk")
|
|
except ImportError:
|
|
chunk_iterator = self.chunks
|
|
|
|
for chunk in chunk_iterator:
|
|
offset = f.tell()
|
|
json.dump(
|
|
{
|
|
"id": chunk["id"],
|
|
"text": chunk["text"],
|
|
"metadata": chunk["metadata"],
|
|
},
|
|
f,
|
|
ensure_ascii=False,
|
|
)
|
|
f.write("\n")
|
|
offset_map[chunk["id"]] = offset
|
|
with open(offset_file, "wb") as f:
|
|
pickle.dump(offset_map, f)
|
|
texts_to_embed = [c["text"] for c in self.chunks]
|
|
embeddings = compute_embeddings(
|
|
texts_to_embed,
|
|
self.embedding_model,
|
|
self.embedding_mode,
|
|
use_server=False,
|
|
is_build=True,
|
|
)
|
|
string_ids = [chunk["id"] for chunk in self.chunks]
|
|
current_backend_kwargs = {**self.backend_kwargs, "dimensions": self.dimensions}
|
|
builder_instance = self.backend_factory.builder(**current_backend_kwargs)
|
|
builder_instance.build(embeddings, string_ids, index_path, **current_backend_kwargs)
|
|
leann_meta_path = index_dir / f"{index_name}.meta.json"
|
|
meta_data = {
|
|
"version": "1.0",
|
|
"backend_name": self.backend_name,
|
|
"embedding_model": self.embedding_model,
|
|
"dimensions": self.dimensions,
|
|
"backend_kwargs": self.backend_kwargs,
|
|
"embedding_mode": self.embedding_mode,
|
|
"passage_sources": [
|
|
{
|
|
"type": "jsonl",
|
|
"path": passages_file.name, # Use relative path (just filename)
|
|
"index_path": offset_file.name, # Use relative path (just filename)
|
|
}
|
|
],
|
|
}
|
|
|
|
# Add storage status flags for HNSW backend
|
|
if self.backend_name == "hnsw":
|
|
is_compact = self.backend_kwargs.get("is_compact", True)
|
|
is_recompute = self.backend_kwargs.get("is_recompute", True)
|
|
meta_data["is_compact"] = is_compact
|
|
meta_data["is_pruned"] = (
|
|
is_compact and is_recompute
|
|
) # Pruned only if compact and recompute
|
|
with open(leann_meta_path, "w", encoding="utf-8") as f:
|
|
json.dump(meta_data, f, indent=2)
|
|
|
|
def build_index_from_embeddings(self, index_path: str, embeddings_file: str):
|
|
"""
|
|
Build an index from pre-computed embeddings stored in a pickle file.
|
|
|
|
Args:
|
|
index_path: Path where the index will be saved
|
|
embeddings_file: Path to pickle file containing (ids, embeddings) tuple
|
|
"""
|
|
# Load pre-computed embeddings
|
|
with open(embeddings_file, "rb") as f:
|
|
data = pickle.load(f)
|
|
|
|
if not isinstance(data, tuple) or len(data) != 2:
|
|
raise ValueError(
|
|
f"Invalid embeddings file format. Expected tuple with 2 elements, got {type(data)}"
|
|
)
|
|
|
|
ids, embeddings = data
|
|
|
|
if not isinstance(embeddings, np.ndarray):
|
|
raise ValueError(f"Expected embeddings to be numpy array, got {type(embeddings)}")
|
|
|
|
if len(ids) != embeddings.shape[0]:
|
|
raise ValueError(
|
|
f"Mismatch between number of IDs ({len(ids)}) and embeddings ({embeddings.shape[0]})"
|
|
)
|
|
|
|
# Validate/set dimensions
|
|
embedding_dim = embeddings.shape[1]
|
|
if self.dimensions is None:
|
|
self.dimensions = embedding_dim
|
|
elif self.dimensions != embedding_dim:
|
|
raise ValueError(f"Dimension mismatch: expected {self.dimensions}, got {embedding_dim}")
|
|
|
|
logger.info(
|
|
f"Building index from precomputed embeddings: {len(ids)} items, {embedding_dim} dimensions"
|
|
)
|
|
|
|
# Ensure we have text data for each embedding
|
|
if len(self.chunks) != len(ids):
|
|
# If no text chunks provided, create placeholder text entries
|
|
if not self.chunks:
|
|
logger.info("No text chunks provided, creating placeholder entries...")
|
|
for id_val in ids:
|
|
self.add_text(
|
|
f"Document {id_val}",
|
|
metadata={"id": str(id_val), "from_embeddings": True},
|
|
)
|
|
else:
|
|
raise ValueError(
|
|
f"Number of text chunks ({len(self.chunks)}) doesn't match number of embeddings ({len(ids)})"
|
|
)
|
|
|
|
# Build file structure
|
|
path = Path(index_path)
|
|
index_dir = path.parent
|
|
index_name = path.name
|
|
index_dir.mkdir(parents=True, exist_ok=True)
|
|
passages_file = index_dir / f"{index_name}.passages.jsonl"
|
|
offset_file = index_dir / f"{index_name}.passages.idx"
|
|
|
|
# Write passages and create offset map
|
|
offset_map = {}
|
|
with open(passages_file, "w", encoding="utf-8") as f:
|
|
for chunk in self.chunks:
|
|
offset = f.tell()
|
|
json.dump(
|
|
{
|
|
"id": chunk["id"],
|
|
"text": chunk["text"],
|
|
"metadata": chunk["metadata"],
|
|
},
|
|
f,
|
|
ensure_ascii=False,
|
|
)
|
|
f.write("\n")
|
|
offset_map[chunk["id"]] = offset
|
|
|
|
with open(offset_file, "wb") as f:
|
|
pickle.dump(offset_map, f)
|
|
|
|
# Build the vector index using precomputed embeddings
|
|
string_ids = [str(id_val) for id_val in ids]
|
|
current_backend_kwargs = {**self.backend_kwargs, "dimensions": self.dimensions}
|
|
builder_instance = self.backend_factory.builder(**current_backend_kwargs)
|
|
builder_instance.build(embeddings, string_ids, index_path)
|
|
|
|
# Create metadata file
|
|
leann_meta_path = index_dir / f"{index_name}.meta.json"
|
|
meta_data = {
|
|
"version": "1.0",
|
|
"backend_name": self.backend_name,
|
|
"embedding_model": self.embedding_model,
|
|
"dimensions": self.dimensions,
|
|
"backend_kwargs": self.backend_kwargs,
|
|
"embedding_mode": self.embedding_mode,
|
|
"passage_sources": [
|
|
{
|
|
"type": "jsonl",
|
|
"path": passages_file.name, # Use relative path (just filename)
|
|
"index_path": offset_file.name, # Use relative path (just filename)
|
|
}
|
|
],
|
|
"built_from_precomputed_embeddings": True,
|
|
"embeddings_source": str(embeddings_file),
|
|
}
|
|
|
|
# Add storage status flags for HNSW backend
|
|
if self.backend_name == "hnsw":
|
|
is_compact = self.backend_kwargs.get("is_compact", True)
|
|
is_recompute = self.backend_kwargs.get("is_recompute", True)
|
|
meta_data["is_compact"] = is_compact
|
|
meta_data["is_pruned"] = is_compact and is_recompute
|
|
|
|
with open(leann_meta_path, "w", encoding="utf-8") as f:
|
|
json.dump(meta_data, f, indent=2)
|
|
|
|
logger.info(f"Index built successfully from precomputed embeddings: {index_path}")
|
|
|
|
|
|
class LeannSearcher:
|
|
def __init__(self, index_path: str, enable_warmup: bool = False, **backend_kwargs):
|
|
# Fix path resolution for Colab and other environments
|
|
if not Path(index_path).is_absolute():
|
|
index_path = str(Path(index_path).resolve())
|
|
|
|
self.meta_path_str = f"{index_path}.meta.json"
|
|
if not Path(self.meta_path_str).exists():
|
|
parent_dir = Path(index_path).parent
|
|
print(
|
|
f"Leann metadata file not found at {self.meta_path_str}, and you may need to rm -rf {parent_dir}"
|
|
)
|
|
# highlight in red the filenotfound error
|
|
raise FileNotFoundError(
|
|
f"Leann metadata file not found at {self.meta_path_str}, \033[91m you may need to rm -rf {parent_dir}\033[0m"
|
|
)
|
|
with open(self.meta_path_str, encoding="utf-8") as f:
|
|
self.meta_data = json.load(f)
|
|
backend_name = self.meta_data["backend_name"]
|
|
self.embedding_model = self.meta_data["embedding_model"]
|
|
# Support both old and new format
|
|
self.embedding_mode = self.meta_data.get("embedding_mode", "sentence-transformers")
|
|
self.passage_manager = PassageManager(
|
|
self.meta_data.get("passage_sources", []), metadata_file_path=self.meta_path_str
|
|
)
|
|
backend_factory = BACKEND_REGISTRY.get(backend_name)
|
|
if backend_factory is None:
|
|
raise ValueError(f"Backend '{backend_name}' not found.")
|
|
final_kwargs = {**self.meta_data.get("backend_kwargs", {}), **backend_kwargs}
|
|
final_kwargs["enable_warmup"] = enable_warmup
|
|
self.backend_impl: LeannBackendSearcherInterface = backend_factory.searcher(
|
|
index_path, **final_kwargs
|
|
)
|
|
|
|
def search(
|
|
self,
|
|
query: str,
|
|
top_k: int = 5,
|
|
complexity: int = 64,
|
|
beam_width: int = 1,
|
|
prune_ratio: float = 0.0,
|
|
recompute_embeddings: bool = True,
|
|
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
|
expected_zmq_port: int = 5557,
|
|
**kwargs,
|
|
) -> list[SearchResult]:
|
|
logger.info("🔍 LeannSearcher.search() called:")
|
|
logger.info(f" Query: '{query}'")
|
|
logger.info(f" Top_k: {top_k}")
|
|
logger.info(f" Additional kwargs: {kwargs}")
|
|
|
|
# Smart top_k detection and adjustment
|
|
total_docs = len(self.passage_manager.global_offset_map)
|
|
original_top_k = top_k
|
|
if top_k > total_docs:
|
|
top_k = total_docs
|
|
logger.warning(
|
|
f" ⚠️ Requested top_k ({original_top_k}) exceeds total documents ({total_docs})"
|
|
)
|
|
logger.warning(f" ✅ Auto-adjusted top_k to {top_k} to match available documents")
|
|
|
|
zmq_port = None
|
|
|
|
start_time = time.time()
|
|
if recompute_embeddings:
|
|
zmq_port = self.backend_impl._ensure_server_running(
|
|
self.meta_path_str,
|
|
port=expected_zmq_port,
|
|
**kwargs,
|
|
)
|
|
del expected_zmq_port
|
|
zmq_time = time.time() - start_time
|
|
logger.info(f" Launching server time: {zmq_time} seconds")
|
|
|
|
start_time = time.time()
|
|
|
|
query_embedding = self.backend_impl.compute_query_embedding(
|
|
query,
|
|
use_server_if_available=recompute_embeddings,
|
|
zmq_port=zmq_port,
|
|
)
|
|
# logger.info(f" Generated embedding shape: {query_embedding.shape}")
|
|
time.time() - start_time
|
|
# logger.info(f" Embedding time: {embedding_time} seconds")
|
|
|
|
start_time = time.time()
|
|
results = self.backend_impl.search(
|
|
query_embedding,
|
|
top_k,
|
|
complexity=complexity,
|
|
beam_width=beam_width,
|
|
prune_ratio=prune_ratio,
|
|
recompute_embeddings=recompute_embeddings,
|
|
pruning_strategy=pruning_strategy,
|
|
zmq_port=zmq_port,
|
|
**kwargs,
|
|
)
|
|
# logger.info(f" Search time: {search_time} seconds")
|
|
logger.info(f" Backend returned: labels={len(results.get('labels', [[]])[0])} results")
|
|
|
|
enriched_results = []
|
|
if "labels" in results and "distances" in results:
|
|
logger.info(f" Processing {len(results['labels'][0])} passage IDs:")
|
|
# Python 3.9 does not support zip(strict=...); lengths are expected to match
|
|
for i, (string_id, dist) in enumerate(
|
|
zip(results["labels"][0], results["distances"][0])
|
|
):
|
|
try:
|
|
passage_data = self.passage_manager.get_passage(string_id)
|
|
enriched_results.append(
|
|
SearchResult(
|
|
id=string_id,
|
|
score=dist,
|
|
text=passage_data["text"],
|
|
metadata=passage_data.get("metadata", {}),
|
|
)
|
|
)
|
|
|
|
# Color codes for better logging
|
|
GREEN = "\033[92m"
|
|
BLUE = "\033[94m"
|
|
YELLOW = "\033[93m"
|
|
RESET = "\033[0m"
|
|
|
|
# Truncate text for display (first 100 chars)
|
|
display_text = passage_data["text"]
|
|
logger.info(
|
|
f" {GREEN}✓{RESET} {BLUE}[{i + 1:2d}]{RESET} {YELLOW}ID:{RESET} '{string_id}' {YELLOW}Score:{RESET} {dist:.4f} {YELLOW}Text:{RESET} {display_text}"
|
|
)
|
|
except KeyError:
|
|
RED = "\033[91m"
|
|
RESET = "\033[0m"
|
|
logger.error(
|
|
f" {RED}✗{RESET} [{i + 1:2d}] ID: '{string_id}' -> {RED}ERROR: Passage not found!{RESET}"
|
|
)
|
|
|
|
# Define color codes outside the loop for final message
|
|
GREEN = "\033[92m"
|
|
RESET = "\033[0m"
|
|
logger.info(f" {GREEN}✓ Final enriched results: {len(enriched_results)} passages{RESET}")
|
|
return enriched_results
|
|
|
|
def cleanup(self):
|
|
"""Explicitly cleanup embedding server resources.
|
|
|
|
This method should be called after you're done using the searcher,
|
|
especially in test environments or batch processing scenarios.
|
|
"""
|
|
if hasattr(self.backend_impl, "embedding_server_manager"):
|
|
self.backend_impl.embedding_server_manager.stop_server()
|
|
|
|
|
|
class LeannChat:
|
|
def __init__(
|
|
self,
|
|
index_path: str,
|
|
llm_config: Optional[dict[str, Any]] = None,
|
|
enable_warmup: bool = False,
|
|
**kwargs,
|
|
):
|
|
self.searcher = LeannSearcher(index_path, enable_warmup=enable_warmup, **kwargs)
|
|
self.llm = get_llm(llm_config)
|
|
|
|
def ask(
|
|
self,
|
|
question: str,
|
|
top_k: int = 5,
|
|
complexity: int = 64,
|
|
beam_width: int = 1,
|
|
prune_ratio: float = 0.0,
|
|
recompute_embeddings: bool = True,
|
|
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
|
llm_kwargs: Optional[dict[str, Any]] = None,
|
|
expected_zmq_port: int = 5557,
|
|
**search_kwargs,
|
|
):
|
|
if llm_kwargs is None:
|
|
llm_kwargs = {}
|
|
search_time = time.time()
|
|
results = self.searcher.search(
|
|
question,
|
|
top_k=top_k,
|
|
complexity=complexity,
|
|
beam_width=beam_width,
|
|
prune_ratio=prune_ratio,
|
|
recompute_embeddings=recompute_embeddings,
|
|
pruning_strategy=pruning_strategy,
|
|
expected_zmq_port=expected_zmq_port,
|
|
**search_kwargs,
|
|
)
|
|
search_time = time.time() - search_time
|
|
# logger.info(f" Search time: {search_time} seconds")
|
|
context = "\n\n".join([r.text for r in results])
|
|
prompt = (
|
|
"Here is some retrieved context that might help answer your question:\n\n"
|
|
f"{context}\n\n"
|
|
f"Question: {question}\n\n"
|
|
"Please provide the best answer you can based on this context and your knowledge."
|
|
)
|
|
|
|
ask_time = time.time()
|
|
ans = self.llm.ask(prompt, **llm_kwargs)
|
|
ask_time = time.time() - ask_time
|
|
logger.info(f" Ask time: {ask_time} seconds")
|
|
return ans
|
|
|
|
def start_interactive(self):
|
|
print("\nLeann Chat started (type 'quit' to exit)")
|
|
while True:
|
|
try:
|
|
user_input = input("You: ").strip()
|
|
if user_input.lower() in ["quit", "exit"]:
|
|
break
|
|
if not user_input:
|
|
continue
|
|
response = self.ask(user_input)
|
|
print(f"Leann: {response}")
|
|
except (KeyboardInterrupt, EOFError):
|
|
print("\nGoodbye!")
|
|
break
|
|
|
|
def cleanup(self):
|
|
"""Explicitly cleanup embedding server resources.
|
|
|
|
This method should be called after you're done using the chat interface,
|
|
especially in test environments or batch processing scenarios.
|
|
"""
|
|
if hasattr(self.searcher, "cleanup"):
|
|
self.searcher.cleanup()
|