From 130c75f5bf4e029d8f4f7ab5658cd4aded699740 Mon Sep 17 00:00:00 2001 From: boeto <34640489+boeto@users.noreply.github.com> Date: Mon, 3 Feb 2025 20:30:07 +0800 Subject: [PATCH] fix huggingface download with tokens (#116) --- py/download.py | 9 ++++++--- py/information.py | 2 +- py/utils.py | 10 ++++++---- 3 files changed, 13 insertions(+), 8 deletions(-) diff --git a/py/download.py b/py/download.py index b7f15e0..7394a77 100644 --- a/py/download.py +++ b/py/download.py @@ -103,6 +103,8 @@ def get_task_content(task_id: str): if not os.path.isfile(task_file): raise RuntimeError(f"Task {task_id} not found") task_content = utils.load_dict_pickle_file(task_file) + if isinstance(task_content, TaskContent): + return task_content return TaskContent(**task_content) @@ -178,17 +180,18 @@ async def create_model_download_task(task_data: dict, request): task_path = utils.join_path(download_path, f"{task_id}.task") if os.path.exists(task_path): raise RuntimeError(f"Task {task_id} already exists") + download_platform = task_data.get("downloadPlatform", None) try: - previewFile = task_data.pop("previewFile", None) - utils.save_model_preview_image(task_path, previewFile) + preview_file = task_data.pop("previewFile", None) + utils.save_model_preview_image(task_path, preview_file, download_platform) set_task_content(task_id, task_data) task_status = TaskStatus( taskId=task_id, type=model_type, fullname=fullname, preview=utils.get_model_preview_name(task_path), - platform=task_data.get("downloadPlatform", None), + platform=download_platform, totalSize=float(task_data.get("sizeBytes", 0)), ) download_model_task_status[task_id] = task_status diff --git a/py/information.py b/py/information.py index 3a804a8..d217796 100644 --- a/py/information.py +++ b/py/information.py @@ -225,7 +225,7 @@ class HuggingfaceModelSearcher(ModelSearcher): "pathIndex": 0, "description": "\n".join(description_parts), "metadata": {}, - "downloadPlatform": "", + "downloadPlatform": "huggingface", "downloadUrl": f"https://huggingface.co/{model_id}/resolve/main/{filename}?download=true", } models.append(model) diff --git a/py/utils.py b/py/utils.py index 040c74f..eed5551 100644 --- a/py/utils.py +++ b/py/utils.py @@ -277,10 +277,9 @@ def remove_model_preview_image(model_path: str): os.remove(preview_path) -def save_model_preview_image(model_path: str, image_file_or_url: Any): +def save_model_preview_image(model_path: str, image_file_or_url: Any, platform: str | None = None): basename = os.path.splitext(model_path)[0] preview_path = f"{basename}.webp" - # Download image file if it is url if type(image_file_or_url) is str: image_url = image_file_or_url @@ -304,8 +303,11 @@ def save_model_preview_image(model_path: str, image_file_or_url: Any): content_type: str = image_file.content_type if not content_type.startswith("image/"): - raise RuntimeError(f"FileTypeError: expected image, got {content_type}") - + if platform == "huggingface": + # huggingface previewFile content_type='text/plain', not startswith("image/") + return + else: + raise RuntimeError(f"FileTypeError: expected image, got {content_type}") image = Image.open(image_file.file) image.save(preview_path, "WEBP")