diff --git a/__init__.py b/__init__.py index 6b1c0a6..57c96fb 100644 --- a/__init__.py +++ b/__init__.py @@ -21,13 +21,61 @@ from .py import services routes = config.routes -@routes.get("/model-manager/ws") -async def socket_handler(request): +@routes.get("/model-manager/download/task") +async def scan_download_tasks(request): """ - Handle websocket connection. + Read download task list. """ - ws = await services.connect_websocket(request) - return ws + try: + result = await services.scan_model_download_task_list() + return web.json_response({"success": True, "data": result}) + except Exception as e: + error_msg = f"Read download task list failed: {e}" + logging.error(error_msg) + logging.debug(traceback.format_exc()) + return web.json_response({"success": False, "error": error_msg}) + + +@routes.put("/model-manager/download/{task_id}") +async def resume_download_task(request): + """ + Toggle download task status. + """ + try: + task_id = request.match_info.get("task_id", None) + if task_id is None: + raise web.HTTPBadRequest(reason="Invalid task id") + json_data = await request.json() + status = json_data.get("status", None) + if status == "pause": + await services.pause_model_download_task(task_id) + elif status == "resume": + await services.resume_model_download_task(task_id, request) + else: + raise web.HTTPBadRequest(reason="Invalid status") + + return web.json_response({"success": True}) + except Exception as e: + error_msg = f"Resume download task failed: {str(e)}" + logging.error(error_msg) + logging.debug(traceback.format_exc()) + return web.json_response({"success": False, "error": error_msg}) + + +@routes.delete("/model-manager/download/{task_id}") +async def delete_model_download_task(request): + """ + Delete download task. + """ + task_id = request.match_info.get("task_id", None) + try: + await services.delete_model_download_task(task_id) + return web.json_response({"success": True}) + except Exception as e: + error_msg = f"Delete download task failed: {str(e)}" + logging.error(error_msg) + logging.debug(traceback.format_exc()) + return web.json_response({"success": False, "error": error_msg}) @routes.get("/model-manager/base-folders") @@ -56,7 +104,7 @@ async def create_model(request): """ post = await request.post() try: - task_id = await services.create_model_download_task(post) + task_id = await services.create_model_download_task(post, request) return web.json_response({"success": True, "data": {"taskId": task_id}}) except Exception as e: error_msg = f"Create model download task failed: {str(e)}" diff --git a/py/config.py b/py/config.py index 8efbdee..f4d11a5 100644 --- a/py/config.py +++ b/py/config.py @@ -19,15 +19,3 @@ from server import PromptServer serverInstance = PromptServer.instance routes = serverInstance.routes - - -class FakeRequest: - def __init__(self): - self.headers = {} - - -class CustomException(BaseException): - def __init__(self, type: str, message: str = None) -> None: - self.type = type - self.message = message - super().__init__(message) diff --git a/py/download.py b/py/download.py index 9779951..95fafdd 100644 --- a/py/download.py +++ b/py/download.py @@ -9,7 +9,6 @@ from typing import Callable, Awaitable, Any, Literal, Union, Optional from dataclasses import dataclass from . import config from . import utils -from . import socket from . import thread @@ -93,33 +92,28 @@ def delete_task_status(task_id: str): download_model_task_status.pop(task_id, None) -async def scan_model_download_task_list(sid: str): +async def scan_model_download_task_list(): """ Scan the download directory and send the task list to the client. """ - try: - download_dir = utils.get_download_path() - task_files = utils.search_files(download_dir) - task_files = folder_paths.filter_files_extensions(task_files, [".task"]) - task_files = sorted( - task_files, - key=lambda x: os.stat(utils.join_path(download_dir, x)).st_ctime, - reverse=True, - ) - task_list: list[dict] = [] - for task_file in task_files: - task_id = task_file.replace(".task", "") - task_status = get_task_status(task_id) - task_list.append(task_status) + download_dir = utils.get_download_path() + task_files = utils.search_files(download_dir) + task_files = folder_paths.filter_files_extensions(task_files, [".task"]) + task_files = sorted( + task_files, + key=lambda x: os.stat(utils.join_path(download_dir, x)).st_ctime, + reverse=True, + ) + task_list: list[dict] = [] + for task_file in task_files: + task_id = task_file.replace(".task", "") + task_status = get_task_status(task_id) + task_list.append(task_status) - await socket.send_json("downloadTaskList", task_list, sid) - except Exception as e: - error_msg = f"Refresh task list failed: {e}" - await socket.send_json("error", error_msg, sid) - logging.error(error_msg) + return utils.unpack_dataclass(task_list) -async def create_model_download_task(post: dict): +async def create_model_download_task(post: dict, request): """ Creates a download task for the given post. """ @@ -152,12 +146,12 @@ async def create_model_download_task(post: dict): totalSize=float(post.get("sizeBytes", 0)), ) download_model_task_status[task_id] = task_status - await socket.send_json("createDownloadTask", task_status) + await utils.send_json("create_download_task", task_status) except Exception as e: await delete_model_download_task(task_id) raise RuntimeError(str(e)) from e - await download_model(task_id) + await download_model(task_id, request) return task_id @@ -170,7 +164,7 @@ async def delete_model_download_task(task_id: str): task_status = get_task_status(task_id) is_running = task_status.status == "doing" task_status.status = "waiting" - await socket.send_json("deleteDownloadTask", task_id) + await utils.send_json("delete_download_task", task_id) # Pause the task if is_running: @@ -185,13 +179,13 @@ async def delete_model_download_task(task_id: str): delete_task_status(task_id) os.remove(utils.join_path(download_dir, task_file)) - await socket.send_json("deleteDownloadTask", task_id) + await utils.send_json("delete_download_task", task_id) -async def download_model(task_id: str): +async def download_model(task_id: str, request): async def download_task(task_id: str): async def report_progress(task_status: TaskStatus): - await socket.send_json("updateDownloadTask", task_status) + await utils.send_json("update_download_task", task_status) try: # When starting a task from the queue, the task may not exist @@ -201,7 +195,7 @@ async def download_model(task_id: str): # Update task status task_status.status = "doing" - await socket.send_json("updateDownloadTask", task_status) + await utils.send_json("update_download_task", task_status) try: @@ -210,12 +204,12 @@ async def download_model(task_id: str): download_platform = task_status.platform if download_platform == "civitai": - api_key = utils.get_setting_value("api_key.civitai") + api_key = utils.get_setting_value(request, "api_key.civitai") if api_key: headers["Authorization"] = f"Bearer {api_key}" elif download_platform == "huggingface": - api_key = utils.get_setting_value("api_key.huggingface") + api_key = utils.get_setting_value(request, "api_key.huggingface") if api_key: headers["Authorization"] = f"Bearer {api_key}" @@ -229,7 +223,7 @@ async def download_model(task_id: str): except Exception as e: task_status.status = "pause" task_status.error = str(e) - await socket.send_json("updateDownloadTask", task_status) + await utils.send_json("update_download_task", task_status) task_status.error = None logging.error(str(e)) @@ -238,11 +232,11 @@ async def download_model(task_id: str): if status == "Waiting": task_status = get_task_status(task_id) task_status.status = "waiting" - await socket.send_json("updateDownloadTask", task_status) + await utils.send_json("update_download_task", task_status) except Exception as e: task_status.status = "pause" task_status.error = str(e) - await socket.send_json("updateDownloadTask", task_status) + await utils.send_json("update_download_task", task_status) task_status.error = None logging.error(traceback.format_exc()) @@ -275,7 +269,7 @@ async def download_model_file( time.sleep(1) task_file = utils.join_path(download_path, f"{task_id}.task") os.remove(task_file) - await socket.send_json("completeDownloadTask", task_id) + await utils.send_json("complete_download_task", task_id) async def update_progress(): nonlocal last_update_time @@ -347,7 +341,7 @@ async def download_model_file( task_content.sizeBytes = total_size task_status.totalSize = total_size set_task_content(task_id, task_content) - await socket.send_json("updateDownloadTask", task_content) + await utils.send_json("update_download_task", task_content) with open(download_tmp_file, "ab") as f: for chunk in response.iter_content(chunk_size=8192): @@ -366,4 +360,4 @@ async def download_model_file( await download_complete() else: task_status.status = "pause" - await socket.send_json("updateDownloadTask", task_status) + await utils.send_json("update_download_task", task_status) diff --git a/py/services.py b/py/services.py index 9518a3f..a94c0c7 100644 --- a/py/services.py +++ b/py/services.py @@ -1,37 +1,13 @@ import os -import logging -import traceback + import folder_paths -from typing import Any from multidict import MultiDictProxy from . import config from . import utils -from . import socket from . import download -async def connect_websocket(request): - async def message_handler(event_type: str, detail: Any, sid: str): - try: - if event_type == "downloadTaskList": - await download.scan_model_download_task_list(sid=sid) - - if event_type == "resumeDownloadTask": - await download.download_model(task_id=detail) - - if event_type == "pauseDownloadTask": - await download.pause_model_download_task(task_id=detail) - - if event_type == "deleteDownloadTask": - await download.delete_model_download_task(task_id=detail) - except Exception: - logging.error(traceback.format_exc()) - - ws = await socket.create_websocket_handler(request, handler=message_handler) - return ws - - def scan_models(): result = [] model_base_paths = config.model_base_paths @@ -135,6 +111,22 @@ def remove_model(model_path: str): os.remove(utils.join_path(model_dirname, description)) -async def create_model_download_task(post): +async def create_model_download_task(post, request): dict_post = dict(post) - return await download.create_model_download_task(dict_post) + return await download.create_model_download_task(dict_post, request) + + +async def scan_model_download_task_list(): + return await download.scan_model_download_task_list() + + +async def pause_model_download_task(task_id): + return await download.pause_model_download_task(task_id) + + +async def resume_model_download_task(task_id, request): + return await download.download_model(task_id, request) + + +async def delete_model_download_task(task_id): + return await download.delete_model_download_task(task_id) diff --git a/py/socket.py b/py/socket.py deleted file mode 100644 index 13a39f9..0000000 --- a/py/socket.py +++ /dev/null @@ -1,63 +0,0 @@ -import aiohttp -import logging -import uuid -import json -from aiohttp import web -from typing import Any, Callable, Awaitable -from . import utils - - -__sockets: dict[str, web.WebSocketResponse] = {} - - -async def create_websocket_handler( - request, handler: Callable[[str, Any, str], Awaitable[Any]] -): - ws = web.WebSocketResponse() - await ws.prepare(request) - sid = request.rel_url.query.get("clientId", "") - if sid: - # Reusing existing session, remove old - __sockets.pop(sid, None) - else: - sid = uuid.uuid4().hex - - __sockets[sid] = ws - - try: - async for msg in ws: - if msg.type == aiohttp.WSMsgType.ERROR: - logging.warning( - "ws connection closed with exception %s" % ws.exception() - ) - if msg.type == aiohttp.WSMsgType.TEXT: - data = json.loads(msg.data) - await handler(data.get("type"), data.get("detail"), sid) - finally: - __sockets.pop(sid, None) - return ws - - -async def send_json(event: str, data: Any, sid: str = None): - detail = utils.unpack_dataclass(data) - message = {"type": event, "data": detail} - - if sid is None: - socket_list = list(__sockets.values()) - for ws in socket_list: - await __send_socket_catch_exception(ws.send_json, message) - elif sid in __sockets: - await __send_socket_catch_exception(__sockets[sid].send_json, message) - - -async def __send_socket_catch_exception(function, message): - try: - await function(message) - except ( - aiohttp.ClientError, - aiohttp.ClientPayloadError, - ConnectionResetError, - BrokenPipeError, - ConnectionError, - ) as err: - logging.warning("send error: {}".format(err)) diff --git a/py/thread.py b/py/thread.py index 82689f7..e40798e 100644 --- a/py/thread.py +++ b/py/thread.py @@ -2,7 +2,6 @@ import asyncio import threading import queue import logging -from . import utils class DownloadThreadPool: @@ -13,14 +12,7 @@ class DownloadThreadPool: self._lock = threading.Lock() default_max_workers = 5 - max_workers: int = utils.get_setting_value( - "download.max_task_count", default_max_workers - ) - - if max_workers <= 0: - max_workers = default_max_workers - utils.set_setting_value("download.max_task_count", max_workers) - + max_workers: int = default_max_workers self.max_worker = max_workers def submit(self, task, task_id): diff --git a/py/utils.py b/py/utils.py index 1aa9dff..e433547 100644 --- a/py/utils.py +++ b/py/utils.py @@ -334,18 +334,16 @@ def resolve_setting_key(key: str) -> str: return setting_id -def set_setting_value(key: str, value: Any): +def set_setting_value(request: web.Request, key: str, value: Any): setting_id = resolve_setting_key(key) - fake_request = config.FakeRequest() - settings = config.serverInstance.user_manager.settings.get_settings(fake_request) + settings = config.serverInstance.user_manager.settings.get_settings(request) settings[setting_id] = value - config.serverInstance.user_manager.settings.save_settings(fake_request, settings) + config.serverInstance.user_manager.settings.save_settings(request, settings) -def get_setting_value(key: str, default: Any = None) -> Any: +def get_setting_value(request: web.Request, key: str, default: Any = None) -> Any: setting_id = resolve_setting_key(key) - fake_request = config.FakeRequest() - settings = config.serverInstance.user_manager.settings.get_settings(fake_request) + settings = config.serverInstance.user_manager.settings.get_settings(request) return settings.get(setting_id, default) @@ -361,3 +359,8 @@ def unpack_dataclass(data: Any): return asdict(data) else: return data + + +async def send_json(event: str, data: Any, sid: str = None): + detail = unpack_dataclass(data) + await config.serverInstance.send_json(event, detail, sid) diff --git a/src/hooks/download.ts b/src/hooks/download.ts index fb1f865..8d41fb4 100644 --- a/src/hooks/download.ts +++ b/src/hooks/download.ts @@ -1,8 +1,9 @@ import { useLoading } from 'hooks/loading' import { MarkdownTool, useMarkdown } from 'hooks/markdown' -import { socket } from 'hooks/socket' +import { request } from 'hooks/request' import { defineStore } from 'hooks/store' import { useToast } from 'hooks/toast' +import { api } from 'scripts/comfyAPI' import { bytesToSize } from 'utils/common' import { onBeforeMount, onMounted, ref, watch } from 'vue' import { useI18n } from 'vue-i18n' @@ -13,10 +14,6 @@ export const useDownload = defineStore('download', (store) => { const taskList = ref([]) - const refresh = () => { - socket.send('downloadTaskList', null) - } - const createTaskItem = (item: DownloadTaskOptions) => { const { downloadedSize, totalSize, bps, ...rest } = item @@ -26,10 +23,20 @@ export const useDownload = defineStore('download', (store) => { downloadProgress: `${bytesToSize(downloadedSize)} / ${bytesToSize(totalSize)}`, downloadSpeed: `${bytesToSize(bps)}/s`, pauseTask() { - socket.send('pauseDownloadTask', item.taskId) + request(`/download/${item.taskId}`, { + method: 'PUT', + body: JSON.stringify({ + status: 'pause', + }), + }) }, resumeTask: () => { - socket.send('resumeDownloadTask', item.taskId) + request(`/download/${item.taskId}`, { + method: 'PUT', + body: JSON.stringify({ + status: 'resume', + }), + }) }, deleteTask: () => { confirm.require({ @@ -46,7 +53,9 @@ export const useDownload = defineStore('download', (store) => { severity: 'danger', }, accept: () => { - socket.send('deleteDownloadTask', item.taskId) + request(`/download/${item.taskId}`, { + method: 'DELETE', + }) }, reject: () => {}, }) @@ -56,12 +65,28 @@ export const useDownload = defineStore('download', (store) => { return task } + const refresh = async () => { + return request('/download/task') + .then((resData: DownloadTaskOptions[]) => { + taskList.value = resData.map((item) => createTaskItem(item)) + return taskList.value + }) + .catch((err) => { + toast.add({ + severity: 'error', + summary: 'Error', + detail: err.message ?? 'Failed to refresh download task list', + life: 15000, + }) + }) + } + onBeforeMount(() => { - socket.addEventListener('reconnected', () => { + api.addEventListener('reconnected', () => { refresh() }) - socket.addEventListener('downloadTaskList', (event) => { + api.addEventListener('fetch_download_task_list', (event) => { const data = event.detail as DownloadTaskOptions[] taskList.value = data.map((item) => { @@ -69,12 +94,12 @@ export const useDownload = defineStore('download', (store) => { }) }) - socket.addEventListener('createDownloadTask', (event) => { + api.addEventListener('create_download_task', (event) => { const item = event.detail as DownloadTaskOptions taskList.value.unshift(createTaskItem(item)) }) - socket.addEventListener('updateDownloadTask', (event) => { + api.addEventListener('update_download_task', (event) => { const item = event.detail as DownloadTaskOptions for (const task of taskList.value) { @@ -93,12 +118,12 @@ export const useDownload = defineStore('download', (store) => { } }) - socket.addEventListener('deleteDownloadTask', (event) => { + api.addEventListener('delete_download_task', (event) => { const taskId = event.detail as string taskList.value = taskList.value.filter((item) => item.taskId !== taskId) }) - socket.addEventListener('completeDownloadTask', (event) => { + api.addEventListener('complete_download_task', (event) => { const taskId = event.detail as string const task = taskList.value.find((item) => item.taskId === taskId) taskList.value = taskList.value.filter((item) => item.taskId !== taskId) diff --git a/src/hooks/socket.ts b/src/hooks/socket.ts deleted file mode 100644 index 582a43a..0000000 --- a/src/hooks/socket.ts +++ /dev/null @@ -1,82 +0,0 @@ -import { globalToast } from 'hooks/toast' -import { readonly } from 'vue' - -class WebSocketEvent extends EventTarget { - private socket: WebSocket | null - - constructor() { - super() - this.createSocket() - } - - private createSocket(isReconnect?: boolean) { - const api_host = location.host - const api_base = location.pathname.split('/').slice(0, -1).join('/') - - let opened = false - let existingSession = window.name - if (existingSession) { - existingSession = '?clientId=' + existingSession - } - - this.socket = readonly( - new WebSocket( - `ws${window.location.protocol === 'https:' ? 's' : ''}://${api_host}${api_base}/model-manager/ws${existingSession}`, - ), - ) - - this.socket.addEventListener('open', () => { - opened = true - if (isReconnect) { - this.dispatchEvent(new CustomEvent('reconnected')) - } - }) - - this.socket.addEventListener('error', () => { - if (this.socket) this.socket.close() - }) - - this.socket.addEventListener('close', (event) => { - setTimeout(() => { - this.socket = null - this.createSocket(true) - }, 300) - if (opened) { - this.dispatchEvent(new CustomEvent('status', { detail: null })) - this.dispatchEvent(new CustomEvent('reconnecting')) - } - }) - - this.socket.addEventListener('message', (event) => { - try { - const msg = JSON.parse(event.data) - if (msg.type === 'error') { - globalToast.value?.add({ - severity: 'error', - summary: 'Error', - detail: msg.data, - life: 15000, - }) - } else { - this.dispatchEvent(new CustomEvent(msg.type, { detail: msg.data })) - } - } catch (error) { - console.error(error) - } - }) - } - - addEventListener = ( - type: string, - callback: CustomEventListener | null, - options?: AddEventListenerOptions | boolean, - ) => { - super.addEventListener(type, callback, options) - } - - send(type: string, data: any) { - this.socket?.send(JSON.stringify({ type, detail: data })) - } -} - -export const socket = new WebSocketEvent()