reproduce docvqa results and add debug file
This commit is contained in:
@@ -11,13 +11,13 @@ This script uses the interface from leann_multi_vector.py to:
|
||||
Usage:
|
||||
# Evaluate all ViDoRe v2 tasks
|
||||
python vidore_v2_benchmark.py --model colqwen2 --tasks all
|
||||
|
||||
|
||||
# Evaluate specific task
|
||||
python vidore_v2_benchmark.py --model colqwen2 --task Vidore2ESGReportsRetrieval
|
||||
|
||||
|
||||
# Use Fast-Plaid index
|
||||
python vidore_v2_benchmark.py --model colqwen2 --use-fast-plaid --fast-plaid-index-path ./indexes/vidore_fastplaid
|
||||
|
||||
|
||||
# Rebuild index
|
||||
python vidore_v2_benchmark.py --model colqwen2 --rebuild-index
|
||||
"""
|
||||
@@ -28,11 +28,9 @@ import os
|
||||
from typing import Optional
|
||||
|
||||
from datasets import load_dataset
|
||||
from PIL import Image
|
||||
|
||||
from leann_multi_vector import (
|
||||
_ensure_repo_paths_importable,
|
||||
ViDoReBenchmarkEvaluator,
|
||||
_ensure_repo_paths_importable,
|
||||
)
|
||||
|
||||
_ensure_repo_paths_importable(__file__)
|
||||
@@ -85,51 +83,57 @@ def load_vidore_v2_data(
|
||||
):
|
||||
"""
|
||||
Load ViDoRe v2 dataset.
|
||||
|
||||
|
||||
Returns:
|
||||
corpus: dict mapping corpus_id to PIL Image
|
||||
queries: dict mapping query_id to query text
|
||||
qrels: dict mapping query_id to dict of {corpus_id: relevance_score}
|
||||
"""
|
||||
print(f"Loading dataset: {dataset_path} (split={split}, language={language})")
|
||||
|
||||
|
||||
# Load queries
|
||||
query_ds = load_dataset(dataset_path, "queries", split=split, revision=revision)
|
||||
|
||||
|
||||
# Check if dataset has language field before filtering
|
||||
has_language_field = len(query_ds) > 0 and "language" in query_ds.column_names
|
||||
|
||||
|
||||
if language and has_language_field:
|
||||
# Map language name to dataset language field value (e.g., "english" -> "eng-Latn")
|
||||
dataset_language = LANGUAGE_MAPPING.get(language, language)
|
||||
query_ds_filtered = query_ds.filter(lambda x: x.get("language") == dataset_language)
|
||||
# Check if filtering resulted in empty dataset
|
||||
if len(query_ds_filtered) == 0:
|
||||
print(f"Warning: No queries found after filtering by language '{language}' (mapped to '{dataset_language}').")
|
||||
print(
|
||||
f"Warning: No queries found after filtering by language '{language}' (mapped to '{dataset_language}')."
|
||||
)
|
||||
# Try with original language value (dataset might use simple names like 'english')
|
||||
print(f"Trying with original language value '{language}'...")
|
||||
query_ds_filtered = query_ds.filter(lambda x: x.get("language") == language)
|
||||
if len(query_ds_filtered) == 0:
|
||||
# Try to get a sample to see actual language values
|
||||
try:
|
||||
sample_ds = load_dataset(dataset_path, "queries", split=split, revision=revision)
|
||||
sample_ds = load_dataset(
|
||||
dataset_path, "queries", split=split, revision=revision
|
||||
)
|
||||
if len(sample_ds) > 0 and "language" in sample_ds.column_names:
|
||||
sample_langs = set(sample_ds["language"])
|
||||
print(f"Available language values in dataset: {sample_langs}")
|
||||
except Exception:
|
||||
pass
|
||||
else:
|
||||
print(f"Found {len(query_ds_filtered)} queries using original language value '{language}'")
|
||||
print(
|
||||
f"Found {len(query_ds_filtered)} queries using original language value '{language}'"
|
||||
)
|
||||
query_ds = query_ds_filtered
|
||||
|
||||
|
||||
queries = {}
|
||||
for row in query_ds:
|
||||
query_id = f"query-{split}-{row['query-id']}"
|
||||
queries[query_id] = row["query"]
|
||||
|
||||
|
||||
# Load corpus (images)
|
||||
corpus_ds = load_dataset(dataset_path, "corpus", split=split, revision=revision)
|
||||
|
||||
|
||||
corpus = {}
|
||||
for row in corpus_ds:
|
||||
corpus_id = f"corpus-{split}-{row['corpus-id']}"
|
||||
@@ -139,11 +143,13 @@ def load_vidore_v2_data(
|
||||
elif "page_image" in row:
|
||||
corpus[corpus_id] = row["page_image"]
|
||||
else:
|
||||
raise ValueError(f"No image field found in corpus. Available fields: {list(row.keys())}")
|
||||
|
||||
raise ValueError(
|
||||
f"No image field found in corpus. Available fields: {list(row.keys())}"
|
||||
)
|
||||
|
||||
# Load qrels (relevance judgments)
|
||||
qrels_ds = load_dataset(dataset_path, "qrels", split=split, revision=revision)
|
||||
|
||||
|
||||
qrels = {}
|
||||
for row in qrels_ds:
|
||||
query_id = f"query-{split}-{row['query-id']}"
|
||||
@@ -151,19 +157,25 @@ def load_vidore_v2_data(
|
||||
if query_id not in qrels:
|
||||
qrels[query_id] = {}
|
||||
qrels[query_id][corpus_id] = int(row["score"])
|
||||
|
||||
print(f"Loaded {len(queries)} queries, {len(corpus)} corpus items, {len(qrels)} query-relevance mappings")
|
||||
|
||||
|
||||
print(
|
||||
f"Loaded {len(queries)} queries, {len(corpus)} corpus items, {len(qrels)} query-relevance mappings"
|
||||
)
|
||||
|
||||
# Filter qrels to only include queries that exist
|
||||
qrels = {qid: rel_docs for qid, rel_docs in qrels.items() if qid in queries}
|
||||
|
||||
|
||||
# Filter out queries without any relevant documents (matching MTEB behavior)
|
||||
# This is important for correct NDCG calculation
|
||||
qrels_filtered = {qid: rel_docs for qid, rel_docs in qrels.items() if len(rel_docs) > 0}
|
||||
queries_filtered = {qid: query_text for qid, query_text in queries.items() if qid in qrels_filtered}
|
||||
|
||||
print(f"After filtering queries without positives: {len(queries_filtered)} queries, {len(qrels_filtered)} query-relevance mappings")
|
||||
|
||||
queries_filtered = {
|
||||
qid: query_text for qid, query_text in queries.items() if qid in qrels_filtered
|
||||
}
|
||||
|
||||
print(
|
||||
f"After filtering queries without positives: {len(queries_filtered)} queries, {len(qrels_filtered)} query-relevance mappings"
|
||||
)
|
||||
|
||||
return corpus, queries_filtered, qrels_filtered
|
||||
|
||||
|
||||
@@ -177,24 +189,24 @@ def evaluate_task(
|
||||
rebuild_index: bool = False,
|
||||
top_k: int = 100,
|
||||
first_stage_k: int = 500,
|
||||
k_values: list[int] = [1, 3, 5, 10, 100],
|
||||
k_values: Optional[list[int]] = None,
|
||||
output_dir: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Evaluate a single ViDoRe v2 task.
|
||||
"""
|
||||
print(f"\n{'='*80}")
|
||||
print(f"\n{'=' * 80}")
|
||||
print(f"Evaluating task: {task_name}")
|
||||
print(f"{'='*80}")
|
||||
|
||||
print(f"{'=' * 80}")
|
||||
|
||||
# Get task config
|
||||
if task_name not in VIDORE_V2_TASKS:
|
||||
raise ValueError(f"Unknown task: {task_name}. Available: {list(VIDORE_V2_TASKS.keys())}")
|
||||
|
||||
|
||||
task_config = VIDORE_V2_TASKS[task_name]
|
||||
dataset_path = task_config["dataset_path"]
|
||||
revision = task_config["revision"]
|
||||
|
||||
|
||||
# Determine language
|
||||
if language is None:
|
||||
# Use first language if multiple available
|
||||
@@ -206,7 +218,11 @@ def evaluate_task(
|
||||
language = languages[0]
|
||||
else:
|
||||
language = None
|
||||
|
||||
|
||||
# Initialize k_values if not provided
|
||||
if k_values is None:
|
||||
k_values = [1, 3, 5, 10, 100]
|
||||
|
||||
# Load data
|
||||
corpus, queries, qrels = load_vidore_v2_data(
|
||||
dataset_path=dataset_path,
|
||||
@@ -214,10 +230,12 @@ def evaluate_task(
|
||||
split="test",
|
||||
language=language,
|
||||
)
|
||||
|
||||
|
||||
# Check if we have any queries
|
||||
if len(queries) == 0:
|
||||
print(f"\nWarning: No queries found for task {task_name} with language {language}. Skipping evaluation.")
|
||||
print(
|
||||
f"\nWarning: No queries found for task {task_name} with language {language}. Skipping evaluation."
|
||||
)
|
||||
# Return zero scores
|
||||
scores = {}
|
||||
for k in k_values:
|
||||
@@ -227,7 +245,7 @@ def evaluate_task(
|
||||
scores[f"precision_at_{k}"] = 0.0
|
||||
scores[f"mrr_at_{k}"] = 0.0
|
||||
return scores
|
||||
|
||||
|
||||
# Initialize evaluator
|
||||
evaluator = ViDoReBenchmarkEvaluator(
|
||||
model_name=model_name,
|
||||
@@ -236,20 +254,20 @@ def evaluate_task(
|
||||
first_stage_k=first_stage_k,
|
||||
k_values=k_values,
|
||||
)
|
||||
|
||||
|
||||
# Build or load index
|
||||
index_path_full = index_path if not use_fast_plaid else fast_plaid_index_path
|
||||
if index_path_full is None:
|
||||
index_path_full = f"./indexes/{task_name}_{model_name}"
|
||||
if use_fast_plaid:
|
||||
index_path_full = f"./indexes/{task_name}_{model_name}_fastplaid"
|
||||
|
||||
|
||||
index_or_retriever, corpus_ids_ordered = evaluator.build_index_from_corpus(
|
||||
corpus=corpus,
|
||||
index_path=index_path_full,
|
||||
rebuild=rebuild_index,
|
||||
)
|
||||
|
||||
|
||||
# Search queries
|
||||
task_prompt = task_config.get("prompt")
|
||||
results = evaluator.search_queries(
|
||||
@@ -259,32 +277,32 @@ def evaluate_task(
|
||||
fast_plaid_index_path=fast_plaid_index_path,
|
||||
task_prompt=task_prompt,
|
||||
)
|
||||
|
||||
|
||||
# Evaluate
|
||||
scores = evaluator.evaluate_results(results, qrels, k_values=k_values)
|
||||
|
||||
|
||||
# Print results
|
||||
print(f"\n{'='*80}")
|
||||
print(f"\n{'=' * 80}")
|
||||
print(f"Results for {task_name}:")
|
||||
print(f"{'='*80}")
|
||||
print(f"{'=' * 80}")
|
||||
for metric, value in scores.items():
|
||||
if isinstance(value, (int, float)):
|
||||
print(f" {metric}: {value:.5f}")
|
||||
|
||||
|
||||
# Save results
|
||||
if output_dir:
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
results_file = os.path.join(output_dir, f"{task_name}_results.json")
|
||||
scores_file = os.path.join(output_dir, f"{task_name}_scores.json")
|
||||
|
||||
|
||||
with open(results_file, "w") as f:
|
||||
json.dump(results, f, indent=2)
|
||||
print(f"\nSaved results to: {results_file}")
|
||||
|
||||
|
||||
with open(scores_file, "w") as f:
|
||||
json.dump(scores, f, indent=2)
|
||||
print(f"Saved scores to: {scores_file}")
|
||||
|
||||
|
||||
return scores
|
||||
|
||||
|
||||
@@ -363,12 +381,12 @@ def main():
|
||||
default="./vidore_v2_results",
|
||||
help="Output directory for results",
|
||||
)
|
||||
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
# Parse k_values
|
||||
k_values = [int(k.strip()) for k in args.k_values.split(",")]
|
||||
|
||||
|
||||
# Determine tasks to evaluate
|
||||
if args.task:
|
||||
tasks_to_eval = [args.task]
|
||||
@@ -376,9 +394,9 @@ def main():
|
||||
tasks_to_eval = list(VIDORE_V2_TASKS.keys())
|
||||
else:
|
||||
tasks_to_eval = [t.strip() for t in args.tasks.split(",")]
|
||||
|
||||
|
||||
print(f"Tasks to evaluate: {tasks_to_eval}")
|
||||
|
||||
|
||||
# Evaluate each task
|
||||
all_scores = {}
|
||||
for task_name in tasks_to_eval:
|
||||
@@ -400,14 +418,15 @@ def main():
|
||||
except Exception as e:
|
||||
print(f"\nError evaluating {task_name}: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
continue
|
||||
|
||||
|
||||
# Print summary
|
||||
if all_scores:
|
||||
print(f"\n{'='*80}")
|
||||
print(f"\n{'=' * 80}")
|
||||
print("SUMMARY")
|
||||
print(f"{'='*80}")
|
||||
print(f"{'=' * 80}")
|
||||
for task_name, scores in all_scores.items():
|
||||
print(f"\n{task_name}:")
|
||||
# Print main metrics
|
||||
@@ -418,4 +437,3 @@ def main():
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user