feat: support more embedders

This commit is contained in:
Andy Lee
2025-07-06 00:35:07 +00:00
parent 0aa84e147b
commit 910927a405
6 changed files with 142 additions and 85 deletions

View File

@@ -3,11 +3,17 @@ Simple demo showing basic leann usage
Run: uv run python examples/simple_demo.py Run: uv run python examples/simple_demo.py
""" """
import argparse
from leann import LeannBuilder, LeannSearcher, LeannChat from leann import LeannBuilder, LeannSearcher, LeannChat
def main(): def main():
print("=== Leann Simple Demo ===") parser = argparse.ArgumentParser(description="Simple demo of Leann with selectable embedding models.")
parser.add_argument("--embedding_model", type=str, default="sentence-transformers/all-mpnet-base-v2",
help="The embedding model to use, e.g., 'sentence-transformers/all-mpnet-base-v2' or 'text-embedding-ada-002'.")
args = parser.parse_args()
print(f"=== Leann Simple Demo with {args.embedding_model} ===")
print() print()
# Sample knowledge base # Sample knowledge base
@@ -24,10 +30,11 @@ def main():
print("1. Building index (no embeddings stored)...") print("1. Building index (no embeddings stored)...")
builder = LeannBuilder( builder = LeannBuilder(
embedding_model="sentence-transformers/all-mpnet-base-v2", embedding_model=args.embedding_model,
prune_ratio=0.7, # Keep 30% of connections backend_name="hnsw",
) )
builder.add_chunks(chunks) for chunk in chunks:
builder.add_text(chunk)
builder.build_index("demo_knowledge.leann") builder.build_index("demo_knowledge.leann")
print() print()
@@ -49,14 +56,7 @@ def main():
print(f" Text: {result.text[:100]}...") print(f" Text: {result.text[:100]}...")
print() print()
print("3. Memory stats:") print("3. Interactive chat demo:")
stats = searcher.get_memory_stats()
print(f" Cache size: {stats.embedding_cache_size}")
print(f" Cache memory: {stats.embedding_cache_memory_mb:.1f} MB")
print(f" Total chunks: {stats.total_chunks}")
print()
print("4. Interactive chat demo:")
print(" (Note: Requires OpenAI API key for real responses)") print(" (Note: Requires OpenAI API key for real responses)")
chat = LeannChat("demo_knowledge.leann") chat = LeannChat("demo_knowledge.leann")

View File

@@ -143,20 +143,16 @@ class DiskannBackend(LeannBackendFactoryInterface):
path = Path(index_path) path = Path(index_path)
meta_path = path.parent / f"{path.name}.meta.json" meta_path = path.parent / f"{path.name}.meta.json"
if not meta_path.exists(): if not meta_path.exists():
raise FileNotFoundError(f"Leann metadata file not found at {meta_path}. Cannot infer vector dimension for searcher.") raise FileNotFoundError(f"Leann metadata file not found at {meta_path}. Cannot infer vector dimension for searcher.")
with open(meta_path, 'r') as f: with open(meta_path, 'r') as f:
meta = json.load(f) meta = json.load(f)
try: dimensions = meta.get("dimensions")
from sentence_transformers import SentenceTransformer if not dimensions:
model = SentenceTransformer(meta.get("embedding_model")) raise ValueError("Dimensions not found in Leann metadata. Please rebuild the index with a newer version of Leann.")
dimensions = model.get_sentence_embedding_dimension()
kwargs['dimensions'] = dimensions
except ImportError:
raise ImportError("sentence-transformers is required to infer embedding dimensions. Please install it.")
except Exception as e:
raise RuntimeError(f"Could not load SentenceTransformer model to get dimension: {e}")
kwargs['dimensions'] = dimensions
return DiskannSearcher(index_path, **kwargs) return DiskannSearcher(index_path, **kwargs)
class DiskannBuilder(LeannBackendBuilderInterface): class DiskannBuilder(LeannBackendBuilderInterface):

View File

