Fix multimodal benchmark scripts type errors

- Fix undefined LeannRetriever -> LeannMultiVector
- Add proper type casts for HuggingFace Dataset iteration
- Cast task config values to correct types
- Add type annotations for dataset row dicts
This commit is contained in:
Andy Lee
2025-12-23 09:17:47 +00:00
parent de56ab8fa7
commit 8aa4c7e5f2
4 changed files with 68 additions and 61 deletions

View File

@@ -18,10 +18,11 @@ _repo_root = Path(__file__).resolve().parents[3]
_leann_core_src = _repo_root / "packages" / "leann-core" / "src" _leann_core_src = _repo_root / "packages" / "leann-core" / "src"
_leann_hnsw_pkg = _repo_root / "packages" / "leann-backend-hnsw" _leann_hnsw_pkg = _repo_root / "packages" / "leann-backend-hnsw"
if str(_leann_core_src) not in sys.path: if str(_leann_core_src) not in sys.path:
sys.path.append(str(_leann_core_src)) sys.path.insert(0, str(_leann_core_src))
if str(_leann_hnsw_pkg) not in sys.path: if str(_leann_hnsw_pkg) not in sys.path:
sys.path.append(str(_leann_hnsw_pkg)) sys.path.insert(0, str(_leann_hnsw_pkg))
from leann_multi_vector import LeannMultiVector
import torch import torch
from colpali_engine.models import ColPali from colpali_engine.models import ColPali
@@ -93,9 +94,9 @@ for batch_doc in tqdm(dataloader):
print(ds[0].shape) print(ds[0].shape)
# %% # %%
# Build HNSW index via LeannRetriever primitives and run search # Build HNSW index via LeannMultiVector primitives and run search
index_path = "./indexes/colpali.leann" index_path = "./indexes/colpali.leann"
retriever = LeannRetriever(index_path=index_path, dim=int(ds[0].shape[-1])) retriever = LeannMultiVector(index_path=index_path, dim=int(ds[0].shape[-1]))
retriever.create_collection() retriever.create_collection()
filepaths = [os.path.join("./pages", name) for name in page_filenames] filepaths = [os.path.join("./pages", name) for name in page_filenames]
for i in range(len(filepaths)): for i in range(len(filepaths)):

View File

@@ -5,7 +5,7 @@ import argparse
import faulthandler import faulthandler
import os import os
import time import time
from typing import Any, Optional from typing import Any, Optional, cast
import numpy as np import numpy as np
from PIL import Image from PIL import Image
@@ -223,7 +223,7 @@ if need_to_build_index:
# Use filenames as identifiers instead of full paths for cleaner metadata # Use filenames as identifiers instead of full paths for cleaner metadata
filepaths = [os.path.basename(fp) for fp in filepaths] filepaths = [os.path.basename(fp) for fp in filepaths]
elif USE_HF_DATASET: elif USE_HF_DATASET:
from datasets import load_dataset, concatenate_datasets, DatasetDict from datasets import Dataset, DatasetDict, concatenate_datasets, load_dataset
# Determine which datasets to load # Determine which datasets to load
if DATASET_NAMES is not None: if DATASET_NAMES is not None:
@@ -281,12 +281,12 @@ if need_to_build_index:
splits_to_load = DATASET_SPLITS splits_to_load = DATASET_SPLITS
# Load and concatenate multiple splits for this dataset # Load and concatenate multiple splits for this dataset
datasets_to_concat = [] datasets_to_concat: list[Dataset] = []
for split in splits_to_load: for split in splits_to_load:
if split not in dataset_dict: if split not in dataset_dict:
print(f" Warning: Split '{split}' not found in dataset. Available splits: {list(dataset_dict.keys())}") print(f" Warning: Split '{split}' not found in dataset. Available splits: {list(dataset_dict.keys())}")
continue continue
split_dataset = dataset_dict[split] split_dataset = cast(Dataset, dataset_dict[split])
print(f" Loaded split '{split}': {len(split_dataset)} pages") print(f" Loaded split '{split}': {len(split_dataset)} pages")
datasets_to_concat.append(split_dataset) datasets_to_concat.append(split_dataset)

View File

