Merge pull request #47 from hayden-fr/feature-multi-user

feat: adapt to multi user
This commit is contained in:
Hayden
2024-11-11 11:11:09 +08:00
committed by GitHub
9 changed files with 154 additions and 257 deletions

View File

@@ -21,13 +21,61 @@ from .py import services
routes = config.routes
@routes.get("/model-manager/ws")
async def socket_handler(request):
@routes.get("/model-manager/download/task")
async def scan_download_tasks(request):
"""
Handle websocket connection.
Read download task list.
"""
ws = await services.connect_websocket(request)
return ws
try:
result = await services.scan_model_download_task_list()
return web.json_response({"success": True, "data": result})
except Exception as e:
error_msg = f"Read download task list failed: {e}"
logging.error(error_msg)
logging.debug(traceback.format_exc())
return web.json_response({"success": False, "error": error_msg})
@routes.put("/model-manager/download/{task_id}")
async def resume_download_task(request):
"""
Toggle download task status.
"""
try:
task_id = request.match_info.get("task_id", None)
if task_id is None:
raise web.HTTPBadRequest(reason="Invalid task id")
json_data = await request.json()
status = json_data.get("status", None)
if status == "pause":
await services.pause_model_download_task(task_id)
elif status == "resume":
await services.resume_model_download_task(task_id, request)
else:
raise web.HTTPBadRequest(reason="Invalid status")
return web.json_response({"success": True})
except Exception as e:
error_msg = f"Resume download task failed: {str(e)}"
logging.error(error_msg)
logging.debug(traceback.format_exc())
return web.json_response({"success": False, "error": error_msg})
@routes.delete("/model-manager/download/{task_id}")
async def delete_model_download_task(request):
"""
Delete download task.
"""
task_id = request.match_info.get("task_id", None)
try:
await services.delete_model_download_task(task_id)
return web.json_response({"success": True})
except Exception as e:
error_msg = f"Delete download task failed: {str(e)}"
logging.error(error_msg)
logging.debug(traceback.format_exc())
return web.json_response({"success": False, "error": error_msg})
@routes.get("/model-manager/base-folders")
@@ -56,7 +104,7 @@ async def create_model(request):
"""
post = await request.post()
try:
task_id = await services.create_model_download_task(post)
task_id = await services.create_model_download_task(post, request)
return web.json_response({"success": True, "data": {"taskId": task_id}})
except Exception as e:
error_msg = f"Create model download task failed: {str(e)}"

View File

@@ -19,15 +19,3 @@ from server import PromptServer
serverInstance = PromptServer.instance
routes = serverInstance.routes
class FakeRequest:
def __init__(self):
self.headers = {}
class CustomException(BaseException):
def __init__(self, type: str, message: str = None) -> None:
self.type = type
self.message = message
super().__init__(message)

View File

@@ -9,7 +9,6 @@ from typing import Callable, Awaitable, Any, Literal, Union, Optional
from dataclasses import dataclass
from . import config
from . import utils
from . import socket
from . import thread
@@ -93,33 +92,28 @@ def delete_task_status(task_id: str):
download_model_task_status.pop(task_id, None)
async def scan_model_download_task_list(sid: str):
async def scan_model_download_task_list():
"""
Scan the download directory and send the task list to the client.
"""
try:
download_dir = utils.get_download_path()
task_files = utils.search_files(download_dir)
task_files = folder_paths.filter_files_extensions(task_files, [".task"])
task_files = sorted(
task_files,
key=lambda x: os.stat(utils.join_path(download_dir, x)).st_ctime,
reverse=True,
)
task_list: list[dict] = []
for task_file in task_files:
task_id = task_file.replace(".task", "")
task_status = get_task_status(task_id)
task_list.append(task_status)
download_dir = utils.get_download_path()
task_files = utils.search_files(download_dir)
task_files = folder_paths.filter_files_extensions(task_files, [".task"])
task_files = sorted(
task_files,
key=lambda x: os.stat(utils.join_path(download_dir, x)).st_ctime,
reverse=True,
)
task_list: list[dict] = []
for task_file in task_files:
task_id = task_file.replace(".task", "")
task_status = get_task_status(task_id)
task_list.append(task_status)
await socket.send_json("downloadTaskList", task_list, sid)
except Exception as e:
error_msg = f"Refresh task list failed: {e}"
await socket.send_json("error", error_msg, sid)
logging.error(error_msg)
return utils.unpack_dataclass(task_list)
async def create_model_download_task(post: dict):
async def create_model_download_task(post: dict, request):
"""
Creates a download task for the given post.
"""
@@ -152,12 +146,12 @@ async def create_model_download_task(post: dict):
totalSize=float(post.get("sizeBytes", 0)),
)
download_model_task_status[task_id] = task_status
await socket.send_json("createDownloadTask", task_status)
await utils.send_json("create_download_task", task_status)
except Exception as e:
await delete_model_download_task(task_id)
raise RuntimeError(str(e)) from e
await download_model(task_id)
await download_model(task_id, request)
return task_id
@@ -170,7 +164,7 @@ async def delete_model_download_task(task_id: str):
task_status = get_task_status(task_id)
is_running = task_status.status == "doing"
task_status.status = "waiting"
await socket.send_json("deleteDownloadTask", task_id)
await utils.send_json("delete_download_task", task_id)
# Pause the task
if is_running:
@@ -185,13 +179,13 @@ async def delete_model_download_task(task_id: str):
delete_task_status(task_id)
os.remove(utils.join_path(download_dir, task_file))
await socket.send_json("deleteDownloadTask", task_id)
await utils.send_json("delete_download_task", task_id)
async def download_model(task_id: str):
async def download_model(task_id: str, request):
async def download_task(task_id: str):
async def report_progress(task_status: TaskStatus):
await socket.send_json("updateDownloadTask", task_status)
await utils.send_json("update_download_task", task_status)
try:
# When starting a task from the queue, the task may not exist
@@ -201,7 +195,7 @@ async def download_model(task_id: str):
# Update task status
task_status.status = "doing"
await socket.send_json("updateDownloadTask", task_status)
await utils.send_json("update_download_task", task_status)
try:
@@ -210,12 +204,12 @@ async def download_model(task_id: str):
download_platform = task_status.platform
if download_platform == "civitai":
api_key = utils.get_setting_value("api_key.civitai")
api_key = utils.get_setting_value(request, "api_key.civitai")
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
elif download_platform == "huggingface":
api_key = utils.get_setting_value("api_key.huggingface")
api_key = utils.get_setting_value(request, "api_key.huggingface")
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
@@ -229,7 +223,7 @@ async def download_model(task_id: str):
except Exception as e:
task_status.status = "pause"
task_status.error = str(e)
await socket.send_json("updateDownloadTask", task_status)
await utils.send_json("update_download_task", task_status)
task_status.error = None
logging.error(str(e))
@@ -238,11 +232,11 @@ async def download_model(task_id: str):
if status == "Waiting":
task_status = get_task_status(task_id)
task_status.status = "waiting"
await socket.send_json("updateDownloadTask", task_status)
await utils.send_json("update_download_task", task_status)
except Exception as e:
task_status.status = "pause"
task_status.error = str(e)
await socket.send_json("updateDownloadTask", task_status)
await utils.send_json("update_download_task", task_status)
task_status.error = None
logging.error(traceback.format_exc())
@@ -275,7 +269,7 @@ async def download_model_file(
time.sleep(1)
task_file = utils.join_path(download_path, f"{task_id}.task")
os.remove(task_file)
await socket.send_json("completeDownloadTask", task_id)
await utils.send_json("complete_download_task", task_id)
async def update_progress():
nonlocal last_update_time
@@ -347,7 +341,7 @@ async def download_model_file(
task_content.sizeBytes = total_size
task_status.totalSize = total_size
set_task_content(task_id, task_content)
await socket.send_json("updateDownloadTask", task_content)
await utils.send_json("update_download_task", task_content)
with open(download_tmp_file, "ab") as f:
for chunk in response.iter_content(chunk_size=8192):
@@ -366,4 +360,4 @@ async def download_model_file(
await download_complete()
else:
task_status.status = "pause"
await socket.send_json("updateDownloadTask", task_status)
await utils.send_json("update_download_task", task_status)

View File

@@ -1,37 +1,13 @@
import os
import logging
import traceback
import folder_paths
from typing import Any
from multidict import MultiDictProxy
from . import config
from . import utils
from . import socket
from . import download
async def connect_websocket(request):
async def message_handler(event_type: str, detail: Any, sid: str):
try:
if event_type == "downloadTaskList":
await download.scan_model_download_task_list(sid=sid)
if event_type == "resumeDownloadTask":
await download.download_model(task_id=detail)
if event_type == "pauseDownloadTask":
await download.pause_model_download_task(task_id=detail)
if event_type == "deleteDownloadTask":
await download.delete_model_download_task(task_id=detail)
except Exception:
logging.error(traceback.format_exc())
ws = await socket.create_websocket_handler(request, handler=message_handler)
return ws
def scan_models():
result = []
model_base_paths = config.model_base_paths
@@ -135,6 +111,22 @@ def remove_model(model_path: str):
os.remove(utils.join_path(model_dirname, description))
async def create_model_download_task(post):
async def create_model_download_task(post, request):
dict_post = dict(post)
return await download.create_model_download_task(dict_post)
return await download.create_model_download_task(dict_post, request)
async def scan_model_download_task_list():
return await download.scan_model_download_task_list()
async def pause_model_download_task(task_id):
return await download.pause_model_download_task(task_id)
async def resume_model_download_task(task_id, request):
return await download.download_model(task_id, request)
async def delete_model_download_task(task_id):
return await download.delete_model_download_task(task_id)

View File

@@ -1,63 +0,0 @@
import aiohttp
import logging
import uuid
import json
from aiohttp import web
from typing import Any, Callable, Awaitable
from . import utils
__sockets: dict[str, web.WebSocketResponse] = {}
async def create_websocket_handler(
request, handler: Callable[[str, Any, str], Awaitable[Any]]
):
ws = web.WebSocketResponse()
await ws.prepare(request)
sid = request.rel_url.query.get("clientId", "")
if sid:
# Reusing existing session, remove old
__sockets.pop(sid, None)
else:
sid = uuid.uuid4().hex
__sockets[sid] = ws
try:
async for msg in ws:
if msg.type == aiohttp.WSMsgType.ERROR:
logging.warning(
"ws connection closed with exception %s" % ws.exception()
)
if msg.type == aiohttp.WSMsgType.TEXT:
data = json.loads(msg.data)
await handler(data.get("type"), data.get("detail"), sid)
finally:
__sockets.pop(sid, None)
return ws
async def send_json(event: str, data: Any, sid: str = None):
detail = utils.unpack_dataclass(data)
message = {"type": event, "data": detail}
if sid is None:
socket_list = list(__sockets.values())
for ws in socket_list:
await __send_socket_catch_exception(ws.send_json, message)
elif sid in __sockets:
await __send_socket_catch_exception(__sockets[sid].send_json, message)
async def __send_socket_catch_exception(function, message):
try:
await function(message)
except (
aiohttp.ClientError,
aiohttp.ClientPayloadError,
ConnectionResetError,
BrokenPipeError,
ConnectionError,
) as err:
logging.warning("send error: {}".format(err))

View File

@@ -2,7 +2,6 @@ import asyncio
import threading
import queue
import logging
from . import utils
class DownloadThreadPool:
@@ -13,14 +12,7 @@ class DownloadThreadPool:
self._lock = threading.Lock()
default_max_workers = 5
max_workers: int = utils.get_setting_value(
"download.max_task_count", default_max_workers
)
if max_workers <= 0:
max_workers = default_max_workers
utils.set_setting_value("download.max_task_count", max_workers)
max_workers: int = default_max_workers
self.max_worker = max_workers
def submit(self, task, task_id):

View File

@@ -334,18 +334,16 @@ def resolve_setting_key(key: str) -> str:
return setting_id
def set_setting_value(key: str, value: Any):
def set_setting_value(request: web.Request, key: str, value: Any):
setting_id = resolve_setting_key(key)
fake_request = config.FakeRequest()
settings = config.serverInstance.user_manager.settings.get_settings(fake_request)
settings = config.serverInstance.user_manager.settings.get_settings(request)
settings[setting_id] = value
config.serverInstance.user_manager.settings.save_settings(fake_request, settings)
config.serverInstance.user_manager.settings.save_settings(request, settings)
def get_setting_value(key: str, default: Any = None) -> Any:
def get_setting_value(request: web.Request, key: str, default: Any = None) -> Any:
setting_id = resolve_setting_key(key)
fake_request = config.FakeRequest()
settings = config.serverInstance.user_manager.settings.get_settings(fake_request)
settings = config.serverInstance.user_manager.settings.get_settings(request)
return settings.get(setting_id, default)
@@ -361,3 +359,8 @@ def unpack_dataclass(data: Any):
return asdict(data)
else:
return data
async def send_json(event: str, data: Any, sid: str = None):
detail = unpack_dataclass(data)
await config.serverInstance.send_json(event, detail, sid)

View File

@@ -1,8 +1,9 @@
import { useLoading } from 'hooks/loading'
import { MarkdownTool, useMarkdown } from 'hooks/markdown'
import { socket } from 'hooks/socket'
import { request } from 'hooks/request'
import { defineStore } from 'hooks/store'
import { useToast } from 'hooks/toast'
import { api } from 'scripts/comfyAPI'
import { bytesToSize } from 'utils/common'
import { onBeforeMount, onMounted, ref, watch } from 'vue'
import { useI18n } from 'vue-i18n'
@@ -13,10 +14,6 @@ export const useDownload = defineStore('download', (store) => {
const taskList = ref<DownloadTask[]>([])
const refresh = () => {
socket.send('downloadTaskList', null)
}
const createTaskItem = (item: DownloadTaskOptions) => {
const { downloadedSize, totalSize, bps, ...rest } = item
@@ -26,10 +23,20 @@ export const useDownload = defineStore('download', (store) => {
downloadProgress: `${bytesToSize(downloadedSize)} / ${bytesToSize(totalSize)}`,
downloadSpeed: `${bytesToSize(bps)}/s`,
pauseTask() {
socket.send('pauseDownloadTask', item.taskId)
request(`/download/${item.taskId}`, {
method: 'PUT',
body: JSON.stringify({
status: 'pause',
}),
})
},
resumeTask: () => {
socket.send('resumeDownloadTask', item.taskId)
request(`/download/${item.taskId}`, {
method: 'PUT',
body: JSON.stringify({
status: 'resume',
}),
})
},
deleteTask: () => {
confirm.require({
@@ -46,7 +53,9 @@ export const useDownload = defineStore('download', (store) => {
severity: 'danger',
},
accept: () => {
socket.send('deleteDownloadTask', item.taskId)
request(`/download/${item.taskId}`, {
method: 'DELETE',
})
},
reject: () => {},
})
@@ -56,12 +65,28 @@ export const useDownload = defineStore('download', (store) => {
return task
}
const refresh = async () => {
return request('/download/task')
.then((resData: DownloadTaskOptions[]) => {
taskList.value = resData.map((item) => createTaskItem(item))
return taskList.value
})
.catch((err) => {
toast.add({
severity: 'error',
summary: 'Error',
detail: err.message ?? 'Failed to refresh download task list',
life: 15000,
})
})
}
onBeforeMount(() => {
socket.addEventListener('reconnected', () => {
api.addEventListener('reconnected', () => {
refresh()
})
socket.addEventListener('downloadTaskList', (event) => {
api.addEventListener('fetch_download_task_list', (event) => {
const data = event.detail as DownloadTaskOptions[]
taskList.value = data.map((item) => {
@@ -69,12 +94,12 @@ export const useDownload = defineStore('download', (store) => {
})
})
socket.addEventListener('createDownloadTask', (event) => {
api.addEventListener('create_download_task', (event) => {
const item = event.detail as DownloadTaskOptions
taskList.value.unshift(createTaskItem(item))
})
socket.addEventListener('updateDownloadTask', (event) => {
api.addEventListener('update_download_task', (event) => {
const item = event.detail as DownloadTaskOptions
for (const task of taskList.value) {
@@ -93,12 +118,12 @@ export const useDownload = defineStore('download', (store) => {
}
})
socket.addEventListener('deleteDownloadTask', (event) => {
api.addEventListener('delete_download_task', (event) => {
const taskId = event.detail as string
taskList.value = taskList.value.filter((item) => item.taskId !== taskId)
})
socket.addEventListener('completeDownloadTask', (event) => {
api.addEventListener('complete_download_task', (event) => {
const taskId = event.detail as string
const task = taskList.value.find((item) => item.taskId === taskId)
taskList.value = taskList.value.filter((item) => item.taskId !== taskId)

View File

@@ -1,82 +0,0 @@
import { globalToast } from 'hooks/toast'
import { readonly } from 'vue'
class WebSocketEvent extends EventTarget {
private socket: WebSocket | null
constructor() {
super()
this.createSocket()
}
private createSocket(isReconnect?: boolean) {
const api_host = location.host
const api_base = location.pathname.split('/').slice(0, -1).join('/')
let opened = false
let existingSession = window.name
if (existingSession) {
existingSession = '?clientId=' + existingSession
}
this.socket = readonly(
new WebSocket(
`ws${window.location.protocol === 'https:' ? 's' : ''}://${api_host}${api_base}/model-manager/ws${existingSession}`,
),
)
this.socket.addEventListener('open', () => {
opened = true
if (isReconnect) {
this.dispatchEvent(new CustomEvent('reconnected'))
}
})
this.socket.addEventListener('error', () => {
if (this.socket) this.socket.close()
})
this.socket.addEventListener('close', (event) => {
setTimeout(() => {
this.socket = null
this.createSocket(true)
}, 300)
if (opened) {
this.dispatchEvent(new CustomEvent('status', { detail: null }))
this.dispatchEvent(new CustomEvent('reconnecting'))
}
})
this.socket.addEventListener('message', (event) => {
try {
const msg = JSON.parse(event.data)
if (msg.type === 'error') {
globalToast.value?.add({
severity: 'error',
summary: 'Error',
detail: msg.data,
life: 15000,
})
} else {
this.dispatchEvent(new CustomEvent(msg.type, { detail: msg.data }))
}
} catch (error) {
console.error(error)
}
})
}
addEventListener = (
type: string,
callback: CustomEventListener | null,
options?: AddEventListenerOptions | boolean,
) => {
super.addEventListener(type, callback, options)
}
send(type: string, data: any) {
this.socket?.send(JSON.stringify({ type, detail: data }))
}
}
export const socket = new WebSocketEvent()