Resolve submodule conflict - update to af2a264

This commit is contained in:
yichuan520030910320
2025-07-13 17:03:42 -07:00
10 changed files with 537 additions and 138 deletions

View File

@@ -11,122 +11,143 @@ import time
from pathlib import Path
import sys
import numpy as np
from typing import List, Dict, Any
import glob
import pickle
# Add project root to path to allow importing from leann
project_root = Path(__file__).resolve().parent.parent
sys.path.insert(0, str(project_root))
from typing import List
from leann.api import LeannSearcher
# --- Configuration ---
NQ_QUERIES_FILE = Path("/opt/dlami/nvme/scaling_out/examples/nq_open.jsonl")
# Ground truth files for different datasets
GROUND_TRUTH_FILES = {
"rpj_wiki": "/opt/dlami/nvme/scaling_out/indices/rpj_wiki/facebook/contriever-msmarco/flat_results_nq_k3.json",
"dpr": "/opt/dlami/nvme/scaling_out/indices/dpr/facebook/contriever-msmarco/flat_results_nq_k3.json"
}
def download_data_if_needed(data_root: Path):
"""Checks if the data directory exists, and if not, downloads it from HF Hub."""
if not data_root.exists():
print(f"Data directory '{data_root}' not found.")
print(
"Downloading evaluation data from Hugging Face Hub... (this may take a moment)"
)
try:
from huggingface_hub import snapshot_download
# Old passages for different datasets
OLD_PASSAGES_GLOBS = {
"rpj_wiki": "/opt/dlami/nvme/scaling_out/passages/rpj_wiki/8-shards/raw_passages-*-of-8.pkl.jsonl",
"dpr": "/opt/dlami/nvme/scaling_out/passages/dpr/1-shards/raw_passages-*-of-1.pkl.jsonl"
}
snapshot_download(
repo_id="LEANN-RAG/leann-rag-evaluation-data",
repo_type="dataset",
local_dir=data_root,
local_dir_use_symlinks=False, # Recommended for Windows compatibility and simpler structure
)
print("Data download complete!")
except ImportError:
print(
"Error: huggingface_hub is not installed. Please install it to download the data:"
)
print("uv pip install -e '.[dev]'")
sys.exit(1)
except Exception as e:
print(f"An error occurred during data download: {e}")
sys.exit(1)
# --- Helper Class to Load Original Passages ---
class OldPassageLoader:
"""A simplified version of the old LazyPassages class to fetch golden results by ID."""
def __init__(self, passages_glob: str):
self.jsonl_paths = sorted(glob.glob(passages_glob))
self.offsets = {}
self.fps = [open(p, "r", encoding="utf-8") for p in self.jsonl_paths]
print("Building offset map for original passages...")
for i, shard_path_str in enumerate(self.jsonl_paths):
old_idx_path = Path(shard_path_str.replace(".jsonl", ".idx"))
if not old_idx_path.exists(): continue
with open(old_idx_path, 'rb') as f:
shard_offsets = pickle.load(f)
for pid, offset in shard_offsets.items():
self.offsets[str(pid)] = (i, offset)
print("Offset map for original passages is ready.")
def get_passage_by_id(self, pid: str) -> Dict[str, Any]:
pid = str(pid)
if pid not in self.offsets:
raise ValueError(f"Passage ID {pid} not found in offsets")
file_idx, offset = self.offsets[pid]
fp = self.fps[file_idx]
fp.seek(offset)
return json.loads(fp.readline())
# --- Helper Function to get Golden Passages ---
def get_golden_texts(searcher: LeannSearcher, golden_ids: List[int]) -> set:
"""
Retrieves the text for golden passage IDs directly from the LeannSearcher's
passage manager.
"""
golden_texts = set()
for gid in golden_ids:
try:
# PassageManager uses string IDs
passage_data = searcher.passage_manager.get_passage(str(gid))
golden_texts.add(passage_data["text"])
except KeyError:
print(
f"Warning: Golden passage ID '{gid}' not found in the index's passage data."
)
return golden_texts
def __del__(self):
for fp in self.fps:
fp.close()
def load_queries(file_path: Path) -> List[str]:
queries = []
with open(file_path, 'r', encoding='utf-8') as f:
with open(file_path, "r", encoding="utf-8") as f:
for line in f:
data = json.loads(line)
queries.append(data['query'])
queries.append(data["query"])
return queries
def main():
parser = argparse.ArgumentParser(description="Run recall evaluation on a LEANN index.")
parser.add_argument("index_path", type=str, help="Path to the LEANN index to evaluate.")
parser.add_argument("--num-queries", type=int, default=10, help="Number of queries to evaluate.")
parser.add_argument("--top-k", type=int, default=3, help="The 'k' value for recall@k.")
parser.add_argument("--ef-search", type=int, default=120, help="The 'efSearch' parameter for HNSW.")
parser = argparse.ArgumentParser(
description="Run recall evaluation on a LEANN index."
)
parser.add_argument(
"index_path", type=str, help="Path to the LEANN index to evaluate."
)
parser.add_argument(
"--num-queries", type=int, default=10, help="Number of queries to evaluate."
)
parser.add_argument(
"--top-k", type=int, default=3, help="The 'k' value for recall@k."
)
parser.add_argument(
"--ef-search", type=int, default=120, help="The 'efSearch' parameter for HNSW."
)
args = parser.parse_args()
print(f"--- Final, Correct Recall Evaluation (efSearch={args.ef_search}) ---")
# Detect dataset type from index path
# --- Path Configuration ---
# Assumes a project structure where the script is in 'examples/'
# and data is in 'data/' at the project root.
project_root = Path(__file__).resolve().parent.parent
data_root = project_root / "data"
# Automatically download data if it doesn't exist
download_data_if_needed(data_root)
# Detect dataset type from index path to select the correct ground truth
index_path_str = str(args.index_path)
if "rpj_wiki" in index_path_str:
dataset_type = "rpj_wiki"
elif "dpr" in index_path_str:
dataset_type = "dpr"
else:
print("WARNING: Unknown dataset type, defaulting to rpj_wiki")
dataset_type = "rpj_wiki"
# Fallback: try to infer from the index directory name
dataset_type = Path(args.index_path).name
print(
f"WARNING: Could not detect dataset type from path, inferred '{dataset_type}'."
)
queries_file = data_root / "queries" / "nq_open.jsonl"
golden_results_file = (
data_root / "ground_truth" / dataset_type / "flat_results_nq_k3.json"
)
print(f"INFO: Detected dataset type: {dataset_type}")
print(f"INFO: Using queries file: {queries_file}")
print(f"INFO: Using ground truth file: {golden_results_file}")
try:
searcher = LeannSearcher(args.index_path)
queries = load_queries(NQ_QUERIES_FILE)
golden_results_file = GROUND_TRUTH_FILES[dataset_type]
old_passages_glob = OLD_PASSAGES_GLOBS[dataset_type]
print(f"INFO: Using ground truth file: {golden_results_file}")
print(f"INFO: Using old passages glob: {old_passages_glob}")
with open(golden_results_file, 'r') as f:
queries = load_queries(queries_file)
with open(golden_results_file, "r") as f:
golden_results_data = json.load(f)
old_passage_loader = OldPassageLoader(old_passages_glob)
num_eval_queries = min(args.num_queries, len(queries))
queries = queries[:num_eval_queries]
print(f"\nRunning evaluation on {num_eval_queries} queries...")
recall_scores = []
search_times = []
for i in range(num_eval_queries):
start_time = time.time()
new_results = searcher.search(queries[i], top_k=args.top_k, ef=args.ef_search)
new_results = searcher.search(
queries[i], top_k=args.top_k, ef=args.ef_search
)
search_times.append(time.time() - start_time)
# Correct Recall Calculation: Based on TEXT content
new_texts = {result.text for result in new_results}
golden_ids = golden_results_data["indices"][i][:args.top_k]
golden_texts = {old_passage_loader.get_passage_by_id(str(gid))['text'] for gid in golden_ids}
# Get golden texts directly from the searcher's passage manager
golden_ids = golden_results_data["indices"][i][: args.top_k]
golden_texts = get_golden_texts(searcher, golden_ids)
overlap = len(new_texts & golden_texts)
recall = overlap / len(golden_texts) if golden_texts else 0
@@ -139,19 +160,21 @@ def main():
print(f"Overlap: {overlap}")
print(f"Recall: {recall}")
print(f"Search Time: {search_times[-1]:.4f}s")
print(f"--------------------------------")
print("--------------------------------")
avg_recall = np.mean(recall_scores) if recall_scores else 0
avg_time = np.mean(search_times) if search_times else 0
print(f"\n🎉 --- Evaluation Complete ---")
print("\n🎉 --- Evaluation Complete ---")
print(f"Avg. Recall@{args.top_k} (efSearch={args.ef_search}): {avg_recall:.4f}")
print(f"Avg. Search Time: {avg_time:.4f}s")
except Exception as e:
print(f"\n❌ An error occurred during evaluation: {e}")
import traceback
traceback.print_exc()
if __name__ == "__main__":
main()
main()