@@ -25,9 +25,9 @@ Usage:
import argparse import argparse
import json import json
import os import os
from typing import Optional from typing import Any, Optional, cast
from datasets import load_dataset from datasets import Dataset, load_dataset
from leann_multi_vector import ( from leann_multi_vector import (
ViDoReBenchmarkEvaluator, ViDoReBenchmarkEvaluator,
_ensure_repo_paths_importable, _ensure_repo_paths_importable,
@@ -151,40 +151,43 @@ def load_vidore_v1_data(
""" """
print(f"Loading dataset: {dataset_path} (split={split})") print(f"Loading dataset: {dataset_path} (split={split})")
# Load queries # Load queries - cast to Dataset since we know split returns Dataset not DatasetDict
query_ds = load_dataset(dataset_path, "queries", split=split, revision=revision) query_ds = cast(Dataset, load_dataset(dataset_path, "queries", split=split, revision=revision))
queries = {} queries: dict[str, str] = {}
for row in query_ds: for row in query_ds:
query_id = f"query-{split}-{row['query-id']}" row_dict = cast(dict[str, Any], row)
queries[query_id] = row["query"] query_id = f"query-{split}-{row_dict['query-id']}"
queries[query_id] = row_dict["query"]
# Load corpus (images) # Load corpus (images) - cast to Dataset
corpus_ds = load_dataset(dataset_path, "corpus", split=split, revision=revision) corpus_ds = cast(Dataset, load_dataset(dataset_path, "corpus", split=split, revision=revision))
corpus = {} corpus: dict[str, Any] = {}
for row in corpus_ds: for row in corpus_ds:
corpus_id = f"corpus-{split}-{row['corpus-id']}" row_dict = cast(dict[str, Any], row)
corpus_id = f"corpus-{split}-{row_dict['corpus-id']}"
# Extract image from the dataset row # Extract image from the dataset row
if "image" in row: if "image" in row_dict:
corpus[corpus_id] = row["image"] corpus[corpus_id] = row_dict["image"]
elif "page_image" in row: elif "page_image" in row_dict:
corpus[corpus_id] = row["page_image"] corpus[corpus_id] = row_dict["page_image"]
else: else:
raise ValueError( raise ValueError(
f"No image field found in corpus. Available fields: {list(row.keys())}" f"No image field found in corpus. Available fields: {list(row_dict.keys())}"
) )
# Load qrels (relevance judgments) # Load qrels (relevance judgments) - cast to Dataset
qrels_ds = load_dataset(dataset_path, "qrels", split=split, revision=revision) qrels_ds = cast(Dataset, load_dataset(dataset_path, "qrels", split=split, revision=revision))
qrels = {} qrels: dict[str, dict[str, int]] = {}
for row in qrels_ds: for row in qrels_ds:
query_id = f"query-{split}-{row['query-id']}" row_dict = cast(dict[str, Any], row)
corpus_id = f"corpus-{split}-{row['corpus-id']}" query_id = f"query-{split}-{row_dict['query-id']}"
corpus_id = f"corpus-{split}-{row_dict['corpus-id']}"
if query_id not in qrels: if query_id not in qrels:
qrels[query_id] = {} qrels[query_id] = {}
qrels[query_id][corpus_id] = int(row["score"]) qrels[query_id][corpus_id] = int(row_dict["score"])
print( print(
f"Loaded {len(queries)} queries, {len(corpus)} corpus items, {len(qrels)} query-relevance mappings" f"Loaded {len(queries)} queries, {len(corpus)} corpus items, {len(qrels)} query-relevance mappings"
@@ -234,8 +237,8 @@ def evaluate_task(
raise ValueError(f"Unknown task: {task_name}. Available: {list(VIDORE_V1_TASKS.keys())}") raise ValueError(f"Unknown task: {task_name}. Available: {list(VIDORE_V1_TASKS.keys())}")
task_config = VIDORE_V1_TASKS[task_name] task_config = VIDORE_V1_TASKS[task_name]
dataset_path = task_config["dataset_path"] dataset_path = str(task_config["dataset_path"])
revision = task_config["revision"] revision = str(task_config["revision"])
# Load data # Load data
corpus, queries, qrels = load_vidore_v1_data( corpus, queries, qrels = load_vidore_v1_data(
@@ -286,7 +289,7 @@ def evaluate_task(
) )
# Search queries # Search queries
task_prompt = task_config.get("prompt") task_prompt = cast(Optional[dict[str, str]], task_config.get("prompt"))
results = evaluator.search_queries( results = evaluator.search_queries(
queries=queries, queries=queries,
corpus_ids=corpus_ids_ordered, corpus_ids=corpus_ids_ordered,

View File

@@ -25,9 +25,9 @@ Usage:
import argparse import argparse
import json import json
import os import os
from typing import Optional from typing import Any, Optional, cast
from datasets import load_dataset from datasets import Dataset, load_dataset
from leann_multi_vector import ( from leann_multi_vector import (
ViDoReBenchmarkEvaluator, ViDoReBenchmarkEvaluator,
_ensure_repo_paths_importable, _ensure_repo_paths_importable,
@@ -91,8 +91,8 @@ def load_vidore_v2_data(
""" """
print(f"Loading dataset: {dataset_path} (split={split}, language={language})") print(f"Loading dataset: {dataset_path} (split={split}, language={language})")
# Load queries # Load queries - cast to Dataset since we know split returns Dataset not DatasetDict
query_ds = load_dataset(dataset_path, "queries", split=split, revision=revision) query_ds = cast(Dataset, load_dataset(dataset_path, "queries", split=split, revision=revision))
# Check if dataset has language field before filtering # Check if dataset has language field before filtering
has_language_field = len(query_ds) > 0 and "language" in query_ds.column_names has_language_field = len(query_ds) > 0 and "language" in query_ds.column_names
@@ -112,8 +112,8 @@ def load_vidore_v2_data(
if len(query_ds_filtered) == 0: if len(query_ds_filtered) == 0:
# Try to get a sample to see actual language values # Try to get a sample to see actual language values
try: try:
sample_ds = load_dataset( sample_ds = cast(
dataset_path, "queries", split=split, revision=revision Dataset, load_dataset(dataset_path, "queries", split=split, revision=revision)
) )
if len(sample_ds) > 0 and "language" in sample_ds.column_names: if len(sample_ds) > 0 and "language" in sample_ds.column_names:
sample_langs = set(sample_ds["language"]) sample_langs = set(sample_ds["language"])
@@ -126,37 +126,40 @@ def load_vidore_v2_data(
) )
query_ds = query_ds_filtered query_ds = query_ds_filtered
queries = {} queries: dict[str, str] = {}
for row in query_ds: for row in query_ds:
query_id = f"query-{split}-{row['query-id']}" row_dict = cast(dict[str, Any], row)
queries[query_id] = row["query"] query_id = f"query-{split}-{row_dict['query-id']}"
queries[query_id] = row_dict["query"]
# Load corpus (images) # Load corpus (images) - cast to Dataset
corpus_ds = load_dataset(dataset_path, "corpus", split=split, revision=revision) corpus_ds = cast(Dataset, load_dataset(dataset_path, "corpus", split=split, revision=revision))
corpus = {} corpus: dict[str, Any] = {}
for row in corpus_ds: for row in corpus_ds:
corpus_id = f"corpus-{split}-{row['corpus-id']}" row_dict = cast(dict[str, Any], row)
corpus_id = f"corpus-{split}-{row_dict['corpus-id']}"
# Extract image from the dataset row # Extract image from the dataset row
if "image" in row: if "image" in row_dict:
corpus[corpus_id] = row["image"] corpus[corpus_id] = row_dict["image"]
elif "page_image" in row: elif "page_image" in row_dict:
corpus[corpus_id] = row["page_image"] corpus[corpus_id] = row_dict["page_image"]
else: else:
raise ValueError( raise ValueError(
f"No image field found in corpus. Available fields: {list(row.keys())}" f"No image field found in corpus. Available fields: {list(row_dict.keys())}"
) )
# Load qrels (relevance judgments) # Load qrels (relevance judgments) - cast to Dataset
qrels_ds = load_dataset(dataset_path, "qrels", split=split, revision=revision) qrels_ds = cast(Dataset, load_dataset(dataset_path, "qrels", split=split, revision=revision))
qrels = {} qrels: dict[str, dict[str, int]] = {}
for row in qrels_ds: for row in qrels_ds:
query_id = f"query-{split}-{row['query-id']}" row_dict = cast(dict[str, Any], row)
corpus_id = f"corpus-{split}-{row['corpus-id']}" query_id = f"query-{split}-{row_dict['query-id']}"
corpus_id = f"corpus-{split}-{row_dict['corpus-id']}"
if query_id not in qrels: if query_id not in qrels:
qrels[query_id] = {} qrels[query_id] = {}
qrels[query_id][corpus_id] = int(row["score"]) qrels[query_id][corpus_id] = int(row_dict["score"])
print( print(
f"Loaded {len(queries)} queries, {len(corpus)} corpus items, {len(qrels)} query-relevance mappings" f"Loaded {len(queries)} queries, {len(corpus)} corpus items, {len(qrels)} query-relevance mappings"
@@ -204,13 +207,13 @@ def evaluate_task(
raise ValueError(f"Unknown task: {task_name}. Available: {list(VIDORE_V2_TASKS.keys())}") raise ValueError(f"Unknown task: {task_name}. Available: {list(VIDORE_V2_TASKS.keys())}")
task_config = VIDORE_V2_TASKS[task_name] task_config = VIDORE_V2_TASKS[task_name]
dataset_path = task_config["dataset_path"] dataset_path = str(task_config["dataset_path"])
revision = task_config["revision"] revision = str(task_config["revision"])
# Determine language # Determine language
if language is None: if language is None:
# Use first language if multiple available # Use first language if multiple available
languages = task_config.get("languages") languages = cast(Optional[list[str]], task_config.get("languages"))
if languages is None: if languages is None:
# Task doesn't support language filtering (e.g., Vidore2ESGReportsHLRetrieval) # Task doesn't support language filtering (e.g., Vidore2ESGReportsHLRetrieval)
language = None language = None
@@ -269,7 +272,7 @@ def evaluate_task(
) )
# Search queries # Search queries
task_prompt = task_config.get("prompt") task_prompt = cast(Optional[dict[str, str]], task_config.get("prompt"))
results = evaluator.search_queries( results = evaluator.search_queries(
queries=queries, queries=queries,
corpus_ids=corpus_ids_ordered, corpus_ids=corpus_ids_ordered,