reproduce docvqa results
This commit is contained in:
@@ -25,38 +25,14 @@ Usage:
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
from datasets import load_dataset
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
|
||||
# Import MTEB for evaluation metrics
|
||||
try:
|
||||
import pytrec_eval
|
||||
from mteb._evaluators.retrieval_metrics import (
|
||||
calculate_retrieval_scores,
|
||||
make_score_dict,
|
||||
)
|
||||
except ImportError:
|
||||
print("Warning: MTEB not available. Install with: pip install mteb")
|
||||
pytrec_eval = None
|
||||
|
||||
from leann_multi_vector import (
|
||||
_ensure_repo_paths_importable,
|
||||
_load_colvision,
|
||||
_embed_images,
|
||||
_embed_queries,
|
||||
_build_index,
|
||||
_load_retriever_if_index_exists,
|
||||
_build_fast_plaid_index,
|
||||
_load_fast_plaid_index_if_exists,
|
||||
_search_fast_plaid,
|
||||
_get_fast_plaid_image,
|
||||
_get_fast_plaid_metadata,
|
||||
ViDoReBenchmarkEvaluator,
|
||||
)
|
||||
|
||||
_ensure_repo_paths_importable(__file__)
|
||||
@@ -181,194 +157,14 @@ def load_vidore_v2_data(
|
||||
# Filter qrels to only include queries that exist
|
||||
qrels = {qid: rel_docs for qid, rel_docs in qrels.items() if qid in queries}
|
||||
|
||||
return corpus, queries, qrels
|
||||
|
||||
|
||||
def build_index_from_corpus(
|
||||
corpus: dict[str, Image.Image],
|
||||
model,
|
||||
processor,
|
||||
index_path: str,
|
||||
use_fast_plaid: bool = False,
|
||||
rebuild: bool = False,
|
||||
):
|
||||
"""
|
||||
Build index from corpus images.
|
||||
# Filter out queries without any relevant documents (matching MTEB behavior)
|
||||
# This is important for correct NDCG calculation
|
||||
qrels_filtered = {qid: rel_docs for qid, rel_docs in qrels.items() if len(rel_docs) > 0}
|
||||
queries_filtered = {qid: query_text for qid, query_text in queries.items() if qid in qrels_filtered}
|
||||
|
||||
Returns:
|
||||
tuple: (retriever or fast_plaid_index object, list of corpus_ids in order)
|
||||
"""
|
||||
# Ensure consistent ordering
|
||||
corpus_ids = sorted(corpus.keys()) # Sort for consistency
|
||||
images = [corpus[cid] for cid in corpus_ids]
|
||||
print(f"After filtering queries without positives: {len(queries_filtered)} queries, {len(qrels_filtered)} query-relevance mappings")
|
||||
|
||||
if use_fast_plaid:
|
||||
# Check if Fast-Plaid index exists
|
||||
if not rebuild and _load_fast_plaid_index_if_exists(index_path) is not None:
|
||||
print(f"Fast-Plaid index already exists at {index_path}")
|
||||
return _load_fast_plaid_index_if_exists(index_path), corpus_ids
|
||||
|
||||
print(f"Building Fast-Plaid index at {index_path}...")
|
||||
|
||||
# Embed images
|
||||
print("Embedding images...")
|
||||
doc_vecs = _embed_images(model, processor, images)
|
||||
|
||||
# Build index
|
||||
fast_plaid_index, build_time = _build_fast_plaid_index(
|
||||
index_path, doc_vecs, corpus_ids, images
|
||||
)
|
||||
print(f"Fast-Plaid index built in {build_time:.2f}s")
|
||||
return fast_plaid_index, corpus_ids
|
||||
else:
|
||||
# Check if LEANN index exists
|
||||
if not rebuild:
|
||||
retriever = _load_retriever_if_index_exists(index_path)
|
||||
if retriever is not None:
|
||||
print(f"LEANN index already exists at {index_path}")
|
||||
return retriever, corpus_ids
|
||||
|
||||
print(f"Building LEANN index at {index_path}...")
|
||||
|
||||
# Embed images
|
||||
print("Embedding images...")
|
||||
doc_vecs = _embed_images(model, processor, images)
|
||||
|
||||
# Build index
|
||||
retriever = _build_index(index_path, doc_vecs, corpus_ids, images)
|
||||
print(f"LEANN index built")
|
||||
return retriever, corpus_ids
|
||||
|
||||
|
||||
def search_queries(
|
||||
queries: dict[str, str],
|
||||
corpus_ids: list[str],
|
||||
model,
|
||||
processor,
|
||||
index_or_retriever: Any,
|
||||
use_fast_plaid: bool = False,
|
||||
fast_plaid_index_path: Optional[str] = None,
|
||||
top_k: int = 100,
|
||||
first_stage_k: int = 500,
|
||||
task_prompt: Optional[dict[str, str]] = None,
|
||||
) -> dict[str, dict[str, float]]:
|
||||
"""
|
||||
Search queries against the index.
|
||||
|
||||
Args:
|
||||
queries: dict mapping query_id to query text
|
||||
corpus_ids: list of corpus_ids in the same order as the index
|
||||
model: model object
|
||||
processor: processor object
|
||||
index_or_retriever: index or retriever object
|
||||
use_fast_plaid: whether using Fast-Plaid
|
||||
fast_plaid_index_path: path to Fast-Plaid index (for metadata)
|
||||
top_k: top-k results to retrieve
|
||||
first_stage_k: first stage k for LEANN search
|
||||
task_prompt: Optional dict with prompt for query (e.g., {"query": "..."})
|
||||
|
||||
Returns:
|
||||
results: dict mapping query_id to dict of {corpus_id: score}
|
||||
"""
|
||||
print(f"Searching {len(queries)} queries (top_k={top_k})...")
|
||||
|
||||
query_ids = list(queries.keys())
|
||||
query_texts = [queries[qid] for qid in query_ids]
|
||||
|
||||
# Match MTEB: combine queries with instruction/prompt if provided
|
||||
# MTEB's _combine_queries_with_instruction_text does: query + " " + instruction
|
||||
if task_prompt and "query" in task_prompt:
|
||||
instruction = task_prompt["query"]
|
||||
query_texts = [q + " " + instruction for q in query_texts]
|
||||
print(f"Added task prompt to queries: {instruction}")
|
||||
|
||||
# Embed queries
|
||||
print("Embedding queries...")
|
||||
query_vecs = _embed_queries(model, processor, query_texts)
|
||||
|
||||
results = {}
|
||||
|
||||
for query_id, query_vec in zip(tqdm(query_ids, desc="Searching"), query_vecs):
|
||||
if use_fast_plaid:
|
||||
# Fast-Plaid search
|
||||
search_results, _ = _search_fast_plaid(index_or_retriever, query_vec, top_k)
|
||||
# Convert doc_id back to corpus_id
|
||||
query_results = {}
|
||||
for score, doc_id in search_results:
|
||||
if doc_id < len(corpus_ids):
|
||||
corpus_id = corpus_ids[doc_id]
|
||||
query_results[corpus_id] = float(score)
|
||||
else:
|
||||
# LEANN search
|
||||
query_np = query_vec.float().numpy()
|
||||
search_results = index_or_retriever.search_exact_all(query_np, topk=top_k)
|
||||
# Convert doc_id back to corpus_id
|
||||
query_results = {}
|
||||
for score, doc_id in search_results:
|
||||
if doc_id < len(corpus_ids):
|
||||
corpus_id = corpus_ids[doc_id]
|
||||
query_results[corpus_id] = float(score)
|
||||
|
||||
results[query_id] = query_results
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def evaluate_results(
|
||||
results: dict[str, dict[str, float]],
|
||||
qrels: dict[str, dict[str, int]],
|
||||
k_values: list[int] = [1, 3, 5, 10, 100],
|
||||
) -> dict[str, float]:
|
||||
"""
|
||||
Evaluate retrieval results using NDCG and other metrics.
|
||||
|
||||
Returns:
|
||||
Dictionary of metric scores
|
||||
"""
|
||||
if pytrec_eval is None:
|
||||
raise ImportError("pytrec_eval is required for evaluation. Install with: pip install pytrec-eval")
|
||||
|
||||
# Check if we have any queries to evaluate
|
||||
if len(results) == 0:
|
||||
print("Warning: No queries to evaluate. Returning zero scores.")
|
||||
# Return zero scores for all metrics
|
||||
scores = {}
|
||||
for k in k_values:
|
||||
scores[f"ndcg_at_{k}"] = 0.0
|
||||
scores[f"map_at_{k}"] = 0.0
|
||||
scores[f"recall_at_{k}"] = 0.0
|
||||
scores[f"precision_at_{k}"] = 0.0
|
||||
scores[f"mrr_at_{k}"] = 0.0
|
||||
return scores
|
||||
|
||||
print(f"Evaluating results with k_values={k_values}...")
|
||||
|
||||
# Convert qrels to pytrec_eval format
|
||||
qrels_pytrec = {}
|
||||
for qid, rel_docs in qrels.items():
|
||||
qrels_pytrec[qid] = {did: score for did, score in rel_docs.items()}
|
||||
|
||||
# Evaluate
|
||||
eval_result = calculate_retrieval_scores(
|
||||
results=results,
|
||||
qrels=qrels_pytrec,
|
||||
k_values=k_values,
|
||||
)
|
||||
|
||||
# Format scores
|
||||
scores = make_score_dict(
|
||||
ndcg=eval_result.ndcg,
|
||||
_map=eval_result.map,
|
||||
recall=eval_result.recall,
|
||||
precision=eval_result.precision,
|
||||
mrr=eval_result.mrr,
|
||||
naucs=eval_result.naucs,
|
||||
naucs_mrr=eval_result.naucs_mrr,
|
||||
cv_recall=eval_result.cv_recall,
|
||||
task_scores={},
|
||||
)
|
||||
|
||||
return scores
|
||||
return corpus, queries_filtered, qrels_filtered
|
||||
|
||||
|
||||
def evaluate_task(
|
||||
@@ -432,10 +228,14 @@ def evaluate_task(
|
||||
scores[f"mrr_at_{k}"] = 0.0
|
||||
return scores
|
||||
|
||||
# Load model
|
||||
print(f"\nLoading model: {model_name}")
|
||||
model_name_actual, model, processor, device_str, device, dtype = _load_colvision(model_name)
|
||||
print(f"Model loaded: {model_name_actual}")
|
||||
# Initialize evaluator
|
||||
evaluator = ViDoReBenchmarkEvaluator(
|
||||
model_name=model_name,
|
||||
use_fast_plaid=use_fast_plaid,
|
||||
top_k=top_k,
|
||||
first_stage_k=first_stage_k,
|
||||
k_values=k_values,
|
||||
)
|
||||
|
||||
# Build or load index
|
||||
index_path_full = index_path if not use_fast_plaid else fast_plaid_index_path
|
||||
@@ -444,32 +244,24 @@ def evaluate_task(
|
||||
if use_fast_plaid:
|
||||
index_path_full = f"./indexes/{task_name}_{model_name}_fastplaid"
|
||||
|
||||
index_or_retriever, corpus_ids_ordered = build_index_from_corpus(
|
||||
index_or_retriever, corpus_ids_ordered = evaluator.build_index_from_corpus(
|
||||
corpus=corpus,
|
||||
model=model,
|
||||
processor=processor,
|
||||
index_path=index_path_full,
|
||||
use_fast_plaid=use_fast_plaid,
|
||||
rebuild=rebuild_index,
|
||||
)
|
||||
|
||||
# Search queries
|
||||
task_prompt = task_config.get("prompt")
|
||||
results = search_queries(
|
||||
results = evaluator.search_queries(
|
||||
queries=queries,
|
||||
corpus_ids=corpus_ids_ordered,
|
||||
model=model,
|
||||
processor=processor,
|
||||
index_or_retriever=index_or_retriever,
|
||||
use_fast_plaid=use_fast_plaid,
|
||||
fast_plaid_index_path=fast_plaid_index_path,
|
||||
top_k=top_k,
|
||||
first_stage_k=first_stage_k,
|
||||
task_prompt=task_prompt,
|
||||
)
|
||||
|
||||
# Evaluate
|
||||
scores = evaluate_results(results, qrels, k_values=k_values)
|
||||
scores = evaluator.evaluate_results(results, qrels, k_values=k_values)
|
||||
|
||||
# Print results
|
||||
print(f"\n{'='*80}")
|
||||
|
||||
Reference in New Issue
Block a user