fix: download module error (#141)

This commit is contained in:
Hayden
2025-02-19 14:37:27 +08:00
committed by GitHub
parent ea26ec5098
commit 05fa31f2c5

View File

@@ -3,8 +3,13 @@ import uuid
import time import time
import requests import requests
import folder_paths import folder_paths
from typing import Callable, Awaitable, Any, Literal, Union, Optional from typing import Callable, Awaitable, Any, Literal, Union, Optional
from dataclasses import dataclass from dataclasses import dataclass
from aiohttp import web
from . import config from . import config
from . import utils from . import utils
from . import thread from . import thread
@@ -87,17 +92,95 @@ class TaskContent:
} }
class ModelDownload:
def add_routes(self, routes):
@routes.get("/model-manager/download/task")
async def scan_download_tasks(request):
"""
Read download task list.
"""
try:
result = await self.scan_model_download_task_list()
return web.json_response({"success": True, "data": result})
except Exception as e:
error_msg = f"Read download task list failed: {e}"
utils.print_error(error_msg)
return web.json_response({"success": False, "error": error_msg})
@routes.put("/model-manager/download/{task_id}")
async def resume_download_task(request):
"""
Toggle download task status.
"""
try:
task_id = request.match_info.get("task_id", None)
if task_id is None:
raise web.HTTPBadRequest(reason="Invalid task id")
json_data = await request.json()
status = json_data.get("status", None)
if status == "pause":
await self.pause_model_download_task(task_id)
elif status == "resume":
await self.download_model(task_id, request)
else:
raise web.HTTPBadRequest(reason="Invalid status")
return web.json_response({"success": True})
except Exception as e:
error_msg = f"Resume download task failed: {str(e)}"
utils.print_error(error_msg)
return web.json_response({"success": False, "error": error_msg})
@routes.delete("/model-manager/download/{task_id}")
async def delete_model_download_task(request):
"""
Delete download task.
"""
task_id = request.match_info.get("task_id", None)
try:
await self.delete_model_download_task(task_id)
return web.json_response({"success": True})
except Exception as e:
error_msg = f"Delete download task failed: {str(e)}"
utils.print_error(error_msg)
return web.json_response({"success": False, "error": error_msg})
@routes.post("/model-manager/model")
async def create_model(request):
"""
Create a new model.
request body: x-www-form-urlencoded
- type: model type.
- pathIndex: index of the model folders.
- fullname: filename that relative to the model folder.
- previewFile: preview file.
- description: description.
- downloadPlatform: download platform.
- downloadUrl: download url.
- hash: a JSON string containing the hash value of the downloaded model.
"""
task_data = await request.post()
task_data = dict(task_data)
try:
task_id = await self.create_model_download_task(task_data, request)
return web.json_response({"success": True, "data": {"taskId": task_id}})
except Exception as e:
error_msg = f"Create model download task failed: {str(e)}"
utils.print_error(error_msg)
return web.json_response({"success": False, "error": error_msg})
download_model_task_status: dict[str, TaskStatus] = {} download_model_task_status: dict[str, TaskStatus] = {}
download_thread_pool = thread.DownloadThreadPool() download_thread_pool = thread.DownloadThreadPool()
def set_task_content(self, task_id: str, task_content: Union[TaskContent, dict]):
def set_task_content(task_id: str, task_content: Union[TaskContent, dict]):
download_path = utils.get_download_path() download_path = utils.get_download_path()
task_file_path = utils.join_path(download_path, f"{task_id}.task") task_file_path = utils.join_path(download_path, f"{task_id}.task")
utils.save_dict_pickle_file(task_file_path, task_content) utils.save_dict_pickle_file(task_file_path, task_content)
def get_task_content(self, task_id: str):
def get_task_content(task_id: str):
download_path = utils.get_download_path() download_path = utils.get_download_path()
task_file = utils.join_path(download_path, f"{task_id}.task") task_file = utils.join_path(download_path, f"{task_id}.task")
if not os.path.isfile(task_file): if not os.path.isfile(task_file):
@@ -107,13 +190,12 @@ def get_task_content(task_id: str):
return task_content return task_content
return TaskContent(**task_content) return TaskContent(**task_content)
def get_task_status(self, task_id: str):
def get_task_status(task_id: str): task_status = self.download_model_task_status.get(task_id, None)
task_status = download_model_task_status.get(task_id, None)
if task_status is None: if task_status is None:
download_path = utils.get_download_path() download_path = utils.get_download_path()
task_content = get_task_content(task_id) task_content = self.get_task_content(task_id)
download_file = utils.join_path(download_path, f"{task_id}.download") download_file = utils.join_path(download_path, f"{task_id}.download")
download_size = 0 download_size = 0
if os.path.exists(download_file): if os.path.exists(download_file):
@@ -131,16 +213,14 @@ def get_task_status(task_id: str):
progress=download_size / total_size * 100 if total_size > 0 else 0, progress=download_size / total_size * 100 if total_size > 0 else 0,
) )
download_model_task_status[task_id] = task_status self.download_model_task_status[task_id] = task_status
return task_status return task_status
def delete_task_status(self, task_id: str):
self.download_model_task_status.pop(task_id, None)
def delete_task_status(task_id: str): async def scan_model_download_task_list(self):
download_model_task_status.pop(task_id, None)
async def scan_model_download_task_list():
""" """
Scan the download directory and send the task list to the client. Scan the download directory and send the task list to the client.
""" """
@@ -155,13 +235,12 @@ async def scan_model_download_task_list():
task_list: list[dict] = [] task_list: list[dict] = []
for task_file in task_files: for task_file in task_files:
task_id = task_file.replace(".task", "") task_id = task_file.replace(".task", "")
task_status = get_task_status(task_id) task_status = self.get_task_status(task_id)
task_list.append(task_status.to_dict()) task_list.append(task_status.to_dict())
return task_list return task_list
async def create_model_download_task(self, task_data: dict, request):
async def create_model_download_task(task_data: dict, request):
""" """
Creates a download task for the given data. Creates a download task for the given data.
""" """
@@ -185,7 +264,7 @@ async def create_model_download_task(task_data: dict, request):
try: try:
preview_file = task_data.pop("previewFile", None) preview_file = task_data.pop("previewFile", None)
utils.save_model_preview_image(task_path, preview_file, download_platform) utils.save_model_preview_image(task_path, preview_file, download_platform)
set_task_content(task_id, task_data) self.set_task_content(task_id, task_data)
task_status = TaskStatus( task_status = TaskStatus(
taskId=task_id, taskId=task_id,
type=model_type, type=model_type,
@@ -194,23 +273,21 @@ async def create_model_download_task(task_data: dict, request):
platform=download_platform, platform=download_platform,
totalSize=float(task_data.get("sizeBytes", 0)), totalSize=float(task_data.get("sizeBytes", 0)),
) )
download_model_task_status[task_id] = task_status self.download_model_task_status[task_id] = task_status
await utils.send_json("create_download_task", task_status.to_dict()) await utils.send_json("create_download_task", task_status.to_dict())
except Exception as e: except Exception as e:
await delete_model_download_task(task_id) await self.delete_model_download_task(task_id)
raise RuntimeError(str(e)) from e raise RuntimeError(str(e)) from e
await download_model(task_id, request) await self.download_model(task_id, request)
return task_id return task_id
async def pause_model_download_task(self, task_id: str):
async def pause_model_download_task(task_id: str): task_status = self.get_task_status(task_id=task_id)
task_status = get_task_status(task_id=task_id)
task_status.status = "pause" task_status.status = "pause"
async def delete_model_download_task(self, task_id: str):
async def delete_model_download_task(task_id: str): task_status = self.get_task_status(task_id)
task_status = get_task_status(task_id)
is_running = task_status.status == "doing" is_running = task_status.status == "doing"
task_status.status = "waiting" task_status.status = "waiting"
await utils.send_json("delete_download_task", task_id) await utils.send_json("delete_download_task", task_id)
@@ -225,20 +302,19 @@ async def delete_model_download_task(task_id: str):
for task_file in task_file_list: for task_file in task_file_list:
task_file_target = os.path.splitext(task_file)[0] task_file_target = os.path.splitext(task_file)[0]
if task_file_target == task_id: if task_file_target == task_id:
delete_task_status(task_id) self.delete_task_status(task_id)
os.remove(utils.join_path(download_dir, task_file)) os.remove(utils.join_path(download_dir, task_file))
await utils.send_json("delete_download_task", task_id) await utils.send_json("delete_download_task", task_id)
async def download_model(self, task_id: str, request):
async def download_model(task_id: str, request):
async def download_task(task_id: str): async def download_task(task_id: str):
async def report_progress(task_status: TaskStatus): async def report_progress(task_status: TaskStatus):
await utils.send_json("update_download_task", task_status.to_dict()) await utils.send_json("update_download_task", task_status.to_dict())
try: try:
# When starting a task from the queue, the task may not exist # When starting a task from the queue, the task may not exist
task_status = get_task_status(task_id) task_status = self.get_task_status(task_id)
except: except:
return return
@@ -263,7 +339,7 @@ async def download_model(task_id: str, request):
headers["Authorization"] = f"Bearer {api_key}" headers["Authorization"] = f"Bearer {api_key}"
progress_interval = 1.0 progress_interval = 1.0
await download_model_file( await self.download_model_file(
task_id=task_id, task_id=task_id,
headers=headers, headers=headers,
progress_callback=report_progress, progress_callback=report_progress,
@@ -277,9 +353,9 @@ async def download_model(task_id: str, request):
utils.print_error(str(e)) utils.print_error(str(e))
try: try:
status = download_thread_pool.submit(download_task, task_id) status = self.download_thread_pool.submit(download_task, task_id)
if status == "Waiting": if status == "Waiting":
task_status = get_task_status(task_id) task_status = self.get_task_status(task_id)
task_status.status = "waiting" task_status.status = "waiting"
await utils.send_json("update_download_task", task_status.to_dict()) await utils.send_json("update_download_task", task_status.to_dict())
except Exception as e: except Exception as e:
@@ -289,8 +365,8 @@ async def download_model(task_id: str, request):
task_status.error = None task_status.error = None
utils.print_error(str(e)) utils.print_error(str(e))
async def download_model_file( async def download_model_file(
self,
task_id: str, task_id: str,
headers: dict, headers: dict,
progress_callback: Callable[[TaskStatus], Awaitable[Any]], progress_callback: Callable[[TaskStatus], Awaitable[Any]],
@@ -331,8 +407,8 @@ async def download_model_file(
last_update_time = time.time() last_update_time = time.time()
last_downloaded_size = downloaded_size last_downloaded_size = downloaded_size
task_status = get_task_status(task_id) task_status = self.get_task_status(task_id)
task_content = get_task_content(task_id) task_content = self.get_task_content(task_id)
# Check download uri # Check download uri
model_url = task_content.downloadUrl model_url = task_content.downloadUrl
@@ -385,7 +461,7 @@ async def download_model_file(
total_size = float(response.headers.get("content-length", 0)) total_size = float(response.headers.get("content-length", 0))
task_content.sizeBytes = total_size task_content.sizeBytes = total_size
task_status.totalSize = total_size task_status.totalSize = total_size
set_task_content(task_id, task_content) self.set_task_content(task_id, task_content)
await utils.send_json("update_download_task", task_content.to_dict()) await utils.send_json("update_download_task", task_content.to_dict())
with open(download_tmp_file, "ab") as f: with open(download_tmp_file, "ab") as f:
@@ -406,86 +482,3 @@ async def download_model_file(
else: else:
task_status.status = "pause" task_status.status = "pause"
await utils.send_json("update_download_task", task_status.to_dict()) await utils.send_json("update_download_task", task_status.to_dict())
from aiohttp import web
class ModelDownload:
def add_routes(self, routes):
@routes.get("/model-manager/download/task")
async def scan_download_tasks(request):
"""
Read download task list.
"""
try:
result = await scan_model_download_task_list()
return web.json_response({"success": True, "data": result})
except Exception as e:
error_msg = f"Read download task list failed: {e}"
utils.print_error(error_msg)
return web.json_response({"success": False, "error": error_msg})
@routes.put("/model-manager/download/{task_id}")
async def resume_download_task(request):
"""
Toggle download task status.
"""
try:
task_id = request.match_info.get("task_id", None)
if task_id is None:
raise web.HTTPBadRequest(reason="Invalid task id")
json_data = await request.json()
status = json_data.get("status", None)
if status == "pause":
await pause_model_download_task(task_id)
elif status == "resume":
await download_model(task_id, request)
else:
raise web.HTTPBadRequest(reason="Invalid status")
return web.json_response({"success": True})
except Exception as e:
error_msg = f"Resume download task failed: {str(e)}"
utils.print_error(error_msg)
return web.json_response({"success": False, "error": error_msg})
@routes.delete("/model-manager/download/{task_id}")
async def delete_model_download_task(request):
"""
Delete download task.
"""
task_id = request.match_info.get("task_id", None)
try:
await delete_model_download_task(task_id)
return web.json_response({"success": True})
except Exception as e:
error_msg = f"Delete download task failed: {str(e)}"
utils.print_error(error_msg)
return web.json_response({"success": False, "error": error_msg})
@routes.post("/model-manager/model")
async def create_model(request):
"""
Create a new model.
request body: x-www-form-urlencoded
- type: model type.
- pathIndex: index of the model folders.
- fullname: filename that relative to the model folder.
- previewFile: preview file.
- description: description.
- downloadPlatform: download platform.
- downloadUrl: download url.
- hash: a JSON string containing the hash value of the downloaded model.
"""
task_data = await request.post()
task_data = dict(task_data)
try:
task_id = await create_model_download_task(task_data, request)
return web.json_response({"success": True, "data": {"taskId": task_id}})
except Exception as e:
error_msg = f"Create model download task failed: {str(e)}"
utils.print_error(error_msg)
return web.json_response({"success": False, "error": error_msg})