This commit is contained in:
yichuan-w
2025-12-20 01:27:54 +00:00
parent 12951ad4d5
commit 9996c29618

View File

@@ -1,6 +1,7 @@
import concurrent.futures import concurrent.futures
import glob import glob
import json import json
import logging
import os import os
import re import re
import sys import sys
@@ -12,6 +13,8 @@ import numpy as np
from PIL import Image from PIL import Image
from tqdm import tqdm from tqdm import tqdm
logger = logging.getLogger(__name__)
def _ensure_repo_paths_importable(current_file: str) -> None: def _ensure_repo_paths_importable(current_file: str) -> None:
"""Make local leann packages importable without installing (mirrors multi-vector-leann.py).""" """Make local leann packages importable without installing (mirrors multi-vector-leann.py)."""
@@ -203,6 +206,8 @@ def _select_device_and_dtype():
def _load_colvision(model_choice: str): def _load_colvision(model_choice: str):
import os
import torch import torch
from colpali_engine.models import ( from colpali_engine.models import (
ColPali, ColPali,
@@ -214,6 +219,16 @@ def _load_colvision(model_choice: str):
from colpali_engine.models.paligemma.colpali.processing_colpali import ColPaliProcessor from colpali_engine.models.paligemma.colpali.processing_colpali import ColPaliProcessor
from transformers.utils.import_utils import is_flash_attn_2_available from transformers.utils.import_utils import is_flash_attn_2_available
# Force HuggingFace Hub to use HF endpoint, avoid Google Drive
# Set environment variables to ensure models are downloaded from HuggingFace
os.environ.setdefault("HF_ENDPOINT", "https://huggingface.co")
os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1")
# Log model loading info
logger.info(f"Loading ColVision model: {model_choice}")
logger.info(f"HF_ENDPOINT: {os.environ.get('HF_ENDPOINT', 'not set')}")
logger.info("Models will be downloaded from HuggingFace Hub, not Google Drive")
device_str, device, dtype = _select_device_and_dtype() device_str, device, dtype = _select_device_and_dtype()
# Determine model name and type # Determine model name and type
@@ -254,29 +269,36 @@ def _load_colvision(model_choice: str):
"flash_attention_2" if (device_str == "cuda" and is_flash_attn_2_available()) else "eager" "flash_attention_2" if (device_str == "cuda" and is_flash_attn_2_available()) else "eager"
) )
# Load model from HuggingFace Hub (not Google Drive)
# Use local_files_only=False to ensure download from HF if not cached
if model_type == "colqwen2.5": if model_type == "colqwen2.5":
model = ColQwen2_5.from_pretrained( model = ColQwen2_5.from_pretrained(
model_name, model_name,
torch_dtype=torch.bfloat16, torch_dtype=torch.bfloat16,
device_map=device, device_map=device,
attn_implementation=attn_implementation, attn_implementation=attn_implementation,
local_files_only=False, # Ensure download from HuggingFace Hub
).eval() ).eval()
processor = ColQwen2_5_Processor.from_pretrained(model_name) processor = ColQwen2_5_Processor.from_pretrained(model_name, local_files_only=False)
elif model_type == "colqwen2": elif model_type == "colqwen2":
model = ColQwen2.from_pretrained( model = ColQwen2.from_pretrained(
model_name, model_name,
torch_dtype=torch.bfloat16, torch_dtype=torch.bfloat16,
device_map=device, device_map=device,
attn_implementation=attn_implementation, attn_implementation=attn_implementation,
local_files_only=False, # Ensure download from HuggingFace Hub
).eval() ).eval()
processor = ColQwen2Processor.from_pretrained(model_name) processor = ColQwen2Processor.from_pretrained(model_name, local_files_only=False)
else: # colpali else: # colpali
model = ColPali.from_pretrained( model = ColPali.from_pretrained(
model_name, model_name,
torch_dtype=torch.bfloat16, torch_dtype=torch.bfloat16,
device_map=device, device_map=device,
local_files_only=False, # Ensure download from HuggingFace Hub
).eval() ).eval()
processor = cast(ColPaliProcessor, ColPaliProcessor.from_pretrained(model_name)) processor = cast(
ColPaliProcessor, ColPaliProcessor.from_pretrained(model_name, local_files_only=False)
)
return model_name, model, processor, device_str, device, dtype return model_name, model, processor, device_str, device, dtype