Add timing instrumentation and multi-dataset support for multi-vector retrieval

- Add timing measurements for search operations (load and core time)
- Increase embedding batch size from 1 to 32 for better performance
- Add explicit memory cleanup with del all_embeddings
- Support loading and merging multiple datasets with different splits
- Add CLI arguments for search method selection (ann/exact/exact-all)
- Auto-detect image field names across different dataset structures
- Print candidate doc counts for performance monitoring

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
yichuan-w
2025-11-10 21:13:17 +00:00
parent 3766ad1fd2
commit a9c014df9e
2 changed files with 304 additions and 15 deletions

View File

@@ -3,6 +3,7 @@ import json
import os import os
import re import re
import sys import sys
import time
from pathlib import Path from pathlib import Path
from typing import Any, Optional, cast from typing import Any, Optional, cast
@@ -194,7 +195,7 @@ def _embed_images(model, processor, images: list[Image.Image]) -> list[Any]:
dataloader = DataLoader( dataloader = DataLoader(
dataset=ListDataset[Image.Image](images), dataset=ListDataset[Image.Image](images),
batch_size=1, batch_size=32,
shuffle=False, shuffle=False,
collate_fn=lambda x: processor.process_images(x), collate_fn=lambda x: processor.process_images(x),
) )
@@ -678,11 +679,15 @@ class LeannMultiVector:
return (float(score), doc_id) return (float(score), doc_id)
scores: list[tuple[float, int]] = [] scores: list[tuple[float, int]] = []
# load and core time
start_time = time.time()
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as ex: with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as ex:
futures = [ex.submit(_score_one, doc_id) for doc_id in candidate_doc_ids] futures = [ex.submit(_score_one, doc_id) for doc_id in candidate_doc_ids]
for fut in concurrent.futures.as_completed(futures): for fut in concurrent.futures.as_completed(futures):
scores.append(fut.result()) scores.append(fut.result())
end_time = time.time()
print(f"Number of candidate doc ids: {len(candidate_doc_ids)}")
print(f"Time taken in load and core time: {end_time - start_time} seconds")
scores.sort(key=lambda x: x[0], reverse=True) scores.sort(key=lambda x: x[0], reverse=True)
return scores[:topk] if len(scores) >= topk else scores return scores[:topk] if len(scores) >= topk else scores
@@ -710,7 +715,6 @@ class LeannMultiVector:
emb_path = self._embeddings_path() emb_path = self._embeddings_path()
if not emb_path.exists(): if not emb_path.exists():
return self.search(data, topk) return self.search(data, topk)
all_embeddings = np.load(emb_path, mmap_mode="r") all_embeddings = np.load(emb_path, mmap_mode="r")
if all_embeddings.dtype != np.float32: if all_embeddings.dtype != np.float32:
all_embeddings = all_embeddings.astype(np.float32) all_embeddings = all_embeddings.astype(np.float32)
@@ -718,23 +722,29 @@ class LeannMultiVector:
assert self._docid_to_indices is not None assert self._docid_to_indices is not None
candidate_doc_ids = list(self._docid_to_indices.keys()) candidate_doc_ids = list(self._docid_to_indices.keys())
def _score_one(doc_id: int) -> tuple[float, int]: def _score_one(doc_id: int, _all_embeddings=all_embeddings) -> tuple[float, int]:
token_indices = self._docid_to_indices.get(doc_id, []) token_indices = self._docid_to_indices.get(doc_id, [])
if not token_indices: if not token_indices:
return (0.0, doc_id) return (0.0, doc_id)
doc_vecs = np.asarray(all_embeddings[token_indices], dtype=np.float32) doc_vecs = np.asarray(_all_embeddings[token_indices], dtype=np.float32)
sim = np.dot(data, doc_vecs.T) sim = np.dot(data, doc_vecs.T)
sim = np.nan_to_num(sim, nan=-1e30, posinf=1e30, neginf=-1e30) sim = np.nan_to_num(sim, nan=-1e30, posinf=1e30, neginf=-1e30)
score = sim.max(axis=2).sum(axis=1) if sim.ndim == 3 else sim.max(axis=1).sum() score = sim.max(axis=2).sum(axis=1) if sim.ndim == 3 else sim.max(axis=1).sum()
return (float(score), doc_id) return (float(score), doc_id)
scores: list[tuple[float, int]] = [] scores: list[tuple[float, int]] = []
# load and core time
start_time = time.time()
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as ex: with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as ex:
futures = [ex.submit(_score_one, d) for d in candidate_doc_ids] futures = [ex.submit(_score_one, d) for d in candidate_doc_ids]
for fut in concurrent.futures.as_completed(futures): for fut in concurrent.futures.as_completed(futures):
scores.append(fut.result()) scores.append(fut.result())
end_time = time.time()
# print number of candidate doc ids
print(f"Number of candidate doc ids: {len(candidate_doc_ids)}")
print(f"Time taken in load and core time: {end_time - start_time} seconds")
scores.sort(key=lambda x: x[0], reverse=True) scores.sort(key=lambda x: x[0], reverse=True)
del all_embeddings
return scores[:topk] if len(scores) >= topk else scores return scores[:topk] if len(scores) >= topk else scores
def get_image(self, doc_id: int) -> Optional[Image.Image]: def get_image(self, doc_id: int) -> Optional[Image.Image]:

