reproduce docvqa results and add debug file

This commit is contained in:
yichuan-w
2025-12-03 08:54:55 +00:00
parent 07afe546ea
commit 1c690e4a8a
5 changed files with 450 additions and 222 deletions

View File

@@ -0,0 +1,132 @@
#!/usr/bin/env python3
"""Simple test script to test colqwen2 forward pass with a single image."""
import os
import sys
from pathlib import Path
# Add the current directory to path to import leann_multi_vector
sys.path.insert(0, str(Path(__file__).parent))
from PIL import Image
import torch
from leann_multi_vector import _load_colvision, _embed_images, _ensure_repo_paths_importable
# Ensure repo paths are importable
_ensure_repo_paths_importable(__file__)
# Set environment variable
os.environ["TOKENIZERS_PARALLELISM"] = "false"
def create_test_image():
"""Create a simple test image."""
# Create a simple RGB image (800x600)
img = Image.new('RGB', (800, 600), color='white')
return img
def load_test_image_from_file():
"""Try to load an image from the indexes directory if available."""
# Try to find an existing image in the indexes directory
indexes_dir = Path(__file__).parent / "indexes"
# Look for images in common locations
possible_paths = [
indexes_dir / "vidore_fastplaid" / "images",
indexes_dir / "colvision_large.leann.images",
indexes_dir / "colvision.leann.images",
]
for img_dir in possible_paths:
if img_dir.exists():
# Find first image file
for ext in ['.png', '.jpg', '.jpeg']:
for img_file in img_dir.glob(f'*{ext}'):
print(f"Loading test image from: {img_file}")
return Image.open(img_file)
return None
def main():
print("=" * 60)
print("Testing ColQwen2 Forward Pass")
print("=" * 60)
# Step 1: Load or create test image
print("\n[Step 1] Loading test image...")
test_image = load_test_image_from_file()
if test_image is None:
print("No existing image found, creating a simple test image...")
test_image = create_test_image()
else:
print(f"✓ Loaded image: {test_image.size} ({test_image.mode})")
# Convert to RGB if needed
if test_image.mode != 'RGB':
test_image = test_image.convert('RGB')
print(f"✓ Converted to RGB: {test_image.size}")
# Step 2: Load model
print("\n[Step 2] Loading ColQwen2 model...")
try:
model_name, model, processor, device_str, device, dtype = _load_colvision("colqwen2")
print(f"✓ Model loaded: {model_name}")
print(f"✓ Device: {device_str}, dtype: {dtype}")
# Print model info
if hasattr(model, 'device'):
print(f"✓ Model device: {model.device}")
if hasattr(model, 'dtype'):
print(f"✓ Model dtype: {model.dtype}")
except Exception as e:
print(f"✗ Error loading model: {e}")
import traceback
traceback.print_exc()
return
# Step 3: Test forward pass
print("\n[Step 3] Running forward pass...")
try:
# Use the _embed_images function which handles batching and forward pass
images = [test_image]
print(f"Processing {len(images)} image(s)...")
doc_vecs = _embed_images(model, processor, images)
print(f"✓ Forward pass completed!")
print(f"✓ Number of embeddings: {len(doc_vecs)}")
if len(doc_vecs) > 0:
emb = doc_vecs[0]
print(f"✓ Embedding shape: {emb.shape}")
print(f"✓ Embedding dtype: {emb.dtype}")
print(f"✓ Embedding stats:")
print(f" - Min: {emb.min().item():.4f}")
print(f" - Max: {emb.max().item():.4f}")
print(f" - Mean: {emb.mean().item():.4f}")
print(f" - Std: {emb.std().item():.4f}")
# Check for NaN or Inf
if torch.isnan(emb).any():
print("⚠ Warning: Embedding contains NaN values!")
if torch.isinf(emb).any():
print("⚠ Warning: Embedding contains Inf values!")
except Exception as e:
print(f"✗ Error during forward pass: {e}")
import traceback
traceback.print_exc()
return
print("\n" + "=" * 60)
print("Test completed successfully!")
print("=" * 60)
if __name__ == "__main__":
main()

View File

