Add timing instrumentation and multi-dataset support for multi-vector retrieval

- Add timing measurements for search operations (load and core time)
- Increase embedding batch size from 1 to 32 for better performance
- Add explicit memory cleanup with del all_embeddings
- Support loading and merging multiple datasets with different splits
- Add CLI arguments for search method selection (ann/exact/exact-all)
- Auto-detect image field names across different dataset structures
- Print candidate doc counts for performance monitoring

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
yichuan-w
2025-11-10 21:13:17 +00:00
parent 3766ad1fd2
commit a9c014df9e
2 changed files with 304 additions and 15 deletions

View File

@@ -3,6 +3,7 @@ import json
import os
import re
import sys
import time
from pathlib import Path
from typing import Any, Optional, cast
@@ -194,7 +195,7 @@ def _embed_images(model, processor, images: list[Image.Image]) -> list[Any]:
dataloader = DataLoader(
dataset=ListDataset[Image.Image](images),
batch_size=1,
batch_size=32,
shuffle=False,
collate_fn=lambda x: processor.process_images(x),
)
@@ -678,11 +679,15 @@ class LeannMultiVector:
return (float(score), doc_id)
scores: list[tuple[float, int]] = []
# load and core time
start_time = time.time()
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as ex:
futures = [ex.submit(_score_one, doc_id) for doc_id in candidate_doc_ids]
for fut in concurrent.futures.as_completed(futures):
scores.append(fut.result())
end_time = time.time()
print(f"Number of candidate doc ids: {len(candidate_doc_ids)}")
print(f"Time taken in load and core time: {end_time - start_time} seconds")
scores.sort(key=lambda x: x[0], reverse=True)
return scores[:topk] if len(scores) >= topk else scores
@@ -710,7 +715,6 @@ class LeannMultiVector:
emb_path = self._embeddings_path()
if not emb_path.exists():
return self.search(data, topk)
all_embeddings = np.load(emb_path, mmap_mode="r")
if all_embeddings.dtype != np.float32:
all_embeddings = all_embeddings.astype(np.float32)
@@ -718,23 +722,29 @@ class LeannMultiVector:
assert self._docid_to_indices is not None
candidate_doc_ids = list(self._docid_to_indices.keys())
def _score_one(doc_id: int) -> tuple[float, int]:
def _score_one(doc_id: int, _all_embeddings=all_embeddings) -> tuple[float, int]:
token_indices = self._docid_to_indices.get(doc_id, [])
if not token_indices:
return (0.0, doc_id)
doc_vecs = np.asarray(all_embeddings[token_indices], dtype=np.float32)
doc_vecs = np.asarray(_all_embeddings[token_indices], dtype=np.float32)
sim = np.dot(data, doc_vecs.T)
sim = np.nan_to_num(sim, nan=-1e30, posinf=1e30, neginf=-1e30)
score = sim.max(axis=2).sum(axis=1) if sim.ndim == 3 else sim.max(axis=1).sum()
return (float(score), doc_id)
scores: list[tuple[float, int]] = []
# load and core time
start_time = time.time()
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as ex:
futures = [ex.submit(_score_one, d) for d in candidate_doc_ids]
for fut in concurrent.futures.as_completed(futures):
scores.append(fut.result())
end_time = time.time()
# print number of candidate doc ids
print(f"Number of candidate doc ids: {len(candidate_doc_ids)}")
print(f"Time taken in load and core time: {end_time - start_time} seconds")
scores.sort(key=lambda x: x[0], reverse=True)
del all_embeddings
return scores[:topk] if len(scores) >= topk else scores
def get_image(self, doc_id: int) -> Optional[Image.Image]:

View File

