""" This file contains the core API for the LEANN project, now definitively updated with the correct, original embedding logic from the user's reference code. """ import json import pickle from leann.interface import LeannBackendSearcherInterface import numpy as np import time from pathlib import Path from typing import List, Dict, Any, Optional, Literal from dataclasses import dataclass, field from .registry import BACKEND_REGISTRY from .interface import LeannBackendFactoryInterface from .chat import get_llm import logging logger = logging.getLogger(__name__) def compute_embeddings( chunks: List[str], model_name: str, mode: str = "sentence-transformers", use_server: bool = True, port: Optional[int] = None, is_build=False, ) -> np.ndarray: """ Computes embeddings using different backends. Args: chunks: List of text chunks to embed model_name: Name of the embedding model mode: Embedding backend mode. Options: - "sentence-transformers": Use sentence-transformers library (default) - "mlx": Use MLX backend for Apple Silicon - "openai": Use OpenAI embedding API use_server: Whether to use embedding server (True for search, False for build) Returns: numpy array of embeddings """ if use_server: # Use embedding server (for search/query) if port is None: raise ValueError("port is required when use_server is True") return compute_embeddings_via_server(chunks, model_name, port=port) else: # Use direct computation (for build_index) from .embedding_compute import ( compute_embeddings as compute_embeddings_direct, ) return compute_embeddings_direct( chunks, model_name, mode=mode, is_build=is_build, ) def compute_embeddings_via_server( chunks: List[str], model_name: str, port: int ) -> np.ndarray: """Computes embeddings using sentence-transformers. Args: chunks: List of text chunks to embed model_name: Name of the sentence transformer model """ logger.info( f"Computing embeddings for {len(chunks)} chunks using SentenceTransformer model '{model_name}' (via embedding server)..." ) import zmq import msgpack import numpy as np # Connect to embedding server context = zmq.Context() socket = context.socket(zmq.REQ) socket.connect(f"tcp://localhost:{port}") # Send chunks to server for embedding computation request = chunks socket.send(msgpack.packb(request)) # Receive embeddings from server response = socket.recv() embeddings_list = msgpack.unpackb(response) # Convert back to numpy array embeddings = np.array(embeddings_list, dtype=np.float32) socket.close() context.term() return embeddings @dataclass class SearchResult: id: str score: float text: str metadata: Dict[str, Any] = field(default_factory=dict) class PassageManager: def __init__(self, passage_sources: List[Dict[str, Any]]): self.offset_maps = {} self.passage_files = {} self.global_offset_map = {} # Combined map for fast lookup for source in passage_sources: assert source["type"] == "jsonl", "only jsonl is supported" passage_file = source["path"] index_file = source["index_path"] # .idx file # Fix path resolution for Colab and other environments if not Path(index_file).is_absolute(): # If relative path, try to resolve it properly index_file = str(Path(index_file).resolve()) if not Path(index_file).exists(): raise FileNotFoundError(f"Passage index file not found: {index_file}") with open(index_file, "rb") as f: offset_map = pickle.load(f) self.offset_maps[passage_file] = offset_map self.passage_files[passage_file] = passage_file # Build global map for O(1) lookup for passage_id, offset in offset_map.items(): self.global_offset_map[passage_id] = (passage_file, offset) def get_passage(self, passage_id: str) -> Dict[str, Any]: if passage_id in self.global_offset_map: passage_file, offset = self.global_offset_map[passage_id] # Lazy file opening - only open when needed with open(passage_file, "r", encoding="utf-8") as f: f.seek(offset) return json.loads(f.readline()) raise KeyError(f"Passage ID not found: {passage_id}") class LeannBuilder: def __init__( self, backend_name: str, embedding_model: str = "facebook/contriever", dimensions: Optional[int] = None, embedding_mode: str = "sentence-transformers", **backend_kwargs, ): self.backend_name = backend_name backend_factory: LeannBackendFactoryInterface | None = BACKEND_REGISTRY.get( backend_name ) if backend_factory is None: raise ValueError(f"Backend '{backend_name}' not found or not registered.") self.backend_factory = backend_factory self.embedding_model = embedding_model self.dimensions = dimensions self.embedding_mode = embedding_mode self.backend_kwargs = backend_kwargs self.chunks: List[Dict[str, Any]] = [] def add_text(self, text: str, metadata: Optional[Dict[str, Any]] = None): if metadata is None: metadata = {} passage_id = metadata.get("id", str(len(self.chunks))) chunk_data = {"id": passage_id, "text": text, "metadata": metadata} self.chunks.append(chunk_data) def build_index(self, index_path: str): if not self.chunks: raise ValueError("No chunks added.") if self.dimensions is None: self.dimensions = len( compute_embeddings( ["dummy"], self.embedding_model, self.embedding_mode, use_server=False, )[0] ) path = Path(index_path) index_dir = path.parent index_name = path.name index_dir.mkdir(parents=True, exist_ok=True) passages_file = index_dir / f"{index_name}.passages.jsonl" offset_file = index_dir / f"{index_name}.passages.idx" offset_map = {} with open(passages_file, "w", encoding="utf-8") as f: try: from tqdm import tqdm chunk_iterator = tqdm( self.chunks, desc="Writing passages", unit="chunk" ) except ImportError: chunk_iterator = self.chunks for chunk in chunk_iterator: offset = f.tell() json.dump( { "id": chunk["id"], "text": chunk["text"], "metadata": chunk["metadata"], }, f, ensure_ascii=False, ) f.write("\n") offset_map[chunk["id"]] = offset with open(offset_file, "wb") as f: pickle.dump(offset_map, f) texts_to_embed = [c["text"] for c in self.chunks] embeddings = compute_embeddings( texts_to_embed, self.embedding_model, self.embedding_mode, use_server=False, is_build=True, ) string_ids = [chunk["id"] for chunk in self.chunks] current_backend_kwargs = {**self.backend_kwargs, "dimensions": self.dimensions} builder_instance = self.backend_factory.builder(**current_backend_kwargs) builder_instance.build( embeddings, string_ids, index_path, **current_backend_kwargs ) leann_meta_path = index_dir / f"{index_name}.meta.json" meta_data = { "version": "1.0", "backend_name": self.backend_name, "embedding_model": self.embedding_model, "dimensions": self.dimensions, "backend_kwargs": self.backend_kwargs, "embedding_mode": self.embedding_mode, "passage_sources": [ { "type": "jsonl", "path": str(passages_file), "index_path": str(offset_file), } ], } # Add storage status flags for HNSW backend if self.backend_name == "hnsw": is_compact = self.backend_kwargs.get("is_compact", True) is_recompute = self.backend_kwargs.get("is_recompute", True) meta_data["is_compact"] = is_compact meta_data["is_pruned"] = ( is_compact and is_recompute ) # Pruned only if compact and recompute with open(leann_meta_path, "w", encoding="utf-8") as f: json.dump(meta_data, f, indent=2) def build_index_from_embeddings(self, index_path: str, embeddings_file: str): """ Build an index from pre-computed embeddings stored in a pickle file. Args: index_path: Path where the index will be saved embeddings_file: Path to pickle file containing (ids, embeddings) tuple """ # Load pre-computed embeddings with open(embeddings_file, "rb") as f: data = pickle.load(f) if not isinstance(data, tuple) or len(data) != 2: raise ValueError( f"Invalid embeddings file format. Expected tuple with 2 elements, got {type(data)}" ) ids, embeddings = data if not isinstance(embeddings, np.ndarray): raise ValueError( f"Expected embeddings to be numpy array, got {type(embeddings)}" ) if len(ids) != embeddings.shape[0]: raise ValueError( f"Mismatch between number of IDs ({len(ids)}) and embeddings ({embeddings.shape[0]})" ) # Validate/set dimensions embedding_dim = embeddings.shape[1] if self.dimensions is None: self.dimensions = embedding_dim elif self.dimensions != embedding_dim: raise ValueError( f"Dimension mismatch: expected {self.dimensions}, got {embedding_dim}" ) logger.info( f"Building index from precomputed embeddings: {len(ids)} items, {embedding_dim} dimensions" ) # Ensure we have text data for each embedding if len(self.chunks) != len(ids): # If no text chunks provided, create placeholder text entries if not self.chunks: logger.info("No text chunks provided, creating placeholder entries...") for id_val in ids: self.add_text( f"Document {id_val}", metadata={"id": str(id_val), "from_embeddings": True}, ) else: raise ValueError( f"Number of text chunks ({len(self.chunks)}) doesn't match number of embeddings ({len(ids)})" ) # Build file structure path = Path(index_path) index_dir = path.parent index_name = path.name index_dir.mkdir(parents=True, exist_ok=True) passages_file = index_dir / f"{index_name}.passages.jsonl" offset_file = index_dir / f"{index_name}.passages.idx" # Write passages and create offset map offset_map = {} with open(passages_file, "w", encoding="utf-8") as f: for chunk in self.chunks: offset = f.tell() json.dump( { "id": chunk["id"], "text": chunk["text"], "metadata": chunk["metadata"], }, f, ensure_ascii=False, ) f.write("\n") offset_map[chunk["id"]] = offset with open(offset_file, "wb") as f: pickle.dump(offset_map, f) # Build the vector index using precomputed embeddings string_ids = [str(id_val) for id_val in ids] current_backend_kwargs = {**self.backend_kwargs, "dimensions": self.dimensions} builder_instance = self.backend_factory.builder(**current_backend_kwargs) builder_instance.build(embeddings, string_ids, index_path) # Create metadata file leann_meta_path = index_dir / f"{index_name}.meta.json" meta_data = { "version": "1.0", "backend_name": self.backend_name, "embedding_model": self.embedding_model, "dimensions": self.dimensions, "backend_kwargs": self.backend_kwargs, "embedding_mode": self.embedding_mode, "passage_sources": [ { "type": "jsonl", "path": str(passages_file), "index_path": str(offset_file), } ], "built_from_precomputed_embeddings": True, "embeddings_source": str(embeddings_file), } # Add storage status flags for HNSW backend if self.backend_name == "hnsw": is_compact = self.backend_kwargs.get("is_compact", True) is_recompute = self.backend_kwargs.get("is_recompute", True) meta_data["is_compact"] = is_compact meta_data["is_pruned"] = is_compact and is_recompute with open(leann_meta_path, "w", encoding="utf-8") as f: json.dump(meta_data, f, indent=2) logger.info( f"Index built successfully from precomputed embeddings: {index_path}" ) class LeannSearcher: def __init__(self, index_path: str, enable_warmup: bool = False, **backend_kwargs): # Fix path resolution for Colab and other environments if not Path(index_path).is_absolute(): index_path = str(Path(index_path).resolve()) self.meta_path_str = f"{index_path}.meta.json" if not Path(self.meta_path_str).exists(): raise FileNotFoundError( f"Leann metadata file not found at {self.meta_path_str}" ) with open(self.meta_path_str, "r", encoding="utf-8") as f: self.meta_data = json.load(f) backend_name = self.meta_data["backend_name"] self.embedding_model = self.meta_data["embedding_model"] # Support both old and new format self.embedding_mode = self.meta_data.get( "embedding_mode", "sentence-transformers" ) self.passage_manager = PassageManager(self.meta_data.get("passage_sources", [])) backend_factory = BACKEND_REGISTRY.get(backend_name) if backend_factory is None: raise ValueError(f"Backend '{backend_name}' not found.") final_kwargs = {**self.meta_data.get("backend_kwargs", {}), **backend_kwargs} final_kwargs["enable_warmup"] = enable_warmup self.backend_impl: LeannBackendSearcherInterface = backend_factory.searcher( index_path, **final_kwargs ) def search( self, query: str, top_k: int = 5, complexity: int = 64, beam_width: int = 1, prune_ratio: float = 0.0, recompute_embeddings: bool = True, pruning_strategy: Literal["global", "local", "proportional"] = "global", expected_zmq_port: int = 5557, **kwargs, ) -> List[SearchResult]: logger.info("🔍 LeannSearcher.search() called:") logger.info(f" Query: '{query}'") logger.info(f" Top_k: {top_k}") logger.info(f" Additional kwargs: {kwargs}") zmq_port = None start_time = time.time() if recompute_embeddings: zmq_port = self.backend_impl._ensure_server_running( self.meta_path_str, port=expected_zmq_port, **kwargs, ) del expected_zmq_port zmq_time = time.time() - start_time logger.info(f" Launching server time: {zmq_time} seconds") start_time = time.time() query_embedding = self.backend_impl.compute_query_embedding( query, use_server_if_available=recompute_embeddings, zmq_port=zmq_port, ) # logger.info(f" Generated embedding shape: {query_embedding.shape}") embedding_time = time.time() - start_time # logger.info(f" Embedding time: {embedding_time} seconds") start_time = time.time() results = self.backend_impl.search( query_embedding, top_k, complexity=complexity, beam_width=beam_width, prune_ratio=prune_ratio, recompute_embeddings=recompute_embeddings, pruning_strategy=pruning_strategy, zmq_port=zmq_port, **kwargs, ) search_time = time.time() - start_time # logger.info(f" Search time: {search_time} seconds") logger.info( f" Backend returned: labels={len(results.get('labels', [[]])[0])} results" ) enriched_results = [] if "labels" in results and "distances" in results: logger.info(f" Processing {len(results['labels'][0])} passage IDs:") for i, (string_id, dist) in enumerate( zip(results["labels"][0], results["distances"][0]) ): try: passage_data = self.passage_manager.get_passage(string_id) enriched_results.append( SearchResult( id=string_id, score=dist, text=passage_data["text"], metadata=passage_data.get("metadata", {}), ) ) # Color codes for better logging GREEN = "\033[92m" BLUE = "\033[94m" YELLOW = "\033[93m" RESET = "\033[0m" # Truncate text for display (first 100 chars) display_text = passage_data['text'] logger.info( f" {GREEN}✓{RESET} {BLUE}[{i + 1:2d}]{RESET} {YELLOW}ID:{RESET} '{string_id}' {YELLOW}Score:{RESET} {dist:.4f} {YELLOW}Text:{RESET} {display_text}" ) except KeyError: RED = "\033[91m" logger.error( f" {RED}✗{RESET} [{i + 1:2d}] ID: '{string_id}' -> {RED}ERROR: Passage not found!{RESET}" ) logger.info(f" {GREEN}✓ Final enriched results: {len(enriched_results)} passages{RESET}") return enriched_results class LeannChat: def __init__( self, index_path: str, llm_config: Optional[Dict[str, Any]] = None, enable_warmup: bool = False, **kwargs, ): self.searcher = LeannSearcher(index_path, enable_warmup=enable_warmup, **kwargs) self.llm = get_llm(llm_config) def ask( self, question: str, top_k: int = 5, complexity: int = 64, beam_width: int = 1, prune_ratio: float = 0.0, recompute_embeddings: bool = True, pruning_strategy: Literal["global", "local", "proportional"] = "global", llm_kwargs: Optional[Dict[str, Any]] = None, expected_zmq_port: int = 5557, **search_kwargs, ): if llm_kwargs is None: llm_kwargs = {} search_time = time.time() results = self.searcher.search( question, top_k=top_k, complexity=complexity, beam_width=beam_width, prune_ratio=prune_ratio, recompute_embeddings=recompute_embeddings, pruning_strategy=pruning_strategy, expected_zmq_port=expected_zmq_port, **search_kwargs, ) search_time = time.time() - search_time # logger.info(f" Search time: {search_time} seconds") context = "\n\n".join([r.text for r in results]) prompt = ( "Here is some retrieved context that might help answer your question:\n\n" f"{context}\n\n" f"Question: {question}\n\n" "Please provide the best answer you can based on this context and your knowledge." ) ans = self.llm.ask(prompt, **llm_kwargs) return ans def start_interactive(self): print("\nLeann Chat started (type 'quit' to exit)") while True: try: user_input = input("You: ").strip() if user_input.lower() in ["quit", "exit"]: break if not user_input: continue response = self.ask(user_input) print(f"Leann: {response}") except (KeyboardInterrupt, EOFError): print("\nGoodbye!") break