@@ -227,28 +227,26 @@ def _embed_queries(model, processor, queries: list[str]) -> list[Any]:
# 2. Manually adds: query_prefix + text + query_augmentation_token * 10 # 2. Manually adds: query_prefix + text + query_augmentation_token * 10
# 3. Calls processor.process_queries(batch) where batch is now a list of strings # 3. Calls processor.process_queries(batch) where batch is now a list of strings
# 4. process_queries adds: query_prefix + text + suffix (suffix = query_augmentation_token * 10) # 4. process_queries adds: query_prefix + text + suffix (suffix = query_augmentation_token * 10)
# #
# This results in duplicate addition: query_prefix is added twice, query_augmentation_token * 20 total # This results in duplicate addition: query_prefix is added twice, query_augmentation_token * 20 total
# We need to match this exactly to reproduce MTEB results # We need to match this exactly to reproduce MTEB results
all_embeds = [] all_embeds = []
batch_size = 32 # Match MTEB's default batch_size batch_size = 32 # Match MTEB's default batch_size
with torch.no_grad(): with torch.no_grad():
for i in tqdm(range(0, len(queries), batch_size), desc="Embedding queries"): for i in tqdm(range(0, len(queries), batch_size), desc="Embedding queries"):
batch_queries = queries[i:i + batch_size] batch_queries = queries[i : i + batch_size]
# Match MTEB: manually add query_prefix + text + query_augmentation_token * 10 # Match MTEB: manually add query_prefix + text + query_augmentation_token * 10
# Then process_queries will add them again (resulting in 20 augmentation tokens total) # Then process_queries will add them again (resulting in 20 augmentation tokens total)
batch = [ batch = [
processor.query_prefix processor.query_prefix + t + processor.query_augmentation_token * 10
+ t
+ processor.query_augmentation_token * 10
for t in batch_queries for t in batch_queries
] ]
inputs = processor.process_queries(batch) inputs = processor.process_queries(batch)
inputs = {k: v.to(model.device) for k, v in inputs.items()} inputs = {k: v.to(model.device) for k, v in inputs.items()}
if model.device.type == "cuda": if model.device.type == "cuda":
with torch.autocast( with torch.autocast(
device_type="cuda", device_type="cuda",
@@ -257,10 +255,10 @@ def _embed_queries(model, processor, queries: list[str]) -> list[Any]:
outs = model(**inputs) outs = model(**inputs)
else: else:
outs = model(**inputs) outs = model(**inputs)
# Match MTEB: convert to float32 on CPU # Match MTEB: convert to float32 on CPU
all_embeds.extend(list(torch.unbind(outs.cpu().to(torch.float32)))) all_embeds.extend(list(torch.unbind(outs.cpu().to(torch.float32))))
return all_embeds return all_embeds
@@ -309,74 +307,82 @@ def _build_fast_plaid_index(
) -> tuple[Any, float]: ) -> tuple[Any, float]:
""" """
Build a Fast-Plaid index from document embeddings. Build a Fast-Plaid index from document embeddings.
Args: Args:
index_path: Path to save the Fast-Plaid index index_path: Path to save the Fast-Plaid index
doc_vecs: List of document embeddings (each is a tensor with shape [num_tokens, embedding_dim]) doc_vecs: List of document embeddings (each is a tensor with shape [num_tokens, embedding_dim])
filepaths: List of filepath identifiers for each document filepaths: List of filepath identifiers for each document
images: List of PIL Images corresponding to each document images: List of PIL Images corresponding to each document
Returns: Returns:
Tuple of (FastPlaid index object, build_time_in_seconds) Tuple of (FastPlaid index object, build_time_in_seconds)
""" """
import torch import torch
from fast_plaid import search as fast_plaid_search from fast_plaid import search as fast_plaid_search
print(f" Preparing {len(doc_vecs)} document embeddings for Fast-Plaid...") print(f" Preparing {len(doc_vecs)} document embeddings for Fast-Plaid...")
_t0 = time.perf_counter() _t0 = time.perf_counter()
# Convert doc_vecs to list of tensors # Convert doc_vecs to list of tensors
documents_embeddings = [] documents_embeddings = []
for i, vec in enumerate(doc_vecs): for i, vec in enumerate(doc_vecs):
if i % 1000 == 0: if i % 1000 == 0:
print(f" Converting embedding {i}/{len(doc_vecs)}...") print(f" Converting embedding {i}/{len(doc_vecs)}...")
if not isinstance(vec, torch.Tensor): if not isinstance(vec, torch.Tensor):
vec = torch.tensor(vec) if isinstance(vec, np.ndarray) else torch.from_numpy(np.array(vec)) vec = (
torch.tensor(vec)
if isinstance(vec, np.ndarray)
else torch.from_numpy(np.array(vec))
)
# Ensure float32 for Fast-Plaid # Ensure float32 for Fast-Plaid
if vec.dtype != torch.float32: if vec.dtype != torch.float32:
vec = vec.float() vec = vec.float()
documents_embeddings.append(vec) documents_embeddings.append(vec)
print(f" Converted {len(documents_embeddings)} embeddings") print(f" Converted {len(documents_embeddings)} embeddings")
if len(documents_embeddings) > 0: if len(documents_embeddings) > 0:
print(f" First embedding shape: {documents_embeddings[0].shape}") print(f" First embedding shape: {documents_embeddings[0].shape}")
print(f" First embedding dtype: {documents_embeddings[0].dtype}") print(f" First embedding dtype: {documents_embeddings[0].dtype}")
# Prepare metadata for Fast-Plaid # Prepare metadata for Fast-Plaid
print(f" Preparing metadata for {len(filepaths)} documents...") print(f" Preparing metadata for {len(filepaths)} documents...")
metadata_list = [] metadata_list = []
for i, filepath in enumerate(filepaths): for i, filepath in enumerate(filepaths):
metadata_list.append({ metadata_list.append(
"filepath": filepath, {
"index": i, "filepath": filepath,
}) "index": i,
}
)
# Create Fast-Plaid index # Create Fast-Plaid index
print(f" Creating FastPlaid object with index path: {index_path}") print(f" Creating FastPlaid object with index path: {index_path}")
try: try:
fast_plaid_index = fast_plaid_search.FastPlaid(index=index_path) fast_plaid_index = fast_plaid_search.FastPlaid(index=index_path)
print(f" FastPlaid object created successfully") print(" FastPlaid object created successfully")
except Exception as e: except Exception as e:
print(f" Error creating FastPlaid object: {type(e).__name__}: {e}") print(f" Error creating FastPlaid object: {type(e).__name__}: {e}")
import traceback import traceback
traceback.print_exc() traceback.print_exc()
raise raise
print(f" Calling fast_plaid_index.create() with {len(documents_embeddings)} documents...") print(f" Calling fast_plaid_index.create() with {len(documents_embeddings)} documents...")
try: try:
fast_plaid_index.create( fast_plaid_index.create(
documents_embeddings=documents_embeddings, documents_embeddings=documents_embeddings,
metadata=metadata_list, metadata=metadata_list,
) )
print(f" Fast-Plaid index created successfully") print(" Fast-Plaid index created successfully")
except Exception as e: except Exception as e:
print(f" Error creating Fast-Plaid index: {type(e).__name__}: {e}") print(f" Error creating Fast-Plaid index: {type(e).__name__}: {e}")
import traceback import traceback
traceback.print_exc() traceback.print_exc()
raise raise
build_secs = time.perf_counter() - _t0 build_secs = time.perf_counter() - _t0
# Save images separately (Fast-Plaid doesn't store images) # Save images separately (Fast-Plaid doesn't store images)
print(f" Saving {len(images)} images...") print(f" Saving {len(images)} images...")
images_dir = Path(index_path) / "images" images_dir = Path(index_path) / "images"
@@ -384,7 +390,7 @@ def _build_fast_plaid_index(
for i, img in enumerate(tqdm(images, desc="Saving images")): for i, img in enumerate(tqdm(images, desc="Saving images")):
img_path = images_dir / f"doc_{i}.png" img_path = images_dir / f"doc_{i}.png"
img.save(str(img_path)) img.save(str(img_path))
return fast_plaid_index, build_secs return fast_plaid_index, build_secs
@@ -392,30 +398,30 @@ def _fast_plaid_index_exists(index_path: str) -> bool:
""" """
Check if Fast-Plaid index exists by checking for key files. Check if Fast-Plaid index exists by checking for key files.
This avoids creating the FastPlaid object which may trigger memory allocation. This avoids creating the FastPlaid object which may trigger memory allocation.
Args: Args:
index_path: Path to the Fast-Plaid index index_path: Path to the Fast-Plaid index
Returns: Returns:
True if index appears to exist, False otherwise True if index appears to exist, False otherwise
""" """
index_path_obj = Path(index_path) index_path_obj = Path(index_path)
if not index_path_obj.exists() or not index_path_obj.is_dir(): if not index_path_obj.exists() or not index_path_obj.is_dir():
return False return False
# Fast-Plaid creates a SQLite database file for metadata # Fast-Plaid creates a SQLite database file for metadata
# Check for metadata.db as the most reliable indicator # Check for metadata.db as the most reliable indicator
metadata_db = index_path_obj / "metadata.db" metadata_db = index_path_obj / "metadata.db"
if metadata_db.exists() and metadata_db.stat().st_size > 0: if metadata_db.exists() and metadata_db.stat().st_size > 0:
return True return True
# Also check if directory has any files (might be incomplete index) # Also check if directory has any files (might be incomplete index)
try: try:
if any(index_path_obj.iterdir()): if any(index_path_obj.iterdir()):
return True return True
except Exception: except Exception:
pass pass
return False return False
@@ -424,20 +430,20 @@ def _load_fast_plaid_index_if_exists(index_path: str) -> Optional[Any]:
Load Fast-Plaid index if it exists. Load Fast-Plaid index if it exists.
First checks if index files exist, then creates the FastPlaid object. First checks if index files exist, then creates the FastPlaid object.
The actual index data loading happens lazily when search is called. The actual index data loading happens lazily when search is called.
Args: Args:
index_path: Path to the Fast-Plaid index index_path: Path to the Fast-Plaid index
Returns: Returns:
FastPlaid index object if exists, None otherwise FastPlaid index object if exists, None otherwise
""" """
try: try:
from fast_plaid import search as fast_plaid_search from fast_plaid import search as fast_plaid_search
# First check if index files exist without creating the object # First check if index files exist without creating the object
if not _fast_plaid_index_exists(index_path): if not _fast_plaid_index_exists(index_path):
return None return None
# Now try to create FastPlaid object # Now try to create FastPlaid object
# This may trigger some memory allocation, but the full index loading is deferred # This may trigger some memory allocation, but the full index loading is deferred
fast_plaid_index = fast_plaid_search.FastPlaid(index=index_path) fast_plaid_index = fast_plaid_search.FastPlaid(index=index_path)
@@ -459,81 +465,105 @@ def _search_fast_plaid(
) -> tuple[list[tuple[float, int]], float]: ) -> tuple[list[tuple[float, int]], float]:
""" """
Search Fast-Plaid index with a query embedding. Search Fast-Plaid index with a query embedding.
Args: Args:
fast_plaid_index: FastPlaid index object fast_plaid_index: FastPlaid index object
query_vec: Query embedding tensor with shape [num_tokens, embedding_dim] query_vec: Query embedding tensor with shape [num_tokens, embedding_dim]
top_k: Number of top results to return top_k: Number of top results to return
Returns: Returns:
Tuple of (results_list, search_time_in_seconds) Tuple of (results_list, search_time_in_seconds)
results_list: List of (score, doc_id) tuples results_list: List of (score, doc_id) tuples
""" """
import torch import torch
_t0 = time.perf_counter() _t0 = time.perf_counter()
# Ensure query is a torch tensor # Ensure query is a torch tensor
if not isinstance(query_vec, torch.Tensor): if not isinstance(query_vec, torch.Tensor):
q_vec_tensor = torch.tensor(query_vec) if isinstance(query_vec, np.ndarray) else torch.from_numpy(np.array(query_vec)) q_vec_tensor = (
torch.tensor(query_vec)
if isinstance(query_vec, np.ndarray)
else torch.from_numpy(np.array(query_vec))
)
else: else:
q_vec_tensor = query_vec q_vec_tensor = query_vec
# Fast-Plaid expects shape [num_queries, num_tokens, embedding_dim] # Fast-Plaid expects shape [num_queries, num_tokens, embedding_dim]
if q_vec_tensor.dim() == 2: if q_vec_tensor.dim() == 2:
q_vec_tensor = q_vec_tensor.unsqueeze(0) # [1, num_tokens, embedding_dim] q_vec_tensor = q_vec_tensor.unsqueeze(0) # [1, num_tokens, embedding_dim]
# Perform search # Perform search
scores = fast_plaid_index.search( scores = fast_plaid_index.search(
queries_embeddings=q_vec_tensor, queries_embeddings=q_vec_tensor,
top_k=top_k, top_k=top_k,
show_progress=True, show_progress=True,
) )
search_secs = time.perf_counter() - _t0 search_secs = time.perf_counter() - _t0
# Convert Fast-Plaid results to same format as LEANN: list of (score, doc_id) tuples # Convert Fast-Plaid results to same format as LEANN: list of (score, doc_id) tuples
results = [] results = []
if scores and len(scores) > 0: if scores and len(scores) > 0:
query_results = scores[0] query_results = scores[0]
# Fast-Plaid returns (doc_id, score), convert to (score, doc_id) to match LEANN format # Fast-Plaid returns (doc_id, score), convert to (score, doc_id) to match LEANN format
results = [(float(score), int(doc_id)) for doc_id, score in query_results] results = [(float(score), int(doc_id)) for doc_id, score in query_results]
return results, search_secs return results, search_secs
def _get_fast_plaid_image(index_path: str, doc_id: int) -> Optional[Image.Image]: def _get_fast_plaid_image(index_path: str, doc_id: int) -> Optional[Image.Image]:
""" """
Retrieve image for a document from Fast-Plaid index. Retrieve image for a document from Fast-Plaid index.
Args: Args:
index_path: Path to the Fast-Plaid index index_path: Path to the Fast-Plaid index
doc_id: Document ID doc_id: Document ID returned by Fast-Plaid search
Returns: Returns:
PIL Image if found, None otherwise PIL Image if found, None otherwise
Note: Uses metadata['index'] to get the actual file index, as Fast-Plaid
doc_id may differ from the file naming index.
""" """
# First get metadata to find the actual index used for file naming
metadata = _get_fast_plaid_metadata(index_path, doc_id)
if metadata is None:
# Fallback: try using doc_id directly
file_index = doc_id
else:
# Use the 'index' field from metadata, which matches the file naming
file_index = metadata.get("index", doc_id)
images_dir = Path(index_path) / "images" images_dir = Path(index_path) / "images"
image_path = images_dir / f"doc_{doc_id}.png" image_path = images_dir / f"doc_{file_index}.png"
if image_path.exists(): if image_path.exists():
return Image.open(image_path) return Image.open(image_path)
# If not found with index, try doc_id as fallback
if file_index != doc_id:
fallback_path = images_dir / f"doc_{doc_id}.png"
if fallback_path.exists():
return Image.open(fallback_path)
return None return None
def _get_fast_plaid_metadata(index_path: str, doc_id: int) -> Optional[dict]: def _get_fast_plaid_metadata(index_path: str, doc_id: int) -> Optional[dict]:
""" """
Retrieve metadata for a document from Fast-Plaid index. Retrieve metadata for a document from Fast-Plaid index.
Args: Args:
index_path: Path to the Fast-Plaid index index_path: Path to the Fast-Plaid index
doc_id: Document ID doc_id: Document ID
Returns: Returns:
Dictionary with metadata if found, None otherwise Dictionary with metadata if found, None otherwise
""" """
try: try:
from fast_plaid import filtering from fast_plaid import filtering
metadata_list = filtering.get(index=index_path, subset=[doc_id]) metadata_list = filtering.get(index=index_path, subset=[doc_id])
if metadata_list and len(metadata_list) > 0: if metadata_list and len(metadata_list) > 0:
return metadata_list[0] return metadata_list[0]
@@ -1053,18 +1083,18 @@ class ViDoReBenchmarkEvaluator:
A reusable class for evaluating ViDoRe benchmarks (v1 and v2). A reusable class for evaluating ViDoRe benchmarks (v1 and v2).
This class encapsulates common functionality for building indexes, searching, and evaluating. This class encapsulates common functionality for building indexes, searching, and evaluating.
""" """
def __init__( def __init__(
self, self,
model_name: str, model_name: str,
use_fast_plaid: bool = False, use_fast_plaid: bool = False,
top_k: int = 100, top_k: int = 100,
first_stage_k: int = 500, first_stage_k: int = 500,
k_values: list[int] = None, k_values: Optional[list[int]] = None,
): ):
""" """
Initialize the evaluator. Initialize the evaluator.
Args: Args:
model_name: Model name ("colqwen2" or "colpali") model_name: Model name ("colqwen2" or "colpali")
use_fast_plaid: Whether to use Fast-Plaid instead of LEANN use_fast_plaid: Whether to use Fast-Plaid instead of LEANN
@@ -1077,19 +1107,21 @@ class ViDoReBenchmarkEvaluator:
self.top_k = top_k self.top_k = top_k
self.first_stage_k = first_stage_k self.first_stage_k = first_stage_k
self.k_values = k_values if k_values is not None else [1, 3, 5, 10, 100] self.k_values = k_values if k_values is not None else [1, 3, 5, 10, 100]
# Load model once (can be reused across tasks) # Load model once (can be reused across tasks)
self._model = None self._model = None
self._processor = None self._processor = None
self._model_name_actual = None self._model_name_actual = None
def _load_model_if_needed(self): def _load_model_if_needed(self):
"""Lazy load the model.""" """Lazy load the model."""
if self._model is None: if self._model is None:
print(f"\nLoading model: {self.model_name}") print(f"\nLoading model: {self.model_name}")
self._model_name_actual, self._model, self._processor, _, _, _ = _load_colvision(self.model_name) self._model_name_actual, self._model, self._processor, _, _, _ = _load_colvision(
self.model_name
)
print(f"Model loaded: {self._model_name_actual}") print(f"Model loaded: {self._model_name_actual}")
def build_index_from_corpus( def build_index_from_corpus(
self, self,
corpus: dict[str, Image.Image], corpus: dict[str, Image.Image],
@@ -1098,31 +1130,31 @@ class ViDoReBenchmarkEvaluator:
) -> tuple[Any, list[str]]: ) -> tuple[Any, list[str]]:
""" """
Build index from corpus images. Build index from corpus images.
Args: Args:
corpus: dict mapping corpus_id to PIL Image corpus: dict mapping corpus_id to PIL Image
index_path: Path to save/load the index index_path: Path to save/load the index
rebuild: Whether to rebuild even if index exists rebuild: Whether to rebuild even if index exists
Returns: Returns:
tuple: (retriever or fast_plaid_index object, list of corpus_ids in order) tuple: (retriever or fast_plaid_index object, list of corpus_ids in order)
""" """
self._load_model_if_needed() self._load_model_if_needed()
# Ensure consistent ordering # Ensure consistent ordering
corpus_ids = sorted(corpus.keys()) corpus_ids = sorted(corpus.keys())
images = [corpus[cid] for cid in corpus_ids] images = [corpus[cid] for cid in corpus_ids]
if self.use_fast_plaid: if self.use_fast_plaid:
# Check if Fast-Plaid index exists # Check if Fast-Plaid index exists
if not rebuild and _load_fast_plaid_index_if_exists(index_path) is not None: if not rebuild and _load_fast_plaid_index_if_exists(index_path) is not None:
print(f"Fast-Plaid index already exists at {index_path}") print(f"Fast-Plaid index already exists at {index_path}")
return _load_fast_plaid_index_if_exists(index_path), corpus_ids return _load_fast_plaid_index_if_exists(index_path), corpus_ids
print(f"Building Fast-Plaid index at {index_path}...") print(f"Building Fast-Plaid index at {index_path}...")
print("Embedding images...") print("Embedding images...")
doc_vecs = _embed_images(self._model, self._processor, images) doc_vecs = _embed_images(self._model, self._processor, images)
fast_plaid_index, build_time = _build_fast_plaid_index( fast_plaid_index, build_time = _build_fast_plaid_index(
index_path, doc_vecs, corpus_ids, images index_path, doc_vecs, corpus_ids, images
) )
@@ -1135,15 +1167,15 @@ class ViDoReBenchmarkEvaluator:
if retriever is not None: if retriever is not None:
print(f"LEANN index already exists at {index_path}") print(f"LEANN index already exists at {index_path}")
return retriever, corpus_ids return retriever, corpus_ids
print(f"Building LEANN index at {index_path}...") print(f"Building LEANN index at {index_path}...")
print("Embedding images...") print("Embedding images...")
doc_vecs = _embed_images(self._model, self._processor, images) doc_vecs = _embed_images(self._model, self._processor, images)
retriever = _build_index(index_path, doc_vecs, corpus_ids, images) retriever = _build_index(index_path, doc_vecs, corpus_ids, images)
print(f"LEANN index built") print("LEANN index built")
return retriever, corpus_ids return retriever, corpus_ids
def search_queries( def search_queries(
self, self,
queries: dict[str, str], queries: dict[str, str],
@@ -1154,34 +1186,34 @@ class ViDoReBenchmarkEvaluator:
) -> dict[str, dict[str, float]]: ) -> dict[str, dict[str, float]]:
""" """
Search queries against the index. Search queries against the index.
Args: Args:
queries: dict mapping query_id to query text queries: dict mapping query_id to query text
corpus_ids: list of corpus_ids in the same order as the index corpus_ids: list of corpus_ids in the same order as the index
index_or_retriever: index or retriever object index_or_retriever: index or retriever object
fast_plaid_index_path: path to Fast-Plaid index (for metadata) fast_plaid_index_path: path to Fast-Plaid index (for metadata)
task_prompt: Optional dict with prompt for query (e.g., {"query": "..."}) task_prompt: Optional dict with prompt for query (e.g., {"query": "..."})
Returns: Returns:
results: dict mapping query_id to dict of {corpus_id: score} results: dict mapping query_id to dict of {corpus_id: score}
""" """
self._load_model_if_needed() self._load_model_if_needed()
print(f"Searching {len(queries)} queries (top_k={self.top_k})...") print(f"Searching {len(queries)} queries (top_k={self.top_k})...")
query_ids = list(queries.keys()) query_ids = list(queries.keys())
query_texts = [queries[qid] for qid in query_ids] query_texts = [queries[qid] for qid in query_ids]
# Note: ColPaliEngineWrapper does NOT use task prompt from metadata # Note: ColPaliEngineWrapper does NOT use task prompt from metadata
# It uses query_prefix + text + query_augmentation_token (handled in _embed_queries) # It uses query_prefix + text + query_augmentation_token (handled in _embed_queries)
# So we don't append task_prompt here to match MTEB behavior # So we don't append task_prompt here to match MTEB behavior
# Embed queries # Embed queries
print("Embedding queries...") print("Embedding queries...")
query_vecs = _embed_queries(self._model, self._processor, query_texts) query_vecs = _embed_queries(self._model, self._processor, query_texts)
results = {} results = {}
for query_id, query_vec in zip(tqdm(query_ids, desc="Searching"), query_vecs): for query_id, query_vec in zip(tqdm(query_ids, desc="Searching"), query_vecs):
if self.use_fast_plaid: if self.use_fast_plaid:
# Fast-Plaid search # Fast-Plaid search
@@ -1194,47 +1226,51 @@ class ViDoReBenchmarkEvaluator:
else: else:
# LEANN search # LEANN search
import torch import torch
query_np = query_vec.float().numpy() if isinstance(query_vec, torch.Tensor) else query_vec
search_results = index_or_retriever.search_exact_all(query_np, topk=self.top_k) query_np = (
query_vec.float().numpy() if isinstance(query_vec, torch.Tensor) else query_vec
)
search_results = index_or_retriever.search_exact(query_np, topk=self.top_k)
query_results = {} query_results = {}
for score, doc_id in search_results: for score, doc_id in search_results:
if doc_id < len(corpus_ids): if doc_id < len(corpus_ids):
corpus_id = corpus_ids[doc_id] corpus_id = corpus_ids[doc_id]
query_results[corpus_id] = float(score) query_results[corpus_id] = float(score)
results[query_id] = query_results results[query_id] = query_results
return results return results
@staticmethod @staticmethod
def evaluate_results( def evaluate_results(
results: dict[str, dict[str, float]], results: dict[str, dict[str, float]],
qrels: dict[str, dict[str, int]], qrels: dict[str, dict[str, int]],
k_values: list[int] = None, k_values: Optional[list[int]] = None,
) -> dict[str, float]: ) -> dict[str, float]:
""" """
Evaluate retrieval results using NDCG and other metrics. Evaluate retrieval results using NDCG and other metrics.
Args: Args:
results: dict mapping query_id to dict of {corpus_id: score} results: dict mapping query_id to dict of {corpus_id: score}
qrels: dict mapping query_id to dict of {corpus_id: relevance_score} qrels: dict mapping query_id to dict of {corpus_id: relevance_score}
k_values: List of k values for evaluation metrics k_values: List of k values for evaluation metrics
Returns: Returns:
Dictionary of metric scores Dictionary of metric scores
""" """
try: try:
import pytrec_eval
from mteb._evaluators.retrieval_metrics import ( from mteb._evaluators.retrieval_metrics import (
calculate_retrieval_scores, calculate_retrieval_scores,
make_score_dict, make_score_dict,
) )
except ImportError: except ImportError:
raise ImportError("pytrec_eval is required for evaluation. Install with: pip install pytrec-eval") raise ImportError(
"pytrec_eval is required for evaluation. Install with: pip install pytrec-eval"
)
if k_values is None: if k_values is None:
k_values = [1, 3, 5, 10, 100] k_values = [1, 3, 5, 10, 100]
# Check if we have any queries to evaluate # Check if we have any queries to evaluate
if len(results) == 0: if len(results) == 0:
print("Warning: No queries to evaluate. Returning zero scores.") print("Warning: No queries to evaluate. Returning zero scores.")
@@ -1246,38 +1282,42 @@ class ViDoReBenchmarkEvaluator:
scores[f"precision_at_{k}"] = 0.0 scores[f"precision_at_{k}"] = 0.0
scores[f"mrr_at_{k}"] = 0.0 scores[f"mrr_at_{k}"] = 0.0
return scores return scores
print(f"Evaluating results with k_values={k_values}...") print(f"Evaluating results with k_values={k_values}...")
print(f"Before filtering: {len(results)} results, {len(qrels)} qrels") print(f"Before filtering: {len(results)} results, {len(qrels)} qrels")
# Filter to ensure qrels and results have the same query set # Filter to ensure qrels and results have the same query set
# This matches MTEB behavior: only evaluate queries that exist in both # This matches MTEB behavior: only evaluate queries that exist in both
# pytrec_eval only evaluates queries in qrels, so we need to ensure # pytrec_eval only evaluates queries in qrels, so we need to ensure
# results contains all queries in qrels, and filter out queries not in qrels # results contains all queries in qrels, and filter out queries not in qrels
results_filtered = {qid: res for qid, res in results.items() if qid in qrels} results_filtered = {qid: res for qid, res in results.items() if qid in qrels}
qrels_filtered = {qid: rel_docs for qid, rel_docs in qrels.items() if qid in results_filtered} qrels_filtered = {
qid: rel_docs for qid, rel_docs in qrels.items() if qid in results_filtered
}
print(f"After filtering: {len(results_filtered)} results, {len(qrels_filtered)} qrels") print(f"After filtering: {len(results_filtered)} results, {len(qrels_filtered)} qrels")
if len(results_filtered) != len(qrels_filtered): if len(results_filtered) != len(qrels_filtered):
print(f"Warning: Mismatch between results ({len(results_filtered)}) and qrels ({len(qrels_filtered)}) queries") print(
f"Warning: Mismatch between results ({len(results_filtered)}) and qrels ({len(qrels_filtered)}) queries"
)
missing_in_results = set(qrels.keys()) - set(results.keys()) missing_in_results = set(qrels.keys()) - set(results.keys())
if missing_in_results: if missing_in_results:
print(f"Queries in qrels but not in results: {len(missing_in_results)} queries") print(f"Queries in qrels but not in results: {len(missing_in_results)} queries")
print(f"First 5 missing queries: {list(missing_in_results)[:5]}") print(f"First 5 missing queries: {list(missing_in_results)[:5]}")
# Convert qrels to pytrec_eval format # Convert qrels to pytrec_eval format
qrels_pytrec = {} qrels_pytrec = {}
for qid, rel_docs in qrels_filtered.items(): for qid, rel_docs in qrels_filtered.items():
qrels_pytrec[qid] = {did: score for did, score in rel_docs.items()} qrels_pytrec[qid] = dict(rel_docs.items())
# Evaluate # Evaluate
eval_result = calculate_retrieval_scores( eval_result = calculate_retrieval_scores(
results=results_filtered, results=results_filtered,
qrels=qrels_pytrec, qrels=qrels_pytrec,
k_values=k_values, k_values=k_values,
) )
# Format scores # Format scores
scores = make_score_dict( scores = make_score_dict(
ndcg=eval_result.ndcg, ndcg=eval_result.ndcg,
@@ -1290,5 +1330,5 @@ class ViDoReBenchmarkEvaluator:
cv_recall=eval_result.cv_recall, cv_recall=eval_result.cv_recall,
task_scores={}, task_scores={},
) )
return scores return scores

View File

@@ -83,7 +83,7 @@ INDEX_PATH: str = "./indexes/colvision_large.leann"
# These are now command-line arguments (see CLI overrides section) # These are now command-line arguments (see CLI overrides section)
TOPK: int = 3 TOPK: int = 3
FIRST_STAGE_K: int = 500 FIRST_STAGE_K: int = 500
REBUILD_INDEX: bool = False REBUILD_INDEX: bool = True
# Artifacts # Artifacts
SAVE_TOP_IMAGE: Optional[str] = "./figures/retrieved_page.png" SAVE_TOP_IMAGE: Optional[str] = "./figures/retrieved_page.png"
@@ -122,11 +122,18 @@ parser.add_argument(
default="./indexes/colvision_fastplaid", default="./indexes/colvision_fastplaid",
help="Path to the Fast-Plaid index. Default: './indexes/colvision_fastplaid'", help="Path to the Fast-Plaid index. Default: './indexes/colvision_fastplaid'",
) )
parser.add_argument(
"--topk",
type=int,
default=TOPK,
help=f"Number of top results to retrieve. Default: {TOPK}",
)
cli_args, _unknown = parser.parse_known_args() cli_args, _unknown = parser.parse_known_args()
SEARCH_METHOD: str = cli_args.search_method SEARCH_METHOD: str = cli_args.search_method
QUERY = cli_args.query # Override QUERY with CLI argument if provided QUERY = cli_args.query # Override QUERY with CLI argument if provided
USE_FAST_PLAID: bool = cli_args.use_fast_plaid USE_FAST_PLAID: bool = cli_args.use_fast_plaid
FAST_PLAID_INDEX_PATH: str = cli_args.fast_plaid_index_path FAST_PLAID_INDEX_PATH: str = cli_args.fast_plaid_index_path
TOPK: int = cli_args.topk # Override TOPK with CLI argument if provided
# %% # %%
@@ -399,7 +406,7 @@ if need_to_build_index:
f"No images found in {PAGES_DIR}. Provide PDF path in PDF variable or ensure images exist." f"No images found in {PAGES_DIR}. Provide PDF path in PDF variable or ensure images exist."
) )
print(f"Loaded {len(images)} images") print(f"Loaded {len(images)} images")
# Memory check before loading model # Memory check before loading model
try: try:
import psutil import psutil
@@ -426,10 +433,10 @@ try:
import sys import sys
print(f" Python version: {sys.version}") print(f" Python version: {sys.version}")
print(f" Python executable: {sys.executable}") print(f" Python executable: {sys.executable}")
model_name, model, processor, device_str, device, dtype = _load_colvision(MODEL) model_name, model, processor, device_str, device, dtype = _load_colvision(MODEL)
print(f"✓ Using model={model_name}, device={device_str}, dtype={dtype}") print(f"✓ Using model={model_name}, device={device_str}, dtype={dtype}")
# Memory check after loading model # Memory check after loading model
try: try:
import psutil import psutil
@@ -457,7 +464,7 @@ if need_to_build_index:
print("Step 4: Building index...") print("Step 4: Building index...")
print(f" Number of images: {len(images)}") print(f" Number of images: {len(images)}")
print(f" Number of filepaths: {len(filepaths)}") print(f" Number of filepaths: {len(filepaths)}")
try: try:
print(" Embedding images...") print(" Embedding images...")
doc_vecs = _embed_images(model, processor, images) doc_vecs = _embed_images(model, processor, images)
@@ -468,7 +475,7 @@ if need_to_build_index:
import traceback import traceback
traceback.print_exc() traceback.print_exc()
raise raise
if USE_FAST_PLAID: if USE_FAST_PLAID:
# Build Fast-Plaid index # Build Fast-Plaid index
print(" Building Fast-Plaid index...") print(" Building Fast-Plaid index...")
@@ -523,13 +530,13 @@ if USE_FAST_PLAID:
fast_plaid_index = _load_fast_plaid_index_if_exists(FAST_PLAID_INDEX_PATH) fast_plaid_index = _load_fast_plaid_index_if_exists(FAST_PLAID_INDEX_PATH)
if fast_plaid_index is None: if fast_plaid_index is None:
raise RuntimeError(f"Fast-Plaid index not found at {FAST_PLAID_INDEX_PATH}") raise RuntimeError(f"Fast-Plaid index not found at {FAST_PLAID_INDEX_PATH}")
results, search_secs = _search_fast_plaid(fast_plaid_index, q_vec, TOPK) results, search_secs = _search_fast_plaid(fast_plaid_index, q_vec, TOPK)
print(f"[Timing] Fast-Plaid Search: {search_secs:.3f}s") print(f"[Timing] Fast-Plaid Search: {search_secs:.3f}s")
else: else:
# Original LEANN search # Original LEANN search
query_np = q_vec.float().numpy() query_np = q_vec.float().numpy()
if SEARCH_METHOD == "ann": if SEARCH_METHOD == "ann":
results = retriever.search(query_np, topk=TOPK, first_stage_k=FIRST_STAGE_K) results = retriever.search(query_np, topk=TOPK, first_stage_k=FIRST_STAGE_K)
search_secs = time.perf_counter() - _t0 search_secs = time.perf_counter() - _t0
@@ -548,7 +555,10 @@ if not results:
print("No results found.") print("No results found.")
else: else:
print(f'Top {len(results)} results for query: "{QUERY}"') print(f'Top {len(results)} results for query: "{QUERY}"')
print("\n[DEBUG] Retrieval details:")
top_images: list[Image.Image] = [] top_images: list[Image.Image] = []
image_hashes = {} # Track image hashes to detect duplicates
for rank, (score, doc_id) in enumerate(results, start=1): for rank, (score, doc_id) in enumerate(results, start=1):
# Retrieve image and metadata based on index type # Retrieve image and metadata based on index type
if USE_FAST_PLAID: if USE_FAST_PLAID:
@@ -557,7 +567,7 @@ else:
if image is None: if image is None:
print(f"Warning: Could not find image for doc_id {doc_id}") print(f"Warning: Could not find image for doc_id {doc_id}")
continue continue
metadata = _get_fast_plaid_metadata(FAST_PLAID_INDEX_PATH, doc_id) metadata = _get_fast_plaid_metadata(FAST_PLAID_INDEX_PATH, doc_id)
path = metadata.get("filepath", f"doc_{doc_id}") if metadata else f"doc_{doc_id}" path = metadata.get("filepath", f"doc_{doc_id}") if metadata else f"doc_{doc_id}"
top_images.append(image) top_images.append(image)
@@ -571,9 +581,27 @@ else:
metadata = retriever.get_metadata(doc_id) metadata = retriever.get_metadata(doc_id)
path = metadata.get("filepath", "unknown") if metadata else "unknown" path = metadata.get("filepath", "unknown") if metadata else "unknown"
top_images.append(image) top_images.append(image)
# For HF dataset, path is a descriptive identifier, not a real file path # Calculate image hash to detect duplicates
print(f"{rank}) MaxSim: {score:.4f}, Page: {path}") import hashlib
import io
# Convert image to bytes for hashing
img_bytes = io.BytesIO()
image.save(img_bytes, format='PNG')
image_bytes = img_bytes.getvalue()
image_hash = hashlib.md5(image_bytes).hexdigest()[:8]
# Check if this image was already seen
duplicate_info = ""
if image_hash in image_hashes:
duplicate_info = f" [DUPLICATE of rank {image_hashes[image_hash]}]"
else:
image_hashes[image_hash] = rank
# Print detailed information
print(f"{rank}) doc_id={doc_id}, MaxSim={score:.4f}, Page={path}, ImageHash={image_hash}{duplicate_info}")
if metadata:
print(f" Metadata: {metadata}")
if SAVE_TOP_IMAGE: if SAVE_TOP_IMAGE:
from pathlib import Path as _Path from pathlib import Path as _Path

