feat: allow build from existed embeddings
This commit is contained in:
@@ -13,10 +13,10 @@ import sys
|
||||
import numpy as np
|
||||
from typing import List
|
||||
|
||||
from leann.api import LeannSearcher
|
||||
from leann.api import LeannSearcher, LeannBuilder
|
||||
|
||||
|
||||
def download_data_if_needed(data_root: Path):
|
||||
def download_data_if_needed(data_root: Path, download_embeddings: bool = False):
|
||||
"""Checks if the data directory exists, and if not, downloads it from HF Hub."""
|
||||
if not data_root.exists():
|
||||
print(f"Data directory '{data_root}' not found.")
|
||||
@@ -26,13 +26,32 @@ def download_data_if_needed(data_root: Path):
|
||||
try:
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
snapshot_download(
|
||||
repo_id="LEANN-RAG/leann-rag-evaluation-data",
|
||||
repo_type="dataset",
|
||||
local_dir=data_root,
|
||||
local_dir_use_symlinks=False, # Recommended for Windows compatibility and simpler structure
|
||||
)
|
||||
print("Data download complete!")
|
||||
if download_embeddings:
|
||||
# Download everything including embeddings (large files)
|
||||
snapshot_download(
|
||||
repo_id="LEANN-RAG/leann-rag-evaluation-data",
|
||||
repo_type="dataset",
|
||||
local_dir=data_root,
|
||||
local_dir_use_symlinks=False,
|
||||
)
|
||||
print("Data download complete (including embeddings)!")
|
||||
else:
|
||||
# Download only specific folders, excluding embeddings
|
||||
allow_patterns = [
|
||||
"ground_truth/**",
|
||||
"indices/**",
|
||||
"queries/**",
|
||||
"*.md",
|
||||
"*.txt",
|
||||
]
|
||||
snapshot_download(
|
||||
repo_id="LEANN-RAG/leann-rag-evaluation-data",
|
||||
repo_type="dataset",
|
||||
local_dir=data_root,
|
||||
local_dir_use_symlinks=False,
|
||||
allow_patterns=allow_patterns,
|
||||
)
|
||||
print("Data download complete (excluding embeddings)!")
|
||||
except ImportError:
|
||||
print(
|
||||
"Error: huggingface_hub is not installed. Please install it to download the data:"
|
||||
@@ -44,6 +63,43 @@ def download_data_if_needed(data_root: Path):
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def download_embeddings_if_needed(data_root: Path, dataset_type: str = None):
|
||||
"""Download embeddings files specifically."""
|
||||
embeddings_dir = data_root / "embeddings"
|
||||
|
||||
if dataset_type:
|
||||
# Check if specific dataset embeddings exist
|
||||
target_file = embeddings_dir / dataset_type / "passages_00.pkl"
|
||||
if target_file.exists():
|
||||
print(f"Embeddings for {dataset_type} already exist")
|
||||
return str(target_file)
|
||||
|
||||
print("Downloading embeddings from HuggingFace Hub...")
|
||||
try:
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
# Download only embeddings folder
|
||||
snapshot_download(
|
||||
repo_id="LEANN-RAG/leann-rag-evaluation-data",
|
||||
repo_type="dataset",
|
||||
local_dir=data_root,
|
||||
local_dir_use_symlinks=False,
|
||||
allow_patterns=["embeddings/**/*.pkl"],
|
||||
)
|
||||
print("Embeddings download complete!")
|
||||
|
||||
if dataset_type:
|
||||
target_file = embeddings_dir / dataset_type / "passages_00.pkl"
|
||||
if target_file.exists():
|
||||
return str(target_file)
|
||||
|
||||
return str(embeddings_dir)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error downloading embeddings: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
# --- Helper Function to get Golden Passages ---
|
||||
def get_golden_texts(searcher: LeannSearcher, golden_ids: List[int]) -> set:
|
||||
"""
|
||||
@@ -72,12 +128,76 @@ def load_queries(file_path: Path) -> List[str]:
|
||||
return queries
|
||||
|
||||
|
||||
def build_index_from_embeddings(
|
||||
embeddings_file: str, output_path: str, backend: str = "hnsw"
|
||||
):
|
||||
"""
|
||||
Build a LEANN index from pre-computed embeddings.
|
||||
|
||||
Args:
|
||||
embeddings_file: Path to pickle file with (ids, embeddings) tuple
|
||||
output_path: Path where to save the index
|
||||
backend: Backend to use ("hnsw" or "diskann")
|
||||
"""
|
||||
print(f"Building {backend} index from embeddings: {embeddings_file}")
|
||||
|
||||
# Create builder with appropriate parameters
|
||||
if backend == "hnsw":
|
||||
builder_kwargs = {
|
||||
"M": 32, # Graph degree
|
||||
"efConstruction": 256, # Construction complexity
|
||||
"is_compact": True, # Use compact storage
|
||||
"is_recompute": True, # Enable pruning for better recall
|
||||
}
|
||||
elif backend == "diskann":
|
||||
builder_kwargs = {
|
||||
"complexity": 64,
|
||||
"graph_degree": 32,
|
||||
"search_memory_maximum": 8.0, # GB
|
||||
"build_memory_maximum": 16.0, # GB
|
||||
}
|
||||
else:
|
||||
builder_kwargs = {}
|
||||
|
||||
builder = LeannBuilder(
|
||||
backend_name=backend,
|
||||
embedding_model="facebook/contriever-msmarco", # Model used to create embeddings
|
||||
dimensions=768, # Will be auto-detected from embeddings
|
||||
**builder_kwargs,
|
||||
)
|
||||
|
||||
# Build index from precomputed embeddings
|
||||
builder.build_index_from_embeddings(output_path, embeddings_file)
|
||||
print(f"Index saved to: {output_path}")
|
||||
return output_path
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Run recall evaluation on a LEANN index."
|
||||
)
|
||||
parser.add_argument(
|
||||
"index_path", type=str, help="Path to the LEANN index to evaluate."
|
||||
"index_path",
|
||||
type=str,
|
||||
nargs="?",
|
||||
help="Path to the LEANN index to evaluate or build (optional).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mode",
|
||||
choices=["evaluate", "build"],
|
||||
default="evaluate",
|
||||
help="Mode: 'evaluate' existing index or 'build' from embeddings",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--embeddings-file",
|
||||
type=str,
|
||||
help="Path to embeddings pickle file (optional for build mode)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--backend",
|
||||
choices=["hnsw", "diskann"],
|
||||
default="hnsw",
|
||||
help="Backend to use for building index (default: hnsw)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-queries", type=int, default=10, help="Number of queries to evaluate."
|
||||
@@ -96,8 +216,90 @@ def main():
|
||||
project_root = Path(__file__).resolve().parent.parent
|
||||
data_root = project_root / "data"
|
||||
|
||||
# Automatically download data if it doesn't exist
|
||||
download_data_if_needed(data_root)
|
||||
# Download data based on mode
|
||||
if args.mode == "build":
|
||||
# For building mode, we need embeddings
|
||||
download_data_if_needed(
|
||||
data_root, download_embeddings=False
|
||||
) # Basic data first
|
||||
|
||||
# Auto-detect dataset type and download embeddings
|
||||
if args.embeddings_file:
|
||||
embeddings_file = args.embeddings_file
|
||||
# Try to detect dataset type from embeddings file path
|
||||
if "rpj_wiki" in str(embeddings_file):
|
||||
dataset_type = "rpj_wiki"
|
||||
elif "dpr" in str(embeddings_file):
|
||||
dataset_type = "dpr"
|
||||
else:
|
||||
dataset_type = "dpr" # Default
|
||||
else:
|
||||
# Auto-detect from index path if provided, otherwise default to DPR
|
||||
if args.index_path:
|
||||
index_path_str = str(args.index_path)
|
||||
if "rpj_wiki" in index_path_str:
|
||||
dataset_type = "rpj_wiki"
|
||||
elif "dpr" in index_path_str:
|
||||
dataset_type = "dpr"
|
||||
else:
|
||||
dataset_type = "dpr" # Default to DPR
|
||||
else:
|
||||
dataset_type = "dpr" # Default to DPR
|
||||
|
||||
embeddings_file = download_embeddings_if_needed(data_root, dataset_type)
|
||||
|
||||
# Auto-generate index path if not provided
|
||||
if not args.index_path:
|
||||
indices_dir = data_root / "indices" / dataset_type
|
||||
indices_dir.mkdir(parents=True, exist_ok=True)
|
||||
args.index_path = str(indices_dir / f"{dataset_type}_from_embeddings")
|
||||
print(f"Auto-generated index path: {args.index_path}")
|
||||
|
||||
print(f"Building index from embeddings: {embeddings_file}")
|
||||
built_index_path = build_index_from_embeddings(
|
||||
embeddings_file, args.index_path, args.backend
|
||||
)
|
||||
print(f"Index built successfully: {built_index_path}")
|
||||
|
||||
# Ask if user wants to run evaluation
|
||||
eval_response = (
|
||||
input("Run evaluation on the built index? (y/n): ").strip().lower()
|
||||
)
|
||||
if eval_response != "y":
|
||||
print("Index building complete. Exiting.")
|
||||
return
|
||||
else:
|
||||
# For evaluation mode, don't need embeddings
|
||||
download_data_if_needed(data_root, download_embeddings=False)
|
||||
|
||||
# Auto-detect index path if not provided
|
||||
if not args.index_path:
|
||||
# Default to using downloaded indices
|
||||
indices_dir = data_root / "indices"
|
||||
|
||||
# Try common datasets in order of preference
|
||||
for dataset in ["dpr", "rpj_wiki"]:
|
||||
dataset_dir = indices_dir / dataset
|
||||
if dataset_dir.exists():
|
||||
# Look for index files
|
||||
index_files = list(dataset_dir.glob("*.index")) + list(
|
||||
dataset_dir.glob("*_disk.index")
|
||||
)
|
||||
if index_files:
|
||||
args.index_path = str(
|
||||
index_files[0].with_suffix("")
|
||||
) # Remove .index extension
|
||||
print(f"Using index: {args.index_path}")
|
||||
break
|
||||
|
||||
if not args.index_path:
|
||||
print(
|
||||
"No indices found. The data download should have included pre-built indices."
|
||||
)
|
||||
print(
|
||||
"Please check the data/indices/ directory or provide --index-path manually."
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
# Detect dataset type from index path to select the correct ground truth
|
||||
index_path_str = str(args.index_path)
|
||||
|
||||
Reference in New Issue
Block a user