From 07afe546ea14e4ca6c86eb9b0bdbe5e2dedf9ba1 Mon Sep 17 00:00:00 2001 From: yichuan-w Date: Fri, 14 Nov 2025 10:22:42 +0000 Subject: [PATCH] reproduce docvqa results --- .../leann_multi_vector.py | 268 +++++++++++- .../vidore_v1_benchmark.py | 389 ++++++++++++++++++ .../vidore_v2_benchmark.py | 246 +---------- 3 files changed, 666 insertions(+), 237 deletions(-) create mode 100644 apps/multimodal/vision-based-pdf-multi-vector/vidore_v1_benchmark.py diff --git a/apps/multimodal/vision-based-pdf-multi-vector/leann_multi_vector.py b/apps/multimodal/vision-based-pdf-multi-vector/leann_multi_vector.py index fc55852..f04d322 100644 --- a/apps/multimodal/vision-based-pdf-multi-vector/leann_multi_vector.py +++ b/apps/multimodal/vision-based-pdf-multi-vector/leann_multi_vector.py @@ -223,17 +223,13 @@ def _embed_queries(model, processor, queries: list[str]) -> list[Any]: model.eval() # Match MTEB's exact query processing from ColPaliEngineWrapper.get_text_embeddings: - # 1. MTEB receives batch["text"] which may already include instruction/prompt + # 1. MTEB receives batch["text"] which already includes instruction/prompt (from _combine_queries_with_instruction_text) # 2. Manually adds: query_prefix + text + query_augmentation_token * 10 # 3. Calls processor.process_queries(batch) where batch is now a list of strings # 4. process_queries adds: query_prefix + text + suffix (suffix = query_augmentation_token * 10) # - # However, MTEB's approach results in duplicate addition (20 tokens total). - # Since we're already adding the prompt in search_queries, let's try: - # Option 1: Just call process_queries (let it handle all additions) - avoids duplicate - # Option 2: Manual add + process_texts (to avoid duplicate) - # - # Testing shows Option 1 works better - just call process_queries without manual addition + # This results in duplicate addition: query_prefix is added twice, query_augmentation_token * 20 total + # We need to match this exactly to reproduce MTEB results all_embeds = [] batch_size = 32 # Match MTEB's default batch_size @@ -242,9 +238,15 @@ def _embed_queries(model, processor, queries: list[str]) -> list[Any]: for i in tqdm(range(0, len(queries), batch_size), desc="Embedding queries"): batch_queries = queries[i:i + batch_size] - # Just call process_queries - it will add query_prefix + text + 10 tokens - # This avoids duplicate addition that happens in MTEB's approach - inputs = processor.process_queries(batch_queries) + # Match MTEB: manually add query_prefix + text + query_augmentation_token * 10 + # Then process_queries will add them again (resulting in 20 augmentation tokens total) + batch = [ + processor.query_prefix + + t + + processor.query_augmentation_token * 10 + for t in batch_queries + ] + inputs = processor.process_queries(batch) inputs = {k: v.to(model.device) for k, v in inputs.items()} if model.device.type == "cuda": @@ -1044,3 +1046,249 @@ class LeannMultiVector: "image_path": meta.get("image_path", ""), } return None + + +class ViDoReBenchmarkEvaluator: + """ + A reusable class for evaluating ViDoRe benchmarks (v1 and v2). + This class encapsulates common functionality for building indexes, searching, and evaluating. + """ + + def __init__( + self, + model_name: str, + use_fast_plaid: bool = False, + top_k: int = 100, + first_stage_k: int = 500, + k_values: list[int] = None, + ): + """ + Initialize the evaluator. + + Args: + model_name: Model name ("colqwen2" or "colpali") + use_fast_plaid: Whether to use Fast-Plaid instead of LEANN + top_k: Top-k results to retrieve + first_stage_k: First stage k for LEANN search + k_values: List of k values for evaluation metrics + """ + self.model_name = model_name + self.use_fast_plaid = use_fast_plaid + self.top_k = top_k + self.first_stage_k = first_stage_k + self.k_values = k_values if k_values is not None else [1, 3, 5, 10, 100] + + # Load model once (can be reused across tasks) + self._model = None + self._processor = None + self._model_name_actual = None + + def _load_model_if_needed(self): + """Lazy load the model.""" + if self._model is None: + print(f"\nLoading model: {self.model_name}") + self._model_name_actual, self._model, self._processor, _, _, _ = _load_colvision(self.model_name) + print(f"Model loaded: {self._model_name_actual}") + + def build_index_from_corpus( + self, + corpus: dict[str, Image.Image], + index_path: str, + rebuild: bool = False, + ) -> tuple[Any, list[str]]: + """ + Build index from corpus images. + + Args: + corpus: dict mapping corpus_id to PIL Image + index_path: Path to save/load the index + rebuild: Whether to rebuild even if index exists + + Returns: + tuple: (retriever or fast_plaid_index object, list of corpus_ids in order) + """ + self._load_model_if_needed() + + # Ensure consistent ordering + corpus_ids = sorted(corpus.keys()) + images = [corpus[cid] for cid in corpus_ids] + + if self.use_fast_plaid: + # Check if Fast-Plaid index exists + if not rebuild and _load_fast_plaid_index_if_exists(index_path) is not None: + print(f"Fast-Plaid index already exists at {index_path}") + return _load_fast_plaid_index_if_exists(index_path), corpus_ids + + print(f"Building Fast-Plaid index at {index_path}...") + print("Embedding images...") + doc_vecs = _embed_images(self._model, self._processor, images) + + fast_plaid_index, build_time = _build_fast_plaid_index( + index_path, doc_vecs, corpus_ids, images + ) + print(f"Fast-Plaid index built in {build_time:.2f}s") + return fast_plaid_index, corpus_ids + else: + # Check if LEANN index exists + if not rebuild: + retriever = _load_retriever_if_index_exists(index_path) + if retriever is not None: + print(f"LEANN index already exists at {index_path}") + return retriever, corpus_ids + + print(f"Building LEANN index at {index_path}...") + print("Embedding images...") + doc_vecs = _embed_images(self._model, self._processor, images) + + retriever = _build_index(index_path, doc_vecs, corpus_ids, images) + print(f"LEANN index built") + return retriever, corpus_ids + + def search_queries( + self, + queries: dict[str, str], + corpus_ids: list[str], + index_or_retriever: Any, + fast_plaid_index_path: Optional[str] = None, + task_prompt: Optional[dict[str, str]] = None, + ) -> dict[str, dict[str, float]]: + """ + Search queries against the index. + + Args: + queries: dict mapping query_id to query text + corpus_ids: list of corpus_ids in the same order as the index + index_or_retriever: index or retriever object + fast_plaid_index_path: path to Fast-Plaid index (for metadata) + task_prompt: Optional dict with prompt for query (e.g., {"query": "..."}) + + Returns: + results: dict mapping query_id to dict of {corpus_id: score} + """ + self._load_model_if_needed() + + print(f"Searching {len(queries)} queries (top_k={self.top_k})...") + + query_ids = list(queries.keys()) + query_texts = [queries[qid] for qid in query_ids] + + # Note: ColPaliEngineWrapper does NOT use task prompt from metadata + # It uses query_prefix + text + query_augmentation_token (handled in _embed_queries) + # So we don't append task_prompt here to match MTEB behavior + + # Embed queries + print("Embedding queries...") + query_vecs = _embed_queries(self._model, self._processor, query_texts) + + results = {} + + for query_id, query_vec in zip(tqdm(query_ids, desc="Searching"), query_vecs): + if self.use_fast_plaid: + # Fast-Plaid search + search_results, _ = _search_fast_plaid(index_or_retriever, query_vec, self.top_k) + query_results = {} + for score, doc_id in search_results: + if doc_id < len(corpus_ids): + corpus_id = corpus_ids[doc_id] + query_results[corpus_id] = float(score) + else: + # LEANN search + import torch + query_np = query_vec.float().numpy() if isinstance(query_vec, torch.Tensor) else query_vec + search_results = index_or_retriever.search_exact_all(query_np, topk=self.top_k) + query_results = {} + for score, doc_id in search_results: + if doc_id < len(corpus_ids): + corpus_id = corpus_ids[doc_id] + query_results[corpus_id] = float(score) + + results[query_id] = query_results + + return results + + @staticmethod + def evaluate_results( + results: dict[str, dict[str, float]], + qrels: dict[str, dict[str, int]], + k_values: list[int] = None, + ) -> dict[str, float]: + """ + Evaluate retrieval results using NDCG and other metrics. + + Args: + results: dict mapping query_id to dict of {corpus_id: score} + qrels: dict mapping query_id to dict of {corpus_id: relevance_score} + k_values: List of k values for evaluation metrics + + Returns: + Dictionary of metric scores + """ + try: + import pytrec_eval + from mteb._evaluators.retrieval_metrics import ( + calculate_retrieval_scores, + make_score_dict, + ) + except ImportError: + raise ImportError("pytrec_eval is required for evaluation. Install with: pip install pytrec-eval") + + if k_values is None: + k_values = [1, 3, 5, 10, 100] + + # Check if we have any queries to evaluate + if len(results) == 0: + print("Warning: No queries to evaluate. Returning zero scores.") + scores = {} + for k in k_values: + scores[f"ndcg_at_{k}"] = 0.0 + scores[f"map_at_{k}"] = 0.0 + scores[f"recall_at_{k}"] = 0.0 + scores[f"precision_at_{k}"] = 0.0 + scores[f"mrr_at_{k}"] = 0.0 + return scores + + print(f"Evaluating results with k_values={k_values}...") + print(f"Before filtering: {len(results)} results, {len(qrels)} qrels") + + # Filter to ensure qrels and results have the same query set + # This matches MTEB behavior: only evaluate queries that exist in both + # pytrec_eval only evaluates queries in qrels, so we need to ensure + # results contains all queries in qrels, and filter out queries not in qrels + results_filtered = {qid: res for qid, res in results.items() if qid in qrels} + qrels_filtered = {qid: rel_docs for qid, rel_docs in qrels.items() if qid in results_filtered} + + print(f"After filtering: {len(results_filtered)} results, {len(qrels_filtered)} qrels") + + if len(results_filtered) != len(qrels_filtered): + print(f"Warning: Mismatch between results ({len(results_filtered)}) and qrels ({len(qrels_filtered)}) queries") + missing_in_results = set(qrels.keys()) - set(results.keys()) + if missing_in_results: + print(f"Queries in qrels but not in results: {len(missing_in_results)} queries") + print(f"First 5 missing queries: {list(missing_in_results)[:5]}") + + # Convert qrels to pytrec_eval format + qrels_pytrec = {} + for qid, rel_docs in qrels_filtered.items(): + qrels_pytrec[qid] = {did: score for did, score in rel_docs.items()} + + # Evaluate + eval_result = calculate_retrieval_scores( + results=results_filtered, + qrels=qrels_pytrec, + k_values=k_values, + ) + + # Format scores + scores = make_score_dict( + ndcg=eval_result.ndcg, + _map=eval_result.map, + recall=eval_result.recall, + precision=eval_result.precision, + mrr=eval_result.mrr, + naucs=eval_result.naucs, + naucs_mrr=eval_result.naucs_mrr, + cv_recall=eval_result.cv_recall, + task_scores={}, + ) + + return scores diff --git a/apps/multimodal/vision-based-pdf-multi-vector/vidore_v1_benchmark.py b/apps/multimodal/vision-based-pdf-multi-vector/vidore_v1_benchmark.py new file mode 100644 index 0000000..3b51f62 --- /dev/null +++ b/apps/multimodal/vision-based-pdf-multi-vector/vidore_v1_benchmark.py @@ -0,0 +1,389 @@ +#!/usr/bin/env python3 +""" +Modular script to reproduce NDCG results for ViDoRe v1 benchmark. + +This script uses the interface from leann_multi_vector.py to: +1. Download ViDoRe v1 datasets +2. Build indexes (LEANN or Fast-Plaid) +3. Perform retrieval +4. Evaluate using NDCG metrics + +Usage: + # Evaluate all ViDoRe v1 tasks + python vidore_v1_benchmark.py --model colqwen2 --tasks all + + # Evaluate specific task + python vidore_v1_benchmark.py --model colqwen2 --task VidoreArxivQARetrieval + + # Use Fast-Plaid index + python vidore_v1_benchmark.py --model colqwen2 --use-fast-plaid --fast-plaid-index-path ./indexes/vidore_fastplaid + + # Rebuild index + python vidore_v1_benchmark.py --model colqwen2 --rebuild-index +""" + +import argparse +import json +import os +from typing import Optional + +from datasets import load_dataset +from PIL import Image + +from leann_multi_vector import ( + _ensure_repo_paths_importable, + ViDoReBenchmarkEvaluator, +) + +_ensure_repo_paths_importable(__file__) + +# ViDoRe v1 task configurations +# Prompts match MTEB task metadata prompts +VIDORE_V1_TASKS = { + "VidoreArxivQARetrieval": { + "dataset_path": "vidore/arxivqa_test_subsampled_beir", + "revision": "7d94d570960eac2408d3baa7a33f9de4822ae3e4", + "prompt": {"query": "Find a screenshot that relevant to the user's question."}, + }, + "VidoreDocVQARetrieval": { + "dataset_path": "vidore/docvqa_test_subsampled_beir", + "revision": "162ba2fc1a8437eda8b6c37b240bc1c0f0deb092", + "prompt": {"query": "Find a screenshot that relevant to the user's question."}, + }, + "VidoreInfoVQARetrieval": { + "dataset_path": "vidore/infovqa_test_subsampled_beir", + "revision": "b802cc5fd6c605df2d673a963667d74881d2c9a4", + "prompt": {"query": "Find a screenshot that relevant to the user's question."}, + }, + "VidoreTabfquadRetrieval": { + "dataset_path": "vidore/tabfquad_test_subsampled_beir", + "revision": "61a2224bcd29b7b261a4892ff4c8bea353527a31", + "prompt": {"query": "Find a screenshot that relevant to the user's question."}, + }, + "VidoreTatdqaRetrieval": { + "dataset_path": "vidore/tatdqa_test_beir", + "revision": "5feb5630fdff4d8d189ffedb2dba56862fdd45c0", + "prompt": {"query": "Find a screenshot that relevant to the user's question."}, + }, + "VidoreShiftProjectRetrieval": { + "dataset_path": "vidore/shiftproject_test_beir", + "revision": "84a382e05c4473fed9cff2bbae95fe2379416117", + "prompt": {"query": "Find a screenshot that relevant to the user's question."}, + }, + "VidoreSyntheticDocQAAIRetrieval": { + "dataset_path": "vidore/syntheticDocQA_artificial_intelligence_test_beir", + "revision": "2d9ebea5a1c6e9ef4a3b902a612f605dca11261c", + "prompt": {"query": "Find a screenshot that relevant to the user's question."}, + }, + "VidoreSyntheticDocQAEnergyRetrieval": { + "dataset_path": "vidore/syntheticDocQA_energy_test_beir", + "revision": "9935aadbad5c8deec30910489db1b2c7133ae7a7", + "prompt": {"query": "Find a screenshot that relevant to the user's question."}, + }, + "VidoreSyntheticDocQAGovernmentReportsRetrieval": { + "dataset_path": "vidore/syntheticDocQA_government_reports_test_beir", + "revision": "b4909afa930f81282fd20601e860668073ad02aa", + "prompt": {"query": "Find a screenshot that relevant to the user's question."}, + }, + "VidoreSyntheticDocQAHealthcareIndustryRetrieval": { + "dataset_path": "vidore/syntheticDocQA_healthcare_industry_test_beir", + "revision": "f9e25d5b6e13e1ad9f5c3cce202565031b3ab164", + "prompt": {"query": "Find a screenshot that relevant to the user's question."}, + }, +} + + +def load_vidore_v1_data( + dataset_path: str, + revision: Optional[str] = None, + split: str = "test", +): + """ + Load ViDoRe v1 dataset. + + Returns: + corpus: dict mapping corpus_id to PIL Image + queries: dict mapping query_id to query text + qrels: dict mapping query_id to dict of {corpus_id: relevance_score} + """ + print(f"Loading dataset: {dataset_path} (split={split})") + + # Load queries + query_ds = load_dataset(dataset_path, "queries", split=split, revision=revision) + + queries = {} + for row in query_ds: + query_id = f"query-{split}-{row['query-id']}" + queries[query_id] = row["query"] + + # Load corpus (images) + corpus_ds = load_dataset(dataset_path, "corpus", split=split, revision=revision) + + corpus = {} + for row in corpus_ds: + corpus_id = f"corpus-{split}-{row['corpus-id']}" + # Extract image from the dataset row + if "image" in row: + corpus[corpus_id] = row["image"] + elif "page_image" in row: + corpus[corpus_id] = row["page_image"] + else: + raise ValueError(f"No image field found in corpus. Available fields: {list(row.keys())}") + + # Load qrels (relevance judgments) + qrels_ds = load_dataset(dataset_path, "qrels", split=split, revision=revision) + + qrels = {} + for row in qrels_ds: + query_id = f"query-{split}-{row['query-id']}" + corpus_id = f"corpus-{split}-{row['corpus-id']}" + if query_id not in qrels: + qrels[query_id] = {} + qrels[query_id][corpus_id] = int(row["score"]) + + print(f"Loaded {len(queries)} queries, {len(corpus)} corpus items, {len(qrels)} query-relevance mappings") + + # Filter qrels to only include queries that exist + qrels = {qid: rel_docs for qid, rel_docs in qrels.items() if qid in queries} + + # Filter out queries without any relevant documents (matching MTEB behavior) + # This is important for correct NDCG calculation + qrels_filtered = {qid: rel_docs for qid, rel_docs in qrels.items() if len(rel_docs) > 0} + queries_filtered = {qid: query_text for qid, query_text in queries.items() if qid in qrels_filtered} + + print(f"After filtering queries without positives: {len(queries_filtered)} queries, {len(qrels_filtered)} query-relevance mappings") + + return corpus, queries_filtered, qrels_filtered + + +def evaluate_task( + task_name: str, + model_name: str, + index_path: str, + use_fast_plaid: bool = False, + fast_plaid_index_path: Optional[str] = None, + rebuild_index: bool = False, + top_k: int = 1000, + first_stage_k: int = 500, + k_values: list[int] = [1, 3, 5, 10, 20, 100, 1000], + output_dir: Optional[str] = None, +): + """ + Evaluate a single ViDoRe v1 task. + """ + print(f"\n{'='*80}") + print(f"Evaluating task: {task_name}") + print(f"{'='*80}") + + # Get task config + if task_name not in VIDORE_V1_TASKS: + raise ValueError(f"Unknown task: {task_name}. Available: {list(VIDORE_V1_TASKS.keys())}") + + task_config = VIDORE_V1_TASKS[task_name] + dataset_path = task_config["dataset_path"] + revision = task_config["revision"] + + # Load data + corpus, queries, qrels = load_vidore_v1_data( + dataset_path=dataset_path, + revision=revision, + split="test", + ) + + # Check if we have any queries + if len(queries) == 0: + print(f"\nWarning: No queries found for task {task_name}. Skipping evaluation.") + # Return zero scores + scores = {} + for k in k_values: + scores[f"ndcg_at_{k}"] = 0.0 + scores[f"map_at_{k}"] = 0.0 + scores[f"recall_at_{k}"] = 0.0 + scores[f"precision_at_{k}"] = 0.0 + scores[f"mrr_at_{k}"] = 0.0 + return scores + + # Initialize evaluator + evaluator = ViDoReBenchmarkEvaluator( + model_name=model_name, + use_fast_plaid=use_fast_plaid, + top_k=top_k, + first_stage_k=first_stage_k, + k_values=k_values, + ) + + # Build or load index + index_path_full = index_path if not use_fast_plaid else fast_plaid_index_path + if index_path_full is None: + index_path_full = f"./indexes/{task_name}_{model_name}" + if use_fast_plaid: + index_path_full = f"./indexes/{task_name}_{model_name}_fastplaid" + + index_or_retriever, corpus_ids_ordered = evaluator.build_index_from_corpus( + corpus=corpus, + index_path=index_path_full, + rebuild=rebuild_index, + ) + + # Search queries + task_prompt = task_config.get("prompt") + results = evaluator.search_queries( + queries=queries, + corpus_ids=corpus_ids_ordered, + index_or_retriever=index_or_retriever, + fast_plaid_index_path=fast_plaid_index_path, + task_prompt=task_prompt, + ) + + # Evaluate + scores = evaluator.evaluate_results(results, qrels, k_values=k_values) + + # Print results + print(f"\n{'='*80}") + print(f"Results for {task_name}:") + print(f"{'='*80}") + for metric, value in scores.items(): + if isinstance(value, (int, float)): + print(f" {metric}: {value:.5f}") + + # Save results + if output_dir: + os.makedirs(output_dir, exist_ok=True) + results_file = os.path.join(output_dir, f"{task_name}_results.json") + scores_file = os.path.join(output_dir, f"{task_name}_scores.json") + + with open(results_file, "w") as f: + json.dump(results, f, indent=2) + print(f"\nSaved results to: {results_file}") + + with open(scores_file, "w") as f: + json.dump(scores, f, indent=2) + print(f"Saved scores to: {scores_file}") + + return scores + + +def main(): + parser = argparse.ArgumentParser( + description="Evaluate ViDoRe v1 benchmark using LEANN/Fast-Plaid indexing" + ) + parser.add_argument( + "--model", + type=str, + default="colqwen2", + choices=["colqwen2", "colpali"], + help="Model to use", + ) + parser.add_argument( + "--task", + type=str, + default=None, + help="Specific task to evaluate (or 'all' for all tasks)", + ) + parser.add_argument( + "--tasks", + type=str, + default="all", + help="Tasks to evaluate: 'all' or comma-separated list", + ) + parser.add_argument( + "--index-path", + type=str, + default=None, + help="Path to LEANN index (auto-generated if not provided)", + ) + parser.add_argument( + "--use-fast-plaid", + action="store_true", + help="Use Fast-Plaid instead of LEANN", + ) + parser.add_argument( + "--fast-plaid-index-path", + type=str, + default=None, + help="Path to Fast-Plaid index (auto-generated if not provided)", + ) + parser.add_argument( + "--rebuild-index", + action="store_true", + help="Rebuild index even if it exists", + ) + parser.add_argument( + "--top-k", + type=int, + default=1000, + help="Top-k results to retrieve (MTEB default is max(k_values)=1000)", + ) + parser.add_argument( + "--first-stage-k", + type=int, + default=500, + help="First stage k for LEANN search", + ) + parser.add_argument( + "--k-values", + type=str, + default="1,3,5,10,20,100,1000", + help="Comma-separated k values for evaluation (e.g., '1,3,5,10,100')", + ) + parser.add_argument( + "--output-dir", + type=str, + default="./vidore_v1_results", + help="Output directory for results", + ) + + args = parser.parse_args() + + # Parse k_values + k_values = [int(k.strip()) for k in args.k_values.split(",")] + + # Determine tasks to evaluate + if args.task: + tasks_to_eval = [args.task] + elif args.tasks.lower() == "all": + tasks_to_eval = list(VIDORE_V1_TASKS.keys()) + else: + tasks_to_eval = [t.strip() for t in args.tasks.split(",")] + + print(f"Tasks to evaluate: {tasks_to_eval}") + + # Evaluate each task + all_scores = {} + for task_name in tasks_to_eval: + try: + scores = evaluate_task( + task_name=task_name, + model_name=args.model, + index_path=args.index_path, + use_fast_plaid=args.use_fast_plaid, + fast_plaid_index_path=args.fast_plaid_index_path, + rebuild_index=args.rebuild_index, + top_k=args.top_k, + first_stage_k=args.first_stage_k, + k_values=k_values, + output_dir=args.output_dir, + ) + all_scores[task_name] = scores + except Exception as e: + print(f"\nError evaluating {task_name}: {e}") + import traceback + traceback.print_exc() + continue + + # Print summary + if all_scores: + print(f"\n{'='*80}") + print("SUMMARY") + print(f"{'='*80}") + for task_name, scores in all_scores.items(): + print(f"\n{task_name}:") + # Print main metrics + for metric in ["ndcg_at_5", "ndcg_at_10", "ndcg_at_100", "map_at_10", "recall_at_10"]: + if metric in scores: + print(f" {metric}: {scores[metric]:.5f}") + + +if __name__ == "__main__": + main() + diff --git a/apps/multimodal/vision-based-pdf-multi-vector/vidore_v2_benchmark.py b/apps/multimodal/vision-based-pdf-multi-vector/vidore_v2_benchmark.py index 6150ad6..5213e4c 100644 --- a/apps/multimodal/vision-based-pdf-multi-vector/vidore_v2_benchmark.py +++ b/apps/multimodal/vision-based-pdf-multi-vector/vidore_v2_benchmark.py @@ -25,38 +25,14 @@ Usage: import argparse import json import os -import time -from pathlib import Path -from typing import Any, Optional +from typing import Optional -import numpy as np from datasets import load_dataset from PIL import Image -from tqdm import tqdm - -# Import MTEB for evaluation metrics -try: - import pytrec_eval - from mteb._evaluators.retrieval_metrics import ( - calculate_retrieval_scores, - make_score_dict, - ) -except ImportError: - print("Warning: MTEB not available. Install with: pip install mteb") - pytrec_eval = None from leann_multi_vector import ( _ensure_repo_paths_importable, - _load_colvision, - _embed_images, - _embed_queries, - _build_index, - _load_retriever_if_index_exists, - _build_fast_plaid_index, - _load_fast_plaid_index_if_exists, - _search_fast_plaid, - _get_fast_plaid_image, - _get_fast_plaid_metadata, + ViDoReBenchmarkEvaluator, ) _ensure_repo_paths_importable(__file__) @@ -181,194 +157,14 @@ def load_vidore_v2_data( # Filter qrels to only include queries that exist qrels = {qid: rel_docs for qid, rel_docs in qrels.items() if qid in queries} - return corpus, queries, qrels - - -def build_index_from_corpus( - corpus: dict[str, Image.Image], - model, - processor, - index_path: str, - use_fast_plaid: bool = False, - rebuild: bool = False, -): - """ - Build index from corpus images. + # Filter out queries without any relevant documents (matching MTEB behavior) + # This is important for correct NDCG calculation + qrels_filtered = {qid: rel_docs for qid, rel_docs in qrels.items() if len(rel_docs) > 0} + queries_filtered = {qid: query_text for qid, query_text in queries.items() if qid in qrels_filtered} - Returns: - tuple: (retriever or fast_plaid_index object, list of corpus_ids in order) - """ - # Ensure consistent ordering - corpus_ids = sorted(corpus.keys()) # Sort for consistency - images = [corpus[cid] for cid in corpus_ids] + print(f"After filtering queries without positives: {len(queries_filtered)} queries, {len(qrels_filtered)} query-relevance mappings") - if use_fast_plaid: - # Check if Fast-Plaid index exists - if not rebuild and _load_fast_plaid_index_if_exists(index_path) is not None: - print(f"Fast-Plaid index already exists at {index_path}") - return _load_fast_plaid_index_if_exists(index_path), corpus_ids - - print(f"Building Fast-Plaid index at {index_path}...") - - # Embed images - print("Embedding images...") - doc_vecs = _embed_images(model, processor, images) - - # Build index - fast_plaid_index, build_time = _build_fast_plaid_index( - index_path, doc_vecs, corpus_ids, images - ) - print(f"Fast-Plaid index built in {build_time:.2f}s") - return fast_plaid_index, corpus_ids - else: - # Check if LEANN index exists - if not rebuild: - retriever = _load_retriever_if_index_exists(index_path) - if retriever is not None: - print(f"LEANN index already exists at {index_path}") - return retriever, corpus_ids - - print(f"Building LEANN index at {index_path}...") - - # Embed images - print("Embedding images...") - doc_vecs = _embed_images(model, processor, images) - - # Build index - retriever = _build_index(index_path, doc_vecs, corpus_ids, images) - print(f"LEANN index built") - return retriever, corpus_ids - - -def search_queries( - queries: dict[str, str], - corpus_ids: list[str], - model, - processor, - index_or_retriever: Any, - use_fast_plaid: bool = False, - fast_plaid_index_path: Optional[str] = None, - top_k: int = 100, - first_stage_k: int = 500, - task_prompt: Optional[dict[str, str]] = None, -) -> dict[str, dict[str, float]]: - """ - Search queries against the index. - - Args: - queries: dict mapping query_id to query text - corpus_ids: list of corpus_ids in the same order as the index - model: model object - processor: processor object - index_or_retriever: index or retriever object - use_fast_plaid: whether using Fast-Plaid - fast_plaid_index_path: path to Fast-Plaid index (for metadata) - top_k: top-k results to retrieve - first_stage_k: first stage k for LEANN search - task_prompt: Optional dict with prompt for query (e.g., {"query": "..."}) - - Returns: - results: dict mapping query_id to dict of {corpus_id: score} - """ - print(f"Searching {len(queries)} queries (top_k={top_k})...") - - query_ids = list(queries.keys()) - query_texts = [queries[qid] for qid in query_ids] - - # Match MTEB: combine queries with instruction/prompt if provided - # MTEB's _combine_queries_with_instruction_text does: query + " " + instruction - if task_prompt and "query" in task_prompt: - instruction = task_prompt["query"] - query_texts = [q + " " + instruction for q in query_texts] - print(f"Added task prompt to queries: {instruction}") - - # Embed queries - print("Embedding queries...") - query_vecs = _embed_queries(model, processor, query_texts) - - results = {} - - for query_id, query_vec in zip(tqdm(query_ids, desc="Searching"), query_vecs): - if use_fast_plaid: - # Fast-Plaid search - search_results, _ = _search_fast_plaid(index_or_retriever, query_vec, top_k) - # Convert doc_id back to corpus_id - query_results = {} - for score, doc_id in search_results: - if doc_id < len(corpus_ids): - corpus_id = corpus_ids[doc_id] - query_results[corpus_id] = float(score) - else: - # LEANN search - query_np = query_vec.float().numpy() - search_results = index_or_retriever.search_exact_all(query_np, topk=top_k) - # Convert doc_id back to corpus_id - query_results = {} - for score, doc_id in search_results: - if doc_id < len(corpus_ids): - corpus_id = corpus_ids[doc_id] - query_results[corpus_id] = float(score) - - results[query_id] = query_results - - return results - - -def evaluate_results( - results: dict[str, dict[str, float]], - qrels: dict[str, dict[str, int]], - k_values: list[int] = [1, 3, 5, 10, 100], -) -> dict[str, float]: - """ - Evaluate retrieval results using NDCG and other metrics. - - Returns: - Dictionary of metric scores - """ - if pytrec_eval is None: - raise ImportError("pytrec_eval is required for evaluation. Install with: pip install pytrec-eval") - - # Check if we have any queries to evaluate - if len(results) == 0: - print("Warning: No queries to evaluate. Returning zero scores.") - # Return zero scores for all metrics - scores = {} - for k in k_values: - scores[f"ndcg_at_{k}"] = 0.0 - scores[f"map_at_{k}"] = 0.0 - scores[f"recall_at_{k}"] = 0.0 - scores[f"precision_at_{k}"] = 0.0 - scores[f"mrr_at_{k}"] = 0.0 - return scores - - print(f"Evaluating results with k_values={k_values}...") - - # Convert qrels to pytrec_eval format - qrels_pytrec = {} - for qid, rel_docs in qrels.items(): - qrels_pytrec[qid] = {did: score for did, score in rel_docs.items()} - - # Evaluate - eval_result = calculate_retrieval_scores( - results=results, - qrels=qrels_pytrec, - k_values=k_values, - ) - - # Format scores - scores = make_score_dict( - ndcg=eval_result.ndcg, - _map=eval_result.map, - recall=eval_result.recall, - precision=eval_result.precision, - mrr=eval_result.mrr, - naucs=eval_result.naucs, - naucs_mrr=eval_result.naucs_mrr, - cv_recall=eval_result.cv_recall, - task_scores={}, - ) - - return scores + return corpus, queries_filtered, qrels_filtered def evaluate_task( @@ -432,10 +228,14 @@ def evaluate_task( scores[f"mrr_at_{k}"] = 0.0 return scores - # Load model - print(f"\nLoading model: {model_name}") - model_name_actual, model, processor, device_str, device, dtype = _load_colvision(model_name) - print(f"Model loaded: {model_name_actual}") + # Initialize evaluator + evaluator = ViDoReBenchmarkEvaluator( + model_name=model_name, + use_fast_plaid=use_fast_plaid, + top_k=top_k, + first_stage_k=first_stage_k, + k_values=k_values, + ) # Build or load index index_path_full = index_path if not use_fast_plaid else fast_plaid_index_path @@ -444,32 +244,24 @@ def evaluate_task( if use_fast_plaid: index_path_full = f"./indexes/{task_name}_{model_name}_fastplaid" - index_or_retriever, corpus_ids_ordered = build_index_from_corpus( + index_or_retriever, corpus_ids_ordered = evaluator.build_index_from_corpus( corpus=corpus, - model=model, - processor=processor, index_path=index_path_full, - use_fast_plaid=use_fast_plaid, rebuild=rebuild_index, ) # Search queries task_prompt = task_config.get("prompt") - results = search_queries( + results = evaluator.search_queries( queries=queries, corpus_ids=corpus_ids_ordered, - model=model, - processor=processor, index_or_retriever=index_or_retriever, - use_fast_plaid=use_fast_plaid, fast_plaid_index_path=fast_plaid_index_path, - top_k=top_k, - first_stage_k=first_stage_k, task_prompt=task_prompt, ) # Evaluate - scores = evaluate_results(results, qrels, k_values=k_values) + scores = evaluator.evaluate_results(results, qrels, k_values=k_values) # Print results print(f"\n{'='*80}")