Files
ComfyUI-Model-Manager/py/services.py

146 lines
5.2 KiB
Python

import os
import logging
import traceback
import folder_paths
from typing import Any
from multidict import MultiDictProxy
from . import utils
from . import socket
from . import download
async def connect_websocket(request):
async def message_handler(event_type: str, detail: Any, sid: str):
try:
if event_type == "downloadTaskList":
await download.scan_model_download_task_list(sid=sid)
if event_type == "resumeDownloadTask":
await download.download_model(task_id=detail)
if event_type == "pauseDownloadTask":
await download.pause_model_download_task(task_id=detail)
if event_type == "deleteDownloadTask":
await download.delete_model_download_task(task_id=detail)
except Exception:
logging.error(traceback.format_exc())
ws = await socket.create_websocket_handler(request, handler=message_handler)
return ws
def scan_models_by_model_type(model_type: str):
"""
Scans all models in the given model type and returns a list of models.
"""
out = []
folders, extensions = folder_paths.folder_names_and_paths[model_type]
for path_index, base_path in enumerate(folders):
files = utils.recursive_search_files(base_path)
models = folder_paths.filter_files_extensions(files, extensions)
for fullname in models:
"""
fullname is model path relative to base_path
eg.
abs_path is /path/to/models/stable-diffusion/custom_group/model_name.ckpt
base_path is /path/to/models/stable-diffusion
fullname is custom_group/model_name.ckpt
basename is custom_group/model_name
extension is .ckpt
"""
fullname = fullname.replace(os.path.sep, "/")
basename = os.path.splitext(fullname)[0]
extension = os.path.splitext(fullname)[1]
prefix_path = fullname.replace(os.path.basename(fullname), "")
abs_path = os.path.join(base_path, fullname)
file_stats = os.stat(abs_path)
# Resolve metadata
metadata = utils.get_model_metadata(abs_path)
# Resolve preview
image_name = utils.get_model_preview_name(abs_path)
image_name = os.path.join(prefix_path, image_name)
abs_image_path = os.path.join(base_path, image_name)
if os.path.isfile(abs_image_path):
image_state = os.stat(abs_image_path)
image_timestamp = round(image_state.st_mtime_ns / 1000000)
image_name = f"{image_name}?ts={image_timestamp}"
model_preview = (
f"/model-manager/preview/{model_type}/{path_index}/{image_name}"
)
# Resolve description
description_file = utils.get_model_description_name(abs_path)
description_file = os.path.join(prefix_path, description_file)
abs_desc_path = os.path.join(base_path, description_file)
description = None
if os.path.isfile(abs_desc_path):
with open(abs_desc_path, "r", encoding="utf-8") as f:
description = f.read()
out.append(
{
"fullname": fullname,
"basename": basename,
"extension": extension,
"type": model_type,
"pathIndex": path_index,
"sizeBytes": file_stats.st_size,
"preview": model_preview,
"description": description,
"createdAt": round(file_stats.st_ctime_ns / 1000000),
"updatedAt": round(file_stats.st_mtime_ns / 1000000),
"metadata": metadata,
}
)
return out
def update_model(model_path: str, post: MultiDictProxy):
if "previewFile" in post:
previewFile = post["previewFile"]
utils.save_model_preview_image(model_path, previewFile)
if "description" in post:
description = post["description"]
utils.save_model_description(model_path, description)
if "type" in post and "pathIndex" in post and "fullname" in post:
model_type = post.get("type", None)
path_index = int(post.get("pathIndex", None))
fullname = post.get("fullname", None)
if model_type is None or path_index is None or fullname is None:
raise RuntimeError("Invalid type or pathIndex or fullname")
# get new path
new_model_path = utils.get_full_path(model_type, path_index, fullname)
utils.rename_model(model_path, new_model_path)
def remove_model(model_path: str):
model_dirname = os.path.dirname(model_path)
os.remove(model_path)
model_previews = utils.get_model_all_images(model_path)
for preview in model_previews:
os.remove(os.path.join(model_dirname, preview))
model_descriptions = utils.get_model_all_descriptions(model_path)
for description in model_descriptions:
os.remove(os.path.join(model_dirname, description))
async def create_model_download_task(post):
dict_post = dict(post)
return await download.create_model_download_task(dict_post)