Files
LEANN/apps/multimodal/vision-based-pdf-multi-vector/leann_multi_vector.py
Yichuan Wang 76cc798e3e Feat/multi vector timing and dataset improvements (#181)
* Add timing instrumentation and multi-dataset support for multi-vector retrieval

- Add timing measurements for search operations (load and core time)
- Increase embedding batch size from 1 to 32 for better performance
- Add explicit memory cleanup with del all_embeddings
- Support loading and merging multiple datasets with different splits
- Add CLI arguments for search method selection (ann/exact/exact-all)
- Auto-detect image field names across different dataset structures
- Print candidate doc counts for performance monitoring

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>

* update vidore

* reproduce docvqa results

* reproduce docvqa results and add debug file

* fix: format colqwen_forward.py to pass pre-commit checks

---------

Co-authored-by: Claude <noreply@anthropic.com>
2025-12-03 01:10:49 -08:00

1335 lines
48 KiB
Python

import concurrent.futures
import json
import os
import re
import sys
import time
from pathlib import Path
from typing import Any, Optional, cast
import numpy as np
from PIL import Image
from tqdm import tqdm
def _ensure_repo_paths_importable(current_file: str) -> None:
"""Make local leann packages importable without installing (mirrors multi-vector-leann.py)."""
_repo_root = Path(current_file).resolve().parents[3]
_leann_core_src = _repo_root / "packages" / "leann-core" / "src"
_leann_hnsw_pkg = _repo_root / "packages" / "leann-backend-hnsw"
if str(_leann_core_src) not in sys.path:
sys.path.append(str(_leann_core_src))
if str(_leann_hnsw_pkg) not in sys.path:
sys.path.append(str(_leann_hnsw_pkg))
def _find_backend_module_file() -> Optional[Path]:
"""Best-effort locate the backend leann_multi_vector.py file, avoiding this file."""
this_file = Path(__file__).resolve()
candidates: list[Path] = []
# Common in-repo location
repo_root = this_file.parents[3]
candidates.append(repo_root / "packages" / "leann-backend-hnsw" / "leann_multi_vector.py")
candidates.append(
repo_root / "packages" / "leann-backend-hnsw" / "src" / "leann_multi_vector.py"
)
for cand in candidates:
try:
if cand.exists() and cand.resolve() != this_file:
return cand.resolve()
except Exception:
pass
# Fallback: scan sys.path for another leann_multi_vector.py different from this file
for p in list(sys.path):
try:
cand = Path(p) / "leann_multi_vector.py"
if cand.exists() and cand.resolve() != this_file:
return cand.resolve()
except Exception:
continue
return None
_BACKEND_LEANN_CLASS: Optional[type] = None
def _get_backend_leann_multi_vector() -> type:
"""Load backend LeannMultiVector class even if this file shadows its module name."""
global _BACKEND_LEANN_CLASS
if _BACKEND_LEANN_CLASS is not None:
return _BACKEND_LEANN_CLASS
backend_path = _find_backend_module_file()
if backend_path is None:
# Fallback to local implementation in this module
try:
cls = LeannMultiVector # type: ignore[name-defined]
_BACKEND_LEANN_CLASS = cls
return cls
except Exception as e:
raise ImportError(
"Could not locate backend 'leann_multi_vector.py' and no local implementation found. "
"Ensure the leann backend is available under packages/leann-backend-hnsw or installed."
) from e
import importlib.util
module_name = "leann_hnsw_backend_module"
spec = importlib.util.spec_from_file_location(module_name, str(backend_path))
if spec is None or spec.loader is None:
raise ImportError(f"Failed to create spec for backend module at {backend_path}")
backend_module = importlib.util.module_from_spec(spec)
sys.modules[module_name] = backend_module
spec.loader.exec_module(backend_module) # type: ignore[assignment]
if not hasattr(backend_module, "LeannMultiVector"):
raise ImportError(f"'LeannMultiVector' not found in backend module at {backend_path}")
_BACKEND_LEANN_CLASS = backend_module.LeannMultiVector
return _BACKEND_LEANN_CLASS
def _natural_sort_key(name: str) -> int:
m = re.search(r"\d+", name)
return int(m.group()) if m else 0
def _load_images_from_dir(pages_dir: str) -> tuple[list[str], list[Image.Image]]:
filenames = [n for n in os.listdir(pages_dir) if n.lower().endswith((".png", ".jpg", ".jpeg"))]
filenames = sorted(filenames, key=_natural_sort_key)
filepaths = [os.path.join(pages_dir, n) for n in filenames]
images = [Image.open(p) for p in filepaths]
return filepaths, images
def _maybe_convert_pdf_to_images(pdf_path: Optional[str], pages_dir: str, dpi: int = 200) -> None:
if not pdf_path:
return
os.makedirs(pages_dir, exist_ok=True)
try:
from pdf2image import convert_from_path
except Exception as e:
raise RuntimeError(
"pdf2image is required to convert PDF to images. Install via pip install pdf2image"
) from e
images = convert_from_path(pdf_path, dpi=dpi)
for i, image in enumerate(images):
image.save(os.path.join(pages_dir, f"page_{i + 1}.png"), "PNG")
def _select_device_and_dtype():
import torch
from colpali_engine.utils.torch_utils import get_torch_device
device_str = (
"cuda"
if torch.cuda.is_available()
else (
"mps"
if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available()
else "cpu"
)
)
device = get_torch_device(device_str)
# Stable dtype selection to avoid NaNs:
# - CUDA: prefer bfloat16 if supported, else float16
# - MPS: use float32 (fp16 on MPS can produce NaNs in some ops)
# - CPU: float32
if device_str == "cuda":
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
try:
torch.backends.cuda.matmul.allow_tf32 = True # Better stability/perf on Ampere+
except Exception:
pass
elif device_str == "mps":
dtype = torch.float32
else:
dtype = torch.float32
return device_str, device, dtype
def _load_colvision(model_choice: str):
import torch
from colpali_engine.models import ColPali, ColQwen2, ColQwen2Processor
from colpali_engine.models.paligemma.colpali.processing_colpali import ColPaliProcessor
from transformers.utils.import_utils import is_flash_attn_2_available
device_str, device, dtype = _select_device_and_dtype()
if model_choice == "colqwen2":
model_name = "vidore/colqwen2-v1.0"
# On CPU/MPS we must avoid flash-attn and stay eager; on CUDA prefer flash-attn if available
attn_implementation = (
"flash_attention_2"
if (device_str == "cuda" and is_flash_attn_2_available())
else "eager"
)
model = ColQwen2.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
device_map=device,
attn_implementation=attn_implementation,
).eval()
processor = ColQwen2Processor.from_pretrained(model_name)
else:
model_name = "vidore/colpali-v1.2"
model = ColPali.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
device_map=device,
).eval()
processor = cast(ColPaliProcessor, ColPaliProcessor.from_pretrained(model_name))
return model_name, model, processor, device_str, device, dtype
def _embed_images(model, processor, images: list[Image.Image]) -> list[Any]:
import torch
from colpali_engine.utils.torch_utils import ListDataset
from torch.utils.data import DataLoader
# Ensure deterministic eval and autocast for stability
model.eval()
dataloader = DataLoader(
dataset=ListDataset[Image.Image](images),
batch_size=32,
shuffle=False,
collate_fn=lambda x: processor.process_images(x),
)
doc_vecs: list[Any] = []
for batch_doc in tqdm(dataloader, desc="Embedding images"):
with torch.no_grad():
batch_doc = {k: v.to(model.device) for k, v in batch_doc.items()}
# autocast on CUDA for bf16/fp16; on CPU/MPS stay in fp32
if model.device.type == "cuda":
with torch.autocast(
device_type="cuda",
dtype=model.dtype if model.dtype.is_floating_point else torch.bfloat16,
):
embeddings_doc = model(**batch_doc)
else:
embeddings_doc = model(**batch_doc)
doc_vecs.extend(list(torch.unbind(embeddings_doc.to("cpu"))))
return doc_vecs
def _embed_queries(model, processor, queries: list[str]) -> list[Any]:
import torch
model.eval()
# Match MTEB's exact query processing from ColPaliEngineWrapper.get_text_embeddings:
# 1. MTEB receives batch["text"] which already includes instruction/prompt (from _combine_queries_with_instruction_text)
# 2. Manually adds: query_prefix + text + query_augmentation_token * 10
# 3. Calls processor.process_queries(batch) where batch is now a list of strings
# 4. process_queries adds: query_prefix + text + suffix (suffix = query_augmentation_token * 10)
#
# This results in duplicate addition: query_prefix is added twice, query_augmentation_token * 20 total
# We need to match this exactly to reproduce MTEB results
all_embeds = []
batch_size = 32 # Match MTEB's default batch_size
with torch.no_grad():
for i in tqdm(range(0, len(queries), batch_size), desc="Embedding queries"):
batch_queries = queries[i : i + batch_size]
# Match MTEB: manually add query_prefix + text + query_augmentation_token * 10
# Then process_queries will add them again (resulting in 20 augmentation tokens total)
batch = [
processor.query_prefix + t + processor.query_augmentation_token * 10
for t in batch_queries
]
inputs = processor.process_queries(batch)
inputs = {k: v.to(model.device) for k, v in inputs.items()}
if model.device.type == "cuda":
with torch.autocast(
device_type="cuda",
dtype=model.dtype if model.dtype.is_floating_point else torch.bfloat16,
):
outs = model(**inputs)
else:
outs = model(**inputs)
# Match MTEB: convert to float32 on CPU
all_embeds.extend(list(torch.unbind(outs.cpu().to(torch.float32))))
return all_embeds
def _build_index(
index_path: str, doc_vecs: list[Any], filepaths: list[str], images: list[Image.Image]
) -> Any:
LeannMultiVector = _get_backend_leann_multi_vector()
dim = int(doc_vecs[0].shape[-1])
retriever = LeannMultiVector(index_path=index_path, dim=dim)
retriever.create_collection()
for i, vec in enumerate(doc_vecs):
data = {
"colbert_vecs": vec.float().numpy(),
"doc_id": i,
"filepath": filepaths[i],
"image": images[i], # Include the original image
}
retriever.insert(data)
retriever.create_index()
return retriever
def _load_retriever_if_index_exists(index_path: str) -> Optional[Any]:
LeannMultiVector = _get_backend_leann_multi_vector()
index_base = Path(index_path)
# Check for the actual HNSW index file written by the backend + our sidecar files
index_file = index_base.parent / f"{index_base.stem}.index"
meta = index_base.parent / f"{index_base.name}.meta.json"
labels = index_base.parent / f"{index_base.name}.labels.json"
if index_file.exists() and meta.exists() and labels.exists():
try:
with open(meta, encoding="utf-8") as f:
meta_json = json.load(f)
dim = int(meta_json.get("dimensions", 128))
except Exception:
dim = 128
return LeannMultiVector(index_path=index_path, dim=dim)
return None
def _build_fast_plaid_index(
index_path: str,
doc_vecs: list[Any],
filepaths: list[str],
images: list[Image.Image],
) -> tuple[Any, float]:
"""
Build a Fast-Plaid index from document embeddings.
Args:
index_path: Path to save the Fast-Plaid index
doc_vecs: List of document embeddings (each is a tensor with shape [num_tokens, embedding_dim])
filepaths: List of filepath identifiers for each document
images: List of PIL Images corresponding to each document
Returns:
Tuple of (FastPlaid index object, build_time_in_seconds)
"""
import torch
from fast_plaid import search as fast_plaid_search
print(f" Preparing {len(doc_vecs)} document embeddings for Fast-Plaid...")
_t0 = time.perf_counter()
# Convert doc_vecs to list of tensors
documents_embeddings = []
for i, vec in enumerate(doc_vecs):
if i % 1000 == 0:
print(f" Converting embedding {i}/{len(doc_vecs)}...")
if not isinstance(vec, torch.Tensor):
vec = (
torch.tensor(vec)
if isinstance(vec, np.ndarray)
else torch.from_numpy(np.array(vec))
)
# Ensure float32 for Fast-Plaid
if vec.dtype != torch.float32:
vec = vec.float()
documents_embeddings.append(vec)
print(f" Converted {len(documents_embeddings)} embeddings")
if len(documents_embeddings) > 0:
print(f" First embedding shape: {documents_embeddings[0].shape}")
print(f" First embedding dtype: {documents_embeddings[0].dtype}")
# Prepare metadata for Fast-Plaid
print(f" Preparing metadata for {len(filepaths)} documents...")
metadata_list = []
for i, filepath in enumerate(filepaths):
metadata_list.append(
{
"filepath": filepath,
"index": i,
}
)
# Create Fast-Plaid index
print(f" Creating FastPlaid object with index path: {index_path}")
try:
fast_plaid_index = fast_plaid_search.FastPlaid(index=index_path)
print(" FastPlaid object created successfully")
except Exception as e:
print(f" Error creating FastPlaid object: {type(e).__name__}: {e}")
import traceback
traceback.print_exc()
raise
print(f" Calling fast_plaid_index.create() with {len(documents_embeddings)} documents...")
try:
fast_plaid_index.create(
documents_embeddings=documents_embeddings,
metadata=metadata_list,
)
print(" Fast-Plaid index created successfully")
except Exception as e:
print(f" Error creating Fast-Plaid index: {type(e).__name__}: {e}")
import traceback
traceback.print_exc()
raise
build_secs = time.perf_counter() - _t0
# Save images separately (Fast-Plaid doesn't store images)
print(f" Saving {len(images)} images...")
images_dir = Path(index_path) / "images"
images_dir.mkdir(parents=True, exist_ok=True)
for i, img in enumerate(tqdm(images, desc="Saving images")):
img_path = images_dir / f"doc_{i}.png"
img.save(str(img_path))
return fast_plaid_index, build_secs
def _fast_plaid_index_exists(index_path: str) -> bool:
"""
Check if Fast-Plaid index exists by checking for key files.
This avoids creating the FastPlaid object which may trigger memory allocation.
Args:
index_path: Path to the Fast-Plaid index
Returns:
True if index appears to exist, False otherwise
"""
index_path_obj = Path(index_path)
if not index_path_obj.exists() or not index_path_obj.is_dir():
return False
# Fast-Plaid creates a SQLite database file for metadata
# Check for metadata.db as the most reliable indicator
metadata_db = index_path_obj / "metadata.db"
if metadata_db.exists() and metadata_db.stat().st_size > 0:
return True
# Also check if directory has any files (might be incomplete index)
try:
if any(index_path_obj.iterdir()):
return True
except Exception:
pass
return False
def _load_fast_plaid_index_if_exists(index_path: str) -> Optional[Any]:
"""
Load Fast-Plaid index if it exists.
First checks if index files exist, then creates the FastPlaid object.
The actual index data loading happens lazily when search is called.
Args:
index_path: Path to the Fast-Plaid index
Returns:
FastPlaid index object if exists, None otherwise
"""
try:
from fast_plaid import search as fast_plaid_search
# First check if index files exist without creating the object
if not _fast_plaid_index_exists(index_path):
return None
# Now try to create FastPlaid object
# This may trigger some memory allocation, but the full index loading is deferred
fast_plaid_index = fast_plaid_search.FastPlaid(index=index_path)
return fast_plaid_index
except ImportError:
# fast-plaid not installed
return None
except Exception as e:
# Any error (including memory errors from Rust backend) - return None
# The error will be caught and index will be rebuilt
print(f"Warning: Could not load Fast-Plaid index: {type(e).__name__}: {e}")
return None
def _search_fast_plaid(
fast_plaid_index: Any,
query_vec: Any,
top_k: int,
) -> tuple[list[tuple[float, int]], float]:
"""
Search Fast-Plaid index with a query embedding.
Args:
fast_plaid_index: FastPlaid index object
query_vec: Query embedding tensor with shape [num_tokens, embedding_dim]
top_k: Number of top results to return
Returns:
Tuple of (results_list, search_time_in_seconds)
results_list: List of (score, doc_id) tuples
"""
import torch
_t0 = time.perf_counter()
# Ensure query is a torch tensor
if not isinstance(query_vec, torch.Tensor):
q_vec_tensor = (
torch.tensor(query_vec)
if isinstance(query_vec, np.ndarray)
else torch.from_numpy(np.array(query_vec))
)
else:
q_vec_tensor = query_vec
# Fast-Plaid expects shape [num_queries, num_tokens, embedding_dim]
if q_vec_tensor.dim() == 2:
q_vec_tensor = q_vec_tensor.unsqueeze(0) # [1, num_tokens, embedding_dim]
# Perform search
scores = fast_plaid_index.search(
queries_embeddings=q_vec_tensor,
top_k=top_k,
show_progress=True,
)
search_secs = time.perf_counter() - _t0
# Convert Fast-Plaid results to same format as LEANN: list of (score, doc_id) tuples
results = []
if scores and len(scores) > 0:
query_results = scores[0]
# Fast-Plaid returns (doc_id, score), convert to (score, doc_id) to match LEANN format
results = [(float(score), int(doc_id)) for doc_id, score in query_results]
return results, search_secs
def _get_fast_plaid_image(index_path: str, doc_id: int) -> Optional[Image.Image]:
"""
Retrieve image for a document from Fast-Plaid index.
Args:
index_path: Path to the Fast-Plaid index
doc_id: Document ID returned by Fast-Plaid search
Returns:
PIL Image if found, None otherwise
Note: Uses metadata['index'] to get the actual file index, as Fast-Plaid
doc_id may differ from the file naming index.
"""
# First get metadata to find the actual index used for file naming
metadata = _get_fast_plaid_metadata(index_path, doc_id)
if metadata is None:
# Fallback: try using doc_id directly
file_index = doc_id
else:
# Use the 'index' field from metadata, which matches the file naming
file_index = metadata.get("index", doc_id)
images_dir = Path(index_path) / "images"
image_path = images_dir / f"doc_{file_index}.png"
if image_path.exists():
return Image.open(image_path)
# If not found with index, try doc_id as fallback
if file_index != doc_id:
fallback_path = images_dir / f"doc_{doc_id}.png"
if fallback_path.exists():
return Image.open(fallback_path)
return None
def _get_fast_plaid_metadata(index_path: str, doc_id: int) -> Optional[dict]:
"""
Retrieve metadata for a document from Fast-Plaid index.
Args:
index_path: Path to the Fast-Plaid index
doc_id: Document ID
Returns:
Dictionary with metadata if found, None otherwise
"""
try:
from fast_plaid import filtering
metadata_list = filtering.get(index=index_path, subset=[doc_id])
if metadata_list and len(metadata_list) > 0:
return metadata_list[0]
except Exception:
pass
return None
def _generate_similarity_map(
model,
processor,
image: Image.Image,
query: str,
token_idx: Optional[int] = None,
output_path: Optional[str] = None,
) -> tuple[int, float]:
import torch
from colpali_engine.interpretability import (
get_similarity_maps_from_embeddings,
plot_similarity_map,
)
batch_images = processor.process_images([image]).to(model.device)
batch_queries = processor.process_queries([query]).to(model.device)
with torch.no_grad():
image_embeddings = model.forward(**batch_images)
query_embeddings = model.forward(**batch_queries)
n_patches = processor.get_n_patches(
image_size=image.size,
spatial_merge_size=getattr(model, "spatial_merge_size", None),
)
image_mask = processor.get_image_mask(batch_images)
batched_similarity_maps = get_similarity_maps_from_embeddings(
image_embeddings=image_embeddings,
query_embeddings=query_embeddings,
n_patches=n_patches,
image_mask=image_mask,
)
similarity_maps = batched_similarity_maps[0]
# Determine token index if not provided: choose the token with highest max score
if token_idx is None:
per_token_max = similarity_maps.view(similarity_maps.shape[0], -1).max(dim=1).values
token_idx = int(per_token_max.argmax().item())
max_sim_score = similarity_maps[token_idx, :, :].max().item()
if output_path:
import matplotlib.pyplot as plt
fig, ax = plot_similarity_map(
image=image,
similarity_map=similarity_maps[token_idx],
figsize=(14, 14),
show_colorbar=False,
)
ax.set_title(f"Token #{token_idx}. MaxSim score: {max_sim_score:.2f}", fontsize=12)
os.makedirs(os.path.dirname(output_path), exist_ok=True)
plt.savefig(output_path, bbox_inches="tight")
plt.close(fig)
return token_idx, float(max_sim_score)
class QwenVL:
def __init__(self, device: str):
from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration
from transformers.utils.import_utils import is_flash_attn_2_available
attn_implementation = "flash_attention_2" if is_flash_attn_2_available() else "eager"
self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
"Qwen/Qwen2.5-VL-3B-Instruct",
torch_dtype="auto",
device_map=device,
attn_implementation=attn_implementation,
)
min_pixels = 256 * 28 * 28
max_pixels = 1280 * 28 * 28
self.processor = AutoProcessor.from_pretrained(
"Qwen/Qwen2.5-VL-3B-Instruct", min_pixels=min_pixels, max_pixels=max_pixels
)
def answer(self, query: str, images: list[Image.Image], max_new_tokens: int = 128) -> str:
import base64
from io import BytesIO
from qwen_vl_utils import process_vision_info
content = []
for img in images:
buffer = BytesIO()
img.save(buffer, format="jpeg")
img_base64 = base64.b64encode(buffer.getvalue()).decode("utf-8")
content.append({"type": "image", "image": f"data:image;base64,{img_base64}"})
content.append({"type": "text", "text": query})
messages = [{"role": "user", "content": content}]
text = self.processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
image_inputs, video_inputs = process_vision_info(messages)
inputs = self.processor(
text=[text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt"
)
inputs = inputs.to(self.model.device)
generated_ids = self.model.generate(**inputs, max_new_tokens=max_new_tokens)
generated_ids_trimmed = [
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
return self.processor.batch_decode(
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)[0]
# Ensure repo paths are importable for dynamic backend loading
_ensure_repo_paths_importable(__file__)
from leann_backend_hnsw.hnsw_backend import HNSWBuilder, HNSWSearcher # noqa: E402
class LeannMultiVector:
def __init__(
self,
index_path: str,
dim: int = 128,
distance_metric: str = "mips",
m: int = 16,
ef_construction: int = 500,
is_compact: bool = False,
is_recompute: bool = False,
embedding_model_name: str = "colvision",
) -> None:
self.index_path = index_path
self.dim = dim
self.embedding_model_name = embedding_model_name
self._pending_items: list[dict] = []
self._backend_kwargs = {
"distance_metric": distance_metric,
"M": m,
"efConstruction": ef_construction,
"is_compact": is_compact,
"is_recompute": is_recompute,
}
self._labels_meta: list[dict] = []
self._docid_to_indices: dict[int, list[int]] | None = None
def _meta_dict(self) -> dict:
return {
"version": "1.0",
"backend_name": "hnsw",
"embedding_model": self.embedding_model_name,
"embedding_mode": "custom",
"dimensions": self.dim,
"backend_kwargs": self._backend_kwargs,
"is_compact": self._backend_kwargs.get("is_compact", True),
"is_pruned": self._backend_kwargs.get("is_compact", True)
and self._backend_kwargs.get("is_recompute", True),
}
def create_collection(self) -> None:
path = Path(self.index_path)
path.parent.mkdir(parents=True, exist_ok=True)
def insert(self, data: dict) -> None:
self._pending_items.append(
{
"doc_id": int(data["doc_id"]),
"filepath": data.get("filepath", ""),
"colbert_vecs": [np.asarray(v, dtype=np.float32) for v in data["colbert_vecs"]],
"image": data.get("image"), # PIL Image object (optional)
}
)
def _labels_path(self) -> Path:
index_path_obj = Path(self.index_path)
return index_path_obj.parent / f"{index_path_obj.name}.labels.json"
def _meta_path(self) -> Path:
index_path_obj = Path(self.index_path)
return index_path_obj.parent / f"{index_path_obj.name}.meta.json"
def _embeddings_path(self) -> Path:
index_path_obj = Path(self.index_path)
return index_path_obj.parent / f"{index_path_obj.name}.emb.npy"
def _images_dir_path(self) -> Path:
"""Directory where original images are stored."""
index_path_obj = Path(self.index_path)
return index_path_obj.parent / f"{index_path_obj.name}.images"
def create_index(self) -> None:
if not self._pending_items:
return
embeddings: list[np.ndarray] = []
labels_meta: list[dict] = []
# Create images directory if needed
images_dir = self._images_dir_path()
images_dir.mkdir(parents=True, exist_ok=True)
for item in self._pending_items:
doc_id = int(item["doc_id"])
filepath = item.get("filepath", "")
colbert_vecs = item["colbert_vecs"]
image = item.get("image")
# Save image if provided
image_path = ""
if image is not None and isinstance(image, Image.Image):
image_filename = f"doc_{doc_id}.png"
image_path = str(images_dir / image_filename)
image.save(image_path, "PNG")
for seq_id, vec in enumerate(colbert_vecs):
vec_np = np.asarray(vec, dtype=np.float32)
embeddings.append(vec_np)
labels_meta.append(
{
"id": f"{doc_id}:{seq_id}",
"doc_id": doc_id,
"seq_id": int(seq_id),
"filepath": filepath,
"image_path": image_path, # Store the path to the saved image
}
)
if not embeddings:
return
embeddings_np = np.vstack(embeddings).astype(np.float32)
print(embeddings_np.shape)
builder = HNSWBuilder(**{**self._backend_kwargs, "dimensions": self.dim})
ids = [str(i) for i in range(embeddings_np.shape[0])]
builder.build(embeddings_np, ids, self.index_path)
import json as _json
with open(self._meta_path(), "w", encoding="utf-8") as f:
_json.dump(self._meta_dict(), f, indent=2)
with open(self._labels_path(), "w", encoding="utf-8") as f:
_json.dump(labels_meta, f)
# Persist embeddings for exact reranking
np.save(self._embeddings_path(), embeddings_np)
self._labels_meta = labels_meta
def _load_labels_meta_if_needed(self) -> None:
if self._labels_meta:
return
labels_path = self._labels_path()
if labels_path.exists():
import json as _json
with open(labels_path, encoding="utf-8") as f:
self._labels_meta = _json.load(f)
def _build_docid_to_indices_if_needed(self) -> None:
if self._docid_to_indices is not None:
return
self._load_labels_meta_if_needed()
mapping: dict[int, list[int]] = {}
for idx, meta in enumerate(self._labels_meta):
try:
doc_id = int(meta["doc_id"]) # type: ignore[index]
except Exception:
continue
mapping.setdefault(doc_id, []).append(idx)
self._docid_to_indices = mapping
def search(
self, data: np.ndarray, topk: int, first_stage_k: int = 50
) -> list[tuple[float, int]]:
if data.ndim == 1:
data = data.reshape(1, -1)
if data.dtype != np.float32:
data = data.astype(np.float32)
self._load_labels_meta_if_needed()
searcher = HNSWSearcher(self.index_path, meta=self._meta_dict())
raw = searcher.search(
data,
first_stage_k,
recompute_embeddings=False,
complexity=128,
beam_width=1,
prune_ratio=0.0,
batch_size=0,
)
labels = raw.get("labels")
distances = raw.get("distances")
if labels is None or distances is None:
return []
doc_scores: dict[int, float] = {}
B = len(labels)
for b in range(B):
per_doc_best: dict[int, float] = {}
for k, sid in enumerate(labels[b]):
try:
idx = int(sid)
except Exception:
continue
if 0 <= idx < len(self._labels_meta):
doc_id = int(self._labels_meta[idx]["doc_id"]) # type: ignore[index]
else:
continue
score = float(distances[b][k])
if (doc_id not in per_doc_best) or (score > per_doc_best[doc_id]):
per_doc_best[doc_id] = score
for doc_id, best_score in per_doc_best.items():
doc_scores[doc_id] = doc_scores.get(doc_id, 0.0) + best_score
scores = sorted(((v, k) for k, v in doc_scores.items()), key=lambda x: x[0], reverse=True)
return scores[:topk] if len(scores) >= topk else scores
def search_exact(
self,
data: np.ndarray,
topk: int,
*,
first_stage_k: int = 200,
max_workers: int = 32,
) -> list[tuple[float, int]]:
"""
High-precision MaxSim reranking over candidate documents.
Steps:
1) Run a first-stage ANN to collect candidate doc_ids (using seq-level neighbors).
2) For each candidate doc, load all its token embeddings and compute
MaxSim(query_tokens, doc_tokens) exactly: sum(max(dot(q_i, d_j))).
Returns top-k list of (score, doc_id).
"""
# Normalize inputs
if data.ndim == 1:
data = data.reshape(1, -1)
if data.dtype != np.float32:
data = data.astype(np.float32)
self._load_labels_meta_if_needed()
self._build_docid_to_indices_if_needed()
emb_path = self._embeddings_path()
if not emb_path.exists():
# Fallback to approximate if we don't have persisted embeddings
return self.search(data, topk, first_stage_k=first_stage_k)
# Memory-map embeddings to avoid loading all into RAM
all_embeddings = np.load(emb_path, mmap_mode="r")
if all_embeddings.dtype != np.float32:
all_embeddings = all_embeddings.astype(np.float32)
# First-stage ANN to collect candidate doc_ids
searcher = HNSWSearcher(self.index_path, meta=self._meta_dict())
raw = searcher.search(
data,
first_stage_k,
recompute_embeddings=False,
complexity=128,
beam_width=1,
prune_ratio=0.0,
batch_size=0,
)
labels = raw.get("labels")
if labels is None:
return []
candidate_doc_ids: set[int] = set()
for batch in labels:
for sid in batch:
try:
idx = int(sid)
except Exception:
continue
if 0 <= idx < len(self._labels_meta):
candidate_doc_ids.add(int(self._labels_meta[idx]["doc_id"])) # type: ignore[index]
# Exact scoring per doc (parallelized)
assert self._docid_to_indices is not None
def _score_one(doc_id: int) -> tuple[float, int]:
token_indices = self._docid_to_indices.get(doc_id, [])
if not token_indices:
return (0.0, doc_id)
doc_vecs = np.asarray(all_embeddings[token_indices], dtype=np.float32)
# (Q, D) x (P, D)^T -> (Q, P) then MaxSim over P, sum over Q
sim = np.dot(data, doc_vecs.T)
# nan-safe
sim = np.nan_to_num(sim, nan=-1e30, posinf=1e30, neginf=-1e30)
score = sim.max(axis=2).sum(axis=1) if sim.ndim == 3 else sim.max(axis=1).sum()
return (float(score), doc_id)
scores: list[tuple[float, int]] = []
# load and core time
start_time = time.time()
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as ex:
futures = [ex.submit(_score_one, doc_id) for doc_id in candidate_doc_ids]
for fut in concurrent.futures.as_completed(futures):
scores.append(fut.result())
end_time = time.time()
print(f"Number of candidate doc ids: {len(candidate_doc_ids)}")
print(f"Time taken in load and core time: {end_time - start_time} seconds")
scores.sort(key=lambda x: x[0], reverse=True)
return scores[:topk] if len(scores) >= topk else scores
def search_exact_all(
self,
data: np.ndarray,
topk: int,
*,
max_workers: int = 32,
) -> list[tuple[float, int]]:
"""
Exact MaxSim over ALL documents (no ANN pre-filtering).
This computes, for each document, sum_i max_j dot(q_i, d_j).
It memory-maps the persisted token-embedding matrix for scalability.
"""
if data.ndim == 1:
data = data.reshape(1, -1)
if data.dtype != np.float32:
data = data.astype(np.float32)
self._load_labels_meta_if_needed()
self._build_docid_to_indices_if_needed()
emb_path = self._embeddings_path()
if not emb_path.exists():
return self.search(data, topk)
all_embeddings = np.load(emb_path, mmap_mode="r")
if all_embeddings.dtype != np.float32:
all_embeddings = all_embeddings.astype(np.float32)
assert self._docid_to_indices is not None
candidate_doc_ids = list(self._docid_to_indices.keys())
def _score_one(doc_id: int, _all_embeddings=all_embeddings) -> tuple[float, int]:
token_indices = self._docid_to_indices.get(doc_id, [])
if not token_indices:
return (0.0, doc_id)
doc_vecs = np.asarray(_all_embeddings[token_indices], dtype=np.float32)
sim = np.dot(data, doc_vecs.T)
sim = np.nan_to_num(sim, nan=-1e30, posinf=1e30, neginf=-1e30)
score = sim.max(axis=2).sum(axis=1) if sim.ndim == 3 else sim.max(axis=1).sum()
return (float(score), doc_id)
scores: list[tuple[float, int]] = []
# load and core time
start_time = time.time()
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as ex:
futures = [ex.submit(_score_one, d) for d in candidate_doc_ids]
for fut in concurrent.futures.as_completed(futures):
scores.append(fut.result())
end_time = time.time()
# print number of candidate doc ids
print(f"Number of candidate doc ids: {len(candidate_doc_ids)}")
print(f"Time taken in load and core time: {end_time - start_time} seconds")
scores.sort(key=lambda x: x[0], reverse=True)
del all_embeddings
return scores[:topk] if len(scores) >= topk else scores
def get_image(self, doc_id: int) -> Optional[Image.Image]:
"""
Retrieve the original image for a given doc_id from the index.
Args:
doc_id: The document ID
Returns:
PIL Image object if found, None otherwise
"""
self._load_labels_meta_if_needed()
# Find the image_path for this doc_id (all seq_ids for same doc share the same image_path)
for meta in self._labels_meta:
if meta.get("doc_id") == doc_id:
image_path = meta.get("image_path", "")
if image_path and Path(image_path).exists():
return Image.open(image_path)
break
return None
def get_metadata(self, doc_id: int) -> Optional[dict]:
"""
Retrieve metadata for a given doc_id.
Args:
doc_id: The document ID
Returns:
Dictionary with metadata (filepath, image_path, etc.) if found, None otherwise
"""
self._load_labels_meta_if_needed()
for meta in self._labels_meta:
if meta.get("doc_id") == doc_id:
return {
"doc_id": doc_id,
"filepath": meta.get("filepath", ""),
"image_path": meta.get("image_path", ""),
}
return None
class ViDoReBenchmarkEvaluator:
"""
A reusable class for evaluating ViDoRe benchmarks (v1 and v2).
This class encapsulates common functionality for building indexes, searching, and evaluating.
"""
def __init__(
self,
model_name: str,
use_fast_plaid: bool = False,
top_k: int = 100,
first_stage_k: int = 500,
k_values: Optional[list[int]] = None,
):
"""
Initialize the evaluator.
Args:
model_name: Model name ("colqwen2" or "colpali")
use_fast_plaid: Whether to use Fast-Plaid instead of LEANN
top_k: Top-k results to retrieve
first_stage_k: First stage k for LEANN search
k_values: List of k values for evaluation metrics
"""
self.model_name = model_name
self.use_fast_plaid = use_fast_plaid
self.top_k = top_k
self.first_stage_k = first_stage_k
self.k_values = k_values if k_values is not None else [1, 3, 5, 10, 100]
# Load model once (can be reused across tasks)
self._model = None
self._processor = None
self._model_name_actual = None
def _load_model_if_needed(self):
"""Lazy load the model."""
if self._model is None:
print(f"\nLoading model: {self.model_name}")
self._model_name_actual, self._model, self._processor, _, _, _ = _load_colvision(
self.model_name
)
print(f"Model loaded: {self._model_name_actual}")
def build_index_from_corpus(
self,
corpus: dict[str, Image.Image],
index_path: str,
rebuild: bool = False,
) -> tuple[Any, list[str]]:
"""
Build index from corpus images.
Args:
corpus: dict mapping corpus_id to PIL Image
index_path: Path to save/load the index
rebuild: Whether to rebuild even if index exists
Returns:
tuple: (retriever or fast_plaid_index object, list of corpus_ids in order)
"""
self._load_model_if_needed()
# Ensure consistent ordering
corpus_ids = sorted(corpus.keys())
images = [corpus[cid] for cid in corpus_ids]
if self.use_fast_plaid:
# Check if Fast-Plaid index exists
if not rebuild and _load_fast_plaid_index_if_exists(index_path) is not None:
print(f"Fast-Plaid index already exists at {index_path}")
return _load_fast_plaid_index_if_exists(index_path), corpus_ids
print(f"Building Fast-Plaid index at {index_path}...")
print("Embedding images...")
doc_vecs = _embed_images(self._model, self._processor, images)
fast_plaid_index, build_time = _build_fast_plaid_index(
index_path, doc_vecs, corpus_ids, images
)
print(f"Fast-Plaid index built in {build_time:.2f}s")
return fast_plaid_index, corpus_ids
else:
# Check if LEANN index exists
if not rebuild:
retriever = _load_retriever_if_index_exists(index_path)
if retriever is not None:
print(f"LEANN index already exists at {index_path}")
return retriever, corpus_ids
print(f"Building LEANN index at {index_path}...")
print("Embedding images...")
doc_vecs = _embed_images(self._model, self._processor, images)
retriever = _build_index(index_path, doc_vecs, corpus_ids, images)
print("LEANN index built")
return retriever, corpus_ids
def search_queries(
self,
queries: dict[str, str],
corpus_ids: list[str],
index_or_retriever: Any,
fast_plaid_index_path: Optional[str] = None,
task_prompt: Optional[dict[str, str]] = None,
) -> dict[str, dict[str, float]]:
"""
Search queries against the index.
Args:
queries: dict mapping query_id to query text
corpus_ids: list of corpus_ids in the same order as the index
index_or_retriever: index or retriever object
fast_plaid_index_path: path to Fast-Plaid index (for metadata)
task_prompt: Optional dict with prompt for query (e.g., {"query": "..."})
Returns:
results: dict mapping query_id to dict of {corpus_id: score}
"""
self._load_model_if_needed()
print(f"Searching {len(queries)} queries (top_k={self.top_k})...")
query_ids = list(queries.keys())
query_texts = [queries[qid] for qid in query_ids]
# Note: ColPaliEngineWrapper does NOT use task prompt from metadata
# It uses query_prefix + text + query_augmentation_token (handled in _embed_queries)
# So we don't append task_prompt here to match MTEB behavior
# Embed queries
print("Embedding queries...")
query_vecs = _embed_queries(self._model, self._processor, query_texts)
results = {}
for query_id, query_vec in zip(tqdm(query_ids, desc="Searching"), query_vecs):
if self.use_fast_plaid:
# Fast-Plaid search
search_results, _ = _search_fast_plaid(index_or_retriever, query_vec, self.top_k)
query_results = {}
for score, doc_id in search_results:
if doc_id < len(corpus_ids):
corpus_id = corpus_ids[doc_id]
query_results[corpus_id] = float(score)
else:
# LEANN search
import torch
query_np = (
query_vec.float().numpy() if isinstance(query_vec, torch.Tensor) else query_vec
)
search_results = index_or_retriever.search_exact(query_np, topk=self.top_k)
query_results = {}
for score, doc_id in search_results:
if doc_id < len(corpus_ids):
corpus_id = corpus_ids[doc_id]
query_results[corpus_id] = float(score)
results[query_id] = query_results
return results
@staticmethod
def evaluate_results(
results: dict[str, dict[str, float]],
qrels: dict[str, dict[str, int]],
k_values: Optional[list[int]] = None,
) -> dict[str, float]:
"""
Evaluate retrieval results using NDCG and other metrics.
Args:
results: dict mapping query_id to dict of {corpus_id: score}
qrels: dict mapping query_id to dict of {corpus_id: relevance_score}
k_values: List of k values for evaluation metrics
Returns:
Dictionary of metric scores
"""
try:
from mteb._evaluators.retrieval_metrics import (
calculate_retrieval_scores,
make_score_dict,
)
except ImportError:
raise ImportError(
"pytrec_eval is required for evaluation. Install with: pip install pytrec-eval"
)
if k_values is None:
k_values = [1, 3, 5, 10, 100]
# Check if we have any queries to evaluate
if len(results) == 0:
print("Warning: No queries to evaluate. Returning zero scores.")
scores = {}
for k in k_values:
scores[f"ndcg_at_{k}"] = 0.0
scores[f"map_at_{k}"] = 0.0
scores[f"recall_at_{k}"] = 0.0
scores[f"precision_at_{k}"] = 0.0
scores[f"mrr_at_{k}"] = 0.0
return scores
print(f"Evaluating results with k_values={k_values}...")
print(f"Before filtering: {len(results)} results, {len(qrels)} qrels")
# Filter to ensure qrels and results have the same query set
# This matches MTEB behavior: only evaluate queries that exist in both
# pytrec_eval only evaluates queries in qrels, so we need to ensure
# results contains all queries in qrels, and filter out queries not in qrels
results_filtered = {qid: res for qid, res in results.items() if qid in qrels}
qrels_filtered = {
qid: rel_docs for qid, rel_docs in qrels.items() if qid in results_filtered
}
print(f"After filtering: {len(results_filtered)} results, {len(qrels_filtered)} qrels")
if len(results_filtered) != len(qrels_filtered):
print(
f"Warning: Mismatch between results ({len(results_filtered)}) and qrels ({len(qrels_filtered)}) queries"
)
missing_in_results = set(qrels.keys()) - set(results.keys())
if missing_in_results:
print(f"Queries in qrels but not in results: {len(missing_in_results)} queries")
print(f"First 5 missing queries: {list(missing_in_results)[:5]}")
# Convert qrels to pytrec_eval format
qrels_pytrec = {}
for qid, rel_docs in qrels_filtered.items():
qrels_pytrec[qid] = dict(rel_docs.items())
# Evaluate
eval_result = calculate_retrieval_scores(
results=results_filtered,
qrels=qrels_pytrec,
k_values=k_values,
)
# Format scores
scores = make_score_dict(
ndcg=eval_result.ndcg,
_map=eval_result.map,
recall=eval_result.recall,
precision=eval_result.precision,
mrr=eval_result.mrr,
naucs=eval_result.naucs,
naucs_mrr=eval_result.naucs_mrr,
cv_recall=eval_result.cv_recall,
task_scores={},
)
return scores