This commit is contained in:
yichuan-w
2025-12-20 01:27:54 +00:00
parent 12951ad4d5
commit 9996c29618

View File

@@ -1,6 +1,7 @@
import concurrent.futures
import glob
import json
import logging
import os
import re
import sys
@@ -12,6 +13,8 @@ import numpy as np
from PIL import Image
from tqdm import tqdm
logger = logging.getLogger(__name__)
def _ensure_repo_paths_importable(current_file: str) -> None:
"""Make local leann packages importable without installing (mirrors multi-vector-leann.py)."""
@@ -203,6 +206,8 @@ def _select_device_and_dtype():
def _load_colvision(model_choice: str):
import os
import torch
from colpali_engine.models import (
ColPali,
@@ -214,6 +219,16 @@ def _load_colvision(model_choice: str):
from colpali_engine.models.paligemma.colpali.processing_colpali import ColPaliProcessor
from transformers.utils.import_utils import is_flash_attn_2_available
# Force HuggingFace Hub to use HF endpoint, avoid Google Drive
# Set environment variables to ensure models are downloaded from HuggingFace
os.environ.setdefault("HF_ENDPOINT", "https://huggingface.co")
os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1")
# Log model loading info
logger.info(f"Loading ColVision model: {model_choice}")
logger.info(f"HF_ENDPOINT: {os.environ.get('HF_ENDPOINT', 'not set')}")
logger.info("Models will be downloaded from HuggingFace Hub, not Google Drive")
device_str, device, dtype = _select_device_and_dtype()
# Determine model name and type
@@ -254,29 +269,36 @@ def _load_colvision(model_choice: str):
"flash_attention_2" if (device_str == "cuda" and is_flash_attn_2_available()) else "eager"
)
# Load model from HuggingFace Hub (not Google Drive)
# Use local_files_only=False to ensure download from HF if not cached
if model_type == "colqwen2.5":
model = ColQwen2_5.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
device_map=device,
attn_implementation=attn_implementation,
local_files_only=False, # Ensure download from HuggingFace Hub
).eval()
processor = ColQwen2_5_Processor.from_pretrained(model_name)
processor = ColQwen2_5_Processor.from_pretrained(model_name, local_files_only=False)
elif model_type == "colqwen2":
model = ColQwen2.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
device_map=device,
attn_implementation=attn_implementation,
local_files_only=False, # Ensure download from HuggingFace Hub
).eval()
processor = ColQwen2Processor.from_pretrained(model_name)
processor = ColQwen2Processor.from_pretrained(model_name, local_files_only=False)
else: # colpali
model = ColPali.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
device_map=device,
local_files_only=False, # Ensure download from HuggingFace Hub
).eval()
processor = cast(ColPaliProcessor, ColPaliProcessor.from_pretrained(model_name))
processor = cast(
ColPaliProcessor, ColPaliProcessor.from_pretrained(model_name, local_files_only=False)
)
return model_name, model, processor, device_str, device, dtype