Add ColQwen2.5 model support and improve model selection (#183)
- Add ColQwen2.5 and ColQwen2_5_Processor imports - Implement smart model type detection for colqwen2, colqwen2.5, and colpali - Add task name aliases for easier benchmark invocation - Add safe model name handling for file paths and index naming - Support custom model paths including LoRA adapters - Improve model choice validation and error handling 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-authored-by: Claude <noreply@anthropic.com>
This commit is contained in:
@@ -152,20 +152,65 @@ def _select_device_and_dtype():
|
||||
|
||||
def _load_colvision(model_choice: str):
|
||||
import torch
|
||||
from colpali_engine.models import ColPali, ColQwen2, ColQwen2Processor
|
||||
from colpali_engine.models import (
|
||||
ColPali,
|
||||
ColQwen2,
|
||||
ColQwen2_5,
|
||||
ColQwen2_5_Processor,
|
||||
ColQwen2Processor,
|
||||
)
|
||||
from colpali_engine.models.paligemma.colpali.processing_colpali import ColPaliProcessor
|
||||
from transformers.utils.import_utils import is_flash_attn_2_available
|
||||
|
||||
device_str, device, dtype = _select_device_and_dtype()
|
||||
|
||||
# Determine model name and type
|
||||
# IMPORTANT: Check colqwen2.5 BEFORE colqwen2 to avoid false matches
|
||||
model_choice_lower = model_choice.lower()
|
||||
if model_choice == "colqwen2":
|
||||
model_name = "vidore/colqwen2-v1.0"
|
||||
# On CPU/MPS we must avoid flash-attn and stay eager; on CUDA prefer flash-attn if available
|
||||
attn_implementation = (
|
||||
"flash_attention_2"
|
||||
if (device_str == "cuda" and is_flash_attn_2_available())
|
||||
else "eager"
|
||||
)
|
||||
model_type = "colqwen2"
|
||||
elif model_choice == "colqwen2.5" or model_choice == "colqwen25":
|
||||
model_name = "vidore/colqwen2.5-v0.2"
|
||||
model_type = "colqwen2.5"
|
||||
elif model_choice == "colpali":
|
||||
model_name = "vidore/colpali-v1.2"
|
||||
model_type = "colpali"
|
||||
elif (
|
||||
"colqwen2.5" in model_choice_lower
|
||||
or "colqwen25" in model_choice_lower
|
||||
or "colqwen2_5" in model_choice_lower
|
||||
):
|
||||
# Handle HuggingFace model names like "vidore/colqwen2.5-v0.2"
|
||||
model_name = model_choice
|
||||
model_type = "colqwen2.5"
|
||||
elif "colqwen2" in model_choice_lower and "colqwen2-v1.0" in model_choice_lower:
|
||||
# Handle HuggingFace model names like "vidore/colqwen2-v1.0" (but not colqwen2.5)
|
||||
model_name = model_choice
|
||||
model_type = "colqwen2"
|
||||
elif "colpali" in model_choice_lower:
|
||||
# Handle HuggingFace model names like "vidore/colpali-v1.2"
|
||||
model_name = model_choice
|
||||
model_type = "colpali"
|
||||
else:
|
||||
# Default to colpali for backward compatibility
|
||||
model_name = "vidore/colpali-v1.2"
|
||||
model_type = "colpali"
|
||||
|
||||
# Load model based on type
|
||||
attn_implementation = (
|
||||
"flash_attention_2" if (device_str == "cuda" and is_flash_attn_2_available()) else "eager"
|
||||
)
|
||||
|
||||
if model_type == "colqwen2.5":
|
||||
model = ColQwen2_5.from_pretrained(
|
||||
model_name,
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map=device,
|
||||
attn_implementation=attn_implementation,
|
||||
).eval()
|
||||
processor = ColQwen2_5_Processor.from_pretrained(model_name)
|
||||
elif model_type == "colqwen2":
|
||||
model = ColQwen2.from_pretrained(
|
||||
model_name,
|
||||
torch_dtype=torch.bfloat16,
|
||||
@@ -173,8 +218,7 @@ def _load_colvision(model_choice: str):
|
||||
attn_implementation=attn_implementation,
|
||||
).eval()
|
||||
processor = ColQwen2Processor.from_pretrained(model_name)
|
||||
else:
|
||||
model_name = "vidore/colpali-v1.2"
|
||||
else: # colpali
|
||||
model = ColPali.from_pretrained(
|
||||
model_name,
|
||||
torch_dtype=torch.bfloat16,
|
||||
|
||||
Reference in New Issue
Block a user