@@ -1,7 +1,9 @@
## Jupyter-style notebook script
# %%
# uv pip install matplotlib qwen_vl_utils
import argparse
import os
import time
from typing import Any, Optional
from PIL import Image
@@ -31,8 +33,33 @@ MODEL: str = "colqwen2" # "colpali" or "colqwen2"
# Data source: set to True to use the Hugging Face dataset example (recommended)
USE_HF_DATASET: bool = True
# Single dataset name (used when DATASET_NAMES is None)
DATASET_NAME: str = "weaviate/arXiv-AI-papers-multi-vector"
DATASET_SPLIT: str = "train"
# Multiple datasets to combine (if provided, DATASET_NAME is ignored)
# Can be:
# - List of strings: ["dataset1", "dataset2"]
# - List of tuples: [("dataset1", "config1"), ("dataset2", None)] # None = no config needed
# - Mixed: ["dataset1", ("dataset2", "config2")]
#
# Some potential datasets with images (may need IMAGE_FIELD_NAME adjustment):
# - "weaviate/arXiv-AI-papers-multi-vector" (current, has "page_image" field)
# - ("lmms-lab/DocVQA", "DocVQA") (has "image" field, document images, needs config)
# - ("lmms-lab/DocVQA", "InfographicVQA") (has "image" field, infographic images)
# - "pixparse/arxiv-papers" (if available, arXiv papers)
# - "allenai/ai2d" (AI2D diagram dataset, has "image" field)
# - "huggingface/document-images" (if available)
# Note: Check dataset structure first - some may need IMAGE_FIELD_NAME specified
# DATASET_NAMES: Optional[list[str | tuple[str, Optional[str]]]] = None
DATASET_NAMES = [
"weaviate/arXiv-AI-papers-multi-vector",
("lmms-lab/DocVQA", "DocVQA"), # Specify config name for datasets with multiple configs
]
# Load multiple splits to get more data (e.g., ["train", "test", "validation"])
# Set to None to try loading all available splits automatically
DATASET_SPLITS: Optional[list[str]] = ["train", "test"] # None = auto-detect all splits
# Image field name in the dataset (auto-detect if None)
# Common names: "page_image", "image", "images", "img"
IMAGE_FIELD_NAME: Optional[str] = None # None = auto-detect
MAX_DOCS: Optional[int] = None # limit number of pages to index; None = all
# Local pages (used when USE_HF_DATASET == False)
@@ -40,7 +67,8 @@ PDF: Optional[str] = None # e.g., "./pdfs/2004.12832v2.pdf"
PAGES_DIR: str = "./pages"
# Index + retrieval settings
INDEX_PATH: str = "./indexes/colvision.leann"
# Use a different index path for larger dataset to avoid overwriting existing index
INDEX_PATH: str = "./indexes/colvision_large.leann"
TOPK: int = 3
FIRST_STAGE_K: int = 500
REBUILD_INDEX: bool = False
@@ -54,6 +82,26 @@ ANSWER: bool = True
MAX_NEW_TOKENS: int = 1024
# %%
# CLI overrides
parser = argparse.ArgumentParser(description="Multi-vector LEANN similarity map demo")
parser.add_argument(
"--search-method",
type=str,
choices=["ann", "exact", "exact-all"],
default="ann",
help="Which search method to use: 'ann' (fast ANN), 'exact' (ANN + exact rerank), or 'exact-all' (exact over all docs).",
)
parser.add_argument(
"--query",
type=str,
default=QUERY,
help=f"Query string to search for. Default: '{QUERY}'",
)
cli_args, _unknown = parser.parse_known_args()
SEARCH_METHOD: str = cli_args.search_method
QUERY = cli_args.query # Override QUERY with CLI argument if provided
# %%
# Step 1: Check if we can skip data loading (index already exists)
@@ -74,18 +122,223 @@ if not REBUILD_INDEX:
if need_to_build_index:
print("Loading dataset...")
if USE_HF_DATASET:
from datasets import load_dataset
from datasets import load_dataset, concatenate_datasets, DatasetDict
dataset = load_dataset(DATASET_NAME, split=DATASET_SPLIT)
# Determine which datasets to load
if DATASET_NAMES is not None:
dataset_names_to_load = DATASET_NAMES
print(f"Loading {len(dataset_names_to_load)} datasets: {dataset_names_to_load}")
else:
dataset_names_to_load = [DATASET_NAME]
print(f"Loading single dataset: {DATASET_NAME}")
# Load and combine datasets
all_datasets_to_concat = []
for dataset_entry in dataset_names_to_load:
# Handle both string and tuple formats
if isinstance(dataset_entry, tuple):
dataset_name, config_name = dataset_entry
else:
dataset_name = dataset_entry
config_name = None
print(f"\nProcessing dataset: {dataset_name}" + (f" (config: {config_name})" if config_name else ""))
# Load dataset to check available splits
# If config_name is provided, use it; otherwise try without config
try:
if config_name:
dataset_dict = load_dataset(dataset_name, config_name)
else:
dataset_dict = load_dataset(dataset_name)
except ValueError as e:
if "Config name is missing" in str(e):
# Try to get available configs and suggest
from datasets import get_dataset_config_names
try:
available_configs = get_dataset_config_names(dataset_name)
raise ValueError(
f"Dataset '{dataset_name}' requires a config name. "
f"Available configs: {available_configs}. "
f"Please specify as: ('{dataset_name}', 'config_name')"
) from e
except Exception:
raise ValueError(
f"Dataset '{dataset_name}' requires a config name. "
f"Please specify as: ('{dataset_name}', 'config_name')"
) from e
raise
# Determine which splits to load
if DATASET_SPLITS is None:
# Auto-detect: try to load all available splits
available_splits = list(dataset_dict.keys())
print(f" Auto-detected splits: {available_splits}")
splits_to_load = available_splits
else:
splits_to_load = DATASET_SPLITS
# Load and concatenate multiple splits for this dataset
datasets_to_concat = []
for split in splits_to_load:
if split not in dataset_dict:
print(f" Warning: Split '{split}' not found in dataset. Available splits: {list(dataset_dict.keys())}")
continue
split_dataset = dataset_dict[split]
print(f" Loaded split '{split}': {len(split_dataset)} pages")
datasets_to_concat.append(split_dataset)
if not datasets_to_concat:
print(f" Warning: No valid splits found for {dataset_name}. Skipping.")
continue
# Concatenate splits for this dataset
if len(datasets_to_concat) > 1:
combined_dataset = concatenate_datasets(datasets_to_concat)
print(f" Concatenated {len(datasets_to_concat)} splits into {len(combined_dataset)} pages")
else:
combined_dataset = datasets_to_concat[0]
all_datasets_to_concat.append(combined_dataset)
if not all_datasets_to_concat:
raise RuntimeError("No valid datasets or splits found.")
# Concatenate all datasets
if len(all_datasets_to_concat) > 1:
dataset = concatenate_datasets(all_datasets_to_concat)
print(f"\nConcatenated {len(all_datasets_to_concat)} datasets into {len(dataset)} total pages")
else:
dataset = all_datasets_to_concat[0]
# Apply MAX_DOCS limit if specified
N = len(dataset) if MAX_DOCS is None else min(MAX_DOCS, len(dataset))
if N < len(dataset):
print(f"Limiting to {N} pages (from {len(dataset)} total)")
dataset = dataset.select(range(N))
# Auto-detect image field name if not specified
if IMAGE_FIELD_NAME is None:
# Check multiple samples to find the most common image field
# (useful when datasets are merged and may have different field names)
possible_image_fields = ["page_image", "image", "images", "img", "page", "document_image"]
field_counts = {}
# Check first few samples to find image fields
num_samples_to_check = min(10, len(dataset))
for sample_idx in range(num_samples_to_check):
sample = dataset[sample_idx]
for field in possible_image_fields:
if field in sample and sample[field] is not None:
value = sample[field]
if isinstance(value, Image.Image) or (hasattr(value, 'size') and hasattr(value, 'mode')):
field_counts[field] = field_counts.get(field, 0) + 1
# Choose the most common field, or first found if tied
if field_counts:
image_field = max(field_counts.items(), key=lambda x: x[1])[0]
print(f"Auto-detected image field: '{image_field}' (found in {field_counts[image_field]}/{num_samples_to_check} samples)")
else:
# Fallback: check first sample only
sample = dataset[0]
image_field = None
for field in possible_image_fields:
if field in sample:
value = sample[field]
if isinstance(value, Image.Image) or (hasattr(value, 'size') and hasattr(value, 'mode')):
image_field = field
break
if image_field is None:
raise RuntimeError(
f"Could not auto-detect image field. Available fields: {list(sample.keys())}. "
f"Please specify IMAGE_FIELD_NAME manually."
)
print(f"Auto-detected image field: '{image_field}'")
else:
image_field = IMAGE_FIELD_NAME
if image_field not in dataset[0]:
raise RuntimeError(
f"Image field '{image_field}' not found. Available fields: {list(dataset[0].keys())}"
)
filepaths: list[str] = []
images: list[Image.Image] = []
for i in tqdm(range(N), desc="Loading dataset", total=N):
for i in tqdm(range(len(dataset)), desc="Loading dataset", total=len(dataset)):
p = dataset[i]
# Compose a descriptive identifier for printing later
identifier = f"arXiv:{p['paper_arxiv_id']}|title:{p['paper_title']}|page:{int(p['page_number'])}|id:{p['page_id']}"
# Try to compose a descriptive identifier
# Handle different dataset structures
identifier_parts = []
# Helper function to safely get field value
def safe_get(field_name, default=None):
if field_name in p and p[field_name] is not None:
return p[field_name]
return default
# Try to get various identifier fields
if safe_get("paper_arxiv_id"):
identifier_parts.append(f"arXiv:{p['paper_arxiv_id']}")
if safe_get("paper_title"):
identifier_parts.append(f"title:{p['paper_title']}")
if safe_get("page_number") is not None:
try:
identifier_parts.append(f"page:{int(p['page_number'])}")
except (ValueError, TypeError):
# If conversion fails, use the raw value or skip
if p['page_number']:
identifier_parts.append(f"page:{p['page_number']}")
if safe_get("page_id"):
identifier_parts.append(f"id:{p['page_id']}")
elif safe_get("questionId"):
identifier_parts.append(f"qid:{p['questionId']}")
elif safe_get("docId"):
identifier_parts.append(f"docId:{p['docId']}")
elif safe_get("id"):
identifier_parts.append(f"id:{p['id']}")
# If no identifier parts found, create one from index
if identifier_parts:
identifier = "|".join(identifier_parts)
else:
# Create identifier from available fields or index
fallback_parts = []
# Try common fields that might exist
for field in ["ucsf_document_id", "docId", "questionId", "id"]:
if safe_get(field):
fallback_parts.append(f"{field}:{p[field]}")
break
if fallback_parts:
identifier = "|".join(fallback_parts) + f"|idx:{i}"
else:
identifier = f"doc_{i}"
filepaths.append(identifier)
images.append(p["page_image"]) # PIL Image
# Get image - try detected field first, then fallback to other common fields
img = None
if image_field in p and p[image_field] is not None:
img = p[image_field]
else:
# Fallback: try other common image field names
for fallback_field in ["image", "page_image", "images", "img"]:
if fallback_field in p and p[fallback_field] is not None:
img = p[fallback_field]
break
if img is None:
raise RuntimeError(
f"No image found for sample {i}. Available fields: {list(p.keys())}. "
f"Expected field: {image_field}"
)
# Ensure it's a PIL Image
if not isinstance(img, Image.Image):
if hasattr(img, 'convert'):
img = img.convert('RGB')
else:
img = Image.fromarray(img) if hasattr(img, '__array__') else Image.open(img)
images.append(img)
else:
_maybe_convert_pdf_to_images(PDF, PAGES_DIR)
filepaths, images = _load_images_from_dir(PAGES_DIR)
@@ -123,8 +376,31 @@ if need_to_build_index and retriever is None:
# %%
# Step 5: Embed query and search
_t0 = time.perf_counter()
q_vec = _embed_queries(model, processor, [QUERY])[0]
results = retriever.search(q_vec.float().numpy(), topk=TOPK)
query_embed_secs = time.perf_counter() - _t0
query_np = q_vec.float().numpy()
print(f"[Search] Method: {SEARCH_METHOD}")
print(f"[Timing] Query embedding: {query_embed_secs:.3f}s")
# Run the selected search method and time it
_t0 = time.perf_counter()
if SEARCH_METHOD == "ann":
results = retriever.search(query_np, topk=TOPK, first_stage_k=FIRST_STAGE_K)
search_secs = time.perf_counter() - _t0
print(f"[Timing] Search (ANN): {search_secs:.3f}s (first_stage_k={FIRST_STAGE_K})")
elif SEARCH_METHOD == "exact":
results = retriever.search_exact(query_np, topk=TOPK, first_stage_k=FIRST_STAGE_K)
search_secs = time.perf_counter() - _t0
print(f"[Timing] Search (Exact rerank): {search_secs:.3f}s (first_stage_k={FIRST_STAGE_K})")
elif SEARCH_METHOD == "exact-all":
results = retriever.search_exact_all(query_np, topk=TOPK)
search_secs = time.perf_counter() - _t0
print(f"[Timing] Search (Exact all): {search_secs:.3f}s")
else:
results = []
if not results:
print("No results found.")
else:
@@ -204,6 +480,9 @@ if results and SIMILARITY_MAP:
# Step 7: Optional answer generation
if results and ANSWER:
qwen = QwenVL(device=device_str)
_t0 = time.perf_counter()
response = qwen.answer(QUERY, top_images[:TOPK], max_new_tokens=MAX_NEW_TOKENS)
gen_secs = time.perf_counter() - _t0
print(f"[Timing] Generation: {gen_secs:.3f}s")
print("\nAnswer:")
print(response)