feat: adapt to multi user

This commit is contained in:
hayden
2024-11-08 11:13:01 +08:00
parent ae518b541a
commit a1e5761dbc
9 changed files with 154 additions and 257 deletions

View File

@@ -1,37 +1,13 @@
import os
import logging
import traceback
import folder_paths
from typing import Any
from multidict import MultiDictProxy
from . import config
from . import utils
from . import socket
from . import download
async def connect_websocket(request):
async def message_handler(event_type: str, detail: Any, sid: str):
try:
if event_type == "downloadTaskList":
await download.scan_model_download_task_list(sid=sid)
if event_type == "resumeDownloadTask":
await download.download_model(task_id=detail)
if event_type == "pauseDownloadTask":
await download.pause_model_download_task(task_id=detail)
if event_type == "deleteDownloadTask":
await download.delete_model_download_task(task_id=detail)
except Exception:
logging.error(traceback.format_exc())
ws = await socket.create_websocket_handler(request, handler=message_handler)
return ws
def scan_models():
result = []
model_base_paths = config.model_base_paths
@@ -135,6 +111,22 @@ def remove_model(model_path: str):
os.remove(utils.join_path(model_dirname, description))
async def create_model_download_task(post):
async def create_model_download_task(post, request):
dict_post = dict(post)
return await download.create_model_download_task(dict_post)
return await download.create_model_download_task(dict_post, request)
async def scan_model_download_task_list():
return await download.scan_model_download_task_list()
async def pause_model_download_task(task_id):
return await download.pause_model_download_task(task_id)
async def resume_model_download_task(task_id, request):
return await download.download_model(task_id, request)
async def delete_model_download_task(task_id):
return await download.delete_model_download_task(task_id)