* Add timing instrumentation and multi-dataset support for multi-vector retrieval - Add timing measurements for search operations (load and core time) - Increase embedding batch size from 1 to 32 for better performance - Add explicit memory cleanup with del all_embeddings - Support loading and merging multiple datasets with different splits - Add CLI arguments for search method selection (ann/exact/exact-all) - Auto-detect image field names across different dataset structures - Print candidate doc counts for performance monitoring 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com> * update vidore * reproduce docvqa results * reproduce docvqa results and add debug file --------- Co-authored-by: Claude <noreply@anthropic.com>
133 lines
4.2 KiB
Python
Executable File
133 lines
4.2 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
"""Simple test script to test colqwen2 forward pass with a single image."""
|
|
|
|
import os
|
|
import sys
|
|
from pathlib import Path
|
|
|
|
# Add the current directory to path to import leann_multi_vector
|
|
sys.path.insert(0, str(Path(__file__).parent))
|
|
|
|
from PIL import Image
|
|
import torch
|
|
|
|
from leann_multi_vector import _load_colvision, _embed_images, _ensure_repo_paths_importable
|
|
|
|
# Ensure repo paths are importable
|
|
_ensure_repo_paths_importable(__file__)
|
|
|
|
# Set environment variable
|
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
|
|
|
|
|
def create_test_image():
|
|
"""Create a simple test image."""
|
|
# Create a simple RGB image (800x600)
|
|
img = Image.new('RGB', (800, 600), color='white')
|
|
return img
|
|
|
|
|
|
def load_test_image_from_file():
|
|
"""Try to load an image from the indexes directory if available."""
|
|
# Try to find an existing image in the indexes directory
|
|
indexes_dir = Path(__file__).parent / "indexes"
|
|
|
|
# Look for images in common locations
|
|
possible_paths = [
|
|
indexes_dir / "vidore_fastplaid" / "images",
|
|
indexes_dir / "colvision_large.leann.images",
|
|
indexes_dir / "colvision.leann.images",
|
|
]
|
|
|
|
for img_dir in possible_paths:
|
|
if img_dir.exists():
|
|
# Find first image file
|
|
for ext in ['.png', '.jpg', '.jpeg']:
|
|
for img_file in img_dir.glob(f'*{ext}'):
|
|
print(f"Loading test image from: {img_file}")
|
|
return Image.open(img_file)
|
|
|
|
return None
|
|
|
|
|
|
def main():
|
|
print("=" * 60)
|
|
print("Testing ColQwen2 Forward Pass")
|
|
print("=" * 60)
|
|
|
|
# Step 1: Load or create test image
|
|
print("\n[Step 1] Loading test image...")
|
|
test_image = load_test_image_from_file()
|
|
if test_image is None:
|
|
print("No existing image found, creating a simple test image...")
|
|
test_image = create_test_image()
|
|
else:
|
|
print(f"✓ Loaded image: {test_image.size} ({test_image.mode})")
|
|
|
|
# Convert to RGB if needed
|
|
if test_image.mode != 'RGB':
|
|
test_image = test_image.convert('RGB')
|
|
print(f"✓ Converted to RGB: {test_image.size}")
|
|
|
|
# Step 2: Load model
|
|
print("\n[Step 2] Loading ColQwen2 model...")
|
|
try:
|
|
model_name, model, processor, device_str, device, dtype = _load_colvision("colqwen2")
|
|
print(f"✓ Model loaded: {model_name}")
|
|
print(f"✓ Device: {device_str}, dtype: {dtype}")
|
|
|
|
# Print model info
|
|
if hasattr(model, 'device'):
|
|
print(f"✓ Model device: {model.device}")
|
|
if hasattr(model, 'dtype'):
|
|
print(f"✓ Model dtype: {model.dtype}")
|
|
|
|
except Exception as e:
|
|
print(f"✗ Error loading model: {e}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
return
|
|
|
|
# Step 3: Test forward pass
|
|
print("\n[Step 3] Running forward pass...")
|
|
try:
|
|
# Use the _embed_images function which handles batching and forward pass
|
|
images = [test_image]
|
|
print(f"Processing {len(images)} image(s)...")
|
|
|
|
doc_vecs = _embed_images(model, processor, images)
|
|
|
|
print(f"✓ Forward pass completed!")
|
|
print(f"✓ Number of embeddings: {len(doc_vecs)}")
|
|
|
|
if len(doc_vecs) > 0:
|
|
emb = doc_vecs[0]
|
|
print(f"✓ Embedding shape: {emb.shape}")
|
|
print(f"✓ Embedding dtype: {emb.dtype}")
|
|
print(f"✓ Embedding stats:")
|
|
print(f" - Min: {emb.min().item():.4f}")
|
|
print(f" - Max: {emb.max().item():.4f}")
|
|
print(f" - Mean: {emb.mean().item():.4f}")
|
|
print(f" - Std: {emb.std().item():.4f}")
|
|
|
|
# Check for NaN or Inf
|
|
if torch.isnan(emb).any():
|
|
print("⚠ Warning: Embedding contains NaN values!")
|
|
if torch.isinf(emb).any():
|
|
print("⚠ Warning: Embedding contains Inf values!")
|
|
|
|
except Exception as e:
|
|
print(f"✗ Error during forward pass: {e}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
return
|
|
|
|
print("\n" + "=" * 60)
|
|
print("Test completed successfully!")
|
|
print("=" * 60)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|
|
|