From d599566fd710d3cf9170ae27e497360bc42518ab Mon Sep 17 00:00:00 2001 From: Yichuan Wang <73766326+yichuan-w@users.noreply.github.com> Date: Wed, 3 Dec 2025 01:09:39 -0800 Subject: [PATCH] =?UTF-8?q?Revert=20"[Multi-vector]Add=20timing=20instrume?= =?UTF-8?q?ntation=20and=20multi-dataset=20support=20fo=E2=80=A6"=20(#180)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This reverts commit 00770aebbb5af5d1952a44bbf856018fb8805bc0. --- .gitignore | 3 +- .../colqwen_forward.py | 132 ---- .../leann_multi_vector.py | 598 +----------------- .../multi-vector-leann-similarity-map.py | 537 ++-------------- .../vidore_v1_benchmark.py | 399 ------------ .../vidore_v2_benchmark.py | 439 ------------- 6 files changed, 60 insertions(+), 2048 deletions(-) delete mode 100755 apps/multimodal/vision-based-pdf-multi-vector/colqwen_forward.py delete mode 100644 apps/multimodal/vision-based-pdf-multi-vector/vidore_v1_benchmark.py delete mode 100644 apps/multimodal/vision-based-pdf-multi-vector/vidore_v2_benchmark.py diff --git a/.gitignore b/.gitignore index e60379d..19df865 100755 --- a/.gitignore +++ b/.gitignore @@ -91,8 +91,7 @@ packages/leann-backend-diskann/third_party/DiskANN/_deps/ *.meta.json *.passages.json -*.npy -*.db + batchtest.py tests/__pytest_cache__/ tests/__pycache__/ diff --git a/apps/multimodal/vision-based-pdf-multi-vector/colqwen_forward.py b/apps/multimodal/vision-based-pdf-multi-vector/colqwen_forward.py deleted file mode 100755 index 53006d6..0000000 --- a/apps/multimodal/vision-based-pdf-multi-vector/colqwen_forward.py +++ /dev/null @@ -1,132 +0,0 @@ -#!/usr/bin/env python3 -"""Simple test script to test colqwen2 forward pass with a single image.""" - -import os -import sys -from pathlib import Path - -# Add the current directory to path to import leann_multi_vector -sys.path.insert(0, str(Path(__file__).parent)) - -from PIL import Image -import torch - -from leann_multi_vector import _load_colvision, _embed_images, _ensure_repo_paths_importable - -# Ensure repo paths are importable -_ensure_repo_paths_importable(__file__) - -# Set environment variable -os.environ["TOKENIZERS_PARALLELISM"] = "false" - - -def create_test_image(): - """Create a simple test image.""" - # Create a simple RGB image (800x600) - img = Image.new('RGB', (800, 600), color='white') - return img - - -def load_test_image_from_file(): - """Try to load an image from the indexes directory if available.""" - # Try to find an existing image in the indexes directory - indexes_dir = Path(__file__).parent / "indexes" - - # Look for images in common locations - possible_paths = [ - indexes_dir / "vidore_fastplaid" / "images", - indexes_dir / "colvision_large.leann.images", - indexes_dir / "colvision.leann.images", - ] - - for img_dir in possible_paths: - if img_dir.exists(): - # Find first image file - for ext in ['.png', '.jpg', '.jpeg']: - for img_file in img_dir.glob(f'*{ext}'): - print(f"Loading test image from: {img_file}") - return Image.open(img_file) - - return None - - -def main(): - print("=" * 60) - print("Testing ColQwen2 Forward Pass") - print("=" * 60) - - # Step 1: Load or create test image - print("\n[Step 1] Loading test image...") - test_image = load_test_image_from_file() - if test_image is None: - print("No existing image found, creating a simple test image...") - test_image = create_test_image() - else: - print(f"✓ Loaded image: {test_image.size} ({test_image.mode})") - - # Convert to RGB if needed - if test_image.mode != 'RGB': - test_image = test_image.convert('RGB') - print(f"✓ Converted to RGB: {test_image.size}") - - # Step 2: Load model - print("\n[Step 2] Loading ColQwen2 model...") - try: - model_name, model, processor, device_str, device, dtype = _load_colvision("colqwen2") - print(f"✓ Model loaded: {model_name}") - print(f"✓ Device: {device_str}, dtype: {dtype}") - - # Print model info - if hasattr(model, 'device'): - print(f"✓ Model device: {model.device}") - if hasattr(model, 'dtype'): - print(f"✓ Model dtype: {model.dtype}") - - except Exception as e: - print(f"✗ Error loading model: {e}") - import traceback - traceback.print_exc() - return - - # Step 3: Test forward pass - print("\n[Step 3] Running forward pass...") - try: - # Use the _embed_images function which handles batching and forward pass - images = [test_image] - print(f"Processing {len(images)} image(s)...") - - doc_vecs = _embed_images(model, processor, images) - - print(f"✓ Forward pass completed!") - print(f"✓ Number of embeddings: {len(doc_vecs)}") - - if len(doc_vecs) > 0: - emb = doc_vecs[0] - print(f"✓ Embedding shape: {emb.shape}") - print(f"✓ Embedding dtype: {emb.dtype}") - print(f"✓ Embedding stats:") - print(f" - Min: {emb.min().item():.4f}") - print(f" - Max: {emb.max().item():.4f}") - print(f" - Mean: {emb.mean().item():.4f}") - print(f" - Std: {emb.std().item():.4f}") - - # Check for NaN or Inf - if torch.isnan(emb).any(): - print("⚠ Warning: Embedding contains NaN values!") - if torch.isinf(emb).any(): - print("⚠ Warning: Embedding contains Inf values!") - - except Exception as e: - print(f"✗ Error during forward pass: {e}") - import traceback - traceback.print_exc() - return - - print("\n" + "=" * 60) - print("Test completed successfully!") - print("=" * 60) - - -if __name__ == "__main__": - main() - diff --git a/apps/multimodal/vision-based-pdf-multi-vector/leann_multi_vector.py b/apps/multimodal/vision-based-pdf-multi-vector/leann_multi_vector.py index 2ea933f..8353d3a 100644 --- a/apps/multimodal/vision-based-pdf-multi-vector/leann_multi_vector.py +++ b/apps/multimodal/vision-based-pdf-multi-vector/leann_multi_vector.py @@ -3,7 +3,6 @@ import json import os import re import sys -import time from pathlib import Path from typing import Any, Optional, cast @@ -195,7 +194,7 @@ def _embed_images(model, processor, images: list[Image.Image]) -> list[Any]: dataloader = DataLoader( dataset=ListDataset[Image.Image](images), - batch_size=32, + batch_size=1, shuffle=False, collate_fn=lambda x: processor.process_images(x), ) @@ -219,47 +218,32 @@ def _embed_images(model, processor, images: list[Image.Image]) -> list[Any]: def _embed_queries(model, processor, queries: list[str]) -> list[Any]: import torch + from colpali_engine.utils.torch_utils import ListDataset + from torch.utils.data import DataLoader model.eval() - # Match MTEB's exact query processing from ColPaliEngineWrapper.get_text_embeddings: - # 1. MTEB receives batch["text"] which already includes instruction/prompt (from _combine_queries_with_instruction_text) - # 2. Manually adds: query_prefix + text + query_augmentation_token * 10 - # 3. Calls processor.process_queries(batch) where batch is now a list of strings - # 4. process_queries adds: query_prefix + text + suffix (suffix = query_augmentation_token * 10) - # - # This results in duplicate addition: query_prefix is added twice, query_augmentation_token * 20 total - # We need to match this exactly to reproduce MTEB results - - all_embeds = [] - batch_size = 32 # Match MTEB's default batch_size - - with torch.no_grad(): - for i in tqdm(range(0, len(queries), batch_size), desc="Embedding queries"): - batch_queries = queries[i : i + batch_size] - - # Match MTEB: manually add query_prefix + text + query_augmentation_token * 10 - # Then process_queries will add them again (resulting in 20 augmentation tokens total) - batch = [ - processor.query_prefix + t + processor.query_augmentation_token * 10 - for t in batch_queries - ] - inputs = processor.process_queries(batch) - inputs = {k: v.to(model.device) for k, v in inputs.items()} + dataloader = DataLoader( + dataset=ListDataset[str](queries), + batch_size=1, + shuffle=False, + collate_fn=lambda x: processor.process_queries(x), + ) + q_vecs: list[Any] = [] + for batch_query in tqdm(dataloader, desc="Embedding queries"): + with torch.no_grad(): + batch_query = {k: v.to(model.device) for k, v in batch_query.items()} if model.device.type == "cuda": with torch.autocast( device_type="cuda", dtype=model.dtype if model.dtype.is_floating_point else torch.bfloat16, ): - outs = model(**inputs) + embeddings_query = model(**batch_query) else: - outs = model(**inputs) - - # Match MTEB: convert to float32 on CPU - all_embeds.extend(list(torch.unbind(outs.cpu().to(torch.float32)))) - - return all_embeds + embeddings_query = model(**batch_query) + q_vecs.extend(list(torch.unbind(embeddings_query.to("cpu")))) + return q_vecs def _build_index( @@ -299,279 +283,6 @@ def _load_retriever_if_index_exists(index_path: str) -> Optional[Any]: return None -def _build_fast_plaid_index( - index_path: str, - doc_vecs: list[Any], - filepaths: list[str], - images: list[Image.Image], -) -> tuple[Any, float]: - """ - Build a Fast-Plaid index from document embeddings. - - Args: - index_path: Path to save the Fast-Plaid index - doc_vecs: List of document embeddings (each is a tensor with shape [num_tokens, embedding_dim]) - filepaths: List of filepath identifiers for each document - images: List of PIL Images corresponding to each document - - Returns: - Tuple of (FastPlaid index object, build_time_in_seconds) - """ - import torch - from fast_plaid import search as fast_plaid_search - - print(f" Preparing {len(doc_vecs)} document embeddings for Fast-Plaid...") - _t0 = time.perf_counter() - - # Convert doc_vecs to list of tensors - documents_embeddings = [] - for i, vec in enumerate(doc_vecs): - if i % 1000 == 0: - print(f" Converting embedding {i}/{len(doc_vecs)}...") - if not isinstance(vec, torch.Tensor): - vec = ( - torch.tensor(vec) - if isinstance(vec, np.ndarray) - else torch.from_numpy(np.array(vec)) - ) - # Ensure float32 for Fast-Plaid - if vec.dtype != torch.float32: - vec = vec.float() - documents_embeddings.append(vec) - - print(f" Converted {len(documents_embeddings)} embeddings") - if len(documents_embeddings) > 0: - print(f" First embedding shape: {documents_embeddings[0].shape}") - print(f" First embedding dtype: {documents_embeddings[0].dtype}") - - # Prepare metadata for Fast-Plaid - print(f" Preparing metadata for {len(filepaths)} documents...") - metadata_list = [] - for i, filepath in enumerate(filepaths): - metadata_list.append( - { - "filepath": filepath, - "index": i, - } - ) - - # Create Fast-Plaid index - print(f" Creating FastPlaid object with index path: {index_path}") - try: - fast_plaid_index = fast_plaid_search.FastPlaid(index=index_path) - print(" FastPlaid object created successfully") - except Exception as e: - print(f" Error creating FastPlaid object: {type(e).__name__}: {e}") - import traceback - - traceback.print_exc() - raise - - print(f" Calling fast_plaid_index.create() with {len(documents_embeddings)} documents...") - try: - fast_plaid_index.create( - documents_embeddings=documents_embeddings, - metadata=metadata_list, - ) - print(" Fast-Plaid index created successfully") - except Exception as e: - print(f" Error creating Fast-Plaid index: {type(e).__name__}: {e}") - import traceback - - traceback.print_exc() - raise - - build_secs = time.perf_counter() - _t0 - - # Save images separately (Fast-Plaid doesn't store images) - print(f" Saving {len(images)} images...") - images_dir = Path(index_path) / "images" - images_dir.mkdir(parents=True, exist_ok=True) - for i, img in enumerate(tqdm(images, desc="Saving images")): - img_path = images_dir / f"doc_{i}.png" - img.save(str(img_path)) - - return fast_plaid_index, build_secs - - -def _fast_plaid_index_exists(index_path: str) -> bool: - """ - Check if Fast-Plaid index exists by checking for key files. - This avoids creating the FastPlaid object which may trigger memory allocation. - - Args: - index_path: Path to the Fast-Plaid index - - Returns: - True if index appears to exist, False otherwise - """ - index_path_obj = Path(index_path) - if not index_path_obj.exists() or not index_path_obj.is_dir(): - return False - - # Fast-Plaid creates a SQLite database file for metadata - # Check for metadata.db as the most reliable indicator - metadata_db = index_path_obj / "metadata.db" - if metadata_db.exists() and metadata_db.stat().st_size > 0: - return True - - # Also check if directory has any files (might be incomplete index) - try: - if any(index_path_obj.iterdir()): - return True - except Exception: - pass - - return False - - -def _load_fast_plaid_index_if_exists(index_path: str) -> Optional[Any]: - """ - Load Fast-Plaid index if it exists. - First checks if index files exist, then creates the FastPlaid object. - The actual index data loading happens lazily when search is called. - - Args: - index_path: Path to the Fast-Plaid index - - Returns: - FastPlaid index object if exists, None otherwise - """ - try: - from fast_plaid import search as fast_plaid_search - - # First check if index files exist without creating the object - if not _fast_plaid_index_exists(index_path): - return None - - # Now try to create FastPlaid object - # This may trigger some memory allocation, but the full index loading is deferred - fast_plaid_index = fast_plaid_search.FastPlaid(index=index_path) - return fast_plaid_index - except ImportError: - # fast-plaid not installed - return None - except Exception as e: - # Any error (including memory errors from Rust backend) - return None - # The error will be caught and index will be rebuilt - print(f"Warning: Could not load Fast-Plaid index: {type(e).__name__}: {e}") - return None - - -def _search_fast_plaid( - fast_plaid_index: Any, - query_vec: Any, - top_k: int, -) -> tuple[list[tuple[float, int]], float]: - """ - Search Fast-Plaid index with a query embedding. - - Args: - fast_plaid_index: FastPlaid index object - query_vec: Query embedding tensor with shape [num_tokens, embedding_dim] - top_k: Number of top results to return - - Returns: - Tuple of (results_list, search_time_in_seconds) - results_list: List of (score, doc_id) tuples - """ - import torch - - _t0 = time.perf_counter() - - # Ensure query is a torch tensor - if not isinstance(query_vec, torch.Tensor): - q_vec_tensor = ( - torch.tensor(query_vec) - if isinstance(query_vec, np.ndarray) - else torch.from_numpy(np.array(query_vec)) - ) - else: - q_vec_tensor = query_vec - - # Fast-Plaid expects shape [num_queries, num_tokens, embedding_dim] - if q_vec_tensor.dim() == 2: - q_vec_tensor = q_vec_tensor.unsqueeze(0) # [1, num_tokens, embedding_dim] - - # Perform search - scores = fast_plaid_index.search( - queries_embeddings=q_vec_tensor, - top_k=top_k, - show_progress=True, - ) - - search_secs = time.perf_counter() - _t0 - - # Convert Fast-Plaid results to same format as LEANN: list of (score, doc_id) tuples - results = [] - if scores and len(scores) > 0: - query_results = scores[0] - # Fast-Plaid returns (doc_id, score), convert to (score, doc_id) to match LEANN format - results = [(float(score), int(doc_id)) for doc_id, score in query_results] - - return results, search_secs - - -def _get_fast_plaid_image(index_path: str, doc_id: int) -> Optional[Image.Image]: - """ - Retrieve image for a document from Fast-Plaid index. - - Args: - index_path: Path to the Fast-Plaid index - doc_id: Document ID returned by Fast-Plaid search - - Returns: - PIL Image if found, None otherwise - - Note: Uses metadata['index'] to get the actual file index, as Fast-Plaid - doc_id may differ from the file naming index. - """ - # First get metadata to find the actual index used for file naming - metadata = _get_fast_plaid_metadata(index_path, doc_id) - if metadata is None: - # Fallback: try using doc_id directly - file_index = doc_id - else: - # Use the 'index' field from metadata, which matches the file naming - file_index = metadata.get("index", doc_id) - - images_dir = Path(index_path) / "images" - image_path = images_dir / f"doc_{file_index}.png" - - if image_path.exists(): - return Image.open(image_path) - - # If not found with index, try doc_id as fallback - if file_index != doc_id: - fallback_path = images_dir / f"doc_{doc_id}.png" - if fallback_path.exists(): - return Image.open(fallback_path) - - return None - - -def _get_fast_plaid_metadata(index_path: str, doc_id: int) -> Optional[dict]: - """ - Retrieve metadata for a document from Fast-Plaid index. - - Args: - index_path: Path to the Fast-Plaid index - doc_id: Document ID - - Returns: - Dictionary with metadata if found, None otherwise - """ - try: - from fast_plaid import filtering - - metadata_list = filtering.get(index=index_path, subset=[doc_id]) - if metadata_list and len(metadata_list) > 0: - return metadata_list[0] - except Exception: - pass - return None - - def _generate_similarity_map( model, processor, @@ -967,15 +678,11 @@ class LeannMultiVector: return (float(score), doc_id) scores: list[tuple[float, int]] = [] - # load and core time - start_time = time.time() with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as ex: futures = [ex.submit(_score_one, doc_id) for doc_id in candidate_doc_ids] for fut in concurrent.futures.as_completed(futures): scores.append(fut.result()) - end_time = time.time() - print(f"Number of candidate doc ids: {len(candidate_doc_ids)}") - print(f"Time taken in load and core time: {end_time - start_time} seconds") + scores.sort(key=lambda x: x[0], reverse=True) return scores[:topk] if len(scores) >= topk else scores @@ -1003,6 +710,7 @@ class LeannMultiVector: emb_path = self._embeddings_path() if not emb_path.exists(): return self.search(data, topk) + all_embeddings = np.load(emb_path, mmap_mode="r") if all_embeddings.dtype != np.float32: all_embeddings = all_embeddings.astype(np.float32) @@ -1010,29 +718,23 @@ class LeannMultiVector: assert self._docid_to_indices is not None candidate_doc_ids = list(self._docid_to_indices.keys()) - def _score_one(doc_id: int, _all_embeddings=all_embeddings) -> tuple[float, int]: + def _score_one(doc_id: int) -> tuple[float, int]: token_indices = self._docid_to_indices.get(doc_id, []) if not token_indices: return (0.0, doc_id) - doc_vecs = np.asarray(_all_embeddings[token_indices], dtype=np.float32) + doc_vecs = np.asarray(all_embeddings[token_indices], dtype=np.float32) sim = np.dot(data, doc_vecs.T) sim = np.nan_to_num(sim, nan=-1e30, posinf=1e30, neginf=-1e30) score = sim.max(axis=2).sum(axis=1) if sim.ndim == 3 else sim.max(axis=1).sum() return (float(score), doc_id) scores: list[tuple[float, int]] = [] - # load and core time - start_time = time.time() with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as ex: futures = [ex.submit(_score_one, d) for d in candidate_doc_ids] for fut in concurrent.futures.as_completed(futures): scores.append(fut.result()) - end_time = time.time() - # print number of candidate doc ids - print(f"Number of candidate doc ids: {len(candidate_doc_ids)}") - print(f"Time taken in load and core time: {end_time - start_time} seconds") + scores.sort(key=lambda x: x[0], reverse=True) - del all_embeddings return scores[:topk] if len(scores) >= topk else scores def get_image(self, doc_id: int) -> Optional[Image.Image]: @@ -1076,259 +778,3 @@ class LeannMultiVector: "image_path": meta.get("image_path", ""), } return None - - -class ViDoReBenchmarkEvaluator: - """ - A reusable class for evaluating ViDoRe benchmarks (v1 and v2). - This class encapsulates common functionality for building indexes, searching, and evaluating. - """ - - def __init__( - self, - model_name: str, - use_fast_plaid: bool = False, - top_k: int = 100, - first_stage_k: int = 500, - k_values: Optional[list[int]] = None, - ): - """ - Initialize the evaluator. - - Args: - model_name: Model name ("colqwen2" or "colpali") - use_fast_plaid: Whether to use Fast-Plaid instead of LEANN - top_k: Top-k results to retrieve - first_stage_k: First stage k for LEANN search - k_values: List of k values for evaluation metrics - """ - self.model_name = model_name - self.use_fast_plaid = use_fast_plaid - self.top_k = top_k - self.first_stage_k = first_stage_k - self.k_values = k_values if k_values is not None else [1, 3, 5, 10, 100] - - # Load model once (can be reused across tasks) - self._model = None - self._processor = None - self._model_name_actual = None - - def _load_model_if_needed(self): - """Lazy load the model.""" - if self._model is None: - print(f"\nLoading model: {self.model_name}") - self._model_name_actual, self._model, self._processor, _, _, _ = _load_colvision( - self.model_name - ) - print(f"Model loaded: {self._model_name_actual}") - - def build_index_from_corpus( - self, - corpus: dict[str, Image.Image], - index_path: str, - rebuild: bool = False, - ) -> tuple[Any, list[str]]: - """ - Build index from corpus images. - - Args: - corpus: dict mapping corpus_id to PIL Image - index_path: Path to save/load the index - rebuild: Whether to rebuild even if index exists - - Returns: - tuple: (retriever or fast_plaid_index object, list of corpus_ids in order) - """ - self._load_model_if_needed() - - # Ensure consistent ordering - corpus_ids = sorted(corpus.keys()) - images = [corpus[cid] for cid in corpus_ids] - - if self.use_fast_plaid: - # Check if Fast-Plaid index exists - if not rebuild and _load_fast_plaid_index_if_exists(index_path) is not None: - print(f"Fast-Plaid index already exists at {index_path}") - return _load_fast_plaid_index_if_exists(index_path), corpus_ids - - print(f"Building Fast-Plaid index at {index_path}...") - print("Embedding images...") - doc_vecs = _embed_images(self._model, self._processor, images) - - fast_plaid_index, build_time = _build_fast_plaid_index( - index_path, doc_vecs, corpus_ids, images - ) - print(f"Fast-Plaid index built in {build_time:.2f}s") - return fast_plaid_index, corpus_ids - else: - # Check if LEANN index exists - if not rebuild: - retriever = _load_retriever_if_index_exists(index_path) - if retriever is not None: - print(f"LEANN index already exists at {index_path}") - return retriever, corpus_ids - - print(f"Building LEANN index at {index_path}...") - print("Embedding images...") - doc_vecs = _embed_images(self._model, self._processor, images) - - retriever = _build_index(index_path, doc_vecs, corpus_ids, images) - print("LEANN index built") - return retriever, corpus_ids - - def search_queries( - self, - queries: dict[str, str], - corpus_ids: list[str], - index_or_retriever: Any, - fast_plaid_index_path: Optional[str] = None, - task_prompt: Optional[dict[str, str]] = None, - ) -> dict[str, dict[str, float]]: - """ - Search queries against the index. - - Args: - queries: dict mapping query_id to query text - corpus_ids: list of corpus_ids in the same order as the index - index_or_retriever: index or retriever object - fast_plaid_index_path: path to Fast-Plaid index (for metadata) - task_prompt: Optional dict with prompt for query (e.g., {"query": "..."}) - - Returns: - results: dict mapping query_id to dict of {corpus_id: score} - """ - self._load_model_if_needed() - - print(f"Searching {len(queries)} queries (top_k={self.top_k})...") - - query_ids = list(queries.keys()) - query_texts = [queries[qid] for qid in query_ids] - - # Note: ColPaliEngineWrapper does NOT use task prompt from metadata - # It uses query_prefix + text + query_augmentation_token (handled in _embed_queries) - # So we don't append task_prompt here to match MTEB behavior - - # Embed queries - print("Embedding queries...") - query_vecs = _embed_queries(self._model, self._processor, query_texts) - - results = {} - - for query_id, query_vec in zip(tqdm(query_ids, desc="Searching"), query_vecs): - if self.use_fast_plaid: - # Fast-Plaid search - search_results, _ = _search_fast_plaid(index_or_retriever, query_vec, self.top_k) - query_results = {} - for score, doc_id in search_results: - if doc_id < len(corpus_ids): - corpus_id = corpus_ids[doc_id] - query_results[corpus_id] = float(score) - else: - # LEANN search - import torch - - query_np = ( - query_vec.float().numpy() if isinstance(query_vec, torch.Tensor) else query_vec - ) - search_results = index_or_retriever.search_exact(query_np, topk=self.top_k) - query_results = {} - for score, doc_id in search_results: - if doc_id < len(corpus_ids): - corpus_id = corpus_ids[doc_id] - query_results[corpus_id] = float(score) - - results[query_id] = query_results - - return results - - @staticmethod - def evaluate_results( - results: dict[str, dict[str, float]], - qrels: dict[str, dict[str, int]], - k_values: Optional[list[int]] = None, - ) -> dict[str, float]: - """ - Evaluate retrieval results using NDCG and other metrics. - - Args: - results: dict mapping query_id to dict of {corpus_id: score} - qrels: dict mapping query_id to dict of {corpus_id: relevance_score} - k_values: List of k values for evaluation metrics - - Returns: - Dictionary of metric scores - """ - try: - from mteb._evaluators.retrieval_metrics import ( - calculate_retrieval_scores, - make_score_dict, - ) - except ImportError: - raise ImportError( - "pytrec_eval is required for evaluation. Install with: pip install pytrec-eval" - ) - - if k_values is None: - k_values = [1, 3, 5, 10, 100] - - # Check if we have any queries to evaluate - if len(results) == 0: - print("Warning: No queries to evaluate. Returning zero scores.") - scores = {} - for k in k_values: - scores[f"ndcg_at_{k}"] = 0.0 - scores[f"map_at_{k}"] = 0.0 - scores[f"recall_at_{k}"] = 0.0 - scores[f"precision_at_{k}"] = 0.0 - scores[f"mrr_at_{k}"] = 0.0 - return scores - - print(f"Evaluating results with k_values={k_values}...") - print(f"Before filtering: {len(results)} results, {len(qrels)} qrels") - - # Filter to ensure qrels and results have the same query set - # This matches MTEB behavior: only evaluate queries that exist in both - # pytrec_eval only evaluates queries in qrels, so we need to ensure - # results contains all queries in qrels, and filter out queries not in qrels - results_filtered = {qid: res for qid, res in results.items() if qid in qrels} - qrels_filtered = { - qid: rel_docs for qid, rel_docs in qrels.items() if qid in results_filtered - } - - print(f"After filtering: {len(results_filtered)} results, {len(qrels_filtered)} qrels") - - if len(results_filtered) != len(qrels_filtered): - print( - f"Warning: Mismatch between results ({len(results_filtered)}) and qrels ({len(qrels_filtered)}) queries" - ) - missing_in_results = set(qrels.keys()) - set(results.keys()) - if missing_in_results: - print(f"Queries in qrels but not in results: {len(missing_in_results)} queries") - print(f"First 5 missing queries: {list(missing_in_results)[:5]}") - - # Convert qrels to pytrec_eval format - qrels_pytrec = {} - for qid, rel_docs in qrels_filtered.items(): - qrels_pytrec[qid] = dict(rel_docs.items()) - - # Evaluate - eval_result = calculate_retrieval_scores( - results=results_filtered, - qrels=qrels_pytrec, - k_values=k_values, - ) - - # Format scores - scores = make_score_dict( - ndcg=eval_result.ndcg, - _map=eval_result.map, - recall=eval_result.recall, - precision=eval_result.precision, - mrr=eval_result.mrr, - naucs=eval_result.naucs, - naucs_mrr=eval_result.naucs_mrr, - cv_recall=eval_result.cv_recall, - task_scores={}, - ) - - return scores diff --git a/apps/multimodal/vision-based-pdf-multi-vector/multi-vector-leann-similarity-map.py b/apps/multimodal/vision-based-pdf-multi-vector/multi-vector-leann-similarity-map.py index 42b4f00..c4c01e8 100644 --- a/apps/multimodal/vision-based-pdf-multi-vector/multi-vector-leann-similarity-map.py +++ b/apps/multimodal/vision-based-pdf-multi-vector/multi-vector-leann-similarity-map.py @@ -1,19 +1,12 @@ ## Jupyter-style notebook script # %% # uv pip install matplotlib qwen_vl_utils -import argparse -import faulthandler import os -import time from typing import Any, Optional -import numpy as np from PIL import Image from tqdm import tqdm -# Enable faulthandler to get stack trace on segfault -faulthandler.enable() - from leann_multi_vector import ( # utility functions/classes _ensure_repo_paths_importable, @@ -25,11 +18,6 @@ from leann_multi_vector import ( # utility functions/classes _build_index, _load_retriever_if_index_exists, _generate_similarity_map, - _build_fast_plaid_index, - _load_fast_plaid_index_if_exists, - _search_fast_plaid, - _get_fast_plaid_image, - _get_fast_plaid_metadata, QwenVL, ) @@ -43,33 +31,8 @@ MODEL: str = "colqwen2" # "colpali" or "colqwen2" # Data source: set to True to use the Hugging Face dataset example (recommended) USE_HF_DATASET: bool = True -# Single dataset name (used when DATASET_NAMES is None) DATASET_NAME: str = "weaviate/arXiv-AI-papers-multi-vector" -# Multiple datasets to combine (if provided, DATASET_NAME is ignored) -# Can be: -# - List of strings: ["dataset1", "dataset2"] -# - List of tuples: [("dataset1", "config1"), ("dataset2", None)] # None = no config needed -# - Mixed: ["dataset1", ("dataset2", "config2")] -# -# Some potential datasets with images (may need IMAGE_FIELD_NAME adjustment): -# - "weaviate/arXiv-AI-papers-multi-vector" (current, has "page_image" field) -# - ("lmms-lab/DocVQA", "DocVQA") (has "image" field, document images, needs config) -# - ("lmms-lab/DocVQA", "InfographicVQA") (has "image" field, infographic images) -# - "pixparse/arxiv-papers" (if available, arXiv papers) -# - "allenai/ai2d" (AI2D diagram dataset, has "image" field) -# - "huggingface/document-images" (if available) -# Note: Check dataset structure first - some may need IMAGE_FIELD_NAME specified -# DATASET_NAMES: Optional[list[str | tuple[str, Optional[str]]]] = None -DATASET_NAMES = [ - "weaviate/arXiv-AI-papers-multi-vector", - ("lmms-lab/DocVQA", "DocVQA"), # Specify config name for datasets with multiple configs -] -# Load multiple splits to get more data (e.g., ["train", "test", "validation"]) -# Set to None to try loading all available splits automatically -DATASET_SPLITS: Optional[list[str]] = ["train", "test"] # None = auto-detect all splits -# Image field name in the dataset (auto-detect if None) -# Common names: "page_image", "image", "images", "img" -IMAGE_FIELD_NAME: Optional[str] = None # None = auto-detect +DATASET_SPLIT: str = "train" MAX_DOCS: Optional[int] = None # limit number of pages to index; None = all # Local pages (used when USE_HF_DATASET == False) @@ -77,13 +40,10 @@ PDF: Optional[str] = None # e.g., "./pdfs/2004.12832v2.pdf" PAGES_DIR: str = "./pages" # Index + retrieval settings -# Use a different index path for larger dataset to avoid overwriting existing index -INDEX_PATH: str = "./indexes/colvision_large.leann" -# Fast-Plaid index settings (alternative to LEANN index) -# These are now command-line arguments (see CLI overrides section) +INDEX_PATH: str = "./indexes/colvision.leann" TOPK: int = 3 FIRST_STAGE_K: int = 500 -REBUILD_INDEX: bool = True +REBUILD_INDEX: bool = False # Artifacts SAVE_TOP_IMAGE: Optional[str] = "./figures/retrieved_page.png" @@ -94,310 +54,38 @@ ANSWER: bool = True MAX_NEW_TOKENS: int = 1024 -# %% -# CLI overrides -parser = argparse.ArgumentParser(description="Multi-vector LEANN similarity map demo") -parser.add_argument( - "--search-method", - type=str, - choices=["ann", "exact", "exact-all"], - default="ann", - help="Which search method to use: 'ann' (fast ANN), 'exact' (ANN + exact rerank), or 'exact-all' (exact over all docs).", -) -parser.add_argument( - "--query", - type=str, - default=QUERY, - help=f"Query string to search for. Default: '{QUERY}'", -) -parser.add_argument( - "--use-fast-plaid", - action="store_true", - default=False, - help="Set to True to use fast-plaid instead of LEANN. Default: False", -) -parser.add_argument( - "--fast-plaid-index-path", - type=str, - default="./indexes/colvision_fastplaid", - help="Path to the Fast-Plaid index. Default: './indexes/colvision_fastplaid'", -) -parser.add_argument( - "--topk", - type=int, - default=TOPK, - help=f"Number of top results to retrieve. Default: {TOPK}", -) -cli_args, _unknown = parser.parse_known_args() -SEARCH_METHOD: str = cli_args.search_method -QUERY = cli_args.query # Override QUERY with CLI argument if provided -USE_FAST_PLAID: bool = cli_args.use_fast_plaid -FAST_PLAID_INDEX_PATH: str = cli_args.fast_plaid_index_path -TOPK: int = cli_args.topk # Override TOPK with CLI argument if provided - # %% # Step 1: Check if we can skip data loading (index already exists) retriever: Optional[Any] = None -fast_plaid_index: Optional[Any] = None need_to_build_index = REBUILD_INDEX -if USE_FAST_PLAID: - # Fast-Plaid index handling - if not REBUILD_INDEX: - try: - fast_plaid_index = _load_fast_plaid_index_if_exists(FAST_PLAID_INDEX_PATH) - if fast_plaid_index is not None: - print(f"✓ Fast-Plaid index found at {FAST_PLAID_INDEX_PATH}") - need_to_build_index = False - else: - print(f"Fast-Plaid index not found, will build new index") - need_to_build_index = True - except Exception as e: - # If loading fails (e.g., memory error, corrupted index), rebuild - print(f"Warning: Failed to load Fast-Plaid index: {e}") - print("Will rebuild the index...") - need_to_build_index = True - fast_plaid_index = None +if not REBUILD_INDEX: + retriever = _load_retriever_if_index_exists(INDEX_PATH) + if retriever is not None: + print(f"✓ Index loaded from {INDEX_PATH}") + print(f"✓ Images available at: {retriever._images_dir_path()}") + need_to_build_index = False else: - print(f"REBUILD_INDEX=True, will rebuild Fast-Plaid index") - need_to_build_index = True -else: - # Original LEANN index handling - if not REBUILD_INDEX: - retriever = _load_retriever_if_index_exists(INDEX_PATH) - if retriever is not None: - print(f"✓ Index loaded from {INDEX_PATH}") - print(f"✓ Images available at: {retriever._images_dir_path()}") - need_to_build_index = False - else: - print(f"Index not found, will build new index") - need_to_build_index = True - else: - print(f"REBUILD_INDEX=True, will rebuild index") + print(f"Index not found, will build new index") need_to_build_index = True # Step 2: Load data only if we need to build the index if need_to_build_index: print("Loading dataset...") if USE_HF_DATASET: - from datasets import load_dataset, concatenate_datasets, DatasetDict + from datasets import load_dataset - # Determine which datasets to load - if DATASET_NAMES is not None: - dataset_names_to_load = DATASET_NAMES - print(f"Loading {len(dataset_names_to_load)} datasets: {dataset_names_to_load}") - else: - dataset_names_to_load = [DATASET_NAME] - print(f"Loading single dataset: {DATASET_NAME}") - - # Load and combine datasets - all_datasets_to_concat = [] - - for dataset_entry in dataset_names_to_load: - # Handle both string and tuple formats - if isinstance(dataset_entry, tuple): - dataset_name, config_name = dataset_entry - else: - dataset_name = dataset_entry - config_name = None - - print(f"\nProcessing dataset: {dataset_name}" + (f" (config: {config_name})" if config_name else "")) - - # Load dataset to check available splits - # If config_name is provided, use it; otherwise try without config - try: - if config_name: - dataset_dict = load_dataset(dataset_name, config_name) - else: - dataset_dict = load_dataset(dataset_name) - except ValueError as e: - if "Config name is missing" in str(e): - # Try to get available configs and suggest - from datasets import get_dataset_config_names - try: - available_configs = get_dataset_config_names(dataset_name) - raise ValueError( - f"Dataset '{dataset_name}' requires a config name. " - f"Available configs: {available_configs}. " - f"Please specify as: ('{dataset_name}', 'config_name')" - ) from e - except Exception: - raise ValueError( - f"Dataset '{dataset_name}' requires a config name. " - f"Please specify as: ('{dataset_name}', 'config_name')" - ) from e - raise - - # Determine which splits to load - if DATASET_SPLITS is None: - # Auto-detect: try to load all available splits - available_splits = list(dataset_dict.keys()) - print(f" Auto-detected splits: {available_splits}") - splits_to_load = available_splits - else: - splits_to_load = DATASET_SPLITS - - # Load and concatenate multiple splits for this dataset - datasets_to_concat = [] - for split in splits_to_load: - if split not in dataset_dict: - print(f" Warning: Split '{split}' not found in dataset. Available splits: {list(dataset_dict.keys())}") - continue - split_dataset = dataset_dict[split] - print(f" Loaded split '{split}': {len(split_dataset)} pages") - datasets_to_concat.append(split_dataset) - - if not datasets_to_concat: - print(f" Warning: No valid splits found for {dataset_name}. Skipping.") - continue - - # Concatenate splits for this dataset - if len(datasets_to_concat) > 1: - combined_dataset = concatenate_datasets(datasets_to_concat) - print(f" Concatenated {len(datasets_to_concat)} splits into {len(combined_dataset)} pages") - else: - combined_dataset = datasets_to_concat[0] - - all_datasets_to_concat.append(combined_dataset) - - if not all_datasets_to_concat: - raise RuntimeError("No valid datasets or splits found.") - - # Concatenate all datasets - if len(all_datasets_to_concat) > 1: - dataset = concatenate_datasets(all_datasets_to_concat) - print(f"\nConcatenated {len(all_datasets_to_concat)} datasets into {len(dataset)} total pages") - else: - dataset = all_datasets_to_concat[0] - - # Apply MAX_DOCS limit if specified + dataset = load_dataset(DATASET_NAME, split=DATASET_SPLIT) N = len(dataset) if MAX_DOCS is None else min(MAX_DOCS, len(dataset)) - if N < len(dataset): - print(f"Limiting to {N} pages (from {len(dataset)} total)") - dataset = dataset.select(range(N)) - - # Auto-detect image field name if not specified - if IMAGE_FIELD_NAME is None: - # Check multiple samples to find the most common image field - # (useful when datasets are merged and may have different field names) - possible_image_fields = ["page_image", "image", "images", "img", "page", "document_image"] - field_counts = {} - - # Check first few samples to find image fields - num_samples_to_check = min(10, len(dataset)) - for sample_idx in range(num_samples_to_check): - sample = dataset[sample_idx] - for field in possible_image_fields: - if field in sample and sample[field] is not None: - value = sample[field] - if isinstance(value, Image.Image) or (hasattr(value, 'size') and hasattr(value, 'mode')): - field_counts[field] = field_counts.get(field, 0) + 1 - - # Choose the most common field, or first found if tied - if field_counts: - image_field = max(field_counts.items(), key=lambda x: x[1])[0] - print(f"Auto-detected image field: '{image_field}' (found in {field_counts[image_field]}/{num_samples_to_check} samples)") - else: - # Fallback: check first sample only - sample = dataset[0] - image_field = None - for field in possible_image_fields: - if field in sample: - value = sample[field] - if isinstance(value, Image.Image) or (hasattr(value, 'size') and hasattr(value, 'mode')): - image_field = field - break - if image_field is None: - raise RuntimeError( - f"Could not auto-detect image field. Available fields: {list(sample.keys())}. " - f"Please specify IMAGE_FIELD_NAME manually." - ) - print(f"Auto-detected image field: '{image_field}'") - else: - image_field = IMAGE_FIELD_NAME - if image_field not in dataset[0]: - raise RuntimeError( - f"Image field '{image_field}' not found. Available fields: {list(dataset[0].keys())}" - ) - filepaths: list[str] = [] images: list[Image.Image] = [] - for i in tqdm(range(len(dataset)), desc="Loading dataset", total=len(dataset)): + for i in tqdm(range(N), desc="Loading dataset", total=N): p = dataset[i] - # Try to compose a descriptive identifier - # Handle different dataset structures - identifier_parts = [] - - # Helper function to safely get field value - def safe_get(field_name, default=None): - if field_name in p and p[field_name] is not None: - return p[field_name] - return default - - # Try to get various identifier fields - if safe_get("paper_arxiv_id"): - identifier_parts.append(f"arXiv:{p['paper_arxiv_id']}") - if safe_get("paper_title"): - identifier_parts.append(f"title:{p['paper_title']}") - if safe_get("page_number") is not None: - try: - identifier_parts.append(f"page:{int(p['page_number'])}") - except (ValueError, TypeError): - # If conversion fails, use the raw value or skip - if p['page_number']: - identifier_parts.append(f"page:{p['page_number']}") - if safe_get("page_id"): - identifier_parts.append(f"id:{p['page_id']}") - elif safe_get("questionId"): - identifier_parts.append(f"qid:{p['questionId']}") - elif safe_get("docId"): - identifier_parts.append(f"docId:{p['docId']}") - elif safe_get("id"): - identifier_parts.append(f"id:{p['id']}") - - # If no identifier parts found, create one from index - if identifier_parts: - identifier = "|".join(identifier_parts) - else: - # Create identifier from available fields or index - fallback_parts = [] - # Try common fields that might exist - for field in ["ucsf_document_id", "docId", "questionId", "id"]: - if safe_get(field): - fallback_parts.append(f"{field}:{p[field]}") - break - if fallback_parts: - identifier = "|".join(fallback_parts) + f"|idx:{i}" - else: - identifier = f"doc_{i}" - + # Compose a descriptive identifier for printing later + identifier = f"arXiv:{p['paper_arxiv_id']}|title:{p['paper_title']}|page:{int(p['page_number'])}|id:{p['page_id']}" filepaths.append(identifier) - - # Get image - try detected field first, then fallback to other common fields - img = None - if image_field in p and p[image_field] is not None: - img = p[image_field] - else: - # Fallback: try other common image field names - for fallback_field in ["image", "page_image", "images", "img"]: - if fallback_field in p and p[fallback_field] is not None: - img = p[fallback_field] - break - - if img is None: - raise RuntimeError( - f"No image found for sample {i}. Available fields: {list(p.keys())}. " - f"Expected field: {image_field}" - ) - - # Ensure it's a PIL Image - if not isinstance(img, Image.Image): - if hasattr(img, 'convert'): - img = img.convert('RGB') - else: - img = Image.fromarray(img) if hasattr(img, '__array__') else Image.open(img) - images.append(img) + images.append(p["page_image"]) # PIL Image else: _maybe_convert_pdf_to_images(PDF, PAGES_DIR) filepaths, images = _load_images_from_dir(PAGES_DIR) @@ -406,19 +94,6 @@ if need_to_build_index: f"No images found in {PAGES_DIR}. Provide PDF path in PDF variable or ensure images exist." ) print(f"Loaded {len(images)} images") - - # Memory check before loading model - try: - import psutil - import torch - process = psutil.Process(os.getpid()) - mem_info = process.memory_info() - print(f"Memory usage after loading images: {mem_info.rss / 1024 / 1024 / 1024:.2f} GB") - if torch.cuda.is_available(): - print(f"GPU memory allocated: {torch.cuda.memory_allocated() / 1024**3:.2f} GB") - print(f"GPU memory reserved: {torch.cuda.memory_reserved() / 1024**3:.2f} GB") - except ImportError: - pass else: print("Skipping dataset loading (using existing index)") filepaths = [] # Not needed when using existing index @@ -427,181 +102,46 @@ else: # %% # Step 3: Load model and processor (only if we need to build index or perform search) -print("Step 3: Loading model and processor...") -print(f" Model: {MODEL}") -try: - import sys - print(f" Python version: {sys.version}") - print(f" Python executable: {sys.executable}") - - model_name, model, processor, device_str, device, dtype = _load_colvision(MODEL) - print(f"✓ Using model={model_name}, device={device_str}, dtype={dtype}") - - # Memory check after loading model - try: - import psutil - import torch - process = psutil.Process(os.getpid()) - mem_info = process.memory_info() - print(f" Memory usage after loading model: {mem_info.rss / 1024 / 1024 / 1024:.2f} GB") - if torch.cuda.is_available(): - print(f" GPU memory allocated: {torch.cuda.memory_allocated() / 1024**3:.2f} GB") - print(f" GPU memory reserved: {torch.cuda.memory_reserved() / 1024**3:.2f} GB") - except ImportError: - pass -except Exception as e: - print(f"✗ Error loading model: {type(e).__name__}: {e}") - import traceback - traceback.print_exc() - raise +model_name, model, processor, device_str, device, dtype = _load_colvision(MODEL) +print(f"Using model={model_name}, device={device_str}, dtype={dtype}") # %% # %% # Step 4: Build index if needed -if need_to_build_index: - print("Step 4: Building index...") - print(f" Number of images: {len(images)}") - print(f" Number of filepaths: {len(filepaths)}") +if need_to_build_index and retriever is None: + print("Building index...") + doc_vecs = _embed_images(model, processor, images) + retriever = _build_index(INDEX_PATH, doc_vecs, filepaths, images) + print(f"✓ Index built and images saved to: {retriever._images_dir_path()}") + # Clear memory + del images, filepaths, doc_vecs - try: - print(" Embedding images...") - doc_vecs = _embed_images(model, processor, images) - print(f" Embedded {len(doc_vecs)} documents") - print(f" First doc vec shape: {doc_vecs[0].shape if len(doc_vecs) > 0 else 'N/A'}") - except Exception as e: - print(f"Error embedding images: {type(e).__name__}: {e}") - import traceback - traceback.print_exc() - raise - - if USE_FAST_PLAID: - # Build Fast-Plaid index - print(" Building Fast-Plaid index...") - try: - fast_plaid_index, build_secs = _build_fast_plaid_index( - FAST_PLAID_INDEX_PATH, doc_vecs, filepaths, images - ) - from pathlib import Path - print(f"✓ Fast-Plaid index built in {build_secs:.3f}s") - print(f"✓ Index saved to: {FAST_PLAID_INDEX_PATH}") - print(f"✓ Images saved to: {Path(FAST_PLAID_INDEX_PATH) / 'images'}") - except Exception as e: - print(f"Error building Fast-Plaid index: {type(e).__name__}: {e}") - import traceback - traceback.print_exc() - raise - finally: - # Clear memory - print(" Clearing memory...") - del images, filepaths, doc_vecs - else: - # Build original LEANN index - try: - retriever = _build_index(INDEX_PATH, doc_vecs, filepaths, images) - print(f"✓ Index built and images saved to: {retriever._images_dir_path()}") - except Exception as e: - print(f"Error building LEANN index: {type(e).__name__}: {e}") - import traceback - traceback.print_exc() - raise - finally: - # Clear memory - print(" Clearing memory...") - del images, filepaths, doc_vecs - -# Note: Images are now stored separately, retriever/fast_plaid_index will reference them +# Note: Images are now stored in the index, retriever will load them on-demand from disk # %% # Step 5: Embed query and search -_t0 = time.perf_counter() q_vec = _embed_queries(model, processor, [QUERY])[0] -query_embed_secs = time.perf_counter() - _t0 - -print(f"[Search] Method: {SEARCH_METHOD}") -print(f"[Timing] Query embedding: {query_embed_secs:.3f}s") - -# Run the selected search method and time it -if USE_FAST_PLAID: - # Fast-Plaid search - if fast_plaid_index is None: - fast_plaid_index = _load_fast_plaid_index_if_exists(FAST_PLAID_INDEX_PATH) - if fast_plaid_index is None: - raise RuntimeError(f"Fast-Plaid index not found at {FAST_PLAID_INDEX_PATH}") - - results, search_secs = _search_fast_plaid(fast_plaid_index, q_vec, TOPK) - print(f"[Timing] Fast-Plaid Search: {search_secs:.3f}s") -else: - # Original LEANN search - query_np = q_vec.float().numpy() - - if SEARCH_METHOD == "ann": - results = retriever.search(query_np, topk=TOPK, first_stage_k=FIRST_STAGE_K) - search_secs = time.perf_counter() - _t0 - print(f"[Timing] Search (ANN): {search_secs:.3f}s (first_stage_k={FIRST_STAGE_K})") - elif SEARCH_METHOD == "exact": - results = retriever.search_exact(query_np, topk=TOPK, first_stage_k=FIRST_STAGE_K) - search_secs = time.perf_counter() - _t0 - print(f"[Timing] Search (Exact rerank): {search_secs:.3f}s (first_stage_k={FIRST_STAGE_K})") - elif SEARCH_METHOD == "exact-all": - results = retriever.search_exact_all(query_np, topk=TOPK) - search_secs = time.perf_counter() - _t0 - print(f"[Timing] Search (Exact all): {search_secs:.3f}s") - else: - results = [] +results = retriever.search(q_vec.float().numpy(), topk=TOPK) if not results: print("No results found.") else: print(f'Top {len(results)} results for query: "{QUERY}"') - print("\n[DEBUG] Retrieval details:") top_images: list[Image.Image] = [] - image_hashes = {} # Track image hashes to detect duplicates - for rank, (score, doc_id) in enumerate(results, start=1): - # Retrieve image and metadata based on index type - if USE_FAST_PLAID: - # Fast-Plaid: load image and get metadata - image = _get_fast_plaid_image(FAST_PLAID_INDEX_PATH, doc_id) - if image is None: - print(f"Warning: Could not find image for doc_id {doc_id}") - continue + # Retrieve image from index instead of memory + image = retriever.get_image(doc_id) + if image is None: + print(f"Warning: Could not retrieve image for doc_id {doc_id}") + continue - metadata = _get_fast_plaid_metadata(FAST_PLAID_INDEX_PATH, doc_id) - path = metadata.get("filepath", f"doc_{doc_id}") if metadata else f"doc_{doc_id}" - top_images.append(image) - else: - # Original LEANN: retrieve from retriever - image = retriever.get_image(doc_id) - if image is None: - print(f"Warning: Could not retrieve image for doc_id {doc_id}") - continue - - metadata = retriever.get_metadata(doc_id) - path = metadata.get("filepath", "unknown") if metadata else "unknown" - top_images.append(image) - - # Calculate image hash to detect duplicates - import hashlib - import io - # Convert image to bytes for hashing - img_bytes = io.BytesIO() - image.save(img_bytes, format='PNG') - image_bytes = img_bytes.getvalue() - image_hash = hashlib.md5(image_bytes).hexdigest()[:8] - - # Check if this image was already seen - duplicate_info = "" - if image_hash in image_hashes: - duplicate_info = f" [DUPLICATE of rank {image_hashes[image_hash]}]" - else: - image_hashes[image_hash] = rank - - # Print detailed information - print(f"{rank}) doc_id={doc_id}, MaxSim={score:.4f}, Page={path}, ImageHash={image_hash}{duplicate_info}") - if metadata: - print(f" Metadata: {metadata}") + metadata = retriever.get_metadata(doc_id) + path = metadata.get("filepath", "unknown") if metadata else "unknown" + # For HF dataset, path is a descriptive identifier, not a real file path + print(f"{rank}) MaxSim: {score:.4f}, Page: {path}") + top_images.append(image) if SAVE_TOP_IMAGE: from pathlib import Path as _Path @@ -664,9 +204,6 @@ if results and SIMILARITY_MAP: # Step 7: Optional answer generation if results and ANSWER: qwen = QwenVL(device=device_str) - _t0 = time.perf_counter() response = qwen.answer(QUERY, top_images[:TOPK], max_new_tokens=MAX_NEW_TOKENS) - gen_secs = time.perf_counter() - _t0 - print(f"[Timing] Generation: {gen_secs:.3f}s") print("\nAnswer:") print(response) diff --git a/apps/multimodal/vision-based-pdf-multi-vector/vidore_v1_benchmark.py b/apps/multimodal/vision-based-pdf-multi-vector/vidore_v1_benchmark.py deleted file mode 100644 index e68a689..0000000 --- a/apps/multimodal/vision-based-pdf-multi-vector/vidore_v1_benchmark.py +++ /dev/null @@ -1,399 +0,0 @@ -#!/usr/bin/env python3 -""" -Modular script to reproduce NDCG results for ViDoRe v1 benchmark. - -This script uses the interface from leann_multi_vector.py to: -1. Download ViDoRe v1 datasets -2. Build indexes (LEANN or Fast-Plaid) -3. Perform retrieval -4. Evaluate using NDCG metrics - -Usage: - # Evaluate all ViDoRe v1 tasks - python vidore_v1_benchmark.py --model colqwen2 --tasks all - - # Evaluate specific task - python vidore_v1_benchmark.py --model colqwen2 --task VidoreArxivQARetrieval - - # Use Fast-Plaid index - python vidore_v1_benchmark.py --model colqwen2 --use-fast-plaid --fast-plaid-index-path ./indexes/vidore_fastplaid - - # Rebuild index - python vidore_v1_benchmark.py --model colqwen2 --rebuild-index -""" - -import argparse -import json -import os -from typing import Optional - -from datasets import load_dataset -from leann_multi_vector import ( - ViDoReBenchmarkEvaluator, - _ensure_repo_paths_importable, -) - -_ensure_repo_paths_importable(__file__) - -# ViDoRe v1 task configurations -# Prompts match MTEB task metadata prompts -VIDORE_V1_TASKS = { - "VidoreArxivQARetrieval": { - "dataset_path": "vidore/arxivqa_test_subsampled_beir", - "revision": "7d94d570960eac2408d3baa7a33f9de4822ae3e4", - "prompt": {"query": "Find a screenshot that relevant to the user's question."}, - }, - "VidoreDocVQARetrieval": { - "dataset_path": "vidore/docvqa_test_subsampled_beir", - "revision": "162ba2fc1a8437eda8b6c37b240bc1c0f0deb092", - "prompt": {"query": "Find a screenshot that relevant to the user's question."}, - }, - "VidoreInfoVQARetrieval": { - "dataset_path": "vidore/infovqa_test_subsampled_beir", - "revision": "b802cc5fd6c605df2d673a963667d74881d2c9a4", - "prompt": {"query": "Find a screenshot that relevant to the user's question."}, - }, - "VidoreTabfquadRetrieval": { - "dataset_path": "vidore/tabfquad_test_subsampled_beir", - "revision": "61a2224bcd29b7b261a4892ff4c8bea353527a31", - "prompt": {"query": "Find a screenshot that relevant to the user's question."}, - }, - "VidoreTatdqaRetrieval": { - "dataset_path": "vidore/tatdqa_test_beir", - "revision": "5feb5630fdff4d8d189ffedb2dba56862fdd45c0", - "prompt": {"query": "Find a screenshot that relevant to the user's question."}, - }, - "VidoreShiftProjectRetrieval": { - "dataset_path": "vidore/shiftproject_test_beir", - "revision": "84a382e05c4473fed9cff2bbae95fe2379416117", - "prompt": {"query": "Find a screenshot that relevant to the user's question."}, - }, - "VidoreSyntheticDocQAAIRetrieval": { - "dataset_path": "vidore/syntheticDocQA_artificial_intelligence_test_beir", - "revision": "2d9ebea5a1c6e9ef4a3b902a612f605dca11261c", - "prompt": {"query": "Find a screenshot that relevant to the user's question."}, - }, - "VidoreSyntheticDocQAEnergyRetrieval": { - "dataset_path": "vidore/syntheticDocQA_energy_test_beir", - "revision": "9935aadbad5c8deec30910489db1b2c7133ae7a7", - "prompt": {"query": "Find a screenshot that relevant to the user's question."}, - }, - "VidoreSyntheticDocQAGovernmentReportsRetrieval": { - "dataset_path": "vidore/syntheticDocQA_government_reports_test_beir", - "revision": "b4909afa930f81282fd20601e860668073ad02aa", - "prompt": {"query": "Find a screenshot that relevant to the user's question."}, - }, - "VidoreSyntheticDocQAHealthcareIndustryRetrieval": { - "dataset_path": "vidore/syntheticDocQA_healthcare_industry_test_beir", - "revision": "f9e25d5b6e13e1ad9f5c3cce202565031b3ab164", - "prompt": {"query": "Find a screenshot that relevant to the user's question."}, - }, -} - - -def load_vidore_v1_data( - dataset_path: str, - revision: Optional[str] = None, - split: str = "test", -): - """ - Load ViDoRe v1 dataset. - - Returns: - corpus: dict mapping corpus_id to PIL Image - queries: dict mapping query_id to query text - qrels: dict mapping query_id to dict of {corpus_id: relevance_score} - """ - print(f"Loading dataset: {dataset_path} (split={split})") - - # Load queries - query_ds = load_dataset(dataset_path, "queries", split=split, revision=revision) - - queries = {} - for row in query_ds: - query_id = f"query-{split}-{row['query-id']}" - queries[query_id] = row["query"] - - # Load corpus (images) - corpus_ds = load_dataset(dataset_path, "corpus", split=split, revision=revision) - - corpus = {} - for row in corpus_ds: - corpus_id = f"corpus-{split}-{row['corpus-id']}" - # Extract image from the dataset row - if "image" in row: - corpus[corpus_id] = row["image"] - elif "page_image" in row: - corpus[corpus_id] = row["page_image"] - else: - raise ValueError( - f"No image field found in corpus. Available fields: {list(row.keys())}" - ) - - # Load qrels (relevance judgments) - qrels_ds = load_dataset(dataset_path, "qrels", split=split, revision=revision) - - qrels = {} - for row in qrels_ds: - query_id = f"query-{split}-{row['query-id']}" - corpus_id = f"corpus-{split}-{row['corpus-id']}" - if query_id not in qrels: - qrels[query_id] = {} - qrels[query_id][corpus_id] = int(row["score"]) - - print( - f"Loaded {len(queries)} queries, {len(corpus)} corpus items, {len(qrels)} query-relevance mappings" - ) - - # Filter qrels to only include queries that exist - qrels = {qid: rel_docs for qid, rel_docs in qrels.items() if qid in queries} - - # Filter out queries without any relevant documents (matching MTEB behavior) - # This is important for correct NDCG calculation - qrels_filtered = {qid: rel_docs for qid, rel_docs in qrels.items() if len(rel_docs) > 0} - queries_filtered = { - qid: query_text for qid, query_text in queries.items() if qid in qrels_filtered - } - - print( - f"After filtering queries without positives: {len(queries_filtered)} queries, {len(qrels_filtered)} query-relevance mappings" - ) - - return corpus, queries_filtered, qrels_filtered - - -def evaluate_task( - task_name: str, - model_name: str, - index_path: str, - use_fast_plaid: bool = False, - fast_plaid_index_path: Optional[str] = None, - rebuild_index: bool = False, - top_k: int = 1000, - first_stage_k: int = 500, - k_values: Optional[list[int]] = None, - output_dir: Optional[str] = None, -): - """ - Evaluate a single ViDoRe v1 task. - """ - print(f"\n{'=' * 80}") - print(f"Evaluating task: {task_name}") - print(f"{'=' * 80}") - - # Get task config - if task_name not in VIDORE_V1_TASKS: - raise ValueError(f"Unknown task: {task_name}. Available: {list(VIDORE_V1_TASKS.keys())}") - - task_config = VIDORE_V1_TASKS[task_name] - dataset_path = task_config["dataset_path"] - revision = task_config["revision"] - - # Load data - corpus, queries, qrels = load_vidore_v1_data( - dataset_path=dataset_path, - revision=revision, - split="test", - ) - - # Initialize k_values if not provided - if k_values is None: - k_values = [1, 3, 5, 10, 20, 100, 1000] - - # Check if we have any queries - if len(queries) == 0: - print(f"\nWarning: No queries found for task {task_name}. Skipping evaluation.") - # Return zero scores - scores = {} - for k in k_values: - scores[f"ndcg_at_{k}"] = 0.0 - scores[f"map_at_{k}"] = 0.0 - scores[f"recall_at_{k}"] = 0.0 - scores[f"precision_at_{k}"] = 0.0 - scores[f"mrr_at_{k}"] = 0.0 - return scores - - # Initialize evaluator - evaluator = ViDoReBenchmarkEvaluator( - model_name=model_name, - use_fast_plaid=use_fast_plaid, - top_k=top_k, - first_stage_k=first_stage_k, - k_values=k_values, - ) - - # Build or load index - index_path_full = index_path if not use_fast_plaid else fast_plaid_index_path - if index_path_full is None: - index_path_full = f"./indexes/{task_name}_{model_name}" - if use_fast_plaid: - index_path_full = f"./indexes/{task_name}_{model_name}_fastplaid" - - index_or_retriever, corpus_ids_ordered = evaluator.build_index_from_corpus( - corpus=corpus, - index_path=index_path_full, - rebuild=rebuild_index, - ) - - # Search queries - task_prompt = task_config.get("prompt") - results = evaluator.search_queries( - queries=queries, - corpus_ids=corpus_ids_ordered, - index_or_retriever=index_or_retriever, - fast_plaid_index_path=fast_plaid_index_path, - task_prompt=task_prompt, - ) - - # Evaluate - scores = evaluator.evaluate_results(results, qrels, k_values=k_values) - - # Print results - print(f"\n{'=' * 80}") - print(f"Results for {task_name}:") - print(f"{'=' * 80}") - for metric, value in scores.items(): - if isinstance(value, (int, float)): - print(f" {metric}: {value:.5f}") - - # Save results - if output_dir: - os.makedirs(output_dir, exist_ok=True) - results_file = os.path.join(output_dir, f"{task_name}_results.json") - scores_file = os.path.join(output_dir, f"{task_name}_scores.json") - - with open(results_file, "w") as f: - json.dump(results, f, indent=2) - print(f"\nSaved results to: {results_file}") - - with open(scores_file, "w") as f: - json.dump(scores, f, indent=2) - print(f"Saved scores to: {scores_file}") - - return scores - - -def main(): - parser = argparse.ArgumentParser( - description="Evaluate ViDoRe v1 benchmark using LEANN/Fast-Plaid indexing" - ) - parser.add_argument( - "--model", - type=str, - default="colqwen2", - choices=["colqwen2", "colpali"], - help="Model to use", - ) - parser.add_argument( - "--task", - type=str, - default=None, - help="Specific task to evaluate (or 'all' for all tasks)", - ) - parser.add_argument( - "--tasks", - type=str, - default="all", - help="Tasks to evaluate: 'all' or comma-separated list", - ) - parser.add_argument( - "--index-path", - type=str, - default=None, - help="Path to LEANN index (auto-generated if not provided)", - ) - parser.add_argument( - "--use-fast-plaid", - action="store_true", - help="Use Fast-Plaid instead of LEANN", - ) - parser.add_argument( - "--fast-plaid-index-path", - type=str, - default=None, - help="Path to Fast-Plaid index (auto-generated if not provided)", - ) - parser.add_argument( - "--rebuild-index", - action="store_true", - help="Rebuild index even if it exists", - ) - parser.add_argument( - "--top-k", - type=int, - default=1000, - help="Top-k results to retrieve (MTEB default is max(k_values)=1000)", - ) - parser.add_argument( - "--first-stage-k", - type=int, - default=500, - help="First stage k for LEANN search", - ) - parser.add_argument( - "--k-values", - type=str, - default="1,3,5,10,20,100,1000", - help="Comma-separated k values for evaluation (e.g., '1,3,5,10,100')", - ) - parser.add_argument( - "--output-dir", - type=str, - default="./vidore_v1_results", - help="Output directory for results", - ) - - args = parser.parse_args() - - # Parse k_values - k_values = [int(k.strip()) for k in args.k_values.split(",")] - - # Determine tasks to evaluate - if args.task: - tasks_to_eval = [args.task] - elif args.tasks.lower() == "all": - tasks_to_eval = list(VIDORE_V1_TASKS.keys()) - else: - tasks_to_eval = [t.strip() for t in args.tasks.split(",")] - - print(f"Tasks to evaluate: {tasks_to_eval}") - - # Evaluate each task - all_scores = {} - for task_name in tasks_to_eval: - try: - scores = evaluate_task( - task_name=task_name, - model_name=args.model, - index_path=args.index_path, - use_fast_plaid=args.use_fast_plaid, - fast_plaid_index_path=args.fast_plaid_index_path, - rebuild_index=args.rebuild_index, - top_k=args.top_k, - first_stage_k=args.first_stage_k, - k_values=k_values, - output_dir=args.output_dir, - ) - all_scores[task_name] = scores - except Exception as e: - print(f"\nError evaluating {task_name}: {e}") - import traceback - - traceback.print_exc() - continue - - # Print summary - if all_scores: - print(f"\n{'=' * 80}") - print("SUMMARY") - print(f"{'=' * 80}") - for task_name, scores in all_scores.items(): - print(f"\n{task_name}:") - # Print main metrics - for metric in ["ndcg_at_5", "ndcg_at_10", "ndcg_at_100", "map_at_10", "recall_at_10"]: - if metric in scores: - print(f" {metric}: {scores[metric]:.5f}") - - -if __name__ == "__main__": - main() diff --git a/apps/multimodal/vision-based-pdf-multi-vector/vidore_v2_benchmark.py b/apps/multimodal/vision-based-pdf-multi-vector/vidore_v2_benchmark.py deleted file mode 100644 index 8a34e69..0000000 --- a/apps/multimodal/vision-based-pdf-multi-vector/vidore_v2_benchmark.py +++ /dev/null @@ -1,439 +0,0 @@ -#!/usr/bin/env python3 -""" -Modular script to reproduce NDCG results for ViDoRe v2 benchmark. - -This script uses the interface from leann_multi_vector.py to: -1. Download ViDoRe v2 datasets -2. Build indexes (LEANN or Fast-Plaid) -3. Perform retrieval -4. Evaluate using NDCG metrics - -Usage: - # Evaluate all ViDoRe v2 tasks - python vidore_v2_benchmark.py --model colqwen2 --tasks all - - # Evaluate specific task - python vidore_v2_benchmark.py --model colqwen2 --task Vidore2ESGReportsRetrieval - - # Use Fast-Plaid index - python vidore_v2_benchmark.py --model colqwen2 --use-fast-plaid --fast-plaid-index-path ./indexes/vidore_fastplaid - - # Rebuild index - python vidore_v2_benchmark.py --model colqwen2 --rebuild-index -""" - -import argparse -import json -import os -from typing import Optional - -from datasets import load_dataset -from leann_multi_vector import ( - ViDoReBenchmarkEvaluator, - _ensure_repo_paths_importable, -) - -_ensure_repo_paths_importable(__file__) - -# Language name to dataset language field value mapping -# Dataset uses ISO 639-3 + ISO 15924 format (e.g., "eng-Latn") -LANGUAGE_MAPPING = { - "english": "eng-Latn", - "french": "fra-Latn", - "spanish": "spa-Latn", - "german": "deu-Latn", -} - -# ViDoRe v2 task configurations -# Prompts match MTEB task metadata prompts -VIDORE_V2_TASKS = { - "Vidore2ESGReportsRetrieval": { - "dataset_path": "vidore/esg_reports_v2", - "revision": "0542c0d03da0ec1c8cbc517c8d78e7e95c75d3d3", - "languages": ["french", "spanish", "english", "german"], - "prompt": {"query": "Find a screenshot that relevant to the user's question."}, - }, - "Vidore2EconomicsReportsRetrieval": { - "dataset_path": "vidore/economics_reports_v2", - "revision": "b3e3a04b07fbbaffe79be49dabf92f691fbca252", - "languages": ["french", "spanish", "english", "german"], - "prompt": {"query": "Find a screenshot that relevant to the user's question."}, - }, - "Vidore2BioMedicalLecturesRetrieval": { - "dataset_path": "vidore/biomedical_lectures_v2", - "revision": "a29202f0da409034d651614d87cd8938d254e2ea", - "languages": ["french", "spanish", "english", "german"], - "prompt": {"query": "Find a screenshot that relevant to the user's question."}, - }, - "Vidore2ESGReportsHLRetrieval": { - "dataset_path": "vidore/esg_reports_human_labeled_v2", - "revision": "6d467dedb09a75144ede1421747e47cf036857dd", - # Note: This dataset doesn't have language filtering - all queries are English - "languages": None, # No language filtering needed - "prompt": {"query": "Find a screenshot that relevant to the user's question."}, - }, -} - - -def load_vidore_v2_data( - dataset_path: str, - revision: Optional[str] = None, - split: str = "test", - language: Optional[str] = None, -): - """ - Load ViDoRe v2 dataset. - - Returns: - corpus: dict mapping corpus_id to PIL Image - queries: dict mapping query_id to query text - qrels: dict mapping query_id to dict of {corpus_id: relevance_score} - """ - print(f"Loading dataset: {dataset_path} (split={split}, language={language})") - - # Load queries - query_ds = load_dataset(dataset_path, "queries", split=split, revision=revision) - - # Check if dataset has language field before filtering - has_language_field = len(query_ds) > 0 and "language" in query_ds.column_names - - if language and has_language_field: - # Map language name to dataset language field value (e.g., "english" -> "eng-Latn") - dataset_language = LANGUAGE_MAPPING.get(language, language) - query_ds_filtered = query_ds.filter(lambda x: x.get("language") == dataset_language) - # Check if filtering resulted in empty dataset - if len(query_ds_filtered) == 0: - print( - f"Warning: No queries found after filtering by language '{language}' (mapped to '{dataset_language}')." - ) - # Try with original language value (dataset might use simple names like 'english') - print(f"Trying with original language value '{language}'...") - query_ds_filtered = query_ds.filter(lambda x: x.get("language") == language) - if len(query_ds_filtered) == 0: - # Try to get a sample to see actual language values - try: - sample_ds = load_dataset( - dataset_path, "queries", split=split, revision=revision - ) - if len(sample_ds) > 0 and "language" in sample_ds.column_names: - sample_langs = set(sample_ds["language"]) - print(f"Available language values in dataset: {sample_langs}") - except Exception: - pass - else: - print( - f"Found {len(query_ds_filtered)} queries using original language value '{language}'" - ) - query_ds = query_ds_filtered - - queries = {} - for row in query_ds: - query_id = f"query-{split}-{row['query-id']}" - queries[query_id] = row["query"] - - # Load corpus (images) - corpus_ds = load_dataset(dataset_path, "corpus", split=split, revision=revision) - - corpus = {} - for row in corpus_ds: - corpus_id = f"corpus-{split}-{row['corpus-id']}" - # Extract image from the dataset row - if "image" in row: - corpus[corpus_id] = row["image"] - elif "page_image" in row: - corpus[corpus_id] = row["page_image"] - else: - raise ValueError( - f"No image field found in corpus. Available fields: {list(row.keys())}" - ) - - # Load qrels (relevance judgments) - qrels_ds = load_dataset(dataset_path, "qrels", split=split, revision=revision) - - qrels = {} - for row in qrels_ds: - query_id = f"query-{split}-{row['query-id']}" - corpus_id = f"corpus-{split}-{row['corpus-id']}" - if query_id not in qrels: - qrels[query_id] = {} - qrels[query_id][corpus_id] = int(row["score"]) - - print( - f"Loaded {len(queries)} queries, {len(corpus)} corpus items, {len(qrels)} query-relevance mappings" - ) - - # Filter qrels to only include queries that exist - qrels = {qid: rel_docs for qid, rel_docs in qrels.items() if qid in queries} - - # Filter out queries without any relevant documents (matching MTEB behavior) - # This is important for correct NDCG calculation - qrels_filtered = {qid: rel_docs for qid, rel_docs in qrels.items() if len(rel_docs) > 0} - queries_filtered = { - qid: query_text for qid, query_text in queries.items() if qid in qrels_filtered - } - - print( - f"After filtering queries without positives: {len(queries_filtered)} queries, {len(qrels_filtered)} query-relevance mappings" - ) - - return corpus, queries_filtered, qrels_filtered - - -def evaluate_task( - task_name: str, - model_name: str, - index_path: str, - use_fast_plaid: bool = False, - fast_plaid_index_path: Optional[str] = None, - language: Optional[str] = None, - rebuild_index: bool = False, - top_k: int = 100, - first_stage_k: int = 500, - k_values: Optional[list[int]] = None, - output_dir: Optional[str] = None, -): - """ - Evaluate a single ViDoRe v2 task. - """ - print(f"\n{'=' * 80}") - print(f"Evaluating task: {task_name}") - print(f"{'=' * 80}") - - # Get task config - if task_name not in VIDORE_V2_TASKS: - raise ValueError(f"Unknown task: {task_name}. Available: {list(VIDORE_V2_TASKS.keys())}") - - task_config = VIDORE_V2_TASKS[task_name] - dataset_path = task_config["dataset_path"] - revision = task_config["revision"] - - # Determine language - if language is None: - # Use first language if multiple available - languages = task_config.get("languages") - if languages is None: - # Task doesn't support language filtering (e.g., Vidore2ESGReportsHLRetrieval) - language = None - elif len(languages) == 1: - language = languages[0] - else: - language = None - - # Initialize k_values if not provided - if k_values is None: - k_values = [1, 3, 5, 10, 100] - - # Load data - corpus, queries, qrels = load_vidore_v2_data( - dataset_path=dataset_path, - revision=revision, - split="test", - language=language, - ) - - # Check if we have any queries - if len(queries) == 0: - print( - f"\nWarning: No queries found for task {task_name} with language {language}. Skipping evaluation." - ) - # Return zero scores - scores = {} - for k in k_values: - scores[f"ndcg_at_{k}"] = 0.0 - scores[f"map_at_{k}"] = 0.0 - scores[f"recall_at_{k}"] = 0.0 - scores[f"precision_at_{k}"] = 0.0 - scores[f"mrr_at_{k}"] = 0.0 - return scores - - # Initialize evaluator - evaluator = ViDoReBenchmarkEvaluator( - model_name=model_name, - use_fast_plaid=use_fast_plaid, - top_k=top_k, - first_stage_k=first_stage_k, - k_values=k_values, - ) - - # Build or load index - index_path_full = index_path if not use_fast_plaid else fast_plaid_index_path - if index_path_full is None: - index_path_full = f"./indexes/{task_name}_{model_name}" - if use_fast_plaid: - index_path_full = f"./indexes/{task_name}_{model_name}_fastplaid" - - index_or_retriever, corpus_ids_ordered = evaluator.build_index_from_corpus( - corpus=corpus, - index_path=index_path_full, - rebuild=rebuild_index, - ) - - # Search queries - task_prompt = task_config.get("prompt") - results = evaluator.search_queries( - queries=queries, - corpus_ids=corpus_ids_ordered, - index_or_retriever=index_or_retriever, - fast_plaid_index_path=fast_plaid_index_path, - task_prompt=task_prompt, - ) - - # Evaluate - scores = evaluator.evaluate_results(results, qrels, k_values=k_values) - - # Print results - print(f"\n{'=' * 80}") - print(f"Results for {task_name}:") - print(f"{'=' * 80}") - for metric, value in scores.items(): - if isinstance(value, (int, float)): - print(f" {metric}: {value:.5f}") - - # Save results - if output_dir: - os.makedirs(output_dir, exist_ok=True) - results_file = os.path.join(output_dir, f"{task_name}_results.json") - scores_file = os.path.join(output_dir, f"{task_name}_scores.json") - - with open(results_file, "w") as f: - json.dump(results, f, indent=2) - print(f"\nSaved results to: {results_file}") - - with open(scores_file, "w") as f: - json.dump(scores, f, indent=2) - print(f"Saved scores to: {scores_file}") - - return scores - - -def main(): - parser = argparse.ArgumentParser( - description="Evaluate ViDoRe v2 benchmark using LEANN/Fast-Plaid indexing" - ) - parser.add_argument( - "--model", - type=str, - default="colqwen2", - choices=["colqwen2", "colpali"], - help="Model to use", - ) - parser.add_argument( - "--task", - type=str, - default=None, - help="Specific task to evaluate (or 'all' for all tasks)", - ) - parser.add_argument( - "--tasks", - type=str, - default="all", - help="Tasks to evaluate: 'all' or comma-separated list", - ) - parser.add_argument( - "--index-path", - type=str, - default=None, - help="Path to LEANN index (auto-generated if not provided)", - ) - parser.add_argument( - "--use-fast-plaid", - action="store_true", - help="Use Fast-Plaid instead of LEANN", - ) - parser.add_argument( - "--fast-plaid-index-path", - type=str, - default=None, - help="Path to Fast-Plaid index (auto-generated if not provided)", - ) - parser.add_argument( - "--rebuild-index", - action="store_true", - help="Rebuild index even if it exists", - ) - parser.add_argument( - "--language", - type=str, - default=None, - help="Language to evaluate (default: first available)", - ) - parser.add_argument( - "--top-k", - type=int, - default=100, - help="Top-k results to retrieve", - ) - parser.add_argument( - "--first-stage-k", - type=int, - default=500, - help="First stage k for LEANN search", - ) - parser.add_argument( - "--k-values", - type=str, - default="1,3,5,10,100", - help="Comma-separated k values for evaluation (e.g., '1,3,5,10,100')", - ) - parser.add_argument( - "--output-dir", - type=str, - default="./vidore_v2_results", - help="Output directory for results", - ) - - args = parser.parse_args() - - # Parse k_values - k_values = [int(k.strip()) for k in args.k_values.split(",")] - - # Determine tasks to evaluate - if args.task: - tasks_to_eval = [args.task] - elif args.tasks.lower() == "all": - tasks_to_eval = list(VIDORE_V2_TASKS.keys()) - else: - tasks_to_eval = [t.strip() for t in args.tasks.split(",")] - - print(f"Tasks to evaluate: {tasks_to_eval}") - - # Evaluate each task - all_scores = {} - for task_name in tasks_to_eval: - try: - scores = evaluate_task( - task_name=task_name, - model_name=args.model, - index_path=args.index_path, - use_fast_plaid=args.use_fast_plaid, - fast_plaid_index_path=args.fast_plaid_index_path, - language=args.language, - rebuild_index=args.rebuild_index, - top_k=args.top_k, - first_stage_k=args.first_stage_k, - k_values=k_values, - output_dir=args.output_dir, - ) - all_scores[task_name] = scores - except Exception as e: - print(f"\nError evaluating {task_name}: {e}") - import traceback - - traceback.print_exc() - continue - - # Print summary - if all_scores: - print(f"\n{'=' * 80}") - print("SUMMARY") - print(f"{'=' * 80}") - for task_name, scores in all_scores.items(): - print(f"\n{task_name}:") - # Print main metrics - for metric in ["ndcg_at_5", "ndcg_at_10", "ndcg_at_100", "map_at_10", "recall_at_10"]: - if metric in scores: - print(f" {metric}: {scores[metric]:.5f}") - - -if __name__ == "__main__": - main()