feat: adapt to multi user
This commit is contained in:
12
py/config.py
12
py/config.py
@@ -19,15 +19,3 @@ from server import PromptServer
|
||||
|
||||
serverInstance = PromptServer.instance
|
||||
routes = serverInstance.routes
|
||||
|
||||
|
||||
class FakeRequest:
|
||||
def __init__(self):
|
||||
self.headers = {}
|
||||
|
||||
|
||||
class CustomException(BaseException):
|
||||
def __init__(self, type: str, message: str = None) -> None:
|
||||
self.type = type
|
||||
self.message = message
|
||||
super().__init__(message)
|
||||
|
||||
@@ -9,7 +9,6 @@ from typing import Callable, Awaitable, Any, Literal, Union, Optional
|
||||
from dataclasses import dataclass
|
||||
from . import config
|
||||
from . import utils
|
||||
from . import socket
|
||||
from . import thread
|
||||
|
||||
|
||||
@@ -93,33 +92,28 @@ def delete_task_status(task_id: str):
|
||||
download_model_task_status.pop(task_id, None)
|
||||
|
||||
|
||||
async def scan_model_download_task_list(sid: str):
|
||||
async def scan_model_download_task_list():
|
||||
"""
|
||||
Scan the download directory and send the task list to the client.
|
||||
"""
|
||||
try:
|
||||
download_dir = utils.get_download_path()
|
||||
task_files = utils.search_files(download_dir)
|
||||
task_files = folder_paths.filter_files_extensions(task_files, [".task"])
|
||||
task_files = sorted(
|
||||
task_files,
|
||||
key=lambda x: os.stat(utils.join_path(download_dir, x)).st_ctime,
|
||||
reverse=True,
|
||||
)
|
||||
task_list: list[dict] = []
|
||||
for task_file in task_files:
|
||||
task_id = task_file.replace(".task", "")
|
||||
task_status = get_task_status(task_id)
|
||||
task_list.append(task_status)
|
||||
download_dir = utils.get_download_path()
|
||||
task_files = utils.search_files(download_dir)
|
||||
task_files = folder_paths.filter_files_extensions(task_files, [".task"])
|
||||
task_files = sorted(
|
||||
task_files,
|
||||
key=lambda x: os.stat(utils.join_path(download_dir, x)).st_ctime,
|
||||
reverse=True,
|
||||
)
|
||||
task_list: list[dict] = []
|
||||
for task_file in task_files:
|
||||
task_id = task_file.replace(".task", "")
|
||||
task_status = get_task_status(task_id)
|
||||
task_list.append(task_status)
|
||||
|
||||
await socket.send_json("downloadTaskList", task_list, sid)
|
||||
except Exception as e:
|
||||
error_msg = f"Refresh task list failed: {e}"
|
||||
await socket.send_json("error", error_msg, sid)
|
||||
logging.error(error_msg)
|
||||
return utils.unpack_dataclass(task_list)
|
||||
|
||||
|
||||
async def create_model_download_task(post: dict):
|
||||
async def create_model_download_task(post: dict, request):
|
||||
"""
|
||||
Creates a download task for the given post.
|
||||
"""
|
||||
@@ -152,12 +146,12 @@ async def create_model_download_task(post: dict):
|
||||
totalSize=float(post.get("sizeBytes", 0)),
|
||||
)
|
||||
download_model_task_status[task_id] = task_status
|
||||
await socket.send_json("createDownloadTask", task_status)
|
||||
await utils.send_json("create_download_task", task_status)
|
||||
except Exception as e:
|
||||
await delete_model_download_task(task_id)
|
||||
raise RuntimeError(str(e)) from e
|
||||
|
||||
await download_model(task_id)
|
||||
await download_model(task_id, request)
|
||||
return task_id
|
||||
|
||||
|
||||
@@ -170,7 +164,7 @@ async def delete_model_download_task(task_id: str):
|
||||
task_status = get_task_status(task_id)
|
||||
is_running = task_status.status == "doing"
|
||||
task_status.status = "waiting"
|
||||
await socket.send_json("deleteDownloadTask", task_id)
|
||||
await utils.send_json("delete_download_task", task_id)
|
||||
|
||||
# Pause the task
|
||||
if is_running:
|
||||
@@ -185,13 +179,13 @@ async def delete_model_download_task(task_id: str):
|
||||
delete_task_status(task_id)
|
||||
os.remove(utils.join_path(download_dir, task_file))
|
||||
|
||||
await socket.send_json("deleteDownloadTask", task_id)
|
||||
await utils.send_json("delete_download_task", task_id)
|
||||
|
||||
|
||||
async def download_model(task_id: str):
|
||||
async def download_model(task_id: str, request):
|
||||
async def download_task(task_id: str):
|
||||
async def report_progress(task_status: TaskStatus):
|
||||
await socket.send_json("updateDownloadTask", task_status)
|
||||
await utils.send_json("update_download_task", task_status)
|
||||
|
||||
try:
|
||||
# When starting a task from the queue, the task may not exist
|
||||
@@ -201,7 +195,7 @@ async def download_model(task_id: str):
|
||||
|
||||
# Update task status
|
||||
task_status.status = "doing"
|
||||
await socket.send_json("updateDownloadTask", task_status)
|
||||
await utils.send_json("update_download_task", task_status)
|
||||
|
||||
try:
|
||||
|
||||
@@ -210,12 +204,12 @@ async def download_model(task_id: str):
|
||||
|
||||
download_platform = task_status.platform
|
||||
if download_platform == "civitai":
|
||||
api_key = utils.get_setting_value("api_key.civitai")
|
||||
api_key = utils.get_setting_value(request, "api_key.civitai")
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
|
||||
elif download_platform == "huggingface":
|
||||
api_key = utils.get_setting_value("api_key.huggingface")
|
||||
api_key = utils.get_setting_value(request, "api_key.huggingface")
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
|
||||
@@ -229,7 +223,7 @@ async def download_model(task_id: str):
|
||||
except Exception as e:
|
||||
task_status.status = "pause"
|
||||
task_status.error = str(e)
|
||||
await socket.send_json("updateDownloadTask", task_status)
|
||||
await utils.send_json("update_download_task", task_status)
|
||||
task_status.error = None
|
||||
logging.error(str(e))
|
||||
|
||||
@@ -238,11 +232,11 @@ async def download_model(task_id: str):
|
||||
if status == "Waiting":
|
||||
task_status = get_task_status(task_id)
|
||||
task_status.status = "waiting"
|
||||
await socket.send_json("updateDownloadTask", task_status)
|
||||
await utils.send_json("update_download_task", task_status)
|
||||
except Exception as e:
|
||||
task_status.status = "pause"
|
||||
task_status.error = str(e)
|
||||
await socket.send_json("updateDownloadTask", task_status)
|
||||
await utils.send_json("update_download_task", task_status)
|
||||
task_status.error = None
|
||||
logging.error(traceback.format_exc())
|
||||
|
||||
@@ -275,7 +269,7 @@ async def download_model_file(
|
||||
time.sleep(1)
|
||||
task_file = utils.join_path(download_path, f"{task_id}.task")
|
||||
os.remove(task_file)
|
||||
await socket.send_json("completeDownloadTask", task_id)
|
||||
await utils.send_json("complete_download_task", task_id)
|
||||
|
||||
async def update_progress():
|
||||
nonlocal last_update_time
|
||||
@@ -347,7 +341,7 @@ async def download_model_file(
|
||||
task_content.sizeBytes = total_size
|
||||
task_status.totalSize = total_size
|
||||
set_task_content(task_id, task_content)
|
||||
await socket.send_json("updateDownloadTask", task_content)
|
||||
await utils.send_json("update_download_task", task_content)
|
||||
|
||||
with open(download_tmp_file, "ab") as f:
|
||||
for chunk in response.iter_content(chunk_size=8192):
|
||||
@@ -366,4 +360,4 @@ async def download_model_file(
|
||||
await download_complete()
|
||||
else:
|
||||
task_status.status = "pause"
|
||||
await socket.send_json("updateDownloadTask", task_status)
|
||||
await utils.send_json("update_download_task", task_status)
|
||||
|
||||
@@ -1,37 +1,13 @@
|
||||
import os
|
||||
import logging
|
||||
import traceback
|
||||
|
||||
import folder_paths
|
||||
|
||||
from typing import Any
|
||||
from multidict import MultiDictProxy
|
||||
from . import config
|
||||
from . import utils
|
||||
from . import socket
|
||||
from . import download
|
||||
|
||||
|
||||
async def connect_websocket(request):
|
||||
async def message_handler(event_type: str, detail: Any, sid: str):
|
||||
try:
|
||||
if event_type == "downloadTaskList":
|
||||
await download.scan_model_download_task_list(sid=sid)
|
||||
|
||||
if event_type == "resumeDownloadTask":
|
||||
await download.download_model(task_id=detail)
|
||||
|
||||
if event_type == "pauseDownloadTask":
|
||||
await download.pause_model_download_task(task_id=detail)
|
||||
|
||||
if event_type == "deleteDownloadTask":
|
||||
await download.delete_model_download_task(task_id=detail)
|
||||
except Exception:
|
||||
logging.error(traceback.format_exc())
|
||||
|
||||
ws = await socket.create_websocket_handler(request, handler=message_handler)
|
||||
return ws
|
||||
|
||||
|
||||
def scan_models():
|
||||
result = []
|
||||
model_base_paths = config.model_base_paths
|
||||
@@ -135,6 +111,22 @@ def remove_model(model_path: str):
|
||||
os.remove(utils.join_path(model_dirname, description))
|
||||
|
||||
|
||||
async def create_model_download_task(post):
|
||||
async def create_model_download_task(post, request):
|
||||
dict_post = dict(post)
|
||||
return await download.create_model_download_task(dict_post)
|
||||
return await download.create_model_download_task(dict_post, request)
|
||||
|
||||
|
||||
async def scan_model_download_task_list():
|
||||
return await download.scan_model_download_task_list()
|
||||
|
||||
|
||||
async def pause_model_download_task(task_id):
|
||||
return await download.pause_model_download_task(task_id)
|
||||
|
||||
|
||||
async def resume_model_download_task(task_id, request):
|
||||
return await download.download_model(task_id, request)
|
||||
|
||||
|
||||
async def delete_model_download_task(task_id):
|
||||
return await download.delete_model_download_task(task_id)
|
||||
|
||||
63
py/socket.py
63
py/socket.py
@@ -1,63 +0,0 @@
|
||||
import aiohttp
|
||||
import logging
|
||||
import uuid
|
||||
import json
|
||||
from aiohttp import web
|
||||
from typing import Any, Callable, Awaitable
|
||||
from . import utils
|
||||
|
||||
|
||||
__sockets: dict[str, web.WebSocketResponse] = {}
|
||||
|
||||
|
||||
async def create_websocket_handler(
|
||||
request, handler: Callable[[str, Any, str], Awaitable[Any]]
|
||||
):
|
||||
ws = web.WebSocketResponse()
|
||||
await ws.prepare(request)
|
||||
sid = request.rel_url.query.get("clientId", "")
|
||||
if sid:
|
||||
# Reusing existing session, remove old
|
||||
__sockets.pop(sid, None)
|
||||
else:
|
||||
sid = uuid.uuid4().hex
|
||||
|
||||
__sockets[sid] = ws
|
||||
|
||||
try:
|
||||
async for msg in ws:
|
||||
if msg.type == aiohttp.WSMsgType.ERROR:
|
||||
logging.warning(
|
||||
"ws connection closed with exception %s" % ws.exception()
|
||||
)
|
||||
if msg.type == aiohttp.WSMsgType.TEXT:
|
||||
data = json.loads(msg.data)
|
||||
await handler(data.get("type"), data.get("detail"), sid)
|
||||
finally:
|
||||
__sockets.pop(sid, None)
|
||||
return ws
|
||||
|
||||
|
||||
async def send_json(event: str, data: Any, sid: str = None):
|
||||
detail = utils.unpack_dataclass(data)
|
||||
message = {"type": event, "data": detail}
|
||||
|
||||
if sid is None:
|
||||
socket_list = list(__sockets.values())
|
||||
for ws in socket_list:
|
||||
await __send_socket_catch_exception(ws.send_json, message)
|
||||
elif sid in __sockets:
|
||||
await __send_socket_catch_exception(__sockets[sid].send_json, message)
|
||||
|
||||
|
||||
async def __send_socket_catch_exception(function, message):
|
||||
try:
|
||||
await function(message)
|
||||
except (
|
||||
aiohttp.ClientError,
|
||||
aiohttp.ClientPayloadError,
|
||||
ConnectionResetError,
|
||||
BrokenPipeError,
|
||||
ConnectionError,
|
||||
) as err:
|
||||
logging.warning("send error: {}".format(err))
|
||||
10
py/thread.py
10
py/thread.py
@@ -2,7 +2,6 @@ import asyncio
|
||||
import threading
|
||||
import queue
|
||||
import logging
|
||||
from . import utils
|
||||
|
||||
|
||||
class DownloadThreadPool:
|
||||
@@ -13,14 +12,7 @@ class DownloadThreadPool:
|
||||
self._lock = threading.Lock()
|
||||
|
||||
default_max_workers = 5
|
||||
max_workers: int = utils.get_setting_value(
|
||||
"download.max_task_count", default_max_workers
|
||||
)
|
||||
|
||||
if max_workers <= 0:
|
||||
max_workers = default_max_workers
|
||||
utils.set_setting_value("download.max_task_count", max_workers)
|
||||
|
||||
max_workers: int = default_max_workers
|
||||
self.max_worker = max_workers
|
||||
|
||||
def submit(self, task, task_id):
|
||||
|
||||
17
py/utils.py
17
py/utils.py
@@ -334,18 +334,16 @@ def resolve_setting_key(key: str) -> str:
|
||||
return setting_id
|
||||
|
||||
|
||||
def set_setting_value(key: str, value: Any):
|
||||
def set_setting_value(request: web.Request, key: str, value: Any):
|
||||
setting_id = resolve_setting_key(key)
|
||||
fake_request = config.FakeRequest()
|
||||
settings = config.serverInstance.user_manager.settings.get_settings(fake_request)
|
||||
settings = config.serverInstance.user_manager.settings.get_settings(request)
|
||||
settings[setting_id] = value
|
||||
config.serverInstance.user_manager.settings.save_settings(fake_request, settings)
|
||||
config.serverInstance.user_manager.settings.save_settings(request, settings)
|
||||
|
||||
|
||||
def get_setting_value(key: str, default: Any = None) -> Any:
|
||||
def get_setting_value(request: web.Request, key: str, default: Any = None) -> Any:
|
||||
setting_id = resolve_setting_key(key)
|
||||
fake_request = config.FakeRequest()
|
||||
settings = config.serverInstance.user_manager.settings.get_settings(fake_request)
|
||||
settings = config.serverInstance.user_manager.settings.get_settings(request)
|
||||
return settings.get(setting_id, default)
|
||||
|
||||
|
||||
@@ -361,3 +359,8 @@ def unpack_dataclass(data: Any):
|
||||
return asdict(data)
|
||||
else:
|
||||
return data
|
||||
|
||||
|
||||
async def send_json(event: str, data: Any, sid: str = None):
|
||||
detail = unpack_dataclass(data)
|
||||
await config.serverInstance.send_json(event, detail, sid)
|
||||
|
||||
Reference in New Issue
Block a user