Compare commits
7 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e36af38375 | ||
|
|
d4922f59d3 | ||
|
|
f2e17744ae | ||
|
|
3b25d3e347 | ||
|
|
3a0676b29f | ||
|
|
a1e5761dbc | ||
|
|
ae518b541a |
60
__init__.py
60
__init__.py
@@ -21,13 +21,61 @@ from .py import services
|
|||||||
routes = config.routes
|
routes = config.routes
|
||||||
|
|
||||||
|
|
||||||
@routes.get("/model-manager/ws")
|
@routes.get("/model-manager/download/task")
|
||||||
async def socket_handler(request):
|
async def scan_download_tasks(request):
|
||||||
"""
|
"""
|
||||||
Handle websocket connection.
|
Read download task list.
|
||||||
"""
|
"""
|
||||||
ws = await services.connect_websocket(request)
|
try:
|
||||||
return ws
|
result = await services.scan_model_download_task_list()
|
||||||
|
return web.json_response({"success": True, "data": result})
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = f"Read download task list failed: {e}"
|
||||||
|
logging.error(error_msg)
|
||||||
|
logging.debug(traceback.format_exc())
|
||||||
|
return web.json_response({"success": False, "error": error_msg})
|
||||||
|
|
||||||
|
|
||||||
|
@routes.put("/model-manager/download/{task_id}")
|
||||||
|
async def resume_download_task(request):
|
||||||
|
"""
|
||||||
|
Toggle download task status.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
task_id = request.match_info.get("task_id", None)
|
||||||
|
if task_id is None:
|
||||||
|
raise web.HTTPBadRequest(reason="Invalid task id")
|
||||||
|
json_data = await request.json()
|
||||||
|
status = json_data.get("status", None)
|
||||||
|
if status == "pause":
|
||||||
|
await services.pause_model_download_task(task_id)
|
||||||
|
elif status == "resume":
|
||||||
|
await services.resume_model_download_task(task_id, request)
|
||||||
|
else:
|
||||||
|
raise web.HTTPBadRequest(reason="Invalid status")
|
||||||
|
|
||||||
|
return web.json_response({"success": True})
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = f"Resume download task failed: {str(e)}"
|
||||||
|
logging.error(error_msg)
|
||||||
|
logging.debug(traceback.format_exc())
|
||||||
|
return web.json_response({"success": False, "error": error_msg})
|
||||||
|
|
||||||
|
|
||||||
|
@routes.delete("/model-manager/download/{task_id}")
|
||||||
|
async def delete_model_download_task(request):
|
||||||
|
"""
|
||||||
|
Delete download task.
|
||||||
|
"""
|
||||||
|
task_id = request.match_info.get("task_id", None)
|
||||||
|
try:
|
||||||
|
await services.delete_model_download_task(task_id)
|
||||||
|
return web.json_response({"success": True})
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = f"Delete download task failed: {str(e)}"
|
||||||
|
logging.error(error_msg)
|
||||||
|
logging.debug(traceback.format_exc())
|
||||||
|
return web.json_response({"success": False, "error": error_msg})
|
||||||
|
|
||||||
|
|
||||||
@routes.get("/model-manager/base-folders")
|
@routes.get("/model-manager/base-folders")
|
||||||
@@ -56,7 +104,7 @@ async def create_model(request):
|
|||||||
"""
|
"""
|
||||||
post = await request.post()
|
post = await request.post()
|
||||||
try:
|
try:
|
||||||
task_id = await services.create_model_download_task(post)
|
task_id = await services.create_model_download_task(post, request)
|
||||||
return web.json_response({"success": True, "data": {"taskId": task_id}})
|
return web.json_response({"success": True, "data": {"taskId": task_id}})
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error_msg = f"Create model download task failed: {str(e)}"
|
error_msg = f"Create model download task failed: {str(e)}"
|
||||||
|
|||||||
12
py/config.py
12
py/config.py
@@ -19,15 +19,3 @@ from server import PromptServer
|
|||||||
|
|
||||||
serverInstance = PromptServer.instance
|
serverInstance = PromptServer.instance
|
||||||
routes = serverInstance.routes
|
routes = serverInstance.routes
|
||||||
|
|
||||||
|
|
||||||
class FakeRequest:
|
|
||||||
def __init__(self):
|
|
||||||
self.headers = {}
|
|
||||||
|
|
||||||
|
|
||||||
class CustomException(BaseException):
|
|
||||||
def __init__(self, type: str, message: str = None) -> None:
|
|
||||||
self.type = type
|
|
||||||
self.message = message
|
|
||||||
super().__init__(message)
|
|
||||||
|
|||||||
@@ -9,7 +9,6 @@ from typing import Callable, Awaitable, Any, Literal, Union, Optional
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from . import config
|
from . import config
|
||||||
from . import utils
|
from . import utils
|
||||||
from . import socket
|
|
||||||
from . import thread
|
from . import thread
|
||||||
|
|
||||||
|
|
||||||
@@ -93,33 +92,28 @@ def delete_task_status(task_id: str):
|
|||||||
download_model_task_status.pop(task_id, None)
|
download_model_task_status.pop(task_id, None)
|
||||||
|
|
||||||
|
|
||||||
async def scan_model_download_task_list(sid: str):
|
async def scan_model_download_task_list():
|
||||||
"""
|
"""
|
||||||
Scan the download directory and send the task list to the client.
|
Scan the download directory and send the task list to the client.
|
||||||
"""
|
"""
|
||||||
try:
|
download_dir = utils.get_download_path()
|
||||||
download_dir = utils.get_download_path()
|
task_files = utils.search_files(download_dir)
|
||||||
task_files = utils.search_files(download_dir)
|
task_files = folder_paths.filter_files_extensions(task_files, [".task"])
|
||||||
task_files = folder_paths.filter_files_extensions(task_files, [".task"])
|
task_files = sorted(
|
||||||
task_files = sorted(
|
task_files,
|
||||||
task_files,
|
key=lambda x: os.stat(utils.join_path(download_dir, x)).st_ctime,
|
||||||
key=lambda x: os.stat(utils.join_path(download_dir, x)).st_ctime,
|
reverse=True,
|
||||||
reverse=True,
|
)
|
||||||
)
|
task_list: list[dict] = []
|
||||||
task_list: list[dict] = []
|
for task_file in task_files:
|
||||||
for task_file in task_files:
|
task_id = task_file.replace(".task", "")
|
||||||
task_id = task_file.replace(".task", "")
|
task_status = get_task_status(task_id)
|
||||||
task_status = get_task_status(task_id)
|
task_list.append(task_status)
|
||||||
task_list.append(task_status)
|
|
||||||
|
|
||||||
await socket.send_json("downloadTaskList", task_list, sid)
|
return utils.unpack_dataclass(task_list)
|
||||||
except Exception as e:
|
|
||||||
error_msg = f"Refresh task list failed: {e}"
|
|
||||||
await socket.send_json("error", error_msg, sid)
|
|
||||||
logging.error(error_msg)
|
|
||||||
|
|
||||||
|
|
||||||
async def create_model_download_task(post: dict):
|
async def create_model_download_task(post: dict, request):
|
||||||
"""
|
"""
|
||||||
Creates a download task for the given post.
|
Creates a download task for the given post.
|
||||||
"""
|
"""
|
||||||
@@ -152,12 +146,12 @@ async def create_model_download_task(post: dict):
|
|||||||
totalSize=float(post.get("sizeBytes", 0)),
|
totalSize=float(post.get("sizeBytes", 0)),
|
||||||
)
|
)
|
||||||
download_model_task_status[task_id] = task_status
|
download_model_task_status[task_id] = task_status
|
||||||
await socket.send_json("createDownloadTask", task_status)
|
await utils.send_json("create_download_task", task_status)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
await delete_model_download_task(task_id)
|
await delete_model_download_task(task_id)
|
||||||
raise RuntimeError(str(e)) from e
|
raise RuntimeError(str(e)) from e
|
||||||
|
|
||||||
await download_model(task_id)
|
await download_model(task_id, request)
|
||||||
return task_id
|
return task_id
|
||||||
|
|
||||||
|
|
||||||
@@ -170,7 +164,7 @@ async def delete_model_download_task(task_id: str):
|
|||||||
task_status = get_task_status(task_id)
|
task_status = get_task_status(task_id)
|
||||||
is_running = task_status.status == "doing"
|
is_running = task_status.status == "doing"
|
||||||
task_status.status = "waiting"
|
task_status.status = "waiting"
|
||||||
await socket.send_json("deleteDownloadTask", task_id)
|
await utils.send_json("delete_download_task", task_id)
|
||||||
|
|
||||||
# Pause the task
|
# Pause the task
|
||||||
if is_running:
|
if is_running:
|
||||||
@@ -185,13 +179,13 @@ async def delete_model_download_task(task_id: str):
|
|||||||
delete_task_status(task_id)
|
delete_task_status(task_id)
|
||||||
os.remove(utils.join_path(download_dir, task_file))
|
os.remove(utils.join_path(download_dir, task_file))
|
||||||
|
|
||||||
await socket.send_json("deleteDownloadTask", task_id)
|
await utils.send_json("delete_download_task", task_id)
|
||||||
|
|
||||||
|
|
||||||
async def download_model(task_id: str):
|
async def download_model(task_id: str, request):
|
||||||
async def download_task(task_id: str):
|
async def download_task(task_id: str):
|
||||||
async def report_progress(task_status: TaskStatus):
|
async def report_progress(task_status: TaskStatus):
|
||||||
await socket.send_json("updateDownloadTask", task_status)
|
await utils.send_json("update_download_task", task_status)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# When starting a task from the queue, the task may not exist
|
# When starting a task from the queue, the task may not exist
|
||||||
@@ -201,7 +195,7 @@ async def download_model(task_id: str):
|
|||||||
|
|
||||||
# Update task status
|
# Update task status
|
||||||
task_status.status = "doing"
|
task_status.status = "doing"
|
||||||
await socket.send_json("updateDownloadTask", task_status)
|
await utils.send_json("update_download_task", task_status)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
||||||
@@ -210,12 +204,12 @@ async def download_model(task_id: str):
|
|||||||
|
|
||||||
download_platform = task_status.platform
|
download_platform = task_status.platform
|
||||||
if download_platform == "civitai":
|
if download_platform == "civitai":
|
||||||
api_key = utils.get_setting_value("api_key.civitai")
|
api_key = utils.get_setting_value(request, "api_key.civitai")
|
||||||
if api_key:
|
if api_key:
|
||||||
headers["Authorization"] = f"Bearer {api_key}"
|
headers["Authorization"] = f"Bearer {api_key}"
|
||||||
|
|
||||||
elif download_platform == "huggingface":
|
elif download_platform == "huggingface":
|
||||||
api_key = utils.get_setting_value("api_key.huggingface")
|
api_key = utils.get_setting_value(request, "api_key.huggingface")
|
||||||
if api_key:
|
if api_key:
|
||||||
headers["Authorization"] = f"Bearer {api_key}"
|
headers["Authorization"] = f"Bearer {api_key}"
|
||||||
|
|
||||||
@@ -229,7 +223,7 @@ async def download_model(task_id: str):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
task_status.status = "pause"
|
task_status.status = "pause"
|
||||||
task_status.error = str(e)
|
task_status.error = str(e)
|
||||||
await socket.send_json("updateDownloadTask", task_status)
|
await utils.send_json("update_download_task", task_status)
|
||||||
task_status.error = None
|
task_status.error = None
|
||||||
logging.error(str(e))
|
logging.error(str(e))
|
||||||
|
|
||||||
@@ -238,11 +232,11 @@ async def download_model(task_id: str):
|
|||||||
if status == "Waiting":
|
if status == "Waiting":
|
||||||
task_status = get_task_status(task_id)
|
task_status = get_task_status(task_id)
|
||||||
task_status.status = "waiting"
|
task_status.status = "waiting"
|
||||||
await socket.send_json("updateDownloadTask", task_status)
|
await utils.send_json("update_download_task", task_status)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
task_status.status = "pause"
|
task_status.status = "pause"
|
||||||
task_status.error = str(e)
|
task_status.error = str(e)
|
||||||
await socket.send_json("updateDownloadTask", task_status)
|
await utils.send_json("update_download_task", task_status)
|
||||||
task_status.error = None
|
task_status.error = None
|
||||||
logging.error(traceback.format_exc())
|
logging.error(traceback.format_exc())
|
||||||
|
|
||||||
@@ -275,7 +269,7 @@ async def download_model_file(
|
|||||||
time.sleep(1)
|
time.sleep(1)
|
||||||
task_file = utils.join_path(download_path, f"{task_id}.task")
|
task_file = utils.join_path(download_path, f"{task_id}.task")
|
||||||
os.remove(task_file)
|
os.remove(task_file)
|
||||||
await socket.send_json("completeDownloadTask", task_id)
|
await utils.send_json("complete_download_task", task_id)
|
||||||
|
|
||||||
async def update_progress():
|
async def update_progress():
|
||||||
nonlocal last_update_time
|
nonlocal last_update_time
|
||||||
@@ -329,6 +323,13 @@ async def download_model_file(
|
|||||||
# If no token is carried, it will be redirected to the login page.
|
# If no token is carried, it will be redirected to the login page.
|
||||||
content_type = response.headers.get("content-type")
|
content_type = response.headers.get("content-type")
|
||||||
if content_type and content_type.startswith("text/html"):
|
if content_type and content_type.startswith("text/html"):
|
||||||
|
# TODO More checks
|
||||||
|
# In addition to requiring login to download, there may be other restrictions.
|
||||||
|
# The currently one situation is early access??? issues#43
|
||||||
|
# Due to the lack of test data, let’s put it aside for now.
|
||||||
|
# If it cannot be downloaded, a redirect will definitely occur.
|
||||||
|
# Maybe consider getting the redirect url from response.history to make a judgment.
|
||||||
|
# Here we also need to consider how different websites are processed.
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"{task_content.fullname} needs to be logged in to download. Please set the API-Key first."
|
f"{task_content.fullname} needs to be logged in to download. Please set the API-Key first."
|
||||||
)
|
)
|
||||||
@@ -340,7 +341,7 @@ async def download_model_file(
|
|||||||
task_content.sizeBytes = total_size
|
task_content.sizeBytes = total_size
|
||||||
task_status.totalSize = total_size
|
task_status.totalSize = total_size
|
||||||
set_task_content(task_id, task_content)
|
set_task_content(task_id, task_content)
|
||||||
await socket.send_json("updateDownloadTask", task_content)
|
await utils.send_json("update_download_task", task_content)
|
||||||
|
|
||||||
with open(download_tmp_file, "ab") as f:
|
with open(download_tmp_file, "ab") as f:
|
||||||
for chunk in response.iter_content(chunk_size=8192):
|
for chunk in response.iter_content(chunk_size=8192):
|
||||||
@@ -359,4 +360,4 @@ async def download_model_file(
|
|||||||
await download_complete()
|
await download_complete()
|
||||||
else:
|
else:
|
||||||
task_status.status = "pause"
|
task_status.status = "pause"
|
||||||
await socket.send_json("updateDownloadTask", task_status)
|
await utils.send_json("update_download_task", task_status)
|
||||||
|
|||||||
@@ -1,37 +1,13 @@
|
|||||||
import os
|
import os
|
||||||
import logging
|
|
||||||
import traceback
|
|
||||||
import folder_paths
|
import folder_paths
|
||||||
|
|
||||||
from typing import Any
|
|
||||||
from multidict import MultiDictProxy
|
from multidict import MultiDictProxy
|
||||||
from . import config
|
from . import config
|
||||||
from . import utils
|
from . import utils
|
||||||
from . import socket
|
|
||||||
from . import download
|
from . import download
|
||||||
|
|
||||||
|
|
||||||
async def connect_websocket(request):
|
|
||||||
async def message_handler(event_type: str, detail: Any, sid: str):
|
|
||||||
try:
|
|
||||||
if event_type == "downloadTaskList":
|
|
||||||
await download.scan_model_download_task_list(sid=sid)
|
|
||||||
|
|
||||||
if event_type == "resumeDownloadTask":
|
|
||||||
await download.download_model(task_id=detail)
|
|
||||||
|
|
||||||
if event_type == "pauseDownloadTask":
|
|
||||||
await download.pause_model_download_task(task_id=detail)
|
|
||||||
|
|
||||||
if event_type == "deleteDownloadTask":
|
|
||||||
await download.delete_model_download_task(task_id=detail)
|
|
||||||
except Exception:
|
|
||||||
logging.error(traceback.format_exc())
|
|
||||||
|
|
||||||
ws = await socket.create_websocket_handler(request, handler=message_handler)
|
|
||||||
return ws
|
|
||||||
|
|
||||||
|
|
||||||
def scan_models():
|
def scan_models():
|
||||||
result = []
|
result = []
|
||||||
model_base_paths = config.model_base_paths
|
model_base_paths = config.model_base_paths
|
||||||
@@ -135,6 +111,22 @@ def remove_model(model_path: str):
|
|||||||
os.remove(utils.join_path(model_dirname, description))
|
os.remove(utils.join_path(model_dirname, description))
|
||||||
|
|
||||||
|
|
||||||
async def create_model_download_task(post):
|
async def create_model_download_task(post, request):
|
||||||
dict_post = dict(post)
|
dict_post = dict(post)
|
||||||
return await download.create_model_download_task(dict_post)
|
return await download.create_model_download_task(dict_post, request)
|
||||||
|
|
||||||
|
|
||||||
|
async def scan_model_download_task_list():
|
||||||
|
return await download.scan_model_download_task_list()
|
||||||
|
|
||||||
|
|
||||||
|
async def pause_model_download_task(task_id):
|
||||||
|
return await download.pause_model_download_task(task_id)
|
||||||
|
|
||||||
|
|
||||||
|
async def resume_model_download_task(task_id, request):
|
||||||
|
return await download.download_model(task_id, request)
|
||||||
|
|
||||||
|
|
||||||
|
async def delete_model_download_task(task_id):
|
||||||
|
return await download.delete_model_download_task(task_id)
|
||||||
|
|||||||
63
py/socket.py
63
py/socket.py
@@ -1,63 +0,0 @@
|
|||||||
import aiohttp
|
|
||||||
import logging
|
|
||||||
import uuid
|
|
||||||
import json
|
|
||||||
from aiohttp import web
|
|
||||||
from typing import Any, Callable, Awaitable
|
|
||||||
from . import utils
|
|
||||||
|
|
||||||
|
|
||||||
__sockets: dict[str, web.WebSocketResponse] = {}
|
|
||||||
|
|
||||||
|
|
||||||
async def create_websocket_handler(
|
|
||||||
request, handler: Callable[[str, Any, str], Awaitable[Any]]
|
|
||||||
):
|
|
||||||
ws = web.WebSocketResponse()
|
|
||||||
await ws.prepare(request)
|
|
||||||
sid = request.rel_url.query.get("clientId", "")
|
|
||||||
if sid:
|
|
||||||
# Reusing existing session, remove old
|
|
||||||
__sockets.pop(sid, None)
|
|
||||||
else:
|
|
||||||
sid = uuid.uuid4().hex
|
|
||||||
|
|
||||||
__sockets[sid] = ws
|
|
||||||
|
|
||||||
try:
|
|
||||||
async for msg in ws:
|
|
||||||
if msg.type == aiohttp.WSMsgType.ERROR:
|
|
||||||
logging.warning(
|
|
||||||
"ws connection closed with exception %s" % ws.exception()
|
|
||||||
)
|
|
||||||
if msg.type == aiohttp.WSMsgType.TEXT:
|
|
||||||
data = json.loads(msg.data)
|
|
||||||
await handler(data.get("type"), data.get("detail"), sid)
|
|
||||||
finally:
|
|
||||||
__sockets.pop(sid, None)
|
|
||||||
return ws
|
|
||||||
|
|
||||||
|
|
||||||
async def send_json(event: str, data: Any, sid: str = None):
|
|
||||||
detail = utils.unpack_dataclass(data)
|
|
||||||
message = {"type": event, "data": detail}
|
|
||||||
|
|
||||||
if sid is None:
|
|
||||||
socket_list = list(__sockets.values())
|
|
||||||
for ws in socket_list:
|
|
||||||
await __send_socket_catch_exception(ws.send_json, message)
|
|
||||||
elif sid in __sockets:
|
|
||||||
await __send_socket_catch_exception(__sockets[sid].send_json, message)
|
|
||||||
|
|
||||||
|
|
||||||
async def __send_socket_catch_exception(function, message):
|
|
||||||
try:
|
|
||||||
await function(message)
|
|
||||||
except (
|
|
||||||
aiohttp.ClientError,
|
|
||||||
aiohttp.ClientPayloadError,
|
|
||||||
ConnectionResetError,
|
|
||||||
BrokenPipeError,
|
|
||||||
ConnectionError,
|
|
||||||
) as err:
|
|
||||||
logging.warning("send error: {}".format(err))
|
|
||||||
10
py/thread.py
10
py/thread.py
@@ -2,7 +2,6 @@ import asyncio
|
|||||||
import threading
|
import threading
|
||||||
import queue
|
import queue
|
||||||
import logging
|
import logging
|
||||||
from . import utils
|
|
||||||
|
|
||||||
|
|
||||||
class DownloadThreadPool:
|
class DownloadThreadPool:
|
||||||
@@ -13,14 +12,7 @@ class DownloadThreadPool:
|
|||||||
self._lock = threading.Lock()
|
self._lock = threading.Lock()
|
||||||
|
|
||||||
default_max_workers = 5
|
default_max_workers = 5
|
||||||
max_workers: int = utils.get_setting_value(
|
max_workers: int = default_max_workers
|
||||||
"download.max_task_count", default_max_workers
|
|
||||||
)
|
|
||||||
|
|
||||||
if max_workers <= 0:
|
|
||||||
max_workers = default_max_workers
|
|
||||||
utils.set_setting_value("download.max_task_count", max_workers)
|
|
||||||
|
|
||||||
self.max_worker = max_workers
|
self.max_worker = max_workers
|
||||||
|
|
||||||
def submit(self, task, task_id):
|
def submit(self, task, task_id):
|
||||||
|
|||||||
17
py/utils.py
17
py/utils.py
@@ -334,18 +334,16 @@ def resolve_setting_key(key: str) -> str:
|
|||||||
return setting_id
|
return setting_id
|
||||||
|
|
||||||
|
|
||||||
def set_setting_value(key: str, value: Any):
|
def set_setting_value(request: web.Request, key: str, value: Any):
|
||||||
setting_id = resolve_setting_key(key)
|
setting_id = resolve_setting_key(key)
|
||||||
fake_request = config.FakeRequest()
|
settings = config.serverInstance.user_manager.settings.get_settings(request)
|
||||||
settings = config.serverInstance.user_manager.settings.get_settings(fake_request)
|
|
||||||
settings[setting_id] = value
|
settings[setting_id] = value
|
||||||
config.serverInstance.user_manager.settings.save_settings(fake_request, settings)
|
config.serverInstance.user_manager.settings.save_settings(request, settings)
|
||||||
|
|
||||||
|
|
||||||
def get_setting_value(key: str, default: Any = None) -> Any:
|
def get_setting_value(request: web.Request, key: str, default: Any = None) -> Any:
|
||||||
setting_id = resolve_setting_key(key)
|
setting_id = resolve_setting_key(key)
|
||||||
fake_request = config.FakeRequest()
|
settings = config.serverInstance.user_manager.settings.get_settings(request)
|
||||||
settings = config.serverInstance.user_manager.settings.get_settings(fake_request)
|
|
||||||
return settings.get(setting_id, default)
|
return settings.get(setting_id, default)
|
||||||
|
|
||||||
|
|
||||||
@@ -361,3 +359,8 @@ def unpack_dataclass(data: Any):
|
|||||||
return asdict(data)
|
return asdict(data)
|
||||||
else:
|
else:
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
async def send_json(event: str, data: Any, sid: str = None):
|
||||||
|
detail = unpack_dataclass(data)
|
||||||
|
await config.serverInstance.send_json(event, detail, sid)
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "comfyui-model-manager"
|
name = "comfyui-model-manager"
|
||||||
description = "Manage models: browsing, download and delete."
|
description = "Manage models: browsing, download and delete."
|
||||||
version = "2.0.2"
|
version = "2.0.3"
|
||||||
license = "LICENSE"
|
license = "LICENSE"
|
||||||
|
|
||||||
[project.urls]
|
[project.urls]
|
||||||
|
|||||||
@@ -28,21 +28,23 @@
|
|||||||
|
|
||||||
<ResponseScroll class="-mx-5 h-full">
|
<ResponseScroll class="-mx-5 h-full">
|
||||||
<div class="px-5">
|
<div class="px-5">
|
||||||
<ModelContent
|
<KeepAlive>
|
||||||
v-if="currentModel"
|
<ModelContent
|
||||||
:key="currentModel.id"
|
v-if="currentModel"
|
||||||
:model="currentModel"
|
:key="currentModel.id"
|
||||||
:editable="true"
|
:model="currentModel"
|
||||||
@submit="createDownTask"
|
:editable="true"
|
||||||
>
|
@submit="createDownTask"
|
||||||
<template #action>
|
>
|
||||||
<Button
|
<template #action>
|
||||||
icon="pi pi-download"
|
<Button
|
||||||
:label="$t('download')"
|
icon="pi pi-download"
|
||||||
type="submit"
|
:label="$t('download')"
|
||||||
></Button>
|
type="submit"
|
||||||
</template>
|
></Button>
|
||||||
</ModelContent>
|
</template>
|
||||||
|
</ModelContent>
|
||||||
|
</KeepAlive>
|
||||||
|
|
||||||
<div v-show="data.length === 0">
|
<div v-show="data.length === 0">
|
||||||
<div class="flex flex-col items-center gap-4 py-8">
|
<div class="flex flex-col items-center gap-4 py-8">
|
||||||
|
|||||||
@@ -40,6 +40,7 @@
|
|||||||
</div>
|
</div>
|
||||||
|
|
||||||
<ResponseScroll
|
<ResponseScroll
|
||||||
|
ref="responseScroll"
|
||||||
:items="list"
|
:items="list"
|
||||||
:itemSize="itemSize"
|
:itemSize="itemSize"
|
||||||
:row-key="(item) => item.map(genModelKey).join(',')"
|
:row-key="(item) => item.map(genModelKey).join(',')"
|
||||||
@@ -80,7 +81,7 @@ import ModelCard from 'components/ModelCard.vue'
|
|||||||
import ResponseInput from 'components/ResponseInput.vue'
|
import ResponseInput from 'components/ResponseInput.vue'
|
||||||
import ResponseSelect from 'components/ResponseSelect.vue'
|
import ResponseSelect from 'components/ResponseSelect.vue'
|
||||||
import ResponseScroll from 'components/ResponseScroll.vue'
|
import ResponseScroll from 'components/ResponseScroll.vue'
|
||||||
import { computed, ref } from 'vue'
|
import { computed, ref, watch } from 'vue'
|
||||||
import { useI18n } from 'vue-i18n'
|
import { useI18n } from 'vue-i18n'
|
||||||
import { chunk } from 'lodash'
|
import { chunk } from 'lodash'
|
||||||
import { defineResizeCallback } from 'hooks/resize'
|
import { defineResizeCallback } from 'hooks/resize'
|
||||||
@@ -91,6 +92,8 @@ const { isMobile, cardWidth, gutter, aspect, modelFolders } = useConfig()
|
|||||||
const { data } = useModels()
|
const { data } = useModels()
|
||||||
const { t } = useI18n()
|
const { t } = useI18n()
|
||||||
|
|
||||||
|
const responseScroll = ref()
|
||||||
|
|
||||||
const searchContent = ref<string>()
|
const searchContent = ref<string>()
|
||||||
|
|
||||||
const currentType = ref('all')
|
const currentType = ref('all')
|
||||||
@@ -120,6 +123,10 @@ const sortOrderOptions = ref(
|
|||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
watch([searchContent, currentType], () => {
|
||||||
|
responseScroll.value.init()
|
||||||
|
})
|
||||||
|
|
||||||
const itemSize = computed(() => {
|
const itemSize = computed(() => {
|
||||||
let itemWidth = cardWidth
|
let itemWidth = cardWidth
|
||||||
let itemGutter = gutter
|
let itemGutter = gutter
|
||||||
|
|||||||
@@ -298,7 +298,6 @@ const startDragThumb = (event: MouseEvent) => {
|
|||||||
watch(
|
watch(
|
||||||
() => props.items,
|
() => props.items,
|
||||||
() => {
|
() => {
|
||||||
init()
|
|
||||||
setSpacerSize()
|
setSpacerSize()
|
||||||
calculateScrollThumbSize()
|
calculateScrollThumbSize()
|
||||||
calculateLoadItems()
|
calculateLoadItems()
|
||||||
@@ -311,5 +310,6 @@ onUnmounted(() => {
|
|||||||
|
|
||||||
defineExpose({
|
defineExpose({
|
||||||
viewport,
|
viewport,
|
||||||
|
init,
|
||||||
})
|
})
|
||||||
</script>
|
</script>
|
||||||
|
|||||||
@@ -1,8 +1,9 @@
|
|||||||
import { useLoading } from 'hooks/loading'
|
import { useLoading } from 'hooks/loading'
|
||||||
import { MarkdownTool, useMarkdown } from 'hooks/markdown'
|
import { MarkdownTool, useMarkdown } from 'hooks/markdown'
|
||||||
import { socket } from 'hooks/socket'
|
import { request } from 'hooks/request'
|
||||||
import { defineStore } from 'hooks/store'
|
import { defineStore } from 'hooks/store'
|
||||||
import { useToast } from 'hooks/toast'
|
import { useToast } from 'hooks/toast'
|
||||||
|
import { api } from 'scripts/comfyAPI'
|
||||||
import { bytesToSize } from 'utils/common'
|
import { bytesToSize } from 'utils/common'
|
||||||
import { onBeforeMount, onMounted, ref, watch } from 'vue'
|
import { onBeforeMount, onMounted, ref, watch } from 'vue'
|
||||||
import { useI18n } from 'vue-i18n'
|
import { useI18n } from 'vue-i18n'
|
||||||
@@ -13,10 +14,6 @@ export const useDownload = defineStore('download', (store) => {
|
|||||||
|
|
||||||
const taskList = ref<DownloadTask[]>([])
|
const taskList = ref<DownloadTask[]>([])
|
||||||
|
|
||||||
const refresh = () => {
|
|
||||||
socket.send('downloadTaskList', null)
|
|
||||||
}
|
|
||||||
|
|
||||||
const createTaskItem = (item: DownloadTaskOptions) => {
|
const createTaskItem = (item: DownloadTaskOptions) => {
|
||||||
const { downloadedSize, totalSize, bps, ...rest } = item
|
const { downloadedSize, totalSize, bps, ...rest } = item
|
||||||
|
|
||||||
@@ -26,10 +23,20 @@ export const useDownload = defineStore('download', (store) => {
|
|||||||
downloadProgress: `${bytesToSize(downloadedSize)} / ${bytesToSize(totalSize)}`,
|
downloadProgress: `${bytesToSize(downloadedSize)} / ${bytesToSize(totalSize)}`,
|
||||||
downloadSpeed: `${bytesToSize(bps)}/s`,
|
downloadSpeed: `${bytesToSize(bps)}/s`,
|
||||||
pauseTask() {
|
pauseTask() {
|
||||||
socket.send('pauseDownloadTask', item.taskId)
|
request(`/download/${item.taskId}`, {
|
||||||
|
method: 'PUT',
|
||||||
|
body: JSON.stringify({
|
||||||
|
status: 'pause',
|
||||||
|
}),
|
||||||
|
})
|
||||||
},
|
},
|
||||||
resumeTask: () => {
|
resumeTask: () => {
|
||||||
socket.send('resumeDownloadTask', item.taskId)
|
request(`/download/${item.taskId}`, {
|
||||||
|
method: 'PUT',
|
||||||
|
body: JSON.stringify({
|
||||||
|
status: 'resume',
|
||||||
|
}),
|
||||||
|
})
|
||||||
},
|
},
|
||||||
deleteTask: () => {
|
deleteTask: () => {
|
||||||
confirm.require({
|
confirm.require({
|
||||||
@@ -46,7 +53,9 @@ export const useDownload = defineStore('download', (store) => {
|
|||||||
severity: 'danger',
|
severity: 'danger',
|
||||||
},
|
},
|
||||||
accept: () => {
|
accept: () => {
|
||||||
socket.send('deleteDownloadTask', item.taskId)
|
request(`/download/${item.taskId}`, {
|
||||||
|
method: 'DELETE',
|
||||||
|
})
|
||||||
},
|
},
|
||||||
reject: () => {},
|
reject: () => {},
|
||||||
})
|
})
|
||||||
@@ -56,12 +65,28 @@ export const useDownload = defineStore('download', (store) => {
|
|||||||
return task
|
return task
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const refresh = async () => {
|
||||||
|
return request('/download/task')
|
||||||
|
.then((resData: DownloadTaskOptions[]) => {
|
||||||
|
taskList.value = resData.map((item) => createTaskItem(item))
|
||||||
|
return taskList.value
|
||||||
|
})
|
||||||
|
.catch((err) => {
|
||||||
|
toast.add({
|
||||||
|
severity: 'error',
|
||||||
|
summary: 'Error',
|
||||||
|
detail: err.message ?? 'Failed to refresh download task list',
|
||||||
|
life: 15000,
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
onBeforeMount(() => {
|
onBeforeMount(() => {
|
||||||
socket.addEventListener('reconnected', () => {
|
api.addEventListener('reconnected', () => {
|
||||||
refresh()
|
refresh()
|
||||||
})
|
})
|
||||||
|
|
||||||
socket.addEventListener('downloadTaskList', (event) => {
|
api.addEventListener('fetch_download_task_list', (event) => {
|
||||||
const data = event.detail as DownloadTaskOptions[]
|
const data = event.detail as DownloadTaskOptions[]
|
||||||
|
|
||||||
taskList.value = data.map((item) => {
|
taskList.value = data.map((item) => {
|
||||||
@@ -69,12 +94,12 @@ export const useDownload = defineStore('download', (store) => {
|
|||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
socket.addEventListener('createDownloadTask', (event) => {
|
api.addEventListener('create_download_task', (event) => {
|
||||||
const item = event.detail as DownloadTaskOptions
|
const item = event.detail as DownloadTaskOptions
|
||||||
taskList.value.unshift(createTaskItem(item))
|
taskList.value.unshift(createTaskItem(item))
|
||||||
})
|
})
|
||||||
|
|
||||||
socket.addEventListener('updateDownloadTask', (event) => {
|
api.addEventListener('update_download_task', (event) => {
|
||||||
const item = event.detail as DownloadTaskOptions
|
const item = event.detail as DownloadTaskOptions
|
||||||
|
|
||||||
for (const task of taskList.value) {
|
for (const task of taskList.value) {
|
||||||
@@ -93,12 +118,12 @@ export const useDownload = defineStore('download', (store) => {
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
socket.addEventListener('deleteDownloadTask', (event) => {
|
api.addEventListener('delete_download_task', (event) => {
|
||||||
const taskId = event.detail as string
|
const taskId = event.detail as string
|
||||||
taskList.value = taskList.value.filter((item) => item.taskId !== taskId)
|
taskList.value = taskList.value.filter((item) => item.taskId !== taskId)
|
||||||
})
|
})
|
||||||
|
|
||||||
socket.addEventListener('completeDownloadTask', (event) => {
|
api.addEventListener('complete_download_task', (event) => {
|
||||||
const taskId = event.detail as string
|
const taskId = event.detail as string
|
||||||
const task = taskList.value.find((item) => item.taskId === taskId)
|
const task = taskList.value.find((item) => item.taskId === taskId)
|
||||||
taskList.value = taskList.value.filter((item) => item.taskId !== taskId)
|
taskList.value = taskList.value.filter((item) => item.taskId !== taskId)
|
||||||
|
|||||||
@@ -1,82 +0,0 @@
|
|||||||
import { globalToast } from 'hooks/toast'
|
|
||||||
import { readonly } from 'vue'
|
|
||||||
|
|
||||||
class WebSocketEvent extends EventTarget {
|
|
||||||
private socket: WebSocket | null
|
|
||||||
|
|
||||||
constructor() {
|
|
||||||
super()
|
|
||||||
this.createSocket()
|
|
||||||
}
|
|
||||||
|
|
||||||
private createSocket(isReconnect?: boolean) {
|
|
||||||
const api_host = location.host
|
|
||||||
const api_base = location.pathname.split('/').slice(0, -1).join('/')
|
|
||||||
|
|
||||||
let opened = false
|
|
||||||
let existingSession = window.name
|
|
||||||
if (existingSession) {
|
|
||||||
existingSession = '?clientId=' + existingSession
|
|
||||||
}
|
|
||||||
|
|
||||||
this.socket = readonly(
|
|
||||||
new WebSocket(
|
|
||||||
`ws${window.location.protocol === 'https:' ? 's' : ''}://${api_host}${api_base}/model-manager/ws${existingSession}`,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
this.socket.addEventListener('open', () => {
|
|
||||||
opened = true
|
|
||||||
if (isReconnect) {
|
|
||||||
this.dispatchEvent(new CustomEvent('reconnected'))
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
this.socket.addEventListener('error', () => {
|
|
||||||
if (this.socket) this.socket.close()
|
|
||||||
})
|
|
||||||
|
|
||||||
this.socket.addEventListener('close', (event) => {
|
|
||||||
setTimeout(() => {
|
|
||||||
this.socket = null
|
|
||||||
this.createSocket(true)
|
|
||||||
}, 300)
|
|
||||||
if (opened) {
|
|
||||||
this.dispatchEvent(new CustomEvent('status', { detail: null }))
|
|
||||||
this.dispatchEvent(new CustomEvent('reconnecting'))
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
this.socket.addEventListener('message', (event) => {
|
|
||||||
try {
|
|
||||||
const msg = JSON.parse(event.data)
|
|
||||||
if (msg.type === 'error') {
|
|
||||||
globalToast.value?.add({
|
|
||||||
severity: 'error',
|
|
||||||
summary: 'Error',
|
|
||||||
detail: msg.data,
|
|
||||||
life: 15000,
|
|
||||||
})
|
|
||||||
} else {
|
|
||||||
this.dispatchEvent(new CustomEvent(msg.type, { detail: msg.data }))
|
|
||||||
}
|
|
||||||
} catch (error) {
|
|
||||||
console.error(error)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
addEventListener = (
|
|
||||||
type: string,
|
|
||||||
callback: CustomEventListener | null,
|
|
||||||
options?: AddEventListenerOptions | boolean,
|
|
||||||
) => {
|
|
||||||
super.addEventListener(type, callback, options)
|
|
||||||
}
|
|
||||||
|
|
||||||
send(type: string, data: any) {
|
|
||||||
this.socket?.send(JSON.stringify({ type, detail: data }))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
export const socket = new WebSocketEvent()
|
|
||||||
Reference in New Issue
Block a user