View File

@@ -11,13 +11,13 @@ This script uses the interface from leann_multi_vector.py to:
Usage: Usage:
# Evaluate all ViDoRe v1 tasks # Evaluate all ViDoRe v1 tasks
python vidore_v1_benchmark.py --model colqwen2 --tasks all python vidore_v1_benchmark.py --model colqwen2 --tasks all
# Evaluate specific task # Evaluate specific task
python vidore_v1_benchmark.py --model colqwen2 --task VidoreArxivQARetrieval python vidore_v1_benchmark.py --model colqwen2 --task VidoreArxivQARetrieval
# Use Fast-Plaid index # Use Fast-Plaid index
python vidore_v1_benchmark.py --model colqwen2 --use-fast-plaid --fast-plaid-index-path ./indexes/vidore_fastplaid python vidore_v1_benchmark.py --model colqwen2 --use-fast-plaid --fast-plaid-index-path ./indexes/vidore_fastplaid
# Rebuild index # Rebuild index
python vidore_v1_benchmark.py --model colqwen2 --rebuild-index python vidore_v1_benchmark.py --model colqwen2 --rebuild-index
""" """
@@ -28,11 +28,9 @@ import os
from typing import Optional from typing import Optional
from datasets import load_dataset from datasets import load_dataset
from PIL import Image
from leann_multi_vector import ( from leann_multi_vector import (
_ensure_repo_paths_importable,
ViDoReBenchmarkEvaluator, ViDoReBenchmarkEvaluator,
_ensure_repo_paths_importable,
) )
_ensure_repo_paths_importable(__file__) _ensure_repo_paths_importable(__file__)
@@ -100,25 +98,25 @@ def load_vidore_v1_data(
): ):
""" """
Load ViDoRe v1 dataset. Load ViDoRe v1 dataset.
Returns: Returns:
corpus: dict mapping corpus_id to PIL Image corpus: dict mapping corpus_id to PIL Image
queries: dict mapping query_id to query text queries: dict mapping query_id to query text
qrels: dict mapping query_id to dict of {corpus_id: relevance_score} qrels: dict mapping query_id to dict of {corpus_id: relevance_score}
""" """
print(f"Loading dataset: {dataset_path} (split={split})") print(f"Loading dataset: {dataset_path} (split={split})")
# Load queries # Load queries
query_ds = load_dataset(dataset_path, "queries", split=split, revision=revision) query_ds = load_dataset(dataset_path, "queries", split=split, revision=revision)
queries = {} queries = {}
for row in query_ds: for row in query_ds:
query_id = f"query-{split}-{row['query-id']}" query_id = f"query-{split}-{row['query-id']}"
queries[query_id] = row["query"] queries[query_id] = row["query"]
# Load corpus (images) # Load corpus (images)
corpus_ds = load_dataset(dataset_path, "corpus", split=split, revision=revision) corpus_ds = load_dataset(dataset_path, "corpus", split=split, revision=revision)
corpus = {} corpus = {}
for row in corpus_ds: for row in corpus_ds:
corpus_id = f"corpus-{split}-{row['corpus-id']}" corpus_id = f"corpus-{split}-{row['corpus-id']}"
@@ -128,11 +126,13 @@ def load_vidore_v1_data(
elif "page_image" in row: elif "page_image" in row:
corpus[corpus_id] = row["page_image"] corpus[corpus_id] = row["page_image"]
else: else:
raise ValueError(f"No image field found in corpus. Available fields: {list(row.keys())}") raise ValueError(
f"No image field found in corpus. Available fields: {list(row.keys())}"
)
# Load qrels (relevance judgments) # Load qrels (relevance judgments)
qrels_ds = load_dataset(dataset_path, "qrels", split=split, revision=revision) qrels_ds = load_dataset(dataset_path, "qrels", split=split, revision=revision)
qrels = {} qrels = {}
for row in qrels_ds: for row in qrels_ds:
query_id = f"query-{split}-{row['query-id']}" query_id = f"query-{split}-{row['query-id']}"
@@ -140,19 +140,25 @@ def load_vidore_v1_data(
if query_id not in qrels: if query_id not in qrels:
qrels[query_id] = {} qrels[query_id] = {}
qrels[query_id][corpus_id] = int(row["score"]) qrels[query_id][corpus_id] = int(row["score"])
print(f"Loaded {len(queries)} queries, {len(corpus)} corpus items, {len(qrels)} query-relevance mappings") print(
f"Loaded {len(queries)} queries, {len(corpus)} corpus items, {len(qrels)} query-relevance mappings"
)
# Filter qrels to only include queries that exist # Filter qrels to only include queries that exist
qrels = {qid: rel_docs for qid, rel_docs in qrels.items() if qid in queries} qrels = {qid: rel_docs for qid, rel_docs in qrels.items() if qid in queries}
# Filter out queries without any relevant documents (matching MTEB behavior) # Filter out queries without any relevant documents (matching MTEB behavior)
# This is important for correct NDCG calculation # This is important for correct NDCG calculation
qrels_filtered = {qid: rel_docs for qid, rel_docs in qrels.items() if len(rel_docs) > 0} qrels_filtered = {qid: rel_docs for qid, rel_docs in qrels.items() if len(rel_docs) > 0}
queries_filtered = {qid: query_text for qid, query_text in queries.items() if qid in qrels_filtered} queries_filtered = {
qid: query_text for qid, query_text in queries.items() if qid in qrels_filtered
print(f"After filtering queries without positives: {len(queries_filtered)} queries, {len(qrels_filtered)} query-relevance mappings") }
print(
f"After filtering queries without positives: {len(queries_filtered)} queries, {len(qrels_filtered)} query-relevance mappings"
)
return corpus, queries_filtered, qrels_filtered return corpus, queries_filtered, qrels_filtered
@@ -165,31 +171,35 @@ def evaluate_task(
rebuild_index: bool = False, rebuild_index: bool = False,
top_k: int = 1000, top_k: int = 1000,
first_stage_k: int = 500, first_stage_k: int = 500,
k_values: list[int] = [1, 3, 5, 10, 20, 100, 1000], k_values: Optional[list[int]] = None,
output_dir: Optional[str] = None, output_dir: Optional[str] = None,
): ):
""" """
Evaluate a single ViDoRe v1 task. Evaluate a single ViDoRe v1 task.
""" """
print(f"\n{'='*80}") print(f"\n{'=' * 80}")
print(f"Evaluating task: {task_name}") print(f"Evaluating task: {task_name}")
print(f"{'='*80}") print(f"{'=' * 80}")
# Get task config # Get task config
if task_name not in VIDORE_V1_TASKS: if task_name not in VIDORE_V1_TASKS:
raise ValueError(f"Unknown task: {task_name}. Available: {list(VIDORE_V1_TASKS.keys())}") raise ValueError(f"Unknown task: {task_name}. Available: {list(VIDORE_V1_TASKS.keys())}")
task_config = VIDORE_V1_TASKS[task_name] task_config = VIDORE_V1_TASKS[task_name]
dataset_path = task_config["dataset_path"] dataset_path = task_config["dataset_path"]
revision = task_config["revision"] revision = task_config["revision"]
# Load data # Load data
corpus, queries, qrels = load_vidore_v1_data( corpus, queries, qrels = load_vidore_v1_data(
dataset_path=dataset_path, dataset_path=dataset_path,
revision=revision, revision=revision,
split="test", split="test",
) )
# Initialize k_values if not provided
if k_values is None:
k_values = [1, 3, 5, 10, 20, 100, 1000]
# Check if we have any queries # Check if we have any queries
if len(queries) == 0: if len(queries) == 0:
print(f"\nWarning: No queries found for task {task_name}. Skipping evaluation.") print(f"\nWarning: No queries found for task {task_name}. Skipping evaluation.")
@@ -202,7 +212,7 @@ def evaluate_task(
scores[f"precision_at_{k}"] = 0.0 scores[f"precision_at_{k}"] = 0.0
scores[f"mrr_at_{k}"] = 0.0 scores[f"mrr_at_{k}"] = 0.0
return scores return scores
# Initialize evaluator # Initialize evaluator
evaluator = ViDoReBenchmarkEvaluator( evaluator = ViDoReBenchmarkEvaluator(
model_name=model_name, model_name=model_name,
@@ -211,20 +221,20 @@ def evaluate_task(
first_stage_k=first_stage_k, first_stage_k=first_stage_k,
k_values=k_values, k_values=k_values,
) )
# Build or load index # Build or load index
index_path_full = index_path if not use_fast_plaid else fast_plaid_index_path index_path_full = index_path if not use_fast_plaid else fast_plaid_index_path
if index_path_full is None: if index_path_full is None:
index_path_full = f"./indexes/{task_name}_{model_name}" index_path_full = f"./indexes/{task_name}_{model_name}"
if use_fast_plaid: if use_fast_plaid:
index_path_full = f"./indexes/{task_name}_{model_name}_fastplaid" index_path_full = f"./indexes/{task_name}_{model_name}_fastplaid"
index_or_retriever, corpus_ids_ordered = evaluator.build_index_from_corpus( index_or_retriever, corpus_ids_ordered = evaluator.build_index_from_corpus(
corpus=corpus, corpus=corpus,
index_path=index_path_full, index_path=index_path_full,
rebuild=rebuild_index, rebuild=rebuild_index,
) )
# Search queries # Search queries
task_prompt = task_config.get("prompt") task_prompt = task_config.get("prompt")
results = evaluator.search_queries( results = evaluator.search_queries(
@@ -234,32 +244,32 @@ def evaluate_task(
fast_plaid_index_path=fast_plaid_index_path, fast_plaid_index_path=fast_plaid_index_path,
task_prompt=task_prompt, task_prompt=task_prompt,
) )
# Evaluate # Evaluate
scores = evaluator.evaluate_results(results, qrels, k_values=k_values) scores = evaluator.evaluate_results(results, qrels, k_values=k_values)
# Print results # Print results
print(f"\n{'='*80}") print(f"\n{'=' * 80}")
print(f"Results for {task_name}:") print(f"Results for {task_name}:")
print(f"{'='*80}") print(f"{'=' * 80}")
for metric, value in scores.items(): for metric, value in scores.items():
if isinstance(value, (int, float)): if isinstance(value, (int, float)):
print(f" {metric}: {value:.5f}") print(f" {metric}: {value:.5f}")
# Save results # Save results
if output_dir: if output_dir:
os.makedirs(output_dir, exist_ok=True) os.makedirs(output_dir, exist_ok=True)
results_file = os.path.join(output_dir, f"{task_name}_results.json") results_file = os.path.join(output_dir, f"{task_name}_results.json")
scores_file = os.path.join(output_dir, f"{task_name}_scores.json") scores_file = os.path.join(output_dir, f"{task_name}_scores.json")
with open(results_file, "w") as f: with open(results_file, "w") as f:
json.dump(results, f, indent=2) json.dump(results, f, indent=2)
print(f"\nSaved results to: {results_file}") print(f"\nSaved results to: {results_file}")
with open(scores_file, "w") as f: with open(scores_file, "w") as f:
json.dump(scores, f, indent=2) json.dump(scores, f, indent=2)
print(f"Saved scores to: {scores_file}") print(f"Saved scores to: {scores_file}")
return scores return scores
@@ -332,12 +342,12 @@ def main():
default="./vidore_v1_results", default="./vidore_v1_results",
help="Output directory for results", help="Output directory for results",
) )
args = parser.parse_args() args = parser.parse_args()
# Parse k_values # Parse k_values
k_values = [int(k.strip()) for k in args.k_values.split(",")] k_values = [int(k.strip()) for k in args.k_values.split(",")]
# Determine tasks to evaluate # Determine tasks to evaluate
if args.task: if args.task:
tasks_to_eval = [args.task] tasks_to_eval = [args.task]
@@ -345,9 +355,9 @@ def main():
tasks_to_eval = list(VIDORE_V1_TASKS.keys()) tasks_to_eval = list(VIDORE_V1_TASKS.keys())
else: else:
tasks_to_eval = [t.strip() for t in args.tasks.split(",")] tasks_to_eval = [t.strip() for t in args.tasks.split(",")]
print(f"Tasks to evaluate: {tasks_to_eval}") print(f"Tasks to evaluate: {tasks_to_eval}")
# Evaluate each task # Evaluate each task
all_scores = {} all_scores = {}
for task_name in tasks_to_eval: for task_name in tasks_to_eval:
@@ -368,14 +378,15 @@ def main():
except Exception as e: except Exception as e:
print(f"\nError evaluating {task_name}: {e}") print(f"\nError evaluating {task_name}: {e}")
import traceback import traceback
traceback.print_exc() traceback.print_exc()
continue continue
# Print summary # Print summary
if all_scores: if all_scores:
print(f"\n{'='*80}") print(f"\n{'=' * 80}")
print("SUMMARY") print("SUMMARY")
print(f"{'='*80}") print(f"{'=' * 80}")
for task_name, scores in all_scores.items(): for task_name, scores in all_scores.items():
print(f"\n{task_name}:") print(f"\n{task_name}:")
# Print main metrics # Print main metrics
@@ -386,4 +397,3 @@ def main():
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@@ -11,13 +11,13 @@ This script uses the interface from leann_multi_vector.py to:
Usage: Usage:
# Evaluate all ViDoRe v2 tasks # Evaluate all ViDoRe v2 tasks
python vidore_v2_benchmark.py --model colqwen2 --tasks all python vidore_v2_benchmark.py --model colqwen2 --tasks all
# Evaluate specific task # Evaluate specific task
python vidore_v2_benchmark.py --model colqwen2 --task Vidore2ESGReportsRetrieval python vidore_v2_benchmark.py --model colqwen2 --task Vidore2ESGReportsRetrieval
# Use Fast-Plaid index # Use Fast-Plaid index
python vidore_v2_benchmark.py --model colqwen2 --use-fast-plaid --fast-plaid-index-path ./indexes/vidore_fastplaid python vidore_v2_benchmark.py --model colqwen2 --use-fast-plaid --fast-plaid-index-path ./indexes/vidore_fastplaid
# Rebuild index # Rebuild index
python vidore_v2_benchmark.py --model colqwen2 --rebuild-index python vidore_v2_benchmark.py --model colqwen2 --rebuild-index
""" """
@@ -28,11 +28,9 @@ import os
from typing import Optional from typing import Optional
from datasets import load_dataset from datasets import load_dataset
from PIL import Image
from leann_multi_vector import ( from leann_multi_vector import (
_ensure_repo_paths_importable,
ViDoReBenchmarkEvaluator, ViDoReBenchmarkEvaluator,
_ensure_repo_paths_importable,
) )
_ensure_repo_paths_importable(__file__) _ensure_repo_paths_importable(__file__)
@@ -85,51 +83,57 @@ def load_vidore_v2_data(
): ):
""" """
Load ViDoRe v2 dataset. Load ViDoRe v2 dataset.
Returns: Returns:
corpus: dict mapping corpus_id to PIL Image corpus: dict mapping corpus_id to PIL Image
queries: dict mapping query_id to query text queries: dict mapping query_id to query text
qrels: dict mapping query_id to dict of {corpus_id: relevance_score} qrels: dict mapping query_id to dict of {corpus_id: relevance_score}
""" """
print(f"Loading dataset: {dataset_path} (split={split}, language={language})") print(f"Loading dataset: {dataset_path} (split={split}, language={language})")
# Load queries # Load queries
query_ds = load_dataset(dataset_path, "queries", split=split, revision=revision) query_ds = load_dataset(dataset_path, "queries", split=split, revision=revision)
# Check if dataset has language field before filtering # Check if dataset has language field before filtering
has_language_field = len(query_ds) > 0 and "language" in query_ds.column_names has_language_field = len(query_ds) > 0 and "language" in query_ds.column_names
if language and has_language_field: if language and has_language_field:
# Map language name to dataset language field value (e.g., "english" -> "eng-Latn") # Map language name to dataset language field value (e.g., "english" -> "eng-Latn")
dataset_language = LANGUAGE_MAPPING.get(language, language) dataset_language = LANGUAGE_MAPPING.get(language, language)
query_ds_filtered = query_ds.filter(lambda x: x.get("language") == dataset_language) query_ds_filtered = query_ds.filter(lambda x: x.get("language") == dataset_language)
# Check if filtering resulted in empty dataset # Check if filtering resulted in empty dataset
if len(query_ds_filtered) == 0: if len(query_ds_filtered) == 0:
print(f"Warning: No queries found after filtering by language '{language}' (mapped to '{dataset_language}').") print(
f"Warning: No queries found after filtering by language '{language}' (mapped to '{dataset_language}')."
)
# Try with original language value (dataset might use simple names like 'english') # Try with original language value (dataset might use simple names like 'english')
print(f"Trying with original language value '{language}'...") print(f"Trying with original language value '{language}'...")
query_ds_filtered = query_ds.filter(lambda x: x.get("language") == language) query_ds_filtered = query_ds.filter(lambda x: x.get("language") == language)
if len(query_ds_filtered) == 0: if len(query_ds_filtered) == 0:
# Try to get a sample to see actual language values # Try to get a sample to see actual language values
try: try:
sample_ds = load_dataset(dataset_path, "queries", split=split, revision=revision) sample_ds = load_dataset(
dataset_path, "queries", split=split, revision=revision
)
if len(sample_ds) > 0 and "language" in sample_ds.column_names: if len(sample_ds) > 0 and "language" in sample_ds.column_names:
sample_langs = set(sample_ds["language"]) sample_langs = set(sample_ds["language"])
print(f"Available language values in dataset: {sample_langs}") print(f"Available language values in dataset: {sample_langs}")
except Exception: except Exception:
pass pass
else: else:
print(f"Found {len(query_ds_filtered)} queries using original language value '{language}'") print(
f"Found {len(query_ds_filtered)} queries using original language value '{language}'"
)
query_ds = query_ds_filtered query_ds = query_ds_filtered
queries = {} queries = {}
for row in query_ds: for row in query_ds:
query_id = f"query-{split}-{row['query-id']}" query_id = f"query-{split}-{row['query-id']}"
queries[query_id] = row["query"] queries[query_id] = row["query"]
# Load corpus (images) # Load corpus (images)
corpus_ds = load_dataset(dataset_path, "corpus", split=split, revision=revision) corpus_ds = load_dataset(dataset_path, "corpus", split=split, revision=revision)
corpus = {} corpus = {}
for row in corpus_ds: for row in corpus_ds:
corpus_id = f"corpus-{split}-{row['corpus-id']}" corpus_id = f"corpus-{split}-{row['corpus-id']}"
@@ -139,11 +143,13 @@ def load_vidore_v2_data(
elif "page_image" in row: elif "page_image" in row:
corpus[corpus_id] = row["page_image"] corpus[corpus_id] = row["page_image"]
else: else:
raise ValueError(f"No image field found in corpus. Available fields: {list(row.keys())}") raise ValueError(
f"No image field found in corpus. Available fields: {list(row.keys())}"
)
# Load qrels (relevance judgments) # Load qrels (relevance judgments)
qrels_ds = load_dataset(dataset_path, "qrels", split=split, revision=revision) qrels_ds = load_dataset(dataset_path, "qrels", split=split, revision=revision)
qrels = {} qrels = {}
for row in qrels_ds: for row in qrels_ds:
query_id = f"query-{split}-{row['query-id']}" query_id = f"query-{split}-{row['query-id']}"
@@ -151,19 +157,25 @@ def load_vidore_v2_data(
if query_id not in qrels: if query_id not in qrels:
qrels[query_id] = {} qrels[query_id] = {}
qrels[query_id][corpus_id] = int(row["score"]) qrels[query_id][corpus_id] = int(row["score"])
print(f"Loaded {len(queries)} queries, {len(corpus)} corpus items, {len(qrels)} query-relevance mappings") print(
f"Loaded {len(queries)} queries, {len(corpus)} corpus items, {len(qrels)} query-relevance mappings"
)
# Filter qrels to only include queries that exist # Filter qrels to only include queries that exist
qrels = {qid: rel_docs for qid, rel_docs in qrels.items() if qid in queries} qrels = {qid: rel_docs for qid, rel_docs in qrels.items() if qid in queries}
# Filter out queries without any relevant documents (matching MTEB behavior) # Filter out queries without any relevant documents (matching MTEB behavior)
# This is important for correct NDCG calculation # This is important for correct NDCG calculation
qrels_filtered = {qid: rel_docs for qid, rel_docs in qrels.items() if len(rel_docs) > 0} qrels_filtered = {qid: rel_docs for qid, rel_docs in qrels.items() if len(rel_docs) > 0}
queries_filtered = {qid: query_text for qid, query_text in queries.items() if qid in qrels_filtered} queries_filtered = {
qid: query_text for qid, query_text in queries.items() if qid in qrels_filtered
print(f"After filtering queries without positives: {len(queries_filtered)} queries, {len(qrels_filtered)} query-relevance mappings") }
print(
f"After filtering queries without positives: {len(queries_filtered)} queries, {len(qrels_filtered)} query-relevance mappings"
)
return corpus, queries_filtered, qrels_filtered return corpus, queries_filtered, qrels_filtered
@@ -177,24 +189,24 @@ def evaluate_task(
rebuild_index: bool = False, rebuild_index: bool = False,
top_k: int = 100, top_k: int = 100,
first_stage_k: int = 500, first_stage_k: int = 500,
k_values: list[int] = [1, 3, 5, 10, 100], k_values: Optional[list[int]] = None,
output_dir: Optional[str] = None, output_dir: Optional[str] = None,
): ):
""" """
Evaluate a single ViDoRe v2 task. Evaluate a single ViDoRe v2 task.
""" """
print(f"\n{'='*80}") print(f"\n{'=' * 80}")
print(f"Evaluating task: {task_name}") print(f"Evaluating task: {task_name}")
print(f"{'='*80}") print(f"{'=' * 80}")
# Get task config # Get task config
if task_name not in VIDORE_V2_TASKS: if task_name not in VIDORE_V2_TASKS:
raise ValueError(f"Unknown task: {task_name}. Available: {list(VIDORE_V2_TASKS.keys())}") raise ValueError(f"Unknown task: {task_name}. Available: {list(VIDORE_V2_TASKS.keys())}")
task_config = VIDORE_V2_TASKS[task_name] task_config = VIDORE_V2_TASKS[task_name]
dataset_path = task_config["dataset_path"] dataset_path = task_config["dataset_path"]
revision = task_config["revision"] revision = task_config["revision"]
# Determine language # Determine language
if language is None: if language is None:
# Use first language if multiple available # Use first language if multiple available
@@ -206,7 +218,11 @@ def evaluate_task(
language = languages[0] language = languages[0]
else: else:
language = None language = None
# Initialize k_values if not provided
if k_values is None:
k_values = [1, 3, 5, 10, 100]
# Load data # Load data
corpus, queries, qrels = load_vidore_v2_data( corpus, queries, qrels = load_vidore_v2_data(
dataset_path=dataset_path, dataset_path=dataset_path,
@@ -214,10 +230,12 @@ def evaluate_task(
split="test", split="test",
language=language, language=language,
) )
# Check if we have any queries # Check if we have any queries
if len(queries) == 0: if len(queries) == 0:
print(f"\nWarning: No queries found for task {task_name} with language {language}. Skipping evaluation.") print(
f"\nWarning: No queries found for task {task_name} with language {language}. Skipping evaluation."
)
# Return zero scores # Return zero scores
scores = {} scores = {}
for k in k_values: for k in k_values:
@@ -227,7 +245,7 @@ def evaluate_task(
scores[f"precision_at_{k}"] = 0.0 scores[f"precision_at_{k}"] = 0.0
scores[f"mrr_at_{k}"] = 0.0 scores[f"mrr_at_{k}"] = 0.0
return scores return scores
# Initialize evaluator # Initialize evaluator
evaluator = ViDoReBenchmarkEvaluator( evaluator = ViDoReBenchmarkEvaluator(
model_name=model_name, model_name=model_name,
@@ -236,20 +254,20 @@ def evaluate_task(
first_stage_k=first_stage_k, first_stage_k=first_stage_k,
k_values=k_values, k_values=k_values,
) )
# Build or load index # Build or load index
index_path_full = index_path if not use_fast_plaid else fast_plaid_index_path index_path_full = index_path if not use_fast_plaid else fast_plaid_index_path
if index_path_full is None: if index_path_full is None:
index_path_full = f"./indexes/{task_name}_{model_name}" index_path_full = f"./indexes/{task_name}_{model_name}"
if use_fast_plaid: if use_fast_plaid:
index_path_full = f"./indexes/{task_name}_{model_name}_fastplaid" index_path_full = f"./indexes/{task_name}_{model_name}_fastplaid"
index_or_retriever, corpus_ids_ordered = evaluator.build_index_from_corpus( index_or_retriever, corpus_ids_ordered = evaluator.build_index_from_corpus(
corpus=corpus, corpus=corpus,
index_path=index_path_full, index_path=index_path_full,
rebuild=rebuild_index, rebuild=rebuild_index,
) )
# Search queries # Search queries
task_prompt = task_config.get("prompt") task_prompt = task_config.get("prompt")
results = evaluator.search_queries( results = evaluator.search_queries(
@@ -259,32 +277,32 @@ def evaluate_task(
fast_plaid_index_path=fast_plaid_index_path, fast_plaid_index_path=fast_plaid_index_path,
task_prompt=task_prompt, task_prompt=task_prompt,
) )
# Evaluate # Evaluate
scores = evaluator.evaluate_results(results, qrels, k_values=k_values) scores = evaluator.evaluate_results(results, qrels, k_values=k_values)
# Print results # Print results
print(f"\n{'='*80}") print(f"\n{'=' * 80}")
print(f"Results for {task_name}:") print(f"Results for {task_name}:")
print(f"{'='*80}") print(f"{'=' * 80}")
for metric, value in scores.items(): for metric, value in scores.items():
if isinstance(value, (int, float)): if isinstance(value, (int, float)):
print(f" {metric}: {value:.5f}") print(f" {metric}: {value:.5f}")
# Save results # Save results
if output_dir: if output_dir:
os.makedirs(output_dir, exist_ok=True) os.makedirs(output_dir, exist_ok=True)
results_file = os.path.join(output_dir, f"{task_name}_results.json") results_file = os.path.join(output_dir, f"{task_name}_results.json")
scores_file = os.path.join(output_dir, f"{task_name}_scores.json") scores_file = os.path.join(output_dir, f"{task_name}_scores.json")
with open(results_file, "w") as f: with open(results_file, "w") as f:
json.dump(results, f, indent=2) json.dump(results, f, indent=2)
print(f"\nSaved results to: {results_file}") print(f"\nSaved results to: {results_file}")
with open(scores_file, "w") as f: with open(scores_file, "w") as f:
json.dump(scores, f, indent=2) json.dump(scores, f, indent=2)
print(f"Saved scores to: {scores_file}") print(f"Saved scores to: {scores_file}")
return scores return scores
@@ -363,12 +381,12 @@ def main():
default="./vidore_v2_results", default="./vidore_v2_results",
help="Output directory for results", help="Output directory for results",
) )
args = parser.parse_args() args = parser.parse_args()
# Parse k_values # Parse k_values
k_values = [int(k.strip()) for k in args.k_values.split(",")] k_values = [int(k.strip()) for k in args.k_values.split(",")]
# Determine tasks to evaluate # Determine tasks to evaluate
if args.task: if args.task:
tasks_to_eval = [args.task] tasks_to_eval = [args.task]
@@ -376,9 +394,9 @@ def main():
tasks_to_eval = list(VIDORE_V2_TASKS.keys()) tasks_to_eval = list(VIDORE_V2_TASKS.keys())
else: else:
tasks_to_eval = [t.strip() for t in args.tasks.split(",")] tasks_to_eval = [t.strip() for t in args.tasks.split(",")]
print(f"Tasks to evaluate: {tasks_to_eval}") print(f"Tasks to evaluate: {tasks_to_eval}")
# Evaluate each task # Evaluate each task
all_scores = {} all_scores = {}
for task_name in tasks_to_eval: for task_name in tasks_to_eval:
@@ -400,14 +418,15 @@ def main():
except Exception as e: except Exception as e:
print(f"\nError evaluating {task_name}: {e}") print(f"\nError evaluating {task_name}: {e}")
import traceback import traceback
traceback.print_exc() traceback.print_exc()
continue continue
# Print summary # Print summary
if all_scores: if all_scores:
print(f"\n{'='*80}") print(f"\n{'=' * 80}")
print("SUMMARY") print("SUMMARY")
print(f"{'='*80}") print(f"{'=' * 80}")
for task_name, scores in all_scores.items(): for task_name, scores in all_scores.items():
print(f"\n{task_name}:") print(f"\n{task_name}:")
# Print main metrics # Print main metrics
@@ -418,4 +437,3 @@ def main():
if __name__ == "__main__": if __name__ == "__main__":
main() main()