From ae3b8af3df450a8f8ba5ca092f58fc634f075335 Mon Sep 17 00:00:00 2001 From: yichuan-w Date: Fri, 14 Nov 2025 07:31:24 +0000 Subject: [PATCH] update vidore --- .gitignore | 3 +- .../leann_multi_vector.py | 290 +++++++- .../multi-vector-leann-similarity-map.py | 238 +++++-- .../vidore_v2_benchmark.py | 629 ++++++++++++++++++ 4 files changed, 1101 insertions(+), 59 deletions(-) create mode 100644 apps/multimodal/vision-based-pdf-multi-vector/vidore_v2_benchmark.py diff --git a/.gitignore b/.gitignore index 19df865..e60379d 100755 --- a/.gitignore +++ b/.gitignore @@ -91,7 +91,8 @@ packages/leann-backend-diskann/third_party/DiskANN/_deps/ *.meta.json *.passages.json - +*.npy +*.db batchtest.py tests/__pytest_cache__/ tests/__pycache__/ diff --git a/apps/multimodal/vision-based-pdf-multi-vector/leann_multi_vector.py b/apps/multimodal/vision-based-pdf-multi-vector/leann_multi_vector.py index c557fef..fc55852 100644 --- a/apps/multimodal/vision-based-pdf-multi-vector/leann_multi_vector.py +++ b/apps/multimodal/vision-based-pdf-multi-vector/leann_multi_vector.py @@ -219,32 +219,47 @@ def _embed_images(model, processor, images: list[Image.Image]) -> list[Any]: def _embed_queries(model, processor, queries: list[str]) -> list[Any]: import torch - from colpali_engine.utils.torch_utils import ListDataset - from torch.utils.data import DataLoader model.eval() - dataloader = DataLoader( - dataset=ListDataset[str](queries), - batch_size=1, - shuffle=False, - collate_fn=lambda x: processor.process_queries(x), - ) - - q_vecs: list[Any] = [] - for batch_query in tqdm(dataloader, desc="Embedding queries"): - with torch.no_grad(): - batch_query = {k: v.to(model.device) for k, v in batch_query.items()} + # Match MTEB's exact query processing from ColPaliEngineWrapper.get_text_embeddings: + # 1. MTEB receives batch["text"] which may already include instruction/prompt + # 2. Manually adds: query_prefix + text + query_augmentation_token * 10 + # 3. Calls processor.process_queries(batch) where batch is now a list of strings + # 4. process_queries adds: query_prefix + text + suffix (suffix = query_augmentation_token * 10) + # + # However, MTEB's approach results in duplicate addition (20 tokens total). + # Since we're already adding the prompt in search_queries, let's try: + # Option 1: Just call process_queries (let it handle all additions) - avoids duplicate + # Option 2: Manual add + process_texts (to avoid duplicate) + # + # Testing shows Option 1 works better - just call process_queries without manual addition + + all_embeds = [] + batch_size = 32 # Match MTEB's default batch_size + + with torch.no_grad(): + for i in tqdm(range(0, len(queries), batch_size), desc="Embedding queries"): + batch_queries = queries[i:i + batch_size] + + # Just call process_queries - it will add query_prefix + text + 10 tokens + # This avoids duplicate addition that happens in MTEB's approach + inputs = processor.process_queries(batch_queries) + inputs = {k: v.to(model.device) for k, v in inputs.items()} + if model.device.type == "cuda": with torch.autocast( device_type="cuda", dtype=model.dtype if model.dtype.is_floating_point else torch.bfloat16, ): - embeddings_query = model(**batch_query) + outs = model(**inputs) else: - embeddings_query = model(**batch_query) - q_vecs.extend(list(torch.unbind(embeddings_query.to("cpu")))) - return q_vecs + outs = model(**inputs) + + # Match MTEB: convert to float32 on CPU + all_embeds.extend(list(torch.unbind(outs.cpu().to(torch.float32)))) + + return all_embeds def _build_index( @@ -284,6 +299,247 @@ def _load_retriever_if_index_exists(index_path: str) -> Optional[Any]: return None +def _build_fast_plaid_index( + index_path: str, + doc_vecs: list[Any], + filepaths: list[str], + images: list[Image.Image], +) -> tuple[Any, float]: + """ + Build a Fast-Plaid index from document embeddings. + + Args: + index_path: Path to save the Fast-Plaid index + doc_vecs: List of document embeddings (each is a tensor with shape [num_tokens, embedding_dim]) + filepaths: List of filepath identifiers for each document + images: List of PIL Images corresponding to each document + + Returns: + Tuple of (FastPlaid index object, build_time_in_seconds) + """ + import torch + from fast_plaid import search as fast_plaid_search + + print(f" Preparing {len(doc_vecs)} document embeddings for Fast-Plaid...") + _t0 = time.perf_counter() + + # Convert doc_vecs to list of tensors + documents_embeddings = [] + for i, vec in enumerate(doc_vecs): + if i % 1000 == 0: + print(f" Converting embedding {i}/{len(doc_vecs)}...") + if not isinstance(vec, torch.Tensor): + vec = torch.tensor(vec) if isinstance(vec, np.ndarray) else torch.from_numpy(np.array(vec)) + # Ensure float32 for Fast-Plaid + if vec.dtype != torch.float32: + vec = vec.float() + documents_embeddings.append(vec) + + print(f" Converted {len(documents_embeddings)} embeddings") + if len(documents_embeddings) > 0: + print(f" First embedding shape: {documents_embeddings[0].shape}") + print(f" First embedding dtype: {documents_embeddings[0].dtype}") + + # Prepare metadata for Fast-Plaid + print(f" Preparing metadata for {len(filepaths)} documents...") + metadata_list = [] + for i, filepath in enumerate(filepaths): + metadata_list.append({ + "filepath": filepath, + "index": i, + }) + + # Create Fast-Plaid index + print(f" Creating FastPlaid object with index path: {index_path}") + try: + fast_plaid_index = fast_plaid_search.FastPlaid(index=index_path) + print(f" FastPlaid object created successfully") + except Exception as e: + print(f" Error creating FastPlaid object: {type(e).__name__}: {e}") + import traceback + traceback.print_exc() + raise + + print(f" Calling fast_plaid_index.create() with {len(documents_embeddings)} documents...") + try: + fast_plaid_index.create( + documents_embeddings=documents_embeddings, + metadata=metadata_list, + ) + print(f" Fast-Plaid index created successfully") + except Exception as e: + print(f" Error creating Fast-Plaid index: {type(e).__name__}: {e}") + import traceback + traceback.print_exc() + raise + + build_secs = time.perf_counter() - _t0 + + # Save images separately (Fast-Plaid doesn't store images) + print(f" Saving {len(images)} images...") + images_dir = Path(index_path) / "images" + images_dir.mkdir(parents=True, exist_ok=True) + for i, img in enumerate(tqdm(images, desc="Saving images")): + img_path = images_dir / f"doc_{i}.png" + img.save(str(img_path)) + + return fast_plaid_index, build_secs + + +def _fast_plaid_index_exists(index_path: str) -> bool: + """ + Check if Fast-Plaid index exists by checking for key files. + This avoids creating the FastPlaid object which may trigger memory allocation. + + Args: + index_path: Path to the Fast-Plaid index + + Returns: + True if index appears to exist, False otherwise + """ + index_path_obj = Path(index_path) + if not index_path_obj.exists() or not index_path_obj.is_dir(): + return False + + # Fast-Plaid creates a SQLite database file for metadata + # Check for metadata.db as the most reliable indicator + metadata_db = index_path_obj / "metadata.db" + if metadata_db.exists() and metadata_db.stat().st_size > 0: + return True + + # Also check if directory has any files (might be incomplete index) + try: + if any(index_path_obj.iterdir()): + return True + except Exception: + pass + + return False + + +def _load_fast_plaid_index_if_exists(index_path: str) -> Optional[Any]: + """ + Load Fast-Plaid index if it exists. + First checks if index files exist, then creates the FastPlaid object. + The actual index data loading happens lazily when search is called. + + Args: + index_path: Path to the Fast-Plaid index + + Returns: + FastPlaid index object if exists, None otherwise + """ + try: + from fast_plaid import search as fast_plaid_search + + # First check if index files exist without creating the object + if not _fast_plaid_index_exists(index_path): + return None + + # Now try to create FastPlaid object + # This may trigger some memory allocation, but the full index loading is deferred + fast_plaid_index = fast_plaid_search.FastPlaid(index=index_path) + return fast_plaid_index + except ImportError: + # fast-plaid not installed + return None + except Exception as e: + # Any error (including memory errors from Rust backend) - return None + # The error will be caught and index will be rebuilt + print(f"Warning: Could not load Fast-Plaid index: {type(e).__name__}: {e}") + return None + + +def _search_fast_plaid( + fast_plaid_index: Any, + query_vec: Any, + top_k: int, +) -> tuple[list[tuple[float, int]], float]: + """ + Search Fast-Plaid index with a query embedding. + + Args: + fast_plaid_index: FastPlaid index object + query_vec: Query embedding tensor with shape [num_tokens, embedding_dim] + top_k: Number of top results to return + + Returns: + Tuple of (results_list, search_time_in_seconds) + results_list: List of (score, doc_id) tuples + """ + import torch + + _t0 = time.perf_counter() + + # Ensure query is a torch tensor + if not isinstance(query_vec, torch.Tensor): + q_vec_tensor = torch.tensor(query_vec) if isinstance(query_vec, np.ndarray) else torch.from_numpy(np.array(query_vec)) + else: + q_vec_tensor = query_vec + + # Fast-Plaid expects shape [num_queries, num_tokens, embedding_dim] + if q_vec_tensor.dim() == 2: + q_vec_tensor = q_vec_tensor.unsqueeze(0) # [1, num_tokens, embedding_dim] + + # Perform search + scores = fast_plaid_index.search( + queries_embeddings=q_vec_tensor, + top_k=top_k, + show_progress=True, + ) + + search_secs = time.perf_counter() - _t0 + + # Convert Fast-Plaid results to same format as LEANN: list of (score, doc_id) tuples + results = [] + if scores and len(scores) > 0: + query_results = scores[0] + # Fast-Plaid returns (doc_id, score), convert to (score, doc_id) to match LEANN format + results = [(float(score), int(doc_id)) for doc_id, score in query_results] + + return results, search_secs + + +def _get_fast_plaid_image(index_path: str, doc_id: int) -> Optional[Image.Image]: + """ + Retrieve image for a document from Fast-Plaid index. + + Args: + index_path: Path to the Fast-Plaid index + doc_id: Document ID + + Returns: + PIL Image if found, None otherwise + """ + images_dir = Path(index_path) / "images" + image_path = images_dir / f"doc_{doc_id}.png" + + if image_path.exists(): + return Image.open(image_path) + return None + + +def _get_fast_plaid_metadata(index_path: str, doc_id: int) -> Optional[dict]: + """ + Retrieve metadata for a document from Fast-Plaid index. + + Args: + index_path: Path to the Fast-Plaid index + doc_id: Document ID + + Returns: + Dictionary with metadata if found, None otherwise + """ + try: + from fast_plaid import filtering + metadata_list = filtering.get(index=index_path, subset=[doc_id]) + if metadata_list and len(metadata_list) > 0: + return metadata_list[0] + except Exception: + pass + return None + + def _generate_similarity_map( model, processor, diff --git a/apps/multimodal/vision-based-pdf-multi-vector/multi-vector-leann-similarity-map.py b/apps/multimodal/vision-based-pdf-multi-vector/multi-vector-leann-similarity-map.py index 4c4c061..c1216b3 100644 --- a/apps/multimodal/vision-based-pdf-multi-vector/multi-vector-leann-similarity-map.py +++ b/apps/multimodal/vision-based-pdf-multi-vector/multi-vector-leann-similarity-map.py @@ -2,13 +2,18 @@ # %% # uv pip install matplotlib qwen_vl_utils import argparse +import faulthandler import os import time from typing import Any, Optional +import numpy as np from PIL import Image from tqdm import tqdm +# Enable faulthandler to get stack trace on segfault +faulthandler.enable() + from leann_multi_vector import ( # utility functions/classes _ensure_repo_paths_importable, @@ -20,6 +25,11 @@ from leann_multi_vector import ( # utility functions/classes _build_index, _load_retriever_if_index_exists, _generate_similarity_map, + _build_fast_plaid_index, + _load_fast_plaid_index_if_exists, + _search_fast_plaid, + _get_fast_plaid_image, + _get_fast_plaid_metadata, QwenVL, ) @@ -69,6 +79,8 @@ PAGES_DIR: str = "./pages" # Index + retrieval settings # Use a different index path for larger dataset to avoid overwriting existing index INDEX_PATH: str = "./indexes/colvision_large.leann" +# Fast-Plaid index settings (alternative to LEANN index) +# These are now command-line arguments (see CLI overrides section) TOPK: int = 3 FIRST_STAGE_K: int = 500 REBUILD_INDEX: bool = False @@ -98,24 +110,64 @@ parser.add_argument( default=QUERY, help=f"Query string to search for. Default: '{QUERY}'", ) +parser.add_argument( + "--use-fast-plaid", + action="store_true", + default=False, + help="Set to True to use fast-plaid instead of LEANN. Default: False", +) +parser.add_argument( + "--fast-plaid-index-path", + type=str, + default="./indexes/colvision_fastplaid", + help="Path to the Fast-Plaid index. Default: './indexes/colvision_fastplaid'", +) cli_args, _unknown = parser.parse_known_args() SEARCH_METHOD: str = cli_args.search_method QUERY = cli_args.query # Override QUERY with CLI argument if provided +USE_FAST_PLAID: bool = cli_args.use_fast_plaid +FAST_PLAID_INDEX_PATH: str = cli_args.fast_plaid_index_path # %% # Step 1: Check if we can skip data loading (index already exists) retriever: Optional[Any] = None +fast_plaid_index: Optional[Any] = None need_to_build_index = REBUILD_INDEX -if not REBUILD_INDEX: - retriever = _load_retriever_if_index_exists(INDEX_PATH) - if retriever is not None: - print(f"✓ Index loaded from {INDEX_PATH}") - print(f"✓ Images available at: {retriever._images_dir_path()}") - need_to_build_index = False +if USE_FAST_PLAID: + # Fast-Plaid index handling + if not REBUILD_INDEX: + try: + fast_plaid_index = _load_fast_plaid_index_if_exists(FAST_PLAID_INDEX_PATH) + if fast_plaid_index is not None: + print(f"✓ Fast-Plaid index found at {FAST_PLAID_INDEX_PATH}") + need_to_build_index = False + else: + print(f"Fast-Plaid index not found, will build new index") + need_to_build_index = True + except Exception as e: + # If loading fails (e.g., memory error, corrupted index), rebuild + print(f"Warning: Failed to load Fast-Plaid index: {e}") + print("Will rebuild the index...") + need_to_build_index = True + fast_plaid_index = None else: - print(f"Index not found, will build new index") + print(f"REBUILD_INDEX=True, will rebuild Fast-Plaid index") + need_to_build_index = True +else: + # Original LEANN index handling + if not REBUILD_INDEX: + retriever = _load_retriever_if_index_exists(INDEX_PATH) + if retriever is not None: + print(f"✓ Index loaded from {INDEX_PATH}") + print(f"✓ Images available at: {retriever._images_dir_path()}") + need_to_build_index = False + else: + print(f"Index not found, will build new index") + need_to_build_index = True + else: + print(f"REBUILD_INDEX=True, will rebuild index") need_to_build_index = True # Step 2: Load data only if we need to build the index @@ -347,6 +399,19 @@ if need_to_build_index: f"No images found in {PAGES_DIR}. Provide PDF path in PDF variable or ensure images exist." ) print(f"Loaded {len(images)} images") + + # Memory check before loading model + try: + import psutil + import torch + process = psutil.Process(os.getpid()) + mem_info = process.memory_info() + print(f"Memory usage after loading images: {mem_info.rss / 1024 / 1024 / 1024:.2f} GB") + if torch.cuda.is_available(): + print(f"GPU memory allocated: {torch.cuda.memory_allocated() / 1024**3:.2f} GB") + print(f"GPU memory reserved: {torch.cuda.memory_reserved() / 1024**3:.2f} GB") + except ImportError: + pass else: print("Skipping dataset loading (using existing index)") filepaths = [] # Not needed when using existing index @@ -355,23 +420,91 @@ else: # %% # Step 3: Load model and processor (only if we need to build index or perform search) -model_name, model, processor, device_str, device, dtype = _load_colvision(MODEL) -print(f"Using model={model_name}, device={device_str}, dtype={dtype}") +print("Step 3: Loading model and processor...") +print(f" Model: {MODEL}") +try: + import sys + print(f" Python version: {sys.version}") + print(f" Python executable: {sys.executable}") + + model_name, model, processor, device_str, device, dtype = _load_colvision(MODEL) + print(f"✓ Using model={model_name}, device={device_str}, dtype={dtype}") + + # Memory check after loading model + try: + import psutil + import torch + process = psutil.Process(os.getpid()) + mem_info = process.memory_info() + print(f" Memory usage after loading model: {mem_info.rss / 1024 / 1024 / 1024:.2f} GB") + if torch.cuda.is_available(): + print(f" GPU memory allocated: {torch.cuda.memory_allocated() / 1024**3:.2f} GB") + print(f" GPU memory reserved: {torch.cuda.memory_reserved() / 1024**3:.2f} GB") + except ImportError: + pass +except Exception as e: + print(f"✗ Error loading model: {type(e).__name__}: {e}") + import traceback + traceback.print_exc() + raise # %% # %% # Step 4: Build index if needed -if need_to_build_index and retriever is None: - print("Building index...") - doc_vecs = _embed_images(model, processor, images) - retriever = _build_index(INDEX_PATH, doc_vecs, filepaths, images) - print(f"✓ Index built and images saved to: {retriever._images_dir_path()}") - # Clear memory - del images, filepaths, doc_vecs +if need_to_build_index: + print("Step 4: Building index...") + print(f" Number of images: {len(images)}") + print(f" Number of filepaths: {len(filepaths)}") + + try: + print(" Embedding images...") + doc_vecs = _embed_images(model, processor, images) + print(f" Embedded {len(doc_vecs)} documents") + print(f" First doc vec shape: {doc_vecs[0].shape if len(doc_vecs) > 0 else 'N/A'}") + except Exception as e: + print(f"Error embedding images: {type(e).__name__}: {e}") + import traceback + traceback.print_exc() + raise + + if USE_FAST_PLAID: + # Build Fast-Plaid index + print(" Building Fast-Plaid index...") + try: + fast_plaid_index, build_secs = _build_fast_plaid_index( + FAST_PLAID_INDEX_PATH, doc_vecs, filepaths, images + ) + from pathlib import Path + print(f"✓ Fast-Plaid index built in {build_secs:.3f}s") + print(f"✓ Index saved to: {FAST_PLAID_INDEX_PATH}") + print(f"✓ Images saved to: {Path(FAST_PLAID_INDEX_PATH) / 'images'}") + except Exception as e: + print(f"Error building Fast-Plaid index: {type(e).__name__}: {e}") + import traceback + traceback.print_exc() + raise + finally: + # Clear memory + print(" Clearing memory...") + del images, filepaths, doc_vecs + else: + # Build original LEANN index + try: + retriever = _build_index(INDEX_PATH, doc_vecs, filepaths, images) + print(f"✓ Index built and images saved to: {retriever._images_dir_path()}") + except Exception as e: + print(f"Error building LEANN index: {type(e).__name__}: {e}") + import traceback + traceback.print_exc() + raise + finally: + # Clear memory + print(" Clearing memory...") + del images, filepaths, doc_vecs -# Note: Images are now stored in the index, retriever will load them on-demand from disk +# Note: Images are now stored separately, retriever/fast_plaid_index will reference them # %% @@ -380,44 +513,67 @@ _t0 = time.perf_counter() q_vec = _embed_queries(model, processor, [QUERY])[0] query_embed_secs = time.perf_counter() - _t0 -query_np = q_vec.float().numpy() - print(f"[Search] Method: {SEARCH_METHOD}") print(f"[Timing] Query embedding: {query_embed_secs:.3f}s") # Run the selected search method and time it -_t0 = time.perf_counter() -if SEARCH_METHOD == "ann": - results = retriever.search(query_np, topk=TOPK, first_stage_k=FIRST_STAGE_K) - search_secs = time.perf_counter() - _t0 - print(f"[Timing] Search (ANN): {search_secs:.3f}s (first_stage_k={FIRST_STAGE_K})") -elif SEARCH_METHOD == "exact": - results = retriever.search_exact(query_np, topk=TOPK, first_stage_k=FIRST_STAGE_K) - search_secs = time.perf_counter() - _t0 - print(f"[Timing] Search (Exact rerank): {search_secs:.3f}s (first_stage_k={FIRST_STAGE_K})") -elif SEARCH_METHOD == "exact-all": - results = retriever.search_exact_all(query_np, topk=TOPK) - search_secs = time.perf_counter() - _t0 - print(f"[Timing] Search (Exact all): {search_secs:.3f}s") +if USE_FAST_PLAID: + # Fast-Plaid search + if fast_plaid_index is None: + fast_plaid_index = _load_fast_plaid_index_if_exists(FAST_PLAID_INDEX_PATH) + if fast_plaid_index is None: + raise RuntimeError(f"Fast-Plaid index not found at {FAST_PLAID_INDEX_PATH}") + + results, search_secs = _search_fast_plaid(fast_plaid_index, q_vec, TOPK) + print(f"[Timing] Fast-Plaid Search: {search_secs:.3f}s") else: - results = [] + # Original LEANN search + query_np = q_vec.float().numpy() + + if SEARCH_METHOD == "ann": + results = retriever.search(query_np, topk=TOPK, first_stage_k=FIRST_STAGE_K) + search_secs = time.perf_counter() - _t0 + print(f"[Timing] Search (ANN): {search_secs:.3f}s (first_stage_k={FIRST_STAGE_K})") + elif SEARCH_METHOD == "exact": + results = retriever.search_exact(query_np, topk=TOPK, first_stage_k=FIRST_STAGE_K) + search_secs = time.perf_counter() - _t0 + print(f"[Timing] Search (Exact rerank): {search_secs:.3f}s (first_stage_k={FIRST_STAGE_K})") + elif SEARCH_METHOD == "exact-all": + results = retriever.search_exact_all(query_np, topk=TOPK) + search_secs = time.perf_counter() - _t0 + print(f"[Timing] Search (Exact all): {search_secs:.3f}s") + else: + results = [] if not results: print("No results found.") else: print(f'Top {len(results)} results for query: "{QUERY}"') top_images: list[Image.Image] = [] for rank, (score, doc_id) in enumerate(results, start=1): - # Retrieve image from index instead of memory - image = retriever.get_image(doc_id) - if image is None: - print(f"Warning: Could not retrieve image for doc_id {doc_id}") - continue + # Retrieve image and metadata based on index type + if USE_FAST_PLAID: + # Fast-Plaid: load image and get metadata + image = _get_fast_plaid_image(FAST_PLAID_INDEX_PATH, doc_id) + if image is None: + print(f"Warning: Could not find image for doc_id {doc_id}") + continue + + metadata = _get_fast_plaid_metadata(FAST_PLAID_INDEX_PATH, doc_id) + path = metadata.get("filepath", f"doc_{doc_id}") if metadata else f"doc_{doc_id}" + top_images.append(image) + else: + # Original LEANN: retrieve from retriever + image = retriever.get_image(doc_id) + if image is None: + print(f"Warning: Could not retrieve image for doc_id {doc_id}") + continue - metadata = retriever.get_metadata(doc_id) - path = metadata.get("filepath", "unknown") if metadata else "unknown" + metadata = retriever.get_metadata(doc_id) + path = metadata.get("filepath", "unknown") if metadata else "unknown" + top_images.append(image) + # For HF dataset, path is a descriptive identifier, not a real file path print(f"{rank}) MaxSim: {score:.4f}, Page: {path}") - top_images.append(image) if SAVE_TOP_IMAGE: from pathlib import Path as _Path diff --git a/apps/multimodal/vision-based-pdf-multi-vector/vidore_v2_benchmark.py b/apps/multimodal/vision-based-pdf-multi-vector/vidore_v2_benchmark.py new file mode 100644 index 0000000..6150ad6 --- /dev/null +++ b/apps/multimodal/vision-based-pdf-multi-vector/vidore_v2_benchmark.py @@ -0,0 +1,629 @@ +#!/usr/bin/env python3 +""" +Modular script to reproduce NDCG results for ViDoRe v2 benchmark. + +This script uses the interface from leann_multi_vector.py to: +1. Download ViDoRe v2 datasets +2. Build indexes (LEANN or Fast-Plaid) +3. Perform retrieval +4. Evaluate using NDCG metrics + +Usage: + # Evaluate all ViDoRe v2 tasks + python vidore_v2_benchmark.py --model colqwen2 --tasks all + + # Evaluate specific task + python vidore_v2_benchmark.py --model colqwen2 --task Vidore2ESGReportsRetrieval + + # Use Fast-Plaid index + python vidore_v2_benchmark.py --model colqwen2 --use-fast-plaid --fast-plaid-index-path ./indexes/vidore_fastplaid + + # Rebuild index + python vidore_v2_benchmark.py --model colqwen2 --rebuild-index +""" + +import argparse +import json +import os +import time +from pathlib import Path +from typing import Any, Optional + +import numpy as np +from datasets import load_dataset +from PIL import Image +from tqdm import tqdm + +# Import MTEB for evaluation metrics +try: + import pytrec_eval + from mteb._evaluators.retrieval_metrics import ( + calculate_retrieval_scores, + make_score_dict, + ) +except ImportError: + print("Warning: MTEB not available. Install with: pip install mteb") + pytrec_eval = None + +from leann_multi_vector import ( + _ensure_repo_paths_importable, + _load_colvision, + _embed_images, + _embed_queries, + _build_index, + _load_retriever_if_index_exists, + _build_fast_plaid_index, + _load_fast_plaid_index_if_exists, + _search_fast_plaid, + _get_fast_plaid_image, + _get_fast_plaid_metadata, +) + +_ensure_repo_paths_importable(__file__) + +# Language name to dataset language field value mapping +# Dataset uses ISO 639-3 + ISO 15924 format (e.g., "eng-Latn") +LANGUAGE_MAPPING = { + "english": "eng-Latn", + "french": "fra-Latn", + "spanish": "spa-Latn", + "german": "deu-Latn", +} + +# ViDoRe v2 task configurations +# Prompts match MTEB task metadata prompts +VIDORE_V2_TASKS = { + "Vidore2ESGReportsRetrieval": { + "dataset_path": "vidore/esg_reports_v2", + "revision": "0542c0d03da0ec1c8cbc517c8d78e7e95c75d3d3", + "languages": ["french", "spanish", "english", "german"], + "prompt": {"query": "Find a screenshot that relevant to the user's question."}, + }, + "Vidore2EconomicsReportsRetrieval": { + "dataset_path": "vidore/economics_reports_v2", + "revision": "b3e3a04b07fbbaffe79be49dabf92f691fbca252", + "languages": ["french", "spanish", "english", "german"], + "prompt": {"query": "Find a screenshot that relevant to the user's question."}, + }, + "Vidore2BioMedicalLecturesRetrieval": { + "dataset_path": "vidore/biomedical_lectures_v2", + "revision": "a29202f0da409034d651614d87cd8938d254e2ea", + "languages": ["french", "spanish", "english", "german"], + "prompt": {"query": "Find a screenshot that relevant to the user's question."}, + }, + "Vidore2ESGReportsHLRetrieval": { + "dataset_path": "vidore/esg_reports_human_labeled_v2", + "revision": "6d467dedb09a75144ede1421747e47cf036857dd", + # Note: This dataset doesn't have language filtering - all queries are English + "languages": None, # No language filtering needed + "prompt": {"query": "Find a screenshot that relevant to the user's question."}, + }, +} + + +def load_vidore_v2_data( + dataset_path: str, + revision: Optional[str] = None, + split: str = "test", + language: Optional[str] = None, +): + """ + Load ViDoRe v2 dataset. + + Returns: + corpus: dict mapping corpus_id to PIL Image + queries: dict mapping query_id to query text + qrels: dict mapping query_id to dict of {corpus_id: relevance_score} + """ + print(f"Loading dataset: {dataset_path} (split={split}, language={language})") + + # Load queries + query_ds = load_dataset(dataset_path, "queries", split=split, revision=revision) + + # Check if dataset has language field before filtering + has_language_field = len(query_ds) > 0 and "language" in query_ds.column_names + + if language and has_language_field: + # Map language name to dataset language field value (e.g., "english" -> "eng-Latn") + dataset_language = LANGUAGE_MAPPING.get(language, language) + query_ds_filtered = query_ds.filter(lambda x: x.get("language") == dataset_language) + # Check if filtering resulted in empty dataset + if len(query_ds_filtered) == 0: + print(f"Warning: No queries found after filtering by language '{language}' (mapped to '{dataset_language}').") + # Try with original language value (dataset might use simple names like 'english') + print(f"Trying with original language value '{language}'...") + query_ds_filtered = query_ds.filter(lambda x: x.get("language") == language) + if len(query_ds_filtered) == 0: + # Try to get a sample to see actual language values + try: + sample_ds = load_dataset(dataset_path, "queries", split=split, revision=revision) + if len(sample_ds) > 0 and "language" in sample_ds.column_names: + sample_langs = set(sample_ds["language"]) + print(f"Available language values in dataset: {sample_langs}") + except Exception: + pass + else: + print(f"Found {len(query_ds_filtered)} queries using original language value '{language}'") + query_ds = query_ds_filtered + + queries = {} + for row in query_ds: + query_id = f"query-{split}-{row['query-id']}" + queries[query_id] = row["query"] + + # Load corpus (images) + corpus_ds = load_dataset(dataset_path, "corpus", split=split, revision=revision) + + corpus = {} + for row in corpus_ds: + corpus_id = f"corpus-{split}-{row['corpus-id']}" + # Extract image from the dataset row + if "image" in row: + corpus[corpus_id] = row["image"] + elif "page_image" in row: + corpus[corpus_id] = row["page_image"] + else: + raise ValueError(f"No image field found in corpus. Available fields: {list(row.keys())}") + + # Load qrels (relevance judgments) + qrels_ds = load_dataset(dataset_path, "qrels", split=split, revision=revision) + + qrels = {} + for row in qrels_ds: + query_id = f"query-{split}-{row['query-id']}" + corpus_id = f"corpus-{split}-{row['corpus-id']}" + if query_id not in qrels: + qrels[query_id] = {} + qrels[query_id][corpus_id] = int(row["score"]) + + print(f"Loaded {len(queries)} queries, {len(corpus)} corpus items, {len(qrels)} query-relevance mappings") + + # Filter qrels to only include queries that exist + qrels = {qid: rel_docs for qid, rel_docs in qrels.items() if qid in queries} + + return corpus, queries, qrels + + +def build_index_from_corpus( + corpus: dict[str, Image.Image], + model, + processor, + index_path: str, + use_fast_plaid: bool = False, + rebuild: bool = False, +): + """ + Build index from corpus images. + + Returns: + tuple: (retriever or fast_plaid_index object, list of corpus_ids in order) + """ + # Ensure consistent ordering + corpus_ids = sorted(corpus.keys()) # Sort for consistency + images = [corpus[cid] for cid in corpus_ids] + + if use_fast_plaid: + # Check if Fast-Plaid index exists + if not rebuild and _load_fast_plaid_index_if_exists(index_path) is not None: + print(f"Fast-Plaid index already exists at {index_path}") + return _load_fast_plaid_index_if_exists(index_path), corpus_ids + + print(f"Building Fast-Plaid index at {index_path}...") + + # Embed images + print("Embedding images...") + doc_vecs = _embed_images(model, processor, images) + + # Build index + fast_plaid_index, build_time = _build_fast_plaid_index( + index_path, doc_vecs, corpus_ids, images + ) + print(f"Fast-Plaid index built in {build_time:.2f}s") + return fast_plaid_index, corpus_ids + else: + # Check if LEANN index exists + if not rebuild: + retriever = _load_retriever_if_index_exists(index_path) + if retriever is not None: + print(f"LEANN index already exists at {index_path}") + return retriever, corpus_ids + + print(f"Building LEANN index at {index_path}...") + + # Embed images + print("Embedding images...") + doc_vecs = _embed_images(model, processor, images) + + # Build index + retriever = _build_index(index_path, doc_vecs, corpus_ids, images) + print(f"LEANN index built") + return retriever, corpus_ids + + +def search_queries( + queries: dict[str, str], + corpus_ids: list[str], + model, + processor, + index_or_retriever: Any, + use_fast_plaid: bool = False, + fast_plaid_index_path: Optional[str] = None, + top_k: int = 100, + first_stage_k: int = 500, + task_prompt: Optional[dict[str, str]] = None, +) -> dict[str, dict[str, float]]: + """ + Search queries against the index. + + Args: + queries: dict mapping query_id to query text + corpus_ids: list of corpus_ids in the same order as the index + model: model object + processor: processor object + index_or_retriever: index or retriever object + use_fast_plaid: whether using Fast-Plaid + fast_plaid_index_path: path to Fast-Plaid index (for metadata) + top_k: top-k results to retrieve + first_stage_k: first stage k for LEANN search + task_prompt: Optional dict with prompt for query (e.g., {"query": "..."}) + + Returns: + results: dict mapping query_id to dict of {corpus_id: score} + """ + print(f"Searching {len(queries)} queries (top_k={top_k})...") + + query_ids = list(queries.keys()) + query_texts = [queries[qid] for qid in query_ids] + + # Match MTEB: combine queries with instruction/prompt if provided + # MTEB's _combine_queries_with_instruction_text does: query + " " + instruction + if task_prompt and "query" in task_prompt: + instruction = task_prompt["query"] + query_texts = [q + " " + instruction for q in query_texts] + print(f"Added task prompt to queries: {instruction}") + + # Embed queries + print("Embedding queries...") + query_vecs = _embed_queries(model, processor, query_texts) + + results = {} + + for query_id, query_vec in zip(tqdm(query_ids, desc="Searching"), query_vecs): + if use_fast_plaid: + # Fast-Plaid search + search_results, _ = _search_fast_plaid(index_or_retriever, query_vec, top_k) + # Convert doc_id back to corpus_id + query_results = {} + for score, doc_id in search_results: + if doc_id < len(corpus_ids): + corpus_id = corpus_ids[doc_id] + query_results[corpus_id] = float(score) + else: + # LEANN search + query_np = query_vec.float().numpy() + search_results = index_or_retriever.search_exact_all(query_np, topk=top_k) + # Convert doc_id back to corpus_id + query_results = {} + for score, doc_id in search_results: + if doc_id < len(corpus_ids): + corpus_id = corpus_ids[doc_id] + query_results[corpus_id] = float(score) + + results[query_id] = query_results + + return results + + +def evaluate_results( + results: dict[str, dict[str, float]], + qrels: dict[str, dict[str, int]], + k_values: list[int] = [1, 3, 5, 10, 100], +) -> dict[str, float]: + """ + Evaluate retrieval results using NDCG and other metrics. + + Returns: + Dictionary of metric scores + """ + if pytrec_eval is None: + raise ImportError("pytrec_eval is required for evaluation. Install with: pip install pytrec-eval") + + # Check if we have any queries to evaluate + if len(results) == 0: + print("Warning: No queries to evaluate. Returning zero scores.") + # Return zero scores for all metrics + scores = {} + for k in k_values: + scores[f"ndcg_at_{k}"] = 0.0 + scores[f"map_at_{k}"] = 0.0 + scores[f"recall_at_{k}"] = 0.0 + scores[f"precision_at_{k}"] = 0.0 + scores[f"mrr_at_{k}"] = 0.0 + return scores + + print(f"Evaluating results with k_values={k_values}...") + + # Convert qrels to pytrec_eval format + qrels_pytrec = {} + for qid, rel_docs in qrels.items(): + qrels_pytrec[qid] = {did: score for did, score in rel_docs.items()} + + # Evaluate + eval_result = calculate_retrieval_scores( + results=results, + qrels=qrels_pytrec, + k_values=k_values, + ) + + # Format scores + scores = make_score_dict( + ndcg=eval_result.ndcg, + _map=eval_result.map, + recall=eval_result.recall, + precision=eval_result.precision, + mrr=eval_result.mrr, + naucs=eval_result.naucs, + naucs_mrr=eval_result.naucs_mrr, + cv_recall=eval_result.cv_recall, + task_scores={}, + ) + + return scores + + +def evaluate_task( + task_name: str, + model_name: str, + index_path: str, + use_fast_plaid: bool = False, + fast_plaid_index_path: Optional[str] = None, + language: Optional[str] = None, + rebuild_index: bool = False, + top_k: int = 100, + first_stage_k: int = 500, + k_values: list[int] = [1, 3, 5, 10, 100], + output_dir: Optional[str] = None, +): + """ + Evaluate a single ViDoRe v2 task. + """ + print(f"\n{'='*80}") + print(f"Evaluating task: {task_name}") + print(f"{'='*80}") + + # Get task config + if task_name not in VIDORE_V2_TASKS: + raise ValueError(f"Unknown task: {task_name}. Available: {list(VIDORE_V2_TASKS.keys())}") + + task_config = VIDORE_V2_TASKS[task_name] + dataset_path = task_config["dataset_path"] + revision = task_config["revision"] + + # Determine language + if language is None: + # Use first language if multiple available + languages = task_config.get("languages") + if languages is None: + # Task doesn't support language filtering (e.g., Vidore2ESGReportsHLRetrieval) + language = None + elif len(languages) == 1: + language = languages[0] + else: + language = None + + # Load data + corpus, queries, qrels = load_vidore_v2_data( + dataset_path=dataset_path, + revision=revision, + split="test", + language=language, + ) + + # Check if we have any queries + if len(queries) == 0: + print(f"\nWarning: No queries found for task {task_name} with language {language}. Skipping evaluation.") + # Return zero scores + scores = {} + for k in k_values: + scores[f"ndcg_at_{k}"] = 0.0 + scores[f"map_at_{k}"] = 0.0 + scores[f"recall_at_{k}"] = 0.0 + scores[f"precision_at_{k}"] = 0.0 + scores[f"mrr_at_{k}"] = 0.0 + return scores + + # Load model + print(f"\nLoading model: {model_name}") + model_name_actual, model, processor, device_str, device, dtype = _load_colvision(model_name) + print(f"Model loaded: {model_name_actual}") + + # Build or load index + index_path_full = index_path if not use_fast_plaid else fast_plaid_index_path + if index_path_full is None: + index_path_full = f"./indexes/{task_name}_{model_name}" + if use_fast_plaid: + index_path_full = f"./indexes/{task_name}_{model_name}_fastplaid" + + index_or_retriever, corpus_ids_ordered = build_index_from_corpus( + corpus=corpus, + model=model, + processor=processor, + index_path=index_path_full, + use_fast_plaid=use_fast_plaid, + rebuild=rebuild_index, + ) + + # Search queries + task_prompt = task_config.get("prompt") + results = search_queries( + queries=queries, + corpus_ids=corpus_ids_ordered, + model=model, + processor=processor, + index_or_retriever=index_or_retriever, + use_fast_plaid=use_fast_plaid, + fast_plaid_index_path=fast_plaid_index_path, + top_k=top_k, + first_stage_k=first_stage_k, + task_prompt=task_prompt, + ) + + # Evaluate + scores = evaluate_results(results, qrels, k_values=k_values) + + # Print results + print(f"\n{'='*80}") + print(f"Results for {task_name}:") + print(f"{'='*80}") + for metric, value in scores.items(): + if isinstance(value, (int, float)): + print(f" {metric}: {value:.5f}") + + # Save results + if output_dir: + os.makedirs(output_dir, exist_ok=True) + results_file = os.path.join(output_dir, f"{task_name}_results.json") + scores_file = os.path.join(output_dir, f"{task_name}_scores.json") + + with open(results_file, "w") as f: + json.dump(results, f, indent=2) + print(f"\nSaved results to: {results_file}") + + with open(scores_file, "w") as f: + json.dump(scores, f, indent=2) + print(f"Saved scores to: {scores_file}") + + return scores + + +def main(): + parser = argparse.ArgumentParser( + description="Evaluate ViDoRe v2 benchmark using LEANN/Fast-Plaid indexing" + ) + parser.add_argument( + "--model", + type=str, + default="colqwen2", + choices=["colqwen2", "colpali"], + help="Model to use", + ) + parser.add_argument( + "--task", + type=str, + default=None, + help="Specific task to evaluate (or 'all' for all tasks)", + ) + parser.add_argument( + "--tasks", + type=str, + default="all", + help="Tasks to evaluate: 'all' or comma-separated list", + ) + parser.add_argument( + "--index-path", + type=str, + default=None, + help="Path to LEANN index (auto-generated if not provided)", + ) + parser.add_argument( + "--use-fast-plaid", + action="store_true", + help="Use Fast-Plaid instead of LEANN", + ) + parser.add_argument( + "--fast-plaid-index-path", + type=str, + default=None, + help="Path to Fast-Plaid index (auto-generated if not provided)", + ) + parser.add_argument( + "--rebuild-index", + action="store_true", + help="Rebuild index even if it exists", + ) + parser.add_argument( + "--language", + type=str, + default=None, + help="Language to evaluate (default: first available)", + ) + parser.add_argument( + "--top-k", + type=int, + default=100, + help="Top-k results to retrieve", + ) + parser.add_argument( + "--first-stage-k", + type=int, + default=500, + help="First stage k for LEANN search", + ) + parser.add_argument( + "--k-values", + type=str, + default="1,3,5,10,100", + help="Comma-separated k values for evaluation (e.g., '1,3,5,10,100')", + ) + parser.add_argument( + "--output-dir", + type=str, + default="./vidore_v2_results", + help="Output directory for results", + ) + + args = parser.parse_args() + + # Parse k_values + k_values = [int(k.strip()) for k in args.k_values.split(",")] + + # Determine tasks to evaluate + if args.task: + tasks_to_eval = [args.task] + elif args.tasks.lower() == "all": + tasks_to_eval = list(VIDORE_V2_TASKS.keys()) + else: + tasks_to_eval = [t.strip() for t in args.tasks.split(",")] + + print(f"Tasks to evaluate: {tasks_to_eval}") + + # Evaluate each task + all_scores = {} + for task_name in tasks_to_eval: + try: + scores = evaluate_task( + task_name=task_name, + model_name=args.model, + index_path=args.index_path, + use_fast_plaid=args.use_fast_plaid, + fast_plaid_index_path=args.fast_plaid_index_path, + language=args.language, + rebuild_index=args.rebuild_index, + top_k=args.top_k, + first_stage_k=args.first_stage_k, + k_values=k_values, + output_dir=args.output_dir, + ) + all_scores[task_name] = scores + except Exception as e: + print(f"\nError evaluating {task_name}: {e}") + import traceback + traceback.print_exc() + continue + + # Print summary + if all_scores: + print(f"\n{'='*80}") + print("SUMMARY") + print(f"{'='*80}") + for task_name, scores in all_scores.items(): + print(f"\n{task_name}:") + # Print main metrics + for metric in ["ndcg_at_5", "ndcg_at_10", "ndcg_at_100", "map_at_10", "recall_at_10"]: + if metric in scores: + print(f" {metric}: {scores[metric]:.5f}") + + +if __name__ == "__main__": + main() +