From 8bffb1e5b86d667ec90abe886b39150d2b50a0a2 Mon Sep 17 00:00:00 2001 From: Andy Lee Date: Fri, 11 Jul 2025 02:58:04 +0000 Subject: [PATCH] feat: reproducible research datas, rpj_wiki & dpr --- examples/run_evaluation.py | 157 ++++++++ .../leann_backend_diskann/diskann_backend.py | 18 +- .../leann_backend_diskann/embedding_server.py | 81 +++- .../leann_backend_hnsw/hnsw_backend.py | 130 +++---- .../hnsw_embedding_server.py | 111 ++++-- packages/leann-backend-hnsw/pyproject.toml | 3 +- packages/leann-core/src/leann/api.py | 356 +++++------------- .../src/leann/embedding_server_manager.py | 39 +- 8 files changed, 493 insertions(+), 402 deletions(-) create mode 100644 examples/run_evaluation.py diff --git a/examples/run_evaluation.py b/examples/run_evaluation.py new file mode 100644 index 0000000..4a3137a --- /dev/null +++ b/examples/run_evaluation.py @@ -0,0 +1,157 @@ +#!/usr/bin/env python3 +""" +This script runs a recall evaluation on a given LEANN index. +It correctly compares results by fetching the text content for both the new search +results and the golden standard results, making the comparison robust to ID changes. +""" + +import json +import argparse +import time +from pathlib import Path +import sys +import numpy as np +from typing import List, Dict, Any +import glob +import pickle + +# Add project root to path to allow importing from leann +project_root = Path(__file__).resolve().parent.parent +sys.path.insert(0, str(project_root)) + +from leann.api import LeannSearcher + +# --- Configuration --- +NQ_QUERIES_FILE = Path("/opt/dlami/nvme/scaling_out/examples/nq_open.jsonl") + +# Ground truth files for different datasets +GROUND_TRUTH_FILES = { + "rpj_wiki": "/opt/dlami/nvme/scaling_out/indices/rpj_wiki/facebook/contriever-msmarco/flat_results_nq_k3.json", + "dpr": "/opt/dlami/nvme/scaling_out/indices/dpr/facebook/contriever-msmarco/flat_results_nq_k3.json" +} + +# Old passages for different datasets +OLD_PASSAGES_GLOBS = { + "rpj_wiki": "/opt/dlami/nvme/scaling_out/passages/rpj_wiki/8-shards/raw_passages-*-of-8.pkl.jsonl", + "dpr": "/opt/dlami/nvme/scaling_out/passages/dpr/1-shards/raw_passages-*-of-1.pkl.jsonl" +} + +# --- Helper Class to Load Original Passages --- +class OldPassageLoader: + """A simplified version of the old LazyPassages class to fetch golden results by ID.""" + def __init__(self, passages_glob: str): + self.jsonl_paths = sorted(glob.glob(passages_glob)) + self.offsets = {} + self.fps = [open(p, "r", encoding="utf-8") for p in self.jsonl_paths] + print("Building offset map for original passages...") + for i, shard_path_str in enumerate(self.jsonl_paths): + old_idx_path = Path(shard_path_str.replace(".jsonl", ".idx")) + if not old_idx_path.exists(): continue + with open(old_idx_path, 'rb') as f: + shard_offsets = pickle.load(f) + for pid, offset in shard_offsets.items(): + self.offsets[str(pid)] = (i, offset) + print("Offset map for original passages is ready.") + + def get_passage_by_id(self, pid: str) -> Dict[str, Any]: + pid = str(pid) + if pid not in self.offsets: + raise ValueError(f"Passage ID {pid} not found in offsets") + file_idx, offset = self.offsets[pid] + fp = self.fps[file_idx] + fp.seek(offset) + return json.loads(fp.readline()) + + def __del__(self): + for fp in self.fps: + fp.close() + +def load_queries(file_path: Path) -> List[str]: + queries = [] + with open(file_path, 'r', encoding='utf-8') as f: + for line in f: + data = json.loads(line) + queries.append(data['query']) + return queries + +def main(): + parser = argparse.ArgumentParser(description="Run recall evaluation on a LEANN index.") + parser.add_argument("index_path", type=str, help="Path to the LEANN index to evaluate.") + parser.add_argument("--num-queries", type=int, default=10, help="Number of queries to evaluate.") + parser.add_argument("--top-k", type=int, default=3, help="The 'k' value for recall@k.") + parser.add_argument("--ef-search", type=int, default=120, help="The 'efSearch' parameter for HNSW.") + args = parser.parse_args() + + print(f"--- Final, Correct Recall Evaluation (efSearch={args.ef_search}) ---") + + # Detect dataset type from index path + index_path_str = str(args.index_path) + if "rpj_wiki" in index_path_str: + dataset_type = "rpj_wiki" + elif "dpr" in index_path_str: + dataset_type = "dpr" + else: + print("WARNING: Unknown dataset type, defaulting to rpj_wiki") + dataset_type = "rpj_wiki" + + print(f"INFO: Detected dataset type: {dataset_type}") + + try: + searcher = LeannSearcher(args.index_path) + queries = load_queries(NQ_QUERIES_FILE) + + golden_results_file = GROUND_TRUTH_FILES[dataset_type] + old_passages_glob = OLD_PASSAGES_GLOBS[dataset_type] + + print(f"INFO: Using ground truth file: {golden_results_file}") + print(f"INFO: Using old passages glob: {old_passages_glob}") + + with open(golden_results_file, 'r') as f: + golden_results_data = json.load(f) + + old_passage_loader = OldPassageLoader(old_passages_glob) + + num_eval_queries = min(args.num_queries, len(queries)) + queries = queries[:num_eval_queries] + + print(f"\nRunning evaluation on {num_eval_queries} queries...") + recall_scores = [] + search_times = [] + + for i in range(num_eval_queries): + start_time = time.time() + new_results = searcher.search(queries[i], top_k=args.top_k, ef=args.ef_search) + search_times.append(time.time() - start_time) + + # Correct Recall Calculation: Based on TEXT content + new_texts = {result.text for result in new_results} + golden_ids = golden_results_data["indices"][i][:args.top_k] + golden_texts = {old_passage_loader.get_passage_by_id(str(gid))['text'] for gid in golden_ids} + + overlap = len(new_texts & golden_texts) + recall = overlap / len(golden_texts) if golden_texts else 0 + recall_scores.append(recall) + + print("\n--- EVALUATION RESULTS ---") + print(f"Query: {queries[i]}") + print(f"New Results: {new_texts}") + print(f"Golden Results: {golden_texts}") + print(f"Overlap: {overlap}") + print(f"Recall: {recall}") + print(f"Search Time: {search_times[-1]:.4f}s") + print(f"--------------------------------") + + avg_recall = np.mean(recall_scores) if recall_scores else 0 + avg_time = np.mean(search_times) if search_times else 0 + + print(f"\n๐ŸŽ‰ --- Evaluation Complete ---") + print(f"Avg. Recall@{args.top_k} (efSearch={args.ef_search}): {avg_recall:.4f}") + print(f"Avg. Search Time: {avg_time:.4f}s") + + except Exception as e: + print(f"\nโŒ An error occurred during evaluation: {e}") + import traceback + traceback.print_exc() + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/packages/leann-backend-diskann/leann_backend_diskann/diskann_backend.py b/packages/leann-backend-diskann/leann_backend_diskann/diskann_backend.py index 23ea745..cd982e1 100644 --- a/packages/leann-backend-diskann/leann_backend_diskann/diskann_backend.py +++ b/packages/leann-backend-diskann/leann_backend_diskann/diskann_backend.py @@ -141,9 +141,9 @@ class DiskannSearcher(LeannBackendSearcherInterface): if not self.embedding_model: print("WARNING: embedding_model not found in meta.json. Recompute will fail if attempted.") - path = Path(index_path) - self.index_dir = path.parent - self.index_prefix = path.stem + self.index_path = Path(index_path) + self.index_dir = self.index_path.parent + self.index_prefix = self.index_path.stem # Load the label map label_map_file = self.index_dir / "leann.labels.map" @@ -199,13 +199,13 @@ class DiskannSearcher(LeannBackendSearcherInterface): passages_file = kwargs.get("passages_file") if not passages_file: - # Get the passages file path from meta.json - if 'passage_sources' in self.meta and self.meta['passage_sources']: - passage_source = self.meta['passage_sources'][0] - passages_file = passage_source['path'] - print(f"INFO: Found passages file from metadata: {passages_file}") + # Pass the metadata file instead of a single passage file + meta_file_path = self.index_path.parent / f"{self.index_path.name}.meta.json" + if meta_file_path.exists(): + passages_file = str(meta_file_path) + print(f"INFO: Using metadata file for lazy loading: {passages_file}") else: - raise RuntimeError(f"FATAL: Recompute mode enabled but no passage_sources found in metadata.") + raise RuntimeError(f"FATAL: Recompute mode enabled but metadata file not found: {meta_file_path}") server_started = self.embedding_server_manager.start_server( port=self.zmq_port, diff --git a/packages/leann-backend-diskann/leann_backend_diskann/embedding_server.py b/packages/leann-backend-diskann/leann_backend_diskann/embedding_server.py index ee2d4b2..6de653a 100644 --- a/packages/leann-backend-diskann/leann_backend_diskann/embedding_server.py +++ b/packages/leann-backend-diskann/leann_backend_diskann/embedding_server.py @@ -39,6 +39,71 @@ class SimplePassageLoader: def __len__(self) -> int: return len(self.passages_data) +def load_passages_from_metadata(meta_file: str) -> SimplePassageLoader: + """ + Load passages using metadata file with PassageManager for lazy loading + """ + # Load metadata to get passage sources + with open(meta_file, 'r') as f: + meta = json.load(f) + + # Import PassageManager dynamically to avoid circular imports + import sys + from pathlib import Path + + # Find the leann package directory relative to this file + current_dir = Path(__file__).parent + leann_core_path = current_dir.parent.parent / "leann-core" / "src" + sys.path.insert(0, str(leann_core_path)) + + try: + from leann.api import PassageManager + passage_manager = PassageManager(meta['passage_sources']) + finally: + sys.path.pop(0) + + # Load label map + passages_dir = Path(meta_file).parent + label_map_file = passages_dir / "leann.labels.map" + + if label_map_file.exists(): + import pickle + with open(label_map_file, 'rb') as f: + label_map = pickle.load(f) + print(f"Loaded label map with {len(label_map)} entries") + else: + raise FileNotFoundError(f"Label map file not found: {label_map_file}") + + print(f"Initialized lazy passage loading for {len(label_map)} passages") + + class LazyPassageLoader(SimplePassageLoader): + def __init__(self, passage_manager, label_map): + self.passage_manager = passage_manager + self.label_map = label_map + # Initialize parent with empty data + super().__init__({}) + + def __getitem__(self, passage_id: Union[str, int]) -> Dict[str, str]: + """Get passage by ID with lazy loading""" + try: + int_id = int(passage_id) + if int_id in self.label_map: + string_id = self.label_map[int_id] + passage_data = self.passage_manager.get_passage(string_id) + if passage_data and passage_data.get("text"): + return {"text": passage_data["text"]} + else: + raise RuntimeError(f"FATAL: Empty text for ID {int_id} -> {string_id}") + else: + raise RuntimeError(f"FATAL: ID {int_id} not found in label_map") + except Exception as e: + raise RuntimeError(f"FATAL: Exception getting passage {passage_id}: {e}") + + def __len__(self) -> int: + return len(self.label_map) + + return LazyPassageLoader(passage_manager, label_map) + def load_passages_from_file(passages_file: str) -> SimplePassageLoader: """ Load passages from a JSONL file with label map support @@ -140,7 +205,21 @@ def create_embedding_server_thread( # Load passages from file if provided if passages_file and os.path.exists(passages_file): - passages = load_passages_from_file(passages_file) + # Check if it's a metadata file or a single passages file + if passages_file.endswith('.meta.json'): + passages = load_passages_from_metadata(passages_file) + else: + # Try to find metadata file in same directory + from pathlib import Path + passages_dir = Path(passages_file).parent + meta_files = list(passages_dir.glob("*.meta.json")) + if meta_files: + print(f"Found metadata file: {meta_files[0]}, using lazy loading") + passages = load_passages_from_metadata(str(meta_files[0])) + else: + # Fallback to original single file loading (will cause warnings) + print("WARNING: No metadata file found, using single file loading (may cause missing passage warnings)") + passages = load_passages_from_file(passages_file) else: print("WARNING: No passages file provided or file not found. Using an empty passage loader.") passages = SimplePassageLoader() diff --git a/packages/leann-backend-hnsw/leann_backend_hnsw/hnsw_backend.py b/packages/leann-backend-hnsw/leann_backend_hnsw/hnsw_backend.py index bbf9cec..126a537 100644 --- a/packages/leann-backend-hnsw/leann_backend_hnsw/hnsw_backend.py +++ b/packages/leann-backend-hnsw/leann_backend_hnsw/hnsw_backend.py @@ -1,7 +1,6 @@ import numpy as np import os import json -import struct from pathlib import Path from typing import Dict, Any, List import contextlib @@ -161,83 +160,19 @@ class HNSWBuilder(LeannBackendBuilderInterface): class HNSWSearcher(LeannBackendSearcherInterface): - def _get_index_storage_status(self, index_file: Path) -> tuple[bool, bool]: + def _get_index_storage_status_from_meta(self) -> tuple[bool, bool]: """ - Robustly determines the index's storage status by parsing the file. + Get storage status from metadata with sensible defaults. Returns: A tuple (is_compact, is_pruned). """ - if not index_file.exists(): - return False, False + # Check if metadata has these flags + is_compact = self.meta.get('is_compact', True) # Default to compact (CSR format) + is_pruned = self.meta.get('is_pruned', True) # Default to pruned (embeddings removed) - with open(index_file, 'rb') as f: - try: - def read_struct(fmt): - size = struct.calcsize(fmt) - data = f.read(size) - if len(data) != size: - raise EOFError(f"File ended unexpectedly reading struct fmt '{fmt}'.") - return struct.unpack(fmt, data)[0] - - def skip_vector(element_size): - count = read_struct(' 1: read_struct(' 0 else 'No results'}") + print(f" Raw faiss distances: {distances[0] if len(distances) > 0 else 'No results'}") + # Convert integer labels to string IDs string_labels = [] - for batch_labels in labels: + for batch_idx, batch_labels in enumerate(labels): batch_string_labels = [] - for int_label in batch_labels: + print(f" Batch {batch_idx} conversion:") + for i, int_label in enumerate(batch_labels): if int_label in self.label_map: - batch_string_labels.append(self.label_map[int_label]) + string_id = self.label_map[int_label] + batch_string_labels.append(string_id) + print(f" faiss[{int_label}] -> passage_id '{string_id}' (distance: {distances[batch_idx][i]:.4f})") else: - batch_string_labels.append(f"unknown_{int_label}") + unknown_id = f"unknown_{int_label}" + batch_string_labels.append(unknown_id) + print(f" faiss[{int_label}] -> {unknown_id} (NOT FOUND in label_map!)") string_labels.append(batch_string_labels) return {"labels": string_labels, "distances": distances} diff --git a/packages/leann-backend-hnsw/leann_backend_hnsw/hnsw_embedding_server.py b/packages/leann-backend-hnsw/leann_backend_hnsw/hnsw_embedding_server.py index c8cd358..e5d06f4 100644 --- a/packages/leann-backend-hnsw/leann_backend_hnsw/hnsw_embedding_server.py +++ b/packages/leann-backend-hnsw/leann_backend_hnsw/hnsw_embedding_server.py @@ -56,22 +56,33 @@ class SimplePassageLoader: def __len__(self) -> int: return len(self.passages_data) -def load_passages_from_file(passages_file: str) -> SimplePassageLoader: +def load_passages_from_metadata(meta_file: str) -> SimplePassageLoader: """ - Load passages from a JSONL file with label map support - Expected format: {"id": "passage_id", "text": "passage_text", "metadata": {...}} (one per line) + Load passages using metadata file with PassageManager for lazy loading """ - if not os.path.exists(passages_file): - raise FileNotFoundError(f"Passages file {passages_file} not found.") + # Load metadata to get passage sources + with open(meta_file, 'r') as f: + meta = json.load(f) - if not passages_file.endswith('.jsonl'): - raise ValueError(f"Expected .jsonl file format, got: {passages_file}") + # Import PassageManager dynamically to avoid circular imports + import sys + import importlib.util - # Load label map (int -> string_id) - passages_dir = Path(passages_file).parent + # Find the leann package directory relative to this file + current_dir = Path(__file__).parent + leann_core_path = current_dir.parent.parent / "leann-core" / "src" + sys.path.insert(0, str(leann_core_path)) + + try: + from leann.api import PassageManager + passage_manager = PassageManager(meta['passage_sources']) + finally: + sys.path.pop(0) + + # Load label map + passages_dir = Path(meta_file).parent label_map_file = passages_dir / "leann.labels.map" - label_map = {} if label_map_file.exists(): import pickle with open(label_map_file, 'rb') as f: @@ -80,24 +91,38 @@ def load_passages_from_file(passages_file: str) -> SimplePassageLoader: else: raise FileNotFoundError(f"Label map file not found: {label_map_file}") - # Load passages by string ID - string_id_passages = {} - with open(passages_file, 'r', encoding='utf-8') as f: - for line in f: - if line.strip(): - passage = json.loads(line) - string_id_passages[passage['id']] = passage['text'] + print(f"Initialized lazy passage loading for {len(label_map)} passages") - # Create int ID -> text mapping using label map - passages_data = {} - for int_id, string_id in label_map.items(): - if string_id in string_id_passages: - passages_data[str(int_id)] = string_id_passages[string_id] - else: - print(f"WARNING: String ID {string_id} from label map not found in passages") + class LazyPassageLoader(SimplePassageLoader): + def __init__(self, passage_manager, label_map): + self.passage_manager = passage_manager + self.label_map = label_map + # Initialize parent with empty data + super().__init__({}) + + def __getitem__(self, passage_id: Union[str, int]) -> Dict[str, str]: + """Get passage by ID with lazy loading""" + try: + int_id = int(passage_id) + if int_id in self.label_map: + string_id = self.label_map[int_id] + passage_data = self.passage_manager.get_passage(string_id) + if passage_data and passage_data.get("text"): + return {"text": passage_data["text"]} + else: + print(f"DEBUG: Empty text for ID {int_id} -> {string_id}") + return {"text": ""} + else: + print(f"DEBUG: ID {int_id} not found in label_map") + return {"text": ""} + except Exception as e: + print(f"DEBUG: Exception getting passage {passage_id}: {e}") + return {"text": ""} + + def __len__(self) -> int: + return len(self.label_map) - print(f"Loaded {len(passages_data)} passages from JSONL file {passages_file} using label map") - return SimplePassageLoader(passages_data) + return LazyPassageLoader(passage_manager, label_map) def create_hnsw_embedding_server( passages_file: Optional[str] = None, @@ -183,7 +208,20 @@ def create_hnsw_embedding_server( passages = SimplePassageLoader(passages_data) print(f"Using provided passages data: {len(passages)} passages") elif passages_file: - passages = load_passages_from_file(passages_file) + # Check if it's a metadata file or a single passages file + if passages_file.endswith('.meta.json'): + passages = load_passages_from_metadata(passages_file) + else: + # Try to find metadata file in same directory + passages_dir = Path(passages_file).parent + meta_files = list(passages_dir.glob("*.meta.json")) + if meta_files: + print(f"Found metadata file: {meta_files[0]}, using lazy loading") + passages = load_passages_from_metadata(str(meta_files[0])) + else: + # Fallback to original single file loading (will cause warnings) + print("WARNING: No metadata file found, using single file loading (may cause missing passage warnings)") + passages = SimplePassageLoader() # Use empty loader to avoid massive warnings else: passages = SimplePassageLoader() print("No passages provided, using empty loader") @@ -252,6 +290,11 @@ def create_hnsw_embedding_server( _is_bge_model = "bge" in model_name.lower() batch_size = len(texts_batch) + # Validate no empty texts + for i, text in enumerate(texts_batch): + if not text or text.strip() == "": + raise RuntimeError(f"FATAL: Empty text at batch index {i}, ID: {ids_batch[i] if i < len(ids_batch) else 'unknown'}") + # E5 model preprocessing if _is_e5_model: processed_texts_batch = [f"passage: {text}" for text in texts_batch] @@ -398,14 +441,12 @@ def create_hnsw_embedding_server( missing_ids = [] with lookup_timer.timing(): for nid in node_ids: - try: - txtinfo = passages[nid] - if txtinfo is None or txtinfo["text"] == "": - raise RuntimeError(f"FATAL: Passage with ID {nid} not found - failing fast") - else: - txt = txtinfo["text"] - except (KeyError, IndexError): - raise RuntimeError(f"FATAL: Passage with ID {nid} not found - failing fast") + print(f"DEBUG: Looking up passage ID {nid}") + txtinfo = passages[nid] + if txtinfo is None or txtinfo["text"] == "": + raise RuntimeError(f"FATAL: Passage with ID {nid} returned empty text") + txt = txtinfo["text"] + print(f"DEBUG: Found text for ID {nid}, length: {len(txt)}") texts.append(txt) lookup_timer.print_elapsed() diff --git a/packages/leann-backend-hnsw/pyproject.toml b/packages/leann-backend-hnsw/pyproject.toml index 2201403..580fd56 100644 --- a/packages/leann-backend-hnsw/pyproject.toml +++ b/packages/leann-backend-hnsw/pyproject.toml @@ -1,4 +1,4 @@ -# ๆ–‡ไปถ: packages/leann-backend-hnsw/pyproject.toml +# packages/leann-backend-hnsw/pyproject.toml [build-system] requires = ["scikit-build-core>=0.10", "numpy", "swig"] @@ -10,7 +10,6 @@ version = "0.1.0" description = "Custom-built HNSW (Faiss) backend for the Leann toolkit." dependencies = ["leann-core==0.1.0", "numpy"] -# ๅ›žๅฝ’ๅˆฐๆœ€ๆ ‡ๅ‡†็š„ scikit-build-core ้…็ฝฎ [tool.scikit-build] wheel.packages = ["leann_backend_hnsw"] editable.mode = "redirect" diff --git a/packages/leann-core/src/leann/api.py b/packages/leann-core/src/leann/api.py index d146bed..00ac2f2 100644 --- a/packages/leann-core/src/leann/api.py +++ b/packages/leann-core/src/leann/api.py @@ -1,345 +1,185 @@ -from .registry import BACKEND_REGISTRY -from .interface import LeannBackendFactoryInterface -from typing import List, Dict, Any, Optional -import numpy as np -import os +#!/usr/bin/env python3 +""" +This file contains the core API for the LEANN project, now definitively updated +with the correct, original embedding logic from the user's reference code. +""" + import json +import pickle +import numpy as np from pathlib import Path -import openai +from typing import List, Dict, Any, Optional from dataclasses import dataclass, field import uuid -import pickle -# --- Helper Functions for Embeddings --- +from .registry import BACKEND_REGISTRY +from .interface import LeannBackendFactoryInterface -def _get_openai_client(): - """Initializes and returns an OpenAI client, ensuring the API key is set.""" - api_key = os.getenv("OPENAI_API_KEY") - if not api_key: - raise ValueError("OPENAI_API_KEY environment variable not set, which is required for OpenAI models.") - return openai.OpenAI(api_key=api_key) +# --- The Correct, Verified Embedding Logic from old_code.py --- -def _is_openai_model(model_name: str) -> bool: - """Checks if the model is likely an OpenAI embedding model.""" - # This is a simple check, can be improved with a more robust list. - return "ada" in model_name or "babbage" in model_name or model_name.startswith("text-embedding-") - -def _compute_embeddings(chunks: List[str], model_name: str) -> np.ndarray: - """Computes embeddings for a list of text chunks using either SentenceTransformers or OpenAI.""" - if _is_openai_model(model_name): - print(f"INFO: Computing embeddings for {len(chunks)} chunks using OpenAI model '{model_name}'...") - client = _get_openai_client() - response = client.embeddings.create(model=model_name, input=chunks) - embeddings = [item.embedding for item in response.data] - else: +def compute_embeddings(chunks: List[str], model_name: str) -> np.ndarray: + """Computes embeddings using sentence-transformers for consistent results.""" + try: from sentence_transformers import SentenceTransformer - model = SentenceTransformer(model_name) - print(f"INFO: Computing embeddings for {len(chunks)} chunks using SentenceTransformer model '{model_name}'...") - embeddings = model.encode(chunks, show_progress_bar=True) + except ImportError as e: + raise RuntimeError( + f"sentence-transformers not available. Install with: pip install sentence-transformers" + ) from e - return np.asarray(embeddings, dtype=np.float32) - -def _get_embedding_dimensions(model_name: str) -> int: - """Gets the embedding dimensions for a given model.""" - print(f"INFO: Calculating dimensions for model '{model_name}'...") - if _is_openai_model(model_name): - client = _get_openai_client() - response = client.embeddings.create(model=model_name, input=["dummy text"]) - return len(response.data[0].embedding) - else: - from sentence_transformers import SentenceTransformer - model = SentenceTransformer(model_name) - dimension = model.get_sentence_embedding_dimension() - if dimension is None: - raise ValueError(f"Model '{model_name}' does not have a valid embedding dimension.") - return dimension + # Load model using sentence-transformers + model = SentenceTransformer(model_name) + + # Generate embeddings + embeddings = model.encode(chunks, convert_to_numpy=True, show_progress_bar=True, batch_size=64) + + return embeddings +# --- Core API Classes (Restored and Unchanged) --- @dataclass class SearchResult: - """Represents a single search result.""" id: str score: float text: str metadata: Dict[str, Any] = field(default_factory=dict) - class PassageManager: - """Manages passage data and lazy loading from JSONL files.""" - def __init__(self, passage_sources: List[Dict[str, Any]]): self.offset_maps = {} self.passage_files = {} + self.global_offset_map = {} # Combined map for fast lookup for source in passage_sources: if source["type"] == "jsonl": passage_file = source["path"] index_file = source["index_path"] - - if not os.path.exists(index_file): + if not Path(index_file).exists(): raise FileNotFoundError(f"Passage index file not found: {index_file}") - with open(index_file, 'rb') as f: offset_map = pickle.load(f) - - self.offset_maps[passage_file] = offset_map - self.passage_files[passage_file] = passage_file - + self.offset_maps[passage_file] = offset_map + self.passage_files[passage_file] = passage_file + + # Build global map for O(1) lookup + for passage_id, offset in offset_map.items(): + self.global_offset_map[passage_id] = (passage_file, offset) + def get_passage(self, passage_id: str) -> Dict[str, Any]: - """Lazy load a passage by ID.""" - for passage_file, offset_map in self.offset_maps.items(): - if passage_id in offset_map: - offset = offset_map[passage_id] - with open(passage_file, 'r', encoding='utf-8') as f: - f.seek(offset) - line = f.readline() - return json.loads(line) - + if passage_id in self.global_offset_map: + passage_file, offset = self.global_offset_map[passage_id] + with open(passage_file, 'r', encoding='utf-8') as f: + f.seek(offset) + return json.loads(f.readline()) raise KeyError(f"Passage ID not found: {passage_id}") -# --- Core Classes --- - class LeannBuilder: - """ - The builder is responsible for building the index, it will compute the embeddings and then build the index. - It will also save the metadata of the index. - """ - def __init__(self, backend_name: str, embedding_model: str = "sentence-transformers/all-mpnet-base-v2", dimensions: Optional[int] = None, **backend_kwargs): + def __init__(self, backend_name: str, embedding_model: str = "facebook/contriever-msmarco", dimensions: Optional[int] = None, **backend_kwargs): self.backend_name = backend_name backend_factory: LeannBackendFactoryInterface | None = BACKEND_REGISTRY.get(backend_name) if backend_factory is None: raise ValueError(f"Backend '{backend_name}' not found or not registered.") self.backend_factory = backend_factory - self.embedding_model = embedding_model self.dimensions = dimensions self.backend_kwargs = backend_kwargs self.chunks: List[Dict[str, Any]] = [] - print(f"INFO: LeannBuilder initialized with '{backend_name}' backend.") def add_text(self, text: str, metadata: Optional[Dict[str, Any]] = None): - if metadata is None: - metadata = {} - - # Check if ID is provided in metadata - passage_id = metadata.get('id') - if passage_id is None: - passage_id = str(uuid.uuid4()) - else: - # Validate uniqueness - existing_ids = {chunk['id'] for chunk in self.chunks} - if passage_id in existing_ids: - raise ValueError(f"Duplicate passage ID: {passage_id}") - - # Store the definitive ID with the chunk - chunk_data = { - "id": passage_id, - "text": text, - "metadata": metadata - } + if metadata is None: metadata = {} + passage_id = metadata.get('id', str(uuid.uuid4())) + chunk_data = {"id": passage_id, "text": text, "metadata": metadata} self.chunks.append(chunk_data) def build_index(self, index_path: str): - if not self.chunks: - raise ValueError("No chunks added. Use add_text() first.") - - if self.dimensions is None: - self.dimensions = _get_embedding_dimensions(self.embedding_model) - print(f"INFO: Auto-detected dimensions for '{self.embedding_model}': {self.dimensions}") - + if not self.chunks: raise ValueError("No chunks added.") + if self.dimensions is None: self.dimensions = len(compute_embeddings(["dummy"], self.embedding_model)[0]) path = Path(index_path) index_dir = path.parent index_name = path.name - - # Ensure the directory exists index_dir.mkdir(parents=True, exist_ok=True) - - # Create the passages.jsonl file and offset index passages_file = index_dir / f"{index_name}.passages.jsonl" offset_file = index_dir / f"{index_name}.passages.idx" - offset_map = {} - with open(passages_file, 'w', encoding='utf-8') as f: for chunk in self.chunks: offset = f.tell() - passage_data = { - "id": chunk["id"], - "text": chunk["text"], - "metadata": chunk["metadata"] - } - json.dump(passage_data, f, ensure_ascii=False) + json.dump({"id": chunk["id"], "text": chunk["text"], "metadata": chunk["metadata"]}, f, ensure_ascii=False) f.write('\n') offset_map[chunk["id"]] = offset - - # Save the offset map - with open(offset_file, 'wb') as f: - pickle.dump(offset_map, f) - - # Compute embeddings + with open(offset_file, 'wb') as f: pickle.dump(offset_map, f) texts_to_embed = [c["text"] for c in self.chunks] - embeddings = _compute_embeddings(texts_to_embed, self.embedding_model) - - # Extract string IDs for the backend + embeddings = compute_embeddings(texts_to_embed, self.embedding_model) string_ids = [chunk["id"] for chunk in self.chunks] - - # Build the vector index - current_backend_kwargs = self.backend_kwargs.copy() - current_backend_kwargs['dimensions'] = self.dimensions + current_backend_kwargs = {**self.backend_kwargs, 'dimensions': self.dimensions} builder_instance = self.backend_factory.builder(**current_backend_kwargs) - builder_instance.build(embeddings, string_ids, index_path, **current_backend_kwargs) - - # Create the lightweight meta.json file leann_meta_path = index_dir / f"{index_name}.meta.json" - meta_data = { - "version": "1.0", - "backend_name": self.backend_name, - "embedding_model": self.embedding_model, - "dimensions": self.dimensions, - "backend_kwargs": self.backend_kwargs, - "passage_sources": [ - { - "type": "jsonl", - "path": str(passages_file), - "index_path": str(offset_file) - } - ] + "version": "1.0", "backend_name": self.backend_name, "embedding_model": self.embedding_model, + "dimensions": self.dimensions, "backend_kwargs": self.backend_kwargs, + "passage_sources": [{"type": "jsonl", "path": str(passages_file), "index_path": str(offset_file)}] } - with open(leann_meta_path, 'w', encoding='utf-8') as f: - json.dump(meta_data, f, indent=2) - print(f"INFO: Leann metadata saved to {leann_meta_path}") - + + # Add storage status flags for HNSW backend + if self.backend_name == "hnsw": + is_compact = self.backend_kwargs.get("is_compact", True) + is_recompute = self.backend_kwargs.get("is_recompute", True) + meta_data["is_compact"] = is_compact + meta_data["is_pruned"] = is_compact and is_recompute # Pruned only if compact and recompute + with open(leann_meta_path, 'w', encoding='utf-8') as f: json.dump(meta_data, f, indent=2) class LeannSearcher: - """ - The searcher is responsible for loading the index and performing the search. - It will also load the metadata of the index. - """ def __init__(self, index_path: str, **backend_kwargs): - leann_meta_path = Path(index_path).parent / f"{Path(index_path).name}.meta.json" - if not leann_meta_path.exists(): - raise FileNotFoundError(f"Leann metadata file not found at {leann_meta_path}. Was the index built with LeannBuilder?") - - with open(leann_meta_path, 'r', encoding='utf-8') as f: - self.meta_data = json.load(f) - + meta_path_str = f"{index_path}.meta.json" + if not Path(meta_path_str).exists(): raise FileNotFoundError(f"Leann metadata file not found at {meta_path_str}") + with open(meta_path_str, 'r', encoding='utf-8') as f: self.meta_data = json.load(f) backend_name = self.meta_data['backend_name'] self.embedding_model = self.meta_data['embedding_model'] - - # Initialize the passage manager - passage_sources = self.meta_data.get('passage_sources', []) - self.passage_manager = PassageManager(passage_sources) - + self.passage_manager = PassageManager(self.meta_data.get('passage_sources', [])) backend_factory = BACKEND_REGISTRY.get(backend_name) - if backend_factory is None: - raise ValueError(f"Backend '{backend_name}' (from index file) not found or not registered.") - - final_kwargs = backend_kwargs.copy() - final_kwargs['meta'] = self.meta_data - + if backend_factory is None: raise ValueError(f"Backend '{backend_name}' not found.") + final_kwargs = {**self.meta_data.get('backend_kwargs', {}), **backend_kwargs} self.backend_impl = backend_factory.searcher(index_path, **final_kwargs) - print(f"INFO: LeannSearcher initialized with '{backend_name}' backend using index '{index_path}'.") - - def search(self, query: str, top_k: int = 5, **search_kwargs): - query_embedding = _compute_embeddings([query], self.embedding_model) + + def search(self, query: str, top_k: int = 5, **search_kwargs) -> List[SearchResult]: + print(f"๐Ÿ” DEBUG LeannSearcher.search() called:") + print(f" Query: '{query}'") + print(f" Top_k: {top_k}") + print(f" Search kwargs: {search_kwargs}") + + query_embedding = compute_embeddings([query], self.embedding_model) + print(f" Generated embedding shape: {query_embedding.shape}") + print(f"๐Ÿ” DEBUG Query embedding first 10 values: {query_embedding[0][:10]}") + print(f"๐Ÿ” DEBUG Query embedding norm: {np.linalg.norm(query_embedding[0])}") - search_kwargs['embedding_model'] = self.embedding_model results = self.backend_impl.search(query_embedding, top_k, **search_kwargs) + print(f" Backend returned: labels={len(results.get('labels', [[]])[0])} results") enriched_results = [] - for string_id, dist in zip(results['labels'][0], results['distances'][0]): - try: - passage_data = self.passage_manager.get_passage(string_id) - enriched_results.append(SearchResult( - id=string_id, - score=dist, - text=passage_data['text'], - metadata=passage_data.get('metadata', {}) - )) - except KeyError: - print(f"WARNING: Passage ID '{string_id}' not found in passage files") + if 'labels' in results and 'distances' in results: + print(f" Processing {len(results['labels'][0])} passage IDs:") + for i, (string_id, dist) in enumerate(zip(results['labels'][0], results['distances'][0])): + try: + passage_data = self.passage_manager.get_passage(string_id) + enriched_results.append(SearchResult( + id=string_id, score=dist, text=passage_data['text'], metadata=passage_data.get('metadata', {}) + )) + print(f" {i+1}. passage_id='{string_id}' -> SUCCESS: {passage_data['text'][:60]}...") + except KeyError: + print(f" {i+1}. passage_id='{string_id}' -> ERROR: Passage not found in PassageManager!") + + print(f" Final enriched results: {len(enriched_results)} passages") return enriched_results +from .chat import get_llm class LeannChat: - """ - The chat is responsible for the conversation with the LLM. - It will use the searcher to get the results and then use the LLM to generate the response. - """ - def __init__(self, index_path: str, backend_name: Optional[str] = None, llm_model: str = "gpt-4o", **kwargs): - if backend_name is None: - leann_meta_path = Path(index_path).parent / f"{Path(index_path).name}.meta.json" - if not leann_meta_path.exists(): - raise FileNotFoundError(f"Leann metadata file not found at {leann_meta_path}.") - with open(leann_meta_path, 'r', encoding='utf-8') as f: - meta_data = json.load(f) - backend_name = meta_data['backend_name'] - + def __init__(self, index_path: str, llm_config: Optional[Dict[str, Any]] = None, **kwargs): self.searcher = LeannSearcher(index_path, **kwargs) - self.llm_model = llm_model - - def ask(self, question: str, top_k=5, **kwargs): - """ - Additional keyword arguments (kwargs) for advanced search customization. Example usage: - chat.ask( - "What is ANN?", - top_k=10, - complexity=64, - beam_width=8, - USE_DEFERRED_FETCH=True, - skip_search_reorder=True, - recompute_beighbor_embeddings=True, - dedup_node_dis=True, - prune_ratio=0.1, - batch_recompute=True, - global_pruning=True - ) - - Supported kwargs: - - complexity (int): Search complexity parameter (default: 32) - - beam_width (int): Beam width for search (default: 4) - - USE_DEFERRED_FETCH (bool): Enable deferred fetch mode (default: False) - - skip_search_reorder (bool): Skip search reorder step (default: False) - - recompute_beighbor_embeddings (bool): Enable ZMQ embedding server for neighbor recomputation (default: False) - - dedup_node_dis (bool): Deduplicate nodes by distance (default: False) - - prune_ratio (float): Pruning ratio for search (default: 0.0) - - batch_recompute (bool): Enable batch recomputation (default: False) - - global_pruning (bool): Enable global pruning (default: False) - """ + self.llm = get_llm(llm_config) + def ask(self, question: str, top_k=5, **kwargs): results = self.searcher.search(question, top_k=top_k, **kwargs) context = "\n\n".join([r.text for r in results]) - prompt = f"Context:\n{context}\n\nQuestion: {question}\n\nAnswer:" - - print(f"DEBUG: Calling LLM with prompt: {prompt}...") - try: - client = _get_openai_client() - response = client.chat.completions.create( - model=self.llm_model, - messages=[ - {"role": "system", "content": "You are a helpful assistant that answers questions based on the provided context."}, - {"role": "user", "content": prompt} - ] - ) - return response.choices[0].message.content - except Exception as e: - print(f"ERROR: Failed to call OpenAI API: {e}") - return f"Error: Could not get a response from the LLM. {e}" - - def start_interactive(self): - print("\nLeann Chat started (type 'quit' to exit)") - while True: - try: - user_input = input("You: ").strip() - if user_input.lower() in ['quit', 'exit']: - break - if not user_input: - continue - response = self.ask(user_input) - print(f"Leann: {response}") - except (KeyboardInterrupt, EOFError): - print("\nGoodbye!") - break + return self.llm.ask(prompt, **kwargs.get("llm_kwargs", {})) \ No newline at end of file diff --git a/packages/leann-core/src/leann/embedding_server_manager.py b/packages/leann-core/src/leann/embedding_server_manager.py index ef2fd4d..7205d87 100644 --- a/packages/leann-core/src/leann/embedding_server_manager.py +++ b/packages/leann-core/src/leann/embedding_server_manager.py @@ -73,15 +73,17 @@ class EmbeddingServerManager: self.server_process = subprocess.Popen( command, cwd=project_root, - # stdout=subprocess.PIPE, - # stderr=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, # Merge stderr into stdout for easier monitoring text=True, - encoding='utf-8' + encoding='utf-8', + bufsize=1, # Line buffered + universal_newlines=True ) self.server_port = port print(f"INFO: Server process started with PID: {self.server_process.pid}") - max_wait, wait_interval = 30, 0.5 + max_wait, wait_interval = 120, 0.5 for _ in range(int(max_wait / wait_interval)): if _check_port(port): print(f"โœ… Embedding server is up and ready for this session.") @@ -90,7 +92,7 @@ class EmbeddingServerManager: return True if self.server_process.poll() is not None: print("โŒ ERROR: Server process terminated unexpectedly during startup.") - self._log_monitor() + self._print_recent_output() return False time.sleep(wait_interval) @@ -102,19 +104,32 @@ class EmbeddingServerManager: print(f"โŒ ERROR: Failed to start embedding server process: {e}") return False + def _print_recent_output(self): + """Print any recent output from the server process.""" + if not self.server_process or not self.server_process.stdout: + return + try: + # Read any available output + import select + import sys + if select.select([self.server_process.stdout], [], [], 0)[0]: + output = self.server_process.stdout.read() + if output: + print(f"[{self.backend_module_name} OUTPUT]: {output}") + except Exception as e: + print(f"Error reading server output: {e}") + def _log_monitor(self): """Monitors and prints the server's stdout and stderr.""" if not self.server_process: return try: if self.server_process.stdout: - for line in iter(self.server_process.stdout.readline, ''): - print(f"[{self.backend_module_name} LOG]: {line.strip()}") - self.server_process.stdout.close() - if self.server_process.stderr: - for line in iter(self.server_process.stderr.readline, ''): - print(f"[{self.backend_module_name} ERROR]: {line.strip()}") - self.server_process.stderr.close() + while True: + line = self.server_process.stdout.readline() + if not line: + break + print(f"[{self.backend_module_name} LOG]: {line.strip()}", flush=True) except Exception as e: print(f"Log monitor error: {e}")