23 Commits

Author SHA1 Message Date
Hayden
e36af38375 prepare release 2.0.3 2024-11-11 11:13:24 +08:00
Hayden
d4922f59d3 Merge pull request #50 from hayden-fr/feature-optimize-ui
Feature optimize UI
2024-11-11 11:11:30 +08:00
Hayden
f2e17744ae Merge pull request #47 from hayden-fr/feature-multi-user
feat: adapt to multi user
2024-11-11 11:11:09 +08:00
hayden
3b25d3e347 pref: optimize the timing of scrollbar reset 2024-11-08 12:42:00 +08:00
hayden
3a0676b29f pref(download): keep model content status 2024-11-08 11:49:18 +08:00
hayden
a1e5761dbc feat: adapt to multi user 2024-11-08 11:13:01 +08:00
hayden
ae518b541a chore(download): add todo notes 2024-11-07 09:42:37 +08:00
hayden
f22fbd46ad prepare release 2.0.2 2024-11-07 08:50:55 +08:00
Hayden
8c3a001657 Merge pull request #46 from hayden-fr/develop
Update: resolving windows issue and enhancing model display
2024-11-07 08:47:49 +08:00
hayden
d052d9dceb fix: optimize markdown style 2024-11-06 17:15:21 +08:00
hayden
652721ac9a fix: hide action button until mouseover 2024-11-06 16:03:28 +08:00
hayden
cfd2bdea4a fix(ResponseInput): unable input any text 2024-11-06 15:55:49 +08:00
hayden
b8cd3c28a5 fix: bug in verification update description error 2024-11-06 15:41:22 +08:00
hayden
153dbc0788 fix: issue saving differences across platforms 2024-11-06 15:39:06 +08:00
hayden
288f026d47 feat: add display of directory information 2024-11-06 13:51:38 +08:00
hayden
0a8c532506 feat: optimize model editing
- close dialog after delete or rename
- keep editing if model update fails
- show more error message
2024-11-05 17:02:10 +08:00
hayden
8bfe601588 chore: optimize development address 2024-11-05 16:46:30 +08:00
hayden
7a183464ae fix: cross-platform paths 2024-11-05 16:44:41 +08:00
hayden
f9b0afcbf5 chore: prepare publish 2.0.1 2024-11-05 09:34:14 +08:00
Hayden
1f4c55ab89 Merge pull request #42 from hayden-fr/hotfix
fix: Cross-device movement
2024-11-04 12:05:28 +08:00
hayden
da1ec3a52c fix: Cross-device movement 2024-11-04 12:01:27 +08:00
Hayden
79b106d986 Merge pull request #41 from sansmoraxz/patch-1
Fix image path resolution for windows
2024-11-04 11:13:20 +08:00
Souyama
4c1af63d0d Update utils.py
Fix image resolution for windows
2024-11-03 18:16:45 +05:30
22 changed files with 378 additions and 362 deletions

View File

@@ -5,7 +5,7 @@ from .py import utils
# Init config settings # Init config settings
config.extension_uri = os.path.dirname(__file__) config.extension_uri = utils.normalize_path(os.path.dirname(__file__))
utils.resolve_model_base_paths() utils.resolve_model_base_paths()
version = utils.get_current_version() version = utils.get_current_version()
@@ -21,13 +21,61 @@ from .py import services
routes = config.routes routes = config.routes
@routes.get("/model-manager/ws") @routes.get("/model-manager/download/task")
async def socket_handler(request): async def scan_download_tasks(request):
""" """
Handle websocket connection. Read download task list.
""" """
ws = await services.connect_websocket(request) try:
return ws result = await services.scan_model_download_task_list()
return web.json_response({"success": True, "data": result})
except Exception as e:
error_msg = f"Read download task list failed: {e}"
logging.error(error_msg)
logging.debug(traceback.format_exc())
return web.json_response({"success": False, "error": error_msg})
@routes.put("/model-manager/download/{task_id}")
async def resume_download_task(request):
"""
Toggle download task status.
"""
try:
task_id = request.match_info.get("task_id", None)
if task_id is None:
raise web.HTTPBadRequest(reason="Invalid task id")
json_data = await request.json()
status = json_data.get("status", None)
if status == "pause":
await services.pause_model_download_task(task_id)
elif status == "resume":
await services.resume_model_download_task(task_id, request)
else:
raise web.HTTPBadRequest(reason="Invalid status")
return web.json_response({"success": True})
except Exception as e:
error_msg = f"Resume download task failed: {str(e)}"
logging.error(error_msg)
logging.debug(traceback.format_exc())
return web.json_response({"success": False, "error": error_msg})
@routes.delete("/model-manager/download/{task_id}")
async def delete_model_download_task(request):
"""
Delete download task.
"""
task_id = request.match_info.get("task_id", None)
try:
await services.delete_model_download_task(task_id)
return web.json_response({"success": True})
except Exception as e:
error_msg = f"Delete download task failed: {str(e)}"
logging.error(error_msg)
logging.debug(traceback.format_exc())
return web.json_response({"success": False, "error": error_msg})
@routes.get("/model-manager/base-folders") @routes.get("/model-manager/base-folders")
@@ -56,7 +104,7 @@ async def create_model(request):
""" """
post = await request.post() post = await request.post()
try: try:
task_id = await services.create_model_download_task(post) task_id = await services.create_model_download_task(post, request)
return web.json_response({"success": True, "data": {"taskId": task_id}}) return web.json_response({"success": True, "data": {"taskId": task_id}})
except Exception as e: except Exception as e:
error_msg = f"Create model download task failed: {str(e)}" error_msg = f"Create model download task failed: {str(e)}"
@@ -173,12 +221,12 @@ async def read_model_preview(request):
try: try:
folders = folder_paths.get_folder_paths(model_type) folders = folder_paths.get_folder_paths(model_type)
base_path = folders[index] base_path = folders[index]
abs_path = os.path.join(base_path, filename) abs_path = utils.join_path(base_path, filename)
except: except:
abs_path = extension_uri abs_path = extension_uri
if not os.path.isfile(abs_path): if not os.path.isfile(abs_path):
abs_path = os.path.join(extension_uri, "assets", "no-preview.png") abs_path = utils.join_path(extension_uri, "assets", "no-preview.png")
return web.FileResponse(abs_path) return web.FileResponse(abs_path)
@@ -188,10 +236,10 @@ async def read_download_preview(request):
extension_uri = config.extension_uri extension_uri = config.extension_uri
download_path = utils.get_download_path() download_path = utils.get_download_path()
preview_path = os.path.join(download_path, filename) preview_path = utils.join_path(download_path, filename)
if not os.path.isfile(preview_path): if not os.path.isfile(preview_path):
preview_path = os.path.join(extension_uri, "assets", "no-preview.png") preview_path = utils.join_path(extension_uri, "assets", "no-preview.png")
return web.FileResponse(preview_path) return web.FileResponse(preview_path)

View File

@@ -19,15 +19,3 @@ from server import PromptServer
serverInstance = PromptServer.instance serverInstance = PromptServer.instance
routes = serverInstance.routes routes = serverInstance.routes
class FakeRequest:
def __init__(self):
self.headers = {}
class CustomException(BaseException):
def __init__(self, type: str, message: str = None) -> None:
self.type = type
self.message = message
super().__init__(message)

View File

