reproduce docvqa results
This commit is contained in:
@@ -223,17 +223,13 @@ def _embed_queries(model, processor, queries: list[str]) -> list[Any]:
|
|||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
# Match MTEB's exact query processing from ColPaliEngineWrapper.get_text_embeddings:
|
# Match MTEB's exact query processing from ColPaliEngineWrapper.get_text_embeddings:
|
||||||
# 1. MTEB receives batch["text"] which may already include instruction/prompt
|
# 1. MTEB receives batch["text"] which already includes instruction/prompt (from _combine_queries_with_instruction_text)
|
||||||
# 2. Manually adds: query_prefix + text + query_augmentation_token * 10
|
# 2. Manually adds: query_prefix + text + query_augmentation_token * 10
|
||||||
# 3. Calls processor.process_queries(batch) where batch is now a list of strings
|
# 3. Calls processor.process_queries(batch) where batch is now a list of strings
|
||||||
# 4. process_queries adds: query_prefix + text + suffix (suffix = query_augmentation_token * 10)
|
# 4. process_queries adds: query_prefix + text + suffix (suffix = query_augmentation_token * 10)
|
||||||
#
|
#
|
||||||
# However, MTEB's approach results in duplicate addition (20 tokens total).
|
# This results in duplicate addition: query_prefix is added twice, query_augmentation_token * 20 total
|
||||||
# Since we're already adding the prompt in search_queries, let's try:
|
# We need to match this exactly to reproduce MTEB results
|
||||||
# Option 1: Just call process_queries (let it handle all additions) - avoids duplicate
|
|
||||||
# Option 2: Manual add + process_texts (to avoid duplicate)
|
|
||||||
#
|
|
||||||
# Testing shows Option 1 works better - just call process_queries without manual addition
|
|
||||||
|
|
||||||
all_embeds = []
|
all_embeds = []
|
||||||
batch_size = 32 # Match MTEB's default batch_size
|
batch_size = 32 # Match MTEB's default batch_size
|
||||||
@@ -242,9 +238,15 @@ def _embed_queries(model, processor, queries: list[str]) -> list[Any]:
|
|||||||
for i in tqdm(range(0, len(queries), batch_size), desc="Embedding queries"):
|
for i in tqdm(range(0, len(queries), batch_size), desc="Embedding queries"):
|
||||||
batch_queries = queries[i:i + batch_size]
|
batch_queries = queries[i:i + batch_size]
|
||||||
|
|
||||||
# Just call process_queries - it will add query_prefix + text + 10 tokens
|
# Match MTEB: manually add query_prefix + text + query_augmentation_token * 10
|
||||||
# This avoids duplicate addition that happens in MTEB's approach
|
# Then process_queries will add them again (resulting in 20 augmentation tokens total)
|
||||||
inputs = processor.process_queries(batch_queries)
|
batch = [
|
||||||
|
processor.query_prefix
|
||||||
|
+ t
|
||||||
|
+ processor.query_augmentation_token * 10
|
||||||
|
for t in batch_queries
|
||||||
|
]
|
||||||
|
inputs = processor.process_queries(batch)
|
||||||
inputs = {k: v.to(model.device) for k, v in inputs.items()}
|
inputs = {k: v.to(model.device) for k, v in inputs.items()}
|
||||||
|
|
||||||
if model.device.type == "cuda":
|
if model.device.type == "cuda":
|
||||||
@@ -1044,3 +1046,249 @@ class LeannMultiVector:
|
|||||||
"image_path": meta.get("image_path", ""),
|
"image_path": meta.get("image_path", ""),
|
||||||
}
|
}
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class ViDoReBenchmarkEvaluator:
|
||||||
|
"""
|
||||||
|
A reusable class for evaluating ViDoRe benchmarks (v1 and v2).
|
||||||
|
This class encapsulates common functionality for building indexes, searching, and evaluating.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_name: str,
|
||||||
|
use_fast_plaid: bool = False,
|
||||||
|
top_k: int = 100,
|
||||||
|
first_stage_k: int = 500,
|
||||||
|
k_values: list[int] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize the evaluator.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_name: Model name ("colqwen2" or "colpali")
|
||||||
|
use_fast_plaid: Whether to use Fast-Plaid instead of LEANN
|
||||||
|
top_k: Top-k results to retrieve
|
||||||
|
first_stage_k: First stage k for LEANN search
|
||||||
|
k_values: List of k values for evaluation metrics
|
||||||
|
"""
|
||||||
|
self.model_name = model_name
|
||||||
|
self.use_fast_plaid = use_fast_plaid
|
||||||
|
self.top_k = top_k
|
||||||
|
self.first_stage_k = first_stage_k
|
||||||
|
self.k_values = k_values if k_values is not None else [1, 3, 5, 10, 100]
|
||||||
|
|
||||||
|
# Load model once (can be reused across tasks)
|
||||||
|
self._model = None
|
||||||
|
self._processor = None
|
||||||
|
self._model_name_actual = None
|
||||||
|
|
||||||
|
def _load_model_if_needed(self):
|
||||||
|
"""Lazy load the model."""
|
||||||
|
if self._model is None:
|
||||||
|
print(f"\nLoading model: {self.model_name}")
|
||||||
|
self._model_name_actual, self._model, self._processor, _, _, _ = _load_colvision(self.model_name)
|
||||||
|
print(f"Model loaded: {self._model_name_actual}")
|
||||||
|
|
||||||
|
def build_index_from_corpus(
|
||||||
|
self,
|
||||||
|
corpus: dict[str, Image.Image],
|
||||||
|
index_path: str,
|
||||||
|
rebuild: bool = False,
|
||||||
|
) -> tuple[Any, list[str]]:
|
||||||
|
"""
|
||||||
|
Build index from corpus images.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
corpus: dict mapping corpus_id to PIL Image
|
||||||
|
index_path: Path to save/load the index
|
||||||
|
rebuild: Whether to rebuild even if index exists
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: (retriever or fast_plaid_index object, list of corpus_ids in order)
|
||||||
|
"""
|
||||||
|
self._load_model_if_needed()
|
||||||
|
|
||||||
|
# Ensure consistent ordering
|
||||||
|
corpus_ids = sorted(corpus.keys())
|
||||||
|
images = [corpus[cid] for cid in corpus_ids]
|
||||||
|
|
||||||
|
if self.use_fast_plaid:
|
||||||
|
# Check if Fast-Plaid index exists
|
||||||
|
if not rebuild and _load_fast_plaid_index_if_exists(index_path) is not None:
|
||||||
|
print(f"Fast-Plaid index already exists at {index_path}")
|
||||||
|
return _load_fast_plaid_index_if_exists(index_path), corpus_ids
|
||||||
|
|
||||||
|
print(f"Building Fast-Plaid index at {index_path}...")
|
||||||
|
print("Embedding images...")
|
||||||
|
doc_vecs = _embed_images(self._model, self._processor, images)
|
||||||
|
|
||||||
|
fast_plaid_index, build_time = _build_fast_plaid_index(
|
||||||
|
index_path, doc_vecs, corpus_ids, images
|
||||||
|
)
|
||||||
|
print(f"Fast-Plaid index built in {build_time:.2f}s")
|
||||||
|
return fast_plaid_index, corpus_ids
|
||||||
|
else:
|
||||||
|
# Check if LEANN index exists
|
||||||
|
if not rebuild:
|
||||||
|
retriever = _load_retriever_if_index_exists(index_path)
|
||||||
|
if retriever is not None:
|
||||||
|
print(f"LEANN index already exists at {index_path}")
|
||||||
|
return retriever, corpus_ids
|
||||||
|
|
||||||
|
print(f"Building LEANN index at {index_path}...")
|
||||||
|
print("Embedding images...")
|
||||||
|
doc_vecs = _embed_images(self._model, self._processor, images)
|
||||||
|
|
||||||
|
retriever = _build_index(index_path, doc_vecs, corpus_ids, images)
|
||||||
|
print(f"LEANN index built")
|
||||||
|
return retriever, corpus_ids
|
||||||
|
|
||||||
|
def search_queries(
|
||||||
|
self,
|
||||||
|
queries: dict[str, str],
|
||||||
|
corpus_ids: list[str],
|
||||||
|
index_or_retriever: Any,
|
||||||
|
fast_plaid_index_path: Optional[str] = None,
|
||||||
|
task_prompt: Optional[dict[str, str]] = None,
|
||||||
|
) -> dict[str, dict[str, float]]:
|
||||||
|
"""
|
||||||
|
Search queries against the index.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
queries: dict mapping query_id to query text
|
||||||
|
corpus_ids: list of corpus_ids in the same order as the index
|
||||||
|
index_or_retriever: index or retriever object
|
||||||
|
fast_plaid_index_path: path to Fast-Plaid index (for metadata)
|
||||||
|
task_prompt: Optional dict with prompt for query (e.g., {"query": "..."})
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
results: dict mapping query_id to dict of {corpus_id: score}
|
||||||
|
"""
|
||||||
|
self._load_model_if_needed()
|
||||||
|
|
||||||
|
print(f"Searching {len(queries)} queries (top_k={self.top_k})...")
|
||||||
|
|
||||||
|
query_ids = list(queries.keys())
|
||||||
|
query_texts = [queries[qid] for qid in query_ids]
|
||||||
|
|
||||||
|
# Note: ColPaliEngineWrapper does NOT use task prompt from metadata
|
||||||
|
# It uses query_prefix + text + query_augmentation_token (handled in _embed_queries)
|
||||||
|
# So we don't append task_prompt here to match MTEB behavior
|
||||||
|
|
||||||
|
# Embed queries
|
||||||
|
print("Embedding queries...")
|
||||||
|
query_vecs = _embed_queries(self._model, self._processor, query_texts)
|
||||||
|
|
||||||
|
results = {}
|
||||||
|
|
||||||
|
for query_id, query_vec in zip(tqdm(query_ids, desc="Searching"), query_vecs):
|
||||||
|
if self.use_fast_plaid:
|
||||||
|
# Fast-Plaid search
|
||||||
|
search_results, _ = _search_fast_plaid(index_or_retriever, query_vec, self.top_k)
|
||||||
|
query_results = {}
|
||||||
|
for score, doc_id in search_results:
|
||||||
|
if doc_id < len(corpus_ids):
|
||||||
|
corpus_id = corpus_ids[doc_id]
|
||||||
|
query_results[corpus_id] = float(score)
|
||||||
|
else:
|
||||||
|
# LEANN search
|
||||||
|
import torch
|
||||||
|
query_np = query_vec.float().numpy() if isinstance(query_vec, torch.Tensor) else query_vec
|
||||||
|
search_results = index_or_retriever.search_exact_all(query_np, topk=self.top_k)
|
||||||
|
query_results = {}
|
||||||
|
for score, doc_id in search_results:
|
||||||
|
if doc_id < len(corpus_ids):
|
||||||
|
corpus_id = corpus_ids[doc_id]
|
||||||
|
query_results[corpus_id] = float(score)
|
||||||
|
|
||||||
|
results[query_id] = query_results
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def evaluate_results(
|
||||||
|
results: dict[str, dict[str, float]],
|
||||||
|
qrels: dict[str, dict[str, int]],
|
||||||
|
k_values: list[int] = None,
|
||||||
|
) -> dict[str, float]:
|
||||||
|
"""
|
||||||
|
Evaluate retrieval results using NDCG and other metrics.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
results: dict mapping query_id to dict of {corpus_id: score}
|
||||||
|
qrels: dict mapping query_id to dict of {corpus_id: relevance_score}
|
||||||
|
k_values: List of k values for evaluation metrics
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary of metric scores
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
import pytrec_eval
|
||||||
|
from mteb._evaluators.retrieval_metrics import (
|
||||||
|
calculate_retrieval_scores,
|
||||||
|
make_score_dict,
|
||||||
|
)
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError("pytrec_eval is required for evaluation. Install with: pip install pytrec-eval")
|
||||||
|
|
||||||
|
if k_values is None:
|
||||||
|
k_values = [1, 3, 5, 10, 100]
|
||||||
|
|
||||||
|
# Check if we have any queries to evaluate
|
||||||
|
if len(results) == 0:
|
||||||
|
print("Warning: No queries to evaluate. Returning zero scores.")
|
||||||
|
scores = {}
|
||||||
|
for k in k_values:
|
||||||
|
scores[f"ndcg_at_{k}"] = 0.0
|
||||||
|
scores[f"map_at_{k}"] = 0.0
|
||||||
|
scores[f"recall_at_{k}"] = 0.0
|
||||||
|
scores[f"precision_at_{k}"] = 0.0
|
||||||
|
scores[f"mrr_at_{k}"] = 0.0
|
||||||
|
return scores
|
||||||
|
|
||||||
|
print(f"Evaluating results with k_values={k_values}...")
|
||||||
|
print(f"Before filtering: {len(results)} results, {len(qrels)} qrels")
|
||||||
|
|
||||||
|
# Filter to ensure qrels and results have the same query set
|
||||||
|
# This matches MTEB behavior: only evaluate queries that exist in both
|
||||||
|
# pytrec_eval only evaluates queries in qrels, so we need to ensure
|
||||||
|
# results contains all queries in qrels, and filter out queries not in qrels
|
||||||
|
results_filtered = {qid: res for qid, res in results.items() if qid in qrels}
|
||||||
|
qrels_filtered = {qid: rel_docs for qid, rel_docs in qrels.items() if qid in results_filtered}
|
||||||
|
|
||||||
|
print(f"After filtering: {len(results_filtered)} results, {len(qrels_filtered)} qrels")
|
||||||
|
|
||||||
|
if len(results_filtered) != len(qrels_filtered):
|
||||||
|
print(f"Warning: Mismatch between results ({len(results_filtered)}) and qrels ({len(qrels_filtered)}) queries")
|
||||||
|
missing_in_results = set(qrels.keys()) - set(results.keys())
|
||||||
|
if missing_in_results:
|
||||||
|
print(f"Queries in qrels but not in results: {len(missing_in_results)} queries")
|
||||||
|
print(f"First 5 missing queries: {list(missing_in_results)[:5]}")
|
||||||
|
|
||||||
|
# Convert qrels to pytrec_eval format
|
||||||
|
qrels_pytrec = {}
|
||||||
|
for qid, rel_docs in qrels_filtered.items():
|
||||||
|
qrels_pytrec[qid] = {did: score for did, score in rel_docs.items()}
|
||||||
|
|
||||||
|
# Evaluate
|
||||||
|
eval_result = calculate_retrieval_scores(
|
||||||
|
results=results_filtered,
|
||||||
|
qrels=qrels_pytrec,
|
||||||
|
k_values=k_values,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Format scores
|
||||||
|
scores = make_score_dict(
|
||||||
|
ndcg=eval_result.ndcg,
|
||||||
|
_map=eval_result.map,
|
||||||
|
recall=eval_result.recall,
|
||||||
|
precision=eval_result.precision,
|
||||||
|
mrr=eval_result.mrr,
|
||||||
|
naucs=eval_result.naucs,
|
||||||
|
naucs_mrr=eval_result.naucs_mrr,
|
||||||
|
cv_recall=eval_result.cv_recall,
|
||||||
|
task_scores={},
|
||||||
|
)
|
||||||
|
|
||||||
|
return scores
|
||||||
|
|||||||
@@ -0,0 +1,389 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Modular script to reproduce NDCG results for ViDoRe v1 benchmark.
|
||||||
|
|
||||||
|
This script uses the interface from leann_multi_vector.py to:
|
||||||
|
1. Download ViDoRe v1 datasets
|
||||||
|
2. Build indexes (LEANN or Fast-Plaid)
|
||||||
|
3. Perform retrieval
|
||||||
|
4. Evaluate using NDCG metrics
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
# Evaluate all ViDoRe v1 tasks
|
||||||
|
python vidore_v1_benchmark.py --model colqwen2 --tasks all
|
||||||
|
|
||||||
|
# Evaluate specific task
|
||||||
|
python vidore_v1_benchmark.py --model colqwen2 --task VidoreArxivQARetrieval
|
||||||
|
|
||||||
|
# Use Fast-Plaid index
|
||||||
|
python vidore_v1_benchmark.py --model colqwen2 --use-fast-plaid --fast-plaid-index-path ./indexes/vidore_fastplaid
|
||||||
|
|
||||||
|
# Rebuild index
|
||||||
|
python vidore_v1_benchmark.py --model colqwen2 --rebuild-index
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from datasets import load_dataset
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from leann_multi_vector import (
|
||||||
|
_ensure_repo_paths_importable,
|
||||||
|
ViDoReBenchmarkEvaluator,
|
||||||
|
)
|
||||||
|
|
||||||
|
_ensure_repo_paths_importable(__file__)
|
||||||
|
|
||||||
|
# ViDoRe v1 task configurations
|
||||||
|
# Prompts match MTEB task metadata prompts
|
||||||
|
VIDORE_V1_TASKS = {
|
||||||
|
"VidoreArxivQARetrieval": {
|
||||||
|
"dataset_path": "vidore/arxivqa_test_subsampled_beir",
|
||||||
|
"revision": "7d94d570960eac2408d3baa7a33f9de4822ae3e4",
|
||||||
|
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
|
||||||
|
},
|
||||||
|
"VidoreDocVQARetrieval": {
|
||||||
|
"dataset_path": "vidore/docvqa_test_subsampled_beir",
|
||||||
|
"revision": "162ba2fc1a8437eda8b6c37b240bc1c0f0deb092",
|
||||||
|
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
|
||||||
|
},
|
||||||
|
"VidoreInfoVQARetrieval": {
|
||||||
|
"dataset_path": "vidore/infovqa_test_subsampled_beir",
|
||||||
|
"revision": "b802cc5fd6c605df2d673a963667d74881d2c9a4",
|
||||||
|
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
|
||||||
|
},
|
||||||
|
"VidoreTabfquadRetrieval": {
|
||||||
|
"dataset_path": "vidore/tabfquad_test_subsampled_beir",
|
||||||
|
"revision": "61a2224bcd29b7b261a4892ff4c8bea353527a31",
|
||||||
|
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
|
||||||
|
},
|
||||||
|
"VidoreTatdqaRetrieval": {
|
||||||
|
"dataset_path": "vidore/tatdqa_test_beir",
|
||||||
|
"revision": "5feb5630fdff4d8d189ffedb2dba56862fdd45c0",
|
||||||
|
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
|
||||||
|
},
|
||||||
|
"VidoreShiftProjectRetrieval": {
|
||||||
|
"dataset_path": "vidore/shiftproject_test_beir",
|
||||||
|
"revision": "84a382e05c4473fed9cff2bbae95fe2379416117",
|
||||||
|
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
|
||||||
|
},
|
||||||
|
"VidoreSyntheticDocQAAIRetrieval": {
|
||||||
|
"dataset_path": "vidore/syntheticDocQA_artificial_intelligence_test_beir",
|
||||||
|
"revision": "2d9ebea5a1c6e9ef4a3b902a612f605dca11261c",
|
||||||
|
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
|
||||||
|
},
|
||||||
|
"VidoreSyntheticDocQAEnergyRetrieval": {
|
||||||
|
"dataset_path": "vidore/syntheticDocQA_energy_test_beir",
|
||||||
|
"revision": "9935aadbad5c8deec30910489db1b2c7133ae7a7",
|
||||||
|
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
|
||||||
|
},
|
||||||
|
"VidoreSyntheticDocQAGovernmentReportsRetrieval": {
|
||||||
|
"dataset_path": "vidore/syntheticDocQA_government_reports_test_beir",
|
||||||
|
"revision": "b4909afa930f81282fd20601e860668073ad02aa",
|
||||||
|
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
|
||||||
|
},
|
||||||
|
"VidoreSyntheticDocQAHealthcareIndustryRetrieval": {
|
||||||
|
"dataset_path": "vidore/syntheticDocQA_healthcare_industry_test_beir",
|
||||||
|
"revision": "f9e25d5b6e13e1ad9f5c3cce202565031b3ab164",
|
||||||
|
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def load_vidore_v1_data(
|
||||||
|
dataset_path: str,
|
||||||
|
revision: Optional[str] = None,
|
||||||
|
split: str = "test",
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Load ViDoRe v1 dataset.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
corpus: dict mapping corpus_id to PIL Image
|
||||||
|
queries: dict mapping query_id to query text
|
||||||
|
qrels: dict mapping query_id to dict of {corpus_id: relevance_score}
|
||||||
|
"""
|
||||||
|
print(f"Loading dataset: {dataset_path} (split={split})")
|
||||||
|
|
||||||
|
# Load queries
|
||||||
|
query_ds = load_dataset(dataset_path, "queries", split=split, revision=revision)
|
||||||
|
|
||||||
|
queries = {}
|
||||||
|
for row in query_ds:
|
||||||
|
query_id = f"query-{split}-{row['query-id']}"
|
||||||
|
queries[query_id] = row["query"]
|
||||||
|
|
||||||
|
# Load corpus (images)
|
||||||
|
corpus_ds = load_dataset(dataset_path, "corpus", split=split, revision=revision)
|
||||||
|
|
||||||
|
corpus = {}
|
||||||
|
for row in corpus_ds:
|
||||||
|
corpus_id = f"corpus-{split}-{row['corpus-id']}"
|
||||||
|
# Extract image from the dataset row
|
||||||
|
if "image" in row:
|
||||||
|
corpus[corpus_id] = row["image"]
|
||||||
|
elif "page_image" in row:
|
||||||
|
corpus[corpus_id] = row["page_image"]
|
||||||
|
else:
|
||||||
|
raise ValueError(f"No image field found in corpus. Available fields: {list(row.keys())}")
|
||||||
|
|
||||||
|
# Load qrels (relevance judgments)
|
||||||
|
qrels_ds = load_dataset(dataset_path, "qrels", split=split, revision=revision)
|
||||||
|
|
||||||
|
qrels = {}
|
||||||
|
for row in qrels_ds:
|
||||||
|
query_id = f"query-{split}-{row['query-id']}"
|
||||||
|
corpus_id = f"corpus-{split}-{row['corpus-id']}"
|
||||||
|
if query_id not in qrels:
|
||||||
|
qrels[query_id] = {}
|
||||||
|
qrels[query_id][corpus_id] = int(row["score"])
|
||||||
|
|
||||||
|
print(f"Loaded {len(queries)} queries, {len(corpus)} corpus items, {len(qrels)} query-relevance mappings")
|
||||||
|
|
||||||
|
# Filter qrels to only include queries that exist
|
||||||
|
qrels = {qid: rel_docs for qid, rel_docs in qrels.items() if qid in queries}
|
||||||
|
|
||||||
|
# Filter out queries without any relevant documents (matching MTEB behavior)
|
||||||
|
# This is important for correct NDCG calculation
|
||||||
|
qrels_filtered = {qid: rel_docs for qid, rel_docs in qrels.items() if len(rel_docs) > 0}
|
||||||
|
queries_filtered = {qid: query_text for qid, query_text in queries.items() if qid in qrels_filtered}
|
||||||
|
|
||||||
|
print(f"After filtering queries without positives: {len(queries_filtered)} queries, {len(qrels_filtered)} query-relevance mappings")
|
||||||
|
|
||||||
|
return corpus, queries_filtered, qrels_filtered
|
||||||
|
|
||||||
|
|
||||||
|
def evaluate_task(
|
||||||
|
task_name: str,
|
||||||
|
model_name: str,
|
||||||
|
index_path: str,
|
||||||
|
use_fast_plaid: bool = False,
|
||||||
|
fast_plaid_index_path: Optional[str] = None,
|
||||||
|
rebuild_index: bool = False,
|
||||||
|
top_k: int = 1000,
|
||||||
|
first_stage_k: int = 500,
|
||||||
|
k_values: list[int] = [1, 3, 5, 10, 20, 100, 1000],
|
||||||
|
output_dir: Optional[str] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Evaluate a single ViDoRe v1 task.
|
||||||
|
"""
|
||||||
|
print(f"\n{'='*80}")
|
||||||
|
print(f"Evaluating task: {task_name}")
|
||||||
|
print(f"{'='*80}")
|
||||||
|
|
||||||
|
# Get task config
|
||||||
|
if task_name not in VIDORE_V1_TASKS:
|
||||||
|
raise ValueError(f"Unknown task: {task_name}. Available: {list(VIDORE_V1_TASKS.keys())}")
|
||||||
|
|
||||||
|
task_config = VIDORE_V1_TASKS[task_name]
|
||||||
|
dataset_path = task_config["dataset_path"]
|
||||||
|
revision = task_config["revision"]
|
||||||
|
|
||||||
|
# Load data
|
||||||
|
corpus, queries, qrels = load_vidore_v1_data(
|
||||||
|
dataset_path=dataset_path,
|
||||||
|
revision=revision,
|
||||||
|
split="test",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check if we have any queries
|
||||||
|
if len(queries) == 0:
|
||||||
|
print(f"\nWarning: No queries found for task {task_name}. Skipping evaluation.")
|
||||||
|
# Return zero scores
|
||||||
|
scores = {}
|
||||||
|
for k in k_values:
|
||||||
|
scores[f"ndcg_at_{k}"] = 0.0
|
||||||
|
scores[f"map_at_{k}"] = 0.0
|
||||||
|
scores[f"recall_at_{k}"] = 0.0
|
||||||
|
scores[f"precision_at_{k}"] = 0.0
|
||||||
|
scores[f"mrr_at_{k}"] = 0.0
|
||||||
|
return scores
|
||||||
|
|
||||||
|
# Initialize evaluator
|
||||||
|
evaluator = ViDoReBenchmarkEvaluator(
|
||||||
|
model_name=model_name,
|
||||||
|
use_fast_plaid=use_fast_plaid,
|
||||||
|
top_k=top_k,
|
||||||
|
first_stage_k=first_stage_k,
|
||||||
|
k_values=k_values,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Build or load index
|
||||||
|
index_path_full = index_path if not use_fast_plaid else fast_plaid_index_path
|
||||||
|
if index_path_full is None:
|
||||||
|
index_path_full = f"./indexes/{task_name}_{model_name}"
|
||||||
|
if use_fast_plaid:
|
||||||
|
index_path_full = f"./indexes/{task_name}_{model_name}_fastplaid"
|
||||||
|
|
||||||
|
index_or_retriever, corpus_ids_ordered = evaluator.build_index_from_corpus(
|
||||||
|
corpus=corpus,
|
||||||
|
index_path=index_path_full,
|
||||||
|
rebuild=rebuild_index,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Search queries
|
||||||
|
task_prompt = task_config.get("prompt")
|
||||||
|
results = evaluator.search_queries(
|
||||||
|
queries=queries,
|
||||||
|
corpus_ids=corpus_ids_ordered,
|
||||||
|
index_or_retriever=index_or_retriever,
|
||||||
|
fast_plaid_index_path=fast_plaid_index_path,
|
||||||
|
task_prompt=task_prompt,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Evaluate
|
||||||
|
scores = evaluator.evaluate_results(results, qrels, k_values=k_values)
|
||||||
|
|
||||||
|
# Print results
|
||||||
|
print(f"\n{'='*80}")
|
||||||
|
print(f"Results for {task_name}:")
|
||||||
|
print(f"{'='*80}")
|
||||||
|
for metric, value in scores.items():
|
||||||
|
if isinstance(value, (int, float)):
|
||||||
|
print(f" {metric}: {value:.5f}")
|
||||||
|
|
||||||
|
# Save results
|
||||||
|
if output_dir:
|
||||||
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
|
results_file = os.path.join(output_dir, f"{task_name}_results.json")
|
||||||
|
scores_file = os.path.join(output_dir, f"{task_name}_scores.json")
|
||||||
|
|
||||||
|
with open(results_file, "w") as f:
|
||||||
|
json.dump(results, f, indent=2)
|
||||||
|
print(f"\nSaved results to: {results_file}")
|
||||||
|
|
||||||
|
with open(scores_file, "w") as f:
|
||||||
|
json.dump(scores, f, indent=2)
|
||||||
|
print(f"Saved scores to: {scores_file}")
|
||||||
|
|
||||||
|
return scores
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Evaluate ViDoRe v1 benchmark using LEANN/Fast-Plaid indexing"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--model",
|
||||||
|
type=str,
|
||||||
|
default="colqwen2",
|
||||||
|
choices=["colqwen2", "colpali"],
|
||||||
|
help="Model to use",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--task",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Specific task to evaluate (or 'all' for all tasks)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--tasks",
|
||||||
|
type=str,
|
||||||
|
default="all",
|
||||||
|
help="Tasks to evaluate: 'all' or comma-separated list",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--index-path",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Path to LEANN index (auto-generated if not provided)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--use-fast-plaid",
|
||||||
|
action="store_true",
|
||||||
|
help="Use Fast-Plaid instead of LEANN",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--fast-plaid-index-path",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Path to Fast-Plaid index (auto-generated if not provided)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--rebuild-index",
|
||||||
|
action="store_true",
|
||||||
|
help="Rebuild index even if it exists",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--top-k",
|
||||||
|
type=int,
|
||||||
|
default=1000,
|
||||||
|
help="Top-k results to retrieve (MTEB default is max(k_values)=1000)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--first-stage-k",
|
||||||
|
type=int,
|
||||||
|
default=500,
|
||||||
|
help="First stage k for LEANN search",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--k-values",
|
||||||
|
type=str,
|
||||||
|
default="1,3,5,10,20,100,1000",
|
||||||
|
help="Comma-separated k values for evaluation (e.g., '1,3,5,10,100')",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--output-dir",
|
||||||
|
type=str,
|
||||||
|
default="./vidore_v1_results",
|
||||||
|
help="Output directory for results",
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# Parse k_values
|
||||||
|
k_values = [int(k.strip()) for k in args.k_values.split(",")]
|
||||||
|
|
||||||
|
# Determine tasks to evaluate
|
||||||
|
if args.task:
|
||||||
|
tasks_to_eval = [args.task]
|
||||||
|
elif args.tasks.lower() == "all":
|
||||||
|
tasks_to_eval = list(VIDORE_V1_TASKS.keys())
|
||||||
|
else:
|
||||||
|
tasks_to_eval = [t.strip() for t in args.tasks.split(",")]
|
||||||
|
|
||||||
|
print(f"Tasks to evaluate: {tasks_to_eval}")
|
||||||
|
|
||||||
|
# Evaluate each task
|
||||||
|
all_scores = {}
|
||||||
|
for task_name in tasks_to_eval:
|
||||||
|
try:
|
||||||
|
scores = evaluate_task(
|
||||||
|
task_name=task_name,
|
||||||
|
model_name=args.model,
|
||||||
|
index_path=args.index_path,
|
||||||
|
use_fast_plaid=args.use_fast_plaid,
|
||||||
|
fast_plaid_index_path=args.fast_plaid_index_path,
|
||||||
|
rebuild_index=args.rebuild_index,
|
||||||
|
top_k=args.top_k,
|
||||||
|
first_stage_k=args.first_stage_k,
|
||||||
|
k_values=k_values,
|
||||||
|
output_dir=args.output_dir,
|
||||||
|
)
|
||||||
|
all_scores[task_name] = scores
|
||||||
|
except Exception as e:
|
||||||
|
print(f"\nError evaluating {task_name}: {e}")
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Print summary
|
||||||
|
if all_scores:
|
||||||
|
print(f"\n{'='*80}")
|
||||||
|
print("SUMMARY")
|
||||||
|
print(f"{'='*80}")
|
||||||
|
for task_name, scores in all_scores.items():
|
||||||
|
print(f"\n{task_name}:")
|
||||||
|
# Print main metrics
|
||||||
|
for metric in ["ndcg_at_5", "ndcg_at_10", "ndcg_at_100", "map_at_10", "recall_at_10"]:
|
||||||
|
if metric in scores:
|
||||||
|
print(f" {metric}: {scores[metric]:.5f}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
|
||||||
@@ -25,38 +25,14 @@ Usage:
|
|||||||
import argparse
|
import argparse
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import time
|
from typing import Optional
|
||||||
from pathlib import Path
|
|
||||||
from typing import Any, Optional
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
# Import MTEB for evaluation metrics
|
|
||||||
try:
|
|
||||||
import pytrec_eval
|
|
||||||
from mteb._evaluators.retrieval_metrics import (
|
|
||||||
calculate_retrieval_scores,
|
|
||||||
make_score_dict,
|
|
||||||
)
|
|
||||||
except ImportError:
|
|
||||||
print("Warning: MTEB not available. Install with: pip install mteb")
|
|
||||||
pytrec_eval = None
|
|
||||||
|
|
||||||
from leann_multi_vector import (
|
from leann_multi_vector import (
|
||||||
_ensure_repo_paths_importable,
|
_ensure_repo_paths_importable,
|
||||||
_load_colvision,
|
ViDoReBenchmarkEvaluator,
|
||||||
_embed_images,
|
|
||||||
_embed_queries,
|
|
||||||
_build_index,
|
|
||||||
_load_retriever_if_index_exists,
|
|
||||||
_build_fast_plaid_index,
|
|
||||||
_load_fast_plaid_index_if_exists,
|
|
||||||
_search_fast_plaid,
|
|
||||||
_get_fast_plaid_image,
|
|
||||||
_get_fast_plaid_metadata,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
_ensure_repo_paths_importable(__file__)
|
_ensure_repo_paths_importable(__file__)
|
||||||
@@ -181,194 +157,14 @@ def load_vidore_v2_data(
|
|||||||
# Filter qrels to only include queries that exist
|
# Filter qrels to only include queries that exist
|
||||||
qrels = {qid: rel_docs for qid, rel_docs in qrels.items() if qid in queries}
|
qrels = {qid: rel_docs for qid, rel_docs in qrels.items() if qid in queries}
|
||||||
|
|
||||||
return corpus, queries, qrels
|
# Filter out queries without any relevant documents (matching MTEB behavior)
|
||||||
|
# This is important for correct NDCG calculation
|
||||||
|
qrels_filtered = {qid: rel_docs for qid, rel_docs in qrels.items() if len(rel_docs) > 0}
|
||||||
def build_index_from_corpus(
|
queries_filtered = {qid: query_text for qid, query_text in queries.items() if qid in qrels_filtered}
|
||||||
corpus: dict[str, Image.Image],
|
|
||||||
model,
|
|
||||||
processor,
|
|
||||||
index_path: str,
|
|
||||||
use_fast_plaid: bool = False,
|
|
||||||
rebuild: bool = False,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Build index from corpus images.
|
|
||||||
|
|
||||||
Returns:
|
print(f"After filtering queries without positives: {len(queries_filtered)} queries, {len(qrels_filtered)} query-relevance mappings")
|
||||||
tuple: (retriever or fast_plaid_index object, list of corpus_ids in order)
|
|
||||||
"""
|
|
||||||
# Ensure consistent ordering
|
|
||||||
corpus_ids = sorted(corpus.keys()) # Sort for consistency
|
|
||||||
images = [corpus[cid] for cid in corpus_ids]
|
|
||||||
|
|
||||||
if use_fast_plaid:
|
return corpus, queries_filtered, qrels_filtered
|
||||||
# Check if Fast-Plaid index exists
|
|
||||||
if not rebuild and _load_fast_plaid_index_if_exists(index_path) is not None:
|
|
||||||
print(f"Fast-Plaid index already exists at {index_path}")
|
|
||||||
return _load_fast_plaid_index_if_exists(index_path), corpus_ids
|
|
||||||
|
|
||||||
print(f"Building Fast-Plaid index at {index_path}...")
|
|
||||||
|
|
||||||
# Embed images
|
|
||||||
print("Embedding images...")
|
|
||||||
doc_vecs = _embed_images(model, processor, images)
|
|
||||||
|
|
||||||
# Build index
|
|
||||||
fast_plaid_index, build_time = _build_fast_plaid_index(
|
|
||||||
index_path, doc_vecs, corpus_ids, images
|
|
||||||
)
|
|
||||||
print(f"Fast-Plaid index built in {build_time:.2f}s")
|
|
||||||
return fast_plaid_index, corpus_ids
|
|
||||||
else:
|
|
||||||
# Check if LEANN index exists
|
|
||||||
if not rebuild:
|
|
||||||
retriever = _load_retriever_if_index_exists(index_path)
|
|
||||||
if retriever is not None:
|
|
||||||
print(f"LEANN index already exists at {index_path}")
|
|
||||||
return retriever, corpus_ids
|
|
||||||
|
|
||||||
print(f"Building LEANN index at {index_path}...")
|
|
||||||
|
|
||||||
# Embed images
|
|
||||||
print("Embedding images...")
|
|
||||||
doc_vecs = _embed_images(model, processor, images)
|
|
||||||
|
|
||||||
# Build index
|
|
||||||
retriever = _build_index(index_path, doc_vecs, corpus_ids, images)
|
|
||||||
print(f"LEANN index built")
|
|
||||||
return retriever, corpus_ids
|
|
||||||
|
|
||||||
|
|
||||||
def search_queries(
|
|
||||||
queries: dict[str, str],
|
|
||||||
corpus_ids: list[str],
|
|
||||||
model,
|
|
||||||
processor,
|
|
||||||
index_or_retriever: Any,
|
|
||||||
use_fast_plaid: bool = False,
|
|
||||||
fast_plaid_index_path: Optional[str] = None,
|
|
||||||
top_k: int = 100,
|
|
||||||
first_stage_k: int = 500,
|
|
||||||
task_prompt: Optional[dict[str, str]] = None,
|
|
||||||
) -> dict[str, dict[str, float]]:
|
|
||||||
"""
|
|
||||||
Search queries against the index.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
queries: dict mapping query_id to query text
|
|
||||||
corpus_ids: list of corpus_ids in the same order as the index
|
|
||||||
model: model object
|
|
||||||
processor: processor object
|
|
||||||
index_or_retriever: index or retriever object
|
|
||||||
use_fast_plaid: whether using Fast-Plaid
|
|
||||||
fast_plaid_index_path: path to Fast-Plaid index (for metadata)
|
|
||||||
top_k: top-k results to retrieve
|
|
||||||
first_stage_k: first stage k for LEANN search
|
|
||||||
task_prompt: Optional dict with prompt for query (e.g., {"query": "..."})
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
results: dict mapping query_id to dict of {corpus_id: score}
|
|
||||||
"""
|
|
||||||
print(f"Searching {len(queries)} queries (top_k={top_k})...")
|
|
||||||
|
|
||||||
query_ids = list(queries.keys())
|
|
||||||
query_texts = [queries[qid] for qid in query_ids]
|
|
||||||
|
|
||||||
# Match MTEB: combine queries with instruction/prompt if provided
|
|
||||||
# MTEB's _combine_queries_with_instruction_text does: query + " " + instruction
|
|
||||||
if task_prompt and "query" in task_prompt:
|
|
||||||
instruction = task_prompt["query"]
|
|
||||||
query_texts = [q + " " + instruction for q in query_texts]
|
|
||||||
print(f"Added task prompt to queries: {instruction}")
|
|
||||||
|
|
||||||
# Embed queries
|
|
||||||
print("Embedding queries...")
|
|
||||||
query_vecs = _embed_queries(model, processor, query_texts)
|
|
||||||
|
|
||||||
results = {}
|
|
||||||
|
|
||||||
for query_id, query_vec in zip(tqdm(query_ids, desc="Searching"), query_vecs):
|
|
||||||
if use_fast_plaid:
|
|
||||||
# Fast-Plaid search
|
|
||||||
search_results, _ = _search_fast_plaid(index_or_retriever, query_vec, top_k)
|
|
||||||
# Convert doc_id back to corpus_id
|
|
||||||
query_results = {}
|
|
||||||
for score, doc_id in search_results:
|
|
||||||
if doc_id < len(corpus_ids):
|
|
||||||
corpus_id = corpus_ids[doc_id]
|
|
||||||
query_results[corpus_id] = float(score)
|
|
||||||
else:
|
|
||||||
# LEANN search
|
|
||||||
query_np = query_vec.float().numpy()
|
|
||||||
search_results = index_or_retriever.search_exact_all(query_np, topk=top_k)
|
|
||||||
# Convert doc_id back to corpus_id
|
|
||||||
query_results = {}
|
|
||||||
for score, doc_id in search_results:
|
|
||||||
if doc_id < len(corpus_ids):
|
|
||||||
corpus_id = corpus_ids[doc_id]
|
|
||||||
query_results[corpus_id] = float(score)
|
|
||||||
|
|
||||||
results[query_id] = query_results
|
|
||||||
|
|
||||||
return results
|
|
||||||
|
|
||||||
|
|
||||||
def evaluate_results(
|
|
||||||
results: dict[str, dict[str, float]],
|
|
||||||
qrels: dict[str, dict[str, int]],
|
|
||||||
k_values: list[int] = [1, 3, 5, 10, 100],
|
|
||||||
) -> dict[str, float]:
|
|
||||||
"""
|
|
||||||
Evaluate retrieval results using NDCG and other metrics.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dictionary of metric scores
|
|
||||||
"""
|
|
||||||
if pytrec_eval is None:
|
|
||||||
raise ImportError("pytrec_eval is required for evaluation. Install with: pip install pytrec-eval")
|
|
||||||
|
|
||||||
# Check if we have any queries to evaluate
|
|
||||||
if len(results) == 0:
|
|
||||||
print("Warning: No queries to evaluate. Returning zero scores.")
|
|
||||||
# Return zero scores for all metrics
|
|
||||||
scores = {}
|
|
||||||
for k in k_values:
|
|
||||||
scores[f"ndcg_at_{k}"] = 0.0
|
|
||||||
scores[f"map_at_{k}"] = 0.0
|
|
||||||
scores[f"recall_at_{k}"] = 0.0
|
|
||||||
scores[f"precision_at_{k}"] = 0.0
|
|
||||||
scores[f"mrr_at_{k}"] = 0.0
|
|
||||||
return scores
|
|
||||||
|
|
||||||
print(f"Evaluating results with k_values={k_values}...")
|
|
||||||
|
|
||||||
# Convert qrels to pytrec_eval format
|
|
||||||
qrels_pytrec = {}
|
|
||||||
for qid, rel_docs in qrels.items():
|
|
||||||
qrels_pytrec[qid] = {did: score for did, score in rel_docs.items()}
|
|
||||||
|
|
||||||
# Evaluate
|
|
||||||
eval_result = calculate_retrieval_scores(
|
|
||||||
results=results,
|
|
||||||
qrels=qrels_pytrec,
|
|
||||||
k_values=k_values,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Format scores
|
|
||||||
scores = make_score_dict(
|
|
||||||
ndcg=eval_result.ndcg,
|
|
||||||
_map=eval_result.map,
|
|
||||||
recall=eval_result.recall,
|
|
||||||
precision=eval_result.precision,
|
|
||||||
mrr=eval_result.mrr,
|
|
||||||
naucs=eval_result.naucs,
|
|
||||||
naucs_mrr=eval_result.naucs_mrr,
|
|
||||||
cv_recall=eval_result.cv_recall,
|
|
||||||
task_scores={},
|
|
||||||
)
|
|
||||||
|
|
||||||
return scores
|
|
||||||
|
|
||||||
|
|
||||||
def evaluate_task(
|
def evaluate_task(
|
||||||
@@ -432,10 +228,14 @@ def evaluate_task(
|
|||||||
scores[f"mrr_at_{k}"] = 0.0
|
scores[f"mrr_at_{k}"] = 0.0
|
||||||
return scores
|
return scores
|
||||||
|
|
||||||
# Load model
|
# Initialize evaluator
|
||||||
print(f"\nLoading model: {model_name}")
|
evaluator = ViDoReBenchmarkEvaluator(
|
||||||
model_name_actual, model, processor, device_str, device, dtype = _load_colvision(model_name)
|
model_name=model_name,
|
||||||
print(f"Model loaded: {model_name_actual}")
|
use_fast_plaid=use_fast_plaid,
|
||||||
|
top_k=top_k,
|
||||||
|
first_stage_k=first_stage_k,
|
||||||
|
k_values=k_values,
|
||||||
|
)
|
||||||
|
|
||||||
# Build or load index
|
# Build or load index
|
||||||
index_path_full = index_path if not use_fast_plaid else fast_plaid_index_path
|
index_path_full = index_path if not use_fast_plaid else fast_plaid_index_path
|
||||||
@@ -444,32 +244,24 @@ def evaluate_task(
|
|||||||
if use_fast_plaid:
|
if use_fast_plaid:
|
||||||
index_path_full = f"./indexes/{task_name}_{model_name}_fastplaid"
|
index_path_full = f"./indexes/{task_name}_{model_name}_fastplaid"
|
||||||
|
|
||||||
index_or_retriever, corpus_ids_ordered = build_index_from_corpus(
|
index_or_retriever, corpus_ids_ordered = evaluator.build_index_from_corpus(
|
||||||
corpus=corpus,
|
corpus=corpus,
|
||||||
model=model,
|
|
||||||
processor=processor,
|
|
||||||
index_path=index_path_full,
|
index_path=index_path_full,
|
||||||
use_fast_plaid=use_fast_plaid,
|
|
||||||
rebuild=rebuild_index,
|
rebuild=rebuild_index,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Search queries
|
# Search queries
|
||||||
task_prompt = task_config.get("prompt")
|
task_prompt = task_config.get("prompt")
|
||||||
results = search_queries(
|
results = evaluator.search_queries(
|
||||||
queries=queries,
|
queries=queries,
|
||||||
corpus_ids=corpus_ids_ordered,
|
corpus_ids=corpus_ids_ordered,
|
||||||
model=model,
|
|
||||||
processor=processor,
|
|
||||||
index_or_retriever=index_or_retriever,
|
index_or_retriever=index_or_retriever,
|
||||||
use_fast_plaid=use_fast_plaid,
|
|
||||||
fast_plaid_index_path=fast_plaid_index_path,
|
fast_plaid_index_path=fast_plaid_index_path,
|
||||||
top_k=top_k,
|
|
||||||
first_stage_k=first_stage_k,
|
|
||||||
task_prompt=task_prompt,
|
task_prompt=task_prompt,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Evaluate
|
# Evaluate
|
||||||
scores = evaluate_results(results, qrels, k_values=k_values)
|
scores = evaluator.evaluate_results(results, qrels, k_values=k_values)
|
||||||
|
|
||||||
# Print results
|
# Print results
|
||||||
print(f"\n{'='*80}")
|
print(f"\n{'='*80}")
|
||||||
|
|||||||
Reference in New Issue
Block a user