fix: download module error (#141)
This commit is contained in:
241
py/download.py
241
py/download.py
@@ -3,8 +3,13 @@ import uuid
|
|||||||
import time
|
import time
|
||||||
import requests
|
import requests
|
||||||
import folder_paths
|
import folder_paths
|
||||||
|
|
||||||
|
|
||||||
from typing import Callable, Awaitable, Any, Literal, Union, Optional
|
from typing import Callable, Awaitable, Any, Literal, Union, Optional
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from aiohttp import web
|
||||||
|
|
||||||
|
|
||||||
from . import config
|
from . import config
|
||||||
from . import utils
|
from . import utils
|
||||||
from . import thread
|
from . import thread
|
||||||
@@ -87,17 +92,95 @@ class TaskContent:
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
download_model_task_status: dict[str, TaskStatus] = {}
|
class ModelDownload:
|
||||||
download_thread_pool = thread.DownloadThreadPool()
|
def add_routes(self, routes):
|
||||||
|
|
||||||
|
@routes.get("/model-manager/download/task")
|
||||||
|
async def scan_download_tasks(request):
|
||||||
|
"""
|
||||||
|
Read download task list.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
result = await self.scan_model_download_task_list()
|
||||||
|
return web.json_response({"success": True, "data": result})
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = f"Read download task list failed: {e}"
|
||||||
|
utils.print_error(error_msg)
|
||||||
|
return web.json_response({"success": False, "error": error_msg})
|
||||||
|
|
||||||
def set_task_content(task_id: str, task_content: Union[TaskContent, dict]):
|
@routes.put("/model-manager/download/{task_id}")
|
||||||
|
async def resume_download_task(request):
|
||||||
|
"""
|
||||||
|
Toggle download task status.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
task_id = request.match_info.get("task_id", None)
|
||||||
|
if task_id is None:
|
||||||
|
raise web.HTTPBadRequest(reason="Invalid task id")
|
||||||
|
json_data = await request.json()
|
||||||
|
status = json_data.get("status", None)
|
||||||
|
if status == "pause":
|
||||||
|
await self.pause_model_download_task(task_id)
|
||||||
|
elif status == "resume":
|
||||||
|
await self.download_model(task_id, request)
|
||||||
|
else:
|
||||||
|
raise web.HTTPBadRequest(reason="Invalid status")
|
||||||
|
|
||||||
|
return web.json_response({"success": True})
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = f"Resume download task failed: {str(e)}"
|
||||||
|
utils.print_error(error_msg)
|
||||||
|
return web.json_response({"success": False, "error": error_msg})
|
||||||
|
|
||||||
|
@routes.delete("/model-manager/download/{task_id}")
|
||||||
|
async def delete_model_download_task(request):
|
||||||
|
"""
|
||||||
|
Delete download task.
|
||||||
|
"""
|
||||||
|
task_id = request.match_info.get("task_id", None)
|
||||||
|
try:
|
||||||
|
await self.delete_model_download_task(task_id)
|
||||||
|
return web.json_response({"success": True})
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = f"Delete download task failed: {str(e)}"
|
||||||
|
utils.print_error(error_msg)
|
||||||
|
return web.json_response({"success": False, "error": error_msg})
|
||||||
|
|
||||||
|
@routes.post("/model-manager/model")
|
||||||
|
async def create_model(request):
|
||||||
|
"""
|
||||||
|
Create a new model.
|
||||||
|
|
||||||
|
request body: x-www-form-urlencoded
|
||||||
|
- type: model type.
|
||||||
|
- pathIndex: index of the model folders.
|
||||||
|
- fullname: filename that relative to the model folder.
|
||||||
|
- previewFile: preview file.
|
||||||
|
- description: description.
|
||||||
|
- downloadPlatform: download platform.
|
||||||
|
- downloadUrl: download url.
|
||||||
|
- hash: a JSON string containing the hash value of the downloaded model.
|
||||||
|
"""
|
||||||
|
task_data = await request.post()
|
||||||
|
task_data = dict(task_data)
|
||||||
|
try:
|
||||||
|
task_id = await self.create_model_download_task(task_data, request)
|
||||||
|
return web.json_response({"success": True, "data": {"taskId": task_id}})
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = f"Create model download task failed: {str(e)}"
|
||||||
|
utils.print_error(error_msg)
|
||||||
|
return web.json_response({"success": False, "error": error_msg})
|
||||||
|
|
||||||
|
download_model_task_status: dict[str, TaskStatus] = {}
|
||||||
|
|
||||||
|
download_thread_pool = thread.DownloadThreadPool()
|
||||||
|
|
||||||
|
def set_task_content(self, task_id: str, task_content: Union[TaskContent, dict]):
|
||||||
download_path = utils.get_download_path()
|
download_path = utils.get_download_path()
|
||||||
task_file_path = utils.join_path(download_path, f"{task_id}.task")
|
task_file_path = utils.join_path(download_path, f"{task_id}.task")
|
||||||
utils.save_dict_pickle_file(task_file_path, task_content)
|
utils.save_dict_pickle_file(task_file_path, task_content)
|
||||||
|
|
||||||
|
def get_task_content(self, task_id: str):
|
||||||
def get_task_content(task_id: str):
|
|
||||||
download_path = utils.get_download_path()
|
download_path = utils.get_download_path()
|
||||||
task_file = utils.join_path(download_path, f"{task_id}.task")
|
task_file = utils.join_path(download_path, f"{task_id}.task")
|
||||||
if not os.path.isfile(task_file):
|
if not os.path.isfile(task_file):
|
||||||
@@ -107,13 +190,12 @@ def get_task_content(task_id: str):
|
|||||||
return task_content
|
return task_content
|
||||||
return TaskContent(**task_content)
|
return TaskContent(**task_content)
|
||||||
|
|
||||||
|
def get_task_status(self, task_id: str):
|
||||||
def get_task_status(task_id: str):
|
task_status = self.download_model_task_status.get(task_id, None)
|
||||||
task_status = download_model_task_status.get(task_id, None)
|
|
||||||
|
|
||||||
if task_status is None:
|
if task_status is None:
|
||||||
download_path = utils.get_download_path()
|
download_path = utils.get_download_path()
|
||||||
task_content = get_task_content(task_id)
|
task_content = self.get_task_content(task_id)
|
||||||
download_file = utils.join_path(download_path, f"{task_id}.download")
|
download_file = utils.join_path(download_path, f"{task_id}.download")
|
||||||
download_size = 0
|
download_size = 0
|
||||||
if os.path.exists(download_file):
|
if os.path.exists(download_file):
|
||||||
@@ -131,16 +213,14 @@ def get_task_status(task_id: str):
|
|||||||
progress=download_size / total_size * 100 if total_size > 0 else 0,
|
progress=download_size / total_size * 100 if total_size > 0 else 0,
|
||||||
)
|
)
|
||||||
|
|
||||||
download_model_task_status[task_id] = task_status
|
self.download_model_task_status[task_id] = task_status
|
||||||
|
|
||||||
return task_status
|
return task_status
|
||||||
|
|
||||||
|
def delete_task_status(self, task_id: str):
|
||||||
|
self.download_model_task_status.pop(task_id, None)
|
||||||
|
|
||||||
def delete_task_status(task_id: str):
|
async def scan_model_download_task_list(self):
|
||||||
download_model_task_status.pop(task_id, None)
|
|
||||||
|
|
||||||
|
|
||||||
async def scan_model_download_task_list():
|
|
||||||
"""
|
"""
|
||||||
Scan the download directory and send the task list to the client.
|
Scan the download directory and send the task list to the client.
|
||||||
"""
|
"""
|
||||||
@@ -155,13 +235,12 @@ async def scan_model_download_task_list():
|
|||||||
task_list: list[dict] = []
|
task_list: list[dict] = []
|
||||||
for task_file in task_files:
|
for task_file in task_files:
|
||||||
task_id = task_file.replace(".task", "")
|
task_id = task_file.replace(".task", "")
|
||||||
task_status = get_task_status(task_id)
|
task_status = self.get_task_status(task_id)
|
||||||
task_list.append(task_status.to_dict())
|
task_list.append(task_status.to_dict())
|
||||||
|
|
||||||
return task_list
|
return task_list
|
||||||
|
|
||||||
|
async def create_model_download_task(self, task_data: dict, request):
|
||||||
async def create_model_download_task(task_data: dict, request):
|
|
||||||
"""
|
"""
|
||||||
Creates a download task for the given data.
|
Creates a download task for the given data.
|
||||||
"""
|
"""
|
||||||
@@ -185,7 +264,7 @@ async def create_model_download_task(task_data: dict, request):
|
|||||||
try:
|
try:
|
||||||
preview_file = task_data.pop("previewFile", None)
|
preview_file = task_data.pop("previewFile", None)
|
||||||
utils.save_model_preview_image(task_path, preview_file, download_platform)
|
utils.save_model_preview_image(task_path, preview_file, download_platform)
|
||||||
set_task_content(task_id, task_data)
|
self.set_task_content(task_id, task_data)
|
||||||
task_status = TaskStatus(
|
task_status = TaskStatus(
|
||||||
taskId=task_id,
|
taskId=task_id,
|
||||||
type=model_type,
|
type=model_type,
|
||||||
@@ -194,23 +273,21 @@ async def create_model_download_task(task_data: dict, request):
|
|||||||
platform=download_platform,
|
platform=download_platform,
|
||||||
totalSize=float(task_data.get("sizeBytes", 0)),
|
totalSize=float(task_data.get("sizeBytes", 0)),
|
||||||
)
|
)
|
||||||
download_model_task_status[task_id] = task_status
|
self.download_model_task_status[task_id] = task_status
|
||||||
await utils.send_json("create_download_task", task_status.to_dict())
|
await utils.send_json("create_download_task", task_status.to_dict())
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
await delete_model_download_task(task_id)
|
await self.delete_model_download_task(task_id)
|
||||||
raise RuntimeError(str(e)) from e
|
raise RuntimeError(str(e)) from e
|
||||||
|
|
||||||
await download_model(task_id, request)
|
await self.download_model(task_id, request)
|
||||||
return task_id
|
return task_id
|
||||||
|
|
||||||
|
async def pause_model_download_task(self, task_id: str):
|
||||||
async def pause_model_download_task(task_id: str):
|
task_status = self.get_task_status(task_id=task_id)
|
||||||
task_status = get_task_status(task_id=task_id)
|
|
||||||
task_status.status = "pause"
|
task_status.status = "pause"
|
||||||
|
|
||||||
|
async def delete_model_download_task(self, task_id: str):
|
||||||
async def delete_model_download_task(task_id: str):
|
task_status = self.get_task_status(task_id)
|
||||||
task_status = get_task_status(task_id)
|
|
||||||
is_running = task_status.status == "doing"
|
is_running = task_status.status == "doing"
|
||||||
task_status.status = "waiting"
|
task_status.status = "waiting"
|
||||||
await utils.send_json("delete_download_task", task_id)
|
await utils.send_json("delete_download_task", task_id)
|
||||||
@@ -225,20 +302,19 @@ async def delete_model_download_task(task_id: str):
|
|||||||
for task_file in task_file_list:
|
for task_file in task_file_list:
|
||||||
task_file_target = os.path.splitext(task_file)[0]
|
task_file_target = os.path.splitext(task_file)[0]
|
||||||
if task_file_target == task_id:
|
if task_file_target == task_id:
|
||||||
delete_task_status(task_id)
|
self.delete_task_status(task_id)
|
||||||
os.remove(utils.join_path(download_dir, task_file))
|
os.remove(utils.join_path(download_dir, task_file))
|
||||||
|
|
||||||
await utils.send_json("delete_download_task", task_id)
|
await utils.send_json("delete_download_task", task_id)
|
||||||
|
|
||||||
|
async def download_model(self, task_id: str, request):
|
||||||
async def download_model(task_id: str, request):
|
|
||||||
async def download_task(task_id: str):
|
async def download_task(task_id: str):
|
||||||
async def report_progress(task_status: TaskStatus):
|
async def report_progress(task_status: TaskStatus):
|
||||||
await utils.send_json("update_download_task", task_status.to_dict())
|
await utils.send_json("update_download_task", task_status.to_dict())
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# When starting a task from the queue, the task may not exist
|
# When starting a task from the queue, the task may not exist
|
||||||
task_status = get_task_status(task_id)
|
task_status = self.get_task_status(task_id)
|
||||||
except:
|
except:
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -263,7 +339,7 @@ async def download_model(task_id: str, request):
|
|||||||
headers["Authorization"] = f"Bearer {api_key}"
|
headers["Authorization"] = f"Bearer {api_key}"
|
||||||
|
|
||||||
progress_interval = 1.0
|
progress_interval = 1.0
|
||||||
await download_model_file(
|
await self.download_model_file(
|
||||||
task_id=task_id,
|
task_id=task_id,
|
||||||
headers=headers,
|
headers=headers,
|
||||||
progress_callback=report_progress,
|
progress_callback=report_progress,
|
||||||
@@ -277,9 +353,9 @@ async def download_model(task_id: str, request):
|
|||||||
utils.print_error(str(e))
|
utils.print_error(str(e))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
status = download_thread_pool.submit(download_task, task_id)
|
status = self.download_thread_pool.submit(download_task, task_id)
|
||||||
if status == "Waiting":
|
if status == "Waiting":
|
||||||
task_status = get_task_status(task_id)
|
task_status = self.get_task_status(task_id)
|
||||||
task_status.status = "waiting"
|
task_status.status = "waiting"
|
||||||
await utils.send_json("update_download_task", task_status.to_dict())
|
await utils.send_json("update_download_task", task_status.to_dict())
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -289,13 +365,13 @@ async def download_model(task_id: str, request):
|
|||||||
task_status.error = None
|
task_status.error = None
|
||||||
utils.print_error(str(e))
|
utils.print_error(str(e))
|
||||||
|
|
||||||
|
async def download_model_file(
|
||||||
async def download_model_file(
|
self,
|
||||||
task_id: str,
|
task_id: str,
|
||||||
headers: dict,
|
headers: dict,
|
||||||
progress_callback: Callable[[TaskStatus], Awaitable[Any]],
|
progress_callback: Callable[[TaskStatus], Awaitable[Any]],
|
||||||
interval: float = 1.0,
|
interval: float = 1.0,
|
||||||
):
|
):
|
||||||
|
|
||||||
async def download_complete():
|
async def download_complete():
|
||||||
"""
|
"""
|
||||||
@@ -331,8 +407,8 @@ async def download_model_file(
|
|||||||
last_update_time = time.time()
|
last_update_time = time.time()
|
||||||
last_downloaded_size = downloaded_size
|
last_downloaded_size = downloaded_size
|
||||||
|
|
||||||
task_status = get_task_status(task_id)
|
task_status = self.get_task_status(task_id)
|
||||||
task_content = get_task_content(task_id)
|
task_content = self.get_task_content(task_id)
|
||||||
|
|
||||||
# Check download uri
|
# Check download uri
|
||||||
model_url = task_content.downloadUrl
|
model_url = task_content.downloadUrl
|
||||||
@@ -385,7 +461,7 @@ async def download_model_file(
|
|||||||
total_size = float(response.headers.get("content-length", 0))
|
total_size = float(response.headers.get("content-length", 0))
|
||||||
task_content.sizeBytes = total_size
|
task_content.sizeBytes = total_size
|
||||||
task_status.totalSize = total_size
|
task_status.totalSize = total_size
|
||||||
set_task_content(task_id, task_content)
|
self.set_task_content(task_id, task_content)
|
||||||
await utils.send_json("update_download_task", task_content.to_dict())
|
await utils.send_json("update_download_task", task_content.to_dict())
|
||||||
|
|
||||||
with open(download_tmp_file, "ab") as f:
|
with open(download_tmp_file, "ab") as f:
|
||||||
@@ -406,86 +482,3 @@ async def download_model_file(
|
|||||||
else:
|
else:
|
||||||
task_status.status = "pause"
|
task_status.status = "pause"
|
||||||
await utils.send_json("update_download_task", task_status.to_dict())
|
await utils.send_json("update_download_task", task_status.to_dict())
|
||||||
|
|
||||||
|
|
||||||
from aiohttp import web
|
|
||||||
|
|
||||||
|
|
||||||
class ModelDownload:
|
|
||||||
def add_routes(self, routes):
|
|
||||||
|
|
||||||
@routes.get("/model-manager/download/task")
|
|
||||||
async def scan_download_tasks(request):
|
|
||||||
"""
|
|
||||||
Read download task list.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
result = await scan_model_download_task_list()
|
|
||||||
return web.json_response({"success": True, "data": result})
|
|
||||||
except Exception as e:
|
|
||||||
error_msg = f"Read download task list failed: {e}"
|
|
||||||
utils.print_error(error_msg)
|
|
||||||
return web.json_response({"success": False, "error": error_msg})
|
|
||||||
|
|
||||||
@routes.put("/model-manager/download/{task_id}")
|
|
||||||
async def resume_download_task(request):
|
|
||||||
"""
|
|
||||||
Toggle download task status.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
task_id = request.match_info.get("task_id", None)
|
|
||||||
if task_id is None:
|
|
||||||
raise web.HTTPBadRequest(reason="Invalid task id")
|
|
||||||
json_data = await request.json()
|
|
||||||
status = json_data.get("status", None)
|
|
||||||
if status == "pause":
|
|
||||||
await pause_model_download_task(task_id)
|
|
||||||
elif status == "resume":
|
|
||||||
await download_model(task_id, request)
|
|
||||||
else:
|
|
||||||
raise web.HTTPBadRequest(reason="Invalid status")
|
|
||||||
|
|
||||||
return web.json_response({"success": True})
|
|
||||||
except Exception as e:
|
|
||||||
error_msg = f"Resume download task failed: {str(e)}"
|
|
||||||
utils.print_error(error_msg)
|
|
||||||
return web.json_response({"success": False, "error": error_msg})
|
|
||||||
|
|
||||||
@routes.delete("/model-manager/download/{task_id}")
|
|
||||||
async def delete_model_download_task(request):
|
|
||||||
"""
|
|
||||||
Delete download task.
|
|
||||||
"""
|
|
||||||
task_id = request.match_info.get("task_id", None)
|
|
||||||
try:
|
|
||||||
await delete_model_download_task(task_id)
|
|
||||||
return web.json_response({"success": True})
|
|
||||||
except Exception as e:
|
|
||||||
error_msg = f"Delete download task failed: {str(e)}"
|
|
||||||
utils.print_error(error_msg)
|
|
||||||
return web.json_response({"success": False, "error": error_msg})
|
|
||||||
|
|
||||||
@routes.post("/model-manager/model")
|
|
||||||
async def create_model(request):
|
|
||||||
"""
|
|
||||||
Create a new model.
|
|
||||||
|
|
||||||
request body: x-www-form-urlencoded
|
|
||||||
- type: model type.
|
|
||||||
- pathIndex: index of the model folders.
|
|
||||||
- fullname: filename that relative to the model folder.
|
|
||||||
- previewFile: preview file.
|
|
||||||
- description: description.
|
|
||||||
- downloadPlatform: download platform.
|
|
||||||
- downloadUrl: download url.
|
|
||||||
- hash: a JSON string containing the hash value of the downloaded model.
|
|
||||||
"""
|
|
||||||
task_data = await request.post()
|
|
||||||
task_data = dict(task_data)
|
|
||||||
try:
|
|
||||||
task_id = await create_model_download_task(task_data, request)
|
|
||||||
return web.json_response({"success": True, "data": {"taskId": task_id}})
|
|
||||||
except Exception as e:
|
|
||||||
error_msg = f"Create model download task failed: {str(e)}"
|
|
||||||
utils.print_error(error_msg)
|
|
||||||
return web.json_response({"success": False, "error": error_msg})
|
|
||||||
|
|||||||
Reference in New Issue
Block a user