fix: reproducible dpr on mac
This commit is contained in:
82
data/.gitattributes
vendored
Normal file
82
data/.gitattributes
vendored
Normal file
@@ -0,0 +1,82 @@
|
|||||||
|
*.7z filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.arrow filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.bin filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.ftz filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.gz filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.h5 filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.joblib filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.lz4 filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.mds filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.model filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.npy filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.npz filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.onnx filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.ot filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.parquet filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.pb filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.pickle filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.pkl filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.pt filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.pth filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.rar filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
||||||
|
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.tar filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.tflite filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.tgz filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.wasm filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.xz filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.zip filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.zst filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
||||||
|
# Audio files - uncompressed
|
||||||
|
*.pcm filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.sam filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.raw filter=lfs diff=lfs merge=lfs -text
|
||||||
|
# Audio files - compressed
|
||||||
|
*.aac filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.flac filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.mp3 filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.ogg filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.wav filter=lfs diff=lfs merge=lfs -text
|
||||||
|
# Image files - uncompressed
|
||||||
|
*.bmp filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.gif filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.png filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.tiff filter=lfs diff=lfs merge=lfs -text
|
||||||
|
# Image files - compressed
|
||||||
|
*.jpg filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.jpeg filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.webp filter=lfs diff=lfs merge=lfs -text
|
||||||
|
# Video files - compressed
|
||||||
|
*.mp4 filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.webm filter=lfs diff=lfs merge=lfs -text
|
||||||
|
ground_truth/dpr/id_map.json filter=lfs diff=lfs merge=lfs -text
|
||||||
|
indices/dpr/dpr_diskann.passages.idx filter=lfs diff=lfs merge=lfs -text
|
||||||
|
indices/dpr/dpr_diskann.passages.jsonl filter=lfs diff=lfs merge=lfs -text
|
||||||
|
indices/dpr/dpr_diskann_disk.index filter=lfs diff=lfs merge=lfs -text
|
||||||
|
indices/dpr/leann.labels.map filter=lfs diff=lfs merge=lfs -text
|
||||||
|
indices/rpj_wiki/leann.labels.map filter=lfs diff=lfs merge=lfs -text
|
||||||
|
indices/rpj_wiki/rpj_wiki.index filter=lfs diff=lfs merge=lfs -text
|
||||||
|
indices/rpj_wiki/rpj_wiki.passages.0.idx filter=lfs diff=lfs merge=lfs -text
|
||||||
|
indices/rpj_wiki/rpj_wiki.passages.0.jsonl filter=lfs diff=lfs merge=lfs -text
|
||||||
|
indices/rpj_wiki/rpj_wiki.passages.1.idx filter=lfs diff=lfs merge=lfs -text
|
||||||
|
indices/rpj_wiki/rpj_wiki.passages.1.jsonl filter=lfs diff=lfs merge=lfs -text
|
||||||
|
indices/rpj_wiki/rpj_wiki.passages.2.idx filter=lfs diff=lfs merge=lfs -text
|
||||||
|
indices/rpj_wiki/rpj_wiki.passages.2.jsonl filter=lfs diff=lfs merge=lfs -text
|
||||||
|
indices/rpj_wiki/rpj_wiki.passages.3.idx filter=lfs diff=lfs merge=lfs -text
|
||||||
|
indices/rpj_wiki/rpj_wiki.passages.3.jsonl filter=lfs diff=lfs merge=lfs -text
|
||||||
|
indices/rpj_wiki/rpj_wiki.passages.4.idx filter=lfs diff=lfs merge=lfs -text
|
||||||
|
indices/rpj_wiki/rpj_wiki.passages.4.jsonl filter=lfs diff=lfs merge=lfs -text
|
||||||
|
indices/rpj_wiki/rpj_wiki.passages.5.idx filter=lfs diff=lfs merge=lfs -text
|
||||||
|
indices/rpj_wiki/rpj_wiki.passages.5.jsonl filter=lfs diff=lfs merge=lfs -text
|
||||||
|
indices/rpj_wiki/rpj_wiki.passages.6.idx filter=lfs diff=lfs merge=lfs -text
|
||||||
|
indices/rpj_wiki/rpj_wiki.passages.6.jsonl filter=lfs diff=lfs merge=lfs -text
|
||||||
|
indices/rpj_wiki/rpj_wiki.passages.7.idx filter=lfs diff=lfs merge=lfs -text
|
||||||
|
indices/rpj_wiki/rpj_wiki.passages.7.jsonl filter=lfs diff=lfs merge=lfs -text
|
||||||
@@ -11,36 +11,37 @@ import time
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import sys
|
import sys
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from typing import List, Dict, Any
|
from typing import List
|
||||||
|
|
||||||
# Add project root to path to allow importing from leann
|
|
||||||
project_root = Path(__file__).resolve().parent.parent
|
|
||||||
sys.path.insert(0, str(project_root))
|
|
||||||
|
|
||||||
from leann.api import LeannSearcher
|
from leann.api import LeannSearcher
|
||||||
|
|
||||||
|
|
||||||
def download_data_if_needed(data_root: Path):
|
def download_data_if_needed(data_root: Path):
|
||||||
"""Checks if the data directory exists, and if not, downloads it from HF Hub."""
|
"""Checks if the data directory exists, and if not, downloads it from HF Hub."""
|
||||||
if not data_root.exists():
|
if not data_root.exists():
|
||||||
print(f"Data directory '{data_root}' not found.")
|
print(f"Data directory '{data_root}' not found.")
|
||||||
print("Downloading evaluation data from Hugging Face Hub... (this may take a moment)")
|
print(
|
||||||
|
"Downloading evaluation data from Hugging Face Hub... (this may take a moment)"
|
||||||
|
)
|
||||||
try:
|
try:
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
|
|
||||||
snapshot_download(
|
snapshot_download(
|
||||||
repo_id="LEANN-RAG/leann-rag-evaluation-data",
|
repo_id="LEANN-RAG/leann-rag-evaluation-data",
|
||||||
repo_type="dataset",
|
repo_type="dataset",
|
||||||
local_dir=data_root,
|
local_dir=data_root,
|
||||||
local_dir_use_symlinks=False # Recommended for Windows compatibility and simpler structure
|
local_dir_use_symlinks=False, # Recommended for Windows compatibility and simpler structure
|
||||||
)
|
)
|
||||||
print("Data download complete!")
|
print("Data download complete!")
|
||||||
except ImportError:
|
except ImportError:
|
||||||
print("Error: huggingface_hub is not installed. Please install it to download the data:")
|
print(
|
||||||
print("pip install -e ".[dev]"")
|
"Error: huggingface_hub is not installed. Please install it to download the data:"
|
||||||
|
)
|
||||||
|
print("uv pip install -e '.[dev]'")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"An error occurred during data download: {e}")
|
print(f"An error occurred during data download: {e}")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
from leann.api import LeannSearcher
|
|
||||||
|
|
||||||
|
|
||||||
# --- Helper Function to get Golden Passages ---
|
# --- Helper Function to get Golden Passages ---
|
||||||
@@ -54,25 +55,39 @@ def get_golden_texts(searcher: LeannSearcher, golden_ids: List[int]) -> set:
|
|||||||
try:
|
try:
|
||||||
# PassageManager uses string IDs
|
# PassageManager uses string IDs
|
||||||
passage_data = searcher.passage_manager.get_passage(str(gid))
|
passage_data = searcher.passage_manager.get_passage(str(gid))
|
||||||
golden_texts.add(passage_data['text'])
|
golden_texts.add(passage_data["text"])
|
||||||
except KeyError:
|
except KeyError:
|
||||||
print(f"Warning: Golden passage ID '{gid}' not found in the index's passage data.")
|
print(
|
||||||
|
f"Warning: Golden passage ID '{gid}' not found in the index's passage data."
|
||||||
|
)
|
||||||
return golden_texts
|
return golden_texts
|
||||||
|
|
||||||
|
|
||||||
def load_queries(file_path: Path) -> List[str]:
|
def load_queries(file_path: Path) -> List[str]:
|
||||||
queries = []
|
queries = []
|
||||||
with open(file_path, 'r', encoding='utf-8') as f:
|
with open(file_path, "r", encoding="utf-8") as f:
|
||||||
for line in f:
|
for line in f:
|
||||||
data = json.loads(line)
|
data = json.loads(line)
|
||||||
queries.append(data['query'])
|
queries.append(data["query"])
|
||||||
return queries
|
return queries
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser(description="Run recall evaluation on a LEANN index.")
|
parser = argparse.ArgumentParser(
|
||||||
parser.add_argument("index_path", type=str, help="Path to the LEANN index to evaluate.")
|
description="Run recall evaluation on a LEANN index."
|
||||||
parser.add_argument("--num-queries", type=int, default=10, help="Number of queries to evaluate.")
|
)
|
||||||
parser.add_argument("--top-k", type=int, default=3, help="The 'k' value for recall@k.")
|
parser.add_argument(
|
||||||
parser.add_argument("--ef-search", type=int, default=120, help="The 'efSearch' parameter for HNSW.")
|
"index_path", type=str, help="Path to the LEANN index to evaluate."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--num-queries", type=int, default=10, help="Number of queries to evaluate."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--top-k", type=int, default=3, help="The 'k' value for recall@k."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--ef-search", type=int, default=120, help="The 'efSearch' parameter for HNSW."
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# --- Path Configuration ---
|
# --- Path Configuration ---
|
||||||
@@ -93,10 +108,14 @@ def main():
|
|||||||
else:
|
else:
|
||||||
# Fallback: try to infer from the index directory name
|
# Fallback: try to infer from the index directory name
|
||||||
dataset_type = Path(args.index_path).name
|
dataset_type = Path(args.index_path).name
|
||||||
print(f"WARNING: Could not detect dataset type from path, inferred '{dataset_type}'.")
|
print(
|
||||||
|
f"WARNING: Could not detect dataset type from path, inferred '{dataset_type}'."
|
||||||
|
)
|
||||||
|
|
||||||
queries_file = data_root / "queries" / "nq_open.jsonl"
|
queries_file = data_root / "queries" / "nq_open.jsonl"
|
||||||
golden_results_file = data_root / "ground_truth" / dataset_type / "flat_results_nq_k3.json"
|
golden_results_file = (
|
||||||
|
data_root / "ground_truth" / dataset_type / "flat_results_nq_k3.json"
|
||||||
|
)
|
||||||
|
|
||||||
print(f"INFO: Detected dataset type: {dataset_type}")
|
print(f"INFO: Detected dataset type: {dataset_type}")
|
||||||
print(f"INFO: Using queries file: {queries_file}")
|
print(f"INFO: Using queries file: {queries_file}")
|
||||||
@@ -106,7 +125,7 @@ def main():
|
|||||||
searcher = LeannSearcher(args.index_path)
|
searcher = LeannSearcher(args.index_path)
|
||||||
queries = load_queries(queries_file)
|
queries = load_queries(queries_file)
|
||||||
|
|
||||||
with open(golden_results_file, 'r') as f:
|
with open(golden_results_file, "r") as f:
|
||||||
golden_results_data = json.load(f)
|
golden_results_data = json.load(f)
|
||||||
|
|
||||||
num_eval_queries = min(args.num_queries, len(queries))
|
num_eval_queries = min(args.num_queries, len(queries))
|
||||||
@@ -118,14 +137,16 @@ def main():
|
|||||||
|
|
||||||
for i in range(num_eval_queries):
|
for i in range(num_eval_queries):
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
new_results = searcher.search(queries[i], top_k=args.top_k, ef=args.ef_search)
|
new_results = searcher.search(
|
||||||
|
queries[i], top_k=args.top_k, ef=args.ef_search
|
||||||
|
)
|
||||||
search_times.append(time.time() - start_time)
|
search_times.append(time.time() - start_time)
|
||||||
|
|
||||||
# Correct Recall Calculation: Based on TEXT content
|
# Correct Recall Calculation: Based on TEXT content
|
||||||
new_texts = {result.text for result in new_results}
|
new_texts = {result.text for result in new_results}
|
||||||
|
|
||||||
# Get golden texts directly from the searcher's passage manager
|
# Get golden texts directly from the searcher's passage manager
|
||||||
golden_ids = golden_results_data["indices"][i][:args.top_k]
|
golden_ids = golden_results_data["indices"][i][: args.top_k]
|
||||||
golden_texts = get_golden_texts(searcher, golden_ids)
|
golden_texts = get_golden_texts(searcher, golden_ids)
|
||||||
|
|
||||||
overlap = len(new_texts & golden_texts)
|
overlap = len(new_texts & golden_texts)
|
||||||
@@ -139,19 +160,21 @@ def main():
|
|||||||
print(f"Overlap: {overlap}")
|
print(f"Overlap: {overlap}")
|
||||||
print(f"Recall: {recall}")
|
print(f"Recall: {recall}")
|
||||||
print(f"Search Time: {search_times[-1]:.4f}s")
|
print(f"Search Time: {search_times[-1]:.4f}s")
|
||||||
print(f"--------------------------------")
|
print("--------------------------------")
|
||||||
|
|
||||||
avg_recall = np.mean(recall_scores) if recall_scores else 0
|
avg_recall = np.mean(recall_scores) if recall_scores else 0
|
||||||
avg_time = np.mean(search_times) if search_times else 0
|
avg_time = np.mean(search_times) if search_times else 0
|
||||||
|
|
||||||
print(f"\n🎉 --- Evaluation Complete ---")
|
print("\n🎉 --- Evaluation Complete ---")
|
||||||
print(f"Avg. Recall@{args.top_k} (efSearch={args.ef_search}): {avg_recall:.4f}")
|
print(f"Avg. Recall@{args.top_k} (efSearch={args.ef_search}): {avg_recall:.4f}")
|
||||||
print(f"Avg. Search Time: {avg_time:.4f}s")
|
print(f"Avg. Search Time: {avg_time:.4f}s")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"\n❌ An error occurred during evaluation: {e}")
|
print(f"\n❌ An error occurred during evaluation: {e}")
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
Submodule packages/leann-backend-diskann/third_party/DiskANN updated: 2dcf156553...af2a26481e
Reference in New Issue
Block a user