@@ -44,7 +44,7 @@ class HNSWEmbeddingServerManager:
self.server_port = None self.server_port = None
atexit.register(self.stop_server) atexit.register(self.stop_server)
def start_server(self, port=5556, model_name="sentence-transformers/all-mpnet-base-v2", passages_file=None): def start_server(self, port=5556, model_name="sentence-transformers/all-mpnet-base-v2", passages_file=None, distance_metric="mips"):
""" """
Start the HNSW embedding server process. Start the HNSW embedding server process.
@@ -52,6 +52,7 @@ class HNSWEmbeddingServerManager:
port: ZMQ port for the server port: ZMQ port for the server
model_name: Name of the embedding model to use model_name: Name of the embedding model to use
passages_file: Optional path to passages JSON file passages_file: Optional path to passages JSON file
distance_metric: The distance metric to use
""" """
if self.server_process and self.server_process.poll() is None: if self.server_process and self.server_process.poll() is None:
print(f"INFO: Reusing existing HNSW server process for this session (PID {self.server_process.pid})") print(f"INFO: Reusing existing HNSW server process for this session (PID {self.server_process.pid})")
@@ -69,7 +70,8 @@ class HNSWEmbeddingServerManager:
sys.executable, sys.executable,
"-m", "leann_backend_hnsw.hnsw_embedding_server", "-m", "leann_backend_hnsw.hnsw_embedding_server",
"--zmq-port", str(port), "--zmq-port", str(port),
"--model-name", model_name "--model-name", model_name,
"--distance-metric", distance_metric
] ]
if passages_file: if passages_file:
@@ -150,21 +152,16 @@ class HNSWBackend(LeannBackendFactoryInterface):
path = Path(index_path) path = Path(index_path)
meta_path = path.parent / f"{path.name}.meta.json" meta_path = path.parent / f"{path.name}.meta.json"
if not meta_path.exists(): if not meta_path.exists():
raise FileNotFoundError(f"Leann metadata file not found at {meta_path}. Cannot infer vector dimension for searcher.") raise FileNotFoundError(f"Leann metadata file not found at {meta_path}. Cannot infer vector dimension for searcher.")
with open(meta_path, 'r') as f: with open(meta_path, 'r') as f:
meta = json.load(f) meta = json.load(f)
try: dimensions = meta.get("dimensions")
from sentence_transformers import SentenceTransformer if not dimensions:
model = SentenceTransformer(meta.get("embedding_model")) raise ValueError("Dimensions not found in Leann metadata. Please rebuild the index with a newer version of Leann.")
dimensions = model.get_sentence_embedding_dimension()
kwargs['dimensions'] = dimensions
except ImportError:
raise ImportError("sentence-transformers is required to infer embedding dimensions. Please install it.")
except Exception as e:
raise RuntimeError(f"Could not load SentenceTransformer model to get dimension: {e}")
kwargs['dimensions'] = dimensions
return HNSWSearcher(index_path, **kwargs) return HNSWSearcher(index_path, **kwargs)
class HNSWBuilder(LeannBackendBuilderInterface): class HNSWBuilder(LeannBackendBuilderInterface):
@@ -172,10 +169,8 @@ class HNSWBuilder(LeannBackendBuilderInterface):
self.build_params = kwargs.copy() self.build_params = kwargs.copy()
# --- Configuration defaults with standardized names --- # --- Configuration defaults with standardized names ---
# Apply defaults and write them back to the build_params dict
# so they can be saved in the metadata file by LeannBuilder.
self.is_compact = self.build_params.setdefault("is_compact", True) self.is_compact = self.build_params.setdefault("is_compact", True)
self.is_recompute = self.build_params.setdefault("is_recompute", True) # Default: prune embeddings self.is_recompute = self.build_params.setdefault("is_recompute", True)
# --- Additional Options --- # --- Additional Options ---
self.is_skip_neighbors = self.build_params.setdefault("is_skip_neighbors", False) self.is_skip_neighbors = self.build_params.setdefault("is_skip_neighbors", False)
@@ -186,6 +181,7 @@ class HNSWBuilder(LeannBackendBuilderInterface):
self.M = self.build_params.setdefault("M", 32) self.M = self.build_params.setdefault("M", 32)
self.efConstruction = self.build_params.setdefault("efConstruction", 200) self.efConstruction = self.build_params.setdefault("efConstruction", 200)
self.distance_metric = self.build_params.setdefault("distance_metric", "mips") self.distance_metric = self.build_params.setdefault("distance_metric", "mips")
self.dimensions = self.build_params.get("dimensions")
if self.is_skip_neighbors and not self.is_compact: if self.is_skip_neighbors and not self.is_compact:
raise ValueError("is_skip_neighbors can only be used with is_compact=True") raise ValueError("is_skip_neighbors can only be used with is_compact=True")
@@ -210,30 +206,25 @@ class HNSWBuilder(LeannBackendBuilderInterface):
metric_str = self.distance_metric.lower() metric_str = self.distance_metric.lower()
metric_enum = get_metric_map().get(metric_str) metric_enum = get_metric_map().get(metric_str)
print('metric_enum', metric_enum,' metric_str', metric_str)
if metric_enum is None: if metric_enum is None:
raise ValueError(f"Unsupported distance_metric '{metric_str}'.") raise ValueError(f"Unsupported distance_metric '{metric_str}'.")
M = self.M M = self.M
efConstruction = self.efConstruction efConstruction = self.efConstruction
dim = data.shape[1] dim = self.dimensions
if not dim:
dim = data.shape[1]
print(f"INFO: Building HNSW index for {data.shape[0]} vectors with metric {metric_enum}...") print(f"INFO: Building HNSW index for {data.shape[0]} vectors with metric {metric_enum}...")
try: try:
if metric_enum == faiss.METRIC_INNER_PRODUCT: index = faiss.IndexHNSWFlat(dim, M, metric_enum)
index = faiss.IndexHNSWFlat(dim, M, metric_enum)
else: # L2
index = faiss.IndexHNSWFlat(dim, M, metric_enum)
index.hnsw.efConstruction = efConstruction index.hnsw.efConstruction = efConstruction
if metric_str == "cosine": if metric_str == "cosine":
faiss.normalize_L2(data) faiss.normalize_L2(data)
print('starting to add vectors to index')
index.add(data.shape[0], faiss.swig_ptr(data)) index.add(data.shape[0], faiss.swig_ptr(data))
print('vectors added to index')
index_file = index_dir / f"{index_prefix}.index" index_file = index_dir / f"{index_prefix}.index"
faiss.write_index(index, str(index_file)) faiss.write_index(index, str(index_file))
@@ -243,7 +234,6 @@ class HNSWBuilder(LeannBackendBuilderInterface):
if self.is_compact: if self.is_compact:
self._convert_to_csr(index_file) self._convert_to_csr(index_file)
# Generate passages file for recompute mode
if self.is_recompute: if self.is_recompute:
self._generate_passages_file(index_dir, index_prefix, **kwargs) self._generate_passages_file(index_dir, index_prefix, **kwargs)
@@ -423,13 +413,11 @@ class HNSWSearcher(LeannBackendSearcherInterface):
# Apply additional configuration options with strict validation # Apply additional configuration options with strict validation
hnsw_config.is_skip_neighbors = self.config.get("is_skip_neighbors", False) hnsw_config.is_skip_neighbors = self.config.get("is_skip_neighbors", False)
# If index is pruned, force recompute mode regardless of user setting
hnsw_config.is_recompute = self.is_pruned or self.config.get("is_recompute", False) hnsw_config.is_recompute = self.is_pruned or self.config.get("is_recompute", False)
hnsw_config.disk_cache_ratio = self.config.get("disk_cache_ratio", 0.0) hnsw_config.disk_cache_ratio = self.config.get("disk_cache_ratio", 0.0)
hnsw_config.external_storage_path = self.config.get("external_storage_path") hnsw_config.external_storage_path = self.config.get("external_storage_path")
hnsw_config.zmq_port = self.config.get("zmq_port", 5557) hnsw_config.zmq_port = self.config.get("zmq_port", 5557)
# CRITICAL ASSERTION: If index is pruned, recompute MUST be enabled
if self.is_pruned and not hnsw_config.is_recompute: if self.is_pruned and not hnsw_config.is_recompute:
raise RuntimeError("Index is pruned (embeddings removed) but recompute is disabled. This is impossible - recompute must be enabled for pruned indices.") raise RuntimeError("Index is pruned (embeddings removed) but recompute is disabled. This is impossible - recompute must be enabled for pruned indices.")
@@ -487,7 +475,7 @@ class HNSWSearcher(LeannBackendSearcherInterface):
if _check_port(zmq_port): if _check_port(zmq_port):
print(f"INFO: Embedding server already running on port {zmq_port}") print(f"INFO: Embedding server already running on port {zmq_port}")
else: else:
if not self.embedding_server_manager.start_server(zmq_port, embedding_model, passages_file): if not self.embedding_server_manager.start_server(zmq_port, embedding_model, passages_file, self.metric_str):
raise RuntimeError(f"Failed to start HNSW embedding server on port {zmq_port}") raise RuntimeError(f"Failed to start HNSW embedding server on port {zmq_port}")
# Give server extra time to fully initialize # Give server extra time to fully initialize

