fix: reproducible dpr on mac

This commit is contained in:
Andy Lee
2025-07-12 18:13:22 -07:00
parent ecab43e307
commit 71ef4b7d4c
3 changed files with 138 additions and 33 deletions

View File

@@ -11,36 +11,37 @@ import time
from pathlib import Path
import sys
import numpy as np
from typing import List, Dict, Any
# Add project root to path to allow importing from leann
project_root = Path(__file__).resolve().parent.parent
sys.path.insert(0, str(project_root))
from typing import List
from leann.api import LeannSearcher
def download_data_if_needed(data_root: Path):
"""Checks if the data directory exists, and if not, downloads it from HF Hub."""
if not data_root.exists():
print(f"Data directory '{data_root}' not found.")
print("Downloading evaluation data from Hugging Face Hub... (this may take a moment)")
print(
"Downloading evaluation data from Hugging Face Hub... (this may take a moment)"
)
try:
from huggingface_hub import snapshot_download
snapshot_download(
repo_id="LEANN-RAG/leann-rag-evaluation-data",
repo_type="dataset",
local_dir=data_root,
local_dir_use_symlinks=False # Recommended for Windows compatibility and simpler structure
local_dir_use_symlinks=False, # Recommended for Windows compatibility and simpler structure
)
print("Data download complete!")
except ImportError:
print("Error: huggingface_hub is not installed. Please install it to download the data:")
print("pip install -e ".[dev]"")
print(
"Error: huggingface_hub is not installed. Please install it to download the data:"
)
print("uv pip install -e '.[dev]'")
sys.exit(1)
except Exception as e:
print(f"An error occurred during data download: {e}")
sys.exit(1)
from leann.api import LeannSearcher
# --- Helper Function to get Golden Passages ---
@@ -54,29 +55,43 @@ def get_golden_texts(searcher: LeannSearcher, golden_ids: List[int]) -> set:
try:
# PassageManager uses string IDs
passage_data = searcher.passage_manager.get_passage(str(gid))
golden_texts.add(passage_data['text'])
golden_texts.add(passage_data["text"])
except KeyError:
print(f"Warning: Golden passage ID '{gid}' not found in the index's passage data.")
print(
f"Warning: Golden passage ID '{gid}' not found in the index's passage data."
)
return golden_texts
def load_queries(file_path: Path) -> List[str]:
queries = []
with open(file_path, 'r', encoding='utf-8') as f:
with open(file_path, "r", encoding="utf-8") as f:
for line in f:
data = json.loads(line)
queries.append(data['query'])
queries.append(data["query"])
return queries
def main():
parser = argparse.ArgumentParser(description="Run recall evaluation on a LEANN index.")
parser.add_argument("index_path", type=str, help="Path to the LEANN index to evaluate.")
parser.add_argument("--num-queries", type=int, default=10, help="Number of queries to evaluate.")
parser.add_argument("--top-k", type=int, default=3, help="The 'k' value for recall@k.")
parser.add_argument("--ef-search", type=int, default=120, help="The 'efSearch' parameter for HNSW.")
parser = argparse.ArgumentParser(
description="Run recall evaluation on a LEANN index."
)
parser.add_argument(
"index_path", type=str, help="Path to the LEANN index to evaluate."
)
parser.add_argument(
"--num-queries", type=int, default=10, help="Number of queries to evaluate."
)
parser.add_argument(
"--top-k", type=int, default=3, help="The 'k' value for recall@k."
)
parser.add_argument(
"--ef-search", type=int, default=120, help="The 'efSearch' parameter for HNSW."
)
args = parser.parse_args()
# --- Path Configuration ---
# Assumes a project structure where the script is in 'examples/'
# Assumes a project structure where the script is in 'examples/'
# and data is in 'data/' at the project root.
project_root = Path(__file__).resolve().parent.parent
data_root = project_root / "data"
@@ -93,10 +108,14 @@ def main():
else:
# Fallback: try to infer from the index directory name
dataset_type = Path(args.index_path).name
print(f"WARNING: Could not detect dataset type from path, inferred '{dataset_type}'.")
print(
f"WARNING: Could not detect dataset type from path, inferred '{dataset_type}'."
)
queries_file = data_root / "queries" / "nq_open.jsonl"
golden_results_file = data_root / "ground_truth" / dataset_type / "flat_results_nq_k3.json"
golden_results_file = (
data_root / "ground_truth" / dataset_type / "flat_results_nq_k3.json"
)
print(f"INFO: Detected dataset type: {dataset_type}")
print(f"INFO: Using queries file: {queries_file}")
@@ -105,27 +124,29 @@ def main():
try:
searcher = LeannSearcher(args.index_path)
queries = load_queries(queries_file)
with open(golden_results_file, 'r') as f:
with open(golden_results_file, "r") as f:
golden_results_data = json.load(f)
num_eval_queries = min(args.num_queries, len(queries))
queries = queries[:num_eval_queries]
print(f"\nRunning evaluation on {num_eval_queries} queries...")
recall_scores = []
search_times = []
for i in range(num_eval_queries):
start_time = time.time()
new_results = searcher.search(queries[i], top_k=args.top_k, ef=args.ef_search)
new_results = searcher.search(
queries[i], top_k=args.top_k, ef=args.ef_search
)
search_times.append(time.time() - start_time)
# Correct Recall Calculation: Based on TEXT content
new_texts = {result.text for result in new_results}
# Get golden texts directly from the searcher's passage manager
golden_ids = golden_results_data["indices"][i][:args.top_k]
golden_ids = golden_results_data["indices"][i][: args.top_k]
golden_texts = get_golden_texts(searcher, golden_ids)
overlap = len(new_texts & golden_texts)
@@ -139,19 +160,21 @@ def main():
print(f"Overlap: {overlap}")
print(f"Recall: {recall}")
print(f"Search Time: {search_times[-1]:.4f}s")
print(f"--------------------------------")
print("--------------------------------")
avg_recall = np.mean(recall_scores) if recall_scores else 0
avg_time = np.mean(search_times) if search_times else 0
print(f"\n🎉 --- Evaluation Complete ---")
print("\n🎉 --- Evaluation Complete ---")
print(f"Avg. Recall@{args.top_k} (efSearch={args.ef_search}): {avg_recall:.4f}")
print(f"Avg. Search Time: {avg_time:.4f}s")
except Exception as e:
print(f"\n❌ An error occurred during evaluation: {e}")
import traceback
traceback.print_exc()
if __name__ == "__main__":
main()
main()