fix: reproducible dpr on mac

This commit is contained in:
Andy Lee
2025-07-12 18:13:22 -07:00
parent ecab43e307
commit 71ef4b7d4c
3 changed files with 138 additions and 33 deletions

82
data/.gitattributes vendored Normal file
View File

@@ -0,0 +1,82 @@
*.7z filter=lfs diff=lfs merge=lfs -text
*.arrow filter=lfs diff=lfs merge=lfs -text
*.bin filter=lfs diff=lfs merge=lfs -text
*.bz2 filter=lfs diff=lfs merge=lfs -text
*.ckpt filter=lfs diff=lfs merge=lfs -text
*.ftz filter=lfs diff=lfs merge=lfs -text
*.gz filter=lfs diff=lfs merge=lfs -text
*.h5 filter=lfs diff=lfs merge=lfs -text
*.joblib filter=lfs diff=lfs merge=lfs -text
*.lfs.* filter=lfs diff=lfs merge=lfs -text
*.lz4 filter=lfs diff=lfs merge=lfs -text
*.mds filter=lfs diff=lfs merge=lfs -text
*.mlmodel filter=lfs diff=lfs merge=lfs -text
*.model filter=lfs diff=lfs merge=lfs -text
*.msgpack filter=lfs diff=lfs merge=lfs -text
*.npy filter=lfs diff=lfs merge=lfs -text
*.npz filter=lfs diff=lfs merge=lfs -text
*.onnx filter=lfs diff=lfs merge=lfs -text
*.ot filter=lfs diff=lfs merge=lfs -text
*.parquet filter=lfs diff=lfs merge=lfs -text
*.pb filter=lfs diff=lfs merge=lfs -text
*.pickle filter=lfs diff=lfs merge=lfs -text
*.pkl filter=lfs diff=lfs merge=lfs -text
*.pt filter=lfs diff=lfs merge=lfs -text
*.pth filter=lfs diff=lfs merge=lfs -text
*.rar filter=lfs diff=lfs merge=lfs -text
*.safetensors filter=lfs diff=lfs merge=lfs -text
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
*.tar.* filter=lfs diff=lfs merge=lfs -text
*.tar filter=lfs diff=lfs merge=lfs -text
*.tflite filter=lfs diff=lfs merge=lfs -text
*.tgz filter=lfs diff=lfs merge=lfs -text
*.wasm filter=lfs diff=lfs merge=lfs -text
*.xz filter=lfs diff=lfs merge=lfs -text
*.zip filter=lfs diff=lfs merge=lfs -text
*.zst filter=lfs diff=lfs merge=lfs -text
*tfevents* filter=lfs diff=lfs merge=lfs -text
# Audio files - uncompressed
*.pcm filter=lfs diff=lfs merge=lfs -text
*.sam filter=lfs diff=lfs merge=lfs -text
*.raw filter=lfs diff=lfs merge=lfs -text
# Audio files - compressed
*.aac filter=lfs diff=lfs merge=lfs -text
*.flac filter=lfs diff=lfs merge=lfs -text
*.mp3 filter=lfs diff=lfs merge=lfs -text
*.ogg filter=lfs diff=lfs merge=lfs -text
*.wav filter=lfs diff=lfs merge=lfs -text
# Image files - uncompressed
*.bmp filter=lfs diff=lfs merge=lfs -text
*.gif filter=lfs diff=lfs merge=lfs -text
*.png filter=lfs diff=lfs merge=lfs -text
*.tiff filter=lfs diff=lfs merge=lfs -text
# Image files - compressed
*.jpg filter=lfs diff=lfs merge=lfs -text
*.jpeg filter=lfs diff=lfs merge=lfs -text
*.webp filter=lfs diff=lfs merge=lfs -text
# Video files - compressed
*.mp4 filter=lfs diff=lfs merge=lfs -text
*.webm filter=lfs diff=lfs merge=lfs -text
ground_truth/dpr/id_map.json filter=lfs diff=lfs merge=lfs -text
indices/dpr/dpr_diskann.passages.idx filter=lfs diff=lfs merge=lfs -text
indices/dpr/dpr_diskann.passages.jsonl filter=lfs diff=lfs merge=lfs -text
indices/dpr/dpr_diskann_disk.index filter=lfs diff=lfs merge=lfs -text
indices/dpr/leann.labels.map filter=lfs diff=lfs merge=lfs -text
indices/rpj_wiki/leann.labels.map filter=lfs diff=lfs merge=lfs -text
indices/rpj_wiki/rpj_wiki.index filter=lfs diff=lfs merge=lfs -text
indices/rpj_wiki/rpj_wiki.passages.0.idx filter=lfs diff=lfs merge=lfs -text
indices/rpj_wiki/rpj_wiki.passages.0.jsonl filter=lfs diff=lfs merge=lfs -text
indices/rpj_wiki/rpj_wiki.passages.1.idx filter=lfs diff=lfs merge=lfs -text
indices/rpj_wiki/rpj_wiki.passages.1.jsonl filter=lfs diff=lfs merge=lfs -text
indices/rpj_wiki/rpj_wiki.passages.2.idx filter=lfs diff=lfs merge=lfs -text
indices/rpj_wiki/rpj_wiki.passages.2.jsonl filter=lfs diff=lfs merge=lfs -text
indices/rpj_wiki/rpj_wiki.passages.3.idx filter=lfs diff=lfs merge=lfs -text
indices/rpj_wiki/rpj_wiki.passages.3.jsonl filter=lfs diff=lfs merge=lfs -text
indices/rpj_wiki/rpj_wiki.passages.4.idx filter=lfs diff=lfs merge=lfs -text
indices/rpj_wiki/rpj_wiki.passages.4.jsonl filter=lfs diff=lfs merge=lfs -text
indices/rpj_wiki/rpj_wiki.passages.5.idx filter=lfs diff=lfs merge=lfs -text
indices/rpj_wiki/rpj_wiki.passages.5.jsonl filter=lfs diff=lfs merge=lfs -text
indices/rpj_wiki/rpj_wiki.passages.6.idx filter=lfs diff=lfs merge=lfs -text
indices/rpj_wiki/rpj_wiki.passages.6.jsonl filter=lfs diff=lfs merge=lfs -text
indices/rpj_wiki/rpj_wiki.passages.7.idx filter=lfs diff=lfs merge=lfs -text
indices/rpj_wiki/rpj_wiki.passages.7.jsonl filter=lfs diff=lfs merge=lfs -text

