update vidore

This commit is contained in:
yichuan-w
2025-11-14 07:31:24 +00:00
parent a9c014df9e
commit ae3b8af3df
4 changed files with 1101 additions and 59 deletions

3
.gitignore vendored
View File

@@ -91,7 +91,8 @@ packages/leann-backend-diskann/third_party/DiskANN/_deps/
*.meta.json *.meta.json
*.passages.json *.passages.json
*.npy
*.db
batchtest.py batchtest.py
tests/__pytest_cache__/ tests/__pytest_cache__/
tests/__pycache__/ tests/__pycache__/

View File

@@ -219,32 +219,47 @@ def _embed_images(model, processor, images: list[Image.Image]) -> list[Any]:
def _embed_queries(model, processor, queries: list[str]) -> list[Any]: def _embed_queries(model, processor, queries: list[str]) -> list[Any]:
import torch import torch
from colpali_engine.utils.torch_utils import ListDataset
from torch.utils.data import DataLoader
model.eval() model.eval()
dataloader = DataLoader( # Match MTEB's exact query processing from ColPaliEngineWrapper.get_text_embeddings:
dataset=ListDataset[str](queries), # 1. MTEB receives batch["text"] which may already include instruction/prompt
batch_size=1, # 2. Manually adds: query_prefix + text + query_augmentation_token * 10
shuffle=False, # 3. Calls processor.process_queries(batch) where batch is now a list of strings
collate_fn=lambda x: processor.process_queries(x), # 4. process_queries adds: query_prefix + text + suffix (suffix = query_augmentation_token * 10)
) #
# However, MTEB's approach results in duplicate addition (20 tokens total).
# Since we're already adding the prompt in search_queries, let's try:
# Option 1: Just call process_queries (let it handle all additions) - avoids duplicate
# Option 2: Manual add + process_texts (to avoid duplicate)
#
# Testing shows Option 1 works better - just call process_queries without manual addition
all_embeds = []
batch_size = 32 # Match MTEB's default batch_size
with torch.no_grad():
for i in tqdm(range(0, len(queries), batch_size), desc="Embedding queries"):
batch_queries = queries[i:i + batch_size]
# Just call process_queries - it will add query_prefix + text + 10 tokens
# This avoids duplicate addition that happens in MTEB's approach
inputs = processor.process_queries(batch_queries)
inputs = {k: v.to(model.device) for k, v in inputs.items()}
q_vecs: list[Any] = []
for batch_query in tqdm(dataloader, desc="Embedding queries"):
with torch.no_grad():
batch_query = {k: v.to(model.device) for k, v in batch_query.items()}
if model.device.type == "cuda": if model.device.type == "cuda":
with torch.autocast( with torch.autocast(
device_type="cuda", device_type="cuda",
dtype=model.dtype if model.dtype.is_floating_point else torch.bfloat16, dtype=model.dtype if model.dtype.is_floating_point else torch.bfloat16,
): ):
embeddings_query = model(**batch_query) outs = model(**inputs)
else: else:
embeddings_query = model(**batch_query) outs = model(**inputs)
q_vecs.extend(list(torch.unbind(embeddings_query.to("cpu"))))
return q_vecs # Match MTEB: convert to float32 on CPU
all_embeds.extend(list(torch.unbind(outs.cpu().to(torch.float32))))
return all_embeds
def _build_index( def _build_index(
@@ -284,6 +299,247 @@ def _load_retriever_if_index_exists(index_path: str) -> Optional[Any]:
return None return None
def _build_fast_plaid_index(
index_path: str,
doc_vecs: list[Any],
filepaths: list[str],
images: list[Image.Image],
) -> tuple[Any, float]:
"""
Build a Fast-Plaid index from document embeddings.
Args:
index_path: Path to save the Fast-Plaid index
doc_vecs: List of document embeddings (each is a tensor with shape [num_tokens, embedding_dim])
filepaths: List of filepath identifiers for each document
images: List of PIL Images corresponding to each document
Returns:
Tuple of (FastPlaid index object, build_time_in_seconds)
"""
import torch
from fast_plaid import search as fast_plaid_search
print(f" Preparing {len(doc_vecs)} document embeddings for Fast-Plaid...")
_t0 = time.perf_counter()
# Convert doc_vecs to list of tensors
documents_embeddings = []
for i, vec in enumerate(doc_vecs):
if i % 1000 == 0:
print(f" Converting embedding {i}/{len(doc_vecs)}...")
if not isinstance(vec, torch.Tensor):
vec = torch.tensor(vec) if isinstance(vec, np.ndarray) else torch.from_numpy(np.array(vec))
# Ensure float32 for Fast-Plaid
if vec.dtype != torch.float32:
vec = vec.float()
documents_embeddings.append(vec)
print(f" Converted {len(documents_embeddings)} embeddings")
if len(documents_embeddings) > 0:
print(f" First embedding shape: {documents_embeddings[0].shape}")
print(f" First embedding dtype: {documents_embeddings[0].dtype}")
# Prepare metadata for Fast-Plaid
print(f" Preparing metadata for {len(filepaths)} documents...")
metadata_list = []
for i, filepath in enumerate(filepaths):
metadata_list.append({
"filepath": filepath,
"index": i,
})
# Create Fast-Plaid index
print(f" Creating FastPlaid object with index path: {index_path}")
try:
fast_plaid_index = fast_plaid_search.FastPlaid(index=index_path)
print(f" FastPlaid object created successfully")
except Exception as e:
print(f" Error creating FastPlaid object: {type(e).__name__}: {e}")
import traceback
traceback.print_exc()
raise
print(f" Calling fast_plaid_index.create() with {len(documents_embeddings)} documents...")
try:
fast_plaid_index.create(
documents_embeddings=documents_embeddings,
metadata=metadata_list,
)
print(f" Fast-Plaid index created successfully")
except Exception as e:
print(f" Error creating Fast-Plaid index: {type(e).__name__}: {e}")
import traceback
traceback.print_exc()
raise
build_secs = time.perf_counter() - _t0
# Save images separately (Fast-Plaid doesn't store images)
print(f" Saving {len(images)} images...")
images_dir = Path(index_path) / "images"
images_dir.mkdir(parents=True, exist_ok=True)
for i, img in enumerate(tqdm(images, desc="Saving images")):
img_path = images_dir / f"doc_{i}.png"
img.save(str(img_path))
return fast_plaid_index, build_secs
def _fast_plaid_index_exists(index_path: str) -> bool:
"""
Check if Fast-Plaid index exists by checking for key files.
This avoids creating the FastPlaid object which may trigger memory allocation.
Args:
index_path: Path to the Fast-Plaid index
Returns:
True if index appears to exist, False otherwise
"""
index_path_obj = Path(index_path)
if not index_path_obj.exists() or not index_path_obj.is_dir():
return False
# Fast-Plaid creates a SQLite database file for metadata
# Check for metadata.db as the most reliable indicator
metadata_db = index_path_obj / "metadata.db"
if metadata_db.exists() and metadata_db.stat().st_size > 0:
return True
# Also check if directory has any files (might be incomplete index)
try:
if any(index_path_obj.iterdir()):
return True
except Exception:
pass
return False
def _load_fast_plaid_index_if_exists(index_path: str) -> Optional[Any]:
"""
Load Fast-Plaid index if it exists.
First checks if index files exist, then creates the FastPlaid object.
The actual index data loading happens lazily when search is called.
Args:
index_path: Path to the Fast-Plaid index
Returns:
FastPlaid index object if exists, None otherwise
"""
try:
from fast_plaid import search as fast_plaid_search
# First check if index files exist without creating the object
if not _fast_plaid_index_exists(index_path):
return None
# Now try to create FastPlaid object
# This may trigger some memory allocation, but the full index loading is deferred
fast_plaid_index = fast_plaid_search.FastPlaid(index=index_path)
return fast_plaid_index
except ImportError:
# fast-plaid not installed
return None
except Exception as e:
# Any error (including memory errors from Rust backend) - return None
# The error will be caught and index will be rebuilt
print(f"Warning: Could not load Fast-Plaid index: {type(e).__name__}: {e}")
return None
def _search_fast_plaid(
fast_plaid_index: Any,
query_vec: Any,
top_k: int,
) -> tuple[list[tuple[float, int]], float]:
"""
Search Fast-Plaid index with a query embedding.
Args:
fast_plaid_index: FastPlaid index object
query_vec: Query embedding tensor with shape [num_tokens, embedding_dim]
top_k: Number of top results to return
Returns:
Tuple of (results_list, search_time_in_seconds)
results_list: List of (score, doc_id) tuples
"""
import torch
_t0 = time.perf_counter()
# Ensure query is a torch tensor
if not isinstance(query_vec, torch.Tensor):
q_vec_tensor = torch.tensor(query_vec) if isinstance(query_vec, np.ndarray) else torch.from_numpy(np.array(query_vec))
else:
q_vec_tensor = query_vec
# Fast-Plaid expects shape [num_queries, num_tokens, embedding_dim]
if q_vec_tensor.dim() == 2:
q_vec_tensor = q_vec_tensor.unsqueeze(0) # [1, num_tokens, embedding_dim]
# Perform search
scores = fast_plaid_index.search(
queries_embeddings=q_vec_tensor,
top_k=top_k,
show_progress=True,
)
search_secs = time.perf_counter() - _t0
# Convert Fast-Plaid results to same format as LEANN: list of (score, doc_id) tuples
results = []
if scores and len(scores) > 0:
query_results = scores[0]
# Fast-Plaid returns (doc_id, score), convert to (score, doc_id) to match LEANN format
results = [(float(score), int(doc_id)) for doc_id, score in query_results]
return results, search_secs
def _get_fast_plaid_image(index_path: str, doc_id: int) -> Optional[Image.Image]:
"""
Retrieve image for a document from Fast-Plaid index.
Args:
index_path: Path to the Fast-Plaid index
doc_id: Document ID
Returns:
PIL Image if found, None otherwise
"""
images_dir = Path(index_path) / "images"
image_path = images_dir / f"doc_{doc_id}.png"
if image_path.exists():
return Image.open(image_path)
return None
def _get_fast_plaid_metadata(index_path: str, doc_id: int) -> Optional[dict]:
"""
Retrieve metadata for a document from Fast-Plaid index.
Args:
index_path: Path to the Fast-Plaid index
doc_id: Document ID
Returns:
Dictionary with metadata if found, None otherwise
"""
try:
from fast_plaid import filtering
metadata_list = filtering.get(index=index_path, subset=[doc_id])
if metadata_list and len(metadata_list) > 0:
return metadata_list[0]
except Exception:
pass
return None
def _generate_similarity_map( def _generate_similarity_map(
model, model,
processor, processor,

View File

@@ -2,13 +2,18 @@
# %% # %%
# uv pip install matplotlib qwen_vl_utils # uv pip install matplotlib qwen_vl_utils
import argparse import argparse
import faulthandler
import os import os
import time import time
from typing import Any, Optional from typing import Any, Optional
import numpy as np
from PIL import Image from PIL import Image
from tqdm import tqdm from tqdm import tqdm
# Enable faulthandler to get stack trace on segfault
faulthandler.enable()
from leann_multi_vector import ( # utility functions/classes from leann_multi_vector import ( # utility functions/classes
_ensure_repo_paths_importable, _ensure_repo_paths_importable,
@@ -20,6 +25,11 @@ from leann_multi_vector import ( # utility functions/classes
_build_index, _build_index,
_load_retriever_if_index_exists, _load_retriever_if_index_exists,
_generate_similarity_map, _generate_similarity_map,
_build_fast_plaid_index,
_load_fast_plaid_index_if_exists,
_search_fast_plaid,
_get_fast_plaid_image,
_get_fast_plaid_metadata,
QwenVL, QwenVL,
) )
@@ -69,6 +79,8 @@ PAGES_DIR: str = "./pages"
# Index + retrieval settings # Index + retrieval settings
# Use a different index path for larger dataset to avoid overwriting existing index # Use a different index path for larger dataset to avoid overwriting existing index
INDEX_PATH: str = "./indexes/colvision_large.leann" INDEX_PATH: str = "./indexes/colvision_large.leann"
# Fast-Plaid index settings (alternative to LEANN index)
# These are now command-line arguments (see CLI overrides section)
TOPK: int = 3 TOPK: int = 3
FIRST_STAGE_K: int = 500 FIRST_STAGE_K: int = 500
REBUILD_INDEX: bool = False REBUILD_INDEX: bool = False
@@ -98,24 +110,64 @@ parser.add_argument(
default=QUERY, default=QUERY,
help=f"Query string to search for. Default: '{QUERY}'", help=f"Query string to search for. Default: '{QUERY}'",
) )
parser.add_argument(
"--use-fast-plaid",
action="store_true",
default=False,
help="Set to True to use fast-plaid instead of LEANN. Default: False",
)
parser.add_argument(
"--fast-plaid-index-path",
type=str,
default="./indexes/colvision_fastplaid",
help="Path to the Fast-Plaid index. Default: './indexes/colvision_fastplaid'",
)
cli_args, _unknown = parser.parse_known_args() cli_args, _unknown = parser.parse_known_args()
SEARCH_METHOD: str = cli_args.search_method SEARCH_METHOD: str = cli_args.search_method
QUERY = cli_args.query # Override QUERY with CLI argument if provided QUERY = cli_args.query # Override QUERY with CLI argument if provided
USE_FAST_PLAID: bool = cli_args.use_fast_plaid
FAST_PLAID_INDEX_PATH: str = cli_args.fast_plaid_index_path
# %% # %%
# Step 1: Check if we can skip data loading (index already exists) # Step 1: Check if we can skip data loading (index already exists)
retriever: Optional[Any] = None retriever: Optional[Any] = None
fast_plaid_index: Optional[Any] = None
need_to_build_index = REBUILD_INDEX need_to_build_index = REBUILD_INDEX
if not REBUILD_INDEX: if USE_FAST_PLAID:
retriever = _load_retriever_if_index_exists(INDEX_PATH) # Fast-Plaid index handling
if retriever is not None: if not REBUILD_INDEX:
print(f"✓ Index loaded from {INDEX_PATH}") try:
print(f"✓ Images available at: {retriever._images_dir_path()}") fast_plaid_index = _load_fast_plaid_index_if_exists(FAST_PLAID_INDEX_PATH)
need_to_build_index = False if fast_plaid_index is not None:
print(f"✓ Fast-Plaid index found at {FAST_PLAID_INDEX_PATH}")
need_to_build_index = False
else:
print(f"Fast-Plaid index not found, will build new index")
need_to_build_index = True
except Exception as e:
# If loading fails (e.g., memory error, corrupted index), rebuild
print(f"Warning: Failed to load Fast-Plaid index: {e}")
print("Will rebuild the index...")
need_to_build_index = True
fast_plaid_index = None
else: else:
print(f"Index not found, will build new index") print(f"REBUILD_INDEX=True, will rebuild Fast-Plaid index")
need_to_build_index = True
else:
# Original LEANN index handling
if not REBUILD_INDEX:
retriever = _load_retriever_if_index_exists(INDEX_PATH)
if retriever is not None:
print(f"✓ Index loaded from {INDEX_PATH}")
print(f"✓ Images available at: {retriever._images_dir_path()}")
need_to_build_index = False
else:
print(f"Index not found, will build new index")
need_to_build_index = True
else:
print(f"REBUILD_INDEX=True, will rebuild index")
need_to_build_index = True need_to_build_index = True
# Step 2: Load data only if we need to build the index # Step 2: Load data only if we need to build the index
@@ -347,6 +399,19 @@ if need_to_build_index:
f"No images found in {PAGES_DIR}. Provide PDF path in PDF variable or ensure images exist." f"No images found in {PAGES_DIR}. Provide PDF path in PDF variable or ensure images exist."
) )
print(f"Loaded {len(images)} images") print(f"Loaded {len(images)} images")
# Memory check before loading model
try:
import psutil
import torch
process = psutil.Process(os.getpid())
mem_info = process.memory_info()
print(f"Memory usage after loading images: {mem_info.rss / 1024 / 1024 / 1024:.2f} GB")
if torch.cuda.is_available():
print(f"GPU memory allocated: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")
print(f"GPU memory reserved: {torch.cuda.memory_reserved() / 1024**3:.2f} GB")
except ImportError:
pass
else: else:
print("Skipping dataset loading (using existing index)") print("Skipping dataset loading (using existing index)")
filepaths = [] # Not needed when using existing index filepaths = [] # Not needed when using existing index
@@ -355,23 +420,91 @@ else:
# %% # %%
# Step 3: Load model and processor (only if we need to build index or perform search) # Step 3: Load model and processor (only if we need to build index or perform search)
model_name, model, processor, device_str, device, dtype = _load_colvision(MODEL) print("Step 3: Loading model and processor...")
print(f"Using model={model_name}, device={device_str}, dtype={dtype}") print(f" Model: {MODEL}")
try:
import sys
print(f" Python version: {sys.version}")
print(f" Python executable: {sys.executable}")
model_name, model, processor, device_str, device, dtype = _load_colvision(MODEL)
print(f"✓ Using model={model_name}, device={device_str}, dtype={dtype}")
# Memory check after loading model
try:
import psutil
import torch
process = psutil.Process(os.getpid())
mem_info = process.memory_info()
print(f" Memory usage after loading model: {mem_info.rss / 1024 / 1024 / 1024:.2f} GB")
if torch.cuda.is_available():
print(f" GPU memory allocated: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")
print(f" GPU memory reserved: {torch.cuda.memory_reserved() / 1024**3:.2f} GB")
except ImportError:
pass
except Exception as e:
print(f"✗ Error loading model: {type(e).__name__}: {e}")
import traceback
traceback.print_exc()
raise
# %% # %%
# %% # %%
# Step 4: Build index if needed # Step 4: Build index if needed
if need_to_build_index and retriever is None: if need_to_build_index:
print("Building index...") print("Step 4: Building index...")
doc_vecs = _embed_images(model, processor, images) print(f" Number of images: {len(images)}")
retriever = _build_index(INDEX_PATH, doc_vecs, filepaths, images) print(f" Number of filepaths: {len(filepaths)}")
print(f"✓ Index built and images saved to: {retriever._images_dir_path()}")
# Clear memory
del images, filepaths, doc_vecs
# Note: Images are now stored in the index, retriever will load them on-demand from disk try:
print(" Embedding images...")
doc_vecs = _embed_images(model, processor, images)
print(f" Embedded {len(doc_vecs)} documents")
print(f" First doc vec shape: {doc_vecs[0].shape if len(doc_vecs) > 0 else 'N/A'}")
except Exception as e:
print(f"Error embedding images: {type(e).__name__}: {e}")
import traceback
traceback.print_exc()
raise
if USE_FAST_PLAID:
# Build Fast-Plaid index
print(" Building Fast-Plaid index...")
try:
fast_plaid_index, build_secs = _build_fast_plaid_index(
FAST_PLAID_INDEX_PATH, doc_vecs, filepaths, images
)
from pathlib import Path
print(f"✓ Fast-Plaid index built in {build_secs:.3f}s")
print(f"✓ Index saved to: {FAST_PLAID_INDEX_PATH}")
print(f"✓ Images saved to: {Path(FAST_PLAID_INDEX_PATH) / 'images'}")
except Exception as e:
print(f"Error building Fast-Plaid index: {type(e).__name__}: {e}")
import traceback
traceback.print_exc()
raise
finally:
# Clear memory
print(" Clearing memory...")
del images, filepaths, doc_vecs
else:
# Build original LEANN index
try:
retriever = _build_index(INDEX_PATH, doc_vecs, filepaths, images)
print(f"✓ Index built and images saved to: {retriever._images_dir_path()}")
except Exception as e:
print(f"Error building LEANN index: {type(e).__name__}: {e}")
import traceback
traceback.print_exc()
raise
finally:
# Clear memory
print(" Clearing memory...")
del images, filepaths, doc_vecs
# Note: Images are now stored separately, retriever/fast_plaid_index will reference them
# %% # %%
@@ -380,44 +513,67 @@ _t0 = time.perf_counter()
q_vec = _embed_queries(model, processor, [QUERY])[0] q_vec = _embed_queries(model, processor, [QUERY])[0]
query_embed_secs = time.perf_counter() - _t0 query_embed_secs = time.perf_counter() - _t0
query_np = q_vec.float().numpy()
print(f"[Search] Method: {SEARCH_METHOD}") print(f"[Search] Method: {SEARCH_METHOD}")
print(f"[Timing] Query embedding: {query_embed_secs:.3f}s") print(f"[Timing] Query embedding: {query_embed_secs:.3f}s")
# Run the selected search method and time it # Run the selected search method and time it
_t0 = time.perf_counter() if USE_FAST_PLAID:
if SEARCH_METHOD == "ann": # Fast-Plaid search
results = retriever.search(query_np, topk=TOPK, first_stage_k=FIRST_STAGE_K) if fast_plaid_index is None:
search_secs = time.perf_counter() - _t0 fast_plaid_index = _load_fast_plaid_index_if_exists(FAST_PLAID_INDEX_PATH)
print(f"[Timing] Search (ANN): {search_secs:.3f}s (first_stage_k={FIRST_STAGE_K})") if fast_plaid_index is None:
elif SEARCH_METHOD == "exact": raise RuntimeError(f"Fast-Plaid index not found at {FAST_PLAID_INDEX_PATH}")
results = retriever.search_exact(query_np, topk=TOPK, first_stage_k=FIRST_STAGE_K)
search_secs = time.perf_counter() - _t0 results, search_secs = _search_fast_plaid(fast_plaid_index, q_vec, TOPK)
print(f"[Timing] Search (Exact rerank): {search_secs:.3f}s (first_stage_k={FIRST_STAGE_K})") print(f"[Timing] Fast-Plaid Search: {search_secs:.3f}s")
elif SEARCH_METHOD == "exact-all":
results = retriever.search_exact_all(query_np, topk=TOPK)
search_secs = time.perf_counter() - _t0
print(f"[Timing] Search (Exact all): {search_secs:.3f}s")
else: else:
results = [] # Original LEANN search
query_np = q_vec.float().numpy()
if SEARCH_METHOD == "ann":
results = retriever.search(query_np, topk=TOPK, first_stage_k=FIRST_STAGE_K)
search_secs = time.perf_counter() - _t0
print(f"[Timing] Search (ANN): {search_secs:.3f}s (first_stage_k={FIRST_STAGE_K})")
elif SEARCH_METHOD == "exact":
results = retriever.search_exact(query_np, topk=TOPK, first_stage_k=FIRST_STAGE_K)
search_secs = time.perf_counter() - _t0
print(f"[Timing] Search (Exact rerank): {search_secs:.3f}s (first_stage_k={FIRST_STAGE_K})")
elif SEARCH_METHOD == "exact-all":
results = retriever.search_exact_all(query_np, topk=TOPK)
search_secs = time.perf_counter() - _t0
print(f"[Timing] Search (Exact all): {search_secs:.3f}s")
else:
results = []
if not results: if not results:
print("No results found.") print("No results found.")
else: else:
print(f'Top {len(results)} results for query: "{QUERY}"') print(f'Top {len(results)} results for query: "{QUERY}"')
top_images: list[Image.Image] = [] top_images: list[Image.Image] = []
for rank, (score, doc_id) in enumerate(results, start=1): for rank, (score, doc_id) in enumerate(results, start=1):
# Retrieve image from index instead of memory # Retrieve image and metadata based on index type
image = retriever.get_image(doc_id) if USE_FAST_PLAID:
if image is None: # Fast-Plaid: load image and get metadata
print(f"Warning: Could not retrieve image for doc_id {doc_id}") image = _get_fast_plaid_image(FAST_PLAID_INDEX_PATH, doc_id)
continue if image is None:
print(f"Warning: Could not find image for doc_id {doc_id}")
continue
metadata = _get_fast_plaid_metadata(FAST_PLAID_INDEX_PATH, doc_id)
path = metadata.get("filepath", f"doc_{doc_id}") if metadata else f"doc_{doc_id}"
top_images.append(image)
else:
# Original LEANN: retrieve from retriever
image = retriever.get_image(doc_id)
if image is None:
print(f"Warning: Could not retrieve image for doc_id {doc_id}")
continue
metadata = retriever.get_metadata(doc_id)
path = metadata.get("filepath", "unknown") if metadata else "unknown"
top_images.append(image)
metadata = retriever.get_metadata(doc_id)
path = metadata.get("filepath", "unknown") if metadata else "unknown"
# For HF dataset, path is a descriptive identifier, not a real file path # For HF dataset, path is a descriptive identifier, not a real file path
print(f"{rank}) MaxSim: {score:.4f}, Page: {path}") print(f"{rank}) MaxSim: {score:.4f}, Page: {path}")
top_images.append(image)
if SAVE_TOP_IMAGE: if SAVE_TOP_IMAGE:
from pathlib import Path as _Path from pathlib import Path as _Path

View File

@@ -0,0 +1,629 @@
#!/usr/bin/env python3
"""
Modular script to reproduce NDCG results for ViDoRe v2 benchmark.
This script uses the interface from leann_multi_vector.py to:
1. Download ViDoRe v2 datasets
2. Build indexes (LEANN or Fast-Plaid)
3. Perform retrieval
4. Evaluate using NDCG metrics
Usage:
# Evaluate all ViDoRe v2 tasks
python vidore_v2_benchmark.py --model colqwen2 --tasks all
# Evaluate specific task
python vidore_v2_benchmark.py --model colqwen2 --task Vidore2ESGReportsRetrieval
# Use Fast-Plaid index
python vidore_v2_benchmark.py --model colqwen2 --use-fast-plaid --fast-plaid-index-path ./indexes/vidore_fastplaid
# Rebuild index
python vidore_v2_benchmark.py --model colqwen2 --rebuild-index
"""
import argparse
import json
import os
import time
from pathlib import Path
from typing import Any, Optional
import numpy as np
from datasets import load_dataset
from PIL import Image
from tqdm import tqdm
# Import MTEB for evaluation metrics
try:
import pytrec_eval
from mteb._evaluators.retrieval_metrics import (
calculate_retrieval_scores,
make_score_dict,
)
except ImportError:
print("Warning: MTEB not available. Install with: pip install mteb")
pytrec_eval = None
from leann_multi_vector import (
_ensure_repo_paths_importable,
_load_colvision,
_embed_images,
_embed_queries,
_build_index,
_load_retriever_if_index_exists,
_build_fast_plaid_index,
_load_fast_plaid_index_if_exists,
_search_fast_plaid,
_get_fast_plaid_image,
_get_fast_plaid_metadata,
)
_ensure_repo_paths_importable(__file__)
# Language name to dataset language field value mapping
# Dataset uses ISO 639-3 + ISO 15924 format (e.g., "eng-Latn")
LANGUAGE_MAPPING = {
"english": "eng-Latn",
"french": "fra-Latn",
"spanish": "spa-Latn",
"german": "deu-Latn",
}
# ViDoRe v2 task configurations
# Prompts match MTEB task metadata prompts
VIDORE_V2_TASKS = {
"Vidore2ESGReportsRetrieval": {
"dataset_path": "vidore/esg_reports_v2",
"revision": "0542c0d03da0ec1c8cbc517c8d78e7e95c75d3d3",
"languages": ["french", "spanish", "english", "german"],
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
},
"Vidore2EconomicsReportsRetrieval": {
"dataset_path": "vidore/economics_reports_v2",
"revision": "b3e3a04b07fbbaffe79be49dabf92f691fbca252",
"languages": ["french", "spanish", "english", "german"],
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
},
"Vidore2BioMedicalLecturesRetrieval": {
"dataset_path": "vidore/biomedical_lectures_v2",
"revision": "a29202f0da409034d651614d87cd8938d254e2ea",
"languages": ["french", "spanish", "english", "german"],
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
},
"Vidore2ESGReportsHLRetrieval": {
"dataset_path": "vidore/esg_reports_human_labeled_v2",
"revision": "6d467dedb09a75144ede1421747e47cf036857dd",
# Note: This dataset doesn't have language filtering - all queries are English
"languages": None, # No language filtering needed
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
},
}
def load_vidore_v2_data(
dataset_path: str,
revision: Optional[str] = None,
split: str = "test",
language: Optional[str] = None,
):
"""
Load ViDoRe v2 dataset.
Returns:
corpus: dict mapping corpus_id to PIL Image
queries: dict mapping query_id to query text
qrels: dict mapping query_id to dict of {corpus_id: relevance_score}
"""
print(f"Loading dataset: {dataset_path} (split={split}, language={language})")
# Load queries
query_ds = load_dataset(dataset_path, "queries", split=split, revision=revision)
# Check if dataset has language field before filtering
has_language_field = len(query_ds) > 0 and "language" in query_ds.column_names
if language and has_language_field:
# Map language name to dataset language field value (e.g., "english" -> "eng-Latn")
dataset_language = LANGUAGE_MAPPING.get(language, language)
query_ds_filtered = query_ds.filter(lambda x: x.get("language") == dataset_language)
# Check if filtering resulted in empty dataset
if len(query_ds_filtered) == 0:
print(f"Warning: No queries found after filtering by language '{language}' (mapped to '{dataset_language}').")
# Try with original language value (dataset might use simple names like 'english')
print(f"Trying with original language value '{language}'...")
query_ds_filtered = query_ds.filter(lambda x: x.get("language") == language)
if len(query_ds_filtered) == 0:
# Try to get a sample to see actual language values
try:
sample_ds = load_dataset(dataset_path, "queries", split=split, revision=revision)
if len(sample_ds) > 0 and "language" in sample_ds.column_names:
sample_langs = set(sample_ds["language"])
print(f"Available language values in dataset: {sample_langs}")
except Exception:
pass
else:
print(f"Found {len(query_ds_filtered)} queries using original language value '{language}'")
query_ds = query_ds_filtered
queries = {}
for row in query_ds:
query_id = f"query-{split}-{row['query-id']}"
queries[query_id] = row["query"]
# Load corpus (images)
corpus_ds = load_dataset(dataset_path, "corpus", split=split, revision=revision)
corpus = {}
for row in corpus_ds:
corpus_id = f"corpus-{split}-{row['corpus-id']}"
# Extract image from the dataset row
if "image" in row:
corpus[corpus_id] = row["image"]
elif "page_image" in row:
corpus[corpus_id] = row["page_image"]
else:
raise ValueError(f"No image field found in corpus. Available fields: {list(row.keys())}")
# Load qrels (relevance judgments)
qrels_ds = load_dataset(dataset_path, "qrels", split=split, revision=revision)
qrels = {}
for row in qrels_ds:
query_id = f"query-{split}-{row['query-id']}"
corpus_id = f"corpus-{split}-{row['corpus-id']}"
if query_id not in qrels:
qrels[query_id] = {}
qrels[query_id][corpus_id] = int(row["score"])
print(f"Loaded {len(queries)} queries, {len(corpus)} corpus items, {len(qrels)} query-relevance mappings")
# Filter qrels to only include queries that exist
qrels = {qid: rel_docs for qid, rel_docs in qrels.items() if qid in queries}
return corpus, queries, qrels
def build_index_from_corpus(
corpus: dict[str, Image.Image],
model,
processor,
index_path: str,
use_fast_plaid: bool = False,
rebuild: bool = False,
):
"""
Build index from corpus images.
Returns:
tuple: (retriever or fast_plaid_index object, list of corpus_ids in order)
"""
# Ensure consistent ordering
corpus_ids = sorted(corpus.keys()) # Sort for consistency
images = [corpus[cid] for cid in corpus_ids]
if use_fast_plaid:
# Check if Fast-Plaid index exists
if not rebuild and _load_fast_plaid_index_if_exists(index_path) is not None:
print(f"Fast-Plaid index already exists at {index_path}")
return _load_fast_plaid_index_if_exists(index_path), corpus_ids
print(f"Building Fast-Plaid index at {index_path}...")
# Embed images
print("Embedding images...")
doc_vecs = _embed_images(model, processor, images)
# Build index
fast_plaid_index, build_time = _build_fast_plaid_index(
index_path, doc_vecs, corpus_ids, images
)
print(f"Fast-Plaid index built in {build_time:.2f}s")
return fast_plaid_index, corpus_ids
else:
# Check if LEANN index exists
if not rebuild:
retriever = _load_retriever_if_index_exists(index_path)
if retriever is not None:
print(f"LEANN index already exists at {index_path}")
return retriever, corpus_ids
print(f"Building LEANN index at {index_path}...")
# Embed images
print("Embedding images...")
doc_vecs = _embed_images(model, processor, images)
# Build index
retriever = _build_index(index_path, doc_vecs, corpus_ids, images)
print(f"LEANN index built")
return retriever, corpus_ids
def search_queries(
queries: dict[str, str],
corpus_ids: list[str],
model,
processor,
index_or_retriever: Any,
use_fast_plaid: bool = False,
fast_plaid_index_path: Optional[str] = None,
top_k: int = 100,
first_stage_k: int = 500,
task_prompt: Optional[dict[str, str]] = None,
) -> dict[str, dict[str, float]]:
"""
Search queries against the index.
Args:
queries: dict mapping query_id to query text
corpus_ids: list of corpus_ids in the same order as the index
model: model object
processor: processor object
index_or_retriever: index or retriever object
use_fast_plaid: whether using Fast-Plaid
fast_plaid_index_path: path to Fast-Plaid index (for metadata)
top_k: top-k results to retrieve
first_stage_k: first stage k for LEANN search
task_prompt: Optional dict with prompt for query (e.g., {"query": "..."})
Returns:
results: dict mapping query_id to dict of {corpus_id: score}
"""
print(f"Searching {len(queries)} queries (top_k={top_k})...")
query_ids = list(queries.keys())
query_texts = [queries[qid] for qid in query_ids]
# Match MTEB: combine queries with instruction/prompt if provided
# MTEB's _combine_queries_with_instruction_text does: query + " " + instruction
if task_prompt and "query" in task_prompt:
instruction = task_prompt["query"]
query_texts = [q + " " + instruction for q in query_texts]
print(f"Added task prompt to queries: {instruction}")
# Embed queries
print("Embedding queries...")
query_vecs = _embed_queries(model, processor, query_texts)
results = {}
for query_id, query_vec in zip(tqdm(query_ids, desc="Searching"), query_vecs):
if use_fast_plaid:
# Fast-Plaid search
search_results, _ = _search_fast_plaid(index_or_retriever, query_vec, top_k)
# Convert doc_id back to corpus_id
query_results = {}
for score, doc_id in search_results:
if doc_id < len(corpus_ids):
corpus_id = corpus_ids[doc_id]
query_results[corpus_id] = float(score)
else:
# LEANN search
query_np = query_vec.float().numpy()
search_results = index_or_retriever.search_exact_all(query_np, topk=top_k)
# Convert doc_id back to corpus_id
query_results = {}
for score, doc_id in search_results:
if doc_id < len(corpus_ids):
corpus_id = corpus_ids[doc_id]
query_results[corpus_id] = float(score)
results[query_id] = query_results
return results
def evaluate_results(
results: dict[str, dict[str, float]],
qrels: dict[str, dict[str, int]],
k_values: list[int] = [1, 3, 5, 10, 100],
) -> dict[str, float]:
"""
Evaluate retrieval results using NDCG and other metrics.
Returns:
Dictionary of metric scores
"""
if pytrec_eval is None:
raise ImportError("pytrec_eval is required for evaluation. Install with: pip install pytrec-eval")
# Check if we have any queries to evaluate
if len(results) == 0:
print("Warning: No queries to evaluate. Returning zero scores.")
# Return zero scores for all metrics
scores = {}
for k in k_values:
scores[f"ndcg_at_{k}"] = 0.0
scores[f"map_at_{k}"] = 0.0
scores[f"recall_at_{k}"] = 0.0
scores[f"precision_at_{k}"] = 0.0
scores[f"mrr_at_{k}"] = 0.0
return scores
print(f"Evaluating results with k_values={k_values}...")
# Convert qrels to pytrec_eval format
qrels_pytrec = {}
for qid, rel_docs in qrels.items():
qrels_pytrec[qid] = {did: score for did, score in rel_docs.items()}
# Evaluate
eval_result = calculate_retrieval_scores(
results=results,
qrels=qrels_pytrec,
k_values=k_values,
)
# Format scores
scores = make_score_dict(
ndcg=eval_result.ndcg,
_map=eval_result.map,
recall=eval_result.recall,
precision=eval_result.precision,
mrr=eval_result.mrr,
naucs=eval_result.naucs,
naucs_mrr=eval_result.naucs_mrr,
cv_recall=eval_result.cv_recall,
task_scores={},
)
return scores
def evaluate_task(
task_name: str,
model_name: str,
index_path: str,
use_fast_plaid: bool = False,
fast_plaid_index_path: Optional[str] = None,
language: Optional[str] = None,
rebuild_index: bool = False,
top_k: int = 100,
first_stage_k: int = 500,
k_values: list[int] = [1, 3, 5, 10, 100],
output_dir: Optional[str] = None,
):
"""
Evaluate a single ViDoRe v2 task.
"""
print(f"\n{'='*80}")
print(f"Evaluating task: {task_name}")
print(f"{'='*80}")
# Get task config
if task_name not in VIDORE_V2_TASKS:
raise ValueError(f"Unknown task: {task_name}. Available: {list(VIDORE_V2_TASKS.keys())}")
task_config = VIDORE_V2_TASKS[task_name]
dataset_path = task_config["dataset_path"]
revision = task_config["revision"]
# Determine language
if language is None:
# Use first language if multiple available
languages = task_config.get("languages")
if languages is None:
# Task doesn't support language filtering (e.g., Vidore2ESGReportsHLRetrieval)
language = None
elif len(languages) == 1:
language = languages[0]
else:
language = None
# Load data
corpus, queries, qrels = load_vidore_v2_data(
dataset_path=dataset_path,
revision=revision,
split="test",
language=language,
)
# Check if we have any queries
if len(queries) == 0:
print(f"\nWarning: No queries found for task {task_name} with language {language}. Skipping evaluation.")
# Return zero scores
scores = {}
for k in k_values:
scores[f"ndcg_at_{k}"] = 0.0
scores[f"map_at_{k}"] = 0.0
scores[f"recall_at_{k}"] = 0.0
scores[f"precision_at_{k}"] = 0.0
scores[f"mrr_at_{k}"] = 0.0
return scores
# Load model
print(f"\nLoading model: {model_name}")
model_name_actual, model, processor, device_str, device, dtype = _load_colvision(model_name)
print(f"Model loaded: {model_name_actual}")
# Build or load index
index_path_full = index_path if not use_fast_plaid else fast_plaid_index_path
if index_path_full is None:
index_path_full = f"./indexes/{task_name}_{model_name}"
if use_fast_plaid:
index_path_full = f"./indexes/{task_name}_{model_name}_fastplaid"
index_or_retriever, corpus_ids_ordered = build_index_from_corpus(
corpus=corpus,
model=model,
processor=processor,
index_path=index_path_full,
use_fast_plaid=use_fast_plaid,
rebuild=rebuild_index,
)
# Search queries
task_prompt = task_config.get("prompt")
results = search_queries(
queries=queries,
corpus_ids=corpus_ids_ordered,
model=model,
processor=processor,
index_or_retriever=index_or_retriever,
use_fast_plaid=use_fast_plaid,
fast_plaid_index_path=fast_plaid_index_path,
top_k=top_k,
first_stage_k=first_stage_k,
task_prompt=task_prompt,
)
# Evaluate
scores = evaluate_results(results, qrels, k_values=k_values)
# Print results
print(f"\n{'='*80}")
print(f"Results for {task_name}:")
print(f"{'='*80}")
for metric, value in scores.items():
if isinstance(value, (int, float)):
print(f" {metric}: {value:.5f}")
# Save results
if output_dir:
os.makedirs(output_dir, exist_ok=True)
results_file = os.path.join(output_dir, f"{task_name}_results.json")
scores_file = os.path.join(output_dir, f"{task_name}_scores.json")
with open(results_file, "w") as f:
json.dump(results, f, indent=2)
print(f"\nSaved results to: {results_file}")
with open(scores_file, "w") as f:
json.dump(scores, f, indent=2)
print(f"Saved scores to: {scores_file}")
return scores
def main():
parser = argparse.ArgumentParser(
description="Evaluate ViDoRe v2 benchmark using LEANN/Fast-Plaid indexing"
)
parser.add_argument(
"--model",
type=str,
default="colqwen2",
choices=["colqwen2", "colpali"],
help="Model to use",
)
parser.add_argument(
"--task",
type=str,
default=None,
help="Specific task to evaluate (or 'all' for all tasks)",
)
parser.add_argument(
"--tasks",
type=str,
default="all",
help="Tasks to evaluate: 'all' or comma-separated list",
)
parser.add_argument(
"--index-path",
type=str,
default=None,
help="Path to LEANN index (auto-generated if not provided)",
)
parser.add_argument(
"--use-fast-plaid",
action="store_true",
help="Use Fast-Plaid instead of LEANN",
)
parser.add_argument(
"--fast-plaid-index-path",
type=str,
default=None,
help="Path to Fast-Plaid index (auto-generated if not provided)",
)
parser.add_argument(
"--rebuild-index",
action="store_true",
help="Rebuild index even if it exists",
)
parser.add_argument(
"--language",
type=str,
default=None,
help="Language to evaluate (default: first available)",
)
parser.add_argument(
"--top-k",
type=int,
default=100,
help="Top-k results to retrieve",
)
parser.add_argument(
"--first-stage-k",
type=int,
default=500,
help="First stage k for LEANN search",
)
parser.add_argument(
"--k-values",
type=str,
default="1,3,5,10,100",
help="Comma-separated k values for evaluation (e.g., '1,3,5,10,100')",
)
parser.add_argument(
"--output-dir",
type=str,
default="./vidore_v2_results",
help="Output directory for results",
)
args = parser.parse_args()
# Parse k_values
k_values = [int(k.strip()) for k in args.k_values.split(",")]
# Determine tasks to evaluate
if args.task:
tasks_to_eval = [args.task]
elif args.tasks.lower() == "all":
tasks_to_eval = list(VIDORE_V2_TASKS.keys())
else:
tasks_to_eval = [t.strip() for t in args.tasks.split(",")]
print(f"Tasks to evaluate: {tasks_to_eval}")
# Evaluate each task
all_scores = {}
for task_name in tasks_to_eval:
try:
scores = evaluate_task(
task_name=task_name,
model_name=args.model,
index_path=args.index_path,
use_fast_plaid=args.use_fast_plaid,
fast_plaid_index_path=args.fast_plaid_index_path,
language=args.language,
rebuild_index=args.rebuild_index,
top_k=args.top_k,
first_stage_k=args.first_stage_k,
k_values=k_values,
output_dir=args.output_dir,
)
all_scores[task_name] = scores
except Exception as e:
print(f"\nError evaluating {task_name}: {e}")
import traceback
traceback.print_exc()
continue
# Print summary
if all_scores:
print(f"\n{'='*80}")
print("SUMMARY")
print(f"{'='*80}")
for task_name, scores in all_scores.items():
print(f"\n{task_name}:")
# Print main metrics
for metric in ["ndcg_at_5", "ndcg_at_10", "ndcg_at_100", "map_at_10", "recall_at_10"]:
if metric in scores:
print(f" {metric}: {scores[metric]:.5f}")
if __name__ == "__main__":
main()