Fix multimodal benchmark scripts type errors

- Fix undefined LeannRetriever -> LeannMultiVector
- Add proper type casts for HuggingFace Dataset iteration
- Cast task config values to correct types
- Add type annotations for dataset row dicts
This commit is contained in:
Andy Lee
2025-12-23 09:17:47 +00:00
parent de56ab8fa7
commit 8aa4c7e5f2
4 changed files with 68 additions and 61 deletions

View File

@@ -18,10 +18,11 @@ _repo_root = Path(__file__).resolve().parents[3]
_leann_core_src = _repo_root / "packages" / "leann-core" / "src"
_leann_hnsw_pkg = _repo_root / "packages" / "leann-backend-hnsw"
if str(_leann_core_src) not in sys.path:
sys.path.append(str(_leann_core_src))
sys.path.insert(0, str(_leann_core_src))
if str(_leann_hnsw_pkg) not in sys.path:
sys.path.append(str(_leann_hnsw_pkg))
sys.path.insert(0, str(_leann_hnsw_pkg))
from leann_multi_vector import LeannMultiVector
import torch
from colpali_engine.models import ColPali
@@ -93,9 +94,9 @@ for batch_doc in tqdm(dataloader):
print(ds[0].shape)
# %%
# Build HNSW index via LeannRetriever primitives and run search
# Build HNSW index via LeannMultiVector primitives and run search
index_path = "./indexes/colpali.leann"
retriever = LeannRetriever(index_path=index_path, dim=int(ds[0].shape[-1]))
retriever = LeannMultiVector(index_path=index_path, dim=int(ds[0].shape[-1]))
retriever.create_collection()
filepaths = [os.path.join("./pages", name) for name in page_filenames]
for i in range(len(filepaths)):

View File

@@ -5,7 +5,7 @@ import argparse
import faulthandler
import os
import time
from typing import Any, Optional
from typing import Any, Optional, cast
import numpy as np
from PIL import Image
@@ -223,7 +223,7 @@ if need_to_build_index:
# Use filenames as identifiers instead of full paths for cleaner metadata
filepaths = [os.path.basename(fp) for fp in filepaths]
elif USE_HF_DATASET:
from datasets import load_dataset, concatenate_datasets, DatasetDict
from datasets import Dataset, DatasetDict, concatenate_datasets, load_dataset
# Determine which datasets to load
if DATASET_NAMES is not None:
@@ -281,12 +281,12 @@ if need_to_build_index:
splits_to_load = DATASET_SPLITS
# Load and concatenate multiple splits for this dataset
datasets_to_concat = []
datasets_to_concat: list[Dataset] = []
for split in splits_to_load:
if split not in dataset_dict:
print(f" Warning: Split '{split}' not found in dataset. Available splits: {list(dataset_dict.keys())}")
continue
split_dataset = dataset_dict[split]
split_dataset = cast(Dataset, dataset_dict[split])
print(f" Loaded split '{split}': {len(split_dataset)} pages")
datasets_to_concat.append(split_dataset)

View File