View File

@@ -11,36 +11,37 @@ import time
from pathlib import Path from pathlib import Path
import sys import sys
import numpy as np import numpy as np
from typing import List, Dict, Any from typing import List
# Add project root to path to allow importing from leann
project_root = Path(__file__).resolve().parent.parent
sys.path.insert(0, str(project_root))
from leann.api import LeannSearcher from leann.api import LeannSearcher
def download_data_if_needed(data_root: Path): def download_data_if_needed(data_root: Path):
"""Checks if the data directory exists, and if not, downloads it from HF Hub.""" """Checks if the data directory exists, and if not, downloads it from HF Hub."""
if not data_root.exists(): if not data_root.exists():
print(f"Data directory '{data_root}' not found.") print(f"Data directory '{data_root}' not found.")
print("Downloading evaluation data from Hugging Face Hub... (this may take a moment)") print(
"Downloading evaluation data from Hugging Face Hub... (this may take a moment)"
)
try: try:
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
snapshot_download( snapshot_download(
repo_id="LEANN-RAG/leann-rag-evaluation-data", repo_id="LEANN-RAG/leann-rag-evaluation-data",
repo_type="dataset", repo_type="dataset",
local_dir=data_root, local_dir=data_root,
local_dir_use_symlinks=False # Recommended for Windows compatibility and simpler structure local_dir_use_symlinks=False, # Recommended for Windows compatibility and simpler structure
) )
print("Data download complete!") print("Data download complete!")
except ImportError: except ImportError:
print("Error: huggingface_hub is not installed. Please install it to download the data:") print(
print("pip install -e ".[dev]"") "Error: huggingface_hub is not installed. Please install it to download the data:"
)
print("uv pip install -e '.[dev]'")
sys.exit(1) sys.exit(1)
except Exception as e: except Exception as e:
print(f"An error occurred during data download: {e}") print(f"An error occurred during data download: {e}")
sys.exit(1) sys.exit(1)
from leann.api import LeannSearcher
# --- Helper Function to get Golden Passages --- # --- Helper Function to get Golden Passages ---
@@ -54,29 +55,43 @@ def get_golden_texts(searcher: LeannSearcher, golden_ids: List[int]) -> set:
try: try:
# PassageManager uses string IDs # PassageManager uses string IDs
passage_data = searcher.passage_manager.get_passage(str(gid)) passage_data = searcher.passage_manager.get_passage(str(gid))
golden_texts.add(passage_data['text']) golden_texts.add(passage_data["text"])
except KeyError: except KeyError:
print(f"Warning: Golden passage ID '{gid}' not found in the index's passage data.") print(
f"Warning: Golden passage ID '{gid}' not found in the index's passage data."
)
return golden_texts return golden_texts
def load_queries(file_path: Path) -> List[str]: def load_queries(file_path: Path) -> List[str]:
queries = [] queries = []
with open(file_path, 'r', encoding='utf-8') as f: with open(file_path, "r", encoding="utf-8") as f:
for line in f: for line in f:
data = json.loads(line) data = json.loads(line)
queries.append(data['query']) queries.append(data["query"])
return queries return queries
def main(): def main():
parser = argparse.ArgumentParser(description="Run recall evaluation on a LEANN index.") parser = argparse.ArgumentParser(
parser.add_argument("index_path", type=str, help="Path to the LEANN index to evaluate.") description="Run recall evaluation on a LEANN index."
parser.add_argument("--num-queries", type=int, default=10, help="Number of queries to evaluate.") )
parser.add_argument("--top-k", type=int, default=3, help="The 'k' value for recall@k.") parser.add_argument(
parser.add_argument("--ef-search", type=int, default=120, help="The 'efSearch' parameter for HNSW.") "index_path", type=str, help="Path to the LEANN index to evaluate."
)
parser.add_argument(
"--num-queries", type=int, default=10, help="Number of queries to evaluate."
)
parser.add_argument(
"--top-k", type=int, default=3, help="The 'k' value for recall@k."
)
parser.add_argument(
"--ef-search", type=int, default=120, help="The 'efSearch' parameter for HNSW."
)
args = parser.parse_args() args = parser.parse_args()
# --- Path Configuration --- # --- Path Configuration ---
# Assumes a project structure where the script is in 'examples/' # Assumes a project structure where the script is in 'examples/'
# and data is in 'data/' at the project root. # and data is in 'data/' at the project root.
project_root = Path(__file__).resolve().parent.parent project_root = Path(__file__).resolve().parent.parent
data_root = project_root / "data" data_root = project_root / "data"
@@ -93,10 +108,14 @@ def main():
else: else:
# Fallback: try to infer from the index directory name # Fallback: try to infer from the index directory name
dataset_type = Path(args.index_path).name dataset_type = Path(args.index_path).name
print(f"WARNING: Could not detect dataset type from path, inferred '{dataset_type}'.") print(
f"WARNING: Could not detect dataset type from path, inferred '{dataset_type}'."
)
queries_file = data_root / "queries" / "nq_open.jsonl" queries_file = data_root / "queries" / "nq_open.jsonl"
golden_results_file = data_root / "ground_truth" / dataset_type / "flat_results_nq_k3.json" golden_results_file = (
data_root / "ground_truth" / dataset_type / "flat_results_nq_k3.json"
)
print(f"INFO: Detected dataset type: {dataset_type}") print(f"INFO: Detected dataset type: {dataset_type}")
print(f"INFO: Using queries file: {queries_file}") print(f"INFO: Using queries file: {queries_file}")
@@ -105,27 +124,29 @@ def main():
try: try:
searcher = LeannSearcher(args.index_path) searcher = LeannSearcher(args.index_path)
queries = load_queries(queries_file) queries = load_queries(queries_file)
with open(golden_results_file, 'r') as f: with open(golden_results_file, "r") as f:
golden_results_data = json.load(f) golden_results_data = json.load(f)
num_eval_queries = min(args.num_queries, len(queries)) num_eval_queries = min(args.num_queries, len(queries))
queries = queries[:num_eval_queries] queries = queries[:num_eval_queries]
print(f"\nRunning evaluation on {num_eval_queries} queries...") print(f"\nRunning evaluation on {num_eval_queries} queries...")
recall_scores = [] recall_scores = []
search_times = [] search_times = []
for i in range(num_eval_queries): for i in range(num_eval_queries):
start_time = time.time() start_time = time.time()
new_results = searcher.search(queries[i], top_k=args.top_k, ef=args.ef_search) new_results = searcher.search(
queries[i], top_k=args.top_k, ef=args.ef_search
)
search_times.append(time.time() - start_time) search_times.append(time.time() - start_time)
# Correct Recall Calculation: Based on TEXT content # Correct Recall Calculation: Based on TEXT content
new_texts = {result.text for result in new_results} new_texts = {result.text for result in new_results}
# Get golden texts directly from the searcher's passage manager # Get golden texts directly from the searcher's passage manager
golden_ids = golden_results_data["indices"][i][:args.top_k] golden_ids = golden_results_data["indices"][i][: args.top_k]
golden_texts = get_golden_texts(searcher, golden_ids) golden_texts = get_golden_texts(searcher, golden_ids)
overlap = len(new_texts & golden_texts) overlap = len(new_texts & golden_texts)
@@ -139,19 +160,21 @@ def main():
print(f"Overlap: {overlap}") print(f"Overlap: {overlap}")
print(f"Recall: {recall}") print(f"Recall: {recall}")
print(f"Search Time: {search_times[-1]:.4f}s") print(f"Search Time: {search_times[-1]:.4f}s")
print(f"--------------------------------") print("--------------------------------")
avg_recall = np.mean(recall_scores) if recall_scores else 0 avg_recall = np.mean(recall_scores) if recall_scores else 0
avg_time = np.mean(search_times) if search_times else 0 avg_time = np.mean(search_times) if search_times else 0
print(f"\n🎉 --- Evaluation Complete ---") print("\n🎉 --- Evaluation Complete ---")
print(f"Avg. Recall@{args.top_k} (efSearch={args.ef_search}): {avg_recall:.4f}") print(f"Avg. Recall@{args.top_k} (efSearch={args.ef_search}): {avg_recall:.4f}")
print(f"Avg. Search Time: {avg_time:.4f}s") print(f"Avg. Search Time: {avg_time:.4f}s")
except Exception as e: except Exception as e:
print(f"\n❌ An error occurred during evaluation: {e}") print(f"\n❌ An error occurred during evaluation: {e}")
import traceback import traceback
traceback.print_exc() traceback.print_exc()
if __name__ == "__main__": if __name__ == "__main__":
main() main()