From 8aa4c7e5f26c310d86178002d23ebe33a0aa86b0 Mon Sep 17 00:00:00 2001 From: Andy Lee Date: Tue, 23 Dec 2025 09:17:47 +0000 Subject: [PATCH] Fix multimodal benchmark scripts type errors - Fix undefined LeannRetriever -> LeannMultiVector - Add proper type casts for HuggingFace Dataset iteration - Cast task config values to correct types - Add type annotations for dataset row dicts --- .../multi-vector-leann-paper-example.py | 9 +-- .../multi-vector-leann-similarity-map.py | 8 +-- .../vidore_v1_benchmark.py | 53 +++++++++-------- .../vidore_v2_benchmark.py | 59 ++++++++++--------- 4 files changed, 68 insertions(+), 61 deletions(-) diff --git a/apps/multimodal/vision-based-pdf-multi-vector/multi-vector-leann-paper-example.py b/apps/multimodal/vision-based-pdf-multi-vector/multi-vector-leann-paper-example.py index 22102d3..16107ca 100644 --- a/apps/multimodal/vision-based-pdf-multi-vector/multi-vector-leann-paper-example.py +++ b/apps/multimodal/vision-based-pdf-multi-vector/multi-vector-leann-paper-example.py @@ -18,10 +18,11 @@ _repo_root = Path(__file__).resolve().parents[3] _leann_core_src = _repo_root / "packages" / "leann-core" / "src" _leann_hnsw_pkg = _repo_root / "packages" / "leann-backend-hnsw" if str(_leann_core_src) not in sys.path: - sys.path.append(str(_leann_core_src)) + sys.path.insert(0, str(_leann_core_src)) if str(_leann_hnsw_pkg) not in sys.path: - sys.path.append(str(_leann_hnsw_pkg)) + sys.path.insert(0, str(_leann_hnsw_pkg)) +from leann_multi_vector import LeannMultiVector import torch from colpali_engine.models import ColPali @@ -93,9 +94,9 @@ for batch_doc in tqdm(dataloader): print(ds[0].shape) # %% -# Build HNSW index via LeannRetriever primitives and run search +# Build HNSW index via LeannMultiVector primitives and run search index_path = "./indexes/colpali.leann" -retriever = LeannRetriever(index_path=index_path, dim=int(ds[0].shape[-1])) +retriever = LeannMultiVector(index_path=index_path, dim=int(ds[0].shape[-1])) retriever.create_collection() filepaths = [os.path.join("./pages", name) for name in page_filenames] for i in range(len(filepaths)): diff --git a/apps/multimodal/vision-based-pdf-multi-vector/multi-vector-leann-similarity-map.py b/apps/multimodal/vision-based-pdf-multi-vector/multi-vector-leann-similarity-map.py index fcde09f..f1be682 100644 --- a/apps/multimodal/vision-based-pdf-multi-vector/multi-vector-leann-similarity-map.py +++ b/apps/multimodal/vision-based-pdf-multi-vector/multi-vector-leann-similarity-map.py @@ -5,7 +5,7 @@ import argparse import faulthandler import os import time -from typing import Any, Optional +from typing import Any, Optional, cast import numpy as np from PIL import Image @@ -223,7 +223,7 @@ if need_to_build_index: # Use filenames as identifiers instead of full paths for cleaner metadata filepaths = [os.path.basename(fp) for fp in filepaths] elif USE_HF_DATASET: - from datasets import load_dataset, concatenate_datasets, DatasetDict + from datasets import Dataset, DatasetDict, concatenate_datasets, load_dataset # Determine which datasets to load if DATASET_NAMES is not None: @@ -281,12 +281,12 @@ if need_to_build_index: splits_to_load = DATASET_SPLITS # Load and concatenate multiple splits for this dataset - datasets_to_concat = [] + datasets_to_concat: list[Dataset] = [] for split in splits_to_load: if split not in dataset_dict: print(f" Warning: Split '{split}' not found in dataset. Available splits: {list(dataset_dict.keys())}") continue - split_dataset = dataset_dict[split] + split_dataset = cast(Dataset, dataset_dict[split]) print(f" Loaded split '{split}': {len(split_dataset)} pages") datasets_to_concat.append(split_dataset) diff --git a/apps/multimodal/vision-based-pdf-multi-vector/vidore_v1_benchmark.py b/apps/multimodal/vision-based-pdf-multi-vector/vidore_v1_benchmark.py index 79472df..3b2d7df 100644 --- a/apps/multimodal/vision-based-pdf-multi-vector/vidore_v1_benchmark.py +++ b/apps/multimodal/vision-based-pdf-multi-vector/vidore_v1_benchmark.py @@ -25,9 +25,9 @@ Usage: import argparse import json import os -from typing import Optional +from typing import Any, Optional, cast -from datasets import load_dataset +from datasets import Dataset, load_dataset from leann_multi_vector import ( ViDoReBenchmarkEvaluator, _ensure_repo_paths_importable, @@ -151,40 +151,43 @@ def load_vidore_v1_data( """ print(f"Loading dataset: {dataset_path} (split={split})") - # Load queries - query_ds = load_dataset(dataset_path, "queries", split=split, revision=revision) + # Load queries - cast to Dataset since we know split returns Dataset not DatasetDict + query_ds = cast(Dataset, load_dataset(dataset_path, "queries", split=split, revision=revision)) - queries = {} + queries: dict[str, str] = {} for row in query_ds: - query_id = f"query-{split}-{row['query-id']}" - queries[query_id] = row["query"] + row_dict = cast(dict[str, Any], row) + query_id = f"query-{split}-{row_dict['query-id']}" + queries[query_id] = row_dict["query"] - # Load corpus (images) - corpus_ds = load_dataset(dataset_path, "corpus", split=split, revision=revision) + # Load corpus (images) - cast to Dataset + corpus_ds = cast(Dataset, load_dataset(dataset_path, "corpus", split=split, revision=revision)) - corpus = {} + corpus: dict[str, Any] = {} for row in corpus_ds: - corpus_id = f"corpus-{split}-{row['corpus-id']}" + row_dict = cast(dict[str, Any], row) + corpus_id = f"corpus-{split}-{row_dict['corpus-id']}" # Extract image from the dataset row - if "image" in row: - corpus[corpus_id] = row["image"] - elif "page_image" in row: - corpus[corpus_id] = row["page_image"] + if "image" in row_dict: + corpus[corpus_id] = row_dict["image"] + elif "page_image" in row_dict: + corpus[corpus_id] = row_dict["page_image"] else: raise ValueError( - f"No image field found in corpus. Available fields: {list(row.keys())}" + f"No image field found in corpus. Available fields: {list(row_dict.keys())}" ) - # Load qrels (relevance judgments) - qrels_ds = load_dataset(dataset_path, "qrels", split=split, revision=revision) + # Load qrels (relevance judgments) - cast to Dataset + qrels_ds = cast(Dataset, load_dataset(dataset_path, "qrels", split=split, revision=revision)) - qrels = {} + qrels: dict[str, dict[str, int]] = {} for row in qrels_ds: - query_id = f"query-{split}-{row['query-id']}" - corpus_id = f"corpus-{split}-{row['corpus-id']}" + row_dict = cast(dict[str, Any], row) + query_id = f"query-{split}-{row_dict['query-id']}" + corpus_id = f"corpus-{split}-{row_dict['corpus-id']}" if query_id not in qrels: qrels[query_id] = {} - qrels[query_id][corpus_id] = int(row["score"]) + qrels[query_id][corpus_id] = int(row_dict["score"]) print( f"Loaded {len(queries)} queries, {len(corpus)} corpus items, {len(qrels)} query-relevance mappings" @@ -234,8 +237,8 @@ def evaluate_task( raise ValueError(f"Unknown task: {task_name}. Available: {list(VIDORE_V1_TASKS.keys())}") task_config = VIDORE_V1_TASKS[task_name] - dataset_path = task_config["dataset_path"] - revision = task_config["revision"] + dataset_path = str(task_config["dataset_path"]) + revision = str(task_config["revision"]) # Load data corpus, queries, qrels = load_vidore_v1_data( @@ -286,7 +289,7 @@ def evaluate_task( ) # Search queries - task_prompt = task_config.get("prompt") + task_prompt = cast(Optional[dict[str, str]], task_config.get("prompt")) results = evaluator.search_queries( queries=queries, corpus_ids=corpus_ids_ordered, diff --git a/apps/multimodal/vision-based-pdf-multi-vector/vidore_v2_benchmark.py b/apps/multimodal/vision-based-pdf-multi-vector/vidore_v2_benchmark.py index 8a34e69..d6130d8 100644 --- a/apps/multimodal/vision-based-pdf-multi-vector/vidore_v2_benchmark.py +++ b/apps/multimodal/vision-based-pdf-multi-vector/vidore_v2_benchmark.py @@ -25,9 +25,9 @@ Usage: import argparse import json import os -from typing import Optional +from typing import Any, Optional, cast -from datasets import load_dataset +from datasets import Dataset, load_dataset from leann_multi_vector import ( ViDoReBenchmarkEvaluator, _ensure_repo_paths_importable, @@ -91,8 +91,8 @@ def load_vidore_v2_data( """ print(f"Loading dataset: {dataset_path} (split={split}, language={language})") - # Load queries - query_ds = load_dataset(dataset_path, "queries", split=split, revision=revision) + # Load queries - cast to Dataset since we know split returns Dataset not DatasetDict + query_ds = cast(Dataset, load_dataset(dataset_path, "queries", split=split, revision=revision)) # Check if dataset has language field before filtering has_language_field = len(query_ds) > 0 and "language" in query_ds.column_names @@ -112,8 +112,8 @@ def load_vidore_v2_data( if len(query_ds_filtered) == 0: # Try to get a sample to see actual language values try: - sample_ds = load_dataset( - dataset_path, "queries", split=split, revision=revision + sample_ds = cast( + Dataset, load_dataset(dataset_path, "queries", split=split, revision=revision) ) if len(sample_ds) > 0 and "language" in sample_ds.column_names: sample_langs = set(sample_ds["language"]) @@ -126,37 +126,40 @@ def load_vidore_v2_data( ) query_ds = query_ds_filtered - queries = {} + queries: dict[str, str] = {} for row in query_ds: - query_id = f"query-{split}-{row['query-id']}" - queries[query_id] = row["query"] + row_dict = cast(dict[str, Any], row) + query_id = f"query-{split}-{row_dict['query-id']}" + queries[query_id] = row_dict["query"] - # Load corpus (images) - corpus_ds = load_dataset(dataset_path, "corpus", split=split, revision=revision) + # Load corpus (images) - cast to Dataset + corpus_ds = cast(Dataset, load_dataset(dataset_path, "corpus", split=split, revision=revision)) - corpus = {} + corpus: dict[str, Any] = {} for row in corpus_ds: - corpus_id = f"corpus-{split}-{row['corpus-id']}" + row_dict = cast(dict[str, Any], row) + corpus_id = f"corpus-{split}-{row_dict['corpus-id']}" # Extract image from the dataset row - if "image" in row: - corpus[corpus_id] = row["image"] - elif "page_image" in row: - corpus[corpus_id] = row["page_image"] + if "image" in row_dict: + corpus[corpus_id] = row_dict["image"] + elif "page_image" in row_dict: + corpus[corpus_id] = row_dict["page_image"] else: raise ValueError( - f"No image field found in corpus. Available fields: {list(row.keys())}" + f"No image field found in corpus. Available fields: {list(row_dict.keys())}" ) - # Load qrels (relevance judgments) - qrels_ds = load_dataset(dataset_path, "qrels", split=split, revision=revision) + # Load qrels (relevance judgments) - cast to Dataset + qrels_ds = cast(Dataset, load_dataset(dataset_path, "qrels", split=split, revision=revision)) - qrels = {} + qrels: dict[str, dict[str, int]] = {} for row in qrels_ds: - query_id = f"query-{split}-{row['query-id']}" - corpus_id = f"corpus-{split}-{row['corpus-id']}" + row_dict = cast(dict[str, Any], row) + query_id = f"query-{split}-{row_dict['query-id']}" + corpus_id = f"corpus-{split}-{row_dict['corpus-id']}" if query_id not in qrels: qrels[query_id] = {} - qrels[query_id][corpus_id] = int(row["score"]) + qrels[query_id][corpus_id] = int(row_dict["score"]) print( f"Loaded {len(queries)} queries, {len(corpus)} corpus items, {len(qrels)} query-relevance mappings" @@ -204,13 +207,13 @@ def evaluate_task( raise ValueError(f"Unknown task: {task_name}. Available: {list(VIDORE_V2_TASKS.keys())}") task_config = VIDORE_V2_TASKS[task_name] - dataset_path = task_config["dataset_path"] - revision = task_config["revision"] + dataset_path = str(task_config["dataset_path"]) + revision = str(task_config["revision"]) # Determine language if language is None: # Use first language if multiple available - languages = task_config.get("languages") + languages = cast(Optional[list[str]], task_config.get("languages")) if languages is None: # Task doesn't support language filtering (e.g., Vidore2ESGReportsHLRetrieval) language = None @@ -269,7 +272,7 @@ def evaluate_task( ) # Search queries - task_prompt = task_config.get("prompt") + task_prompt = cast(Optional[dict[str, str]], task_config.get("prompt")) results = evaluator.search_queries( queries=queries, corpus_ids=corpus_ids_ordered,