reproduce docvqa results and add debug file
This commit is contained in:
132
apps/multimodal/vision-based-pdf-multi-vector/colqwen_forward.py
Executable file
132
apps/multimodal/vision-based-pdf-multi-vector/colqwen_forward.py
Executable file
@@ -0,0 +1,132 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Simple test script to test colqwen2 forward pass with a single image."""
|
||||
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Add the current directory to path to import leann_multi_vector
|
||||
sys.path.insert(0, str(Path(__file__).parent))
|
||||
|
||||
from PIL import Image
|
||||
import torch
|
||||
|
||||
from leann_multi_vector import _load_colvision, _embed_images, _ensure_repo_paths_importable
|
||||
|
||||
# Ensure repo paths are importable
|
||||
_ensure_repo_paths_importable(__file__)
|
||||
|
||||
# Set environment variable
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
|
||||
def create_test_image():
|
||||
"""Create a simple test image."""
|
||||
# Create a simple RGB image (800x600)
|
||||
img = Image.new('RGB', (800, 600), color='white')
|
||||
return img
|
||||
|
||||
|
||||
def load_test_image_from_file():
|
||||
"""Try to load an image from the indexes directory if available."""
|
||||
# Try to find an existing image in the indexes directory
|
||||
indexes_dir = Path(__file__).parent / "indexes"
|
||||
|
||||
# Look for images in common locations
|
||||
possible_paths = [
|
||||
indexes_dir / "vidore_fastplaid" / "images",
|
||||
indexes_dir / "colvision_large.leann.images",
|
||||
indexes_dir / "colvision.leann.images",
|
||||
]
|
||||
|
||||
for img_dir in possible_paths:
|
||||
if img_dir.exists():
|
||||
# Find first image file
|
||||
for ext in ['.png', '.jpg', '.jpeg']:
|
||||
for img_file in img_dir.glob(f'*{ext}'):
|
||||
print(f"Loading test image from: {img_file}")
|
||||
return Image.open(img_file)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def main():
|
||||
print("=" * 60)
|
||||
print("Testing ColQwen2 Forward Pass")
|
||||
print("=" * 60)
|
||||
|
||||
# Step 1: Load or create test image
|
||||
print("\n[Step 1] Loading test image...")
|
||||
test_image = load_test_image_from_file()
|
||||
if test_image is None:
|
||||
print("No existing image found, creating a simple test image...")
|
||||
test_image = create_test_image()
|
||||
else:
|
||||
print(f"✓ Loaded image: {test_image.size} ({test_image.mode})")
|
||||
|
||||
# Convert to RGB if needed
|
||||
if test_image.mode != 'RGB':
|
||||
test_image = test_image.convert('RGB')
|
||||
print(f"✓ Converted to RGB: {test_image.size}")
|
||||
|
||||
# Step 2: Load model
|
||||
print("\n[Step 2] Loading ColQwen2 model...")
|
||||
try:
|
||||
model_name, model, processor, device_str, device, dtype = _load_colvision("colqwen2")
|
||||
print(f"✓ Model loaded: {model_name}")
|
||||
print(f"✓ Device: {device_str}, dtype: {dtype}")
|
||||
|
||||
# Print model info
|
||||
if hasattr(model, 'device'):
|
||||
print(f"✓ Model device: {model.device}")
|
||||
if hasattr(model, 'dtype'):
|
||||
print(f"✓ Model dtype: {model.dtype}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"✗ Error loading model: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return
|
||||
|
||||
# Step 3: Test forward pass
|
||||
print("\n[Step 3] Running forward pass...")
|
||||
try:
|
||||
# Use the _embed_images function which handles batching and forward pass
|
||||
images = [test_image]
|
||||
print(f"Processing {len(images)} image(s)...")
|
||||
|
||||
doc_vecs = _embed_images(model, processor, images)
|
||||
|
||||
print(f"✓ Forward pass completed!")
|
||||
print(f"✓ Number of embeddings: {len(doc_vecs)}")
|
||||
|
||||
if len(doc_vecs) > 0:
|
||||
emb = doc_vecs[0]
|
||||
print(f"✓ Embedding shape: {emb.shape}")
|
||||
print(f"✓ Embedding dtype: {emb.dtype}")
|
||||
print(f"✓ Embedding stats:")
|
||||
print(f" - Min: {emb.min().item():.4f}")
|
||||
print(f" - Max: {emb.max().item():.4f}")
|
||||
print(f" - Mean: {emb.mean().item():.4f}")
|
||||
print(f" - Std: {emb.std().item():.4f}")
|
||||
|
||||
# Check for NaN or Inf
|
||||
if torch.isnan(emb).any():
|
||||
print("⚠ Warning: Embedding contains NaN values!")
|
||||
if torch.isinf(emb).any():
|
||||
print("⚠ Warning: Embedding contains Inf values!")
|
||||
|
||||
except Exception as e:
|
||||
print(f"✗ Error during forward pass: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("Test completed successfully!")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@@ -236,14 +236,12 @@ def _embed_queries(model, processor, queries: list[str]) -> list[Any]:
|
||||
|
||||
with torch.no_grad():
|
||||
for i in tqdm(range(0, len(queries), batch_size), desc="Embedding queries"):
|
||||
batch_queries = queries[i:i + batch_size]
|
||||
batch_queries = queries[i : i + batch_size]
|
||||
|
||||
# Match MTEB: manually add query_prefix + text + query_augmentation_token * 10
|
||||
# Then process_queries will add them again (resulting in 20 augmentation tokens total)
|
||||
batch = [
|
||||
processor.query_prefix
|
||||
+ t
|
||||
+ processor.query_augmentation_token * 10
|
||||
processor.query_prefix + t + processor.query_augmentation_token * 10
|
||||
for t in batch_queries
|
||||
]
|
||||
inputs = processor.process_queries(batch)
|
||||
@@ -331,7 +329,11 @@ def _build_fast_plaid_index(
|
||||
if i % 1000 == 0:
|
||||
print(f" Converting embedding {i}/{len(doc_vecs)}...")
|
||||
if not isinstance(vec, torch.Tensor):
|
||||
vec = torch.tensor(vec) if isinstance(vec, np.ndarray) else torch.from_numpy(np.array(vec))
|
||||
vec = (
|
||||
torch.tensor(vec)
|
||||
if isinstance(vec, np.ndarray)
|
||||
else torch.from_numpy(np.array(vec))
|
||||
)
|
||||
# Ensure float32 for Fast-Plaid
|
||||
if vec.dtype != torch.float32:
|
||||
vec = vec.float()
|
||||
@@ -346,19 +348,22 @@ def _build_fast_plaid_index(
|
||||
print(f" Preparing metadata for {len(filepaths)} documents...")
|
||||
metadata_list = []
|
||||
for i, filepath in enumerate(filepaths):
|
||||
metadata_list.append({
|
||||
"filepath": filepath,
|
||||
"index": i,
|
||||
})
|
||||
metadata_list.append(
|
||||
{
|
||||
"filepath": filepath,
|
||||
"index": i,
|
||||
}
|
||||
)
|
||||
|
||||
# Create Fast-Plaid index
|
||||
print(f" Creating FastPlaid object with index path: {index_path}")
|
||||
try:
|
||||
fast_plaid_index = fast_plaid_search.FastPlaid(index=index_path)
|
||||
print(f" FastPlaid object created successfully")
|
||||
print(" FastPlaid object created successfully")
|
||||
except Exception as e:
|
||||
print(f" Error creating FastPlaid object: {type(e).__name__}: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
raise
|
||||
|
||||
@@ -368,10 +373,11 @@ def _build_fast_plaid_index(
|
||||
documents_embeddings=documents_embeddings,
|
||||
metadata=metadata_list,
|
||||
)
|
||||
print(f" Fast-Plaid index created successfully")
|
||||
print(" Fast-Plaid index created successfully")
|
||||
except Exception as e:
|
||||
print(f" Error creating Fast-Plaid index: {type(e).__name__}: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
raise
|
||||
|
||||
@@ -475,7 +481,11 @@ def _search_fast_plaid(
|
||||
|
||||
# Ensure query is a torch tensor
|
||||
if not isinstance(query_vec, torch.Tensor):
|
||||
q_vec_tensor = torch.tensor(query_vec) if isinstance(query_vec, np.ndarray) else torch.from_numpy(np.array(query_vec))
|
||||
q_vec_tensor = (
|
||||
torch.tensor(query_vec)
|
||||
if isinstance(query_vec, np.ndarray)
|
||||
else torch.from_numpy(np.array(query_vec))
|
||||
)
|
||||
else:
|
||||
q_vec_tensor = query_vec
|
||||
|
||||
@@ -508,16 +518,35 @@ def _get_fast_plaid_image(index_path: str, doc_id: int) -> Optional[Image.Image]
|
||||
|
||||
Args:
|
||||
index_path: Path to the Fast-Plaid index
|
||||
doc_id: Document ID
|
||||
doc_id: Document ID returned by Fast-Plaid search
|
||||
|
||||
Returns:
|
||||
PIL Image if found, None otherwise
|
||||
|
||||
Note: Uses metadata['index'] to get the actual file index, as Fast-Plaid
|
||||
doc_id may differ from the file naming index.
|
||||
"""
|
||||
# First get metadata to find the actual index used for file naming
|
||||
metadata = _get_fast_plaid_metadata(index_path, doc_id)
|
||||
if metadata is None:
|
||||
# Fallback: try using doc_id directly
|
||||
file_index = doc_id
|
||||
else:
|
||||
# Use the 'index' field from metadata, which matches the file naming
|
||||
file_index = metadata.get("index", doc_id)
|
||||
|
||||
images_dir = Path(index_path) / "images"
|
||||
image_path = images_dir / f"doc_{doc_id}.png"
|
||||
image_path = images_dir / f"doc_{file_index}.png"
|
||||
|
||||
if image_path.exists():
|
||||
return Image.open(image_path)
|
||||
|
||||
# If not found with index, try doc_id as fallback
|
||||
if file_index != doc_id:
|
||||
fallback_path = images_dir / f"doc_{doc_id}.png"
|
||||
if fallback_path.exists():
|
||||
return Image.open(fallback_path)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
@@ -534,6 +563,7 @@ def _get_fast_plaid_metadata(index_path: str, doc_id: int) -> Optional[dict]:
|
||||
"""
|
||||
try:
|
||||
from fast_plaid import filtering
|
||||
|
||||
metadata_list = filtering.get(index=index_path, subset=[doc_id])
|
||||
if metadata_list and len(metadata_list) > 0:
|
||||
return metadata_list[0]
|
||||
@@ -1060,7 +1090,7 @@ class ViDoReBenchmarkEvaluator:
|
||||
use_fast_plaid: bool = False,
|
||||
top_k: int = 100,
|
||||
first_stage_k: int = 500,
|
||||
k_values: list[int] = None,
|
||||
k_values: Optional[list[int]] = None,
|
||||
):
|
||||
"""
|
||||
Initialize the evaluator.
|
||||
@@ -1087,7 +1117,9 @@ class ViDoReBenchmarkEvaluator:
|
||||
"""Lazy load the model."""
|
||||
if self._model is None:
|
||||
print(f"\nLoading model: {self.model_name}")
|
||||
self._model_name_actual, self._model, self._processor, _, _, _ = _load_colvision(self.model_name)
|
||||
self._model_name_actual, self._model, self._processor, _, _, _ = _load_colvision(
|
||||
self.model_name
|
||||
)
|
||||
print(f"Model loaded: {self._model_name_actual}")
|
||||
|
||||
def build_index_from_corpus(
|
||||
@@ -1141,7 +1173,7 @@ class ViDoReBenchmarkEvaluator:
|
||||
doc_vecs = _embed_images(self._model, self._processor, images)
|
||||
|
||||
retriever = _build_index(index_path, doc_vecs, corpus_ids, images)
|
||||
print(f"LEANN index built")
|
||||
print("LEANN index built")
|
||||
return retriever, corpus_ids
|
||||
|
||||
def search_queries(
|
||||
@@ -1194,8 +1226,11 @@ class ViDoReBenchmarkEvaluator:
|
||||
else:
|
||||
# LEANN search
|
||||
import torch
|
||||
query_np = query_vec.float().numpy() if isinstance(query_vec, torch.Tensor) else query_vec
|
||||
search_results = index_or_retriever.search_exact_all(query_np, topk=self.top_k)
|
||||
|
||||
query_np = (
|
||||
query_vec.float().numpy() if isinstance(query_vec, torch.Tensor) else query_vec
|
||||
)
|
||||
search_results = index_or_retriever.search_exact(query_np, topk=self.top_k)
|
||||
query_results = {}
|
||||
for score, doc_id in search_results:
|
||||
if doc_id < len(corpus_ids):
|
||||
@@ -1210,7 +1245,7 @@ class ViDoReBenchmarkEvaluator:
|
||||
def evaluate_results(
|
||||
results: dict[str, dict[str, float]],
|
||||
qrels: dict[str, dict[str, int]],
|
||||
k_values: list[int] = None,
|
||||
k_values: Optional[list[int]] = None,
|
||||
) -> dict[str, float]:
|
||||
"""
|
||||
Evaluate retrieval results using NDCG and other metrics.
|
||||
@@ -1224,13 +1259,14 @@ class ViDoReBenchmarkEvaluator:
|
||||
Dictionary of metric scores
|
||||
"""
|
||||
try:
|
||||
import pytrec_eval
|
||||
from mteb._evaluators.retrieval_metrics import (
|
||||
calculate_retrieval_scores,
|
||||
make_score_dict,
|
||||
)
|
||||
except ImportError:
|
||||
raise ImportError("pytrec_eval is required for evaluation. Install with: pip install pytrec-eval")
|
||||
raise ImportError(
|
||||
"pytrec_eval is required for evaluation. Install with: pip install pytrec-eval"
|
||||
)
|
||||
|
||||
if k_values is None:
|
||||
k_values = [1, 3, 5, 10, 100]
|
||||
@@ -1255,12 +1291,16 @@ class ViDoReBenchmarkEvaluator:
|
||||
# pytrec_eval only evaluates queries in qrels, so we need to ensure
|
||||
# results contains all queries in qrels, and filter out queries not in qrels
|
||||
results_filtered = {qid: res for qid, res in results.items() if qid in qrels}
|
||||
qrels_filtered = {qid: rel_docs for qid, rel_docs in qrels.items() if qid in results_filtered}
|
||||
qrels_filtered = {
|
||||
qid: rel_docs for qid, rel_docs in qrels.items() if qid in results_filtered
|
||||
}
|
||||
|
||||
print(f"After filtering: {len(results_filtered)} results, {len(qrels_filtered)} qrels")
|
||||
|
||||
if len(results_filtered) != len(qrels_filtered):
|
||||
print(f"Warning: Mismatch between results ({len(results_filtered)}) and qrels ({len(qrels_filtered)}) queries")
|
||||
print(
|
||||
f"Warning: Mismatch between results ({len(results_filtered)}) and qrels ({len(qrels_filtered)}) queries"
|
||||
)
|
||||
missing_in_results = set(qrels.keys()) - set(results.keys())
|
||||
if missing_in_results:
|
||||
print(f"Queries in qrels but not in results: {len(missing_in_results)} queries")
|
||||
@@ -1269,7 +1309,7 @@ class ViDoReBenchmarkEvaluator:
|
||||
# Convert qrels to pytrec_eval format
|
||||
qrels_pytrec = {}
|
||||
for qid, rel_docs in qrels_filtered.items():
|
||||
qrels_pytrec[qid] = {did: score for did, score in rel_docs.items()}
|
||||
qrels_pytrec[qid] = dict(rel_docs.items())
|
||||
|
||||
# Evaluate
|
||||
eval_result = calculate_retrieval_scores(
|
||||
|
||||
@@ -83,7 +83,7 @@ INDEX_PATH: str = "./indexes/colvision_large.leann"
|
||||
# These are now command-line arguments (see CLI overrides section)
|
||||
TOPK: int = 3
|
||||
FIRST_STAGE_K: int = 500
|
||||
REBUILD_INDEX: bool = False
|
||||
REBUILD_INDEX: bool = True
|
||||
|
||||
# Artifacts
|
||||
SAVE_TOP_IMAGE: Optional[str] = "./figures/retrieved_page.png"
|
||||
@@ -122,11 +122,18 @@ parser.add_argument(
|
||||
default="./indexes/colvision_fastplaid",
|
||||
help="Path to the Fast-Plaid index. Default: './indexes/colvision_fastplaid'",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--topk",
|
||||
type=int,
|
||||
default=TOPK,
|
||||
help=f"Number of top results to retrieve. Default: {TOPK}",
|
||||
)
|
||||
cli_args, _unknown = parser.parse_known_args()
|
||||
SEARCH_METHOD: str = cli_args.search_method
|
||||
QUERY = cli_args.query # Override QUERY with CLI argument if provided
|
||||
USE_FAST_PLAID: bool = cli_args.use_fast_plaid
|
||||
FAST_PLAID_INDEX_PATH: str = cli_args.fast_plaid_index_path
|
||||
TOPK: int = cli_args.topk # Override TOPK with CLI argument if provided
|
||||
|
||||
# %%
|
||||
|
||||
@@ -548,7 +555,10 @@ if not results:
|
||||
print("No results found.")
|
||||
else:
|
||||
print(f'Top {len(results)} results for query: "{QUERY}"')
|
||||
print("\n[DEBUG] Retrieval details:")
|
||||
top_images: list[Image.Image] = []
|
||||
image_hashes = {} # Track image hashes to detect duplicates
|
||||
|
||||
for rank, (score, doc_id) in enumerate(results, start=1):
|
||||
# Retrieve image and metadata based on index type
|
||||
if USE_FAST_PLAID:
|
||||
@@ -572,8 +582,26 @@ else:
|
||||
path = metadata.get("filepath", "unknown") if metadata else "unknown"
|
||||
top_images.append(image)
|
||||
|
||||
# For HF dataset, path is a descriptive identifier, not a real file path
|
||||
print(f"{rank}) MaxSim: {score:.4f}, Page: {path}")
|
||||
# Calculate image hash to detect duplicates
|
||||
import hashlib
|
||||
import io
|
||||
# Convert image to bytes for hashing
|
||||
img_bytes = io.BytesIO()
|
||||
image.save(img_bytes, format='PNG')
|
||||
image_bytes = img_bytes.getvalue()
|
||||
image_hash = hashlib.md5(image_bytes).hexdigest()[:8]
|
||||
|
||||
# Check if this image was already seen
|
||||
duplicate_info = ""
|
||||
if image_hash in image_hashes:
|
||||
duplicate_info = f" [DUPLICATE of rank {image_hashes[image_hash]}]"
|
||||
else:
|
||||
image_hashes[image_hash] = rank
|
||||
|
||||
# Print detailed information
|
||||
print(f"{rank}) doc_id={doc_id}, MaxSim={score:.4f}, Page={path}, ImageHash={image_hash}{duplicate_info}")
|
||||
if metadata:
|
||||
print(f" Metadata: {metadata}")
|
||||
|
||||
if SAVE_TOP_IMAGE:
|
||||
from pathlib import Path as _Path
|
||||
|
||||
@@ -28,11 +28,9 @@ import os
|
||||
from typing import Optional
|
||||
|
||||
from datasets import load_dataset
|
||||
from PIL import Image
|
||||
|
||||
from leann_multi_vector import (
|
||||
_ensure_repo_paths_importable,
|
||||
ViDoReBenchmarkEvaluator,
|
||||
_ensure_repo_paths_importable,
|
||||
)
|
||||
|
||||
_ensure_repo_paths_importable(__file__)
|
||||
@@ -128,7 +126,9 @@ def load_vidore_v1_data(
|
||||
elif "page_image" in row:
|
||||
corpus[corpus_id] = row["page_image"]
|
||||
else:
|
||||
raise ValueError(f"No image field found in corpus. Available fields: {list(row.keys())}")
|
||||
raise ValueError(
|
||||
f"No image field found in corpus. Available fields: {list(row.keys())}"
|
||||
)
|
||||
|
||||
# Load qrels (relevance judgments)
|
||||
qrels_ds = load_dataset(dataset_path, "qrels", split=split, revision=revision)
|
||||
@@ -141,7 +141,9 @@ def load_vidore_v1_data(
|
||||
qrels[query_id] = {}
|
||||
qrels[query_id][corpus_id] = int(row["score"])
|
||||
|
||||
print(f"Loaded {len(queries)} queries, {len(corpus)} corpus items, {len(qrels)} query-relevance mappings")
|
||||
print(
|
||||
f"Loaded {len(queries)} queries, {len(corpus)} corpus items, {len(qrels)} query-relevance mappings"
|
||||
)
|
||||
|
||||
# Filter qrels to only include queries that exist
|
||||
qrels = {qid: rel_docs for qid, rel_docs in qrels.items() if qid in queries}
|
||||
@@ -149,9 +151,13 @@ def load_vidore_v1_data(
|
||||
# Filter out queries without any relevant documents (matching MTEB behavior)
|
||||
# This is important for correct NDCG calculation
|
||||
qrels_filtered = {qid: rel_docs for qid, rel_docs in qrels.items() if len(rel_docs) > 0}
|
||||
queries_filtered = {qid: query_text for qid, query_text in queries.items() if qid in qrels_filtered}
|
||||
queries_filtered = {
|
||||
qid: query_text for qid, query_text in queries.items() if qid in qrels_filtered
|
||||
}
|
||||
|
||||
print(f"After filtering queries without positives: {len(queries_filtered)} queries, {len(qrels_filtered)} query-relevance mappings")
|
||||
print(
|
||||
f"After filtering queries without positives: {len(queries_filtered)} queries, {len(qrels_filtered)} query-relevance mappings"
|
||||
)
|
||||
|
||||
return corpus, queries_filtered, qrels_filtered
|
||||
|
||||
@@ -165,15 +171,15 @@ def evaluate_task(
|
||||
rebuild_index: bool = False,
|
||||
top_k: int = 1000,
|
||||
first_stage_k: int = 500,
|
||||
k_values: list[int] = [1, 3, 5, 10, 20, 100, 1000],
|
||||
k_values: Optional[list[int]] = None,
|
||||
output_dir: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Evaluate a single ViDoRe v1 task.
|
||||
"""
|
||||
print(f"\n{'='*80}")
|
||||
print(f"\n{'=' * 80}")
|
||||
print(f"Evaluating task: {task_name}")
|
||||
print(f"{'='*80}")
|
||||
print(f"{'=' * 80}")
|
||||
|
||||
# Get task config
|
||||
if task_name not in VIDORE_V1_TASKS:
|
||||
@@ -190,6 +196,10 @@ def evaluate_task(
|
||||
split="test",
|
||||
)
|
||||
|
||||
# Initialize k_values if not provided
|
||||
if k_values is None:
|
||||
k_values = [1, 3, 5, 10, 20, 100, 1000]
|
||||
|
||||
# Check if we have any queries
|
||||
if len(queries) == 0:
|
||||
print(f"\nWarning: No queries found for task {task_name}. Skipping evaluation.")
|
||||
@@ -239,9 +249,9 @@ def evaluate_task(
|
||||
scores = evaluator.evaluate_results(results, qrels, k_values=k_values)
|
||||
|
||||
# Print results
|
||||
print(f"\n{'='*80}")
|
||||
print(f"\n{'=' * 80}")
|
||||
print(f"Results for {task_name}:")
|
||||
print(f"{'='*80}")
|
||||
print(f"{'=' * 80}")
|
||||
for metric, value in scores.items():
|
||||
if isinstance(value, (int, float)):
|
||||
print(f" {metric}: {value:.5f}")
|
||||
@@ -368,14 +378,15 @@ def main():
|
||||
except Exception as e:
|
||||
print(f"\nError evaluating {task_name}: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
continue
|
||||
|
||||
# Print summary
|
||||
if all_scores:
|
||||
print(f"\n{'='*80}")
|
||||
print(f"\n{'=' * 80}")
|
||||
print("SUMMARY")
|
||||
print(f"{'='*80}")
|
||||
print(f"{'=' * 80}")
|
||||
for task_name, scores in all_scores.items():
|
||||
print(f"\n{task_name}:")
|
||||
# Print main metrics
|
||||
@@ -386,4 +397,3 @@ def main():
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
|
||||
@@ -28,11 +28,9 @@ import os
|
||||
from typing import Optional
|
||||
|
||||
from datasets import load_dataset
|
||||
from PIL import Image
|
||||
|
||||
from leann_multi_vector import (
|
||||
_ensure_repo_paths_importable,
|
||||
ViDoReBenchmarkEvaluator,
|
||||
_ensure_repo_paths_importable,
|
||||
)
|
||||
|
||||
_ensure_repo_paths_importable(__file__)
|
||||
@@ -105,21 +103,27 @@ def load_vidore_v2_data(
|
||||
query_ds_filtered = query_ds.filter(lambda x: x.get("language") == dataset_language)
|
||||
# Check if filtering resulted in empty dataset
|
||||
if len(query_ds_filtered) == 0:
|
||||
print(f"Warning: No queries found after filtering by language '{language}' (mapped to '{dataset_language}').")
|
||||
print(
|
||||
f"Warning: No queries found after filtering by language '{language}' (mapped to '{dataset_language}')."
|
||||
)
|
||||
# Try with original language value (dataset might use simple names like 'english')
|
||||
print(f"Trying with original language value '{language}'...")
|
||||
query_ds_filtered = query_ds.filter(lambda x: x.get("language") == language)
|
||||
if len(query_ds_filtered) == 0:
|
||||
# Try to get a sample to see actual language values
|
||||
try:
|
||||
sample_ds = load_dataset(dataset_path, "queries", split=split, revision=revision)
|
||||
sample_ds = load_dataset(
|
||||
dataset_path, "queries", split=split, revision=revision
|
||||
)
|
||||
if len(sample_ds) > 0 and "language" in sample_ds.column_names:
|
||||
sample_langs = set(sample_ds["language"])
|
||||
print(f"Available language values in dataset: {sample_langs}")
|
||||
except Exception:
|
||||
pass
|
||||
else:
|
||||
print(f"Found {len(query_ds_filtered)} queries using original language value '{language}'")
|
||||
print(
|
||||
f"Found {len(query_ds_filtered)} queries using original language value '{language}'"
|
||||
)
|
||||
query_ds = query_ds_filtered
|
||||
|
||||
queries = {}
|
||||
@@ -139,7 +143,9 @@ def load_vidore_v2_data(
|
||||
elif "page_image" in row:
|
||||
corpus[corpus_id] = row["page_image"]
|
||||
else:
|
||||
raise ValueError(f"No image field found in corpus. Available fields: {list(row.keys())}")
|
||||
raise ValueError(
|
||||
f"No image field found in corpus. Available fields: {list(row.keys())}"
|
||||
)
|
||||
|
||||
# Load qrels (relevance judgments)
|
||||
qrels_ds = load_dataset(dataset_path, "qrels", split=split, revision=revision)
|
||||
@@ -152,7 +158,9 @@ def load_vidore_v2_data(
|
||||
qrels[query_id] = {}
|
||||
qrels[query_id][corpus_id] = int(row["score"])
|
||||
|
||||
print(f"Loaded {len(queries)} queries, {len(corpus)} corpus items, {len(qrels)} query-relevance mappings")
|
||||
print(
|
||||
f"Loaded {len(queries)} queries, {len(corpus)} corpus items, {len(qrels)} query-relevance mappings"
|
||||
)
|
||||
|
||||
# Filter qrels to only include queries that exist
|
||||
qrels = {qid: rel_docs for qid, rel_docs in qrels.items() if qid in queries}
|
||||
@@ -160,9 +168,13 @@ def load_vidore_v2_data(
|
||||
# Filter out queries without any relevant documents (matching MTEB behavior)
|
||||
# This is important for correct NDCG calculation
|
||||
qrels_filtered = {qid: rel_docs for qid, rel_docs in qrels.items() if len(rel_docs) > 0}
|
||||
queries_filtered = {qid: query_text for qid, query_text in queries.items() if qid in qrels_filtered}
|
||||
queries_filtered = {
|
||||
qid: query_text for qid, query_text in queries.items() if qid in qrels_filtered
|
||||
}
|
||||
|
||||
print(f"After filtering queries without positives: {len(queries_filtered)} queries, {len(qrels_filtered)} query-relevance mappings")
|
||||
print(
|
||||
f"After filtering queries without positives: {len(queries_filtered)} queries, {len(qrels_filtered)} query-relevance mappings"
|
||||
)
|
||||
|
||||
return corpus, queries_filtered, qrels_filtered
|
||||
|
||||
@@ -177,15 +189,15 @@ def evaluate_task(
|
||||
rebuild_index: bool = False,
|
||||
top_k: int = 100,
|
||||
first_stage_k: int = 500,
|
||||
k_values: list[int] = [1, 3, 5, 10, 100],
|
||||
k_values: Optional[list[int]] = None,
|
||||
output_dir: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Evaluate a single ViDoRe v2 task.
|
||||
"""
|
||||
print(f"\n{'='*80}")
|
||||
print(f"\n{'=' * 80}")
|
||||
print(f"Evaluating task: {task_name}")
|
||||
print(f"{'='*80}")
|
||||
print(f"{'=' * 80}")
|
||||
|
||||
# Get task config
|
||||
if task_name not in VIDORE_V2_TASKS:
|
||||
@@ -207,6 +219,10 @@ def evaluate_task(
|
||||
else:
|
||||
language = None
|
||||
|
||||
# Initialize k_values if not provided
|
||||
if k_values is None:
|
||||
k_values = [1, 3, 5, 10, 100]
|
||||
|
||||
# Load data
|
||||
corpus, queries, qrels = load_vidore_v2_data(
|
||||
dataset_path=dataset_path,
|
||||
@@ -217,7 +233,9 @@ def evaluate_task(
|
||||
|
||||
# Check if we have any queries
|
||||
if len(queries) == 0:
|
||||
print(f"\nWarning: No queries found for task {task_name} with language {language}. Skipping evaluation.")
|
||||
print(
|
||||
f"\nWarning: No queries found for task {task_name} with language {language}. Skipping evaluation."
|
||||
)
|
||||
# Return zero scores
|
||||
scores = {}
|
||||
for k in k_values:
|
||||
@@ -264,9 +282,9 @@ def evaluate_task(
|
||||
scores = evaluator.evaluate_results(results, qrels, k_values=k_values)
|
||||
|
||||
# Print results
|
||||
print(f"\n{'='*80}")
|
||||
print(f"\n{'=' * 80}")
|
||||
print(f"Results for {task_name}:")
|
||||
print(f"{'='*80}")
|
||||
print(f"{'=' * 80}")
|
||||
for metric, value in scores.items():
|
||||
if isinstance(value, (int, float)):
|
||||
print(f" {metric}: {value:.5f}")
|
||||
@@ -400,14 +418,15 @@ def main():
|
||||
except Exception as e:
|
||||
print(f"\nError evaluating {task_name}: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
continue
|
||||
|
||||
# Print summary
|
||||
if all_scores:
|
||||
print(f"\n{'='*80}")
|
||||
print(f"\n{'=' * 80}")
|
||||
print("SUMMARY")
|
||||
print(f"{'='*80}")
|
||||
print(f"{'=' * 80}")
|
||||
for task_name, scores in all_scores.items():
|
||||
print(f"\n{task_name}:")
|
||||
# Print main metrics
|
||||
@@ -418,4 +437,3 @@ def main():
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user