format
This commit is contained in:
@@ -1,6 +1,7 @@
|
|||||||
import concurrent.futures
|
import concurrent.futures
|
||||||
import glob
|
import glob
|
||||||
import json
|
import json
|
||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import sys
|
import sys
|
||||||
@@ -12,6 +13,8 @@ import numpy as np
|
|||||||
from PIL import Image
|
from PIL import Image
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def _ensure_repo_paths_importable(current_file: str) -> None:
|
def _ensure_repo_paths_importable(current_file: str) -> None:
|
||||||
"""Make local leann packages importable without installing (mirrors multi-vector-leann.py)."""
|
"""Make local leann packages importable without installing (mirrors multi-vector-leann.py)."""
|
||||||
@@ -203,6 +206,8 @@ def _select_device_and_dtype():
|
|||||||
|
|
||||||
|
|
||||||
def _load_colvision(model_choice: str):
|
def _load_colvision(model_choice: str):
|
||||||
|
import os
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from colpali_engine.models import (
|
from colpali_engine.models import (
|
||||||
ColPali,
|
ColPali,
|
||||||
@@ -214,6 +219,16 @@ def _load_colvision(model_choice: str):
|
|||||||
from colpali_engine.models.paligemma.colpali.processing_colpali import ColPaliProcessor
|
from colpali_engine.models.paligemma.colpali.processing_colpali import ColPaliProcessor
|
||||||
from transformers.utils.import_utils import is_flash_attn_2_available
|
from transformers.utils.import_utils import is_flash_attn_2_available
|
||||||
|
|
||||||
|
# Force HuggingFace Hub to use HF endpoint, avoid Google Drive
|
||||||
|
# Set environment variables to ensure models are downloaded from HuggingFace
|
||||||
|
os.environ.setdefault("HF_ENDPOINT", "https://huggingface.co")
|
||||||
|
os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1")
|
||||||
|
|
||||||
|
# Log model loading info
|
||||||
|
logger.info(f"Loading ColVision model: {model_choice}")
|
||||||
|
logger.info(f"HF_ENDPOINT: {os.environ.get('HF_ENDPOINT', 'not set')}")
|
||||||
|
logger.info("Models will be downloaded from HuggingFace Hub, not Google Drive")
|
||||||
|
|
||||||
device_str, device, dtype = _select_device_and_dtype()
|
device_str, device, dtype = _select_device_and_dtype()
|
||||||
|
|
||||||
# Determine model name and type
|
# Determine model name and type
|
||||||
@@ -254,29 +269,36 @@ def _load_colvision(model_choice: str):
|
|||||||
"flash_attention_2" if (device_str == "cuda" and is_flash_attn_2_available()) else "eager"
|
"flash_attention_2" if (device_str == "cuda" and is_flash_attn_2_available()) else "eager"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Load model from HuggingFace Hub (not Google Drive)
|
||||||
|
# Use local_files_only=False to ensure download from HF if not cached
|
||||||
if model_type == "colqwen2.5":
|
if model_type == "colqwen2.5":
|
||||||
model = ColQwen2_5.from_pretrained(
|
model = ColQwen2_5.from_pretrained(
|
||||||
model_name,
|
model_name,
|
||||||
torch_dtype=torch.bfloat16,
|
torch_dtype=torch.bfloat16,
|
||||||
device_map=device,
|
device_map=device,
|
||||||
attn_implementation=attn_implementation,
|
attn_implementation=attn_implementation,
|
||||||
|
local_files_only=False, # Ensure download from HuggingFace Hub
|
||||||
).eval()
|
).eval()
|
||||||
processor = ColQwen2_5_Processor.from_pretrained(model_name)
|
processor = ColQwen2_5_Processor.from_pretrained(model_name, local_files_only=False)
|
||||||
elif model_type == "colqwen2":
|
elif model_type == "colqwen2":
|
||||||
model = ColQwen2.from_pretrained(
|
model = ColQwen2.from_pretrained(
|
||||||
model_name,
|
model_name,
|
||||||
torch_dtype=torch.bfloat16,
|
torch_dtype=torch.bfloat16,
|
||||||
device_map=device,
|
device_map=device,
|
||||||
attn_implementation=attn_implementation,
|
attn_implementation=attn_implementation,
|
||||||
|
local_files_only=False, # Ensure download from HuggingFace Hub
|
||||||
).eval()
|
).eval()
|
||||||
processor = ColQwen2Processor.from_pretrained(model_name)
|
processor = ColQwen2Processor.from_pretrained(model_name, local_files_only=False)
|
||||||
else: # colpali
|
else: # colpali
|
||||||
model = ColPali.from_pretrained(
|
model = ColPali.from_pretrained(
|
||||||
model_name,
|
model_name,
|
||||||
torch_dtype=torch.bfloat16,
|
torch_dtype=torch.bfloat16,
|
||||||
device_map=device,
|
device_map=device,
|
||||||
|
local_files_only=False, # Ensure download from HuggingFace Hub
|
||||||
).eval()
|
).eval()
|
||||||
processor = cast(ColPaliProcessor, ColPaliProcessor.from_pretrained(model_name))
|
processor = cast(
|
||||||
|
ColPaliProcessor, ColPaliProcessor.from_pretrained(model_name, local_files_only=False)
|
||||||
|
)
|
||||||
|
|
||||||
return model_name, model, processor, device_str, device, dtype
|
return model_name, model, processor, device_str, device, dtype
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user