From 71ef4b7d4cd95a74ea176baa1d681196328fbeec Mon Sep 17 00:00:00 2001 From: Andy Lee Date: Sat, 12 Jul 2025 18:13:22 -0700 Subject: [PATCH] fix: reproducible dpr on mac --- data/.gitattributes | 82 +++++++++++++++++ examples/run_evaluation.py | 87 ++++++++++++------- .../leann-backend-diskann/third_party/DiskANN | 2 +- 3 files changed, 138 insertions(+), 33 deletions(-) create mode 100644 data/.gitattributes diff --git a/data/.gitattributes b/data/.gitattributes new file mode 100644 index 0000000..4fb7c03 --- /dev/null +++ b/data/.gitattributes @@ -0,0 +1,82 @@ +*.7z filter=lfs diff=lfs merge=lfs -text +*.arrow filter=lfs diff=lfs merge=lfs -text +*.bin filter=lfs diff=lfs merge=lfs -text +*.bz2 filter=lfs diff=lfs merge=lfs -text +*.ckpt filter=lfs diff=lfs merge=lfs -text +*.ftz filter=lfs diff=lfs merge=lfs -text +*.gz filter=lfs diff=lfs merge=lfs -text +*.h5 filter=lfs diff=lfs merge=lfs -text +*.joblib filter=lfs diff=lfs merge=lfs -text +*.lfs.* filter=lfs diff=lfs merge=lfs -text +*.lz4 filter=lfs diff=lfs merge=lfs -text +*.mds filter=lfs diff=lfs merge=lfs -text +*.mlmodel filter=lfs diff=lfs merge=lfs -text +*.model filter=lfs diff=lfs merge=lfs -text +*.msgpack filter=lfs diff=lfs merge=lfs -text +*.npy filter=lfs diff=lfs merge=lfs -text +*.npz filter=lfs diff=lfs merge=lfs -text +*.onnx filter=lfs diff=lfs merge=lfs -text +*.ot filter=lfs diff=lfs merge=lfs -text +*.parquet filter=lfs diff=lfs merge=lfs -text +*.pb filter=lfs diff=lfs merge=lfs -text +*.pickle filter=lfs diff=lfs merge=lfs -text +*.pkl filter=lfs diff=lfs merge=lfs -text +*.pt filter=lfs diff=lfs merge=lfs -text +*.pth filter=lfs diff=lfs merge=lfs -text +*.rar filter=lfs diff=lfs merge=lfs -text +*.safetensors filter=lfs diff=lfs merge=lfs -text +saved_model/**/* filter=lfs diff=lfs merge=lfs -text +*.tar.* filter=lfs diff=lfs merge=lfs -text +*.tar filter=lfs diff=lfs merge=lfs -text +*.tflite filter=lfs diff=lfs merge=lfs -text +*.tgz filter=lfs diff=lfs merge=lfs -text +*.wasm filter=lfs diff=lfs merge=lfs -text +*.xz filter=lfs diff=lfs merge=lfs -text +*.zip filter=lfs diff=lfs merge=lfs -text +*.zst filter=lfs diff=lfs merge=lfs -text +*tfevents* filter=lfs diff=lfs merge=lfs -text +# Audio files - uncompressed +*.pcm filter=lfs diff=lfs merge=lfs -text +*.sam filter=lfs diff=lfs merge=lfs -text +*.raw filter=lfs diff=lfs merge=lfs -text +# Audio files - compressed +*.aac filter=lfs diff=lfs merge=lfs -text +*.flac filter=lfs diff=lfs merge=lfs -text +*.mp3 filter=lfs diff=lfs merge=lfs -text +*.ogg filter=lfs diff=lfs merge=lfs -text +*.wav filter=lfs diff=lfs merge=lfs -text +# Image files - uncompressed +*.bmp filter=lfs diff=lfs merge=lfs -text +*.gif filter=lfs diff=lfs merge=lfs -text +*.png filter=lfs diff=lfs merge=lfs -text +*.tiff filter=lfs diff=lfs merge=lfs -text +# Image files - compressed +*.jpg filter=lfs diff=lfs merge=lfs -text +*.jpeg filter=lfs diff=lfs merge=lfs -text +*.webp filter=lfs diff=lfs merge=lfs -text +# Video files - compressed +*.mp4 filter=lfs diff=lfs merge=lfs -text +*.webm filter=lfs diff=lfs merge=lfs -text +ground_truth/dpr/id_map.json filter=lfs diff=lfs merge=lfs -text +indices/dpr/dpr_diskann.passages.idx filter=lfs diff=lfs merge=lfs -text +indices/dpr/dpr_diskann.passages.jsonl filter=lfs diff=lfs merge=lfs -text +indices/dpr/dpr_diskann_disk.index filter=lfs diff=lfs merge=lfs -text +indices/dpr/leann.labels.map filter=lfs diff=lfs merge=lfs -text +indices/rpj_wiki/leann.labels.map filter=lfs diff=lfs merge=lfs -text +indices/rpj_wiki/rpj_wiki.index filter=lfs diff=lfs merge=lfs -text +indices/rpj_wiki/rpj_wiki.passages.0.idx filter=lfs diff=lfs merge=lfs -text +indices/rpj_wiki/rpj_wiki.passages.0.jsonl filter=lfs diff=lfs merge=lfs -text +indices/rpj_wiki/rpj_wiki.passages.1.idx filter=lfs diff=lfs merge=lfs -text +indices/rpj_wiki/rpj_wiki.passages.1.jsonl filter=lfs diff=lfs merge=lfs -text +indices/rpj_wiki/rpj_wiki.passages.2.idx filter=lfs diff=lfs merge=lfs -text +indices/rpj_wiki/rpj_wiki.passages.2.jsonl filter=lfs diff=lfs merge=lfs -text +indices/rpj_wiki/rpj_wiki.passages.3.idx filter=lfs diff=lfs merge=lfs -text +indices/rpj_wiki/rpj_wiki.passages.3.jsonl filter=lfs diff=lfs merge=lfs -text +indices/rpj_wiki/rpj_wiki.passages.4.idx filter=lfs diff=lfs merge=lfs -text +indices/rpj_wiki/rpj_wiki.passages.4.jsonl filter=lfs diff=lfs merge=lfs -text +indices/rpj_wiki/rpj_wiki.passages.5.idx filter=lfs diff=lfs merge=lfs -text +indices/rpj_wiki/rpj_wiki.passages.5.jsonl filter=lfs diff=lfs merge=lfs -text +indices/rpj_wiki/rpj_wiki.passages.6.idx filter=lfs diff=lfs merge=lfs -text +indices/rpj_wiki/rpj_wiki.passages.6.jsonl filter=lfs diff=lfs merge=lfs -text +indices/rpj_wiki/rpj_wiki.passages.7.idx filter=lfs diff=lfs merge=lfs -text +indices/rpj_wiki/rpj_wiki.passages.7.jsonl filter=lfs diff=lfs merge=lfs -text diff --git a/examples/run_evaluation.py b/examples/run_evaluation.py index 21421d0..f80d7a6 100644 --- a/examples/run_evaluation.py +++ b/examples/run_evaluation.py @@ -11,36 +11,37 @@ import time from pathlib import Path import sys import numpy as np -from typing import List, Dict, Any - -# Add project root to path to allow importing from leann -project_root = Path(__file__).resolve().parent.parent -sys.path.insert(0, str(project_root)) +from typing import List from leann.api import LeannSearcher + def download_data_if_needed(data_root: Path): """Checks if the data directory exists, and if not, downloads it from HF Hub.""" if not data_root.exists(): print(f"Data directory '{data_root}' not found.") - print("Downloading evaluation data from Hugging Face Hub... (this may take a moment)") + print( + "Downloading evaluation data from Hugging Face Hub... (this may take a moment)" + ) try: from huggingface_hub import snapshot_download + snapshot_download( repo_id="LEANN-RAG/leann-rag-evaluation-data", repo_type="dataset", local_dir=data_root, - local_dir_use_symlinks=False # Recommended for Windows compatibility and simpler structure + local_dir_use_symlinks=False, # Recommended for Windows compatibility and simpler structure ) print("Data download complete!") except ImportError: - print("Error: huggingface_hub is not installed. Please install it to download the data:") - print("pip install -e ".[dev]"") + print( + "Error: huggingface_hub is not installed. Please install it to download the data:" + ) + print("uv pip install -e '.[dev]'") sys.exit(1) except Exception as e: print(f"An error occurred during data download: {e}") sys.exit(1) -from leann.api import LeannSearcher # --- Helper Function to get Golden Passages --- @@ -54,29 +55,43 @@ def get_golden_texts(searcher: LeannSearcher, golden_ids: List[int]) -> set: try: # PassageManager uses string IDs passage_data = searcher.passage_manager.get_passage(str(gid)) - golden_texts.add(passage_data['text']) + golden_texts.add(passage_data["text"]) except KeyError: - print(f"Warning: Golden passage ID '{gid}' not found in the index's passage data.") + print( + f"Warning: Golden passage ID '{gid}' not found in the index's passage data." + ) return golden_texts + def load_queries(file_path: Path) -> List[str]: queries = [] - with open(file_path, 'r', encoding='utf-8') as f: + with open(file_path, "r", encoding="utf-8") as f: for line in f: data = json.loads(line) - queries.append(data['query']) + queries.append(data["query"]) return queries + def main(): - parser = argparse.ArgumentParser(description="Run recall evaluation on a LEANN index.") - parser.add_argument("index_path", type=str, help="Path to the LEANN index to evaluate.") - parser.add_argument("--num-queries", type=int, default=10, help="Number of queries to evaluate.") - parser.add_argument("--top-k", type=int, default=3, help="The 'k' value for recall@k.") - parser.add_argument("--ef-search", type=int, default=120, help="The 'efSearch' parameter for HNSW.") + parser = argparse.ArgumentParser( + description="Run recall evaluation on a LEANN index." + ) + parser.add_argument( + "index_path", type=str, help="Path to the LEANN index to evaluate." + ) + parser.add_argument( + "--num-queries", type=int, default=10, help="Number of queries to evaluate." + ) + parser.add_argument( + "--top-k", type=int, default=3, help="The 'k' value for recall@k." + ) + parser.add_argument( + "--ef-search", type=int, default=120, help="The 'efSearch' parameter for HNSW." + ) args = parser.parse_args() # --- Path Configuration --- - # Assumes a project structure where the script is in 'examples/' + # Assumes a project structure where the script is in 'examples/' # and data is in 'data/' at the project root. project_root = Path(__file__).resolve().parent.parent data_root = project_root / "data" @@ -93,10 +108,14 @@ def main(): else: # Fallback: try to infer from the index directory name dataset_type = Path(args.index_path).name - print(f"WARNING: Could not detect dataset type from path, inferred '{dataset_type}'.") + print( + f"WARNING: Could not detect dataset type from path, inferred '{dataset_type}'." + ) queries_file = data_root / "queries" / "nq_open.jsonl" - golden_results_file = data_root / "ground_truth" / dataset_type / "flat_results_nq_k3.json" + golden_results_file = ( + data_root / "ground_truth" / dataset_type / "flat_results_nq_k3.json" + ) print(f"INFO: Detected dataset type: {dataset_type}") print(f"INFO: Using queries file: {queries_file}") @@ -105,27 +124,29 @@ def main(): try: searcher = LeannSearcher(args.index_path) queries = load_queries(queries_file) - - with open(golden_results_file, 'r') as f: + + with open(golden_results_file, "r") as f: golden_results_data = json.load(f) - + num_eval_queries = min(args.num_queries, len(queries)) queries = queries[:num_eval_queries] - + print(f"\nRunning evaluation on {num_eval_queries} queries...") recall_scores = [] search_times = [] for i in range(num_eval_queries): start_time = time.time() - new_results = searcher.search(queries[i], top_k=args.top_k, ef=args.ef_search) + new_results = searcher.search( + queries[i], top_k=args.top_k, ef=args.ef_search + ) search_times.append(time.time() - start_time) # Correct Recall Calculation: Based on TEXT content new_texts = {result.text for result in new_results} - + # Get golden texts directly from the searcher's passage manager - golden_ids = golden_results_data["indices"][i][:args.top_k] + golden_ids = golden_results_data["indices"][i][: args.top_k] golden_texts = get_golden_texts(searcher, golden_ids) overlap = len(new_texts & golden_texts) @@ -139,19 +160,21 @@ def main(): print(f"Overlap: {overlap}") print(f"Recall: {recall}") print(f"Search Time: {search_times[-1]:.4f}s") - print(f"--------------------------------") + print("--------------------------------") avg_recall = np.mean(recall_scores) if recall_scores else 0 avg_time = np.mean(search_times) if search_times else 0 - print(f"\nšŸŽ‰ --- Evaluation Complete ---") + print("\nšŸŽ‰ --- Evaluation Complete ---") print(f"Avg. Recall@{args.top_k} (efSearch={args.ef_search}): {avg_recall:.4f}") print(f"Avg. Search Time: {avg_time:.4f}s") except Exception as e: print(f"\nāŒ An error occurred during evaluation: {e}") import traceback + traceback.print_exc() + if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/packages/leann-backend-diskann/third_party/DiskANN b/packages/leann-backend-diskann/third_party/DiskANN index 2dcf156..af2a264 160000 --- a/packages/leann-backend-diskann/third_party/DiskANN +++ b/packages/leann-backend-diskann/third_party/DiskANN @@ -1 +1 @@ -Subproject commit 2dcf156553050eeaf56e7b003f416fab70465429 +Subproject commit af2a26481e65232b57b82d96e68833cdee9f7635