reproduce docvqa results and add debug file
This commit is contained in:
132
apps/multimodal/vision-based-pdf-multi-vector/colqwen_forward.py
Executable file
132
apps/multimodal/vision-based-pdf-multi-vector/colqwen_forward.py
Executable file
@@ -0,0 +1,132 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""Simple test script to test colqwen2 forward pass with a single image."""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
# Add the current directory to path to import leann_multi_vector
|
||||||
|
sys.path.insert(0, str(Path(__file__).parent))
|
||||||
|
|
||||||
|
from PIL import Image
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from leann_multi_vector import _load_colvision, _embed_images, _ensure_repo_paths_importable
|
||||||
|
|
||||||
|
# Ensure repo paths are importable
|
||||||
|
_ensure_repo_paths_importable(__file__)
|
||||||
|
|
||||||
|
# Set environment variable
|
||||||
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||||
|
|
||||||
|
|
||||||
|
def create_test_image():
|
||||||
|
"""Create a simple test image."""
|
||||||
|
# Create a simple RGB image (800x600)
|
||||||
|
img = Image.new('RGB', (800, 600), color='white')
|
||||||
|
return img
|
||||||
|
|
||||||
|
|
||||||
|
def load_test_image_from_file():
|
||||||
|
"""Try to load an image from the indexes directory if available."""
|
||||||
|
# Try to find an existing image in the indexes directory
|
||||||
|
indexes_dir = Path(__file__).parent / "indexes"
|
||||||
|
|
||||||
|
# Look for images in common locations
|
||||||
|
possible_paths = [
|
||||||
|
indexes_dir / "vidore_fastplaid" / "images",
|
||||||
|
indexes_dir / "colvision_large.leann.images",
|
||||||
|
indexes_dir / "colvision.leann.images",
|
||||||
|
]
|
||||||
|
|
||||||
|
for img_dir in possible_paths:
|
||||||
|
if img_dir.exists():
|
||||||
|
# Find first image file
|
||||||
|
for ext in ['.png', '.jpg', '.jpeg']:
|
||||||
|
for img_file in img_dir.glob(f'*{ext}'):
|
||||||
|
print(f"Loading test image from: {img_file}")
|
||||||
|
return Image.open(img_file)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
print("=" * 60)
|
||||||
|
print("Testing ColQwen2 Forward Pass")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
# Step 1: Load or create test image
|
||||||
|
print("\n[Step 1] Loading test image...")
|
||||||
|
test_image = load_test_image_from_file()
|
||||||
|
if test_image is None:
|
||||||
|
print("No existing image found, creating a simple test image...")
|
||||||
|
test_image = create_test_image()
|
||||||
|
else:
|
||||||
|
print(f"✓ Loaded image: {test_image.size} ({test_image.mode})")
|
||||||
|
|
||||||
|
# Convert to RGB if needed
|
||||||
|
if test_image.mode != 'RGB':
|
||||||
|
test_image = test_image.convert('RGB')
|
||||||
|
print(f"✓ Converted to RGB: {test_image.size}")
|
||||||
|
|
||||||
|
# Step 2: Load model
|
||||||
|
print("\n[Step 2] Loading ColQwen2 model...")
|
||||||
|
try:
|
||||||
|
model_name, model, processor, device_str, device, dtype = _load_colvision("colqwen2")
|
||||||
|
print(f"✓ Model loaded: {model_name}")
|
||||||
|
print(f"✓ Device: {device_str}, dtype: {dtype}")
|
||||||
|
|
||||||
|
# Print model info
|
||||||
|
if hasattr(model, 'device'):
|
||||||
|
print(f"✓ Model device: {model.device}")
|
||||||
|
if hasattr(model, 'dtype'):
|
||||||
|
print(f"✓ Model dtype: {model.dtype}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"✗ Error loading model: {e}")
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
|
return
|
||||||
|
|
||||||
|
# Step 3: Test forward pass
|
||||||
|
print("\n[Step 3] Running forward pass...")
|
||||||
|
try:
|
||||||
|
# Use the _embed_images function which handles batching and forward pass
|
||||||
|
images = [test_image]
|
||||||
|
print(f"Processing {len(images)} image(s)...")
|
||||||
|
|
||||||
|
doc_vecs = _embed_images(model, processor, images)
|
||||||
|
|
||||||
|
print(f"✓ Forward pass completed!")
|
||||||
|
print(f"✓ Number of embeddings: {len(doc_vecs)}")
|
||||||
|
|
||||||
|
if len(doc_vecs) > 0:
|
||||||
|
emb = doc_vecs[0]
|
||||||
|
print(f"✓ Embedding shape: {emb.shape}")
|
||||||
|
print(f"✓ Embedding dtype: {emb.dtype}")
|
||||||
|
print(f"✓ Embedding stats:")
|
||||||
|
print(f" - Min: {emb.min().item():.4f}")
|
||||||
|
print(f" - Max: {emb.max().item():.4f}")
|
||||||
|
print(f" - Mean: {emb.mean().item():.4f}")
|
||||||
|
print(f" - Std: {emb.std().item():.4f}")
|
||||||
|
|
||||||
|
# Check for NaN or Inf
|
||||||
|
if torch.isnan(emb).any():
|
||||||
|
print("⚠ Warning: Embedding contains NaN values!")
|
||||||
|
if torch.isinf(emb).any():
|
||||||
|
print("⚠ Warning: Embedding contains Inf values!")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"✗ Error during forward pass: {e}")
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
|
return
|
||||||
|
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("Test completed successfully!")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
|
||||||
@@ -227,28 +227,26 @@ def _embed_queries(model, processor, queries: list[str]) -> list[Any]:
|
|||||||
# 2. Manually adds: query_prefix + text + query_augmentation_token * 10
|
# 2. Manually adds: query_prefix + text + query_augmentation_token * 10
|
||||||
# 3. Calls processor.process_queries(batch) where batch is now a list of strings
|
# 3. Calls processor.process_queries(batch) where batch is now a list of strings
|
||||||
# 4. process_queries adds: query_prefix + text + suffix (suffix = query_augmentation_token * 10)
|
# 4. process_queries adds: query_prefix + text + suffix (suffix = query_augmentation_token * 10)
|
||||||
#
|
#
|
||||||
# This results in duplicate addition: query_prefix is added twice, query_augmentation_token * 20 total
|
# This results in duplicate addition: query_prefix is added twice, query_augmentation_token * 20 total
|
||||||
# We need to match this exactly to reproduce MTEB results
|
# We need to match this exactly to reproduce MTEB results
|
||||||
|
|
||||||
all_embeds = []
|
all_embeds = []
|
||||||
batch_size = 32 # Match MTEB's default batch_size
|
batch_size = 32 # Match MTEB's default batch_size
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for i in tqdm(range(0, len(queries), batch_size), desc="Embedding queries"):
|
for i in tqdm(range(0, len(queries), batch_size), desc="Embedding queries"):
|
||||||
batch_queries = queries[i:i + batch_size]
|
batch_queries = queries[i : i + batch_size]
|
||||||
|
|
||||||
# Match MTEB: manually add query_prefix + text + query_augmentation_token * 10
|
# Match MTEB: manually add query_prefix + text + query_augmentation_token * 10
|
||||||
# Then process_queries will add them again (resulting in 20 augmentation tokens total)
|
# Then process_queries will add them again (resulting in 20 augmentation tokens total)
|
||||||
batch = [
|
batch = [
|
||||||
processor.query_prefix
|
processor.query_prefix + t + processor.query_augmentation_token * 10
|
||||||
+ t
|
|
||||||
+ processor.query_augmentation_token * 10
|
|
||||||
for t in batch_queries
|
for t in batch_queries
|
||||||
]
|
]
|
||||||
inputs = processor.process_queries(batch)
|
inputs = processor.process_queries(batch)
|
||||||
inputs = {k: v.to(model.device) for k, v in inputs.items()}
|
inputs = {k: v.to(model.device) for k, v in inputs.items()}
|
||||||
|
|
||||||
if model.device.type == "cuda":
|
if model.device.type == "cuda":
|
||||||
with torch.autocast(
|
with torch.autocast(
|
||||||
device_type="cuda",
|
device_type="cuda",
|
||||||
@@ -257,10 +255,10 @@ def _embed_queries(model, processor, queries: list[str]) -> list[Any]:
|
|||||||
outs = model(**inputs)
|
outs = model(**inputs)
|
||||||
else:
|
else:
|
||||||
outs = model(**inputs)
|
outs = model(**inputs)
|
||||||
|
|
||||||
# Match MTEB: convert to float32 on CPU
|
# Match MTEB: convert to float32 on CPU
|
||||||
all_embeds.extend(list(torch.unbind(outs.cpu().to(torch.float32))))
|
all_embeds.extend(list(torch.unbind(outs.cpu().to(torch.float32))))
|
||||||
|
|
||||||
return all_embeds
|
return all_embeds
|
||||||
|
|
||||||
|
|
||||||
@@ -309,74 +307,82 @@ def _build_fast_plaid_index(
|
|||||||
) -> tuple[Any, float]:
|
) -> tuple[Any, float]:
|
||||||
"""
|
"""
|
||||||
Build a Fast-Plaid index from document embeddings.
|
Build a Fast-Plaid index from document embeddings.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
index_path: Path to save the Fast-Plaid index
|
index_path: Path to save the Fast-Plaid index
|
||||||
doc_vecs: List of document embeddings (each is a tensor with shape [num_tokens, embedding_dim])
|
doc_vecs: List of document embeddings (each is a tensor with shape [num_tokens, embedding_dim])
|
||||||
filepaths: List of filepath identifiers for each document
|
filepaths: List of filepath identifiers for each document
|
||||||
images: List of PIL Images corresponding to each document
|
images: List of PIL Images corresponding to each document
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple of (FastPlaid index object, build_time_in_seconds)
|
Tuple of (FastPlaid index object, build_time_in_seconds)
|
||||||
"""
|
"""
|
||||||
import torch
|
import torch
|
||||||
from fast_plaid import search as fast_plaid_search
|
from fast_plaid import search as fast_plaid_search
|
||||||
|
|
||||||
print(f" Preparing {len(doc_vecs)} document embeddings for Fast-Plaid...")
|
print(f" Preparing {len(doc_vecs)} document embeddings for Fast-Plaid...")
|
||||||
_t0 = time.perf_counter()
|
_t0 = time.perf_counter()
|
||||||
|
|
||||||
# Convert doc_vecs to list of tensors
|
# Convert doc_vecs to list of tensors
|
||||||
documents_embeddings = []
|
documents_embeddings = []
|
||||||
for i, vec in enumerate(doc_vecs):
|
for i, vec in enumerate(doc_vecs):
|
||||||
if i % 1000 == 0:
|
if i % 1000 == 0:
|
||||||
print(f" Converting embedding {i}/{len(doc_vecs)}...")
|
print(f" Converting embedding {i}/{len(doc_vecs)}...")
|
||||||
if not isinstance(vec, torch.Tensor):
|
if not isinstance(vec, torch.Tensor):
|
||||||
vec = torch.tensor(vec) if isinstance(vec, np.ndarray) else torch.from_numpy(np.array(vec))
|
vec = (
|
||||||
|
torch.tensor(vec)
|
||||||
|
if isinstance(vec, np.ndarray)
|
||||||
|
else torch.from_numpy(np.array(vec))
|
||||||
|
)
|
||||||
# Ensure float32 for Fast-Plaid
|
# Ensure float32 for Fast-Plaid
|
||||||
if vec.dtype != torch.float32:
|
if vec.dtype != torch.float32:
|
||||||
vec = vec.float()
|
vec = vec.float()
|
||||||
documents_embeddings.append(vec)
|
documents_embeddings.append(vec)
|
||||||
|
|
||||||
print(f" Converted {len(documents_embeddings)} embeddings")
|
print(f" Converted {len(documents_embeddings)} embeddings")
|
||||||
if len(documents_embeddings) > 0:
|
if len(documents_embeddings) > 0:
|
||||||
print(f" First embedding shape: {documents_embeddings[0].shape}")
|
print(f" First embedding shape: {documents_embeddings[0].shape}")
|
||||||
print(f" First embedding dtype: {documents_embeddings[0].dtype}")
|
print(f" First embedding dtype: {documents_embeddings[0].dtype}")
|
||||||
|
|
||||||
# Prepare metadata for Fast-Plaid
|
# Prepare metadata for Fast-Plaid
|
||||||
print(f" Preparing metadata for {len(filepaths)} documents...")
|
print(f" Preparing metadata for {len(filepaths)} documents...")
|
||||||
metadata_list = []
|
metadata_list = []
|
||||||
for i, filepath in enumerate(filepaths):
|
for i, filepath in enumerate(filepaths):
|
||||||
metadata_list.append({
|
metadata_list.append(
|
||||||
"filepath": filepath,
|
{
|
||||||
"index": i,
|
"filepath": filepath,
|
||||||
})
|
"index": i,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
# Create Fast-Plaid index
|
# Create Fast-Plaid index
|
||||||
print(f" Creating FastPlaid object with index path: {index_path}")
|
print(f" Creating FastPlaid object with index path: {index_path}")
|
||||||
try:
|
try:
|
||||||
fast_plaid_index = fast_plaid_search.FastPlaid(index=index_path)
|
fast_plaid_index = fast_plaid_search.FastPlaid(index=index_path)
|
||||||
print(f" FastPlaid object created successfully")
|
print(" FastPlaid object created successfully")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f" Error creating FastPlaid object: {type(e).__name__}: {e}")
|
print(f" Error creating FastPlaid object: {type(e).__name__}: {e}")
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
raise
|
raise
|
||||||
|
|
||||||
print(f" Calling fast_plaid_index.create() with {len(documents_embeddings)} documents...")
|
print(f" Calling fast_plaid_index.create() with {len(documents_embeddings)} documents...")
|
||||||
try:
|
try:
|
||||||
fast_plaid_index.create(
|
fast_plaid_index.create(
|
||||||
documents_embeddings=documents_embeddings,
|
documents_embeddings=documents_embeddings,
|
||||||
metadata=metadata_list,
|
metadata=metadata_list,
|
||||||
)
|
)
|
||||||
print(f" Fast-Plaid index created successfully")
|
print(" Fast-Plaid index created successfully")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f" Error creating Fast-Plaid index: {type(e).__name__}: {e}")
|
print(f" Error creating Fast-Plaid index: {type(e).__name__}: {e}")
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
raise
|
raise
|
||||||
|
|
||||||
build_secs = time.perf_counter() - _t0
|
build_secs = time.perf_counter() - _t0
|
||||||
|
|
||||||
# Save images separately (Fast-Plaid doesn't store images)
|
# Save images separately (Fast-Plaid doesn't store images)
|
||||||
print(f" Saving {len(images)} images...")
|
print(f" Saving {len(images)} images...")
|
||||||
images_dir = Path(index_path) / "images"
|
images_dir = Path(index_path) / "images"
|
||||||
@@ -384,7 +390,7 @@ def _build_fast_plaid_index(
|
|||||||
for i, img in enumerate(tqdm(images, desc="Saving images")):
|
for i, img in enumerate(tqdm(images, desc="Saving images")):
|
||||||
img_path = images_dir / f"doc_{i}.png"
|
img_path = images_dir / f"doc_{i}.png"
|
||||||
img.save(str(img_path))
|
img.save(str(img_path))
|
||||||
|
|
||||||
return fast_plaid_index, build_secs
|
return fast_plaid_index, build_secs
|
||||||
|
|
||||||
|
|
||||||
@@ -392,30 +398,30 @@ def _fast_plaid_index_exists(index_path: str) -> bool:
|
|||||||
"""
|
"""
|
||||||
Check if Fast-Plaid index exists by checking for key files.
|
Check if Fast-Plaid index exists by checking for key files.
|
||||||
This avoids creating the FastPlaid object which may trigger memory allocation.
|
This avoids creating the FastPlaid object which may trigger memory allocation.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
index_path: Path to the Fast-Plaid index
|
index_path: Path to the Fast-Plaid index
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
True if index appears to exist, False otherwise
|
True if index appears to exist, False otherwise
|
||||||
"""
|
"""
|
||||||
index_path_obj = Path(index_path)
|
index_path_obj = Path(index_path)
|
||||||
if not index_path_obj.exists() or not index_path_obj.is_dir():
|
if not index_path_obj.exists() or not index_path_obj.is_dir():
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Fast-Plaid creates a SQLite database file for metadata
|
# Fast-Plaid creates a SQLite database file for metadata
|
||||||
# Check for metadata.db as the most reliable indicator
|
# Check for metadata.db as the most reliable indicator
|
||||||
metadata_db = index_path_obj / "metadata.db"
|
metadata_db = index_path_obj / "metadata.db"
|
||||||
if metadata_db.exists() and metadata_db.stat().st_size > 0:
|
if metadata_db.exists() and metadata_db.stat().st_size > 0:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
# Also check if directory has any files (might be incomplete index)
|
# Also check if directory has any files (might be incomplete index)
|
||||||
try:
|
try:
|
||||||
if any(index_path_obj.iterdir()):
|
if any(index_path_obj.iterdir()):
|
||||||
return True
|
return True
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
@@ -424,20 +430,20 @@ def _load_fast_plaid_index_if_exists(index_path: str) -> Optional[Any]:
|
|||||||
Load Fast-Plaid index if it exists.
|
Load Fast-Plaid index if it exists.
|
||||||
First checks if index files exist, then creates the FastPlaid object.
|
First checks if index files exist, then creates the FastPlaid object.
|
||||||
The actual index data loading happens lazily when search is called.
|
The actual index data loading happens lazily when search is called.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
index_path: Path to the Fast-Plaid index
|
index_path: Path to the Fast-Plaid index
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
FastPlaid index object if exists, None otherwise
|
FastPlaid index object if exists, None otherwise
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
from fast_plaid import search as fast_plaid_search
|
from fast_plaid import search as fast_plaid_search
|
||||||
|
|
||||||
# First check if index files exist without creating the object
|
# First check if index files exist without creating the object
|
||||||
if not _fast_plaid_index_exists(index_path):
|
if not _fast_plaid_index_exists(index_path):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# Now try to create FastPlaid object
|
# Now try to create FastPlaid object
|
||||||
# This may trigger some memory allocation, but the full index loading is deferred
|
# This may trigger some memory allocation, but the full index loading is deferred
|
||||||
fast_plaid_index = fast_plaid_search.FastPlaid(index=index_path)
|
fast_plaid_index = fast_plaid_search.FastPlaid(index=index_path)
|
||||||
@@ -459,81 +465,105 @@ def _search_fast_plaid(
|
|||||||
) -> tuple[list[tuple[float, int]], float]:
|
) -> tuple[list[tuple[float, int]], float]:
|
||||||
"""
|
"""
|
||||||
Search Fast-Plaid index with a query embedding.
|
Search Fast-Plaid index with a query embedding.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
fast_plaid_index: FastPlaid index object
|
fast_plaid_index: FastPlaid index object
|
||||||
query_vec: Query embedding tensor with shape [num_tokens, embedding_dim]
|
query_vec: Query embedding tensor with shape [num_tokens, embedding_dim]
|
||||||
top_k: Number of top results to return
|
top_k: Number of top results to return
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple of (results_list, search_time_in_seconds)
|
Tuple of (results_list, search_time_in_seconds)
|
||||||
results_list: List of (score, doc_id) tuples
|
results_list: List of (score, doc_id) tuples
|
||||||
"""
|
"""
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
_t0 = time.perf_counter()
|
_t0 = time.perf_counter()
|
||||||
|
|
||||||
# Ensure query is a torch tensor
|
# Ensure query is a torch tensor
|
||||||
if not isinstance(query_vec, torch.Tensor):
|
if not isinstance(query_vec, torch.Tensor):
|
||||||
q_vec_tensor = torch.tensor(query_vec) if isinstance(query_vec, np.ndarray) else torch.from_numpy(np.array(query_vec))
|
q_vec_tensor = (
|
||||||
|
torch.tensor(query_vec)
|
||||||
|
if isinstance(query_vec, np.ndarray)
|
||||||
|
else torch.from_numpy(np.array(query_vec))
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
q_vec_tensor = query_vec
|
q_vec_tensor = query_vec
|
||||||
|
|
||||||
# Fast-Plaid expects shape [num_queries, num_tokens, embedding_dim]
|
# Fast-Plaid expects shape [num_queries, num_tokens, embedding_dim]
|
||||||
if q_vec_tensor.dim() == 2:
|
if q_vec_tensor.dim() == 2:
|
||||||
q_vec_tensor = q_vec_tensor.unsqueeze(0) # [1, num_tokens, embedding_dim]
|
q_vec_tensor = q_vec_tensor.unsqueeze(0) # [1, num_tokens, embedding_dim]
|
||||||
|
|
||||||
# Perform search
|
# Perform search
|
||||||
scores = fast_plaid_index.search(
|
scores = fast_plaid_index.search(
|
||||||
queries_embeddings=q_vec_tensor,
|
queries_embeddings=q_vec_tensor,
|
||||||
top_k=top_k,
|
top_k=top_k,
|
||||||
show_progress=True,
|
show_progress=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
search_secs = time.perf_counter() - _t0
|
search_secs = time.perf_counter() - _t0
|
||||||
|
|
||||||
# Convert Fast-Plaid results to same format as LEANN: list of (score, doc_id) tuples
|
# Convert Fast-Plaid results to same format as LEANN: list of (score, doc_id) tuples
|
||||||
results = []
|
results = []
|
||||||
if scores and len(scores) > 0:
|
if scores and len(scores) > 0:
|
||||||
query_results = scores[0]
|
query_results = scores[0]
|
||||||
# Fast-Plaid returns (doc_id, score), convert to (score, doc_id) to match LEANN format
|
# Fast-Plaid returns (doc_id, score), convert to (score, doc_id) to match LEANN format
|
||||||
results = [(float(score), int(doc_id)) for doc_id, score in query_results]
|
results = [(float(score), int(doc_id)) for doc_id, score in query_results]
|
||||||
|
|
||||||
return results, search_secs
|
return results, search_secs
|
||||||
|
|
||||||
|
|
||||||
def _get_fast_plaid_image(index_path: str, doc_id: int) -> Optional[Image.Image]:
|
def _get_fast_plaid_image(index_path: str, doc_id: int) -> Optional[Image.Image]:
|
||||||
"""
|
"""
|
||||||
Retrieve image for a document from Fast-Plaid index.
|
Retrieve image for a document from Fast-Plaid index.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
index_path: Path to the Fast-Plaid index
|
index_path: Path to the Fast-Plaid index
|
||||||
doc_id: Document ID
|
doc_id: Document ID returned by Fast-Plaid search
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
PIL Image if found, None otherwise
|
PIL Image if found, None otherwise
|
||||||
|
|
||||||
|
Note: Uses metadata['index'] to get the actual file index, as Fast-Plaid
|
||||||
|
doc_id may differ from the file naming index.
|
||||||
"""
|
"""
|
||||||
|
# First get metadata to find the actual index used for file naming
|
||||||
|
metadata = _get_fast_plaid_metadata(index_path, doc_id)
|
||||||
|
if metadata is None:
|
||||||
|
# Fallback: try using doc_id directly
|
||||||
|
file_index = doc_id
|
||||||
|
else:
|
||||||
|
# Use the 'index' field from metadata, which matches the file naming
|
||||||
|
file_index = metadata.get("index", doc_id)
|
||||||
|
|
||||||
images_dir = Path(index_path) / "images"
|
images_dir = Path(index_path) / "images"
|
||||||
image_path = images_dir / f"doc_{doc_id}.png"
|
image_path = images_dir / f"doc_{file_index}.png"
|
||||||
|
|
||||||
if image_path.exists():
|
if image_path.exists():
|
||||||
return Image.open(image_path)
|
return Image.open(image_path)
|
||||||
|
|
||||||
|
# If not found with index, try doc_id as fallback
|
||||||
|
if file_index != doc_id:
|
||||||
|
fallback_path = images_dir / f"doc_{doc_id}.png"
|
||||||
|
if fallback_path.exists():
|
||||||
|
return Image.open(fallback_path)
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def _get_fast_plaid_metadata(index_path: str, doc_id: int) -> Optional[dict]:
|
def _get_fast_plaid_metadata(index_path: str, doc_id: int) -> Optional[dict]:
|
||||||
"""
|
"""
|
||||||
Retrieve metadata for a document from Fast-Plaid index.
|
Retrieve metadata for a document from Fast-Plaid index.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
index_path: Path to the Fast-Plaid index
|
index_path: Path to the Fast-Plaid index
|
||||||
doc_id: Document ID
|
doc_id: Document ID
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dictionary with metadata if found, None otherwise
|
Dictionary with metadata if found, None otherwise
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
from fast_plaid import filtering
|
from fast_plaid import filtering
|
||||||
|
|
||||||
metadata_list = filtering.get(index=index_path, subset=[doc_id])
|
metadata_list = filtering.get(index=index_path, subset=[doc_id])
|
||||||
if metadata_list and len(metadata_list) > 0:
|
if metadata_list and len(metadata_list) > 0:
|
||||||
return metadata_list[0]
|
return metadata_list[0]
|
||||||
@@ -1053,18 +1083,18 @@ class ViDoReBenchmarkEvaluator:
|
|||||||
A reusable class for evaluating ViDoRe benchmarks (v1 and v2).
|
A reusable class for evaluating ViDoRe benchmarks (v1 and v2).
|
||||||
This class encapsulates common functionality for building indexes, searching, and evaluating.
|
This class encapsulates common functionality for building indexes, searching, and evaluating.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
use_fast_plaid: bool = False,
|
use_fast_plaid: bool = False,
|
||||||
top_k: int = 100,
|
top_k: int = 100,
|
||||||
first_stage_k: int = 500,
|
first_stage_k: int = 500,
|
||||||
k_values: list[int] = None,
|
k_values: Optional[list[int]] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Initialize the evaluator.
|
Initialize the evaluator.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model_name: Model name ("colqwen2" or "colpali")
|
model_name: Model name ("colqwen2" or "colpali")
|
||||||
use_fast_plaid: Whether to use Fast-Plaid instead of LEANN
|
use_fast_plaid: Whether to use Fast-Plaid instead of LEANN
|
||||||
@@ -1077,19 +1107,21 @@ class ViDoReBenchmarkEvaluator:
|
|||||||
self.top_k = top_k
|
self.top_k = top_k
|
||||||
self.first_stage_k = first_stage_k
|
self.first_stage_k = first_stage_k
|
||||||
self.k_values = k_values if k_values is not None else [1, 3, 5, 10, 100]
|
self.k_values = k_values if k_values is not None else [1, 3, 5, 10, 100]
|
||||||
|
|
||||||
# Load model once (can be reused across tasks)
|
# Load model once (can be reused across tasks)
|
||||||
self._model = None
|
self._model = None
|
||||||
self._processor = None
|
self._processor = None
|
||||||
self._model_name_actual = None
|
self._model_name_actual = None
|
||||||
|
|
||||||
def _load_model_if_needed(self):
|
def _load_model_if_needed(self):
|
||||||
"""Lazy load the model."""
|
"""Lazy load the model."""
|
||||||
if self._model is None:
|
if self._model is None:
|
||||||
print(f"\nLoading model: {self.model_name}")
|
print(f"\nLoading model: {self.model_name}")
|
||||||
self._model_name_actual, self._model, self._processor, _, _, _ = _load_colvision(self.model_name)
|
self._model_name_actual, self._model, self._processor, _, _, _ = _load_colvision(
|
||||||
|
self.model_name
|
||||||
|
)
|
||||||
print(f"Model loaded: {self._model_name_actual}")
|
print(f"Model loaded: {self._model_name_actual}")
|
||||||
|
|
||||||
def build_index_from_corpus(
|
def build_index_from_corpus(
|
||||||
self,
|
self,
|
||||||
corpus: dict[str, Image.Image],
|
corpus: dict[str, Image.Image],
|
||||||
@@ -1098,31 +1130,31 @@ class ViDoReBenchmarkEvaluator:
|
|||||||
) -> tuple[Any, list[str]]:
|
) -> tuple[Any, list[str]]:
|
||||||
"""
|
"""
|
||||||
Build index from corpus images.
|
Build index from corpus images.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
corpus: dict mapping corpus_id to PIL Image
|
corpus: dict mapping corpus_id to PIL Image
|
||||||
index_path: Path to save/load the index
|
index_path: Path to save/load the index
|
||||||
rebuild: Whether to rebuild even if index exists
|
rebuild: Whether to rebuild even if index exists
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
tuple: (retriever or fast_plaid_index object, list of corpus_ids in order)
|
tuple: (retriever or fast_plaid_index object, list of corpus_ids in order)
|
||||||
"""
|
"""
|
||||||
self._load_model_if_needed()
|
self._load_model_if_needed()
|
||||||
|
|
||||||
# Ensure consistent ordering
|
# Ensure consistent ordering
|
||||||
corpus_ids = sorted(corpus.keys())
|
corpus_ids = sorted(corpus.keys())
|
||||||
images = [corpus[cid] for cid in corpus_ids]
|
images = [corpus[cid] for cid in corpus_ids]
|
||||||
|
|
||||||
if self.use_fast_plaid:
|
if self.use_fast_plaid:
|
||||||
# Check if Fast-Plaid index exists
|
# Check if Fast-Plaid index exists
|
||||||
if not rebuild and _load_fast_plaid_index_if_exists(index_path) is not None:
|
if not rebuild and _load_fast_plaid_index_if_exists(index_path) is not None:
|
||||||
print(f"Fast-Plaid index already exists at {index_path}")
|
print(f"Fast-Plaid index already exists at {index_path}")
|
||||||
return _load_fast_plaid_index_if_exists(index_path), corpus_ids
|
return _load_fast_plaid_index_if_exists(index_path), corpus_ids
|
||||||
|
|
||||||
print(f"Building Fast-Plaid index at {index_path}...")
|
print(f"Building Fast-Plaid index at {index_path}...")
|
||||||
print("Embedding images...")
|
print("Embedding images...")
|
||||||
doc_vecs = _embed_images(self._model, self._processor, images)
|
doc_vecs = _embed_images(self._model, self._processor, images)
|
||||||
|
|
||||||
fast_plaid_index, build_time = _build_fast_plaid_index(
|
fast_plaid_index, build_time = _build_fast_plaid_index(
|
||||||
index_path, doc_vecs, corpus_ids, images
|
index_path, doc_vecs, corpus_ids, images
|
||||||
)
|
)
|
||||||
@@ -1135,15 +1167,15 @@ class ViDoReBenchmarkEvaluator:
|
|||||||
if retriever is not None:
|
if retriever is not None:
|
||||||
print(f"LEANN index already exists at {index_path}")
|
print(f"LEANN index already exists at {index_path}")
|
||||||
return retriever, corpus_ids
|
return retriever, corpus_ids
|
||||||
|
|
||||||
print(f"Building LEANN index at {index_path}...")
|
print(f"Building LEANN index at {index_path}...")
|
||||||
print("Embedding images...")
|
print("Embedding images...")
|
||||||
doc_vecs = _embed_images(self._model, self._processor, images)
|
doc_vecs = _embed_images(self._model, self._processor, images)
|
||||||
|
|
||||||
retriever = _build_index(index_path, doc_vecs, corpus_ids, images)
|
retriever = _build_index(index_path, doc_vecs, corpus_ids, images)
|
||||||
print(f"LEANN index built")
|
print("LEANN index built")
|
||||||
return retriever, corpus_ids
|
return retriever, corpus_ids
|
||||||
|
|
||||||
def search_queries(
|
def search_queries(
|
||||||
self,
|
self,
|
||||||
queries: dict[str, str],
|
queries: dict[str, str],
|
||||||
@@ -1154,34 +1186,34 @@ class ViDoReBenchmarkEvaluator:
|
|||||||
) -> dict[str, dict[str, float]]:
|
) -> dict[str, dict[str, float]]:
|
||||||
"""
|
"""
|
||||||
Search queries against the index.
|
Search queries against the index.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
queries: dict mapping query_id to query text
|
queries: dict mapping query_id to query text
|
||||||
corpus_ids: list of corpus_ids in the same order as the index
|
corpus_ids: list of corpus_ids in the same order as the index
|
||||||
index_or_retriever: index or retriever object
|
index_or_retriever: index or retriever object
|
||||||
fast_plaid_index_path: path to Fast-Plaid index (for metadata)
|
fast_plaid_index_path: path to Fast-Plaid index (for metadata)
|
||||||
task_prompt: Optional dict with prompt for query (e.g., {"query": "..."})
|
task_prompt: Optional dict with prompt for query (e.g., {"query": "..."})
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
results: dict mapping query_id to dict of {corpus_id: score}
|
results: dict mapping query_id to dict of {corpus_id: score}
|
||||||
"""
|
"""
|
||||||
self._load_model_if_needed()
|
self._load_model_if_needed()
|
||||||
|
|
||||||
print(f"Searching {len(queries)} queries (top_k={self.top_k})...")
|
print(f"Searching {len(queries)} queries (top_k={self.top_k})...")
|
||||||
|
|
||||||
query_ids = list(queries.keys())
|
query_ids = list(queries.keys())
|
||||||
query_texts = [queries[qid] for qid in query_ids]
|
query_texts = [queries[qid] for qid in query_ids]
|
||||||
|
|
||||||
# Note: ColPaliEngineWrapper does NOT use task prompt from metadata
|
# Note: ColPaliEngineWrapper does NOT use task prompt from metadata
|
||||||
# It uses query_prefix + text + query_augmentation_token (handled in _embed_queries)
|
# It uses query_prefix + text + query_augmentation_token (handled in _embed_queries)
|
||||||
# So we don't append task_prompt here to match MTEB behavior
|
# So we don't append task_prompt here to match MTEB behavior
|
||||||
|
|
||||||
# Embed queries
|
# Embed queries
|
||||||
print("Embedding queries...")
|
print("Embedding queries...")
|
||||||
query_vecs = _embed_queries(self._model, self._processor, query_texts)
|
query_vecs = _embed_queries(self._model, self._processor, query_texts)
|
||||||
|
|
||||||
results = {}
|
results = {}
|
||||||
|
|
||||||
for query_id, query_vec in zip(tqdm(query_ids, desc="Searching"), query_vecs):
|
for query_id, query_vec in zip(tqdm(query_ids, desc="Searching"), query_vecs):
|
||||||
if self.use_fast_plaid:
|
if self.use_fast_plaid:
|
||||||
# Fast-Plaid search
|
# Fast-Plaid search
|
||||||
@@ -1194,47 +1226,51 @@ class ViDoReBenchmarkEvaluator:
|
|||||||
else:
|
else:
|
||||||
# LEANN search
|
# LEANN search
|
||||||
import torch
|
import torch
|
||||||
query_np = query_vec.float().numpy() if isinstance(query_vec, torch.Tensor) else query_vec
|
|
||||||
search_results = index_or_retriever.search_exact_all(query_np, topk=self.top_k)
|
query_np = (
|
||||||
|
query_vec.float().numpy() if isinstance(query_vec, torch.Tensor) else query_vec
|
||||||
|
)
|
||||||
|
search_results = index_or_retriever.search_exact(query_np, topk=self.top_k)
|
||||||
query_results = {}
|
query_results = {}
|
||||||
for score, doc_id in search_results:
|
for score, doc_id in search_results:
|
||||||
if doc_id < len(corpus_ids):
|
if doc_id < len(corpus_ids):
|
||||||
corpus_id = corpus_ids[doc_id]
|
corpus_id = corpus_ids[doc_id]
|
||||||
query_results[corpus_id] = float(score)
|
query_results[corpus_id] = float(score)
|
||||||
|
|
||||||
results[query_id] = query_results
|
results[query_id] = query_results
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def evaluate_results(
|
def evaluate_results(
|
||||||
results: dict[str, dict[str, float]],
|
results: dict[str, dict[str, float]],
|
||||||
qrels: dict[str, dict[str, int]],
|
qrels: dict[str, dict[str, int]],
|
||||||
k_values: list[int] = None,
|
k_values: Optional[list[int]] = None,
|
||||||
) -> dict[str, float]:
|
) -> dict[str, float]:
|
||||||
"""
|
"""
|
||||||
Evaluate retrieval results using NDCG and other metrics.
|
Evaluate retrieval results using NDCG and other metrics.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
results: dict mapping query_id to dict of {corpus_id: score}
|
results: dict mapping query_id to dict of {corpus_id: score}
|
||||||
qrels: dict mapping query_id to dict of {corpus_id: relevance_score}
|
qrels: dict mapping query_id to dict of {corpus_id: relevance_score}
|
||||||
k_values: List of k values for evaluation metrics
|
k_values: List of k values for evaluation metrics
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dictionary of metric scores
|
Dictionary of metric scores
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
import pytrec_eval
|
|
||||||
from mteb._evaluators.retrieval_metrics import (
|
from mteb._evaluators.retrieval_metrics import (
|
||||||
calculate_retrieval_scores,
|
calculate_retrieval_scores,
|
||||||
make_score_dict,
|
make_score_dict,
|
||||||
)
|
)
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise ImportError("pytrec_eval is required for evaluation. Install with: pip install pytrec-eval")
|
raise ImportError(
|
||||||
|
"pytrec_eval is required for evaluation. Install with: pip install pytrec-eval"
|
||||||
|
)
|
||||||
|
|
||||||
if k_values is None:
|
if k_values is None:
|
||||||
k_values = [1, 3, 5, 10, 100]
|
k_values = [1, 3, 5, 10, 100]
|
||||||
|
|
||||||
# Check if we have any queries to evaluate
|
# Check if we have any queries to evaluate
|
||||||
if len(results) == 0:
|
if len(results) == 0:
|
||||||
print("Warning: No queries to evaluate. Returning zero scores.")
|
print("Warning: No queries to evaluate. Returning zero scores.")
|
||||||
@@ -1246,38 +1282,42 @@ class ViDoReBenchmarkEvaluator:
|
|||||||
scores[f"precision_at_{k}"] = 0.0
|
scores[f"precision_at_{k}"] = 0.0
|
||||||
scores[f"mrr_at_{k}"] = 0.0
|
scores[f"mrr_at_{k}"] = 0.0
|
||||||
return scores
|
return scores
|
||||||
|
|
||||||
print(f"Evaluating results with k_values={k_values}...")
|
print(f"Evaluating results with k_values={k_values}...")
|
||||||
print(f"Before filtering: {len(results)} results, {len(qrels)} qrels")
|
print(f"Before filtering: {len(results)} results, {len(qrels)} qrels")
|
||||||
|
|
||||||
# Filter to ensure qrels and results have the same query set
|
# Filter to ensure qrels and results have the same query set
|
||||||
# This matches MTEB behavior: only evaluate queries that exist in both
|
# This matches MTEB behavior: only evaluate queries that exist in both
|
||||||
# pytrec_eval only evaluates queries in qrels, so we need to ensure
|
# pytrec_eval only evaluates queries in qrels, so we need to ensure
|
||||||
# results contains all queries in qrels, and filter out queries not in qrels
|
# results contains all queries in qrels, and filter out queries not in qrels
|
||||||
results_filtered = {qid: res for qid, res in results.items() if qid in qrels}
|
results_filtered = {qid: res for qid, res in results.items() if qid in qrels}
|
||||||
qrels_filtered = {qid: rel_docs for qid, rel_docs in qrels.items() if qid in results_filtered}
|
qrels_filtered = {
|
||||||
|
qid: rel_docs for qid, rel_docs in qrels.items() if qid in results_filtered
|
||||||
|
}
|
||||||
|
|
||||||
print(f"After filtering: {len(results_filtered)} results, {len(qrels_filtered)} qrels")
|
print(f"After filtering: {len(results_filtered)} results, {len(qrels_filtered)} qrels")
|
||||||
|
|
||||||
if len(results_filtered) != len(qrels_filtered):
|
if len(results_filtered) != len(qrels_filtered):
|
||||||
print(f"Warning: Mismatch between results ({len(results_filtered)}) and qrels ({len(qrels_filtered)}) queries")
|
print(
|
||||||
|
f"Warning: Mismatch between results ({len(results_filtered)}) and qrels ({len(qrels_filtered)}) queries"
|
||||||
|
)
|
||||||
missing_in_results = set(qrels.keys()) - set(results.keys())
|
missing_in_results = set(qrels.keys()) - set(results.keys())
|
||||||
if missing_in_results:
|
if missing_in_results:
|
||||||
print(f"Queries in qrels but not in results: {len(missing_in_results)} queries")
|
print(f"Queries in qrels but not in results: {len(missing_in_results)} queries")
|
||||||
print(f"First 5 missing queries: {list(missing_in_results)[:5]}")
|
print(f"First 5 missing queries: {list(missing_in_results)[:5]}")
|
||||||
|
|
||||||
# Convert qrels to pytrec_eval format
|
# Convert qrels to pytrec_eval format
|
||||||
qrels_pytrec = {}
|
qrels_pytrec = {}
|
||||||
for qid, rel_docs in qrels_filtered.items():
|
for qid, rel_docs in qrels_filtered.items():
|
||||||
qrels_pytrec[qid] = {did: score for did, score in rel_docs.items()}
|
qrels_pytrec[qid] = dict(rel_docs.items())
|
||||||
|
|
||||||
# Evaluate
|
# Evaluate
|
||||||
eval_result = calculate_retrieval_scores(
|
eval_result = calculate_retrieval_scores(
|
||||||
results=results_filtered,
|
results=results_filtered,
|
||||||
qrels=qrels_pytrec,
|
qrels=qrels_pytrec,
|
||||||
k_values=k_values,
|
k_values=k_values,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Format scores
|
# Format scores
|
||||||
scores = make_score_dict(
|
scores = make_score_dict(
|
||||||
ndcg=eval_result.ndcg,
|
ndcg=eval_result.ndcg,
|
||||||
@@ -1290,5 +1330,5 @@ class ViDoReBenchmarkEvaluator:
|
|||||||
cv_recall=eval_result.cv_recall,
|
cv_recall=eval_result.cv_recall,
|
||||||
task_scores={},
|
task_scores={},
|
||||||
)
|
)
|
||||||
|
|
||||||
return scores
|
return scores
|
||||||
|
|||||||
@@ -83,7 +83,7 @@ INDEX_PATH: str = "./indexes/colvision_large.leann"
|
|||||||
# These are now command-line arguments (see CLI overrides section)
|
# These are now command-line arguments (see CLI overrides section)
|
||||||
TOPK: int = 3
|
TOPK: int = 3
|
||||||
FIRST_STAGE_K: int = 500
|
FIRST_STAGE_K: int = 500
|
||||||
REBUILD_INDEX: bool = False
|
REBUILD_INDEX: bool = True
|
||||||
|
|
||||||
# Artifacts
|
# Artifacts
|
||||||
SAVE_TOP_IMAGE: Optional[str] = "./figures/retrieved_page.png"
|
SAVE_TOP_IMAGE: Optional[str] = "./figures/retrieved_page.png"
|
||||||
@@ -122,11 +122,18 @@ parser.add_argument(
|
|||||||
default="./indexes/colvision_fastplaid",
|
default="./indexes/colvision_fastplaid",
|
||||||
help="Path to the Fast-Plaid index. Default: './indexes/colvision_fastplaid'",
|
help="Path to the Fast-Plaid index. Default: './indexes/colvision_fastplaid'",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--topk",
|
||||||
|
type=int,
|
||||||
|
default=TOPK,
|
||||||
|
help=f"Number of top results to retrieve. Default: {TOPK}",
|
||||||
|
)
|
||||||
cli_args, _unknown = parser.parse_known_args()
|
cli_args, _unknown = parser.parse_known_args()
|
||||||
SEARCH_METHOD: str = cli_args.search_method
|
SEARCH_METHOD: str = cli_args.search_method
|
||||||
QUERY = cli_args.query # Override QUERY with CLI argument if provided
|
QUERY = cli_args.query # Override QUERY with CLI argument if provided
|
||||||
USE_FAST_PLAID: bool = cli_args.use_fast_plaid
|
USE_FAST_PLAID: bool = cli_args.use_fast_plaid
|
||||||
FAST_PLAID_INDEX_PATH: str = cli_args.fast_plaid_index_path
|
FAST_PLAID_INDEX_PATH: str = cli_args.fast_plaid_index_path
|
||||||
|
TOPK: int = cli_args.topk # Override TOPK with CLI argument if provided
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
|
|
||||||
@@ -399,7 +406,7 @@ if need_to_build_index:
|
|||||||
f"No images found in {PAGES_DIR}. Provide PDF path in PDF variable or ensure images exist."
|
f"No images found in {PAGES_DIR}. Provide PDF path in PDF variable or ensure images exist."
|
||||||
)
|
)
|
||||||
print(f"Loaded {len(images)} images")
|
print(f"Loaded {len(images)} images")
|
||||||
|
|
||||||
# Memory check before loading model
|
# Memory check before loading model
|
||||||
try:
|
try:
|
||||||
import psutil
|
import psutil
|
||||||
@@ -426,10 +433,10 @@ try:
|
|||||||
import sys
|
import sys
|
||||||
print(f" Python version: {sys.version}")
|
print(f" Python version: {sys.version}")
|
||||||
print(f" Python executable: {sys.executable}")
|
print(f" Python executable: {sys.executable}")
|
||||||
|
|
||||||
model_name, model, processor, device_str, device, dtype = _load_colvision(MODEL)
|
model_name, model, processor, device_str, device, dtype = _load_colvision(MODEL)
|
||||||
print(f"✓ Using model={model_name}, device={device_str}, dtype={dtype}")
|
print(f"✓ Using model={model_name}, device={device_str}, dtype={dtype}")
|
||||||
|
|
||||||
# Memory check after loading model
|
# Memory check after loading model
|
||||||
try:
|
try:
|
||||||
import psutil
|
import psutil
|
||||||
@@ -457,7 +464,7 @@ if need_to_build_index:
|
|||||||
print("Step 4: Building index...")
|
print("Step 4: Building index...")
|
||||||
print(f" Number of images: {len(images)}")
|
print(f" Number of images: {len(images)}")
|
||||||
print(f" Number of filepaths: {len(filepaths)}")
|
print(f" Number of filepaths: {len(filepaths)}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
print(" Embedding images...")
|
print(" Embedding images...")
|
||||||
doc_vecs = _embed_images(model, processor, images)
|
doc_vecs = _embed_images(model, processor, images)
|
||||||
@@ -468,7 +475,7 @@ if need_to_build_index:
|
|||||||
import traceback
|
import traceback
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
raise
|
raise
|
||||||
|
|
||||||
if USE_FAST_PLAID:
|
if USE_FAST_PLAID:
|
||||||
# Build Fast-Plaid index
|
# Build Fast-Plaid index
|
||||||
print(" Building Fast-Plaid index...")
|
print(" Building Fast-Plaid index...")
|
||||||
@@ -523,13 +530,13 @@ if USE_FAST_PLAID:
|
|||||||
fast_plaid_index = _load_fast_plaid_index_if_exists(FAST_PLAID_INDEX_PATH)
|
fast_plaid_index = _load_fast_plaid_index_if_exists(FAST_PLAID_INDEX_PATH)
|
||||||
if fast_plaid_index is None:
|
if fast_plaid_index is None:
|
||||||
raise RuntimeError(f"Fast-Plaid index not found at {FAST_PLAID_INDEX_PATH}")
|
raise RuntimeError(f"Fast-Plaid index not found at {FAST_PLAID_INDEX_PATH}")
|
||||||
|
|
||||||
results, search_secs = _search_fast_plaid(fast_plaid_index, q_vec, TOPK)
|
results, search_secs = _search_fast_plaid(fast_plaid_index, q_vec, TOPK)
|
||||||
print(f"[Timing] Fast-Plaid Search: {search_secs:.3f}s")
|
print(f"[Timing] Fast-Plaid Search: {search_secs:.3f}s")
|
||||||
else:
|
else:
|
||||||
# Original LEANN search
|
# Original LEANN search
|
||||||
query_np = q_vec.float().numpy()
|
query_np = q_vec.float().numpy()
|
||||||
|
|
||||||
if SEARCH_METHOD == "ann":
|
if SEARCH_METHOD == "ann":
|
||||||
results = retriever.search(query_np, topk=TOPK, first_stage_k=FIRST_STAGE_K)
|
results = retriever.search(query_np, topk=TOPK, first_stage_k=FIRST_STAGE_K)
|
||||||
search_secs = time.perf_counter() - _t0
|
search_secs = time.perf_counter() - _t0
|
||||||
@@ -548,7 +555,10 @@ if not results:
|
|||||||
print("No results found.")
|
print("No results found.")
|
||||||
else:
|
else:
|
||||||
print(f'Top {len(results)} results for query: "{QUERY}"')
|
print(f'Top {len(results)} results for query: "{QUERY}"')
|
||||||
|
print("\n[DEBUG] Retrieval details:")
|
||||||
top_images: list[Image.Image] = []
|
top_images: list[Image.Image] = []
|
||||||
|
image_hashes = {} # Track image hashes to detect duplicates
|
||||||
|
|
||||||
for rank, (score, doc_id) in enumerate(results, start=1):
|
for rank, (score, doc_id) in enumerate(results, start=1):
|
||||||
# Retrieve image and metadata based on index type
|
# Retrieve image and metadata based on index type
|
||||||
if USE_FAST_PLAID:
|
if USE_FAST_PLAID:
|
||||||
@@ -557,7 +567,7 @@ else:
|
|||||||
if image is None:
|
if image is None:
|
||||||
print(f"Warning: Could not find image for doc_id {doc_id}")
|
print(f"Warning: Could not find image for doc_id {doc_id}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
metadata = _get_fast_plaid_metadata(FAST_PLAID_INDEX_PATH, doc_id)
|
metadata = _get_fast_plaid_metadata(FAST_PLAID_INDEX_PATH, doc_id)
|
||||||
path = metadata.get("filepath", f"doc_{doc_id}") if metadata else f"doc_{doc_id}"
|
path = metadata.get("filepath", f"doc_{doc_id}") if metadata else f"doc_{doc_id}"
|
||||||
top_images.append(image)
|
top_images.append(image)
|
||||||
@@ -571,9 +581,27 @@ else:
|
|||||||
metadata = retriever.get_metadata(doc_id)
|
metadata = retriever.get_metadata(doc_id)
|
||||||
path = metadata.get("filepath", "unknown") if metadata else "unknown"
|
path = metadata.get("filepath", "unknown") if metadata else "unknown"
|
||||||
top_images.append(image)
|
top_images.append(image)
|
||||||
|
|
||||||
# For HF dataset, path is a descriptive identifier, not a real file path
|
# Calculate image hash to detect duplicates
|
||||||
print(f"{rank}) MaxSim: {score:.4f}, Page: {path}")
|
import hashlib
|
||||||
|
import io
|
||||||
|
# Convert image to bytes for hashing
|
||||||
|
img_bytes = io.BytesIO()
|
||||||
|
image.save(img_bytes, format='PNG')
|
||||||
|
image_bytes = img_bytes.getvalue()
|
||||||
|
image_hash = hashlib.md5(image_bytes).hexdigest()[:8]
|
||||||
|
|
||||||
|
# Check if this image was already seen
|
||||||
|
duplicate_info = ""
|
||||||
|
if image_hash in image_hashes:
|
||||||
|
duplicate_info = f" [DUPLICATE of rank {image_hashes[image_hash]}]"
|
||||||
|
else:
|
||||||
|
image_hashes[image_hash] = rank
|
||||||
|
|
||||||
|
# Print detailed information
|
||||||
|
print(f"{rank}) doc_id={doc_id}, MaxSim={score:.4f}, Page={path}, ImageHash={image_hash}{duplicate_info}")
|
||||||
|
if metadata:
|
||||||
|
print(f" Metadata: {metadata}")
|
||||||
|
|
||||||
if SAVE_TOP_IMAGE:
|
if SAVE_TOP_IMAGE:
|
||||||
from pathlib import Path as _Path
|
from pathlib import Path as _Path
|
||||||
|
|||||||
@@ -11,13 +11,13 @@ This script uses the interface from leann_multi_vector.py to:
|
|||||||
Usage:
|
Usage:
|
||||||
# Evaluate all ViDoRe v1 tasks
|
# Evaluate all ViDoRe v1 tasks
|
||||||
python vidore_v1_benchmark.py --model colqwen2 --tasks all
|
python vidore_v1_benchmark.py --model colqwen2 --tasks all
|
||||||
|
|
||||||
# Evaluate specific task
|
# Evaluate specific task
|
||||||
python vidore_v1_benchmark.py --model colqwen2 --task VidoreArxivQARetrieval
|
python vidore_v1_benchmark.py --model colqwen2 --task VidoreArxivQARetrieval
|
||||||
|
|
||||||
# Use Fast-Plaid index
|
# Use Fast-Plaid index
|
||||||
python vidore_v1_benchmark.py --model colqwen2 --use-fast-plaid --fast-plaid-index-path ./indexes/vidore_fastplaid
|
python vidore_v1_benchmark.py --model colqwen2 --use-fast-plaid --fast-plaid-index-path ./indexes/vidore_fastplaid
|
||||||
|
|
||||||
# Rebuild index
|
# Rebuild index
|
||||||
python vidore_v1_benchmark.py --model colqwen2 --rebuild-index
|
python vidore_v1_benchmark.py --model colqwen2 --rebuild-index
|
||||||
"""
|
"""
|
||||||
@@ -28,11 +28,9 @@ import os
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
from leann_multi_vector import (
|
from leann_multi_vector import (
|
||||||
_ensure_repo_paths_importable,
|
|
||||||
ViDoReBenchmarkEvaluator,
|
ViDoReBenchmarkEvaluator,
|
||||||
|
_ensure_repo_paths_importable,
|
||||||
)
|
)
|
||||||
|
|
||||||
_ensure_repo_paths_importable(__file__)
|
_ensure_repo_paths_importable(__file__)
|
||||||
@@ -100,25 +98,25 @@ def load_vidore_v1_data(
|
|||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Load ViDoRe v1 dataset.
|
Load ViDoRe v1 dataset.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
corpus: dict mapping corpus_id to PIL Image
|
corpus: dict mapping corpus_id to PIL Image
|
||||||
queries: dict mapping query_id to query text
|
queries: dict mapping query_id to query text
|
||||||
qrels: dict mapping query_id to dict of {corpus_id: relevance_score}
|
qrels: dict mapping query_id to dict of {corpus_id: relevance_score}
|
||||||
"""
|
"""
|
||||||
print(f"Loading dataset: {dataset_path} (split={split})")
|
print(f"Loading dataset: {dataset_path} (split={split})")
|
||||||
|
|
||||||
# Load queries
|
# Load queries
|
||||||
query_ds = load_dataset(dataset_path, "queries", split=split, revision=revision)
|
query_ds = load_dataset(dataset_path, "queries", split=split, revision=revision)
|
||||||
|
|
||||||
queries = {}
|
queries = {}
|
||||||
for row in query_ds:
|
for row in query_ds:
|
||||||
query_id = f"query-{split}-{row['query-id']}"
|
query_id = f"query-{split}-{row['query-id']}"
|
||||||
queries[query_id] = row["query"]
|
queries[query_id] = row["query"]
|
||||||
|
|
||||||
# Load corpus (images)
|
# Load corpus (images)
|
||||||
corpus_ds = load_dataset(dataset_path, "corpus", split=split, revision=revision)
|
corpus_ds = load_dataset(dataset_path, "corpus", split=split, revision=revision)
|
||||||
|
|
||||||
corpus = {}
|
corpus = {}
|
||||||
for row in corpus_ds:
|
for row in corpus_ds:
|
||||||
corpus_id = f"corpus-{split}-{row['corpus-id']}"
|
corpus_id = f"corpus-{split}-{row['corpus-id']}"
|
||||||
@@ -128,11 +126,13 @@ def load_vidore_v1_data(
|
|||||||
elif "page_image" in row:
|
elif "page_image" in row:
|
||||||
corpus[corpus_id] = row["page_image"]
|
corpus[corpus_id] = row["page_image"]
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"No image field found in corpus. Available fields: {list(row.keys())}")
|
raise ValueError(
|
||||||
|
f"No image field found in corpus. Available fields: {list(row.keys())}"
|
||||||
|
)
|
||||||
|
|
||||||
# Load qrels (relevance judgments)
|
# Load qrels (relevance judgments)
|
||||||
qrels_ds = load_dataset(dataset_path, "qrels", split=split, revision=revision)
|
qrels_ds = load_dataset(dataset_path, "qrels", split=split, revision=revision)
|
||||||
|
|
||||||
qrels = {}
|
qrels = {}
|
||||||
for row in qrels_ds:
|
for row in qrels_ds:
|
||||||
query_id = f"query-{split}-{row['query-id']}"
|
query_id = f"query-{split}-{row['query-id']}"
|
||||||
@@ -140,19 +140,25 @@ def load_vidore_v1_data(
|
|||||||
if query_id not in qrels:
|
if query_id not in qrels:
|
||||||
qrels[query_id] = {}
|
qrels[query_id] = {}
|
||||||
qrels[query_id][corpus_id] = int(row["score"])
|
qrels[query_id][corpus_id] = int(row["score"])
|
||||||
|
|
||||||
print(f"Loaded {len(queries)} queries, {len(corpus)} corpus items, {len(qrels)} query-relevance mappings")
|
print(
|
||||||
|
f"Loaded {len(queries)} queries, {len(corpus)} corpus items, {len(qrels)} query-relevance mappings"
|
||||||
|
)
|
||||||
|
|
||||||
# Filter qrels to only include queries that exist
|
# Filter qrels to only include queries that exist
|
||||||
qrels = {qid: rel_docs for qid, rel_docs in qrels.items() if qid in queries}
|
qrels = {qid: rel_docs for qid, rel_docs in qrels.items() if qid in queries}
|
||||||
|
|
||||||
# Filter out queries without any relevant documents (matching MTEB behavior)
|
# Filter out queries without any relevant documents (matching MTEB behavior)
|
||||||
# This is important for correct NDCG calculation
|
# This is important for correct NDCG calculation
|
||||||
qrels_filtered = {qid: rel_docs for qid, rel_docs in qrels.items() if len(rel_docs) > 0}
|
qrels_filtered = {qid: rel_docs for qid, rel_docs in qrels.items() if len(rel_docs) > 0}
|
||||||
queries_filtered = {qid: query_text for qid, query_text in queries.items() if qid in qrels_filtered}
|
queries_filtered = {
|
||||||
|
qid: query_text for qid, query_text in queries.items() if qid in qrels_filtered
|
||||||
print(f"After filtering queries without positives: {len(queries_filtered)} queries, {len(qrels_filtered)} query-relevance mappings")
|
}
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"After filtering queries without positives: {len(queries_filtered)} queries, {len(qrels_filtered)} query-relevance mappings"
|
||||||
|
)
|
||||||
|
|
||||||
return corpus, queries_filtered, qrels_filtered
|
return corpus, queries_filtered, qrels_filtered
|
||||||
|
|
||||||
|
|
||||||
@@ -165,31 +171,35 @@ def evaluate_task(
|
|||||||
rebuild_index: bool = False,
|
rebuild_index: bool = False,
|
||||||
top_k: int = 1000,
|
top_k: int = 1000,
|
||||||
first_stage_k: int = 500,
|
first_stage_k: int = 500,
|
||||||
k_values: list[int] = [1, 3, 5, 10, 20, 100, 1000],
|
k_values: Optional[list[int]] = None,
|
||||||
output_dir: Optional[str] = None,
|
output_dir: Optional[str] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Evaluate a single ViDoRe v1 task.
|
Evaluate a single ViDoRe v1 task.
|
||||||
"""
|
"""
|
||||||
print(f"\n{'='*80}")
|
print(f"\n{'=' * 80}")
|
||||||
print(f"Evaluating task: {task_name}")
|
print(f"Evaluating task: {task_name}")
|
||||||
print(f"{'='*80}")
|
print(f"{'=' * 80}")
|
||||||
|
|
||||||
# Get task config
|
# Get task config
|
||||||
if task_name not in VIDORE_V1_TASKS:
|
if task_name not in VIDORE_V1_TASKS:
|
||||||
raise ValueError(f"Unknown task: {task_name}. Available: {list(VIDORE_V1_TASKS.keys())}")
|
raise ValueError(f"Unknown task: {task_name}. Available: {list(VIDORE_V1_TASKS.keys())}")
|
||||||
|
|
||||||
task_config = VIDORE_V1_TASKS[task_name]
|
task_config = VIDORE_V1_TASKS[task_name]
|
||||||
dataset_path = task_config["dataset_path"]
|
dataset_path = task_config["dataset_path"]
|
||||||
revision = task_config["revision"]
|
revision = task_config["revision"]
|
||||||
|
|
||||||
# Load data
|
# Load data
|
||||||
corpus, queries, qrels = load_vidore_v1_data(
|
corpus, queries, qrels = load_vidore_v1_data(
|
||||||
dataset_path=dataset_path,
|
dataset_path=dataset_path,
|
||||||
revision=revision,
|
revision=revision,
|
||||||
split="test",
|
split="test",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Initialize k_values if not provided
|
||||||
|
if k_values is None:
|
||||||
|
k_values = [1, 3, 5, 10, 20, 100, 1000]
|
||||||
|
|
||||||
# Check if we have any queries
|
# Check if we have any queries
|
||||||
if len(queries) == 0:
|
if len(queries) == 0:
|
||||||
print(f"\nWarning: No queries found for task {task_name}. Skipping evaluation.")
|
print(f"\nWarning: No queries found for task {task_name}. Skipping evaluation.")
|
||||||
@@ -202,7 +212,7 @@ def evaluate_task(
|
|||||||
scores[f"precision_at_{k}"] = 0.0
|
scores[f"precision_at_{k}"] = 0.0
|
||||||
scores[f"mrr_at_{k}"] = 0.0
|
scores[f"mrr_at_{k}"] = 0.0
|
||||||
return scores
|
return scores
|
||||||
|
|
||||||
# Initialize evaluator
|
# Initialize evaluator
|
||||||
evaluator = ViDoReBenchmarkEvaluator(
|
evaluator = ViDoReBenchmarkEvaluator(
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
@@ -211,20 +221,20 @@ def evaluate_task(
|
|||||||
first_stage_k=first_stage_k,
|
first_stage_k=first_stage_k,
|
||||||
k_values=k_values,
|
k_values=k_values,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Build or load index
|
# Build or load index
|
||||||
index_path_full = index_path if not use_fast_plaid else fast_plaid_index_path
|
index_path_full = index_path if not use_fast_plaid else fast_plaid_index_path
|
||||||
if index_path_full is None:
|
if index_path_full is None:
|
||||||
index_path_full = f"./indexes/{task_name}_{model_name}"
|
index_path_full = f"./indexes/{task_name}_{model_name}"
|
||||||
if use_fast_plaid:
|
if use_fast_plaid:
|
||||||
index_path_full = f"./indexes/{task_name}_{model_name}_fastplaid"
|
index_path_full = f"./indexes/{task_name}_{model_name}_fastplaid"
|
||||||
|
|
||||||
index_or_retriever, corpus_ids_ordered = evaluator.build_index_from_corpus(
|
index_or_retriever, corpus_ids_ordered = evaluator.build_index_from_corpus(
|
||||||
corpus=corpus,
|
corpus=corpus,
|
||||||
index_path=index_path_full,
|
index_path=index_path_full,
|
||||||
rebuild=rebuild_index,
|
rebuild=rebuild_index,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Search queries
|
# Search queries
|
||||||
task_prompt = task_config.get("prompt")
|
task_prompt = task_config.get("prompt")
|
||||||
results = evaluator.search_queries(
|
results = evaluator.search_queries(
|
||||||
@@ -234,32 +244,32 @@ def evaluate_task(
|
|||||||
fast_plaid_index_path=fast_plaid_index_path,
|
fast_plaid_index_path=fast_plaid_index_path,
|
||||||
task_prompt=task_prompt,
|
task_prompt=task_prompt,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Evaluate
|
# Evaluate
|
||||||
scores = evaluator.evaluate_results(results, qrels, k_values=k_values)
|
scores = evaluator.evaluate_results(results, qrels, k_values=k_values)
|
||||||
|
|
||||||
# Print results
|
# Print results
|
||||||
print(f"\n{'='*80}")
|
print(f"\n{'=' * 80}")
|
||||||
print(f"Results for {task_name}:")
|
print(f"Results for {task_name}:")
|
||||||
print(f"{'='*80}")
|
print(f"{'=' * 80}")
|
||||||
for metric, value in scores.items():
|
for metric, value in scores.items():
|
||||||
if isinstance(value, (int, float)):
|
if isinstance(value, (int, float)):
|
||||||
print(f" {metric}: {value:.5f}")
|
print(f" {metric}: {value:.5f}")
|
||||||
|
|
||||||
# Save results
|
# Save results
|
||||||
if output_dir:
|
if output_dir:
|
||||||
os.makedirs(output_dir, exist_ok=True)
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
results_file = os.path.join(output_dir, f"{task_name}_results.json")
|
results_file = os.path.join(output_dir, f"{task_name}_results.json")
|
||||||
scores_file = os.path.join(output_dir, f"{task_name}_scores.json")
|
scores_file = os.path.join(output_dir, f"{task_name}_scores.json")
|
||||||
|
|
||||||
with open(results_file, "w") as f:
|
with open(results_file, "w") as f:
|
||||||
json.dump(results, f, indent=2)
|
json.dump(results, f, indent=2)
|
||||||
print(f"\nSaved results to: {results_file}")
|
print(f"\nSaved results to: {results_file}")
|
||||||
|
|
||||||
with open(scores_file, "w") as f:
|
with open(scores_file, "w") as f:
|
||||||
json.dump(scores, f, indent=2)
|
json.dump(scores, f, indent=2)
|
||||||
print(f"Saved scores to: {scores_file}")
|
print(f"Saved scores to: {scores_file}")
|
||||||
|
|
||||||
return scores
|
return scores
|
||||||
|
|
||||||
|
|
||||||
@@ -332,12 +342,12 @@ def main():
|
|||||||
default="./vidore_v1_results",
|
default="./vidore_v1_results",
|
||||||
help="Output directory for results",
|
help="Output directory for results",
|
||||||
)
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Parse k_values
|
# Parse k_values
|
||||||
k_values = [int(k.strip()) for k in args.k_values.split(",")]
|
k_values = [int(k.strip()) for k in args.k_values.split(",")]
|
||||||
|
|
||||||
# Determine tasks to evaluate
|
# Determine tasks to evaluate
|
||||||
if args.task:
|
if args.task:
|
||||||
tasks_to_eval = [args.task]
|
tasks_to_eval = [args.task]
|
||||||
@@ -345,9 +355,9 @@ def main():
|
|||||||
tasks_to_eval = list(VIDORE_V1_TASKS.keys())
|
tasks_to_eval = list(VIDORE_V1_TASKS.keys())
|
||||||
else:
|
else:
|
||||||
tasks_to_eval = [t.strip() for t in args.tasks.split(",")]
|
tasks_to_eval = [t.strip() for t in args.tasks.split(",")]
|
||||||
|
|
||||||
print(f"Tasks to evaluate: {tasks_to_eval}")
|
print(f"Tasks to evaluate: {tasks_to_eval}")
|
||||||
|
|
||||||
# Evaluate each task
|
# Evaluate each task
|
||||||
all_scores = {}
|
all_scores = {}
|
||||||
for task_name in tasks_to_eval:
|
for task_name in tasks_to_eval:
|
||||||
@@ -368,14 +378,15 @@ def main():
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"\nError evaluating {task_name}: {e}")
|
print(f"\nError evaluating {task_name}: {e}")
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Print summary
|
# Print summary
|
||||||
if all_scores:
|
if all_scores:
|
||||||
print(f"\n{'='*80}")
|
print(f"\n{'=' * 80}")
|
||||||
print("SUMMARY")
|
print("SUMMARY")
|
||||||
print(f"{'='*80}")
|
print(f"{'=' * 80}")
|
||||||
for task_name, scores in all_scores.items():
|
for task_name, scores in all_scores.items():
|
||||||
print(f"\n{task_name}:")
|
print(f"\n{task_name}:")
|
||||||
# Print main metrics
|
# Print main metrics
|
||||||
@@ -386,4 +397,3 @@ def main():
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|
||||||
|
|||||||
@@ -11,13 +11,13 @@ This script uses the interface from leann_multi_vector.py to:
|
|||||||
Usage:
|
Usage:
|
||||||
# Evaluate all ViDoRe v2 tasks
|
# Evaluate all ViDoRe v2 tasks
|
||||||
python vidore_v2_benchmark.py --model colqwen2 --tasks all
|
python vidore_v2_benchmark.py --model colqwen2 --tasks all
|
||||||
|
|
||||||
# Evaluate specific task
|
# Evaluate specific task
|
||||||
python vidore_v2_benchmark.py --model colqwen2 --task Vidore2ESGReportsRetrieval
|
python vidore_v2_benchmark.py --model colqwen2 --task Vidore2ESGReportsRetrieval
|
||||||
|
|
||||||
# Use Fast-Plaid index
|
# Use Fast-Plaid index
|
||||||
python vidore_v2_benchmark.py --model colqwen2 --use-fast-plaid --fast-plaid-index-path ./indexes/vidore_fastplaid
|
python vidore_v2_benchmark.py --model colqwen2 --use-fast-plaid --fast-plaid-index-path ./indexes/vidore_fastplaid
|
||||||
|
|
||||||
# Rebuild index
|
# Rebuild index
|
||||||
python vidore_v2_benchmark.py --model colqwen2 --rebuild-index
|
python vidore_v2_benchmark.py --model colqwen2 --rebuild-index
|
||||||
"""
|
"""
|
||||||
@@ -28,11 +28,9 @@ import os
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
from leann_multi_vector import (
|
from leann_multi_vector import (
|
||||||
_ensure_repo_paths_importable,
|
|
||||||
ViDoReBenchmarkEvaluator,
|
ViDoReBenchmarkEvaluator,
|
||||||
|
_ensure_repo_paths_importable,
|
||||||
)
|
)
|
||||||
|
|
||||||
_ensure_repo_paths_importable(__file__)
|
_ensure_repo_paths_importable(__file__)
|
||||||
@@ -85,51 +83,57 @@ def load_vidore_v2_data(
|
|||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Load ViDoRe v2 dataset.
|
Load ViDoRe v2 dataset.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
corpus: dict mapping corpus_id to PIL Image
|
corpus: dict mapping corpus_id to PIL Image
|
||||||
queries: dict mapping query_id to query text
|
queries: dict mapping query_id to query text
|
||||||
qrels: dict mapping query_id to dict of {corpus_id: relevance_score}
|
qrels: dict mapping query_id to dict of {corpus_id: relevance_score}
|
||||||
"""
|
"""
|
||||||
print(f"Loading dataset: {dataset_path} (split={split}, language={language})")
|
print(f"Loading dataset: {dataset_path} (split={split}, language={language})")
|
||||||
|
|
||||||
# Load queries
|
# Load queries
|
||||||
query_ds = load_dataset(dataset_path, "queries", split=split, revision=revision)
|
query_ds = load_dataset(dataset_path, "queries", split=split, revision=revision)
|
||||||
|
|
||||||
# Check if dataset has language field before filtering
|
# Check if dataset has language field before filtering
|
||||||
has_language_field = len(query_ds) > 0 and "language" in query_ds.column_names
|
has_language_field = len(query_ds) > 0 and "language" in query_ds.column_names
|
||||||
|
|
||||||
if language and has_language_field:
|
if language and has_language_field:
|
||||||
# Map language name to dataset language field value (e.g., "english" -> "eng-Latn")
|
# Map language name to dataset language field value (e.g., "english" -> "eng-Latn")
|
||||||
dataset_language = LANGUAGE_MAPPING.get(language, language)
|
dataset_language = LANGUAGE_MAPPING.get(language, language)
|
||||||
query_ds_filtered = query_ds.filter(lambda x: x.get("language") == dataset_language)
|
query_ds_filtered = query_ds.filter(lambda x: x.get("language") == dataset_language)
|
||||||
# Check if filtering resulted in empty dataset
|
# Check if filtering resulted in empty dataset
|
||||||
if len(query_ds_filtered) == 0:
|
if len(query_ds_filtered) == 0:
|
||||||
print(f"Warning: No queries found after filtering by language '{language}' (mapped to '{dataset_language}').")
|
print(
|
||||||
|
f"Warning: No queries found after filtering by language '{language}' (mapped to '{dataset_language}')."
|
||||||
|
)
|
||||||
# Try with original language value (dataset might use simple names like 'english')
|
# Try with original language value (dataset might use simple names like 'english')
|
||||||
print(f"Trying with original language value '{language}'...")
|
print(f"Trying with original language value '{language}'...")
|
||||||
query_ds_filtered = query_ds.filter(lambda x: x.get("language") == language)
|
query_ds_filtered = query_ds.filter(lambda x: x.get("language") == language)
|
||||||
if len(query_ds_filtered) == 0:
|
if len(query_ds_filtered) == 0:
|
||||||
# Try to get a sample to see actual language values
|
# Try to get a sample to see actual language values
|
||||||
try:
|
try:
|
||||||
sample_ds = load_dataset(dataset_path, "queries", split=split, revision=revision)
|
sample_ds = load_dataset(
|
||||||
|
dataset_path, "queries", split=split, revision=revision
|
||||||
|
)
|
||||||
if len(sample_ds) > 0 and "language" in sample_ds.column_names:
|
if len(sample_ds) > 0 and "language" in sample_ds.column_names:
|
||||||
sample_langs = set(sample_ds["language"])
|
sample_langs = set(sample_ds["language"])
|
||||||
print(f"Available language values in dataset: {sample_langs}")
|
print(f"Available language values in dataset: {sample_langs}")
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
print(f"Found {len(query_ds_filtered)} queries using original language value '{language}'")
|
print(
|
||||||
|
f"Found {len(query_ds_filtered)} queries using original language value '{language}'"
|
||||||
|
)
|
||||||
query_ds = query_ds_filtered
|
query_ds = query_ds_filtered
|
||||||
|
|
||||||
queries = {}
|
queries = {}
|
||||||
for row in query_ds:
|
for row in query_ds:
|
||||||
query_id = f"query-{split}-{row['query-id']}"
|
query_id = f"query-{split}-{row['query-id']}"
|
||||||
queries[query_id] = row["query"]
|
queries[query_id] = row["query"]
|
||||||
|
|
||||||
# Load corpus (images)
|
# Load corpus (images)
|
||||||
corpus_ds = load_dataset(dataset_path, "corpus", split=split, revision=revision)
|
corpus_ds = load_dataset(dataset_path, "corpus", split=split, revision=revision)
|
||||||
|
|
||||||
corpus = {}
|
corpus = {}
|
||||||
for row in corpus_ds:
|
for row in corpus_ds:
|
||||||
corpus_id = f"corpus-{split}-{row['corpus-id']}"
|
corpus_id = f"corpus-{split}-{row['corpus-id']}"
|
||||||
@@ -139,11 +143,13 @@ def load_vidore_v2_data(
|
|||||||
elif "page_image" in row:
|
elif "page_image" in row:
|
||||||
corpus[corpus_id] = row["page_image"]
|
corpus[corpus_id] = row["page_image"]
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"No image field found in corpus. Available fields: {list(row.keys())}")
|
raise ValueError(
|
||||||
|
f"No image field found in corpus. Available fields: {list(row.keys())}"
|
||||||
|
)
|
||||||
|
|
||||||
# Load qrels (relevance judgments)
|
# Load qrels (relevance judgments)
|
||||||
qrels_ds = load_dataset(dataset_path, "qrels", split=split, revision=revision)
|
qrels_ds = load_dataset(dataset_path, "qrels", split=split, revision=revision)
|
||||||
|
|
||||||
qrels = {}
|
qrels = {}
|
||||||
for row in qrels_ds:
|
for row in qrels_ds:
|
||||||
query_id = f"query-{split}-{row['query-id']}"
|
query_id = f"query-{split}-{row['query-id']}"
|
||||||
@@ -151,19 +157,25 @@ def load_vidore_v2_data(
|
|||||||
if query_id not in qrels:
|
if query_id not in qrels:
|
||||||
qrels[query_id] = {}
|
qrels[query_id] = {}
|
||||||
qrels[query_id][corpus_id] = int(row["score"])
|
qrels[query_id][corpus_id] = int(row["score"])
|
||||||
|
|
||||||
print(f"Loaded {len(queries)} queries, {len(corpus)} corpus items, {len(qrels)} query-relevance mappings")
|
print(
|
||||||
|
f"Loaded {len(queries)} queries, {len(corpus)} corpus items, {len(qrels)} query-relevance mappings"
|
||||||
|
)
|
||||||
|
|
||||||
# Filter qrels to only include queries that exist
|
# Filter qrels to only include queries that exist
|
||||||
qrels = {qid: rel_docs for qid, rel_docs in qrels.items() if qid in queries}
|
qrels = {qid: rel_docs for qid, rel_docs in qrels.items() if qid in queries}
|
||||||
|
|
||||||
# Filter out queries without any relevant documents (matching MTEB behavior)
|
# Filter out queries without any relevant documents (matching MTEB behavior)
|
||||||
# This is important for correct NDCG calculation
|
# This is important for correct NDCG calculation
|
||||||
qrels_filtered = {qid: rel_docs for qid, rel_docs in qrels.items() if len(rel_docs) > 0}
|
qrels_filtered = {qid: rel_docs for qid, rel_docs in qrels.items() if len(rel_docs) > 0}
|
||||||
queries_filtered = {qid: query_text for qid, query_text in queries.items() if qid in qrels_filtered}
|
queries_filtered = {
|
||||||
|
qid: query_text for qid, query_text in queries.items() if qid in qrels_filtered
|
||||||
print(f"After filtering queries without positives: {len(queries_filtered)} queries, {len(qrels_filtered)} query-relevance mappings")
|
}
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"After filtering queries without positives: {len(queries_filtered)} queries, {len(qrels_filtered)} query-relevance mappings"
|
||||||
|
)
|
||||||
|
|
||||||
return corpus, queries_filtered, qrels_filtered
|
return corpus, queries_filtered, qrels_filtered
|
||||||
|
|
||||||
|
|
||||||
@@ -177,24 +189,24 @@ def evaluate_task(
|
|||||||
rebuild_index: bool = False,
|
rebuild_index: bool = False,
|
||||||
top_k: int = 100,
|
top_k: int = 100,
|
||||||
first_stage_k: int = 500,
|
first_stage_k: int = 500,
|
||||||
k_values: list[int] = [1, 3, 5, 10, 100],
|
k_values: Optional[list[int]] = None,
|
||||||
output_dir: Optional[str] = None,
|
output_dir: Optional[str] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Evaluate a single ViDoRe v2 task.
|
Evaluate a single ViDoRe v2 task.
|
||||||
"""
|
"""
|
||||||
print(f"\n{'='*80}")
|
print(f"\n{'=' * 80}")
|
||||||
print(f"Evaluating task: {task_name}")
|
print(f"Evaluating task: {task_name}")
|
||||||
print(f"{'='*80}")
|
print(f"{'=' * 80}")
|
||||||
|
|
||||||
# Get task config
|
# Get task config
|
||||||
if task_name not in VIDORE_V2_TASKS:
|
if task_name not in VIDORE_V2_TASKS:
|
||||||
raise ValueError(f"Unknown task: {task_name}. Available: {list(VIDORE_V2_TASKS.keys())}")
|
raise ValueError(f"Unknown task: {task_name}. Available: {list(VIDORE_V2_TASKS.keys())}")
|
||||||
|
|
||||||
task_config = VIDORE_V2_TASKS[task_name]
|
task_config = VIDORE_V2_TASKS[task_name]
|
||||||
dataset_path = task_config["dataset_path"]
|
dataset_path = task_config["dataset_path"]
|
||||||
revision = task_config["revision"]
|
revision = task_config["revision"]
|
||||||
|
|
||||||
# Determine language
|
# Determine language
|
||||||
if language is None:
|
if language is None:
|
||||||
# Use first language if multiple available
|
# Use first language if multiple available
|
||||||
@@ -206,7 +218,11 @@ def evaluate_task(
|
|||||||
language = languages[0]
|
language = languages[0]
|
||||||
else:
|
else:
|
||||||
language = None
|
language = None
|
||||||
|
|
||||||
|
# Initialize k_values if not provided
|
||||||
|
if k_values is None:
|
||||||
|
k_values = [1, 3, 5, 10, 100]
|
||||||
|
|
||||||
# Load data
|
# Load data
|
||||||
corpus, queries, qrels = load_vidore_v2_data(
|
corpus, queries, qrels = load_vidore_v2_data(
|
||||||
dataset_path=dataset_path,
|
dataset_path=dataset_path,
|
||||||
@@ -214,10 +230,12 @@ def evaluate_task(
|
|||||||
split="test",
|
split="test",
|
||||||
language=language,
|
language=language,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check if we have any queries
|
# Check if we have any queries
|
||||||
if len(queries) == 0:
|
if len(queries) == 0:
|
||||||
print(f"\nWarning: No queries found for task {task_name} with language {language}. Skipping evaluation.")
|
print(
|
||||||
|
f"\nWarning: No queries found for task {task_name} with language {language}. Skipping evaluation."
|
||||||
|
)
|
||||||
# Return zero scores
|
# Return zero scores
|
||||||
scores = {}
|
scores = {}
|
||||||
for k in k_values:
|
for k in k_values:
|
||||||
@@ -227,7 +245,7 @@ def evaluate_task(
|
|||||||
scores[f"precision_at_{k}"] = 0.0
|
scores[f"precision_at_{k}"] = 0.0
|
||||||
scores[f"mrr_at_{k}"] = 0.0
|
scores[f"mrr_at_{k}"] = 0.0
|
||||||
return scores
|
return scores
|
||||||
|
|
||||||
# Initialize evaluator
|
# Initialize evaluator
|
||||||
evaluator = ViDoReBenchmarkEvaluator(
|
evaluator = ViDoReBenchmarkEvaluator(
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
@@ -236,20 +254,20 @@ def evaluate_task(
|
|||||||
first_stage_k=first_stage_k,
|
first_stage_k=first_stage_k,
|
||||||
k_values=k_values,
|
k_values=k_values,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Build or load index
|
# Build or load index
|
||||||
index_path_full = index_path if not use_fast_plaid else fast_plaid_index_path
|
index_path_full = index_path if not use_fast_plaid else fast_plaid_index_path
|
||||||
if index_path_full is None:
|
if index_path_full is None:
|
||||||
index_path_full = f"./indexes/{task_name}_{model_name}"
|
index_path_full = f"./indexes/{task_name}_{model_name}"
|
||||||
if use_fast_plaid:
|
if use_fast_plaid:
|
||||||
index_path_full = f"./indexes/{task_name}_{model_name}_fastplaid"
|
index_path_full = f"./indexes/{task_name}_{model_name}_fastplaid"
|
||||||
|
|
||||||
index_or_retriever, corpus_ids_ordered = evaluator.build_index_from_corpus(
|
index_or_retriever, corpus_ids_ordered = evaluator.build_index_from_corpus(
|
||||||
corpus=corpus,
|
corpus=corpus,
|
||||||
index_path=index_path_full,
|
index_path=index_path_full,
|
||||||
rebuild=rebuild_index,
|
rebuild=rebuild_index,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Search queries
|
# Search queries
|
||||||
task_prompt = task_config.get("prompt")
|
task_prompt = task_config.get("prompt")
|
||||||
results = evaluator.search_queries(
|
results = evaluator.search_queries(
|
||||||
@@ -259,32 +277,32 @@ def evaluate_task(
|
|||||||
fast_plaid_index_path=fast_plaid_index_path,
|
fast_plaid_index_path=fast_plaid_index_path,
|
||||||
task_prompt=task_prompt,
|
task_prompt=task_prompt,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Evaluate
|
# Evaluate
|
||||||
scores = evaluator.evaluate_results(results, qrels, k_values=k_values)
|
scores = evaluator.evaluate_results(results, qrels, k_values=k_values)
|
||||||
|
|
||||||
# Print results
|
# Print results
|
||||||
print(f"\n{'='*80}")
|
print(f"\n{'=' * 80}")
|
||||||
print(f"Results for {task_name}:")
|
print(f"Results for {task_name}:")
|
||||||
print(f"{'='*80}")
|
print(f"{'=' * 80}")
|
||||||
for metric, value in scores.items():
|
for metric, value in scores.items():
|
||||||
if isinstance(value, (int, float)):
|
if isinstance(value, (int, float)):
|
||||||
print(f" {metric}: {value:.5f}")
|
print(f" {metric}: {value:.5f}")
|
||||||
|
|
||||||
# Save results
|
# Save results
|
||||||
if output_dir:
|
if output_dir:
|
||||||
os.makedirs(output_dir, exist_ok=True)
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
results_file = os.path.join(output_dir, f"{task_name}_results.json")
|
results_file = os.path.join(output_dir, f"{task_name}_results.json")
|
||||||
scores_file = os.path.join(output_dir, f"{task_name}_scores.json")
|
scores_file = os.path.join(output_dir, f"{task_name}_scores.json")
|
||||||
|
|
||||||
with open(results_file, "w") as f:
|
with open(results_file, "w") as f:
|
||||||
json.dump(results, f, indent=2)
|
json.dump(results, f, indent=2)
|
||||||
print(f"\nSaved results to: {results_file}")
|
print(f"\nSaved results to: {results_file}")
|
||||||
|
|
||||||
with open(scores_file, "w") as f:
|
with open(scores_file, "w") as f:
|
||||||
json.dump(scores, f, indent=2)
|
json.dump(scores, f, indent=2)
|
||||||
print(f"Saved scores to: {scores_file}")
|
print(f"Saved scores to: {scores_file}")
|
||||||
|
|
||||||
return scores
|
return scores
|
||||||
|
|
||||||
|
|
||||||
@@ -363,12 +381,12 @@ def main():
|
|||||||
default="./vidore_v2_results",
|
default="./vidore_v2_results",
|
||||||
help="Output directory for results",
|
help="Output directory for results",
|
||||||
)
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Parse k_values
|
# Parse k_values
|
||||||
k_values = [int(k.strip()) for k in args.k_values.split(",")]
|
k_values = [int(k.strip()) for k in args.k_values.split(",")]
|
||||||
|
|
||||||
# Determine tasks to evaluate
|
# Determine tasks to evaluate
|
||||||
if args.task:
|
if args.task:
|
||||||
tasks_to_eval = [args.task]
|
tasks_to_eval = [args.task]
|
||||||
@@ -376,9 +394,9 @@ def main():
|
|||||||
tasks_to_eval = list(VIDORE_V2_TASKS.keys())
|
tasks_to_eval = list(VIDORE_V2_TASKS.keys())
|
||||||
else:
|
else:
|
||||||
tasks_to_eval = [t.strip() for t in args.tasks.split(",")]
|
tasks_to_eval = [t.strip() for t in args.tasks.split(",")]
|
||||||
|
|
||||||
print(f"Tasks to evaluate: {tasks_to_eval}")
|
print(f"Tasks to evaluate: {tasks_to_eval}")
|
||||||
|
|
||||||
# Evaluate each task
|
# Evaluate each task
|
||||||
all_scores = {}
|
all_scores = {}
|
||||||
for task_name in tasks_to_eval:
|
for task_name in tasks_to_eval:
|
||||||
@@ -400,14 +418,15 @@ def main():
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"\nError evaluating {task_name}: {e}")
|
print(f"\nError evaluating {task_name}: {e}")
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Print summary
|
# Print summary
|
||||||
if all_scores:
|
if all_scores:
|
||||||
print(f"\n{'='*80}")
|
print(f"\n{'=' * 80}")
|
||||||
print("SUMMARY")
|
print("SUMMARY")
|
||||||
print(f"{'='*80}")
|
print(f"{'=' * 80}")
|
||||||
for task_name, scores in all_scores.items():
|
for task_name, scores in all_scores.items():
|
||||||
print(f"\n{task_name}:")
|
print(f"\n{task_name}:")
|
||||||
# Print main metrics
|
# Print main metrics
|
||||||
@@ -418,4 +437,3 @@ def main():
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user