fix huggingface download with tokens (#116)
This commit is contained in:
@@ -103,6 +103,8 @@ def get_task_content(task_id: str):
|
||||
if not os.path.isfile(task_file):
|
||||
raise RuntimeError(f"Task {task_id} not found")
|
||||
task_content = utils.load_dict_pickle_file(task_file)
|
||||
if isinstance(task_content, TaskContent):
|
||||
return task_content
|
||||
return TaskContent(**task_content)
|
||||
|
||||
|
||||
@@ -178,17 +180,18 @@ async def create_model_download_task(task_data: dict, request):
|
||||
task_path = utils.join_path(download_path, f"{task_id}.task")
|
||||
if os.path.exists(task_path):
|
||||
raise RuntimeError(f"Task {task_id} already exists")
|
||||
download_platform = task_data.get("downloadPlatform", None)
|
||||
|
||||
try:
|
||||
previewFile = task_data.pop("previewFile", None)
|
||||
utils.save_model_preview_image(task_path, previewFile)
|
||||
preview_file = task_data.pop("previewFile", None)
|
||||
utils.save_model_preview_image(task_path, preview_file, download_platform)
|
||||
set_task_content(task_id, task_data)
|
||||
task_status = TaskStatus(
|
||||
taskId=task_id,
|
||||
type=model_type,
|
||||
fullname=fullname,
|
||||
preview=utils.get_model_preview_name(task_path),
|
||||
platform=task_data.get("downloadPlatform", None),
|
||||
platform=download_platform,
|
||||
totalSize=float(task_data.get("sizeBytes", 0)),
|
||||
)
|
||||
download_model_task_status[task_id] = task_status
|
||||
|
||||
@@ -225,7 +225,7 @@ class HuggingfaceModelSearcher(ModelSearcher):
|
||||
"pathIndex": 0,
|
||||
"description": "\n".join(description_parts),
|
||||
"metadata": {},
|
||||
"downloadPlatform": "",
|
||||
"downloadPlatform": "huggingface",
|
||||
"downloadUrl": f"https://huggingface.co/{model_id}/resolve/main/{filename}?download=true",
|
||||
}
|
||||
models.append(model)
|
||||
|
||||
10
py/utils.py
10
py/utils.py
@@ -277,10 +277,9 @@ def remove_model_preview_image(model_path: str):
|
||||
os.remove(preview_path)
|
||||
|
||||
|
||||
def save_model_preview_image(model_path: str, image_file_or_url: Any):
|
||||
def save_model_preview_image(model_path: str, image_file_or_url: Any, platform: str | None = None):
|
||||
basename = os.path.splitext(model_path)[0]
|
||||
preview_path = f"{basename}.webp"
|
||||
|
||||
# Download image file if it is url
|
||||
if type(image_file_or_url) is str:
|
||||
image_url = image_file_or_url
|
||||
@@ -304,8 +303,11 @@ def save_model_preview_image(model_path: str, image_file_or_url: Any):
|
||||
|
||||
content_type: str = image_file.content_type
|
||||
if not content_type.startswith("image/"):
|
||||
raise RuntimeError(f"FileTypeError: expected image, got {content_type}")
|
||||
|
||||
if platform == "huggingface":
|
||||
# huggingface previewFile content_type='text/plain', not startswith("image/")
|
||||
return
|
||||
else:
|
||||
raise RuntimeError(f"FileTypeError: expected image, got {content_type}")
|
||||
image = Image.open(image_file.file)
|
||||
image.save(preview_path, "WEBP")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user