diff --git a/apps/multimodal/vision-based-pdf-multi-vector/leann_multi_vector.py b/apps/multimodal/vision-based-pdf-multi-vector/leann_multi_vector.py index 8353d3a..c557fef 100644 --- a/apps/multimodal/vision-based-pdf-multi-vector/leann_multi_vector.py +++ b/apps/multimodal/vision-based-pdf-multi-vector/leann_multi_vector.py @@ -3,6 +3,7 @@ import json import os import re import sys +import time from pathlib import Path from typing import Any, Optional, cast @@ -194,7 +195,7 @@ def _embed_images(model, processor, images: list[Image.Image]) -> list[Any]: dataloader = DataLoader( dataset=ListDataset[Image.Image](images), - batch_size=1, + batch_size=32, shuffle=False, collate_fn=lambda x: processor.process_images(x), ) @@ -678,11 +679,15 @@ class LeannMultiVector: return (float(score), doc_id) scores: list[tuple[float, int]] = [] + # load and core time + start_time = time.time() with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as ex: futures = [ex.submit(_score_one, doc_id) for doc_id in candidate_doc_ids] for fut in concurrent.futures.as_completed(futures): scores.append(fut.result()) - + end_time = time.time() + print(f"Number of candidate doc ids: {len(candidate_doc_ids)}") + print(f"Time taken in load and core time: {end_time - start_time} seconds") scores.sort(key=lambda x: x[0], reverse=True) return scores[:topk] if len(scores) >= topk else scores @@ -710,7 +715,6 @@ class LeannMultiVector: emb_path = self._embeddings_path() if not emb_path.exists(): return self.search(data, topk) - all_embeddings = np.load(emb_path, mmap_mode="r") if all_embeddings.dtype != np.float32: all_embeddings = all_embeddings.astype(np.float32) @@ -718,23 +722,29 @@ class LeannMultiVector: assert self._docid_to_indices is not None candidate_doc_ids = list(self._docid_to_indices.keys()) - def _score_one(doc_id: int) -> tuple[float, int]: + def _score_one(doc_id: int, _all_embeddings=all_embeddings) -> tuple[float, int]: token_indices = self._docid_to_indices.get(doc_id, []) if not token_indices: return (0.0, doc_id) - doc_vecs = np.asarray(all_embeddings[token_indices], dtype=np.float32) + doc_vecs = np.asarray(_all_embeddings[token_indices], dtype=np.float32) sim = np.dot(data, doc_vecs.T) sim = np.nan_to_num(sim, nan=-1e30, posinf=1e30, neginf=-1e30) score = sim.max(axis=2).sum(axis=1) if sim.ndim == 3 else sim.max(axis=1).sum() return (float(score), doc_id) scores: list[tuple[float, int]] = [] + # load and core time + start_time = time.time() with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as ex: futures = [ex.submit(_score_one, d) for d in candidate_doc_ids] for fut in concurrent.futures.as_completed(futures): scores.append(fut.result()) - + end_time = time.time() + # print number of candidate doc ids + print(f"Number of candidate doc ids: {len(candidate_doc_ids)}") + print(f"Time taken in load and core time: {end_time - start_time} seconds") scores.sort(key=lambda x: x[0], reverse=True) + del all_embeddings return scores[:topk] if len(scores) >= topk else scores def get_image(self, doc_id: int) -> Optional[Image.Image]: diff --git a/apps/multimodal/vision-based-pdf-multi-vector/multi-vector-leann-similarity-map.py b/apps/multimodal/vision-based-pdf-multi-vector/multi-vector-leann-similarity-map.py index c4c01e8..4c4c061 100644 --- a/apps/multimodal/vision-based-pdf-multi-vector/multi-vector-leann-similarity-map.py +++ b/apps/multimodal/vision-based-pdf-multi-vector/multi-vector-leann-similarity-map.py @@ -1,7 +1,9 @@ ## Jupyter-style notebook script # %% # uv pip install matplotlib qwen_vl_utils +import argparse import os +import time from typing import Any, Optional from PIL import Image @@ -31,8 +33,33 @@ MODEL: str = "colqwen2" # "colpali" or "colqwen2" # Data source: set to True to use the Hugging Face dataset example (recommended) USE_HF_DATASET: bool = True +# Single dataset name (used when DATASET_NAMES is None) DATASET_NAME: str = "weaviate/arXiv-AI-papers-multi-vector" -DATASET_SPLIT: str = "train" +# Multiple datasets to combine (if provided, DATASET_NAME is ignored) +# Can be: +# - List of strings: ["dataset1", "dataset2"] +# - List of tuples: [("dataset1", "config1"), ("dataset2", None)] # None = no config needed +# - Mixed: ["dataset1", ("dataset2", "config2")] +# +# Some potential datasets with images (may need IMAGE_FIELD_NAME adjustment): +# - "weaviate/arXiv-AI-papers-multi-vector" (current, has "page_image" field) +# - ("lmms-lab/DocVQA", "DocVQA") (has "image" field, document images, needs config) +# - ("lmms-lab/DocVQA", "InfographicVQA") (has "image" field, infographic images) +# - "pixparse/arxiv-papers" (if available, arXiv papers) +# - "allenai/ai2d" (AI2D diagram dataset, has "image" field) +# - "huggingface/document-images" (if available) +# Note: Check dataset structure first - some may need IMAGE_FIELD_NAME specified +# DATASET_NAMES: Optional[list[str | tuple[str, Optional[str]]]] = None +DATASET_NAMES = [ + "weaviate/arXiv-AI-papers-multi-vector", + ("lmms-lab/DocVQA", "DocVQA"), # Specify config name for datasets with multiple configs +] +# Load multiple splits to get more data (e.g., ["train", "test", "validation"]) +# Set to None to try loading all available splits automatically +DATASET_SPLITS: Optional[list[str]] = ["train", "test"] # None = auto-detect all splits +# Image field name in the dataset (auto-detect if None) +# Common names: "page_image", "image", "images", "img" +IMAGE_FIELD_NAME: Optional[str] = None # None = auto-detect MAX_DOCS: Optional[int] = None # limit number of pages to index; None = all # Local pages (used when USE_HF_DATASET == False) @@ -40,7 +67,8 @@ PDF: Optional[str] = None # e.g., "./pdfs/2004.12832v2.pdf" PAGES_DIR: str = "./pages" # Index + retrieval settings -INDEX_PATH: str = "./indexes/colvision.leann" +# Use a different index path for larger dataset to avoid overwriting existing index +INDEX_PATH: str = "./indexes/colvision_large.leann" TOPK: int = 3 FIRST_STAGE_K: int = 500 REBUILD_INDEX: bool = False @@ -54,6 +82,26 @@ ANSWER: bool = True MAX_NEW_TOKENS: int = 1024 +# %% +# CLI overrides +parser = argparse.ArgumentParser(description="Multi-vector LEANN similarity map demo") +parser.add_argument( + "--search-method", + type=str, + choices=["ann", "exact", "exact-all"], + default="ann", + help="Which search method to use: 'ann' (fast ANN), 'exact' (ANN + exact rerank), or 'exact-all' (exact over all docs).", +) +parser.add_argument( + "--query", + type=str, + default=QUERY, + help=f"Query string to search for. Default: '{QUERY}'", +) +cli_args, _unknown = parser.parse_known_args() +SEARCH_METHOD: str = cli_args.search_method +QUERY = cli_args.query # Override QUERY with CLI argument if provided + # %% # Step 1: Check if we can skip data loading (index already exists) @@ -74,18 +122,223 @@ if not REBUILD_INDEX: if need_to_build_index: print("Loading dataset...") if USE_HF_DATASET: - from datasets import load_dataset + from datasets import load_dataset, concatenate_datasets, DatasetDict - dataset = load_dataset(DATASET_NAME, split=DATASET_SPLIT) + # Determine which datasets to load + if DATASET_NAMES is not None: + dataset_names_to_load = DATASET_NAMES + print(f"Loading {len(dataset_names_to_load)} datasets: {dataset_names_to_load}") + else: + dataset_names_to_load = [DATASET_NAME] + print(f"Loading single dataset: {DATASET_NAME}") + + # Load and combine datasets + all_datasets_to_concat = [] + + for dataset_entry in dataset_names_to_load: + # Handle both string and tuple formats + if isinstance(dataset_entry, tuple): + dataset_name, config_name = dataset_entry + else: + dataset_name = dataset_entry + config_name = None + + print(f"\nProcessing dataset: {dataset_name}" + (f" (config: {config_name})" if config_name else "")) + + # Load dataset to check available splits + # If config_name is provided, use it; otherwise try without config + try: + if config_name: + dataset_dict = load_dataset(dataset_name, config_name) + else: + dataset_dict = load_dataset(dataset_name) + except ValueError as e: + if "Config name is missing" in str(e): + # Try to get available configs and suggest + from datasets import get_dataset_config_names + try: + available_configs = get_dataset_config_names(dataset_name) + raise ValueError( + f"Dataset '{dataset_name}' requires a config name. " + f"Available configs: {available_configs}. " + f"Please specify as: ('{dataset_name}', 'config_name')" + ) from e + except Exception: + raise ValueError( + f"Dataset '{dataset_name}' requires a config name. " + f"Please specify as: ('{dataset_name}', 'config_name')" + ) from e + raise + + # Determine which splits to load + if DATASET_SPLITS is None: + # Auto-detect: try to load all available splits + available_splits = list(dataset_dict.keys()) + print(f" Auto-detected splits: {available_splits}") + splits_to_load = available_splits + else: + splits_to_load = DATASET_SPLITS + + # Load and concatenate multiple splits for this dataset + datasets_to_concat = [] + for split in splits_to_load: + if split not in dataset_dict: + print(f" Warning: Split '{split}' not found in dataset. Available splits: {list(dataset_dict.keys())}") + continue + split_dataset = dataset_dict[split] + print(f" Loaded split '{split}': {len(split_dataset)} pages") + datasets_to_concat.append(split_dataset) + + if not datasets_to_concat: + print(f" Warning: No valid splits found for {dataset_name}. Skipping.") + continue + + # Concatenate splits for this dataset + if len(datasets_to_concat) > 1: + combined_dataset = concatenate_datasets(datasets_to_concat) + print(f" Concatenated {len(datasets_to_concat)} splits into {len(combined_dataset)} pages") + else: + combined_dataset = datasets_to_concat[0] + + all_datasets_to_concat.append(combined_dataset) + + if not all_datasets_to_concat: + raise RuntimeError("No valid datasets or splits found.") + + # Concatenate all datasets + if len(all_datasets_to_concat) > 1: + dataset = concatenate_datasets(all_datasets_to_concat) + print(f"\nConcatenated {len(all_datasets_to_concat)} datasets into {len(dataset)} total pages") + else: + dataset = all_datasets_to_concat[0] + + # Apply MAX_DOCS limit if specified N = len(dataset) if MAX_DOCS is None else min(MAX_DOCS, len(dataset)) + if N < len(dataset): + print(f"Limiting to {N} pages (from {len(dataset)} total)") + dataset = dataset.select(range(N)) + + # Auto-detect image field name if not specified + if IMAGE_FIELD_NAME is None: + # Check multiple samples to find the most common image field + # (useful when datasets are merged and may have different field names) + possible_image_fields = ["page_image", "image", "images", "img", "page", "document_image"] + field_counts = {} + + # Check first few samples to find image fields + num_samples_to_check = min(10, len(dataset)) + for sample_idx in range(num_samples_to_check): + sample = dataset[sample_idx] + for field in possible_image_fields: + if field in sample and sample[field] is not None: + value = sample[field] + if isinstance(value, Image.Image) or (hasattr(value, 'size') and hasattr(value, 'mode')): + field_counts[field] = field_counts.get(field, 0) + 1 + + # Choose the most common field, or first found if tied + if field_counts: + image_field = max(field_counts.items(), key=lambda x: x[1])[0] + print(f"Auto-detected image field: '{image_field}' (found in {field_counts[image_field]}/{num_samples_to_check} samples)") + else: + # Fallback: check first sample only + sample = dataset[0] + image_field = None + for field in possible_image_fields: + if field in sample: + value = sample[field] + if isinstance(value, Image.Image) or (hasattr(value, 'size') and hasattr(value, 'mode')): + image_field = field + break + if image_field is None: + raise RuntimeError( + f"Could not auto-detect image field. Available fields: {list(sample.keys())}. " + f"Please specify IMAGE_FIELD_NAME manually." + ) + print(f"Auto-detected image field: '{image_field}'") + else: + image_field = IMAGE_FIELD_NAME + if image_field not in dataset[0]: + raise RuntimeError( + f"Image field '{image_field}' not found. Available fields: {list(dataset[0].keys())}" + ) + filepaths: list[str] = [] images: list[Image.Image] = [] - for i in tqdm(range(N), desc="Loading dataset", total=N): + for i in tqdm(range(len(dataset)), desc="Loading dataset", total=len(dataset)): p = dataset[i] - # Compose a descriptive identifier for printing later - identifier = f"arXiv:{p['paper_arxiv_id']}|title:{p['paper_title']}|page:{int(p['page_number'])}|id:{p['page_id']}" + # Try to compose a descriptive identifier + # Handle different dataset structures + identifier_parts = [] + + # Helper function to safely get field value + def safe_get(field_name, default=None): + if field_name in p and p[field_name] is not None: + return p[field_name] + return default + + # Try to get various identifier fields + if safe_get("paper_arxiv_id"): + identifier_parts.append(f"arXiv:{p['paper_arxiv_id']}") + if safe_get("paper_title"): + identifier_parts.append(f"title:{p['paper_title']}") + if safe_get("page_number") is not None: + try: + identifier_parts.append(f"page:{int(p['page_number'])}") + except (ValueError, TypeError): + # If conversion fails, use the raw value or skip + if p['page_number']: + identifier_parts.append(f"page:{p['page_number']}") + if safe_get("page_id"): + identifier_parts.append(f"id:{p['page_id']}") + elif safe_get("questionId"): + identifier_parts.append(f"qid:{p['questionId']}") + elif safe_get("docId"): + identifier_parts.append(f"docId:{p['docId']}") + elif safe_get("id"): + identifier_parts.append(f"id:{p['id']}") + + # If no identifier parts found, create one from index + if identifier_parts: + identifier = "|".join(identifier_parts) + else: + # Create identifier from available fields or index + fallback_parts = [] + # Try common fields that might exist + for field in ["ucsf_document_id", "docId", "questionId", "id"]: + if safe_get(field): + fallback_parts.append(f"{field}:{p[field]}") + break + if fallback_parts: + identifier = "|".join(fallback_parts) + f"|idx:{i}" + else: + identifier = f"doc_{i}" + filepaths.append(identifier) - images.append(p["page_image"]) # PIL Image + + # Get image - try detected field first, then fallback to other common fields + img = None + if image_field in p and p[image_field] is not None: + img = p[image_field] + else: + # Fallback: try other common image field names + for fallback_field in ["image", "page_image", "images", "img"]: + if fallback_field in p and p[fallback_field] is not None: + img = p[fallback_field] + break + + if img is None: + raise RuntimeError( + f"No image found for sample {i}. Available fields: {list(p.keys())}. " + f"Expected field: {image_field}" + ) + + # Ensure it's a PIL Image + if not isinstance(img, Image.Image): + if hasattr(img, 'convert'): + img = img.convert('RGB') + else: + img = Image.fromarray(img) if hasattr(img, '__array__') else Image.open(img) + images.append(img) else: _maybe_convert_pdf_to_images(PDF, PAGES_DIR) filepaths, images = _load_images_from_dir(PAGES_DIR) @@ -123,8 +376,31 @@ if need_to_build_index and retriever is None: # %% # Step 5: Embed query and search +_t0 = time.perf_counter() q_vec = _embed_queries(model, processor, [QUERY])[0] -results = retriever.search(q_vec.float().numpy(), topk=TOPK) +query_embed_secs = time.perf_counter() - _t0 + +query_np = q_vec.float().numpy() + +print(f"[Search] Method: {SEARCH_METHOD}") +print(f"[Timing] Query embedding: {query_embed_secs:.3f}s") + +# Run the selected search method and time it +_t0 = time.perf_counter() +if SEARCH_METHOD == "ann": + results = retriever.search(query_np, topk=TOPK, first_stage_k=FIRST_STAGE_K) + search_secs = time.perf_counter() - _t0 + print(f"[Timing] Search (ANN): {search_secs:.3f}s (first_stage_k={FIRST_STAGE_K})") +elif SEARCH_METHOD == "exact": + results = retriever.search_exact(query_np, topk=TOPK, first_stage_k=FIRST_STAGE_K) + search_secs = time.perf_counter() - _t0 + print(f"[Timing] Search (Exact rerank): {search_secs:.3f}s (first_stage_k={FIRST_STAGE_K})") +elif SEARCH_METHOD == "exact-all": + results = retriever.search_exact_all(query_np, topk=TOPK) + search_secs = time.perf_counter() - _t0 + print(f"[Timing] Search (Exact all): {search_secs:.3f}s") +else: + results = [] if not results: print("No results found.") else: @@ -204,6 +480,9 @@ if results and SIMILARITY_MAP: # Step 7: Optional answer generation if results and ANSWER: qwen = QwenVL(device=device_str) + _t0 = time.perf_counter() response = qwen.answer(QUERY, top_images[:TOPK], max_new_tokens=MAX_NEW_TOKENS) + gen_secs = time.perf_counter() - _t0 + print(f"[Timing] Generation: {gen_secs:.3f}s") print("\nAnswer:") print(response)