@@ -25,9 +25,9 @@ Usage:
import argparse
import json
import os
from typing import Optional
from typing import Any, Optional, cast
from datasets import load_dataset
from datasets import Dataset, load_dataset
from leann_multi_vector import (
ViDoReBenchmarkEvaluator,
_ensure_repo_paths_importable,
@@ -151,40 +151,43 @@ def load_vidore_v1_data(
"""
print(f"Loading dataset: {dataset_path} (split={split})")
# Load queries
query_ds = load_dataset(dataset_path, "queries", split=split, revision=revision)
# Load queries - cast to Dataset since we know split returns Dataset not DatasetDict
query_ds = cast(Dataset, load_dataset(dataset_path, "queries", split=split, revision=revision))
queries = {}
queries: dict[str, str] = {}
for row in query_ds:
query_id = f"query-{split}-{row['query-id']}"
queries[query_id] = row["query"]
row_dict = cast(dict[str, Any], row)
query_id = f"query-{split}-{row_dict['query-id']}"
queries[query_id] = row_dict["query"]
# Load corpus (images)
corpus_ds = load_dataset(dataset_path, "corpus", split=split, revision=revision)
# Load corpus (images) - cast to Dataset
corpus_ds = cast(Dataset, load_dataset(dataset_path, "corpus", split=split, revision=revision))
corpus = {}
corpus: dict[str, Any] = {}
for row in corpus_ds:
corpus_id = f"corpus-{split}-{row['corpus-id']}"
row_dict = cast(dict[str, Any], row)
corpus_id = f"corpus-{split}-{row_dict['corpus-id']}"
# Extract image from the dataset row
if "image" in row:
corpus[corpus_id] = row["image"]
elif "page_image" in row:
corpus[corpus_id] = row["page_image"]
if "image" in row_dict:
corpus[corpus_id] = row_dict["image"]
elif "page_image" in row_dict:
corpus[corpus_id] = row_dict["page_image"]
else:
raise ValueError(
f"No image field found in corpus. Available fields: {list(row.keys())}"
f"No image field found in corpus. Available fields: {list(row_dict.keys())}"
)
# Load qrels (relevance judgments)
qrels_ds = load_dataset(dataset_path, "qrels", split=split, revision=revision)
# Load qrels (relevance judgments) - cast to Dataset
qrels_ds = cast(Dataset, load_dataset(dataset_path, "qrels", split=split, revision=revision))
qrels = {}
qrels: dict[str, dict[str, int]] = {}
for row in qrels_ds:
query_id = f"query-{split}-{row['query-id']}"
corpus_id = f"corpus-{split}-{row['corpus-id']}"
row_dict = cast(dict[str, Any], row)
query_id = f"query-{split}-{row_dict['query-id']}"
corpus_id = f"corpus-{split}-{row_dict['corpus-id']}"
if query_id not in qrels:
qrels[query_id] = {}
qrels[query_id][corpus_id] = int(row["score"])
qrels[query_id][corpus_id] = int(row_dict["score"])
print(
f"Loaded {len(queries)} queries, {len(corpus)} corpus items, {len(qrels)} query-relevance mappings"
@@ -234,8 +237,8 @@ def evaluate_task(
raise ValueError(f"Unknown task: {task_name}. Available: {list(VIDORE_V1_TASKS.keys())}")
task_config = VIDORE_V1_TASKS[task_name]
dataset_path = task_config["dataset_path"]
revision = task_config["revision"]
dataset_path = str(task_config["dataset_path"])
revision = str(task_config["revision"])
# Load data
corpus, queries, qrels = load_vidore_v1_data(
@@ -286,7 +289,7 @@ def evaluate_task(
)
# Search queries
task_prompt = task_config.get("prompt")
task_prompt = cast(Optional[dict[str, str]], task_config.get("prompt"))
results = evaluator.search_queries(
queries=queries,
corpus_ids=corpus_ids_ordered,

View File

@@ -25,9 +25,9 @@ Usage:
import argparse
import json
import os
from typing import Optional
from typing import Any, Optional, cast
from datasets import load_dataset
from datasets import Dataset, load_dataset
from leann_multi_vector import (
ViDoReBenchmarkEvaluator,
_ensure_repo_paths_importable,
@@ -91,8 +91,8 @@ def load_vidore_v2_data(
"""
print(f"Loading dataset: {dataset_path} (split={split}, language={language})")
# Load queries
query_ds = load_dataset(dataset_path, "queries", split=split, revision=revision)
# Load queries - cast to Dataset since we know split returns Dataset not DatasetDict
query_ds = cast(Dataset, load_dataset(dataset_path, "queries", split=split, revision=revision))
# Check if dataset has language field before filtering
has_language_field = len(query_ds) > 0 and "language" in query_ds.column_names
@@ -112,8 +112,8 @@ def load_vidore_v2_data(
if len(query_ds_filtered) == 0:
# Try to get a sample to see actual language values
try:
sample_ds = load_dataset(
dataset_path, "queries", split=split, revision=revision
sample_ds = cast(
Dataset, load_dataset(dataset_path, "queries", split=split, revision=revision)
)
if len(sample_ds) > 0 and "language" in sample_ds.column_names:
sample_langs = set(sample_ds["language"])
@@ -126,37 +126,40 @@ def load_vidore_v2_data(
)
query_ds = query_ds_filtered
queries = {}
queries: dict[str, str] = {}
for row in query_ds:
query_id = f"query-{split}-{row['query-id']}"
queries[query_id] = row["query"]
row_dict = cast(dict[str, Any], row)
query_id = f"query-{split}-{row_dict['query-id']}"
queries[query_id] = row_dict["query"]
# Load corpus (images)
corpus_ds = load_dataset(dataset_path, "corpus", split=split, revision=revision)
# Load corpus (images) - cast to Dataset
corpus_ds = cast(Dataset, load_dataset(dataset_path, "corpus", split=split, revision=revision))
corpus = {}
corpus: dict[str, Any] = {}
for row in corpus_ds:
corpus_id = f"corpus-{split}-{row['corpus-id']}"
row_dict = cast(dict[str, Any], row)
corpus_id = f"corpus-{split}-{row_dict['corpus-id']}"
# Extract image from the dataset row
if "image" in row:
corpus[corpus_id] = row["image"]
elif "page_image" in row:
corpus[corpus_id] = row["page_image"]
if "image" in row_dict:
corpus[corpus_id] = row_dict["image"]
elif "page_image" in row_dict:
corpus[corpus_id] = row_dict["page_image"]
else:
raise ValueError(
f"No image field found in corpus. Available fields: {list(row.keys())}"
f"No image field found in corpus. Available fields: {list(row_dict.keys())}"
)
# Load qrels (relevance judgments)
qrels_ds = load_dataset(dataset_path, "qrels", split=split, revision=revision)
# Load qrels (relevance judgments) - cast to Dataset
qrels_ds = cast(Dataset, load_dataset(dataset_path, "qrels", split=split, revision=revision))
qrels = {}
qrels: dict[str, dict[str, int]] = {}
for row in qrels_ds:
query_id = f"query-{split}-{row['query-id']}"
corpus_id = f"corpus-{split}-{row['corpus-id']}"
row_dict = cast(dict[str, Any], row)
query_id = f"query-{split}-{row_dict['query-id']}"
corpus_id = f"corpus-{split}-{row_dict['corpus-id']}"
if query_id not in qrels:
qrels[query_id] = {}
qrels[query_id][corpus_id] = int(row["score"])
qrels[query_id][corpus_id] = int(row_dict["score"])
print(
f"Loaded {len(queries)} queries, {len(corpus)} corpus items, {len(qrels)} query-relevance mappings"
@@ -204,13 +207,13 @@ def evaluate_task(
raise ValueError(f"Unknown task: {task_name}. Available: {list(VIDORE_V2_TASKS.keys())}")
task_config = VIDORE_V2_TASKS[task_name]
dataset_path = task_config["dataset_path"]
revision = task_config["revision"]
dataset_path = str(task_config["dataset_path"])
revision = str(task_config["revision"])
# Determine language
if language is None:
# Use first language if multiple available
languages = task_config.get("languages")
languages = cast(Optional[list[str]], task_config.get("languages"))
if languages is None:
# Task doesn't support language filtering (e.g., Vidore2ESGReportsHLRetrieval)
language = None
@@ -269,7 +272,7 @@ def evaluate_task(
)
# Search queries
task_prompt = task_config.get("prompt")
task_prompt = cast(Optional[dict[str, str]], task_config.get("prompt"))
results = evaluator.search_queries(
queries=queries,
corpus_ids=corpus_ids_ordered,