Compare commits
3 Commits
revert-161
...
feat/add-c
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
aaadb00e44 | ||
|
|
76cc798e3e | ||
|
|
d599566fd7 |
@@ -8,10 +8,9 @@ from pathlib import Path
|
|||||||
# Add the current directory to path to import leann_multi_vector
|
# Add the current directory to path to import leann_multi_vector
|
||||||
sys.path.insert(0, str(Path(__file__).parent))
|
sys.path.insert(0, str(Path(__file__).parent))
|
||||||
|
|
||||||
from PIL import Image
|
|
||||||
import torch
|
import torch
|
||||||
|
from leann_multi_vector import _embed_images, _ensure_repo_paths_importable, _load_colvision
|
||||||
from leann_multi_vector import _load_colvision, _embed_images, _ensure_repo_paths_importable
|
from PIL import Image
|
||||||
|
|
||||||
# Ensure repo paths are importable
|
# Ensure repo paths are importable
|
||||||
_ensure_repo_paths_importable(__file__)
|
_ensure_repo_paths_importable(__file__)
|
||||||
@@ -23,7 +22,7 @@ os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
|||||||
def create_test_image():
|
def create_test_image():
|
||||||
"""Create a simple test image."""
|
"""Create a simple test image."""
|
||||||
# Create a simple RGB image (800x600)
|
# Create a simple RGB image (800x600)
|
||||||
img = Image.new('RGB', (800, 600), color='white')
|
img = Image.new("RGB", (800, 600), color="white")
|
||||||
return img
|
return img
|
||||||
|
|
||||||
|
|
||||||
@@ -42,8 +41,8 @@ def load_test_image_from_file():
|
|||||||
for img_dir in possible_paths:
|
for img_dir in possible_paths:
|
||||||
if img_dir.exists():
|
if img_dir.exists():
|
||||||
# Find first image file
|
# Find first image file
|
||||||
for ext in ['.png', '.jpg', '.jpeg']:
|
for ext in [".png", ".jpg", ".jpeg"]:
|
||||||
for img_file in img_dir.glob(f'*{ext}'):
|
for img_file in img_dir.glob(f"*{ext}"):
|
||||||
print(f"Loading test image from: {img_file}")
|
print(f"Loading test image from: {img_file}")
|
||||||
return Image.open(img_file)
|
return Image.open(img_file)
|
||||||
|
|
||||||
@@ -65,8 +64,8 @@ def main():
|
|||||||
print(f"✓ Loaded image: {test_image.size} ({test_image.mode})")
|
print(f"✓ Loaded image: {test_image.size} ({test_image.mode})")
|
||||||
|
|
||||||
# Convert to RGB if needed
|
# Convert to RGB if needed
|
||||||
if test_image.mode != 'RGB':
|
if test_image.mode != "RGB":
|
||||||
test_image = test_image.convert('RGB')
|
test_image = test_image.convert("RGB")
|
||||||
print(f"✓ Converted to RGB: {test_image.size}")
|
print(f"✓ Converted to RGB: {test_image.size}")
|
||||||
|
|
||||||
# Step 2: Load model
|
# Step 2: Load model
|
||||||
@@ -77,14 +76,15 @@ def main():
|
|||||||
print(f"✓ Device: {device_str}, dtype: {dtype}")
|
print(f"✓ Device: {device_str}, dtype: {dtype}")
|
||||||
|
|
||||||
# Print model info
|
# Print model info
|
||||||
if hasattr(model, 'device'):
|
if hasattr(model, "device"):
|
||||||
print(f"✓ Model device: {model.device}")
|
print(f"✓ Model device: {model.device}")
|
||||||
if hasattr(model, 'dtype'):
|
if hasattr(model, "dtype"):
|
||||||
print(f"✓ Model dtype: {model.dtype}")
|
print(f"✓ Model dtype: {model.dtype}")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"✗ Error loading model: {e}")
|
print(f"✗ Error loading model: {e}")
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -97,14 +97,14 @@ def main():
|
|||||||
|
|
||||||
doc_vecs = _embed_images(model, processor, images)
|
doc_vecs = _embed_images(model, processor, images)
|
||||||
|
|
||||||
print(f"✓ Forward pass completed!")
|
print("✓ Forward pass completed!")
|
||||||
print(f"✓ Number of embeddings: {len(doc_vecs)}")
|
print(f"✓ Number of embeddings: {len(doc_vecs)}")
|
||||||
|
|
||||||
if len(doc_vecs) > 0:
|
if len(doc_vecs) > 0:
|
||||||
emb = doc_vecs[0]
|
emb = doc_vecs[0]
|
||||||
print(f"✓ Embedding shape: {emb.shape}")
|
print(f"✓ Embedding shape: {emb.shape}")
|
||||||
print(f"✓ Embedding dtype: {emb.dtype}")
|
print(f"✓ Embedding dtype: {emb.dtype}")
|
||||||
print(f"✓ Embedding stats:")
|
print("✓ Embedding stats:")
|
||||||
print(f" - Min: {emb.min().item():.4f}")
|
print(f" - Min: {emb.min().item():.4f}")
|
||||||
print(f" - Max: {emb.max().item():.4f}")
|
print(f" - Max: {emb.max().item():.4f}")
|
||||||
print(f" - Mean: {emb.mean().item():.4f}")
|
print(f" - Mean: {emb.mean().item():.4f}")
|
||||||
@@ -119,6 +119,7 @@ def main():
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"✗ Error during forward pass: {e}")
|
print(f"✗ Error during forward pass: {e}")
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -129,4 +130,3 @@ def main():
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|
||||||
|
|||||||
@@ -152,20 +152,65 @@ def _select_device_and_dtype():
|
|||||||
|
|
||||||
def _load_colvision(model_choice: str):
|
def _load_colvision(model_choice: str):
|
||||||
import torch
|
import torch
|
||||||
from colpali_engine.models import ColPali, ColQwen2, ColQwen2Processor
|
from colpali_engine.models import (
|
||||||
|
ColPali,
|
||||||
|
ColQwen2,
|
||||||
|
ColQwen2_5,
|
||||||
|
ColQwen2_5_Processor,
|
||||||
|
ColQwen2Processor,
|
||||||
|
)
|
||||||
from colpali_engine.models.paligemma.colpali.processing_colpali import ColPaliProcessor
|
from colpali_engine.models.paligemma.colpali.processing_colpali import ColPaliProcessor
|
||||||
from transformers.utils.import_utils import is_flash_attn_2_available
|
from transformers.utils.import_utils import is_flash_attn_2_available
|
||||||
|
|
||||||
device_str, device, dtype = _select_device_and_dtype()
|
device_str, device, dtype = _select_device_and_dtype()
|
||||||
|
|
||||||
|
# Determine model name and type
|
||||||
|
# IMPORTANT: Check colqwen2.5 BEFORE colqwen2 to avoid false matches
|
||||||
|
model_choice_lower = model_choice.lower()
|
||||||
if model_choice == "colqwen2":
|
if model_choice == "colqwen2":
|
||||||
model_name = "vidore/colqwen2-v1.0"
|
model_name = "vidore/colqwen2-v1.0"
|
||||||
# On CPU/MPS we must avoid flash-attn and stay eager; on CUDA prefer flash-attn if available
|
model_type = "colqwen2"
|
||||||
attn_implementation = (
|
elif model_choice == "colqwen2.5" or model_choice == "colqwen25":
|
||||||
"flash_attention_2"
|
model_name = "vidore/colqwen2.5-v0.2"
|
||||||
if (device_str == "cuda" and is_flash_attn_2_available())
|
model_type = "colqwen2.5"
|
||||||
else "eager"
|
elif model_choice == "colpali":
|
||||||
)
|
model_name = "vidore/colpali-v1.2"
|
||||||
|
model_type = "colpali"
|
||||||
|
elif (
|
||||||
|
"colqwen2.5" in model_choice_lower
|
||||||
|
or "colqwen25" in model_choice_lower
|
||||||
|
or "colqwen2_5" in model_choice_lower
|
||||||
|
):
|
||||||
|
# Handle HuggingFace model names like "vidore/colqwen2.5-v0.2"
|
||||||
|
model_name = model_choice
|
||||||
|
model_type = "colqwen2.5"
|
||||||
|
elif "colqwen2" in model_choice_lower and "colqwen2-v1.0" in model_choice_lower:
|
||||||
|
# Handle HuggingFace model names like "vidore/colqwen2-v1.0" (but not colqwen2.5)
|
||||||
|
model_name = model_choice
|
||||||
|
model_type = "colqwen2"
|
||||||
|
elif "colpali" in model_choice_lower:
|
||||||
|
# Handle HuggingFace model names like "vidore/colpali-v1.2"
|
||||||
|
model_name = model_choice
|
||||||
|
model_type = "colpali"
|
||||||
|
else:
|
||||||
|
# Default to colpali for backward compatibility
|
||||||
|
model_name = "vidore/colpali-v1.2"
|
||||||
|
model_type = "colpali"
|
||||||
|
|
||||||
|
# Load model based on type
|
||||||
|
attn_implementation = (
|
||||||
|
"flash_attention_2" if (device_str == "cuda" and is_flash_attn_2_available()) else "eager"
|
||||||
|
)
|
||||||
|
|
||||||
|
if model_type == "colqwen2.5":
|
||||||
|
model = ColQwen2_5.from_pretrained(
|
||||||
|
model_name,
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
device_map=device,
|
||||||
|
attn_implementation=attn_implementation,
|
||||||
|
).eval()
|
||||||
|
processor = ColQwen2_5_Processor.from_pretrained(model_name)
|
||||||
|
elif model_type == "colqwen2":
|
||||||
model = ColQwen2.from_pretrained(
|
model = ColQwen2.from_pretrained(
|
||||||
model_name,
|
model_name,
|
||||||
torch_dtype=torch.bfloat16,
|
torch_dtype=torch.bfloat16,
|
||||||
@@ -173,8 +218,7 @@ def _load_colvision(model_choice: str):
|
|||||||
attn_implementation=attn_implementation,
|
attn_implementation=attn_implementation,
|
||||||
).eval()
|
).eval()
|
||||||
processor = ColQwen2Processor.from_pretrained(model_name)
|
processor = ColQwen2Processor.from_pretrained(model_name)
|
||||||
else:
|
else: # colpali
|
||||||
model_name = "vidore/colpali-v1.2"
|
|
||||||
model = ColPali.from_pretrained(
|
model = ColPali.from_pretrained(
|
||||||
model_name,
|
model_name,
|
||||||
torch_dtype=torch.bfloat16,
|
torch_dtype=torch.bfloat16,
|
||||||
|
|||||||
@@ -90,6 +90,51 @@ VIDORE_V1_TASKS = {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# Task name aliases (short names -> full names)
|
||||||
|
TASK_ALIASES = {
|
||||||
|
"arxivqa": "VidoreArxivQARetrieval",
|
||||||
|
"docvqa": "VidoreDocVQARetrieval",
|
||||||
|
"infovqa": "VidoreInfoVQARetrieval",
|
||||||
|
"tabfquad": "VidoreTabfquadRetrieval",
|
||||||
|
"tatdqa": "VidoreTatdqaRetrieval",
|
||||||
|
"shiftproject": "VidoreShiftProjectRetrieval",
|
||||||
|
"syntheticdocqa_ai": "VidoreSyntheticDocQAAIRetrieval",
|
||||||
|
"syntheticdocqa_energy": "VidoreSyntheticDocQAEnergyRetrieval",
|
||||||
|
"syntheticdocqa_government": "VidoreSyntheticDocQAGovernmentReportsRetrieval",
|
||||||
|
"syntheticdocqa_healthcare": "VidoreSyntheticDocQAHealthcareIndustryRetrieval",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_task_name(task_name: str) -> str:
|
||||||
|
"""Normalize task name (handle aliases)."""
|
||||||
|
task_name_lower = task_name.lower()
|
||||||
|
if task_name in VIDORE_V1_TASKS:
|
||||||
|
return task_name
|
||||||
|
if task_name_lower in TASK_ALIASES:
|
||||||
|
return TASK_ALIASES[task_name_lower]
|
||||||
|
# Try partial match
|
||||||
|
for alias, full_name in TASK_ALIASES.items():
|
||||||
|
if alias in task_name_lower or task_name_lower in alias:
|
||||||
|
return full_name
|
||||||
|
return task_name
|
||||||
|
|
||||||
|
|
||||||
|
def get_safe_model_name(model_name: str) -> str:
|
||||||
|
"""Get a safe model name for use in file paths."""
|
||||||
|
import hashlib
|
||||||
|
import os
|
||||||
|
|
||||||
|
# If it's a path, use basename or hash
|
||||||
|
if os.path.exists(model_name) and os.path.isdir(model_name):
|
||||||
|
# Use basename if it's reasonable, otherwise use hash
|
||||||
|
basename = os.path.basename(model_name.rstrip("/"))
|
||||||
|
if basename and len(basename) < 100 and not basename.startswith("."):
|
||||||
|
return basename
|
||||||
|
# Use hash for very long or problematic paths
|
||||||
|
return hashlib.md5(model_name.encode()).hexdigest()[:16]
|
||||||
|
# For HuggingFace model names, replace / with _
|
||||||
|
return model_name.replace("/", "_").replace(":", "_")
|
||||||
|
|
||||||
|
|
||||||
def load_vidore_v1_data(
|
def load_vidore_v1_data(
|
||||||
dataset_path: str,
|
dataset_path: str,
|
||||||
@@ -181,6 +226,9 @@ def evaluate_task(
|
|||||||
print(f"Evaluating task: {task_name}")
|
print(f"Evaluating task: {task_name}")
|
||||||
print(f"{'=' * 80}")
|
print(f"{'=' * 80}")
|
||||||
|
|
||||||
|
# Normalize task name (handle aliases)
|
||||||
|
task_name = normalize_task_name(task_name)
|
||||||
|
|
||||||
# Get task config
|
# Get task config
|
||||||
if task_name not in VIDORE_V1_TASKS:
|
if task_name not in VIDORE_V1_TASKS:
|
||||||
raise ValueError(f"Unknown task: {task_name}. Available: {list(VIDORE_V1_TASKS.keys())}")
|
raise ValueError(f"Unknown task: {task_name}. Available: {list(VIDORE_V1_TASKS.keys())}")
|
||||||
@@ -223,11 +271,13 @@ def evaluate_task(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Build or load index
|
# Build or load index
|
||||||
|
# Use safe model name for index path (different models need different indexes)
|
||||||
|
safe_model_name = get_safe_model_name(model_name)
|
||||||
index_path_full = index_path if not use_fast_plaid else fast_plaid_index_path
|
index_path_full = index_path if not use_fast_plaid else fast_plaid_index_path
|
||||||
if index_path_full is None:
|
if index_path_full is None:
|
||||||
index_path_full = f"./indexes/{task_name}_{model_name}"
|
index_path_full = f"./indexes/{task_name}_{safe_model_name}"
|
||||||
if use_fast_plaid:
|
if use_fast_plaid:
|
||||||
index_path_full = f"./indexes/{task_name}_{model_name}_fastplaid"
|
index_path_full = f"./indexes/{task_name}_{safe_model_name}_fastplaid"
|
||||||
|
|
||||||
index_or_retriever, corpus_ids_ordered = evaluator.build_index_from_corpus(
|
index_or_retriever, corpus_ids_ordered = evaluator.build_index_from_corpus(
|
||||||
corpus=corpus,
|
corpus=corpus,
|
||||||
@@ -281,8 +331,7 @@ def main():
|
|||||||
"--model",
|
"--model",
|
||||||
type=str,
|
type=str,
|
||||||
default="colqwen2",
|
default="colqwen2",
|
||||||
choices=["colqwen2", "colpali"],
|
help="Model to use: 'colqwen2', 'colpali', or path to a model directory (supports LoRA adapters)",
|
||||||
help="Model to use",
|
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--task",
|
"--task",
|
||||||
@@ -350,11 +399,11 @@ def main():
|
|||||||
|
|
||||||
# Determine tasks to evaluate
|
# Determine tasks to evaluate
|
||||||
if args.task:
|
if args.task:
|
||||||
tasks_to_eval = [args.task]
|
tasks_to_eval = [normalize_task_name(args.task)]
|
||||||
elif args.tasks.lower() == "all":
|
elif args.tasks.lower() == "all":
|
||||||
tasks_to_eval = list(VIDORE_V1_TASKS.keys())
|
tasks_to_eval = list(VIDORE_V1_TASKS.keys())
|
||||||
else:
|
else:
|
||||||
tasks_to_eval = [t.strip() for t in args.tasks.split(",")]
|
tasks_to_eval = [normalize_task_name(t.strip()) for t in args.tasks.split(",")]
|
||||||
|
|
||||||
print(f"Tasks to evaluate: {tasks_to_eval}")
|
print(f"Tasks to evaluate: {tasks_to_eval}")
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user