View File

@@ -1,7 +1,9 @@
## Jupyter-style notebook script ## Jupyter-style notebook script
# %% # %%
# uv pip install matplotlib qwen_vl_utils # uv pip install matplotlib qwen_vl_utils
import argparse
import os import os
import time
from typing import Any, Optional from typing import Any, Optional
from PIL import Image from PIL import Image
@@ -31,8 +33,33 @@ MODEL: str = "colqwen2" # "colpali" or "colqwen2"
# Data source: set to True to use the Hugging Face dataset example (recommended) # Data source: set to True to use the Hugging Face dataset example (recommended)
USE_HF_DATASET: bool = True USE_HF_DATASET: bool = True
# Single dataset name (used when DATASET_NAMES is None)
DATASET_NAME: str = "weaviate/arXiv-AI-papers-multi-vector" DATASET_NAME: str = "weaviate/arXiv-AI-papers-multi-vector"
DATASET_SPLIT: str = "train" # Multiple datasets to combine (if provided, DATASET_NAME is ignored)
# Can be:
# - List of strings: ["dataset1", "dataset2"]
# - List of tuples: [("dataset1", "config1"), ("dataset2", None)] # None = no config needed
# - Mixed: ["dataset1", ("dataset2", "config2")]
#
# Some potential datasets with images (may need IMAGE_FIELD_NAME adjustment):
# - "weaviate/arXiv-AI-papers-multi-vector" (current, has "page_image" field)
# - ("lmms-lab/DocVQA", "DocVQA") (has "image" field, document images, needs config)
# - ("lmms-lab/DocVQA", "InfographicVQA") (has "image" field, infographic images)
# - "pixparse/arxiv-papers" (if available, arXiv papers)
# - "allenai/ai2d" (AI2D diagram dataset, has "image" field)
# - "huggingface/document-images" (if available)
# Note: Check dataset structure first - some may need IMAGE_FIELD_NAME specified
# DATASET_NAMES: Optional[list[str | tuple[str, Optional[str]]]] = None
DATASET_NAMES = [
"weaviate/arXiv-AI-papers-multi-vector",
("lmms-lab/DocVQA", "DocVQA"), # Specify config name for datasets with multiple configs
]
# Load multiple splits to get more data (e.g., ["train", "test", "validation"])
# Set to None to try loading all available splits automatically
DATASET_SPLITS: Optional[list[str]] = ["train", "test"] # None = auto-detect all splits
# Image field name in the dataset (auto-detect if None)
# Common names: "page_image", "image", "images", "img"
IMAGE_FIELD_NAME: Optional[str] = None # None = auto-detect
MAX_DOCS: Optional[int] = None # limit number of pages to index; None = all MAX_DOCS: Optional[int] = None # limit number of pages to index; None = all
# Local pages (used when USE_HF_DATASET == False) # Local pages (used when USE_HF_DATASET == False)
@@ -40,7 +67,8 @@ PDF: Optional[str] = None # e.g., "./pdfs/2004.12832v2.pdf"
PAGES_DIR: str = "./pages" PAGES_DIR: str = "./pages"
# Index + retrieval settings # Index + retrieval settings
INDEX_PATH: str = "./indexes/colvision.leann" # Use a different index path for larger dataset to avoid overwriting existing index
INDEX_PATH: str = "./indexes/colvision_large.leann"
TOPK: int = 3 TOPK: int = 3
FIRST_STAGE_K: int = 500 FIRST_STAGE_K: int = 500
REBUILD_INDEX: bool = False REBUILD_INDEX: bool = False
@@ -54,6 +82,26 @@ ANSWER: bool = True
MAX_NEW_TOKENS: int = 1024 MAX_NEW_TOKENS: int = 1024
# %%
# CLI overrides
parser = argparse.ArgumentParser(description="Multi-vector LEANN similarity map demo")
parser.add_argument(
"--search-method",
type=str,
choices=["ann", "exact", "exact-all"],
default="ann",
help="Which search method to use: 'ann' (fast ANN), 'exact' (ANN + exact rerank), or 'exact-all' (exact over all docs).",
)
parser.add_argument(
"--query",
type=str,
default=QUERY,
help=f"Query string to search for. Default: '{QUERY}'",
)
cli_args, _unknown = parser.parse_known_args()
SEARCH_METHOD: str = cli_args.search_method
QUERY = cli_args.query # Override QUERY with CLI argument if provided
# %% # %%
# Step 1: Check if we can skip data loading (index already exists) # Step 1: Check if we can skip data loading (index already exists)
@@ -74,18 +122,223 @@ if not REBUILD_INDEX:
if need_to_build_index: if need_to_build_index:
print("Loading dataset...") print("Loading dataset...")
if USE_HF_DATASET: if USE_HF_DATASET:
from datasets import load_dataset from datasets import load_dataset, concatenate_datasets, DatasetDict
dataset = load_dataset(DATASET_NAME, split=DATASET_SPLIT) # Determine which datasets to load
if DATASET_NAMES is not None:
dataset_names_to_load = DATASET_NAMES
print(f"Loading {len(dataset_names_to_load)} datasets: {dataset_names_to_load}")
else:
dataset_names_to_load = [DATASET_NAME]
print(f"Loading single dataset: {DATASET_NAME}")
# Load and combine datasets
all_datasets_to_concat = []
for dataset_entry in dataset_names_to_load:
# Handle both string and tuple formats
if isinstance(dataset_entry, tuple):
dataset_name, config_name = dataset_entry
else:
dataset_name = dataset_entry
config_name = None
print(f"\nProcessing dataset: {dataset_name}" + (f" (config: {config_name})" if config_name else ""))
# Load dataset to check available splits
# If config_name is provided, use it; otherwise try without config
try:
if config_name:
dataset_dict = load_dataset(dataset_name, config_name)
else:
dataset_dict = load_dataset(dataset_name)
except ValueError as e:
if "Config name is missing" in str(e):
# Try to get available configs and suggest
from datasets import get_dataset_config_names
try:
available_configs = get_dataset_config_names(dataset_name)
raise ValueError(
f"Dataset '{dataset_name}' requires a config name. "
f"Available configs: {available_configs}. "
f"Please specify as: ('{dataset_name}', 'config_name')"
) from e
except Exception:
raise ValueError(
f"Dataset '{dataset_name}' requires a config name. "
f"Please specify as: ('{dataset_name}', 'config_name')"
) from e
raise
# Determine which splits to load
if DATASET_SPLITS is None:
# Auto-detect: try to load all available splits
available_splits = list(dataset_dict.keys())
print(f" Auto-detected splits: {available_splits}")
splits_to_load = available_splits
else:
splits_to_load = DATASET_SPLITS
# Load and concatenate multiple splits for this dataset
datasets_to_concat = []
for split in splits_to_load:
if split not in dataset_dict:
print(f" Warning: Split '{split}' not found in dataset. Available splits: {list(dataset_dict.keys())}")
continue
split_dataset = dataset_dict[split]
print(f" Loaded split '{split}': {len(split_dataset)} pages")
datasets_to_concat.append(split_dataset)
if not datasets_to_concat:
print(f" Warning: No valid splits found for {dataset_name}. Skipping.")
continue
# Concatenate splits for this dataset
if len(datasets_to_concat) > 1:
combined_dataset = concatenate_datasets(datasets_to_concat)
print(f" Concatenated {len(datasets_to_concat)} splits into {len(combined_dataset)} pages")
else:
combined_dataset = datasets_to_concat[0]
all_datasets_to_concat.append(combined_dataset)
if not all_datasets_to_concat:
raise RuntimeError("No valid datasets or splits found.")
# Concatenate all datasets
if len(all_datasets_to_concat) > 1:
dataset = concatenate_datasets(all_datasets_to_concat)
print(f"\nConcatenated {len(all_datasets_to_concat)} datasets into {len(dataset)} total pages")
else:
dataset = all_datasets_to_concat[0]
# Apply MAX_DOCS limit if specified
N = len(dataset) if MAX_DOCS is None else min(MAX_DOCS, len(dataset)) N = len(dataset) if MAX_DOCS is None else min(MAX_DOCS, len(dataset))
if N < len(dataset):
print(f"Limiting to {N} pages (from {len(dataset)} total)")
dataset = dataset.select(range(N))
# Auto-detect image field name if not specified
if IMAGE_FIELD_NAME is None:
# Check multiple samples to find the most common image field
# (useful when datasets are merged and may have different field names)
possible_image_fields = ["page_image", "image", "images", "img", "page", "document_image"]
field_counts = {}
# Check first few samples to find image fields
num_samples_to_check = min(10, len(dataset))
for sample_idx in range(num_samples_to_check):
sample = dataset[sample_idx]
for field in possible_image_fields:
if field in sample and sample[field] is not None:
value = sample[field]
if isinstance(value, Image.Image) or (hasattr(value, 'size') and hasattr(value, 'mode')):
field_counts[field] = field_counts.get(field, 0) + 1
# Choose the most common field, or first found if tied
if field_counts:
image_field = max(field_counts.items(), key=lambda x: x[1])[0]
print(f"Auto-detected image field: '{image_field}' (found in {field_counts[image_field]}/{num_samples_to_check} samples)")
else:
# Fallback: check first sample only
sample = dataset[0]
image_field = None
for field in possible_image_fields:
if field in sample:
value = sample[field]
if isinstance(value, Image.Image) or (hasattr(value, 'size') and hasattr(value, 'mode')):
image_field = field
break
if image_field is None:
raise RuntimeError(
f"Could not auto-detect image field. Available fields: {list(sample.keys())}. "
f"Please specify IMAGE_FIELD_NAME manually."
)
print(f"Auto-detected image field: '{image_field}'")
else:
image_field = IMAGE_FIELD_NAME
if image_field not in dataset[0]:
raise RuntimeError(
f"Image field '{image_field}' not found. Available fields: {list(dataset[0].keys())}"
)
filepaths: list[str] = [] filepaths: list[str] = []
images: list[Image.Image] = [] images: list[Image.Image] = []
for i in tqdm(range(N), desc="Loading dataset", total=N): for i in tqdm(range(len(dataset)), desc="Loading dataset", total=len(dataset)):
p = dataset[i] p = dataset[i]
# Compose a descriptive identifier for printing later # Try to compose a descriptive identifier
identifier = f"arXiv:{p['paper_arxiv_id']}|title:{p['paper_title']}|page:{int(p['page_number'])}|id:{p['page_id']}" # Handle different dataset structures
identifier_parts = []
# Helper function to safely get field value
def safe_get(field_name, default=None):
if field_name in p and p[field_name] is not None:
return p[field_name]
return default
# Try to get various identifier fields
if safe_get("paper_arxiv_id"):
identifier_parts.append(f"arXiv:{p['paper_arxiv_id']}")
if safe_get("paper_title"):
identifier_parts.append(f"title:{p['paper_title']}")
if safe_get("page_number") is not None:
try:
identifier_parts.append(f"page:{int(p['page_number'])}")
except (ValueError, TypeError):
# If conversion fails, use the raw value or skip
if p['page_number']:
identifier_parts.append(f"page:{p['page_number']}")
if safe_get("page_id"):
identifier_parts.append(f"id:{p['page_id']}")
elif safe_get("questionId"):
identifier_parts.append(f"qid:{p['questionId']}")
elif safe_get("docId"):
identifier_parts.append(f"docId:{p['docId']}")
elif safe_get("id"):
identifier_parts.append(f"id:{p['id']}")
# If no identifier parts found, create one from index
if identifier_parts:
identifier = "|".join(identifier_parts)
else:
# Create identifier from available fields or index
fallback_parts = []
# Try common fields that might exist
for field in ["ucsf_document_id", "docId", "questionId", "id"]:
if safe_get(field):
fallback_parts.append(f"{field}:{p[field]}")
break
if fallback_parts:
identifier = "|".join(fallback_parts) + f"|idx:{i}"
else:
identifier = f"doc_{i}"
filepaths.append(identifier) filepaths.append(identifier)
images.append(p["page_image"]) # PIL Image
# Get image - try detected field first, then fallback to other common fields
img = None
if image_field in p and p[image_field] is not None:
img = p[image_field]
else:
# Fallback: try other common image field names
for fallback_field in ["image", "page_image", "images", "img"]:
if fallback_field in p and p[fallback_field] is not None:
img = p[fallback_field]
break
if img is None:
raise RuntimeError(
f"No image found for sample {i}. Available fields: {list(p.keys())}. "
f"Expected field: {image_field}"
)
# Ensure it's a PIL Image
if not isinstance(img, Image.Image):
if hasattr(img, 'convert'):
img = img.convert('RGB')
else:
img = Image.fromarray(img) if hasattr(img, '__array__') else Image.open(img)
images.append(img)
else: else:
_maybe_convert_pdf_to_images(PDF, PAGES_DIR) _maybe_convert_pdf_to_images(PDF, PAGES_DIR)
filepaths, images = _load_images_from_dir(PAGES_DIR) filepaths, images = _load_images_from_dir(PAGES_DIR)
@@ -123,8 +376,31 @@ if need_to_build_index and retriever is None:
# %% # %%
# Step 5: Embed query and search # Step 5: Embed query and search
_t0 = time.perf_counter()
q_vec = _embed_queries(model, processor, [QUERY])[0] q_vec = _embed_queries(model, processor, [QUERY])[0]
results = retriever.search(q_vec.float().numpy(), topk=TOPK) query_embed_secs = time.perf_counter() - _t0
query_np = q_vec.float().numpy()
print(f"[Search] Method: {SEARCH_METHOD}")
print(f"[Timing] Query embedding: {query_embed_secs:.3f}s")
# Run the selected search method and time it
_t0 = time.perf_counter()
if SEARCH_METHOD == "ann":
results = retriever.search(query_np, topk=TOPK, first_stage_k=FIRST_STAGE_K)
search_secs = time.perf_counter() - _t0
print(f"[Timing] Search (ANN): {search_secs:.3f}s (first_stage_k={FIRST_STAGE_K})")
elif SEARCH_METHOD == "exact":
results = retriever.search_exact(query_np, topk=TOPK, first_stage_k=FIRST_STAGE_K)
search_secs = time.perf_counter() - _t0
print(f"[Timing] Search (Exact rerank): {search_secs:.3f}s (first_stage_k={FIRST_STAGE_K})")
elif SEARCH_METHOD == "exact-all":
results = retriever.search_exact_all(query_np, topk=TOPK)
search_secs = time.perf_counter() - _t0
print(f"[Timing] Search (Exact all): {search_secs:.3f}s")
else:
results = []
if not results: if not results:
print("No results found.") print("No results found.")
else: else:
@@ -204,6 +480,9 @@ if results and SIMILARITY_MAP:
# Step 7: Optional answer generation # Step 7: Optional answer generation
if results and ANSWER: if results and ANSWER:
qwen = QwenVL(device=device_str) qwen = QwenVL(device=device_str)
_t0 = time.perf_counter()
response = qwen.answer(QUERY, top_images[:TOPK], max_new_tokens=MAX_NEW_TOKENS) response = qwen.answer(QUERY, top_images[:TOPK], max_new_tokens=MAX_NEW_TOKENS)
gen_secs = time.perf_counter() - _t0
print(f"[Timing] Generation: {gen_secs:.3f}s")
print("\nAnswer:") print("\nAnswer:")
print(response) print(response)