@@ -9,7 +9,6 @@ from typing import Callable, Awaitable, Any, Literal, Union, Optional
from dataclasses import dataclass from dataclasses import dataclass
from . import config from . import config
from . import utils from . import utils
from . import socket
from . import thread from . import thread
@@ -46,13 +45,13 @@ download_thread_pool = thread.DownloadThreadPool()
def set_task_content(task_id: str, task_content: Union[TaskContent, dict]): def set_task_content(task_id: str, task_content: Union[TaskContent, dict]):
download_path = utils.get_download_path() download_path = utils.get_download_path()
task_file_path = os.path.join(download_path, f"{task_id}.task") task_file_path = utils.join_path(download_path, f"{task_id}.task")
utils.save_dict_pickle_file(task_file_path, utils.unpack_dataclass(task_content)) utils.save_dict_pickle_file(task_file_path, utils.unpack_dataclass(task_content))
def get_task_content(task_id: str): def get_task_content(task_id: str):
download_path = utils.get_download_path() download_path = utils.get_download_path()
task_file = os.path.join(download_path, f"{task_id}.task") task_file = utils.join_path(download_path, f"{task_id}.task")
if not os.path.isfile(task_file): if not os.path.isfile(task_file):
raise RuntimeError(f"Task {task_id} not found") raise RuntimeError(f"Task {task_id} not found")
task_content = utils.load_dict_pickle_file(task_file) task_content = utils.load_dict_pickle_file(task_file)
@@ -67,7 +66,7 @@ def get_task_status(task_id: str):
if task_status is None: if task_status is None:
download_path = utils.get_download_path() download_path = utils.get_download_path()
task_content = get_task_content(task_id) task_content = get_task_content(task_id)
download_file = os.path.join(download_path, f"{task_id}.download") download_file = utils.join_path(download_path, f"{task_id}.download")
download_size = 0 download_size = 0
if os.path.exists(download_file): if os.path.exists(download_file):
download_size = os.path.getsize(download_file) download_size = os.path.getsize(download_file)
@@ -93,33 +92,28 @@ def delete_task_status(task_id: str):
download_model_task_status.pop(task_id, None) download_model_task_status.pop(task_id, None)
async def scan_model_download_task_list(sid: str): async def scan_model_download_task_list():
""" """
Scan the download directory and send the task list to the client. Scan the download directory and send the task list to the client.
""" """
try: download_dir = utils.get_download_path()
download_dir = utils.get_download_path() task_files = utils.search_files(download_dir)
task_files = utils.search_files(download_dir) task_files = folder_paths.filter_files_extensions(task_files, [".task"])
task_files = folder_paths.filter_files_extensions(task_files, [".task"]) task_files = sorted(
task_files = sorted( task_files,
task_files, key=lambda x: os.stat(utils.join_path(download_dir, x)).st_ctime,
key=lambda x: os.stat(os.path.join(download_dir, x)).st_ctime, reverse=True,
reverse=True, )
) task_list: list[dict] = []
task_list: list[dict] = [] for task_file in task_files:
for task_file in task_files: task_id = task_file.replace(".task", "")
task_id = task_file.replace(".task", "") task_status = get_task_status(task_id)
task_status = get_task_status(task_id) task_list.append(task_status)
task_list.append(task_status)
await socket.send_json("downloadTaskList", task_list, sid) return utils.unpack_dataclass(task_list)
except Exception as e:
error_msg = f"Refresh task list failed: {e}"
await socket.send_json("error", error_msg, sid)
logging.error(error_msg)
async def create_model_download_task(post: dict): async def create_model_download_task(post: dict, request):
""" """
Creates a download task for the given post. Creates a download task for the given post.
""" """
@@ -135,7 +129,7 @@ async def create_model_download_task(post: dict):
download_path = utils.get_download_path() download_path = utils.get_download_path()
task_id = uuid.uuid4().hex task_id = uuid.uuid4().hex
task_path = os.path.join(download_path, f"{task_id}.task") task_path = utils.join_path(download_path, f"{task_id}.task")
if os.path.exists(task_path): if os.path.exists(task_path):
raise RuntimeError(f"Task {task_id} already exists") raise RuntimeError(f"Task {task_id} already exists")
@@ -152,12 +146,12 @@ async def create_model_download_task(post: dict):
totalSize=float(post.get("sizeBytes", 0)), totalSize=float(post.get("sizeBytes", 0)),
) )
download_model_task_status[task_id] = task_status download_model_task_status[task_id] = task_status
await socket.send_json("createDownloadTask", task_status) await utils.send_json("create_download_task", task_status)
except Exception as e: except Exception as e:
await delete_model_download_task(task_id) await delete_model_download_task(task_id)
raise RuntimeError(str(e)) from e raise RuntimeError(str(e)) from e
await download_model(task_id) await download_model(task_id, request)
return task_id return task_id
@@ -170,7 +164,7 @@ async def delete_model_download_task(task_id: str):
task_status = get_task_status(task_id) task_status = get_task_status(task_id)
is_running = task_status.status == "doing" is_running = task_status.status == "doing"
task_status.status = "waiting" task_status.status = "waiting"
await socket.send_json("deleteDownloadTask", task_id) await utils.send_json("delete_download_task", task_id)
# Pause the task # Pause the task
if is_running: if is_running:
@@ -183,15 +177,15 @@ async def delete_model_download_task(task_id: str):
task_file_target = os.path.splitext(task_file)[0] task_file_target = os.path.splitext(task_file)[0]
if task_file_target == task_id: if task_file_target == task_id:
delete_task_status(task_id) delete_task_status(task_id)
os.remove(os.path.join(download_dir, task_file)) os.remove(utils.join_path(download_dir, task_file))
await socket.send_json("deleteDownloadTask", task_id) await utils.send_json("delete_download_task", task_id)
async def download_model(task_id: str): async def download_model(task_id: str, request):
async def download_task(task_id: str): async def download_task(task_id: str):
async def report_progress(task_status: TaskStatus): async def report_progress(task_status: TaskStatus):
await socket.send_json("updateDownloadTask", task_status) await utils.send_json("update_download_task", task_status)
try: try:
# When starting a task from the queue, the task may not exist # When starting a task from the queue, the task may not exist
@@ -201,7 +195,7 @@ async def download_model(task_id: str):
# Update task status # Update task status
task_status.status = "doing" task_status.status = "doing"
await socket.send_json("updateDownloadTask", task_status) await utils.send_json("update_download_task", task_status)
try: try:
@@ -210,12 +204,12 @@ async def download_model(task_id: str):
download_platform = task_status.platform download_platform = task_status.platform
if download_platform == "civitai": if download_platform == "civitai":
api_key = utils.get_setting_value("api_key.civitai") api_key = utils.get_setting_value(request, "api_key.civitai")
if api_key: if api_key:
headers["Authorization"] = f"Bearer {api_key}" headers["Authorization"] = f"Bearer {api_key}"
elif download_platform == "huggingface": elif download_platform == "huggingface":
api_key = utils.get_setting_value("api_key.huggingface") api_key = utils.get_setting_value(request, "api_key.huggingface")
if api_key: if api_key:
headers["Authorization"] = f"Bearer {api_key}" headers["Authorization"] = f"Bearer {api_key}"
@@ -229,7 +223,7 @@ async def download_model(task_id: str):
except Exception as e: except Exception as e:
task_status.status = "pause" task_status.status = "pause"
task_status.error = str(e) task_status.error = str(e)
await socket.send_json("updateDownloadTask", task_status) await utils.send_json("update_download_task", task_status)
task_status.error = None task_status.error = None
logging.error(str(e)) logging.error(str(e))
@@ -238,11 +232,11 @@ async def download_model(task_id: str):
if status == "Waiting": if status == "Waiting":
task_status = get_task_status(task_id) task_status = get_task_status(task_id)
task_status.status = "waiting" task_status.status = "waiting"
await socket.send_json("updateDownloadTask", task_status) await utils.send_json("update_download_task", task_status)
except Exception as e: except Exception as e:
task_status.status = "pause" task_status.status = "pause"
task_status.error = str(e) task_status.error = str(e)
await socket.send_json("updateDownloadTask", task_status) await utils.send_json("update_download_task", task_status)
task_status.error = None task_status.error = None
logging.error(traceback.format_exc()) logging.error(traceback.format_exc())
@@ -264,8 +258,8 @@ async def download_model_file(
fullname = task_content.fullname fullname = task_content.fullname
# Write description file # Write description file
description = task_content.description description = task_content.description
description_file = os.path.join(download_path, f"{task_id}.md") description_file = utils.join_path(download_path, f"{task_id}.md")
with open(description_file, "w") as f: with open(description_file, "w", encoding="utf-8", newline="") as f:
f.write(description) f.write(description)
model_path = utils.get_full_path(model_type, path_index, fullname) model_path = utils.get_full_path(model_type, path_index, fullname)
@@ -273,9 +267,9 @@ async def download_model_file(
utils.rename_model(download_tmp_file, model_path) utils.rename_model(download_tmp_file, model_path)
time.sleep(1) time.sleep(1)
task_file = os.path.join(download_path, f"{task_id}.task") task_file = utils.join_path(download_path, f"{task_id}.task")
os.remove(task_file) os.remove(task_file)
await socket.send_json("completeDownloadTask", task_id) await utils.send_json("complete_download_task", task_id)
async def update_progress(): async def update_progress():
nonlocal last_update_time nonlocal last_update_time
@@ -297,7 +291,7 @@ async def download_model_file(
raise RuntimeError("No downloadUrl found") raise RuntimeError("No downloadUrl found")
download_path = utils.get_download_path() download_path = utils.get_download_path()
download_tmp_file = os.path.join(download_path, f"{task_id}.download") download_tmp_file = utils.join_path(download_path, f"{task_id}.download")
downloaded_size = 0 downloaded_size = 0
if os.path.isfile(download_tmp_file): if os.path.isfile(download_tmp_file):
@@ -329,6 +323,13 @@ async def download_model_file(
# If no token is carried, it will be redirected to the login page. # If no token is carried, it will be redirected to the login page.
content_type = response.headers.get("content-type") content_type = response.headers.get("content-type")
if content_type and content_type.startswith("text/html"): if content_type and content_type.startswith("text/html"):
# TODO More checks
# In addition to requiring login to download, there may be other restrictions.
# The currently one situation is early access??? issues#43
# Due to the lack of test data, lets put it aside for now.
# If it cannot be downloaded, a redirect will definitely occur.
# Maybe consider getting the redirect url from response.history to make a judgment.
# Here we also need to consider how different websites are processed.
raise RuntimeError( raise RuntimeError(
f"{task_content.fullname} needs to be logged in to download. Please set the API-Key first." f"{task_content.fullname} needs to be logged in to download. Please set the API-Key first."
) )
@@ -340,7 +341,7 @@ async def download_model_file(
task_content.sizeBytes = total_size task_content.sizeBytes = total_size
task_status.totalSize = total_size task_status.totalSize = total_size
set_task_content(task_id, task_content) set_task_content(task_id, task_content)
await socket.send_json("updateDownloadTask", task_content) await utils.send_json("update_download_task", task_content)
with open(download_tmp_file, "ab") as f: with open(download_tmp_file, "ab") as f:
for chunk in response.iter_content(chunk_size=8192): for chunk in response.iter_content(chunk_size=8192):
@@ -359,4 +360,4 @@ async def download_model_file(
await download_complete() await download_complete()
else: else:
task_status.status = "pause" task_status.status = "pause"
await socket.send_json("updateDownloadTask", task_status) await utils.send_json("update_download_task", task_status)

View File

@@ -1,37 +1,13 @@
import os import os
import logging
import traceback
import folder_paths import folder_paths
from typing import Any
from multidict import MultiDictProxy from multidict import MultiDictProxy
from . import config from . import config
from . import utils from . import utils
from . import socket
from . import download from . import download
async def connect_websocket(request):
async def message_handler(event_type: str, detail: Any, sid: str):
try:
if event_type == "downloadTaskList":
await download.scan_model_download_task_list(sid=sid)
if event_type == "resumeDownloadTask":
await download.download_model(task_id=detail)
if event_type == "pauseDownloadTask":
await download.pause_model_download_task(task_id=detail)
if event_type == "deleteDownloadTask":
await download.delete_model_download_task(task_id=detail)
except Exception:
logging.error(traceback.format_exc())
ws = await socket.create_websocket_handler(request, handler=message_handler)
return ws
def scan_models(): def scan_models():
result = [] result = []
model_base_paths = config.model_base_paths model_base_paths = config.model_base_paths
@@ -46,16 +22,16 @@ def scan_models():
image_dict = utils.file_list_to_name_dict(images) image_dict = utils.file_list_to_name_dict(images)
for fullname in models: for fullname in models:
fullname = fullname.replace(os.path.sep, "/") fullname = utils.normalize_path(fullname)
basename = os.path.splitext(fullname)[0] basename = os.path.splitext(fullname)[0]
extension = os.path.splitext(fullname)[1] extension = os.path.splitext(fullname)[1]
abs_path = os.path.join(base_path, fullname) abs_path = utils.join_path(base_path, fullname)
file_stats = os.stat(abs_path) file_stats = os.stat(abs_path)
# Resolve preview # Resolve preview
image_name = image_dict.get(basename, "no-preview.png") image_name = image_dict.get(basename, "no-preview.png")
abs_image_path = os.path.join(base_path, image_name) abs_image_path = utils.join_path(base_path, image_name)
if os.path.isfile(abs_image_path): if os.path.isfile(abs_image_path):
image_state = os.stat(abs_image_path) image_state = os.stat(abs_image_path)
image_timestamp = round(image_state.st_mtime_ns / 1000000) image_timestamp = round(image_state.st_mtime_ns / 1000000)
@@ -87,10 +63,10 @@ def get_model_info(model_path: str):
metadata = utils.get_model_metadata(model_path) metadata = utils.get_model_metadata(model_path)
description_file = utils.get_model_description_name(model_path) description_file = utils.get_model_description_name(model_path)
description_file = os.path.join(directory, description_file) description_file = utils.join_path(directory, description_file)
description = None description = None
if os.path.isfile(description_file): if os.path.isfile(description_file):
with open(description_file, "r", encoding="utf-8") as f: with open(description_file, "r", encoding="utf-8", newline="") as f:
description = f.read() description = f.read()
return { return {
@@ -128,13 +104,29 @@ def remove_model(model_path: str):
model_previews = utils.get_model_all_images(model_path) model_previews = utils.get_model_all_images(model_path)
for preview in model_previews: for preview in model_previews:
os.remove(os.path.join(model_dirname, preview)) os.remove(utils.join_path(model_dirname, preview))
model_descriptions = utils.get_model_all_descriptions(model_path) model_descriptions = utils.get_model_all_descriptions(model_path)
for description in model_descriptions: for description in model_descriptions:
os.remove(os.path.join(model_dirname, description)) os.remove(utils.join_path(model_dirname, description))
async def create_model_download_task(post): async def create_model_download_task(post, request):
dict_post = dict(post) dict_post = dict(post)
return await download.create_model_download_task(dict_post) return await download.create_model_download_task(dict_post, request)
async def scan_model_download_task_list():
return await download.scan_model_download_task_list()
async def pause_model_download_task(task_id):
return await download.pause_model_download_task(task_id)
async def resume_model_download_task(task_id, request):
return await download.download_model(task_id, request)
async def delete_model_download_task(task_id):
return await download.delete_model_download_task(task_id)

View File

@@ -1,63 +0,0 @@
import aiohttp
import logging
import uuid
import json
from aiohttp import web
from typing import Any, Callable, Awaitable
from . import utils
__sockets: dict[str, web.WebSocketResponse] = {}
async def create_websocket_handler(
request, handler: Callable[[str, Any, str], Awaitable[Any]]
):
ws = web.WebSocketResponse()
await ws.prepare(request)
sid = request.rel_url.query.get("clientId", "")
if sid:
# Reusing existing session, remove old
__sockets.pop(sid, None)
else:
sid = uuid.uuid4().hex
__sockets[sid] = ws
try:
async for msg in ws:
if msg.type == aiohttp.WSMsgType.ERROR:
logging.warning(
"ws connection closed with exception %s" % ws.exception()
)
if msg.type == aiohttp.WSMsgType.TEXT:
data = json.loads(msg.data)
await handler(data.get("type"), data.get("detail"), sid)
finally:
__sockets.pop(sid, None)
return ws
async def send_json(event: str, data: Any, sid: str = None):
detail = utils.unpack_dataclass(data)
message = {"type": event, "data": detail}
if sid is None:
socket_list = list(__sockets.values())
for ws in socket_list:
await __send_socket_catch_exception(ws.send_json, message)
elif sid in __sockets:
await __send_socket_catch_exception(__sockets[sid].send_json, message)
async def __send_socket_catch_exception(function, message):
try:
await function(message)
except (
aiohttp.ClientError,
aiohttp.ClientPayloadError,
ConnectionResetError,
BrokenPipeError,
ConnectionError,
) as err:
logging.warning("send error: {}".format(err))

View File

@@ -2,7 +2,6 @@ import asyncio
import threading import threading
import queue import queue
import logging import logging
from . import utils
class DownloadThreadPool: class DownloadThreadPool:
@@ -13,14 +12,7 @@ class DownloadThreadPool:
self._lock = threading.Lock() self._lock = threading.Lock()
default_max_workers = 5 default_max_workers = 5
max_workers: int = utils.get_setting_value( max_workers: int = default_max_workers
"download.max_task_count", default_max_workers
)
if max_workers <= 0:
max_workers = default_max_workers
utils.set_setting_value("download.max_task_count", max_workers)
self.max_worker = max_workers self.max_worker = max_workers
def submit(self, task, task_id): def submit(self, task, task_id):

View File

@@ -15,9 +15,18 @@ from typing import Any
from . import config from . import config
def normalize_path(path: str):
normpath = os.path.normpath(path)
return normpath.replace(os.path.sep, "/")
def join_path(path: str, *paths: list[str]):
return normalize_path(os.path.join(path, *paths))
def get_current_version(): def get_current_version():
try: try:
pyproject_path = os.path.join(config.extension_uri, "pyproject.toml") pyproject_path = join_path(config.extension_uri, "pyproject.toml")
config_parser = configparser.ConfigParser() config_parser = configparser.ConfigParser()
config_parser.read(pyproject_path) config_parser.read(pyproject_path)
version = config_parser.get("project", "version") version = config_parser.get("project", "version")
@@ -27,15 +36,15 @@ def get_current_version():
def download_web_distribution(version: str): def download_web_distribution(version: str):
web_path = os.path.join(config.extension_uri, "web") web_path = join_path(config.extension_uri, "web")
dev_web_file = os.path.join(web_path, "manager-dev.js") dev_web_file = join_path(web_path, "manager-dev.js")
if os.path.exists(dev_web_file): if os.path.exists(dev_web_file):
return return
web_version = "0.0.0" web_version = "0.0.0"
version_file = os.path.join(web_path, "version.yaml") version_file = join_path(web_path, "version.yaml")
if os.path.exists(version_file): if os.path.exists(version_file):
with open(version_file, "r") as f: with open(version_file, "r", encoding="utf-8", newline="") as f:
version_content = yaml.safe_load(f) version_content = yaml.safe_load(f)
web_version = version_content.get("version", web_version) web_version = version_content.get("version", web_version)
@@ -49,7 +58,7 @@ def download_web_distribution(version: str):
response = requests.get(download_url, stream=True) response = requests.get(download_url, stream=True)
response.raise_for_status() response.raise_for_status()
temp_file = os.path.join(config.extension_uri, "temp.tar.gz") temp_file = join_path(config.extension_uri, "temp.tar.gz")
with open(temp_file, "wb") as f: with open(temp_file, "wb") as f:
for chunk in response.iter_content(chunk_size=8192): for chunk in response.iter_content(chunk_size=8192):
f.write(chunk) f.write(chunk)
@@ -82,7 +91,8 @@ def resolve_model_base_paths():
continue continue
if folder == "custom_nodes": if folder == "custom_nodes":
continue continue
config.model_base_paths[folder] = folder_paths.get_folder_paths(folder) folders = folder_paths.get_folder_paths(folder)
config.model_base_paths[folder] = [normalize_path(f) for f in folders]
def get_full_path(model_type: str, path_index: int, filename: str): def get_full_path(model_type: str, path_index: int, filename: str):
@@ -93,7 +103,8 @@ def get_full_path(model_type: str, path_index: int, filename: str):
if not path_index < len(folders): if not path_index < len(folders):
raise RuntimeError(f"PathIndex {path_index} is not in {model_type}") raise RuntimeError(f"PathIndex {path_index} is not in {model_type}")
base_path = folders[path_index] base_path = folders[path_index]
return os.path.join(base_path, filename) full_path = join_path(base_path, filename)
return full_path
def get_valid_full_path(model_type: str, path_index: int, filename: str): def get_valid_full_path(model_type: str, path_index: int, filename: str):
@@ -104,7 +115,7 @@ def get_valid_full_path(model_type: str, path_index: int, filename: str):
if not path_index < len(folders): if not path_index < len(folders):
raise RuntimeError(f"PathIndex {path_index} is not in {model_type}") raise RuntimeError(f"PathIndex {path_index} is not in {model_type}")
base_path = folders[path_index] base_path = folders[path_index]
full_path = os.path.join(base_path, filename) full_path = join_path(base_path, filename)
if os.path.isfile(full_path): if os.path.isfile(full_path):
return full_path return full_path
elif os.path.islink(full_path): elif os.path.islink(full_path):
@@ -114,7 +125,7 @@ def get_valid_full_path(model_type: str, path_index: int, filename: str):
def get_download_path(): def get_download_path():
download_path = os.path.join(config.extension_uri, "downloads") download_path = join_path(config.extension_uri, "downloads")
if not os.path.exists(download_path): if not os.path.exists(download_path):
os.makedirs(download_path) os.makedirs(download_path)
return download_path return download_path
@@ -124,12 +135,12 @@ def recursive_search_files(directory: str):
files, folder_all = folder_paths.recursive_search( files, folder_all = folder_paths.recursive_search(
directory, excluded_dir_names=[".git"] directory, excluded_dir_names=[".git"]
) )
return files return [normalize_path(f) for f in files]
def search_files(directory: str): def search_files(directory: str):
entries = os.listdir(directory) entries = os.listdir(directory)
files = [f for f in entries if os.path.isfile(os.path.join(directory, f))] files = [f for f in entries if os.path.isfile(join_path(directory, f))]
return files return files
@@ -193,13 +204,13 @@ def save_model_preview_image(model_path: str, image_file: Any):
for image in old_preview_images: for image in old_preview_images:
if os.path.splitext(image)[1].endswith(".preview"): if os.path.splitext(image)[1].endswith(".preview"):
a1111_civitai_helper_image = True a1111_civitai_helper_image = True
image_path = os.path.join(base_dirname, image) image_path = join_path(base_dirname, image)
os.remove(image_path) os.remove(image_path)
# save new preview image # save new preview image
basename = os.path.splitext(os.path.basename(model_path))[0] basename = os.path.splitext(os.path.basename(model_path))[0]
extension = f".{content_type.split('/')[1]}" extension = f".{content_type.split('/')[1]}"
new_preview_path = os.path.join(base_dirname, f"{basename}{extension}") new_preview_path = join_path(base_dirname, f"{basename}{extension}")
with open(new_preview_path, "wb") as f: with open(new_preview_path, "wb") as f:
f.write(image_file.file.read()) f.write(image_file.file.read())
@@ -209,7 +220,7 @@ def save_model_preview_image(model_path: str, image_file: Any):
""" """
Keep preview image of a1111_civitai_helper Keep preview image of a1111_civitai_helper
""" """
new_preview_path = os.path.join(base_dirname, f"{basename}.preview{extension}") new_preview_path = join_path(base_dirname, f"{basename}.preview{extension}")
with open(new_preview_path, "wb") as f: with open(new_preview_path, "wb") as f:
f.write(image_file.file.read()) f.write(image_file.file.read())
@@ -243,15 +254,15 @@ def save_model_description(model_path: str, content: Any):
# remove old descriptions # remove old descriptions
old_descriptions = get_model_all_descriptions(model_path) old_descriptions = get_model_all_descriptions(model_path)
for desc in old_descriptions: for desc in old_descriptions:
description_path = os.path.join(base_dirname, desc) description_path = join_path(base_dirname, desc)
os.remove(description_path) os.remove(description_path)
# save new description # save new description
basename = os.path.splitext(os.path.basename(model_path))[0] basename = os.path.splitext(os.path.basename(model_path))[0]
extension = ".md" extension = ".md"
new_desc_path = os.path.join(base_dirname, f"{basename}{extension}") new_desc_path = join_path(base_dirname, f"{basename}{extension}")
with open(new_desc_path, "w", encoding="utf-8") as f: with open(new_desc_path, "w", encoding="utf-8", newline="") as f:
f.write(content) f.write(content)
@@ -272,29 +283,27 @@ def rename_model(model_path: str, new_model_path: str):
os.makedirs(new_model_dirname) os.makedirs(new_model_dirname)
# move model # move model
os.rename(model_path, new_model_path) shutil.move(model_path, new_model_path)
# move preview # move preview
previews = get_model_all_images(model_path) previews = get_model_all_images(model_path)
for preview in previews: for preview in previews:
preview_path = os.path.join(model_dirname, preview) preview_path = join_path(model_dirname, preview)
preview_name = os.path.splitext(preview)[0] preview_name = os.path.splitext(preview)[0]
preview_ext = os.path.splitext(preview)[1] preview_ext = os.path.splitext(preview)[1]
new_preview_path = ( new_preview_path = (
os.path.join(new_model_dirname, new_model_name + preview_ext) join_path(new_model_dirname, new_model_name + preview_ext)
if preview_name == model_name if preview_name == model_name
else os.path.join( else join_path(new_model_dirname, new_model_name + ".preview" + preview_ext)
new_model_dirname, new_model_name + ".preview" + preview_ext
)
) )
os.rename(preview_path, new_preview_path) shutil.move(preview_path, new_preview_path)
# move description # move description
description = get_model_description_name(model_path) description = get_model_description_name(model_path)
description_path = os.path.join(model_dirname, description) description_path = join_path(model_dirname, description)
if os.path.isfile(description_path): if os.path.isfile(description_path):
new_description_path = os.path.join(new_model_dirname, f"{new_model_name}.md") new_description_path = join_path(new_model_dirname, f"{new_model_name}.md")
os.rename(description_path, new_description_path) shutil.move(description_path, new_description_path)
import pickle import pickle
@@ -325,18 +334,16 @@ def resolve_setting_key(key: str) -> str:
return setting_id return setting_id
def set_setting_value(key: str, value: Any): def set_setting_value(request: web.Request, key: str, value: Any):
setting_id = resolve_setting_key(key) setting_id = resolve_setting_key(key)
fake_request = config.FakeRequest() settings = config.serverInstance.user_manager.settings.get_settings(request)
settings = config.serverInstance.user_manager.settings.get_settings(fake_request)
settings[setting_id] = value settings[setting_id] = value
config.serverInstance.user_manager.settings.save_settings(fake_request, settings) config.serverInstance.user_manager.settings.save_settings(request, settings)
def get_setting_value(key: str, default: Any = None) -> Any: def get_setting_value(request: web.Request, key: str, default: Any = None) -> Any:
setting_id = resolve_setting_key(key) setting_id = resolve_setting_key(key)
fake_request = config.FakeRequest() settings = config.serverInstance.user_manager.settings.get_settings(request)
settings = config.serverInstance.user_manager.settings.get_settings(fake_request)
return settings.get(setting_id, default) return settings.get(setting_id, default)
@@ -352,3 +359,8 @@ def unpack_dataclass(data: Any):
return asdict(data) return asdict(data)
else: else:
return data return data
async def send_json(event: str, data: Any, sid: str = None):
detail = unpack_dataclass(data)
await config.serverInstance.send_json(event, detail, sid)

View File

@@ -1,7 +1,7 @@
[project] [project]
name = "comfyui-model-manager" name = "comfyui-model-manager"
description = "Manage models: browsing, download and delete." description = "Manage models: browsing, download and delete."
version = "2.0.0" version = "2.0.3"
license = "LICENSE" license = "LICENSE"
[project.urls] [project.urls]

View File

@@ -28,21 +28,23 @@
<ResponseScroll class="-mx-5 h-full"> <ResponseScroll class="-mx-5 h-full">
<div class="px-5"> <div class="px-5">
<ModelContent <KeepAlive>
v-if="currentModel" <ModelContent
:key="currentModel.id" v-if="currentModel"
:model="currentModel" :key="currentModel.id"
:editable="true" :model="currentModel"
@submit="createDownTask" :editable="true"
> @submit="createDownTask"
<template #action> >
<Button <template #action>
icon="pi pi-download" <Button
:label="$t('download')" icon="pi pi-download"
type="submit" :label="$t('download')"
></Button> type="submit"
</template> ></Button>
</ModelContent> </template>
</ModelContent>
</KeepAlive>
<div v-show="data.length === 0"> <div v-show="data.length === 0">
<div class="flex flex-col items-center gap-4 py-8"> <div class="flex flex-col items-center gap-4 py-8">

View File

@@ -40,6 +40,7 @@
</div> </div>
<ResponseScroll <ResponseScroll
ref="responseScroll"
:items="list" :items="list"
:itemSize="itemSize" :itemSize="itemSize"
:row-key="(item) => item.map(genModelKey).join(',')" :row-key="(item) => item.map(genModelKey).join(',')"
@@ -80,7 +81,7 @@ import ModelCard from 'components/ModelCard.vue'
import ResponseInput from 'components/ResponseInput.vue' import ResponseInput from 'components/ResponseInput.vue'
import ResponseSelect from 'components/ResponseSelect.vue' import ResponseSelect from 'components/ResponseSelect.vue'
import ResponseScroll from 'components/ResponseScroll.vue' import ResponseScroll from 'components/ResponseScroll.vue'
import { computed, ref } from 'vue' import { computed, ref, watch } from 'vue'
import { useI18n } from 'vue-i18n' import { useI18n } from 'vue-i18n'
import { chunk } from 'lodash' import { chunk } from 'lodash'
import { defineResizeCallback } from 'hooks/resize' import { defineResizeCallback } from 'hooks/resize'
@@ -91,6 +92,8 @@ const { isMobile, cardWidth, gutter, aspect, modelFolders } = useConfig()
const { data } = useModels() const { data } = useModels()
const { t } = useI18n() const { t } = useI18n()
const responseScroll = ref()
const searchContent = ref<string>() const searchContent = ref<string>()
const currentType = ref('all') const currentType = ref('all')
@@ -120,6 +123,10 @@ const sortOrderOptions = ref(
}), }),
) )
watch([searchContent, currentType], () => {
responseScroll.value.init()
})
const itemSize = computed(() => { const itemSize = computed(() => {
let itemWidth = cardWidth let itemWidth = cardWidth
let itemGutter = gutter let itemGutter = gutter

View File

@@ -72,8 +72,8 @@ const handleCancel = () => {
} }
const handleSave = async (data: BaseModel) => { const handleSave = async (data: BaseModel) => {
await update(modelContent.value, data)
editable.value = false editable.value = false
await update(props.model, data)
} }
const handleDelete = async () => { const handleDelete = async () => {

View File

@@ -29,11 +29,13 @@
<col /> <col />
</colgroup> </colgroup>
<tbody> <tbody>
<tr v-for="item in information" class="h-8 border-b"> <tr v-for="item in information" class="h-8 whitespace-nowrap border-b">
<td class="border-r bg-gray-300 px-4 dark:bg-gray-800"> <td class="border-r bg-gray-300 px-4 dark:bg-gray-800">
{{ $t(`info.${item.key}`) }} {{ $t(`info.${item.key}`) }}
</td> </td>
<td class="break-all px-4">{{ item.display }}</td> <td class="overflow-hidden text-ellipsis break-all px-4">
{{ item.display }}
</td>
</tr> </tr>
</tbody> </tbody>
</table> </table>
@@ -81,7 +83,8 @@ const pathOptions = computed(() => {
const information = computed(() => { const information = computed(() => {
return Object.values(baseInfo.value).filter((row) => { return Object.values(baseInfo.value).filter((row) => {
if (editable.value) { if (editable.value) {
return row.key !== 'fullname' const hiddenKeys = ['fullname', 'pathIndex']
return !hiddenKeys.includes(row.key)
} }
return true return true
}) })

View File

@@ -34,7 +34,7 @@
</div> </div>
</div> </div>
<div class="duration-300 group-hover/card:opacity-100"> <div class="opacity-0 duration-300 group-hover/card:opacity-100">
<div class="flex flex-col gap-4 *:pointer-events-auto"> <div class="flex flex-col gap-4 *:pointer-events-auto">
<Button <Button
icon="pi pi-plus" icon="pi pi-plus"

View File

@@ -16,7 +16,7 @@
></textarea> ></textarea>
<div v-show="!active"> <div v-show="!active">
<div v-show="editable" class="flex items-center gap-2 text-gray-600"> <div v-show="editable" class="mb-4 flex items-center gap-2 text-gray-600">
<i class="pi pi-info-circle"></i> <i class="pi pi-info-circle"></i>
<span> <span>
{{ $t('tapToChange') }} {{ $t('tapToChange') }}

View File

@@ -44,7 +44,7 @@ const [content, modifiers] = defineModel<string, 'trim'>()
const inputRef = ref() const inputRef = ref()
const innerValue = ref(content) const innerValue = ref(content)
const trigger = computed(() => props.updateTrigger ?? 'input') const trigger = computed(() => props.updateTrigger ?? 'change')
const updateContent = () => { const updateContent = () => {
let value = innerValue.value let value = innerValue.value

View File

@@ -298,7 +298,6 @@ const startDragThumb = (event: MouseEvent) => {
watch( watch(
() => props.items, () => props.items,
() => { () => {
init()
setSpacerSize() setSpacerSize()
calculateScrollThumbSize() calculateScrollThumbSize()
calculateLoadItems() calculateLoadItems()
@@ -311,5 +310,6 @@ onUnmounted(() => {
defineExpose({ defineExpose({
viewport, viewport,
init,
}) })
</script> </script>

View File

@@ -1,8 +1,9 @@
import { useLoading } from 'hooks/loading' import { useLoading } from 'hooks/loading'
import { MarkdownTool, useMarkdown } from 'hooks/markdown' import { MarkdownTool, useMarkdown } from 'hooks/markdown'
import { socket } from 'hooks/socket' import { request } from 'hooks/request'
import { defineStore } from 'hooks/store' import { defineStore } from 'hooks/store'
import { useToast } from 'hooks/toast' import { useToast } from 'hooks/toast'
import { api } from 'scripts/comfyAPI'
import { bytesToSize } from 'utils/common' import { bytesToSize } from 'utils/common'
import { onBeforeMount, onMounted, ref, watch } from 'vue' import { onBeforeMount, onMounted, ref, watch } from 'vue'
import { useI18n } from 'vue-i18n' import { useI18n } from 'vue-i18n'
@@ -13,10 +14,6 @@ export const useDownload = defineStore('download', (store) => {
const taskList = ref<DownloadTask[]>([]) const taskList = ref<DownloadTask[]>([])
const refresh = () => {
socket.send('downloadTaskList', null)
}
const createTaskItem = (item: DownloadTaskOptions) => { const createTaskItem = (item: DownloadTaskOptions) => {
const { downloadedSize, totalSize, bps, ...rest } = item const { downloadedSize, totalSize, bps, ...rest } = item
@@ -26,10 +23,20 @@ export const useDownload = defineStore('download', (store) => {
downloadProgress: `${bytesToSize(downloadedSize)} / ${bytesToSize(totalSize)}`, downloadProgress: `${bytesToSize(downloadedSize)} / ${bytesToSize(totalSize)}`,
downloadSpeed: `${bytesToSize(bps)}/s`, downloadSpeed: `${bytesToSize(bps)}/s`,
pauseTask() { pauseTask() {
socket.send('pauseDownloadTask', item.taskId) request(`/download/${item.taskId}`, {
method: 'PUT',
body: JSON.stringify({
status: 'pause',
}),
})
}, },
resumeTask: () => { resumeTask: () => {
socket.send('resumeDownloadTask', item.taskId) request(`/download/${item.taskId}`, {
method: 'PUT',
body: JSON.stringify({
status: 'resume',
}),
})
}, },
deleteTask: () => { deleteTask: () => {
confirm.require({ confirm.require({
@@ -46,7 +53,9 @@ export const useDownload = defineStore('download', (store) => {
severity: 'danger', severity: 'danger',
}, },
accept: () => { accept: () => {
socket.send('deleteDownloadTask', item.taskId) request(`/download/${item.taskId}`, {
method: 'DELETE',
})
}, },
reject: () => {}, reject: () => {},
}) })
@@ -56,12 +65,28 @@ export const useDownload = defineStore('download', (store) => {
return task return task
} }
const refresh = async () => {
return request('/download/task')
.then((resData: DownloadTaskOptions[]) => {
taskList.value = resData.map((item) => createTaskItem(item))
return taskList.value
})
.catch((err) => {
toast.add({
severity: 'error',
summary: 'Error',
detail: err.message ?? 'Failed to refresh download task list',
life: 15000,
})
})
}
onBeforeMount(() => { onBeforeMount(() => {
socket.addEventListener('reconnected', () => { api.addEventListener('reconnected', () => {
refresh() refresh()
}) })
socket.addEventListener('downloadTaskList', (event) => { api.addEventListener('fetch_download_task_list', (event) => {
const data = event.detail as DownloadTaskOptions[] const data = event.detail as DownloadTaskOptions[]
taskList.value = data.map((item) => { taskList.value = data.map((item) => {
@@ -69,12 +94,12 @@ export const useDownload = defineStore('download', (store) => {
}) })
}) })
socket.addEventListener('createDownloadTask', (event) => { api.addEventListener('create_download_task', (event) => {
const item = event.detail as DownloadTaskOptions const item = event.detail as DownloadTaskOptions
taskList.value.unshift(createTaskItem(item)) taskList.value.unshift(createTaskItem(item))
}) })
socket.addEventListener('updateDownloadTask', (event) => { api.addEventListener('update_download_task', (event) => {
const item = event.detail as DownloadTaskOptions const item = event.detail as DownloadTaskOptions
for (const task of taskList.value) { for (const task of taskList.value) {
@@ -93,12 +118,12 @@ export const useDownload = defineStore('download', (store) => {
} }
}) })
socket.addEventListener('deleteDownloadTask', (event) => { api.addEventListener('delete_download_task', (event) => {
const taskId = event.detail as string const taskId = event.detail as string
taskList.value = taskList.value.filter((item) => item.taskId !== taskId) taskList.value = taskList.value.filter((item) => item.taskId !== taskId)
}) })
socket.addEventListener('completeDownloadTask', (event) => { api.addEventListener('complete_download_task', (event) => {
const taskId = event.detail as string const taskId = event.detail as string
const task = taskList.value.find((item) => item.taskId === taskId) const task = taskList.value.find((item) => item.taskId === taskId)
taskList.value = taskList.value.filter((item) => item.taskId !== taskId) taskList.value = taskList.value.filter((item) => item.taskId !== taskId)

View File

@@ -1,3 +1,4 @@
import { useConfig } from 'hooks/config'
import { useLoading } from 'hooks/loading' import { useLoading } from 'hooks/loading'
import { useMarkdown } from 'hooks/markdown' import { useMarkdown } from 'hooks/markdown'
import { request, useRequest } from 'hooks/request' import { request, useRequest } from 'hooks/request'
@@ -7,7 +8,7 @@ import { cloneDeep } from 'lodash'
import { app } from 'scripts/comfyAPI' import { app } from 'scripts/comfyAPI'
import { bytesToSize, formatDate, previewUrlToFile } from 'utils/common' import { bytesToSize, formatDate, previewUrlToFile } from 'utils/common'
import { ModelGrid } from 'utils/legacy' import { ModelGrid } from 'utils/legacy'
import { resolveModelTypeLoader } from 'utils/model' import { genModelKey, resolveModelTypeLoader } from 'utils/model'
import { import {
computed, computed,
inject, inject,
@@ -20,7 +21,7 @@ import {
} from 'vue' } from 'vue'
import { useI18n } from 'vue-i18n' import { useI18n } from 'vue-i18n'
export const useModels = defineStore('models', () => { export const useModels = defineStore('models', (store) => {
const { data, refresh } = useRequest<Model[]>('/models', { defaultValue: [] }) const { data, refresh } = useRequest<Model[]>('/models', { defaultValue: [] })
const { toast, confirm } = useToast() const { toast, confirm } = useToast()
const { t } = useI18n() const { t } = useI18n()
@@ -28,6 +29,7 @@ export const useModels = defineStore('models', () => {
const updateModel = async (model: BaseModel, data: BaseModel) => { const updateModel = async (model: BaseModel, data: BaseModel) => {
const formData = new FormData() const formData = new FormData()
let oldKey: string | null = null
// Check current preview // Check current preview
if (model.preview !== data.preview) { if (model.preview !== data.preview) {
@@ -45,6 +47,7 @@ export const useModels = defineStore('models', () => {
model.fullname !== data.fullname || model.fullname !== data.fullname ||
model.pathIndex !== data.pathIndex model.pathIndex !== data.pathIndex
) { ) {
oldKey = genModelKey(model)
formData.append('type', data.type) formData.append('type', data.type)
formData.append('pathIndex', data.pathIndex.toString()) formData.append('pathIndex', data.pathIndex.toString())
formData.append('fullname', data.fullname) formData.append('fullname', data.fullname)
@@ -59,19 +62,25 @@ export const useModels = defineStore('models', () => {
method: 'PUT', method: 'PUT',
body: formData, body: formData,
}) })
.catch(() => { .catch((err) => {
const error_message = err.message ?? err.error
toast.add({ toast.add({
severity: 'error', severity: 'error',
summary: 'Error', summary: 'Error',
detail: 'Failed to update model', detail: `Failed to update model: ${error_message}`,
life: 15000, life: 15000,
}) })
throw new Error(error_message)
}) })
.finally(() => { .finally(() => {
loading.hide() loading.hide()
}) })
await refresh() if (oldKey) {
store.dialog.close({ key: oldKey })
}
refresh()
} }
const deleteModel = async (model: BaseModel) => { const deleteModel = async (model: BaseModel) => {
@@ -90,6 +99,7 @@ export const useModels = defineStore('models', () => {
severity: 'danger', severity: 'danger',
}, },
accept: () => { accept: () => {
const dialogKey = genModelKey(model)
loading.show() loading.show()
request(`/model/${model.type}/${model.pathIndex}/${model.fullname}`, { request(`/model/${model.type}/${model.pathIndex}/${model.fullname}`, {
method: 'DELETE', method: 'DELETE',
@@ -101,6 +111,7 @@ export const useModels = defineStore('models', () => {
detail: `${model.fullname} Deleted`, detail: `${model.fullname} Deleted`,
life: 2000, life: 2000,
}) })
store.dialog.close({ key: dialogKey })
return refresh() return refresh()
}) })
.then(() => { .then(() => {
@@ -118,7 +129,9 @@ export const useModels = defineStore('models', () => {
loading.hide() loading.hide()
}) })
}, },
reject: () => {}, reject: () => {
resolve(void 0)
},
}) })
}) })
} }
@@ -191,6 +204,8 @@ const baseInfoKey = Symbol('baseInfo') as InjectionKey<
export const useModelBaseInfoEditor = (formInstance: ModelFormInstance) => { export const useModelBaseInfoEditor = (formInstance: ModelFormInstance) => {
const { formData: model, modelData } = formInstance const { formData: model, modelData } = formInstance
const { modelFolders } = useConfig()
const type = computed({ const type = computed({
get: () => { get: () => {
return model.value.type return model.value.type
@@ -239,6 +254,15 @@ export const useModelBaseInfoEditor = (formInstance: ModelFormInstance) => {
key: 'type', key: 'type',
formatter: () => modelData.value.type, formatter: () => modelData.value.type,
}, },
{
key: 'pathIndex',
formatter: () => {
const modelType = modelData.value.type
const pathIndex = modelData.value.pathIndex
const folders = modelFolders.value[modelType] ?? []
return `${folders[pathIndex]}`
},
},
{ {
key: 'fullname', key: 'fullname',
formatter: (val) => val, formatter: (val) => val,

View File

@@ -1,82 +0,0 @@
import { globalToast } from 'hooks/toast'
import { readonly } from 'vue'
class WebSocketEvent extends EventTarget {
private socket: WebSocket | null
constructor() {
super()
this.createSocket()
}
private createSocket(isReconnect?: boolean) {
const api_host = location.host
const api_base = location.pathname.split('/').slice(0, -1).join('/')
let opened = false
let existingSession = window.name
if (existingSession) {
existingSession = '?clientId=' + existingSession
}
this.socket = readonly(
new WebSocket(
`ws${window.location.protocol === 'https:' ? 's' : ''}://${api_host}${api_base}/model-manager/ws${existingSession}`,
),
)
this.socket.addEventListener('open', () => {
opened = true
if (isReconnect) {
this.dispatchEvent(new CustomEvent('reconnected'))
}
})
this.socket.addEventListener('error', () => {
if (this.socket) this.socket.close()
})
this.socket.addEventListener('close', (event) => {
setTimeout(() => {
this.socket = null
this.createSocket(true)
}, 300)
if (opened) {
this.dispatchEvent(new CustomEvent('status', { detail: null }))
this.dispatchEvent(new CustomEvent('reconnecting'))
}
})
this.socket.addEventListener('message', (event) => {
try {
const msg = JSON.parse(event.data)
if (msg.type === 'error') {
globalToast.value?.add({
severity: 'error',
summary: 'Error',
detail: msg.data,
life: 15000,
})
} else {
this.dispatchEvent(new CustomEvent(msg.type, { detail: msg.data }))
}
} catch (error) {
console.error(error)
}
})
}
addEventListener = (
type: string,
callback: CustomEventListener | null,
options?: AddEventListenerOptions | boolean,
) => {
super.addEventListener(type, callback, options)
}
send(type: string, data: any) {
this.socket?.send(JSON.stringify({ type, detail: data }))
}
}
export const socket = new WebSocketEvent()

View File

@@ -32,6 +32,7 @@ const messages = {
}, },
info: { info: {
type: 'Model Type', type: 'Model Type',
pathIndex: 'Directory',
fullname: 'File Name', fullname: 'File Name',
sizeBytes: 'File Size', sizeBytes: 'File Size',
createdAt: 'Created At', createdAt: 'Created At',
@@ -69,6 +70,7 @@ const messages = {
}, },
info: { info: {
type: '类型', type: '类型',
pathIndex: '目录',
fullname: '文件名', fullname: '文件名',
sizeBytes: '文件大小', sizeBytes: '文件大小',
createdAt: '创建时间', createdAt: '创建时间',

View File

@@ -82,26 +82,69 @@
.markdown-it { .markdown-it {
font-family: theme('fontFamily.sans'); font-family: theme('fontFamily.sans');
font-size: theme('fontSize.base');
line-height: theme('lineHeight.relaxed'); line-height: theme('lineHeight.relaxed');
word-break: break-word; word-break: break-word;
margin: 0; margin: 0;
&::before {
display: table;
content: '';
}
&::after {
display: table;
content: '';
clear: both;
}
> *:first-child {
margin-top: 0 !important;
}
> *:last-child {
margin-bottom: 0 !important;
}
h1,
h2,
h3,
h4,
h5,
h6 {
margin-top: 1.5em;
margin-bottom: 1em;
font-weight: 600;
line-height: 1.25;
}
h1 { h1 {
font-size: theme('fontSize.2xl'); font-size: 2em;
font-weight: theme('fontWeight.bold'); padding-bottom: 0.3em;
border-bottom: 1px solid #ddd; border-bottom: 1px solid var(--p-surface-700);
margin-top: theme('margin.4');
margin-bottom: theme('margin.4');
padding-bottom: theme('padding[2.5]');
} }
h2 { h2 {
font-size: theme('fontSize.xl'); font-size: 1.5em;
font-weight: theme('fontWeight.bold'); padding-bottom: 0.3em;
border-bottom: 1px solid var(--p-surface-700);
} }
h3 { h3 {
font-size: theme('fontSize.lg'); font-size: 1.25em;
}
h4 {
font-size: 1em;
}
h5 {
font-size: 0.875em;
}
h6 {
font-size: 0.85em;
color: var(--p-surface-500);
} }
a { a {
@@ -114,8 +157,16 @@
text-decoration: underline; text-decoration: underline;
} }
p { p,
margin: 1em 0; blockquote,
ul,
ol,
dl,
table,
pre,
details {
margin-top: 0;
margin-bottom: 1em;
} }
p img { p img {
@@ -126,7 +177,6 @@
ul, ul,
ol { ol {
margin: 1em 0;
padding-left: 2em; padding-left: 2em;
} }
@@ -135,23 +185,38 @@
} }
blockquote { blockquote {
border-left: 5px solid #ddd; padding: 0px 1em;
padding: 10px 20px; border-left: 0.25em solid var(--p-surface-500);
margin: 1.5em 0; color: var(--p-surface-500);
background: #f9f9f9; margin: 1em 0;
} }
code, blockquote > *:first-child {
pre { margin-top: 0;
background: #f9f9f9; }
padding: 3px 5px;
border: 1px solid #ddd; blockquote > *:last-child {
border-radius: 3px; margin-bottom: 0;
font-family: 'Courier New', Courier, monospace;
} }
pre { pre {
padding: 10px; font-size: 85%;
border-radius: 6px;
padding: 8px 16px;
overflow-x: auto; overflow-x: auto;
background: var(--p-dialog-background);
filter: invert(10%);
}
pre code,
pre tt {
display: inline;
padding: 0;
margin: 0;
overflow: visible;
line-height: inherit;
word-wrap: normal;
background-color: transparent;
border: 0;
} }
} }

View File

@@ -79,7 +79,7 @@ function dev(): Plugin {
fs.mkdirSync(outDirPath) fs.mkdirSync(outDirPath)
const port = server.config.server.port const port = server.config.server.port
const content = `import "http://127.0.0.1:${port}/src/main.ts";` const content = `import "http://localhost:${port}/src/main.ts";`
fs.writeFileSync(path.join(outDirPath, 'manager-dev.js'), content) fs.writeFileSync(path.join(outDirPath, 'manager-dev.js'), content)
}) })
}, },