* fix: diskann zmq port and passages * feat: auto discovery of packages and fix passage gen for diskann * docs: embedding pruning * refactor: passage structure * feat: reproducible research datas, rpj_wiki & dpr * refactor: chat and base searcher * feat: chat on mps
157 lines
6.1 KiB
Python
157 lines
6.1 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
This script runs a recall evaluation on a given LEANN index.
|
|
It correctly compares results by fetching the text content for both the new search
|
|
results and the golden standard results, making the comparison robust to ID changes.
|
|
"""
|
|
|
|
import json
|
|
import argparse
|
|
import time
|
|
from pathlib import Path
|
|
import sys
|
|
import numpy as np
|
|
from typing import List, Dict, Any
|
|
import glob
|
|
import pickle
|
|
|
|
# Add project root to path to allow importing from leann
|
|
project_root = Path(__file__).resolve().parent.parent
|
|
sys.path.insert(0, str(project_root))
|
|
|
|
from leann.api import LeannSearcher
|
|
|
|
# --- Configuration ---
|
|
NQ_QUERIES_FILE = Path("/opt/dlami/nvme/scaling_out/examples/nq_open.jsonl")
|
|
|
|
# Ground truth files for different datasets
|
|
GROUND_TRUTH_FILES = {
|
|
"rpj_wiki": "/opt/dlami/nvme/scaling_out/indices/rpj_wiki/facebook/contriever-msmarco/flat_results_nq_k3.json",
|
|
"dpr": "/opt/dlami/nvme/scaling_out/indices/dpr/facebook/contriever-msmarco/flat_results_nq_k3.json"
|
|
}
|
|
|
|
# Old passages for different datasets
|
|
OLD_PASSAGES_GLOBS = {
|
|
"rpj_wiki": "/opt/dlami/nvme/scaling_out/passages/rpj_wiki/8-shards/raw_passages-*-of-8.pkl.jsonl",
|
|
"dpr": "/opt/dlami/nvme/scaling_out/passages/dpr/1-shards/raw_passages-*-of-1.pkl.jsonl"
|
|
}
|
|
|
|
# --- Helper Class to Load Original Passages ---
|
|
class OldPassageLoader:
|
|
"""A simplified version of the old LazyPassages class to fetch golden results by ID."""
|
|
def __init__(self, passages_glob: str):
|
|
self.jsonl_paths = sorted(glob.glob(passages_glob))
|
|
self.offsets = {}
|
|
self.fps = [open(p, "r", encoding="utf-8") for p in self.jsonl_paths]
|
|
print("Building offset map for original passages...")
|
|
for i, shard_path_str in enumerate(self.jsonl_paths):
|
|
old_idx_path = Path(shard_path_str.replace(".jsonl", ".idx"))
|
|
if not old_idx_path.exists(): continue
|
|
with open(old_idx_path, 'rb') as f:
|
|
shard_offsets = pickle.load(f)
|
|
for pid, offset in shard_offsets.items():
|
|
self.offsets[str(pid)] = (i, offset)
|
|
print("Offset map for original passages is ready.")
|
|
|
|
def get_passage_by_id(self, pid: str) -> Dict[str, Any]:
|
|
pid = str(pid)
|
|
if pid not in self.offsets:
|
|
raise ValueError(f"Passage ID {pid} not found in offsets")
|
|
file_idx, offset = self.offsets[pid]
|
|
fp = self.fps[file_idx]
|
|
fp.seek(offset)
|
|
return json.loads(fp.readline())
|
|
|
|
def __del__(self):
|
|
for fp in self.fps:
|
|
fp.close()
|
|
|
|
def load_queries(file_path: Path) -> List[str]:
|
|
queries = []
|
|
with open(file_path, 'r', encoding='utf-8') as f:
|
|
for line in f:
|
|
data = json.loads(line)
|
|
queries.append(data['query'])
|
|
return queries
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description="Run recall evaluation on a LEANN index.")
|
|
parser.add_argument("index_path", type=str, help="Path to the LEANN index to evaluate.")
|
|
parser.add_argument("--num-queries", type=int, default=10, help="Number of queries to evaluate.")
|
|
parser.add_argument("--top-k", type=int, default=3, help="The 'k' value for recall@k.")
|
|
parser.add_argument("--ef-search", type=int, default=120, help="The 'efSearch' parameter for HNSW.")
|
|
args = parser.parse_args()
|
|
|
|
print(f"--- Final, Correct Recall Evaluation (efSearch={args.ef_search}) ---")
|
|
|
|
# Detect dataset type from index path
|
|
index_path_str = str(args.index_path)
|
|
if "rpj_wiki" in index_path_str:
|
|
dataset_type = "rpj_wiki"
|
|
elif "dpr" in index_path_str:
|
|
dataset_type = "dpr"
|
|
else:
|
|
print("WARNING: Unknown dataset type, defaulting to rpj_wiki")
|
|
dataset_type = "rpj_wiki"
|
|
|
|
print(f"INFO: Detected dataset type: {dataset_type}")
|
|
|
|
try:
|
|
searcher = LeannSearcher(args.index_path)
|
|
queries = load_queries(NQ_QUERIES_FILE)
|
|
|
|
golden_results_file = GROUND_TRUTH_FILES[dataset_type]
|
|
old_passages_glob = OLD_PASSAGES_GLOBS[dataset_type]
|
|
|
|
print(f"INFO: Using ground truth file: {golden_results_file}")
|
|
print(f"INFO: Using old passages glob: {old_passages_glob}")
|
|
|
|
with open(golden_results_file, 'r') as f:
|
|
golden_results_data = json.load(f)
|
|
|
|
old_passage_loader = OldPassageLoader(old_passages_glob)
|
|
|
|
num_eval_queries = min(args.num_queries, len(queries))
|
|
queries = queries[:num_eval_queries]
|
|
|
|
print(f"\nRunning evaluation on {num_eval_queries} queries...")
|
|
recall_scores = []
|
|
search_times = []
|
|
|
|
for i in range(num_eval_queries):
|
|
start_time = time.time()
|
|
new_results = searcher.search(queries[i], top_k=args.top_k, ef=args.ef_search)
|
|
search_times.append(time.time() - start_time)
|
|
|
|
# Correct Recall Calculation: Based on TEXT content
|
|
new_texts = {result.text for result in new_results}
|
|
golden_ids = golden_results_data["indices"][i][:args.top_k]
|
|
golden_texts = {old_passage_loader.get_passage_by_id(str(gid))['text'] for gid in golden_ids}
|
|
|
|
overlap = len(new_texts & golden_texts)
|
|
recall = overlap / len(golden_texts) if golden_texts else 0
|
|
recall_scores.append(recall)
|
|
|
|
print("\n--- EVALUATION RESULTS ---")
|
|
print(f"Query: {queries[i]}")
|
|
print(f"New Results: {new_texts}")
|
|
print(f"Golden Results: {golden_texts}")
|
|
print(f"Overlap: {overlap}")
|
|
print(f"Recall: {recall}")
|
|
print(f"Search Time: {search_times[-1]:.4f}s")
|
|
print(f"--------------------------------")
|
|
|
|
avg_recall = np.mean(recall_scores) if recall_scores else 0
|
|
avg_time = np.mean(search_times) if search_times else 0
|
|
|
|
print(f"\n🎉 --- Evaluation Complete ---")
|
|
print(f"Avg. Recall@{args.top_k} (efSearch={args.ef_search}): {avg_recall:.4f}")
|
|
print(f"Avg. Search Time: {avg_time:.4f}s")
|
|
|
|
except Exception as e:
|
|
print(f"\n❌ An error occurred during evaluation: {e}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
|
|
if __name__ == "__main__":
|
|
main() |