fix huggingface download with tokens (#116)

This commit is contained in:
boeto
2025-02-03 20:30:07 +08:00
committed by GitHub
parent 921dabc057
commit 130c75f5bf
3 changed files with 13 additions and 8 deletions

View File

@@ -103,6 +103,8 @@ def get_task_content(task_id: str):
if not os.path.isfile(task_file):
raise RuntimeError(f"Task {task_id} not found")
task_content = utils.load_dict_pickle_file(task_file)
if isinstance(task_content, TaskContent):
return task_content
return TaskContent(**task_content)
@@ -178,17 +180,18 @@ async def create_model_download_task(task_data: dict, request):
task_path = utils.join_path(download_path, f"{task_id}.task")
if os.path.exists(task_path):
raise RuntimeError(f"Task {task_id} already exists")
download_platform = task_data.get("downloadPlatform", None)
try:
previewFile = task_data.pop("previewFile", None)
utils.save_model_preview_image(task_path, previewFile)
preview_file = task_data.pop("previewFile", None)
utils.save_model_preview_image(task_path, preview_file, download_platform)
set_task_content(task_id, task_data)
task_status = TaskStatus(
taskId=task_id,
type=model_type,
fullname=fullname,
preview=utils.get_model_preview_name(task_path),
platform=task_data.get("downloadPlatform", None),
platform=download_platform,
totalSize=float(task_data.get("sizeBytes", 0)),
)
download_model_task_status[task_id] = task_status

View File

@@ -225,7 +225,7 @@ class HuggingfaceModelSearcher(ModelSearcher):
"pathIndex": 0,
"description": "\n".join(description_parts),
"metadata": {},
"downloadPlatform": "",
"downloadPlatform": "huggingface",
"downloadUrl": f"https://huggingface.co/{model_id}/resolve/main/{filename}?download=true",
}
models.append(model)

View File

@@ -277,10 +277,9 @@ def remove_model_preview_image(model_path: str):
os.remove(preview_path)
def save_model_preview_image(model_path: str, image_file_or_url: Any):
def save_model_preview_image(model_path: str, image_file_or_url: Any, platform: str | None = None):
basename = os.path.splitext(model_path)[0]
preview_path = f"{basename}.webp"
# Download image file if it is url
if type(image_file_or_url) is str:
image_url = image_file_or_url
@@ -304,8 +303,11 @@ def save_model_preview_image(model_path: str, image_file_or_url: Any):
content_type: str = image_file.content_type
if not content_type.startswith("image/"):
raise RuntimeError(f"FileTypeError: expected image, got {content_type}")
if platform == "huggingface":
# huggingface previewFile content_type='text/plain', not startswith("image/")
return
else:
raise RuntimeError(f"FileTypeError: expected image, got {content_type}")
image = Image.open(image_file.file)
image.save(preview_path, "WEBP")