feat: dataset for evaluation
This commit is contained in:
@@ -12,8 +12,6 @@ from pathlib import Path
|
||||
import sys
|
||||
import numpy as np
|
||||
from typing import List, Dict, Any
|
||||
import glob
|
||||
import pickle
|
||||
|
||||
# Add project root to path to allow importing from leann
|
||||
project_root = Path(__file__).resolve().parent.parent
|
||||
@@ -21,50 +19,45 @@ sys.path.insert(0, str(project_root))
|
||||
|
||||
from leann.api import LeannSearcher
|
||||
|
||||
# --- Configuration ---
|
||||
NQ_QUERIES_FILE = Path("/opt/dlami/nvme/scaling_out/examples/nq_open.jsonl")
|
||||
def download_data_if_needed(data_root: Path):
|
||||
"""Checks if the data directory exists, and if not, downloads it from HF Hub."""
|
||||
if not data_root.exists():
|
||||
print(f"Data directory '{data_root}' not found.")
|
||||
print("Downloading evaluation data from Hugging Face Hub... (this may take a moment)")
|
||||
try:
|
||||
from huggingface_hub import snapshot_download
|
||||
snapshot_download(
|
||||
repo_id="LEANN-RAG/leann-rag-evaluation-data",
|
||||
repo_type="dataset",
|
||||
local_dir=data_root,
|
||||
local_dir_use_symlinks=False # Recommended for Windows compatibility and simpler structure
|
||||
)
|
||||
print("Data download complete!")
|
||||
except ImportError:
|
||||
print("Error: huggingface_hub is not installed. Please install it to download the data:")
|
||||
print("pip install -e ".[dev]"")
|
||||
sys.exit(1)
|
||||
except Exception as e:
|
||||
print(f"An error occurred during data download: {e}")
|
||||
sys.exit(1)
|
||||
from leann.api import LeannSearcher
|
||||
|
||||
# Ground truth files for different datasets
|
||||
GROUND_TRUTH_FILES = {
|
||||
"rpj_wiki": "/opt/dlami/nvme/scaling_out/indices/rpj_wiki/facebook/contriever-msmarco/flat_results_nq_k3.json",
|
||||
"dpr": "/opt/dlami/nvme/scaling_out/indices/dpr/facebook/contriever-msmarco/flat_results_nq_k3.json"
|
||||
}
|
||||
|
||||
# Old passages for different datasets
|
||||
OLD_PASSAGES_GLOBS = {
|
||||
"rpj_wiki": "/opt/dlami/nvme/scaling_out/passages/rpj_wiki/8-shards/raw_passages-*-of-8.pkl.jsonl",
|
||||
"dpr": "/opt/dlami/nvme/scaling_out/passages/dpr/1-shards/raw_passages-*-of-1.pkl.jsonl"
|
||||
}
|
||||
|
||||
# --- Helper Class to Load Original Passages ---
|
||||
class OldPassageLoader:
|
||||
"""A simplified version of the old LazyPassages class to fetch golden results by ID."""
|
||||
def __init__(self, passages_glob: str):
|
||||
self.jsonl_paths = sorted(glob.glob(passages_glob))
|
||||
self.offsets = {}
|
||||
self.fps = [open(p, "r", encoding="utf-8") for p in self.jsonl_paths]
|
||||
print("Building offset map for original passages...")
|
||||
for i, shard_path_str in enumerate(self.jsonl_paths):
|
||||
old_idx_path = Path(shard_path_str.replace(".jsonl", ".idx"))
|
||||
if not old_idx_path.exists(): continue
|
||||
with open(old_idx_path, 'rb') as f:
|
||||
shard_offsets = pickle.load(f)
|
||||
for pid, offset in shard_offsets.items():
|
||||
self.offsets[str(pid)] = (i, offset)
|
||||
print("Offset map for original passages is ready.")
|
||||
|
||||
def get_passage_by_id(self, pid: str) -> Dict[str, Any]:
|
||||
pid = str(pid)
|
||||
if pid not in self.offsets:
|
||||
raise ValueError(f"Passage ID {pid} not found in offsets")
|
||||
file_idx, offset = self.offsets[pid]
|
||||
fp = self.fps[file_idx]
|
||||
fp.seek(offset)
|
||||
return json.loads(fp.readline())
|
||||
|
||||
def __del__(self):
|
||||
for fp in self.fps:
|
||||
fp.close()
|
||||
# --- Helper Function to get Golden Passages ---
|
||||
def get_golden_texts(searcher: LeannSearcher, golden_ids: List[int]) -> set:
|
||||
"""
|
||||
Retrieves the text for golden passage IDs directly from the LeannSearcher's
|
||||
passage manager.
|
||||
"""
|
||||
golden_texts = set()
|
||||
for gid in golden_ids:
|
||||
try:
|
||||
# PassageManager uses string IDs
|
||||
passage_data = searcher.passage_manager.get_passage(str(gid))
|
||||
golden_texts.add(passage_data['text'])
|
||||
except KeyError:
|
||||
print(f"Warning: Golden passage ID '{gid}' not found in the index's passage data.")
|
||||
return golden_texts
|
||||
|
||||
def load_queries(file_path: Path) -> List[str]:
|
||||
queries = []
|
||||
@@ -82,35 +75,40 @@ def main():
|
||||
parser.add_argument("--ef-search", type=int, default=120, help="The 'efSearch' parameter for HNSW.")
|
||||
args = parser.parse_args()
|
||||
|
||||
print(f"--- Final, Correct Recall Evaluation (efSearch={args.ef_search}) ---")
|
||||
|
||||
# Detect dataset type from index path
|
||||
# --- Path Configuration ---
|
||||
# Assumes a project structure where the script is in 'examples/'
|
||||
# and data is in 'data/' at the project root.
|
||||
project_root = Path(__file__).resolve().parent.parent
|
||||
data_root = project_root / "data"
|
||||
|
||||
# Automatically download data if it doesn't exist
|
||||
download_data_if_needed(data_root)
|
||||
|
||||
# Detect dataset type from index path to select the correct ground truth
|
||||
index_path_str = str(args.index_path)
|
||||
if "rpj_wiki" in index_path_str:
|
||||
dataset_type = "rpj_wiki"
|
||||
elif "dpr" in index_path_str:
|
||||
dataset_type = "dpr"
|
||||
else:
|
||||
print("WARNING: Unknown dataset type, defaulting to rpj_wiki")
|
||||
dataset_type = "rpj_wiki"
|
||||
|
||||
# Fallback: try to infer from the index directory name
|
||||
dataset_type = Path(args.index_path).name
|
||||
print(f"WARNING: Could not detect dataset type from path, inferred '{dataset_type}'.")
|
||||
|
||||
queries_file = data_root / "queries" / "nq_open.jsonl"
|
||||
golden_results_file = data_root / "ground_truth" / dataset_type / "flat_results_nq_k3.json"
|
||||
|
||||
print(f"INFO: Detected dataset type: {dataset_type}")
|
||||
print(f"INFO: Using queries file: {queries_file}")
|
||||
print(f"INFO: Using ground truth file: {golden_results_file}")
|
||||
|
||||
try:
|
||||
searcher = LeannSearcher(args.index_path)
|
||||
queries = load_queries(NQ_QUERIES_FILE)
|
||||
|
||||
golden_results_file = GROUND_TRUTH_FILES[dataset_type]
|
||||
old_passages_glob = OLD_PASSAGES_GLOBS[dataset_type]
|
||||
|
||||
print(f"INFO: Using ground truth file: {golden_results_file}")
|
||||
print(f"INFO: Using old passages glob: {old_passages_glob}")
|
||||
queries = load_queries(queries_file)
|
||||
|
||||
with open(golden_results_file, 'r') as f:
|
||||
golden_results_data = json.load(f)
|
||||
|
||||
old_passage_loader = OldPassageLoader(old_passages_glob)
|
||||
|
||||
num_eval_queries = min(args.num_queries, len(queries))
|
||||
queries = queries[:num_eval_queries]
|
||||
|
||||
@@ -125,8 +123,10 @@ def main():
|
||||
|
||||
# Correct Recall Calculation: Based on TEXT content
|
||||
new_texts = {result.text for result in new_results}
|
||||
|
||||
# Get golden texts directly from the searcher's passage manager
|
||||
golden_ids = golden_results_data["indices"][i][:args.top_k]
|
||||
golden_texts = {old_passage_loader.get_passage_by_id(str(gid))['text'] for gid in golden_ids}
|
||||
golden_texts = get_golden_texts(searcher, golden_ids)
|
||||
|
||||
overlap = len(new_texts & golden_texts)
|
||||
recall = overlap / len(golden_texts) if golden_texts else 0
|
||||
|
||||
Reference in New Issue
Block a user