diff --git a/examples/run_evaluation.py b/examples/run_evaluation.py index f80d7a6..6b07f09 100644 --- a/examples/run_evaluation.py +++ b/examples/run_evaluation.py @@ -13,10 +13,10 @@ import sys import numpy as np from typing import List -from leann.api import LeannSearcher +from leann.api import LeannSearcher, LeannBuilder -def download_data_if_needed(data_root: Path): +def download_data_if_needed(data_root: Path, download_embeddings: bool = False): """Checks if the data directory exists, and if not, downloads it from HF Hub.""" if not data_root.exists(): print(f"Data directory '{data_root}' not found.") @@ -26,13 +26,32 @@ def download_data_if_needed(data_root: Path): try: from huggingface_hub import snapshot_download - snapshot_download( - repo_id="LEANN-RAG/leann-rag-evaluation-data", - repo_type="dataset", - local_dir=data_root, - local_dir_use_symlinks=False, # Recommended for Windows compatibility and simpler structure - ) - print("Data download complete!") + if download_embeddings: + # Download everything including embeddings (large files) + snapshot_download( + repo_id="LEANN-RAG/leann-rag-evaluation-data", + repo_type="dataset", + local_dir=data_root, + local_dir_use_symlinks=False, + ) + print("Data download complete (including embeddings)!") + else: + # Download only specific folders, excluding embeddings + allow_patterns = [ + "ground_truth/**", + "indices/**", + "queries/**", + "*.md", + "*.txt", + ] + snapshot_download( + repo_id="LEANN-RAG/leann-rag-evaluation-data", + repo_type="dataset", + local_dir=data_root, + local_dir_use_symlinks=False, + allow_patterns=allow_patterns, + ) + print("Data download complete (excluding embeddings)!") except ImportError: print( "Error: huggingface_hub is not installed. Please install it to download the data:" @@ -44,6 +63,43 @@ def download_data_if_needed(data_root: Path): sys.exit(1) +def download_embeddings_if_needed(data_root: Path, dataset_type: str = None): + """Download embeddings files specifically.""" + embeddings_dir = data_root / "embeddings" + + if dataset_type: + # Check if specific dataset embeddings exist + target_file = embeddings_dir / dataset_type / "passages_00.pkl" + if target_file.exists(): + print(f"Embeddings for {dataset_type} already exist") + return str(target_file) + + print("Downloading embeddings from HuggingFace Hub...") + try: + from huggingface_hub import snapshot_download + + # Download only embeddings folder + snapshot_download( + repo_id="LEANN-RAG/leann-rag-evaluation-data", + repo_type="dataset", + local_dir=data_root, + local_dir_use_symlinks=False, + allow_patterns=["embeddings/**/*.pkl"], + ) + print("Embeddings download complete!") + + if dataset_type: + target_file = embeddings_dir / dataset_type / "passages_00.pkl" + if target_file.exists(): + return str(target_file) + + return str(embeddings_dir) + + except Exception as e: + print(f"Error downloading embeddings: {e}") + sys.exit(1) + + # --- Helper Function to get Golden Passages --- def get_golden_texts(searcher: LeannSearcher, golden_ids: List[int]) -> set: """ @@ -72,12 +128,76 @@ def load_queries(file_path: Path) -> List[str]: return queries +def build_index_from_embeddings( + embeddings_file: str, output_path: str, backend: str = "hnsw" +): + """ + Build a LEANN index from pre-computed embeddings. + + Args: + embeddings_file: Path to pickle file with (ids, embeddings) tuple + output_path: Path where to save the index + backend: Backend to use ("hnsw" or "diskann") + """ + print(f"Building {backend} index from embeddings: {embeddings_file}") + + # Create builder with appropriate parameters + if backend == "hnsw": + builder_kwargs = { + "M": 32, # Graph degree + "efConstruction": 256, # Construction complexity + "is_compact": True, # Use compact storage + "is_recompute": True, # Enable pruning for better recall + } + elif backend == "diskann": + builder_kwargs = { + "complexity": 64, + "graph_degree": 32, + "search_memory_maximum": 8.0, # GB + "build_memory_maximum": 16.0, # GB + } + else: + builder_kwargs = {} + + builder = LeannBuilder( + backend_name=backend, + embedding_model="facebook/contriever-msmarco", # Model used to create embeddings + dimensions=768, # Will be auto-detected from embeddings + **builder_kwargs, + ) + + # Build index from precomputed embeddings + builder.build_index_from_embeddings(output_path, embeddings_file) + print(f"Index saved to: {output_path}") + return output_path + + def main(): parser = argparse.ArgumentParser( description="Run recall evaluation on a LEANN index." ) parser.add_argument( - "index_path", type=str, help="Path to the LEANN index to evaluate." + "index_path", + type=str, + nargs="?", + help="Path to the LEANN index to evaluate or build (optional).", + ) + parser.add_argument( + "--mode", + choices=["evaluate", "build"], + default="evaluate", + help="Mode: 'evaluate' existing index or 'build' from embeddings", + ) + parser.add_argument( + "--embeddings-file", + type=str, + help="Path to embeddings pickle file (optional for build mode)", + ) + parser.add_argument( + "--backend", + choices=["hnsw", "diskann"], + default="hnsw", + help="Backend to use for building index (default: hnsw)", ) parser.add_argument( "--num-queries", type=int, default=10, help="Number of queries to evaluate." @@ -96,8 +216,90 @@ def main(): project_root = Path(__file__).resolve().parent.parent data_root = project_root / "data" - # Automatically download data if it doesn't exist - download_data_if_needed(data_root) + # Download data based on mode + if args.mode == "build": + # For building mode, we need embeddings + download_data_if_needed( + data_root, download_embeddings=False + ) # Basic data first + + # Auto-detect dataset type and download embeddings + if args.embeddings_file: + embeddings_file = args.embeddings_file + # Try to detect dataset type from embeddings file path + if "rpj_wiki" in str(embeddings_file): + dataset_type = "rpj_wiki" + elif "dpr" in str(embeddings_file): + dataset_type = "dpr" + else: + dataset_type = "dpr" # Default + else: + # Auto-detect from index path if provided, otherwise default to DPR + if args.index_path: + index_path_str = str(args.index_path) + if "rpj_wiki" in index_path_str: + dataset_type = "rpj_wiki" + elif "dpr" in index_path_str: + dataset_type = "dpr" + else: + dataset_type = "dpr" # Default to DPR + else: + dataset_type = "dpr" # Default to DPR + + embeddings_file = download_embeddings_if_needed(data_root, dataset_type) + + # Auto-generate index path if not provided + if not args.index_path: + indices_dir = data_root / "indices" / dataset_type + indices_dir.mkdir(parents=True, exist_ok=True) + args.index_path = str(indices_dir / f"{dataset_type}_from_embeddings") + print(f"Auto-generated index path: {args.index_path}") + + print(f"Building index from embeddings: {embeddings_file}") + built_index_path = build_index_from_embeddings( + embeddings_file, args.index_path, args.backend + ) + print(f"Index built successfully: {built_index_path}") + + # Ask if user wants to run evaluation + eval_response = ( + input("Run evaluation on the built index? (y/n): ").strip().lower() + ) + if eval_response != "y": + print("Index building complete. Exiting.") + return + else: + # For evaluation mode, don't need embeddings + download_data_if_needed(data_root, download_embeddings=False) + + # Auto-detect index path if not provided + if not args.index_path: + # Default to using downloaded indices + indices_dir = data_root / "indices" + + # Try common datasets in order of preference + for dataset in ["dpr", "rpj_wiki"]: + dataset_dir = indices_dir / dataset + if dataset_dir.exists(): + # Look for index files + index_files = list(dataset_dir.glob("*.index")) + list( + dataset_dir.glob("*_disk.index") + ) + if index_files: + args.index_path = str( + index_files[0].with_suffix("") + ) # Remove .index extension + print(f"Using index: {args.index_path}") + break + + if not args.index_path: + print( + "No indices found. The data download should have included pre-built indices." + ) + print( + "Please check the data/indices/ directory or provide --index-path manually." + ) + sys.exit(1) # Detect dataset type from index path to select the correct ground truth index_path_str = str(args.index_path) diff --git a/packages/leann-core/src/leann/api.py b/packages/leann-core/src/leann/api.py index c6e43ab..529d817 100644 --- a/packages/leann-core/src/leann/api.py +++ b/packages/leann-core/src/leann/api.py @@ -18,14 +18,14 @@ from .chat import get_llm def compute_embeddings( - chunks: List[str], - model_name: str, + chunks: List[str], + model_name: str, mode: str = "sentence-transformers", - use_server: bool = True + use_server: bool = True, ) -> np.ndarray: """ Computes embeddings using different backends. - + Args: chunks: List of text chunks to embed model_name: Name of the embedding model @@ -34,40 +34,48 @@ def compute_embeddings( - "mlx": Use MLX backend for Apple Silicon - "openai": Use OpenAI embedding API use_server: Whether to use embedding server (True for search, False for build) - + Returns: numpy array of embeddings """ # Auto-detect mode based on model name if not explicitly set if mode == "sentence-transformers" and model_name.startswith("text-embedding-"): mode = "openai" - + if mode == "mlx": return compute_embeddings_mlx(chunks, model_name) elif mode == "openai": return compute_embeddings_openai(chunks, model_name) elif mode == "sentence-transformers": - return compute_embeddings_sentence_transformers(chunks, model_name, use_server=use_server) + return compute_embeddings_sentence_transformers( + chunks, model_name, use_server=use_server + ) else: - raise ValueError(f"Unsupported embedding mode: {mode}. Supported modes: sentence-transformers, mlx, openai") + raise ValueError( + f"Unsupported embedding mode: {mode}. Supported modes: sentence-transformers, mlx, openai" + ) -def compute_embeddings_sentence_transformers(chunks: List[str], model_name: str, use_server: bool = True) -> np.ndarray: +def compute_embeddings_sentence_transformers( + chunks: List[str], model_name: str, use_server: bool = True +) -> np.ndarray: """Computes embeddings using sentence-transformers. - + Args: chunks: List of text chunks to embed model_name: Name of the sentence transformer model use_server: If True, use embedding server (good for search). If False, use direct computation (good for build). """ if not use_server: - print(f"INFO: Computing embeddings for {len(chunks)} chunks using SentenceTransformer model '{model_name}' (direct)...") + print( + f"INFO: Computing embeddings for {len(chunks)} chunks using SentenceTransformer model '{model_name}' (direct)..." + ) return _compute_embeddings_sentence_transformers_direct(chunks, model_name) - + print( f"INFO: Computing embeddings for {len(chunks)} chunks using SentenceTransformer model '{model_name}' (via embedding server)..." ) - + # Use embedding server for sentence-transformers too # This avoids loading the model twice (once in API, once in server) try: @@ -76,49 +84,55 @@ def compute_embeddings_sentence_transformers(chunks: List[str], model_name: str, import msgpack import numpy as np from .embedding_server_manager import EmbeddingServerManager - + # Ensure embedding server is running port = 5557 - server_manager = EmbeddingServerManager(backend_module_name="leann_backend_hnsw.hnsw_embedding_server") - + server_manager = EmbeddingServerManager( + backend_module_name="leann_backend_hnsw.hnsw_embedding_server" + ) + server_started = server_manager.start_server( port=port, model_name=model_name, embedding_mode="sentence-transformers", enable_warmup=False, ) - + if not server_started: raise RuntimeError(f"Failed to start embedding server on port {port}") - + # Connect to embedding server context = zmq.Context() socket = context.socket(zmq.REQ) socket.connect(f"tcp://localhost:{port}") - + # Send chunks to server for embedding computation request = chunks socket.send(msgpack.packb(request)) - + # Receive embeddings from server response = socket.recv() embeddings_list = msgpack.unpackb(response) - + # Convert back to numpy array embeddings = np.array(embeddings_list, dtype=np.float32) - + socket.close() context.term() - + return embeddings - + except Exception as e: # Fallback to direct sentence-transformers if server connection fails - print(f"Warning: Failed to connect to embedding server, falling back to direct computation: {e}") + print( + f"Warning: Failed to connect to embedding server, falling back to direct computation: {e}" + ) return _compute_embeddings_sentence_transformers_direct(chunks, model_name) -def _compute_embeddings_sentence_transformers_direct(chunks: List[str], model_name: str) -> np.ndarray: +def _compute_embeddings_sentence_transformers_direct( + chunks: List[str], model_name: str +) -> np.ndarray: """Direct sentence-transformers computation (fallback).""" try: from sentence_transformers import SentenceTransformer @@ -159,37 +173,40 @@ def compute_embeddings_openai(chunks: List[str], model_name: str) -> np.ndarray: raise RuntimeError( "openai not available. Install with: uv pip install openai" ) from e - + # Get API key from environment api_key = os.getenv("OPENAI_API_KEY") if not api_key: raise RuntimeError("OPENAI_API_KEY environment variable not set") - + client = openai.OpenAI(api_key=api_key) - - print(f"INFO: Computing embeddings for {len(chunks)} chunks using OpenAI model '{model_name}'...") - + + print( + f"INFO: Computing embeddings for {len(chunks)} chunks using OpenAI model '{model_name}'..." + ) + # OpenAI has a limit on batch size and input length max_batch_size = 100 # Conservative batch size all_embeddings = [] - + for i in range(0, len(chunks), max_batch_size): - batch_chunks = chunks[i:i + max_batch_size] - print(f"INFO: Processing batch {i//max_batch_size + 1}/{(len(chunks) + max_batch_size - 1)//max_batch_size}") - + batch_chunks = chunks[i : i + max_batch_size] + print( + f"INFO: Processing batch {i // max_batch_size + 1}/{(len(chunks) + max_batch_size - 1) // max_batch_size}" + ) + try: - response = client.embeddings.create( - model=model_name, - input=batch_chunks - ) + response = client.embeddings.create(model=model_name, input=batch_chunks) batch_embeddings = [embedding.embedding for embedding in response.data] all_embeddings.extend(batch_embeddings) except Exception as e: print(f"ERROR: Failed to get embeddings for batch starting at {i}: {e}") raise - + embeddings = np.array(all_embeddings, dtype=np.float32) - print(f"INFO: Generated {len(embeddings)} embeddings with dimension {embeddings.shape[1]}") + print( + f"INFO: Generated {len(embeddings)} embeddings with dimension {embeddings.shape[1]}" + ) return embeddings @@ -308,7 +325,12 @@ class LeannBuilder: raise ValueError("No chunks added.") if self.dimensions is None: self.dimensions = len( - compute_embeddings(["dummy"], self.embedding_model, self.embedding_mode, use_server=False)[0] + compute_embeddings( + ["dummy"], + self.embedding_model, + self.embedding_mode, + use_server=False, + )[0] ) path = Path(index_path) index_dir = path.parent @@ -371,6 +393,129 @@ class LeannBuilder: with open(leann_meta_path, "w", encoding="utf-8") as f: json.dump(meta_data, f, indent=2) + def build_index_from_embeddings(self, index_path: str, embeddings_file: str): + """ + Build an index from pre-computed embeddings stored in a pickle file. + + Args: + index_path: Path where the index will be saved + embeddings_file: Path to pickle file containing (ids, embeddings) tuple + """ + # Load pre-computed embeddings + with open(embeddings_file, "rb") as f: + data = pickle.load(f) + + if not isinstance(data, tuple) or len(data) != 2: + raise ValueError( + f"Invalid embeddings file format. Expected tuple with 2 elements, got {type(data)}" + ) + + ids, embeddings = data + + if not isinstance(embeddings, np.ndarray): + raise ValueError( + f"Expected embeddings to be numpy array, got {type(embeddings)}" + ) + + if len(ids) != embeddings.shape[0]: + raise ValueError( + f"Mismatch between number of IDs ({len(ids)}) and embeddings ({embeddings.shape[0]})" + ) + + # Validate/set dimensions + embedding_dim = embeddings.shape[1] + if self.dimensions is None: + self.dimensions = embedding_dim + elif self.dimensions != embedding_dim: + raise ValueError( + f"Dimension mismatch: expected {self.dimensions}, got {embedding_dim}" + ) + + print( + f"Building index from precomputed embeddings: {len(ids)} items, {embedding_dim} dimensions" + ) + + # Ensure we have text data for each embedding + if len(self.chunks) != len(ids): + # If no text chunks provided, create placeholder text entries + if not self.chunks: + print("No text chunks provided, creating placeholder entries...") + for id_val in ids: + self.add_text( + f"Document {id_val}", + metadata={"id": str(id_val), "from_embeddings": True}, + ) + else: + raise ValueError( + f"Number of text chunks ({len(self.chunks)}) doesn't match number of embeddings ({len(ids)})" + ) + + # Build file structure + path = Path(index_path) + index_dir = path.parent + index_name = path.name + index_dir.mkdir(parents=True, exist_ok=True) + passages_file = index_dir / f"{index_name}.passages.jsonl" + offset_file = index_dir / f"{index_name}.passages.idx" + + # Write passages and create offset map + offset_map = {} + with open(passages_file, "w", encoding="utf-8") as f: + for chunk in self.chunks: + offset = f.tell() + json.dump( + { + "id": chunk["id"], + "text": chunk["text"], + "metadata": chunk["metadata"], + }, + f, + ensure_ascii=False, + ) + f.write("\n") + offset_map[chunk["id"]] = offset + + with open(offset_file, "wb") as f: + pickle.dump(offset_map, f) + + # Build the vector index using precomputed embeddings + string_ids = [str(id_val) for id_val in ids] + current_backend_kwargs = {**self.backend_kwargs, "dimensions": self.dimensions} + builder_instance = self.backend_factory.builder(**current_backend_kwargs) + builder_instance.build(embeddings, string_ids, index_path) + + # Create metadata file + leann_meta_path = index_dir / f"{index_name}.meta.json" + meta_data = { + "version": "1.0", + "backend_name": self.backend_name, + "embedding_model": self.embedding_model, + "dimensions": self.dimensions, + "backend_kwargs": self.backend_kwargs, + "embedding_mode": self.embedding_mode, + "passage_sources": [ + { + "type": "jsonl", + "path": str(passages_file), + "index_path": str(offset_file), + } + ], + "built_from_precomputed_embeddings": True, + "embeddings_source": str(embeddings_file), + } + + # Add storage status flags for HNSW backend + if self.backend_name == "hnsw": + is_compact = self.backend_kwargs.get("is_compact", True) + is_recompute = self.backend_kwargs.get("is_recompute", True) + meta_data["is_compact"] = is_compact + meta_data["is_pruned"] = is_compact and is_recompute + + with open(leann_meta_path, "w", encoding="utf-8") as f: + json.dump(meta_data, f, indent=2) + + print(f"Index built successfully from precomputed embeddings: {index_path}") + class LeannSearcher: def __init__(self, index_path: str, enable_warmup: bool = False, **backend_kwargs): @@ -382,7 +527,9 @@ class LeannSearcher: backend_name = self.meta_data["backend_name"] self.embedding_model = self.meta_data["embedding_model"] # Support both old and new format - self.embedding_mode = self.meta_data.get("embedding_mode", "sentence-transformers") + self.embedding_mode = self.meta_data.get( + "embedding_mode", "sentence-transformers" + ) # Backward compatibility with use_mlx if self.meta_data.get("use_mlx", False): self.embedding_mode = "mlx" @@ -414,6 +561,7 @@ class LeannSearcher: # Use backend's compute_query_embedding method # This will automatically use embedding server if available and needed import time + start_time = time.time() query_embedding = self.backend_impl.compute_query_embedding(query, zmq_port) @@ -513,7 +661,7 @@ class LeannChat: "Please provide the best answer you can based on this context and your knowledge." ) - ans=self.llm.ask(prompt, **llm_kwargs) + ans = self.llm.ask(prompt, **llm_kwargs) return ans def start_interactive(self):