View File

@@ -85,6 +85,7 @@ def create_hnsw_embedding_server(
max_batch_size: int = 128, max_batch_size: int = 128,
model_name: str = "sentence-transformers/all-mpnet-base-v2", model_name: str = "sentence-transformers/all-mpnet-base-v2",
custom_max_length_param: Optional[int] = None, custom_max_length_param: Optional[int] = None,
distance_metric: str = "mips",
): ):
""" """
Create and start a ZMQ-based embedding server for HNSW backend. Create and start a ZMQ-based embedding server for HNSW backend.
@@ -100,6 +101,7 @@ def create_hnsw_embedding_server(
max_batch_size: Maximum batch size for processing max_batch_size: Maximum batch size for processing
model_name: Transformer model name model_name: Transformer model name
custom_max_length_param: Custom max sequence length custom_max_length_param: Custom max sequence length
distance_metric: The distance metric to use
""" """
print(f"Loading tokenizer for {model_name}...") print(f"Loading tokenizer for {model_name}...")
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True) tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
@@ -222,6 +224,7 @@ def create_hnsw_embedding_server(
def process_batch(texts_batch, ids_batch, missing_ids): def process_batch(texts_batch, ids_batch, missing_ids):
"""Process a batch of texts and return embeddings""" """Process a batch of texts and return embeddings"""
_is_e5_model = "e5" in model_name.lower() _is_e5_model = "e5" in model_name.lower()
_is_bge_model = "bge" in model_name.lower()
batch_size = len(texts_batch) batch_size = len(texts_batch)
# E5 model preprocessing # E5 model preprocessing
@@ -262,7 +265,9 @@ def create_hnsw_embedding_server(
out = model(enc["input_ids"], enc["attention_mask"]) out = model(enc["input_ids"], enc["attention_mask"])
with pool_timer.timing(): with pool_timer.timing():
if not hasattr(out, 'last_hidden_state'): if _is_bge_model:
pooled_embeddings = out.last_hidden_state[:, 0]
elif not hasattr(out, 'last_hidden_state'):
if isinstance(out, torch.Tensor) and len(out.shape) == 2: if isinstance(out, torch.Tensor) and len(out.shape) == 2:
pooled_embeddings = out pooled_embeddings = out
else: else:
@@ -279,7 +284,7 @@ def create_hnsw_embedding_server(
pooled_embeddings = sum_embeddings / sum_mask pooled_embeddings = sum_embeddings / sum_mask
final_embeddings = pooled_embeddings final_embeddings = pooled_embeddings
if _is_e5_model: if _is_e5_model or _is_bge_model:
with norm_timer.timing(): with norm_timer.timing():
final_embeddings = F.normalize(pooled_embeddings, p=2, dim=1) final_embeddings = F.normalize(pooled_embeddings, p=2, dim=1)
@@ -408,14 +413,14 @@ def create_hnsw_embedding_server(
calc_timer = DeviceTimer("distance calculation", device) calc_timer = DeviceTimer("distance calculation", device)
with calc_timer.timing(): with calc_timer.timing():
with torch.no_grad(): with torch.no_grad():
if is_similarity_metric(): if distance_metric == "l2":
node_embeddings_np = node_embeddings_tensor.cpu().numpy()
query_np = query_tensor.cpu().numpy()
distances = -np.dot(node_embeddings_np, query_np)
else:
node_embeddings_np = node_embeddings_tensor.cpu().numpy().astype(np.float32) node_embeddings_np = node_embeddings_tensor.cpu().numpy().astype(np.float32)
query_np = query_tensor.cpu().numpy().astype(np.float32) query_np = query_tensor.cpu().numpy().astype(np.float32)
distances = np.sum(np.square(node_embeddings_np - query_np.reshape(1, -1)), axis=1) distances = np.sum(np.square(node_embeddings_np - query_np.reshape(1, -1)), axis=1)
else: # mips or cosine
node_embeddings_np = node_embeddings_tensor.cpu().numpy()
query_np = query_tensor.cpu().numpy()
distances = -np.dot(node_embeddings_np, query_np)
calc_timer.print_elapsed() calc_timer.print_elapsed()
try: try:
@@ -572,6 +577,7 @@ if __name__ == "__main__":
parser.add_argument("--model-name", type=str, default="sentence-transformers/all-mpnet-base-v2", parser.add_argument("--model-name", type=str, default="sentence-transformers/all-mpnet-base-v2",
help="Embedding model name") help="Embedding model name")
parser.add_argument("--custom-max-length", type=int, default=None, help="Override model's default max sequence length") parser.add_argument("--custom-max-length", type=int, default=None, help="Override model's default max sequence length")
parser.add_argument("--distance-metric", type=str, default="mips", help="Distance metric to use")
args = parser.parse_args() args = parser.parse_args()
@@ -586,4 +592,5 @@ if __name__ == "__main__":
max_batch_size=args.max_batch_size, max_batch_size=args.max_batch_size,
model_name=args.model_name, model_name=args.model_name,
custom_max_length_param=args.custom_max_length, custom_max_length_param=args.custom_max_length,
distance_metric=args.distance_metric,
) )

View File

@@ -0,0 +1,17 @@
# This file makes the 'leann' directory a Python package.
from .api import LeannBuilder, LeannSearcher, LeannChat, SearchResult
# Import backends to ensure they are registered
try:
import leann_backend_hnsw
except ImportError:
pass
try:
import leann_backend_diskann
except ImportError:
pass
__all__ = ['LeannBuilder', 'LeannSearcher', 'LeannChat', 'SearchResult']

View File

@@ -6,22 +6,69 @@ import os
import json import json
from pathlib import Path from pathlib import Path
import openai import openai
from dataclasses import dataclass, field
# --- Helper Functions for Embeddings ---
def _get_openai_client():
"""Initializes and returns an OpenAI client, ensuring the API key is set."""
api_key = os.getenv("OPENAI_API_KEY")
if not api_key:
raise ValueError("OPENAI_API_KEY environment variable not set, which is required for OpenAI models.")
return openai.OpenAI(api_key=api_key)
def _is_openai_model(model_name: str) -> bool:
"""Checks if the model is likely an OpenAI embedding model."""
# This is a simple check, can be improved with a more robust list.
return "ada" in model_name or "babbage" in model_name or model_name.startswith("text-embedding-")
def _compute_embeddings(chunks: List[str], model_name: str) -> np.ndarray: def _compute_embeddings(chunks: List[str], model_name: str) -> np.ndarray:
from sentence_transformers import SentenceTransformer """Computes embeddings for a list of text chunks using either SentenceTransformers or OpenAI."""
# TODO: use a better embedding model if _is_openai_model(model_name):
model = SentenceTransformer(model_name) print(f"INFO: Computing embeddings for {len(chunks)} chunks using OpenAI model '{model_name}'...")
print(f"INFO: Computing embeddings for {len(chunks)} chunks using '{model_name}'...") client = _get_openai_client()
embeddings = model.encode(chunks, show_progress_bar=True) response = client.embeddings.create(model=model_name, input=chunks)
embeddings = [item.embedding for item in response.data]
else:
from sentence_transformers import SentenceTransformer
model = SentenceTransformer(model_name)
print(f"INFO: Computing embeddings for {len(chunks)} chunks using SentenceTransformer model '{model_name}'...")
embeddings = model.encode(chunks, show_progress_bar=True)
return np.asarray(embeddings, dtype=np.float32) return np.asarray(embeddings, dtype=np.float32)
def _get_embedding_dimensions(model_name: str) -> int:
"""Gets the embedding dimensions for a given model."""
print(f"INFO: Calculating dimensions for model '{model_name}'...")
if _is_openai_model(model_name):
client = _get_openai_client()
response = client.embeddings.create(model=model_name, input=["dummy text"])
return len(response.data[0].embedding)
else:
from sentence_transformers import SentenceTransformer
model = SentenceTransformer(model_name)
dimension = model.get_sentence_embedding_dimension()
if dimension is None:
raise ValueError(f"Model '{model_name}' does not have a valid embedding dimension.")
return dimension
@dataclass
class SearchResult:
"""Represents a single search result."""
id: int
score: float
text: str
metadata: Dict[str, Any] = field(default_factory=dict)
# --- Core Classes ---
class LeannBuilder: class LeannBuilder:
""" """
The builder is responsible for building the index, it will compute the embeddings and then build the index. The builder is responsible for building the index, it will compute the embeddings and then build the index.
It will also save the metadata of the index. It will also save the metadata of the index.
""" """
def __init__(self, backend_name: str, embedding_model: str = "sentence-transformers/all-mpnet-base-v2", **backend_kwargs): def __init__(self, backend_name: str, embedding_model: str = "sentence-transformers/all-mpnet-base-v2", dimensions: Optional[int] = None, **backend_kwargs):
self.backend_name = backend_name self.backend_name = backend_name
backend_factory: LeannBackendFactoryInterface | None = BACKEND_REGISTRY.get(backend_name) backend_factory: LeannBackendFactoryInterface | None = BACKEND_REGISTRY.get(backend_name)
if backend_factory is None: if backend_factory is None:
@@ -29,6 +76,7 @@ class LeannBuilder:
self.backend_factory = backend_factory self.backend_factory = backend_factory
self.embedding_model = embedding_model self.embedding_model = embedding_model
self.dimensions = dimensions
self.backend_kwargs = backend_kwargs self.backend_kwargs = backend_kwargs
self.chunks: List[Dict[str, Any]] = [] self.chunks: List[Dict[str, Any]] = []
print(f"INFO: LeannBuilder initialized with '{backend_name}' backend.") print(f"INFO: LeannBuilder initialized with '{backend_name}' backend.")
@@ -40,12 +88,18 @@ class LeannBuilder:
if not self.chunks: if not self.chunks:
raise ValueError("No chunks added. Use add_text() first.") raise ValueError("No chunks added. Use add_text() first.")
if self.dimensions is None:
self.dimensions = _get_embedding_dimensions(self.embedding_model)
print(f"INFO: Auto-detected dimensions for '{self.embedding_model}': {self.dimensions}")
texts_to_embed = [c["text"] for c in self.chunks] texts_to_embed = [c["text"] for c in self.chunks]
embeddings = _compute_embeddings(texts_to_embed, self.embedding_model) embeddings = _compute_embeddings(texts_to_embed, self.embedding_model)
builder_instance = self.backend_factory.builder(**self.backend_kwargs) current_backend_kwargs = self.backend_kwargs.copy()
# Pass chunks data for passages file generation current_backend_kwargs['dimensions'] = self.dimensions
build_kwargs = self.backend_kwargs.copy() builder_instance = self.backend_factory.builder(**current_backend_kwargs)
build_kwargs = current_backend_kwargs.copy()
build_kwargs['chunks'] = self.chunks build_kwargs['chunks'] = self.chunks
builder_instance.build(embeddings, index_path, **build_kwargs) builder_instance.build(embeddings, index_path, **build_kwargs)
@@ -56,6 +110,7 @@ class LeannBuilder:
"version": "0.1.0", "version": "0.1.0",
"backend_name": self.backend_name, "backend_name": self.backend_name,
"embedding_model": self.embedding_model, "embedding_model": self.embedding_model,
"dimensions": self.dimensions,
"backend_kwargs": self.backend_kwargs, "backend_kwargs": self.backend_kwargs,
"num_chunks": len(self.chunks), "num_chunks": len(self.chunks),
"chunks": self.chunks, "chunks": self.chunks,
@@ -87,6 +142,8 @@ class LeannSearcher:
final_kwargs = self.meta_data.get("backend_kwargs", {}) final_kwargs = self.meta_data.get("backend_kwargs", {})
final_kwargs.update(backend_kwargs) final_kwargs.update(backend_kwargs)
if 'dimensions' not in final_kwargs:
final_kwargs['dimensions'] = self.meta_data.get('dimensions')
self.backend_impl = backend_factory.searcher(index_path, **final_kwargs) self.backend_impl = backend_factory.searcher(index_path, **final_kwargs)
print(f"INFO: LeannSearcher initialized with '{backend_name}' backend using index '{index_path}'.") print(f"INFO: LeannSearcher initialized with '{backend_name}' backend using index '{index_path}'.")
@@ -94,18 +151,19 @@ class LeannSearcher:
def search(self, query: str, top_k: int = 5, **search_kwargs): def search(self, query: str, top_k: int = 5, **search_kwargs):
query_embedding = _compute_embeddings([query], self.embedding_model) query_embedding = _compute_embeddings([query], self.embedding_model)
search_kwargs['embedding_model'] = self.embedding_model
results = self.backend_impl.search(query_embedding, top_k, **search_kwargs) results = self.backend_impl.search(query_embedding, top_k, **search_kwargs)
enriched_results = [] enriched_results = []
for label, dist in zip(results['labels'][0], results['distances'][0]): for label, dist in zip(results['labels'][0], results['distances'][0]):
if label < len(self.meta_data['chunks']): if label < len(self.meta_data['chunks']):
chunk_info = self.meta_data['chunks'][label] chunk_info = self.meta_data['chunks'][label]
enriched_results.append({ enriched_results.append(SearchResult(
"id": label, id=label,
"score": dist, score=dist,
"text": chunk_info['text'], text=chunk_info['text'],
"metadata": chunk_info['metadata'] metadata=chunk_info.get('metadata', {})
}) ))
return enriched_results return enriched_results
@@ -125,15 +183,6 @@ class LeannChat:
self.searcher = LeannSearcher(index_path, **kwargs) self.searcher = LeannSearcher(index_path, **kwargs)
self.llm_model = llm_model self.llm_model = llm_model
self.openai_client = None # Lazy load
def _get_openai_client(self):
if self.openai_client is None:
api_key = os.getenv("OPENAI_API_KEY")
if not api_key:
raise ValueError("OPENAI_API_KEY environment variable not set.")
self.openai_client = openai.OpenAI(api_key=api_key)
return self.openai_client
def ask(self, question: str, top_k=5, **kwargs): def ask(self, question: str, top_k=5, **kwargs):
""" """
@@ -165,13 +214,13 @@ class LeannChat:
""" """
results = self.searcher.search(question, top_k=top_k, **kwargs) results = self.searcher.search(question, top_k=top_k, **kwargs)
context = "\n\n".join([r['text'] for r in results]) context = "\n\n".join([r.text for r in results])
prompt = f"Context:\n{context}\n\nQuestion: {question}\n\nAnswer:" prompt = f"Context:\n{context}\n\nQuestion: {question}\n\nAnswer:"
print(f"DEBUG: Calling LLM with prompt: {prompt}...") print(f"DEBUG: Calling LLM with prompt: {prompt}...")
try: try:
client = self._get_openai_client() client = _get_openai_client()
response = client.chat.completions.create( response = client.chat.completions.create(
model=self.llm_model, model=self.llm_model,
messages=[ messages=[