diff --git a/.gitignore b/.gitignore index 9dd999f..2585b6a 100644 --- a/.gitignore +++ b/.gitignore @@ -194,3 +194,6 @@ node_modules/ # dist web/ + +# config +config/ diff --git a/.prettierrc b/.prettierrc index 8da9e7e..0113523 100644 --- a/.prettierrc +++ b/.prettierrc @@ -7,6 +7,7 @@ "endOfLine": "lf", "semi": false, "plugins": [ - "prettier-plugin-organize-imports" + "prettier-plugin-organize-imports", + "prettier-plugin-tailwindcss" ] } \ No newline at end of file diff --git a/.vscode/settings.json b/.vscode/settings.json index 66ffc70..cb2e947 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -1,23 +1,47 @@ { "cSpell.words": [ - "apng", - "Civitai", - "ckpt", + "tailwindcss", + "vnode", + "unref", + "civitai", + "huggingface", "comfyui", - "FYUIKMNVB", - "gguf", + "ckpt", "gligen", - "jfif", - "locon", "loras", - "noimage", - "onnx", - "rfilename", + "safetensors", "unet", - "upscaler" + "controlnet", + "hypernetwork", + "hypernetworks", + "photomaker", + "upscaler", + "comfyorg", + "fullname", + "primevue", + "maximizable", + "inputgroup", + "inputgroupaddon", + "iconfield", + "inputicon", + "inputtext", + "overlaybadge", + "usetoast", + "toastservice", + "useconfirm", + "confirmationservice", + "confirmdialog", + "popupmenu", + "inplace", + "contentcontainer", + "itemlist", ], "editor.defaultFormatter": "esbenp.prettier-vscode", "files.associations": { "*.css": "tailwindcss" - } + }, + "editor.quickSuggestions": { + "strings": "on" + }, + "css.lint.unknownAtRules": "ignore" } \ No newline at end of file diff --git a/README.md b/README.md index 41626f7..f9f50a1 100644 --- a/README.md +++ b/README.md @@ -4,64 +4,60 @@ Download, browse and delete models in ComfyUI. Designed to support desktop, mobile and multi-screen devices. -Model Manager Demo Screenshot +# Usage -Model Manager Demo Screenshot +```bash +cd /path/to/ComfyUI/custom_nodes +git clone https://github.com/hayden-fr/ComfyUI-Model-Manager.git +cd /path/to/ComfyUI/custom_nodes/ComfyUI-Model-Manager +npm install +npm run build +``` ## Features -### Node Graph +## Freely adjust size and position -Model Manager Demo Screenshot + + +### Support Node Graph + + - Drag a model thumbnail onto the graph to add a new node. - Drag a model thumbnail onto an existing node to set the input field. - If there are multiple valid possible fields, then the drag must be exact. - Drag an embedding thumbnail onto a text area, or highlight any number of nodes, to append it onto the end of the text. - Drag the preview image in a model's info view onto the graph to load the embedded workflow (if it exists). - -Model Manager Demo Screenshot - - Press the "copy" button to copy a model to ComfyUI's clipboard or copy the embedding to the system clipboard. (Copying the embedding to the system clipboard requires a secure http connection.) - Press the "add" button to add the model to the ComfyUI graph or append the embedding to one or more selected nodes. - Press the "load workflow" button to try and load a workflow embedded in a model's preview image. ### Download Tab -Model Manager Demo Screenshot + - View multiple models associated with a url. - Select a save directory and input a filename. - Optionally set a model's preview image. -- Optionally edit and save descriptions as a .txt note. (Default behavior can be set in the settings tab.) -- Add Civitai and HuggingFace API tokens in `server_settings.yaml`. +- Optionally edit and save descriptions as a .md note. +- Add Civitai and HuggingFace API tokens in ComfyUI's settings. + + ### Models Tab -Model Manager Demo Screenshot +Model Manager Demo Screenshot - Search in real-time for models using the search bar. -- Use advance keyword search by typing `"multiple words in quotes"` or a minus sign before to `-exclude` a word or phrase. -- Add `/` at the start of a search to view a dropdown list of subdirectories (for example, `/0/1.5/styles/clothing`). - - Any directory paths in ComfyUI's `extra_model_paths.yaml` or directories added in `ComfyUI/models/` will automatically be detected. -- Sort models by "Date Created", "Date Modified", "Name" and "File Size". +- Sort models by "Name", "File Size", "Date Created" and "Date Modified". ### Model Info View -Model Manager Demo Screenshot +Model Manager Demo Screenshot - View file info and metadata. - Rename, move or **permanently** remove a model and all of it's related files. -- Read, edit and save notes. (Saved as a `.txt` file beside the model). - - `Ctrl+s` or `⌘+S` to save a note when the textarea is in focus. - - Autosave can be enabled in settings. (Note: Once the model info view is closed, the undo history is lost.) +- Read, edit and save notes. (Saved as a `.md` file beside the model). - Change or remove a model's preview image. - View training tags and use the random tag generator to generate prompt ideas. (Inspired by the one in A1111.) - -### Settings Tab - -Model Manager Demo Screenshot - -- Settings are saved to `ui_settings.yaml`. -- Most settings should update immediately, but a few may require a page reload to take effect. -- Press the "Fix Extensions" button to correct all image file extensions in the model directories. (Note: This may take a minute or so to complete.) diff --git a/__init__.py b/__init__.py index 6aadb9b..1525a25 100644 --- a/__init__.py +++ b/__init__.py @@ -1,1228 +1,183 @@ import os -import io -import pathlib -import shutil -from datetime import datetime -import sys -import copy -import importlib -import re -import base64 - -from aiohttp import web -import server -import urllib.parse -import urllib.request -import struct -import json -import requests -requests.packages.urllib3.disable_warnings() - -import comfy.utils import folder_paths - -comfyui_model_uri = folder_paths.models_dir - -extension_uri = os.path.dirname(__file__) - -config_loader_path = os.path.join(extension_uri, 'config_loader.py') -config_loader_spec = importlib.util.spec_from_file_location('config_loader', config_loader_path) -config_loader = importlib.util.module_from_spec(config_loader_spec) -config_loader_spec.loader.exec_module(config_loader) - -no_preview_image = os.path.join(extension_uri, "no-preview.png") -ui_settings_uri = os.path.join(extension_uri, "ui_settings.yaml") -server_settings_uri = os.path.join(extension_uri, "server_settings.yaml") - -fallback_model_extensions = set([".bin", ".ckpt", ".gguf", ".onnx", ".pt", ".pth", ".safetensors"]) # TODO: magic values -jpeg_format_names = ["JPG", "JPEG", "JFIF"] -image_extensions = ( - ".png", # order matters - ".webp", - ".jpeg", - ".jpg", - ".jfif", - ".gif", - ".apng", -) -stable_diffusion_webui_civitai_helper_image_extensions = ( - ".preview.png", # order matters - ".preview.webp", - ".preview.jpeg", - ".preview.jpg", - ".preview.jfif", - ".preview.gif", - ".preview.apng", -) -preview_extensions = ( # TODO: JavaScript does not know about this (x2 states) - image_extensions + # order matters - stable_diffusion_webui_civitai_helper_image_extensions -) -model_info_extension = ".txt" -#video_extensions = (".avi", ".mp4", ".webm") # TODO: Requires ffmpeg or cv2. Cache preview frame? - -def split_valid_ext(s, *arg_exts): - sl = s.lower() - for exts in arg_exts: - for ext in exts: - if sl.endswith(ext.lower()): - return (s[:-len(ext)], ext) - return (s, "") - -_folder_names_and_paths = None # dict[str, tuple[list[str], list[str]]] -def folder_paths_folder_names_and_paths(refresh = False): - global _folder_names_and_paths - if refresh or _folder_names_and_paths is None: - _folder_names_and_paths = {} - for item_name in os.listdir(comfyui_model_uri): - item_path = os.path.join(comfyui_model_uri, item_name) - if not os.path.isdir(item_path): - continue - if item_name == "configs": - continue - if item_name in folder_paths.folder_names_and_paths: - dir_paths, extensions = copy.deepcopy(folder_paths.folder_names_and_paths[item_name]) - else: - dir_paths = [item_path] - extensions = copy.deepcopy(fallback_model_extensions) - _folder_names_and_paths[item_name] = (dir_paths, extensions) - return _folder_names_and_paths - -def folder_paths_get_folder_paths(folder_name, refresh = False): # API function crashes querying unknown model folder - paths = folder_paths_folder_names_and_paths(refresh) - if folder_name in paths: - return paths[folder_name][0] - - maybe_path = os.path.join(comfyui_model_uri, folder_name) - if os.path.exists(maybe_path): - return [maybe_path] - return [] - -def folder_paths_get_supported_pt_extensions(folder_name, refresh = False): # Missing API function - paths = folder_paths_folder_names_and_paths(refresh) - if folder_name in paths: - return paths[folder_name][1] - model_extensions = copy.deepcopy(fallback_model_extensions) - return model_extensions +from .py import config +from .py import utils -def search_path_to_system_path(model_path): - sep = os.path.sep - model_path = os.path.normpath(model_path.replace("/", sep)) - model_path = model_path.lstrip(sep) +# Init config settings +config.extension_uri = os.path.dirname(__file__) +utils.resolve_model_base_paths() - isep1 = model_path.find(sep, 0) - if isep1 == -1 or isep1 == len(model_path): - return (None, None) - isep2 = model_path.find(sep, isep1 + 1) - if isep2 == -1 or isep2 - isep1 == 1: - isep2 = len(model_path) +import logging +from aiohttp import web +import traceback +from .py import services - model_path_type = model_path[0:isep1] - paths = folder_paths_get_folder_paths(model_path_type) - if len(paths) == 0: - return (None, None) - model_path_index = model_path[isep1 + 1:isep2] +routes = config.routes + + +@routes.get("/model-manager/ws") +async def socket_handler(request): + """ + Handle websocket connection. + """ + ws = await services.connect_websocket(request) + return ws + + +@routes.get("/model-manager/base-folders") +async def get_model_paths(request): + """ + Returns the base folders for models. + """ + model_base_paths = config.model_base_paths + return web.json_response({"success": True, "data": model_base_paths}) + + +@routes.post("/model-manager/model") +async def create_model(request): + """ + Create a new model. + + request body: x-www-form-urlencoded + - type: model type. + - pathIndex: index of the model folders. + - fullname: filename that relative to the model folder. + - previewFile: preview file. + - description: description. + - downloadPlatform: download platform. + - downloadUrl: download url. + - hash: a JSON string containing the hash value of the downloaded model. + """ + post = await request.post() try: - model_path_index = int(model_path_index) - except: - return (None, None) - if model_path_index < 0 or model_path_index >= len(paths): - return (None, None) - - system_path = os.path.normpath( - paths[model_path_index] + - sep + - model_path[isep2:] - ) - - return (system_path, model_path_type) - - -def get_safetensor_header(path): - try: - header_bytes = comfy.utils.safetensors_header(path) - header_json = json.loads(header_bytes) - return header_json if header_json is not None else {} - except: - return {} - - -def end_swap_and_pop(x, i): - x[i], x[-1] = x[-1], x[i] - return x.pop(-1) - - -def model_type_to_dir_name(model_type): - if model_type == "checkpoint": return "checkpoints" - #elif model_type == "clip": return "clip" - #elif model_type == "clip_vision": return "clip_vision" - #elif model_type == "controlnet": return "controlnet" - elif model_type == "diffuser": return "diffusers" - elif model_type == "embedding": return "embeddings" - #elif model_type== "gligen": return "gligen" - elif model_type == "hypernetwork": return "hypernetworks" - elif model_type == "lora": return "loras" - #elif model_type == "style_models": return "style_models" - #elif model_type == "unet": return "unet" - elif model_type == "upscale_model": return "upscale_models" - #elif model_type == "vae": return "vae" - #elif model_type == "vae_approx": return "vae_approx" - else: return model_type - - -def ui_rules(): - Rule = config_loader.Rule - return [ - Rule("model-search-always-append", "", str), - Rule("model-default-browser-model-type", "checkpoints", str), - Rule("model-real-time-search", True, bool), - Rule("model-persistent-search", True, bool), - - Rule("model-preview-thumbnail-type", "AUTO", str), - Rule("model-preview-fallback-search-safetensors-thumbnail", False, bool), - Rule("model-show-label-extensions", False, bool), - Rule("model-show-add-button", True, bool), - Rule("model-show-copy-button", True, bool), - Rule("model-show-load-workflow-button", True, bool), - Rule("model-info-button-on-left", False, bool), - - Rule("model-add-embedding-extension", False, bool), - Rule("model-add-drag-strict-on-field", False, bool), - Rule("model-add-offset", 25, int), - - Rule("model-info-autosave-notes", False, bool), - - Rule("download-save-description-as-text-file", True, bool), - - Rule("sidebar-control-always-compact", False, bool), - Rule("sidebar-default-width", 0.5, float, 0.0, 1.0), - Rule("sidebar-default-height", 0.5, float, 0.0, 1.0), - Rule("text-input-always-hide-search-button", False, bool), - Rule("text-input-always-hide-clear-button", False, bool), - - Rule("tag-generator-sampler-method", "Frequency", str), - Rule("tag-generator-count", 10, int), - Rule("tag-generator-threshold", 2, int), - ] - - -def server_rules(): - Rule = config_loader.Rule - return [ - #Rule("model_extension_download_whitelist", [".safetensors"], list), - Rule("civitai_api_key", "", str), - Rule("huggingface_api_key", "", str), - ] -server_settings = config_loader.yaml_load(server_settings_uri, server_rules()) -config_loader.yaml_save(server_settings_uri, server_rules(), server_settings) - - -def get_def_headers(url=""): - def_headers = { - "User-Agent": "Mozilla/5.0 (iPad; CPU OS 12_2 like Mac OS X) AppleWebKit/605.1.15 (KHTML, like Gecko) Mobile/15E148", - } - - if url.startswith("https://civitai.com/"): - api_key = server_settings["civitai_api_key"] - if (api_key != ""): - def_headers["Authorization"] = f"Bearer {api_key}" - url += "&" if "?" in url else "?" # not the most robust solution - url += f"token={api_key}" # TODO: Authorization didn't work in the header - elif url.startswith("https://huggingface.co/"): - api_key = server_settings["huggingface_api_key"] - if api_key != "": - def_headers["Authorization"] = f"Bearer {api_key}" - - return def_headers - - -@server.PromptServer.instance.routes.get("/model-manager/timestamp") -async def get_timestamp(request): - return web.json_response({ "timestamp": datetime.now().timestamp() }) - - -@server.PromptServer.instance.routes.get("/model-manager/settings/load") -async def load_ui_settings(request): - rules = ui_rules() - settings = config_loader.yaml_load(ui_settings_uri, rules) - return web.json_response({ "settings": settings }) - - -@server.PromptServer.instance.routes.post("/model-manager/settings/save") -async def save_ui_settings(request): - body = await request.json() - settings = body.get("settings") - rules = ui_rules() - validated_settings = config_loader.validated(rules, settings) - success = config_loader.yaml_save(ui_settings_uri, rules, validated_settings) - print("Saved file: " + ui_settings_uri) - return web.json_response({ - "success": success, - "settings": validated_settings if success else "", - }) - - -from PIL import Image, TiffImagePlugin -from PIL.PngImagePlugin import PngInfo -def PIL_cast_serializable(v): - # source: https://github.com/python-pillow/Pillow/issues/6199#issuecomment-1214854558 - if isinstance(v, TiffImagePlugin.IFDRational): - return float(v) - elif isinstance(v, tuple): - return tuple(PIL_cast_serializable(t) for t in v) - elif isinstance(v, bytes): - return v.decode(errors="replace") - elif isinstance(v, dict): - for kk, vv in v.items(): - v[kk] = PIL_cast_serializable(vv) - return v - else: - return v - - -def get_safetensors_image_bytes(path): - if not os.path.isfile(path): - raise RuntimeError("Path was invalid!") - header = get_safetensor_header(path) - metadata = header.get("__metadata__", None) - if metadata is None: - return None - thumbnail = metadata.get("modelspec.thumbnail", None) - if thumbnail is None: - return None - image_data = thumbnail.split(',')[1] - return base64.b64decode(image_data) - - -def get_image_info(image): - metadata = None - if len(image.info) > 0: - metadata = PngInfo() - for (key, value) in image.info.items(): - value_str = str(PIL_cast_serializable(value)) # not sure if this is correct (sometimes includes exif) - metadata.add_text(key, value_str) - return metadata - - -def image_format_is_equal(f1, f2): - if not isinstance(f1, str) or not isinstance(f2, str): - return False - if f1[0] == ".": f1 = f1[1:] - if f2[0] == ".": f2 = f2[1:] - f1 = f1.upper() - f2 = f2.upper() - return f1 == f2 or (f1 in jpeg_format_names and f2 in jpeg_format_names) - - -def get_auto_thumbnail_format(original_format): - if original_format in ["JPEG", "WEBP", "JPG"]: # JFIF? - return original_format - return "JPEG" # default fallback - - -@server.PromptServer.instance.routes.get("/model-manager/preview/get") -async def get_model_preview(request): - uri = request.query.get("uri") - quality = 75 - response_image_format = request.query.get("image-format", None) - if isinstance(response_image_format, str): - response_image_format = response_image_format.upper() - - image_path = no_preview_image - file_name = os.path.split(no_preview_image)[1] - if uri != "no-preview": - sep = os.path.sep - uri = uri.replace("/" if sep == "\\" else "/", sep) - path, _ = search_path_to_system_path(uri) - head, extension = split_valid_ext(path, preview_extensions) - if os.path.exists(path): - image_path = path - file_name = os.path.split(head)[1] + extension - elif os.path.exists(head) and head.endswith(".safetensors"): - image_path = head - file_name = os.path.splitext(os.path.split(head)[1])[0] + extension - - w = request.query.get("width") - h = request.query.get("height") - try: - w = int(w) - if w < 1: - w = None - except: - w = None - try: - h = int(h) - if w < 1: - h = None - except: - h = None - - image_data = None - if w is None and h is None: # full size - if image_path.endswith(".safetensors"): - image_data = get_safetensors_image_bytes(image_path) - else: - with open(image_path, "rb") as image: - image_data = image.read() - fp = io.BytesIO(image_data) - with Image.open(fp) as image: - image_format = image.format - if response_image_format is None: - response_image_format = image_format - elif response_image_format == "AUTO": - response_image_format = get_auto_thumbnail_format(image_format) - - if not image_format_is_equal(response_image_format, image_format): - exif = image.getexif() - metadata = get_image_info(image) - if response_image_format in jpeg_format_names: - image = image.convert('RGB') - image_bytes = io.BytesIO() - image.save(image_bytes, format=response_image_format, exif=exif, pnginfo=metadata, quality=quality) - image_data = image_bytes.getvalue() - else: - if image_path.endswith(".safetensors"): - image_data = get_safetensors_image_bytes(image_path) - fp = io.BytesIO(image_data) - else: - fp = image_path - - with Image.open(fp) as image: - image_format = image.format - if response_image_format is None: - response_image_format = image_format - elif response_image_format == "AUTO": - response_image_format = get_auto_thumbnail_format(image_format) - - w0, h0 = image.size - if w is None: - w = (h * w0) // h0 - elif h is None: - h = (w * h0) // w0 - - exif = image.getexif() - metadata = get_image_info(image) - - ratio_original = w0 / h0 - ratio_thumbnail = w / h - if abs(ratio_original - ratio_thumbnail) < 0.01: - crop_box = (0, 0, w0, h0) - elif ratio_original > ratio_thumbnail: - crop_width_fp = h0 * w / h - x0 = int((w0 - crop_width_fp) / 2) - crop_box = (x0, 0, x0 + int(crop_width_fp), h0) - else: - crop_height_fp = w0 * h / w - y0 = int((h0 - crop_height_fp) / 2) - crop_box = (0, y0, w0, y0 + int(crop_height_fp)) - image = image.crop(crop_box) - - if w < w0 and h < h0: - resampling_method = Image.Resampling.BOX - else: - resampling_method = Image.Resampling.BICUBIC - image.thumbnail((w, h), resample=resampling_method) - - if not image_format_is_equal(image_format, response_image_format) and response_image_format in jpeg_format_names: - image = image.convert('RGB') - image_bytes = io.BytesIO() - image.save(image_bytes, format=response_image_format, exif=exif, pnginfo=metadata, quality=quality) - image_data = image_bytes.getvalue() - - response_file_name = os.path.splitext(file_name)[0] + '.' + response_image_format.lower() - return web.Response( - headers={ - "Content-Disposition": f"inline; filename={response_file_name}", - }, - body=image_data, - content_type="image/" + response_image_format.lower(), - ) - - -@server.PromptServer.instance.routes.get("/model-manager/image/extensions") -async def get_image_extensions(request): - return web.json_response(image_extensions) - - -def download_model_preview(formdata): - path = formdata.get("path", None) - if type(path) is not str: - raise ValueError("Invalid path!") - path, model_type = search_path_to_system_path(path) - model_type_extensions = folder_paths_get_supported_pt_extensions(model_type) - path_without_extension, _ = split_valid_ext(path, model_type_extensions) - - overwrite = formdata.get("overwrite", "true").lower() - overwrite = True if overwrite == "true" else False - - image = formdata.get("image", None) - if type(image) is str: - civitai_image_url = "https://civitai.com/images/" - if image.startswith(civitai_image_url): - image_id = re.search(r"^\d+", image[len(civitai_image_url):]).group(0) - image_id = str(int(image_id)) - image_info_url = f"https://civitai.com/api/v1/images?imageId={image_id}" - def_headers = get_def_headers(image_info_url) - response = requests.get( - url=image_info_url, - stream=False, - verify=False, - headers=def_headers, - proxies=None, - allow_redirects=False, - ) - if response.ok: - content_type = response.headers.get("Content-Type") - info = response.json() - items = info["items"] - if len(items) == 0: - raise RuntimeError("Civitai /api/v1/images returned 0 items!") - image = items[0]["url"] - else: - raise RuntimeError("Bad response from api/v1/images!") - _, image_extension = split_valid_ext(image, image_extensions) - if image_extension == "": - raise ValueError("Invalid image type!") - image_path = path_without_extension + image_extension - download_file(image, image_path, overwrite) - else: - content_type = image.content_type - if not content_type.startswith("image/"): - raise RuntimeError("Invalid content type!") - image_extension = "." + content_type[len("image/"):] - if image_extension not in image_extensions: - raise RuntimeError("Invalid extension!") - - image_path = path_without_extension + image_extension - if not overwrite and os.path.isfile(image_path): - raise RuntimeError("Image already exists!") - file: io.IOBase = image.file - image_data = file.read() - with open(image_path, "wb") as f: - f.write(image_data) - print("Saved file: " + image_path) - - if overwrite: - delete_same_name_files(path_without_extension, preview_extensions, image_extension) - - # detect (and try to fix) wrong file extension - image_format = None - with Image.open(image_path) as image: - image_format = image.format - image_dir_and_name, image_ext = os.path.splitext(image_path) - if not image_format_is_equal(image_format, image_ext): - corrected_image_path = image_dir_and_name + "." + image_format.lower() - if os.path.exists(corrected_image_path) and not overwrite: - print("WARNING: '" + image_path + "' has wrong extension!") - else: - os.rename(image_path, corrected_image_path) - print("Saved file: " + corrected_image_path) - image_path = corrected_image_path - return image_path # return in-case need corrected path - - -@server.PromptServer.instance.routes.post("/model-manager/preview/set") -async def set_model_preview(request): - formdata = await request.post() - try: - download_model_preview(formdata) - return web.json_response({ "success": True }) - except ValueError as e: - print(e, file=sys.stderr, flush=True) - return web.json_response({ - "success": False, - "alert": "Failed to set preview!\n\n" + str(e), - }) - - -@server.PromptServer.instance.routes.post("/model-manager/preview/delete") -async def delete_model_preview(request): - result = { "success": False } - - model_path = request.query.get("path", None) - if model_path is None: - result["alert"] = "Missing model path!" - return web.json_response(result) - model_path = urllib.parse.unquote(model_path) - - model_path, model_type = search_path_to_system_path(model_path) - model_extensions = folder_paths_get_supported_pt_extensions(model_type) - path_and_name, _ = split_valid_ext(model_path, model_extensions) - delete_same_name_files(path_and_name, preview_extensions) - - result["success"] = True - return web.json_response(result) - - -def correct_image_extensions(root_dir): - detected_image_count = 0 - corrected_image_count = 0 - for root, dirs, files in os.walk(root_dir): - for file_name in files: - file_path = root + os.path.sep + file_name - image_format = None - try: - with Image.open(file_path) as image: - image_format = image.format - except: - continue - image_path = file_path - image_dir_and_name, image_ext = os.path.splitext(image_path) - if not image_format_is_equal(image_format, image_ext): - detected_image_count += 1 - corrected_image_path = image_dir_and_name + "." + image_format.lower() - if os.path.exists(corrected_image_path): - print("WARNING: '" + image_path + "' has wrong extension!") - else: - try: - os.rename(image_path, corrected_image_path) - except: - print("WARNING: Unable to rename '" + image_path + "'!") - continue - ext0 = os.path.splitext(image_path)[1] - ext1 = os.path.splitext(corrected_image_path)[1] - print(f"({ext0} -> {ext1}): {corrected_image_path}") - corrected_image_count += 1 - return (detected_image_count, corrected_image_count) - - -@server.PromptServer.instance.routes.get("/model-manager/preview/correct-extensions") -async def correct_preview_extensions(request): - result = { "success": False } - - detected = 0 - corrected = 0 - - model_types = os.listdir(comfyui_model_uri) - model_types.remove("configs") - model_types.sort() - - for model_type in model_types: - for base_path_index, model_base_path in enumerate(folder_paths_get_folder_paths(model_type)): - if not os.path.exists(model_base_path): # TODO: Bug in main code? ("ComfyUI\output\checkpoints", "ComfyUI\output\clip", "ComfyUI\models\t2i_adapter", "ComfyUI\output\vae") - continue - d, c = correct_image_extensions(model_base_path) - detected += d - corrected += c - - result["success"] = True - result["detected"] = detected - result["corrected"] = corrected - return web.json_response(result) - - -@server.PromptServer.instance.routes.get("/model-manager/models/list") -async def get_model_list(request): - use_safetensor_thumbnail = ( - config_loader.yaml_load(ui_settings_uri, ui_rules()) - .get("model-preview-fallback-search-safetensors-thumbnail", False) - ) - - model_types = os.listdir(comfyui_model_uri) - model_types.remove("configs") - model_types.sort() - - models = {} - for model_type in model_types: - model_extensions = tuple(folder_paths_get_supported_pt_extensions(model_type)) - file_infos = [] - for base_path_index, model_base_path in enumerate(folder_paths_get_folder_paths(model_type)): - if not os.path.exists(model_base_path): # TODO: Bug in main code? ("ComfyUI\output\checkpoints", "ComfyUI\output\clip", "ComfyUI\models\t2i_adapter", "ComfyUI\output\vae") - continue - for cwd, subdirs, files in os.walk(model_base_path): - dir_models = [] - dir_images = [] - - for file in files: - if file.lower().endswith(model_extensions): - dir_models.append(file) - elif file.lower().endswith(preview_extensions): - dir_images.append(file) - - for model in dir_models: - model_name, model_ext = split_valid_ext(model, model_extensions) - image = None - image_modified = None - for ext in preview_extensions: # order matters - for iImage in range(len(dir_images)-1, -1, -1): - image_name = dir_images[iImage] - if not image_name.lower().endswith(ext.lower()): - continue - image_name = image_name[:-len(ext)] - if model_name == image_name: - image = end_swap_and_pop(dir_images, iImage) - img_abs_path = os.path.join(cwd, image) - image_modified = pathlib.Path(img_abs_path).stat().st_mtime_ns - break - if image is not None: - break - abs_path = os.path.join(cwd, model) - stats = pathlib.Path(abs_path).stat() - sizeBytes = stats.st_size - model_modified = stats.st_mtime_ns - model_created = stats.st_ctime_ns - if use_safetensor_thumbnail and image is None and model_ext == ".safetensors": - # try to fallback on safetensor embedded thumbnail - header = get_safetensor_header(abs_path) - metadata = header.get("__metadata__", None) - if metadata is not None: - thumbnail = metadata.get("modelspec.thumbnail", None) - if thumbnail is not None: - i0 = thumbnail.find("/") + 1 - i1 = thumbnail.find(";") - image_ext = "." + thumbnail[i0:i1] - if image_ext in image_extensions: - image = model + image_ext - image_modified = model_modified - rel_path = "" if cwd == model_base_path else os.path.relpath(cwd, model_base_path) - info = ( - model, - image, - base_path_index, - rel_path, - model_modified, - model_created, - image_modified, - sizeBytes, - ) - file_infos.append(info) - #file_infos.sort(key=lambda tup: tup[4], reverse=True) # TODO: remove sort; sorted on client - - model_items = [] - for model, image, base_path_index, rel_path, model_modified, model_created, image_modified, sizeBytes in file_infos: - item = { - "name": model, - "path": "/" + os.path.join(model_type, str(base_path_index), rel_path, model).replace(os.path.sep, "/"), # relative logical path - #"systemPath": os.path.join(rel_path, model), # relative system path (less information than "search path") - "dateModified": model_modified, - "dateCreated": model_created, - #"dateLastUsed": "", # TODO: track server-side, send increment client-side - #"countUsed": 0, # TODO: track server-side, send increment client-side - "sizeBytes": sizeBytes, - } - if image is not None: - raw_post = os.path.join(model_type, str(base_path_index), rel_path, image) - item["preview"] = { - "path": urllib.parse.quote_plus(raw_post), - "dateModified": urllib.parse.quote_plus(str(image_modified)), - } - model_items.append(item) - - models[model_type] = model_items - - return web.json_response(models) - - -def linear_directory_hierarchy(refresh = False): - model_paths = folder_paths_folder_names_and_paths(refresh) - dir_list = [] - dir_list.append({ "name": "", "childIndex": 1, "childCount": len(model_paths) }) - for model_dir_name, (model_dirs, _) in model_paths.items(): - dir_list.append({ "name": model_dir_name, "childIndex": None, "childCount": len(model_dirs) }) - for model_dir_index, (_, (model_dirs, extension_whitelist)) in enumerate(model_paths.items()): - model_dir_child_index = len(dir_list) - dir_list[model_dir_index + 1]["childIndex"] = model_dir_child_index - for dir_path_index, dir_path in enumerate(model_dirs): - dir_list.append({ "name": str(dir_path_index), "childIndex": None, "childCount": None }) - for dir_path_index, dir_path in enumerate(model_dirs): - if not os.path.exists(dir_path) or os.path.isfile(dir_path): - continue - - #dir_list.append({ "name": str(dir_path_index), "childIndex": None, "childCount": 0 }) - dir_stack = [(dir_path, model_dir_child_index + dir_path_index)] - while len(dir_stack) > 0: # DEPTH-FIRST - dir_path, dir_index = dir_stack.pop() - - dir_items = os.listdir(dir_path) - dir_items = sorted(dir_items, key=str.casefold) - - dir_child_count = 0 - - # TODO: sort content of directory: alphabetically - # TODO: sort content of directory: files first - - subdirs = [] - for item_name in dir_items: # BREADTH-FIRST - item_path = os.path.join(dir_path, item_name) - if os.path.isdir(item_path): - # dir - subdir_index = len(dir_list) # this must be done BEFORE `dir_list.append` - subdirs.append((item_path, subdir_index)) - dir_list.append({ "name": item_name, "childIndex": None, "childCount": 0 }) - dir_child_count += 1 - else: - # file - if extension_whitelist is None or split_valid_ext(item_name, extension_whitelist)[1] != "": - dir_list.append({ "name": item_name }) - dir_child_count += 1 - if dir_child_count > 0: - dir_list[dir_index]["childIndex"] = len(dir_list) - dir_child_count - dir_list[dir_index]["childCount"] = dir_child_count - subdirs.reverse() - for dir_path, subdir_index in subdirs: - dir_stack.append((dir_path, subdir_index)) - return dir_list - - -@server.PromptServer.instance.routes.get("/model-manager/models/directory-list") -async def get_directory_list(request): - #body = await request.json() - dir_list = linear_directory_hierarchy(True) - #json.dump(dir_list, sys.stdout, indent=4) - return web.json_response(dir_list) - - -def download_file(url, filename, overwrite): - if not overwrite and os.path.isfile(filename): - raise ValueError("File already exists!") - - filename_temp = filename + ".download" - - def_headers = get_def_headers(url) - rh = requests.get( - url=url, - stream=True, - verify=False, - headers=def_headers, - proxies=None, - allow_redirects=False, - ) - if not rh.ok: - raise ValueError( - "Unable to download! Request header status code: " + - str(rh.status_code) - ) - - downloaded_size = 0 - if rh.status_code == 200 and os.path.exists(filename_temp): - downloaded_size = os.path.getsize(filename_temp) - - headers = {"Range": "bytes=%d-" % downloaded_size} - headers["User-Agent"] = def_headers["User-Agent"] - headers["Authorization"] = def_headers.get("Authorization", None) - - r = requests.get( - url=url, - stream=True, - verify=False, - headers=headers, - proxies=None, - allow_redirects=False, - ) - if rh.status_code == 307 and r.status_code == 307: - # Civitai redirect - redirect_url = r.content.decode("utf-8") - if not redirect_url.startswith("http"): - # Civitai requires login (NSFW or user-required) - # TODO: inform user WHY download failed - raise ValueError("Unable to download from Civitai! Redirect url: " + str(redirect_url)) - download_file(redirect_url, filename, overwrite) - return - if rh.status_code == 302 and r.status_code == 302: - # HuggingFace redirect - redirect_url = r.content.decode("utf-8") - redirect_url_index = redirect_url.find("http") - if redirect_url_index == -1: - raise ValueError("Unable to download from HuggingFace! Redirect url: " + str(redirect_url)) - download_file(redirect_url[redirect_url_index:], filename, overwrite) - return - elif rh.status_code == 200 and r.status_code == 206: - # Civitai download link - pass - - total_size = int(rh.headers.get("Content-Length", 0)) # TODO: pass in total size earlier - - print("Downloading file: " + url) - if total_size != 0: - print("Download file size: " + str(total_size)) - - mode = "wb" if overwrite else "ab" - with open(filename_temp, mode) as f: - for chunk in r.iter_content(chunk_size=1024): - if chunk is not None: - downloaded_size += len(chunk) - f.write(chunk) - f.flush() - - if total_size != 0: - fraction = 1 if downloaded_size == total_size else downloaded_size / total_size - progress = int(50 * fraction) - sys.stdout.reconfigure(encoding="utf-8") - sys.stdout.write( - "\r[%s%s] %d%%" - % ( - "-" * progress, - " " * (50 - progress), - 100 * fraction, - ) - ) - sys.stdout.flush() - print() - - if overwrite and os.path.isfile(filename): - os.remove(filename) - os.rename(filename_temp, filename) - print("Saved file: " + filename) - - -def bytes_to_size(total_bytes): - units = ["B", "KiB", "MiB", "GiB", "TiB", "PiB"] - b = total_bytes - i = 0 - while True: - b = b >> 10 - if (b == 0): break - i = i + 1 - if i >= len(units) or i == 0: - return str(total_bytes) + " " + units[0] - return "{:.2f}".format(total_bytes / (1 << (i * 10))) + " " + units[i] - - -@server.PromptServer.instance.routes.get("/model-manager/model/info") -async def get_model_info(request): - result = { "success": False } - - model_path = request.query.get("path", None) - if model_path is None: - result["alert"] = "Missing model path!" - return web.json_response(result) - model_path = urllib.parse.unquote(model_path) - - abs_path, model_type = search_path_to_system_path(model_path) - if abs_path is None: - result["alert"] = "Invalid model path!" - return web.json_response(result) - - info = {} - comfyui_directory, name = os.path.split(model_path) - info["File Name"] = name - info["File Directory"] = comfyui_directory - info["File Size"] = bytes_to_size(os.path.getsize(abs_path)) - stats = pathlib.Path(abs_path).stat() - date_format = "%Y-%m-%d %H:%M:%S" - date_modified = datetime.fromtimestamp(stats.st_mtime).strftime(date_format) - #info["Date Modified"] = date_modified - #info["Date Created"] = datetime.fromtimestamp(stats.st_ctime).strftime(date_format) - - model_extensions = folder_paths_get_supported_pt_extensions(model_type) - abs_name , _ = split_valid_ext(abs_path, model_extensions) - - for extension in preview_extensions: - maybe_preview = abs_name + extension - if os.path.isfile(maybe_preview): - preview_path, _ = split_valid_ext(model_path, model_extensions) - preview_modified = pathlib.Path(maybe_preview).stat().st_mtime_ns - info["Preview"] = { - "path": urllib.parse.quote_plus(preview_path + extension), - "dateModified": urllib.parse.quote_plus(str(preview_modified)), - } - break - - header = get_safetensor_header(abs_path) - metadata = header.get("__metadata__", None) - - if metadata is not None and info.get("Preview", None) is None: - thumbnail = metadata.get("modelspec.thumbnail") - if thumbnail is not None: - i0 = thumbnail.find("/") + 1 - i1 = thumbnail.find(";", i0) - thumbnail_extension = "." + thumbnail[i0:i1] - if thumbnail_extension in image_extensions: - info["Preview"] = { - "path": request.query["path"] + thumbnail_extension, - "dateModified": date_modified, - } - - if metadata is not None: - info["Base Training Model"] = metadata.get("ss_sd_model_name", "") - info["Base Model Version"] = metadata.get("ss_base_model_version", "") - info["Network Dimension"] = metadata.get("ss_network_dim", "") - info["Network Alpha"] = metadata.get("ss_network_alpha", "") - - if metadata is not None: - training_comment = metadata.get("ss_training_comment", "") - info["Description"] = ( - metadata.get("modelspec.description", "") + - "\n\n" + - metadata.get("modelspec.usage_hint", "") + - "\n\n" + - training_comment if training_comment != "None" else "" - ).strip() - - info_text_file = abs_name + model_info_extension - notes = "" - if os.path.isfile(info_text_file): - with open(info_text_file, 'r', encoding="utf-8") as f: - notes = f.read() - - if metadata is not None: - img_buckets = metadata.get("ss_bucket_info", None) - datasets = metadata.get("ss_datasets", None) - - if type(img_buckets) is str: - img_buckets = json.loads(img_buckets) - elif type(datasets) is str: - datasets = json.loads(datasets) - if isinstance(datasets, list): - datasets = datasets[0] - img_buckets = datasets.get("bucket_info", None) - resolutions = {} - if img_buckets is not None: - buckets = img_buckets.get("buckets", {}) - for resolution in buckets.values(): - dim = resolution["resolution"] - x, y = dim[0], dim[1] - count = resolution["count"] - resolutions[str(x) + "x" + str(y)] = count - resolutions = list(resolutions.items()) - resolutions.sort(key=lambda x: x[1], reverse=True) - info["Bucket Resolutions"] = resolutions - - tags = None - if metadata is not None: - dir_tags = metadata.get("ss_tag_frequency", "{}") - if type(dir_tags) is str: - dir_tags = json.loads(dir_tags) - tags = {} - for train_tags in dir_tags.values(): - for tag, count in train_tags.items(): - tags[tag] = tags.get(tag, 0) + count - tags = list(tags.items()) - tags.sort(key=lambda x: x[1], reverse=True) - - result["success"] = True - result["info"] = info - if metadata is not None: - result["metadata"] = metadata - if tags is not None: - result["tags"] = tags - result["notes"] = notes - return web.json_response(result) - - -@server.PromptServer.instance.routes.get("/model-manager/system-separator") -async def get_system_separator(request): - return web.json_response(os.path.sep) - - -@server.PromptServer.instance.routes.post("/model-manager/model/download") -async def download_model(request): - formdata = await request.post() - result = { "success": False } - - overwrite = formdata.get("overwrite", "false").lower() - overwrite = True if overwrite == "true" else False - - model_path = formdata.get("path", "/0") - directory, model_type = search_path_to_system_path(model_path) - if directory is None: - result["alert"] = "Invalid save path!" - return web.json_response(result) - - download_uri = formdata.get("download") - if download_uri is None: - result["alert"] = "Invalid download url!" - return web.json_response(result) - - name = formdata.get("name") - model_extensions = folder_paths_get_supported_pt_extensions(model_type) - name_head, model_extension = split_valid_ext(name, model_extensions) - name_without_extension = os.path.split(name_head)[1] - if name_without_extension == "": - result["alert"] = "Cannot have empty model name!" - return web.json_response(result) - if model_extension == "": - result["alert"] = "Unrecognized model extension!" - return web.json_response(result) - file_name = os.path.join(directory, name) - try: - download_file(download_uri, file_name, overwrite) + task_id = await services.create_model_download_task(post) + return web.json_response({"success": True, "data": {"taskId": task_id}}) except Exception as e: - print(e, file=sys.stderr, flush=True) - result["alert"] = "Failed to download model!\n\n" + str(e) - return web.json_response(result) - - image = formdata.get("image") - if image is not None and image != "": - try: - download_model_preview({ - "path": model_path + os.sep + name, - "image": image, - "overwrite": formdata.get("overwrite"), - }) - except Exception as e: - print(e, file=sys.stderr, flush=True) - result["alert"] = "Failed to download preview!\n\n" + str(e) - - result["success"] = True - return web.json_response(result) + error_msg = f"Create model download task failed: {str(e)}" + logging.error(error_msg) + logging.debug(traceback.format_exc()) + return web.json_response({"success": False, "error": error_msg}) -@server.PromptServer.instance.routes.post("/model-manager/model/move") -async def move_model(request): - body = await request.json() - result = { "success": False } - - old_file = body.get("oldFile", None) - if old_file is None: - result["alert"] = "No model was given!" - return web.json_response(result) - old_file, old_model_type = search_path_to_system_path(old_file) - if not os.path.isfile(old_file): - result["alert"] = "Model does not exist!" - return web.json_response(result) - old_model_extensions = folder_paths_get_supported_pt_extensions(old_model_type) - old_file_without_extension, model_extension = split_valid_ext(old_file, old_model_extensions) - if model_extension == "": - result["alert"] = "Invalid model extension!" - return web.json_response(result) - - new_file = body.get("newFile", None) - if new_file is None or new_file == "": - result["alert"] = "New model name was invalid!" - return web.json_response(result) - new_file, new_model_type = search_path_to_system_path(new_file) - if not new_file.endswith(model_extension): - result["alert"] = "Cannot change model extension!" - return web.json_response(result) - if os.path.isfile(new_file): - result["alert"] = "Cannot overwrite existing model!" - return web.json_response(result) - new_model_extensions = folder_paths_get_supported_pt_extensions(new_model_type) - new_file_without_extension, new_model_extension = split_valid_ext(new_file, new_model_extensions) - if model_extension != new_model_extension: - result["alert"] = "Cannot change model extension!" - return web.json_response(result) - new_file_dir, new_file_name = os.path.split(new_file) - if not os.path.isdir(new_file_dir): - result["alert"] = "Destination directory does not exist!" - return web.json_response(result) - new_name_without_extension = os.path.splitext(new_file_name)[0] - if new_file_name == new_name_without_extension or new_name_without_extension == "": - result["alert"] = "New model name was empty!" - return web.json_response(result) - - if old_file == new_file: - # no-op - result["success"] = True - return web.json_response(result) +@routes.get("/model-manager/models") +async def read_models(request): + """ + Scan all models and read their information. + """ try: - shutil.move(old_file, new_file) - print("Moved file: " + new_file) - except ValueError as e: - print(e, file=sys.stderr, flush=True) - result["alert"] = "Failed to move model!\n\n" + str(e) - return web.json_response(result) - # TODO: this could overwrite existing files in destination; do a check beforehand? - for extension in preview_extensions + (model_info_extension,): - old_file = old_file_without_extension + extension - if os.path.isfile(old_file): - new_file = new_file_without_extension + extension - try: - shutil.move(old_file, new_file) - print("Moved file: " + new_file) - except ValueError as e: - print(e, file=sys.stderr, flush=True) - msg = result.get("alert","") - if msg == "": - result["alert"] = "Failed to move model resource file!\n\n" + str(e) - else: - result["alert"] = msg + "\n" + str(e) - - result["success"] = True - return web.json_response(result) + result = [] + model_base_paths = config.model_base_paths + for model_type in model_base_paths: + result.extend(services.scan_models_by_model_type(model_type)) + result = [{"id": i, **x} for i, x in enumerate(result)] + return web.json_response({"success": True, "data": result}) + except Exception as e: + error_msg = f"Read models failed: {str(e)}" + logging.error(error_msg) + logging.debug(traceback.format_exc()) + return web.json_response({"success": False, "error": error_msg}) -def delete_same_name_files(path_without_extension, extensions, keep_extension=None): - for extension in extensions: - if extension == keep_extension: continue - file = path_without_extension + extension - if os.path.isfile(file): - os.remove(file) - print("Deleted file: " + file) +@routes.put("/model-manager/model/{type}/{index}/{filename:.*}") +async def update_model(request): + """ + Update model information. + + request body: x-www-form-urlencoded + - previewFile: preview file. + - description: description. + - type: model type. + - pathIndex: index of the model folders. + - fullname: filename that relative to the model folder. + All fields are optional, but type, pathIndex and fullname must appear together. + """ + model_type = request.match_info.get("type", None) + index = int(request.match_info.get("index", None)) + filename = request.match_info.get("filename", None) + + post: dict = await request.post() + + try: + model_path = utils.get_valid_full_path(model_type, index, filename) + if model_path is None: + raise RuntimeError(f"File {filename} not found") + services.update_model(model_path, post) + return web.json_response({"success": True}) + except Exception as e: + error_msg = f"Update model failed: {str(e)}" + logging.error(error_msg) + logging.debug(traceback.format_exc()) + return web.json_response({"success": False, "error": error_msg}) -@server.PromptServer.instance.routes.post("/model-manager/model/delete") +@routes.delete("/model-manager/model/{type}/{index}/{filename:.*}") async def delete_model(request): - result = { "success": False } + """ + Delete model. + """ + model_type = request.match_info.get("type", None) + index = int(request.match_info.get("index", None)) + filename = request.match_info.get("filename", None) - model_path = request.query.get("path", None) - if model_path is None: - result["alert"] = "Missing model path!" - return web.json_response(result) - model_path = urllib.parse.unquote(model_path) - model_path, model_type = search_path_to_system_path(model_path) - if model_path is None: - result["alert"] = "Invalid model path!" - return web.json_response(result) - - model_extensions = folder_paths_get_supported_pt_extensions(model_type) - path_and_name, model_extension = split_valid_ext(model_path, model_extensions) - if model_extension == "": - result["alert"] = "Cannot delete file!" - return web.json_response(result) - - if os.path.isfile(model_path): - os.remove(model_path) - result["success"] = True - print("Deleted file: " + model_path) - - delete_same_name_files(path_and_name, preview_extensions) - delete_same_name_files(path_and_name, (model_info_extension,)) - - return web.json_response(result) + try: + model_path = utils.get_valid_full_path(model_type, index, filename) + if model_path is None: + raise RuntimeError(f"File {filename} not found") + services.remove_model(model_path) + return web.json_response({"success": True}) + except Exception as e: + error_msg = f"Delete model failed: {str(e)}" + logging.error(error_msg) + logging.debug(traceback.format_exc()) + return web.json_response({"success": False, "error": error_msg}) -@server.PromptServer.instance.routes.post("/model-manager/notes/save") -async def set_notes(request): - body = await request.json() - result = { "success": False } +@routes.get("/model-manager/preview/{type}/{index}/{filename:.*}") +async def read_model_preview(request): + """ + Get the file stream of the specified image. + If the file does not exist, no-preview.png is returned. - dt_epoch = body.get("timestamp", None) + :param type: The type of the model. eg.checkpoints, loras, vae, etc. + :param index: The index of the model folders. + :param filename: The filename of the image. + """ + model_type = request.match_info.get("type", None) + index = int(request.match_info.get("index", None)) + filename = request.match_info.get("filename", None) - text = body.get("notes", None) - if type(text) is not str: - result["alert"] = "Invalid note!" - return web.json_response(result) + extension_uri = config.extension_uri - model_path = body.get("path", None) - if type(model_path) is not str: - result["alert"] = "Missing model path!" - return web.json_response(result) - model_path, model_type = search_path_to_system_path(model_path) - model_extensions = folder_paths_get_supported_pt_extensions(model_type) - file_path_without_extension, _ = split_valid_ext(model_path, model_extensions) - filename = os.path.normpath(file_path_without_extension + model_info_extension) - - if dt_epoch is not None and os.path.exists(filename) and os.path.getmtime(filename) > dt_epoch: - # discard late save - result["success"] = True - return web.json_response(result) - - if text.isspace() or text == "": - if os.path.exists(filename): - os.remove(filename) - #print("Deleted file: " + filename) # autosave -> too verbose - else: - try: - with open(filename, "w", encoding="utf-8") as f: - f.write(text) - if dt_epoch is not None: - os.utime(filename, (dt_epoch, dt_epoch)) - #print("Saved file: " + filename) # autosave -> too verbose - except ValueError as e: - print(e, file=sys.stderr, flush=True) - result["alert"] = "Failed to save notes!\n\n" + str(e) - web.json_response(result) + try: + folders = folder_paths.get_folder_paths(model_type) + base_path = folders[index] + abs_path = os.path.join(base_path, filename) + except: + abs_path = extension_uri - result["success"] = True - return web.json_response(result) + if not os.path.isfile(abs_path): + abs_path = os.path.join(extension_uri, "assets", "no-preview.png") + return web.FileResponse(abs_path) + + +@routes.get("/model-manager/preview/download/{filename}") +async def read_download_preview(request): + filename = request.match_info.get("filename", None) + extension_uri = config.extension_uri + + download_path = utils.get_download_path() + preview_path = os.path.join(download_path, filename) + + if not os.path.isfile(preview_path): + preview_path = os.path.join(extension_uri, "assets", "no-preview.png") + + return web.FileResponse(preview_path) WEB_DIRECTORY = "web" NODE_CLASS_MAPPINGS = {} -__all__ = ["NODE_CLASS_MAPPINGS"] +__all__ = ["WEB_DIRECTORY", "NODE_CLASS_MAPPINGS"] diff --git a/no-preview.png b/assets/no-preview.png similarity index 100% rename from no-preview.png rename to assets/no-preview.png diff --git a/config_loader.py b/config_loader.py deleted file mode 100644 index 2f9e895..0000000 --- a/config_loader.py +++ /dev/null @@ -1,72 +0,0 @@ -import yaml -from dataclasses import dataclass - -@dataclass -class Rule: - key: any - value_default: any - value_type: type - value_min: any # int | float | None - value_max: any # int | float | None - - def __init__( - self, - key, - value_default, - value_type: type, - value_min: any = None, # int | float | None - value_max: any = None, # int | float | None - ): - self.key = key - self.value_default = value_default - self.value_type = value_type - self.value_min = value_min - self.value_max = value_max - -def _get_valid_value(data: dict, r: Rule): - if r.value_type != type(r.value_default): - raise Exception(f"'value_type' does not match type of 'value_default'!") - value = data.get(r.key) - if value is None: - value = r.value_default - else: - try: - value = r.value_type(value) - except: - value = r.value_default - - value_is_numeric = r.value_type == int or r.value_type == float - if value_is_numeric and r.value_min: - if r.value_type != type(r.value_min): - raise Exception(f"Type of 'value_type' does not match the type of 'value_min'!") - value = max(r.value_min, value) - if value_is_numeric and r.value_max: - if r.value_type != type(r.value_max): - raise Exception(f"Type of 'value_type' does not match the type of 'value_max'!") - value = min(r.value_max, value) - - return value - -def validated(rules: list[Rule], data: dict = {}): - valid = {} - for r in rules: - valid[r.key] = _get_valid_value(data, r) - return valid - -def yaml_load(path, rules: list[Rule]): - data = {} - try: - with open(path, 'r') as file: - data = yaml.safe_load(file) - except: - pass - return validated(rules, data) - -def yaml_save(path, rules: list[Rule], data: dict) -> bool: - data = validated(rules, data) - try: - with open(path, 'w') as file: - yaml.dump(data, file) - return True - except: - return False diff --git a/demo/beta-menu-model-manager-button-settings-group.png b/demo/beta-menu-model-manager-button-settings-group.png deleted file mode 100644 index e5f0d7e..0000000 Binary files a/demo/beta-menu-model-manager-button-settings-group.png and /dev/null differ diff --git a/demo/tab-download.png b/demo/tab-download.png index 4e5f7a7..50ffa5f 100644 Binary files a/demo/tab-download.png and b/demo/tab-download.png differ diff --git a/demo/tab-model-drag-add.gif b/demo/tab-model-drag-add.gif deleted file mode 100644 index 897b474..0000000 Binary files a/demo/tab-model-drag-add.gif and /dev/null differ diff --git a/demo/tab-model-info-overview.png b/demo/tab-model-info-overview.png old mode 100644 new mode 100755 index 0637e75..bef89c4 Binary files a/demo/tab-model-info-overview.png and b/demo/tab-model-info-overview.png differ diff --git a/demo/tab-model-node-graph.gif b/demo/tab-model-node-graph.gif new file mode 100755 index 0000000..638bcdc Binary files /dev/null and b/demo/tab-model-node-graph.gif differ diff --git a/demo/tab-model-preview-thumbnail-buttons-example.png b/demo/tab-model-preview-thumbnail-buttons-example.png deleted file mode 100644 index 8f96ba6..0000000 Binary files a/demo/tab-model-preview-thumbnail-buttons-example.png and /dev/null differ diff --git a/demo/tab-models-dropdown.png b/demo/tab-models-dropdown.png deleted file mode 100644 index e23f763..0000000 Binary files a/demo/tab-models-dropdown.png and /dev/null differ diff --git a/demo/tab-models.gif b/demo/tab-models.gif new file mode 100755 index 0000000..b3e28f0 Binary files /dev/null and b/demo/tab-models.gif differ diff --git a/demo/tab-models.png b/demo/tab-models.png old mode 100644 new mode 100755 index 672ecea..f2a8825 Binary files a/demo/tab-models.png and b/demo/tab-models.png differ diff --git a/demo/tab-settings.png b/demo/tab-settings.png old mode 100644 new mode 100755 index 13ee3ef..04f3b49 Binary files a/demo/tab-settings.png and b/demo/tab-settings.png differ diff --git a/package.json b/package.json index 4b3fab8..0c1f49e 100644 --- a/package.json +++ b/package.json @@ -4,12 +4,16 @@ "version": "1.0.0", "type": "module", "scripts": { - "dev": "vite build --watch", + "dev": "vite", "build": "vite build", "prepare": "husky" }, "devDependencies": { + "@tailwindcss/container-queries": "^0.1.1", + "@types/lodash": "^4.17.9", + "@types/markdown-it": "^14.1.2", "@types/node": "^22.5.5", + "@types/turndown": "^5.0.5", "@vitejs/plugin-vue": "^5.1.4", "autoprefixer": "^10.4.20", "eslint": "^9.10.0", @@ -20,15 +24,23 @@ "postcss": "^8.4.47", "prettier": "^3.3.3", "prettier-plugin-organize-imports": "^4.1.0", + "prettier-plugin-tailwindcss": "^0.6.8", "tailwindcss": "^3.4.12", "typescript": "^5.6.2", "typescript-eslint": "^8.6.0", "vite": "^5.4.6" }, "dependencies": { + "@primevue/themes": "^4.0.7", + "dayjs": "^1.11.13", + "lodash": "^4.17.21", + "markdown-it": "^14.1.0", + "markdown-it-metadata-block": "^1.0.6", "primevue": "^4.0.7", + "turndown": "^7.2.0", "vue": "^3.4.31", - "vue-i18n": "^9.13.1" + "vue-i18n": "^9.13.1", + "yaml": "^2.6.0" }, "lint-staged": { "./**/*.{js,ts,tsx,vue}": [ diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index f75ee96..f155679 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -8,19 +8,52 @@ importers: .: dependencies: + '@primevue/themes': + specifier: ^4.0.7 + version: 4.0.7 + dayjs: + specifier: ^1.11.13 + version: 1.11.13 + lodash: + specifier: ^4.17.21 + version: 4.17.21 + markdown-it: + specifier: ^14.1.0 + version: 14.1.0 + markdown-it-metadata-block: + specifier: ^1.0.6 + version: 1.0.6 primevue: specifier: ^4.0.7 version: 4.0.7(vue@3.5.6(typescript@5.6.2)) + turndown: + specifier: ^7.2.0 + version: 7.2.0 vue: specifier: ^3.4.31 version: 3.5.6(typescript@5.6.2) vue-i18n: specifier: ^9.13.1 version: 9.14.0(vue@3.5.6(typescript@5.6.2)) + yaml: + specifier: ^2.6.0 + version: 2.6.0 devDependencies: + '@tailwindcss/container-queries': + specifier: ^0.1.1 + version: 0.1.1(tailwindcss@3.4.12) + '@types/lodash': + specifier: ^4.17.9 + version: 4.17.9 + '@types/markdown-it': + specifier: ^14.1.2 + version: 14.1.2 '@types/node': specifier: ^22.5.5 version: 22.5.5 + '@types/turndown': + specifier: ^5.0.5 + version: 5.0.5 '@vitejs/plugin-vue': specifier: ^5.1.4 version: 5.1.4(vite@5.4.6(@types/node@22.5.5)(less@4.2.0))(vue@3.5.6(typescript@5.6.2)) @@ -51,6 +84,9 @@ importers: prettier-plugin-organize-imports: specifier: ^4.1.0 version: 4.1.0(prettier@3.3.3)(typescript@5.6.2) + prettier-plugin-tailwindcss: + specifier: ^0.6.8 + version: 0.6.8(prettier-plugin-organize-imports@4.1.0(prettier@3.3.3)(typescript@5.6.2))(prettier@3.3.3) tailwindcss: specifier: ^3.4.12 version: 3.4.12 @@ -297,6 +333,9 @@ packages: '@jridgewell/trace-mapping@0.3.25': resolution: {integrity: sha512-vNk6aEwybGtawWmy/PzwnGDOjCkLWSD2wqvjGGAgOAwCGWySYXfYoxt00IJkTF+8Lb57DwOb3Aa0o9CApepiYQ==} + '@mixmark-io/domino@2.2.0': + resolution: {integrity: sha512-Y28PR25bHXUg88kCV7nivXrP2Nj2RueZ3/l/jdx6J9f8J4nsEGcgX0Qe6lt7Pa+J79+kPiJU3LguR6O/6zrLOw==} + '@nodelib/fs.scandir@2.1.5': resolution: {integrity: sha512-vq24Bq3ym5HEQm2NKCr3yXDwjc7vTsEThRDnkp2DK9p1uqLR+DHurm/NOTo0KG7HYHU7eppKZj3MyqYuMBf62g==} engines: {node: '>= 8'} @@ -331,6 +370,10 @@ packages: resolution: {integrity: sha512-tj4dfRdV5iN6O0mbkpjhMsGlT3wZTqOPL779ndY5gKuCwN5zcFmKmABWVQmr/ClRivnMkw6Yr1x6gRTV/N0ydg==} engines: {node: '>=12.11.0'} + '@primevue/themes@4.0.7': + resolution: {integrity: sha512-ZbDUrpBmtuqdeegNwUaJTubaLDBBJWOc4Z6UoQM3DG2c7EAE19wQbuh+cG9zqA7sT/Xsp+ACC/Z9e4FnfqB55g==} + engines: {node: '>=12.11.0'} + '@rollup/rollup-android-arm-eabi@4.22.0': resolution: {integrity: sha512-/IZQvg6ZR0tAkEi4tdXOraQoWeJy9gbQ/cx4I7k9dJaCk9qrXEcdouxRVz5kZXt5C2bQ9pILoAA+KB4C/d3pfw==} cpu: [arm] @@ -420,12 +463,32 @@ packages: cpu: [x64] os: [win32] + '@tailwindcss/container-queries@0.1.1': + resolution: {integrity: sha512-p18dswChx6WnTSaJCSGx6lTmrGzNNvm2FtXmiO6AuA1V4U5REyoqwmT6kgAsIMdjo07QdAfYXHJ4hnMtfHzWgA==} + peerDependencies: + tailwindcss: '>=3.2.0' + '@types/estree@1.0.5': resolution: {integrity: sha512-/kYRxGDLWzHOB7q+wtSUQlFrtcdUccpfy+X+9iMBpHK8QLLhx2wIPYuS5DYtR9Wa/YlZAbIovy7qVdB1Aq6Lyw==} + '@types/linkify-it@5.0.0': + resolution: {integrity: sha512-sVDA58zAw4eWAffKOaQH5/5j3XeayukzDk+ewSsnv3p4yJEZHCCzMDiZM8e0OUrRvmpGZ85jf4yDHkHsgBNr9Q==} + + '@types/lodash@4.17.9': + resolution: {integrity: sha512-w9iWudx1XWOHW5lQRS9iKpK/XuRhnN+0T7HvdCCd802FYkT1AMTnxndJHGrNJwRoRHkslGr4S29tjm1cT7x/7w==} + + '@types/markdown-it@14.1.2': + resolution: {integrity: sha512-promo4eFwuiW+TfGxhi+0x3czqTYJkG8qB17ZUJiVF10Xm7NLVRSLUsfRTU/6h1e24VvRnXCx+hG7li58lkzog==} + + '@types/mdurl@2.0.0': + resolution: {integrity: sha512-RGdgjQUZba5p6QEFAVx2OGb8rQDL/cPRG7GiedRzMcJ1tYnUANBncjbSB1NRGwbvjcPeikRABz2nshyPk1bhWg==} + '@types/node@22.5.5': resolution: {integrity: sha512-Xjs4y5UPO/CLdzpgR6GirZJx36yScjh73+2NlLlkFRSoQN8B0DpfXPdZGnvVmLRLOsqDpOfTNv7D9trgGhmOIA==} + '@types/turndown@5.0.5': + resolution: {integrity: sha512-TL2IgGgc7B5j78rIccBtlYAnkuv8nUQqhQc+DSYV5j9Be9XOcm/SKOVRuA47xAVI3680Tk9B1d8flK2GWT2+4w==} + '@typescript-eslint/eslint-plugin@8.6.0': resolution: {integrity: sha512-UOaz/wFowmoh2G6Mr9gw60B1mm0MzUtm6Ic8G2yM1Le6gyj5Loi/N+O5mocugRGY+8OeeKmkMmbxNqUCq3B4Sg==} engines: {node: ^18.18.0 || ^20.9.0 || >=21.1.0} @@ -667,6 +730,9 @@ packages: csstype@3.1.3: resolution: {integrity: sha512-M1uQkMl8rQK/szD0LNhtqxIPLpimGm8sOBwU7lLnCpSbTyY3yeU1Vc7l4KT5zT4s/yOxHH5O7tIuuLOCnLADRw==} + dayjs@1.11.13: + resolution: {integrity: sha512-oaMBel6gjolK862uaPQOVTA7q3TZhuSvuMQAAglQDOWYO9A91IrAOUJEyKVlqJlHE0vq5p5UXxzdPfMH/x6xNg==} + debug@4.3.7: resolution: {integrity: sha512-Er2nc/H7RrMXZBFCEim6TCmMk02Z8vLC2Rbi1KEBggpo0fS6l0S1nnapwmIi3yW/+GOJap1Krg4w0Hg80oCqgQ==} engines: {node: '>=6.0'} @@ -1002,6 +1068,9 @@ packages: lines-and-columns@1.2.4: resolution: {integrity: sha512-7ylylesZQ/PV29jhEDl3Ufjo6ZX7gCqJr5F7PKrqc93v7fzSymt1BpwEU8nAUXs8qzzvqhbjhK5QZg6Mt/HkBg==} + linkify-it@5.0.0: + resolution: {integrity: sha512-5aHCbzQRADcdP+ATqnDuhhJ/MRIqDkZX5pyjFHRRysS8vZ5AbqGEoFIb6pYHPZ+L/OC2Lc+xT8uHVVR5CAK/wQ==} + lint-staged@15.2.10: resolution: {integrity: sha512-5dY5t743e1byO19P9I4b3x8HJwalIznL5E1FWYnU6OWw33KxNBSLAc6Cy7F2PsFEO8FKnLwjwm5hx7aMF0jzZg==} engines: {node: '>=18.12.0'} @@ -1035,6 +1104,16 @@ packages: resolution: {integrity: sha512-LS9X+dc8KLxXCb8dni79fLIIUA5VyZoyjSMCwTluaXA0o27cCK0bhXkpgw+sTXVpPy/lSO57ilRixqk0vDmtRA==} engines: {node: '>=6'} + markdown-it-metadata-block@1.0.6: + resolution: {integrity: sha512-0nMBdV/CLy/bFfcw3wFdiZ6sgEv/yWAoNxgb3qY+5lLEP804r/JT9yLmLH3Z3YrqGDHb5xIi7gqhj7gwbPHycQ==} + + markdown-it@14.1.0: + resolution: {integrity: sha512-a54IwgWPaeBCAAsv13YgmALOF1elABB08FxO9i+r4VFk5Vl4pKokRPeX8u5TCgSsPi6ec1otfLjdOpVcgbpshg==} + hasBin: true + + mdurl@2.0.0: + resolution: {integrity: sha512-Lf+9+2r+Tdp5wXDXC4PcIBjTDtq4UKjCPMQhKIuzpJNW0b96kVqSwW0bT7FhRSfmAiFYgP+SCRvdrDozfh0U5w==} + merge-stream@2.0.0: resolution: {integrity: sha512-abv/qOcuPfk3URPfDzmZU1LKmuw8kT+0nIHvKrKgFrwifol/doWcdA4ZqsWQ8ENrFKkd67Mfpo/LovbIUsbt3w==} @@ -1244,6 +1323,61 @@ packages: vue-tsc: optional: true + prettier-plugin-tailwindcss@0.6.8: + resolution: {integrity: sha512-dGu3kdm7SXPkiW4nzeWKCl3uoImdd5CTZEJGxyypEPL37Wj0HT2pLqjrvSei1nTeuQfO4PUfjeW5cTUNRLZ4sA==} + engines: {node: '>=14.21.3'} + peerDependencies: + '@ianvs/prettier-plugin-sort-imports': '*' + '@prettier/plugin-pug': '*' + '@shopify/prettier-plugin-liquid': '*' + '@trivago/prettier-plugin-sort-imports': '*' + '@zackad/prettier-plugin-twig-melody': '*' + prettier: ^3.0 + prettier-plugin-astro: '*' + prettier-plugin-css-order: '*' + prettier-plugin-import-sort: '*' + prettier-plugin-jsdoc: '*' + prettier-plugin-marko: '*' + prettier-plugin-multiline-arrays: '*' + prettier-plugin-organize-attributes: '*' + prettier-plugin-organize-imports: '*' + prettier-plugin-sort-imports: '*' + prettier-plugin-style-order: '*' + prettier-plugin-svelte: '*' + peerDependenciesMeta: + '@ianvs/prettier-plugin-sort-imports': + optional: true + '@prettier/plugin-pug': + optional: true + '@shopify/prettier-plugin-liquid': + optional: true + '@trivago/prettier-plugin-sort-imports': + optional: true + '@zackad/prettier-plugin-twig-melody': + optional: true + prettier-plugin-astro: + optional: true + prettier-plugin-css-order: + optional: true + prettier-plugin-import-sort: + optional: true + prettier-plugin-jsdoc: + optional: true + prettier-plugin-marko: + optional: true + prettier-plugin-multiline-arrays: + optional: true + prettier-plugin-organize-attributes: + optional: true + prettier-plugin-organize-imports: + optional: true + prettier-plugin-sort-imports: + optional: true + prettier-plugin-style-order: + optional: true + prettier-plugin-svelte: + optional: true + prettier@3.3.3: resolution: {integrity: sha512-i2tDNA0O5IrMO757lfrdQZCc2jPNDVntV0m/+4whiDfWaTKfMNgR7Qz0NAeGz/nRqF4m5/6CLzbP4/liHt12Ew==} engines: {node: '>=14'} @@ -1256,6 +1390,10 @@ packages: prr@1.0.1: resolution: {integrity: sha512-yPw4Sng1gWghHQWj0B3ZggWUm4qVbPwPFcRG8KyxiU7J2OHFSoEHKS+EZ3fv5l1t9CyCiop6l/ZYeWbrgoQejw==} + punycode.js@2.3.1: + resolution: {integrity: sha512-uxFIHU0YlHYhDQtV4R9J6a52SLx28BCjT+4ieh7IGbgwVJWO+km431c4yRlREUAsAmt/uMjQUyQHNEPf0M39CA==} + engines: {node: '>=6'} + punycode@2.3.1: resolution: {integrity: sha512-vYt7UD1U9Wg6138shLtLOvdAu+8DsC/ilFtEVHcH+wydcSpNE20AfSOduf6MkRFahL5FY7X1oU7nKVZFtfq8Fg==} engines: {node: '>=6'} @@ -1420,6 +1558,9 @@ packages: tslib@2.7.0: resolution: {integrity: sha512-gLXCKdN1/j47AiHiOkJN69hJmcbGTHI0ImLmbYLHykhgeN0jVGola9yVjFgzCUklsZQMW55o+dW7IXv3RCXDzA==} + turndown@7.2.0: + resolution: {integrity: sha512-eCZGBN4nNNqM9Owkv9HAtWRYfLA4h909E/WGAWWBpmB275ehNhZyk87/Tpvjbp0jjNl9XwCsbe6bm6CqFsgD+A==} + type-check@0.4.0: resolution: {integrity: sha512-XleUoc9uwGXqjWwXaUTZAmzMcFZ5858QA2vvx1Ur5xIcixXIP+8LnFDgRplU30us6teqdlskFfu+ae4K79Ooew==} engines: {node: '>= 0.8.0'} @@ -1442,6 +1583,9 @@ packages: engines: {node: '>=14.17'} hasBin: true + uc.micro@2.1.0: + resolution: {integrity: sha512-ARDJmphmdvUk6Glw7y9DQ2bFkKBHwQHLi2lsaH6PPmz/Ka9sFOBsBluozhDltWmnv9u/cF6Rt87znRTPV+yp/A==} + undici-types@6.19.8: resolution: {integrity: sha512-ve2KP6f/JnbPBFyobGHuerC9g1FYGn/F8n1LWTwNxCEzd6IfqTwUQcNXgEtmmQ6DlRrC1hrSrBnCZPokRrDHjw==} @@ -1538,6 +1682,11 @@ packages: engines: {node: '>= 14'} hasBin: true + yaml@2.6.0: + resolution: {integrity: sha512-a6ae//JvKDEra2kdi1qzCyrJW/WZCgFi8ydDV+eXExl95t+5R+ijnqHJbz9tmMh8FUjx3iv2fCQ4dclAQlO2UQ==} + engines: {node: '>= 14'} + hasBin: true + yocto-queue@0.1.0: resolution: {integrity: sha512-rVksvsnNCdJ/ohGc6xgPwyN8eheCxsiLM8mxuE/t/mOVqJewPuO1miLpTHQiRgTKCLexL4MeAFVagts7HmNZ2Q==} engines: {node: '>=10'} @@ -1708,6 +1857,8 @@ snapshots: '@jridgewell/resolve-uri': 3.1.2 '@jridgewell/sourcemap-codec': 1.5.0 + '@mixmark-io/domino@2.2.0': {} + '@nodelib/fs.scandir@2.1.5': dependencies: '@nodelib/fs.stat': 2.0.5 @@ -1742,6 +1893,10 @@ snapshots: transitivePeerDependencies: - vue + '@primevue/themes@4.0.7': + dependencies: + '@primeuix/styled': 0.0.5 + '@rollup/rollup-android-arm-eabi@4.22.0': optional: true @@ -1790,12 +1945,29 @@ snapshots: '@rollup/rollup-win32-x64-msvc@4.22.0': optional: true + '@tailwindcss/container-queries@0.1.1(tailwindcss@3.4.12)': + dependencies: + tailwindcss: 3.4.12 + '@types/estree@1.0.5': {} + '@types/linkify-it@5.0.0': {} + + '@types/lodash@4.17.9': {} + + '@types/markdown-it@14.1.2': + dependencies: + '@types/linkify-it': 5.0.0 + '@types/mdurl': 2.0.0 + + '@types/mdurl@2.0.0': {} + '@types/node@22.5.5': dependencies: undici-types: 6.19.8 + '@types/turndown@5.0.5': {} + '@typescript-eslint/eslint-plugin@8.6.0(@typescript-eslint/parser@8.6.0(eslint@9.10.0(jiti@1.21.6))(typescript@5.6.2))(eslint@9.10.0(jiti@1.21.6))(typescript@5.6.2)': dependencies: '@eslint-community/regexpp': 4.11.1 @@ -2074,6 +2246,8 @@ snapshots: csstype@3.1.3: {} + dayjs@1.11.13: {} + debug@4.3.7: dependencies: ms: 2.1.3 @@ -2430,6 +2604,10 @@ snapshots: lines-and-columns@1.2.4: {} + linkify-it@5.0.0: + dependencies: + uc.micro: 2.1.0 + lint-staged@15.2.10: dependencies: chalk: 5.3.0 @@ -2482,6 +2660,19 @@ snapshots: semver: 5.7.2 optional: true + markdown-it-metadata-block@1.0.6: {} + + markdown-it@14.1.0: + dependencies: + argparse: 2.0.1 + entities: 4.5.0 + linkify-it: 5.0.0 + mdurl: 2.0.0 + punycode.js: 2.3.1 + uc.micro: 2.1.0 + + mdurl@2.0.0: {} + merge-stream@2.0.0: {} merge2@1.4.1: {} @@ -2618,7 +2809,7 @@ snapshots: postcss-load-config@4.0.2(postcss@8.4.47): dependencies: lilconfig: 3.1.2 - yaml: 2.5.1 + yaml: 2.6.0 optionalDependencies: postcss: 8.4.47 @@ -2647,6 +2838,12 @@ snapshots: prettier: 3.3.3 typescript: 5.6.2 + prettier-plugin-tailwindcss@0.6.8(prettier-plugin-organize-imports@4.1.0(prettier@3.3.3)(typescript@5.6.2))(prettier@3.3.3): + dependencies: + prettier: 3.3.3 + optionalDependencies: + prettier-plugin-organize-imports: 4.1.0(prettier@3.3.3)(typescript@5.6.2) + prettier@3.3.3: {} primevue@4.0.7(vue@3.5.6(typescript@5.6.2)): @@ -2661,6 +2858,8 @@ snapshots: prr@1.0.1: optional: true + punycode.js@2.3.1: {} + punycode@2.3.1: {} queue-microtask@1.2.3: {} @@ -2849,6 +3048,10 @@ snapshots: tslib@2.7.0: {} + turndown@7.2.0: + dependencies: + '@mixmark-io/domino': 2.2.0 + type-check@0.4.0: dependencies: prelude-ls: 1.2.1 @@ -2868,6 +3071,8 @@ snapshots: typescript@5.6.2: {} + uc.micro@2.1.0: {} + undici-types@6.19.8: {} update-browserslist-db@1.1.0(browserslist@4.23.3): @@ -2950,4 +3155,6 @@ snapshots: yaml@2.5.1: {} + yaml@2.6.0: {} + yocto-queue@0.1.0: {} diff --git a/py/config.py b/py/config.py new file mode 100644 index 0000000..8efbdee --- /dev/null +++ b/py/config.py @@ -0,0 +1,33 @@ +extension_uri: str = None +model_base_paths: dict[str, list[str]] = {} + + +setting_key = { + "api_key": { + "civitai": "ModelManager.APIKey.Civitai", + "huggingface": "ModelManager.APIKey.HuggingFace", + }, + "download": { + "max_task_count": "ModelManager.Download.MaxTaskCount", + }, +} + +user_agent = "Mozilla/5.0 (iPad; CPU OS 12_2 like Mac OS X) AppleWebKit/605.1.15 (KHTML, like Gecko) Mobile/15E148" + + +from server import PromptServer + +serverInstance = PromptServer.instance +routes = serverInstance.routes + + +class FakeRequest: + def __init__(self): + self.headers = {} + + +class CustomException(BaseException): + def __init__(self, type: str, message: str = None) -> None: + self.type = type + self.message = message + super().__init__(message) diff --git a/py/download.py b/py/download.py new file mode 100644 index 0000000..ca8f42d --- /dev/null +++ b/py/download.py @@ -0,0 +1,362 @@ +import os +import uuid +import time +import logging +import requests +import folder_paths +import traceback +from typing import Callable, Awaitable, Any, Literal, Union, Optional +from dataclasses import dataclass +from . import config +from . import utils +from . import socket +from . import thread + + +@dataclass +class TaskStatus: + taskId: str + type: str + fullname: str + preview: str + status: Literal["pause", "waiting", "doing"] = "pause" + platform: Union[str, None] = None + downloadedSize: float = 0 + totalSize: float = 0 + progress: float = 0 + bps: float = 0 + error: Optional[str] = None + + +@dataclass +class TaskContent: + type: str + pathIndex: int + fullname: str + description: str + downloadPlatform: str + downloadUrl: str + sizeBytes: float + hashes: Optional[dict[str, str]] = None + + +download_model_task_status: dict[str, TaskStatus] = {} +download_thread_pool = thread.DownloadThreadPool() + + +def set_task_content(task_id: str, task_content: Union[TaskContent, dict]): + download_path = utils.get_download_path() + task_file_path = os.path.join(download_path, f"{task_id}.task") + utils.save_dict_pickle_file(task_file_path, utils.unpack_dataclass(task_content)) + + +def get_task_content(task_id: str): + download_path = utils.get_download_path() + task_file = os.path.join(download_path, f"{task_id}.task") + if not os.path.isfile(task_file): + raise RuntimeError(f"Task {task_id} not found") + task_content = utils.load_dict_pickle_file(task_file) + task_content["pathIndex"] = int(task_content.get("pathIndex", 0)) + task_content["sizeBytes"] = float(task_content.get("sizeBytes", 0)) + return TaskContent(**task_content) + + +def get_task_status(task_id: str): + task_status = download_model_task_status.get(task_id, None) + + if task_status is None: + download_path = utils.get_download_path() + task_content = get_task_content(task_id) + download_file = os.path.join(download_path, f"{task_id}.download") + download_size = 0 + if os.path.exists(download_file): + download_size = os.path.getsize(download_file) + + total_size = task_content.sizeBytes + task_status = TaskStatus( + taskId=task_id, + type=task_content.type, + fullname=task_content.fullname, + preview=utils.get_model_preview_name(download_file), + platform=task_content.downloadPlatform, + downloadedSize=download_size, + totalSize=task_content.sizeBytes, + progress=download_size / total_size * 100 if total_size > 0 else 0, + ) + + download_model_task_status[task_id] = task_status + + return task_status + + +def delete_task_status(task_id: str): + download_model_task_status.pop(task_id, None) + + +async def scan_model_download_task_list(sid: str): + """ + Scan the download directory and send the task list to the client. + """ + try: + download_dir = utils.get_download_path() + task_files = utils.search_files(download_dir) + task_files = folder_paths.filter_files_extensions(task_files, [".task"]) + task_files = sorted( + task_files, + key=lambda x: os.stat(os.path.join(download_dir, x)).st_ctime, + reverse=True, + ) + task_list: list[dict] = [] + for task_file in task_files: + task_id = task_file.replace(".task", "") + task_status = get_task_status(task_id) + task_list.append(task_status) + + await socket.send_json("downloadTaskList", task_list, sid) + except Exception as e: + error_msg = f"Refresh task list failed: {e}" + await socket.send_json("error", error_msg, sid) + logging.error(error_msg) + + +async def create_model_download_task(post: dict): + """ + Creates a download task for the given post. + """ + model_type = post.get("type", None) + path_index = int(post.get("pathIndex", None)) + fullname = post.get("fullname", None) + + model_path = utils.get_full_path(model_type, path_index, fullname) + # Check if the model path is valid + if os.path.exists(model_path): + raise RuntimeError(f"File already exists: {model_path}") + + download_path = utils.get_download_path() + + task_id = uuid.uuid4().hex + task_path = os.path.join(download_path, f"{task_id}.task") + if os.path.exists(task_path): + raise RuntimeError(f"Task {task_id} already exists") + + try: + previewFile = post.pop("previewFile", None) + utils.save_model_preview_image(task_path, previewFile) + set_task_content(task_id, post) + task_status = TaskStatus( + taskId=task_id, + type=model_type, + fullname=fullname, + preview=utils.get_model_preview_name(task_path), + platform=post.get("downloadPlatform", None), + totalSize=float(post.get("sizeBytes", 0)), + ) + download_model_task_status[task_id] = task_status + await socket.send_json("createDownloadTask", task_status) + except Exception as e: + await delete_model_download_task(task_id) + raise RuntimeError(str(e)) from e + + await download_model(task_id) + return task_id + + +async def pause_model_download_task(task_id: str): + task_status = get_task_status(task_id=task_id) + task_status.status = "pause" + + +async def delete_model_download_task(task_id: str): + task_status = get_task_status(task_id) + is_running = task_status.status == "doing" + task_status.status = "waiting" + await socket.send_json("deleteDownloadTask", task_id) + + # Pause the task + if is_running: + task_status.status = "pause" + time.sleep(1) + + download_dir = utils.get_download_path() + task_file_list = os.listdir(download_dir) + for task_file in task_file_list: + task_file_target = os.path.splitext(task_file)[0] + if task_file_target == task_id: + delete_task_status(task_id) + os.remove(os.path.join(download_dir, task_file)) + + await socket.send_json("deleteDownloadTask", task_id) + + +async def download_model(task_id: str): + async def download_task(task_id: str): + async def report_progress(task_status: TaskStatus): + await socket.send_json("updateDownloadTask", task_status) + + try: + # When starting a task from the queue, the task may not exist + task_status = get_task_status(task_id) + except: + return + + # Update task status + task_status.status = "doing" + await socket.send_json("updateDownloadTask", task_status) + + try: + + # Set download request headers + headers = {"User-Agent": config.user_agent} + + download_platform = task_status.platform + if download_platform == "civitai": + api_key = utils.get_setting_value("api_key.civitai") + if api_key: + headers["Authorization"] = f"Bearer {api_key}" + + elif download_platform == "huggingface": + api_key = utils.get_setting_value("api_key.huggingface") + if api_key: + headers["Authorization"] = f"Bearer {api_key}" + + progress_interval = 1.0 + await download_model_file( + task_id=task_id, + headers=headers, + progress_callback=report_progress, + interval=progress_interval, + ) + except Exception as e: + task_status.status = "pause" + task_status.error = str(e) + await socket.send_json("updateDownloadTask", task_status) + task_status.error = None + logging.error(str(e)) + + try: + status = download_thread_pool.submit(download_task, task_id) + if status == "Waiting": + task_status = get_task_status(task_id) + task_status.status = "waiting" + await socket.send_json("updateDownloadTask", task_status) + except Exception as e: + task_status.status = "pause" + task_status.error = str(e) + await socket.send_json("updateDownloadTask", task_status) + task_status.error = None + logging.error(traceback.format_exc()) + + +async def download_model_file( + task_id: str, + headers: dict, + progress_callback: Callable[[TaskStatus], Awaitable[Any]], + interval: float = 1.0, +): + + async def download_complete(): + """ + Restore the model information from the task file + and move the model file to the target directory. + """ + model_type = task_content.type + path_index = task_content.pathIndex + fullname = task_content.fullname + # Write description file + description = task_content.description + description_file = os.path.join(download_path, f"{task_id}.md") + with open(description_file, "w") as f: + f.write(description) + + model_path = utils.get_full_path(model_type, path_index, fullname) + + utils.rename_model(download_tmp_file, model_path) + + time.sleep(1) + task_file = os.path.join(download_path, f"{task_id}.task") + os.remove(task_file) + await socket.send_json("completeDownloadTask", task_id) + + async def update_progress(): + nonlocal last_update_time + nonlocal last_downloaded_size + progress = (downloaded_size / total_size) * 100 if total_size > 0 else 0 + task_status.downloadedSize = downloaded_size + task_status.progress = progress + task_status.bps = downloaded_size - last_downloaded_size + await progress_callback(task_status) + last_update_time = time.time() + last_downloaded_size = downloaded_size + + task_status = get_task_status(task_id) + task_content = get_task_content(task_id) + + # Check download uri + model_url = task_content.downloadUrl + if not model_url: + raise RuntimeError("No downloadUrl found") + + download_path = utils.get_download_path() + download_tmp_file = os.path.join(download_path, f"{task_id}.download") + + downloaded_size = 0 + if os.path.isfile(download_tmp_file): + downloaded_size = os.path.getsize(download_tmp_file) + headers["Range"] = f"bytes={downloaded_size}-" + + total_size = task_content.sizeBytes + + if total_size > 0 and downloaded_size == total_size: + await download_complete() + return + + last_update_time = time.time() + last_downloaded_size = downloaded_size + + response = requests.get( + url=model_url, + headers=headers, + stream=True, + allow_redirects=True, + ) + + if response.status_code not in (200, 206): + raise RuntimeError( + f"Failed to download {task_content.fullname}, status code: {response.status_code}" + ) + + # Some models require logging in before they can be downloaded. + # If no token is carried, it will be redirected to the login page. + content_type = response.headers.get("content-type") + if content_type and content_type.startswith("text/html"): + raise RuntimeError( + f"{task_content.fullname} needs to be logged in to download. Please set the API-Key first." + ) + + # When parsing model information from HuggingFace API, + # the file size was not found and needs to be obtained from the response header. + if total_size == 0: + total_size = int(response.headers.get("content-length", 0)) + task_content.sizeBytes = total_size + task_status.totalSize = total_size + set_task_content(task_id, task_content) + await socket.send_json("updateDownloadTask", task_content) + + with open(download_tmp_file, "ab") as f: + for chunk in response.iter_content(chunk_size=8192): + if task_status.status == "pause": + break + + f.write(chunk) + downloaded_size += len(chunk) + + if time.time() - last_update_time >= interval: + await update_progress() + + await update_progress() + + if total_size > 0 and downloaded_size == total_size: + await download_complete() + else: + task_status.status = "pause" + await socket.send_json("updateDownloadTask", task_status) diff --git a/py/services.py b/py/services.py new file mode 100644 index 0000000..835b095 --- /dev/null +++ b/py/services.py @@ -0,0 +1,145 @@ +import os +import logging +import traceback +import folder_paths + +from typing import Any +from multidict import MultiDictProxy +from . import utils +from . import socket +from . import download + + +async def connect_websocket(request): + async def message_handler(event_type: str, detail: Any, sid: str): + try: + if event_type == "downloadTaskList": + await download.scan_model_download_task_list(sid=sid) + + if event_type == "resumeDownloadTask": + await download.download_model(task_id=detail) + + if event_type == "pauseDownloadTask": + await download.pause_model_download_task(task_id=detail) + + if event_type == "deleteDownloadTask": + await download.delete_model_download_task(task_id=detail) + except Exception: + logging.error(traceback.format_exc()) + + ws = await socket.create_websocket_handler(request, handler=message_handler) + return ws + + +def scan_models_by_model_type(model_type: str): + """ + Scans all models in the given model type and returns a list of models. + """ + out = [] + folders, extensions = folder_paths.folder_names_and_paths[model_type] + for path_index, base_path in enumerate(folders): + files = utils.recursive_search_files(base_path) + + models = folder_paths.filter_files_extensions(files, extensions) + + for fullname in models: + """ + fullname is model path relative to base_path + eg. + abs_path is /path/to/models/stable-diffusion/custom_group/model_name.ckpt + base_path is /path/to/models/stable-diffusion + fullname is custom_group/model_name.ckpt + basename is custom_group/model_name + extension is .ckpt + """ + + fullname = fullname.replace(os.path.sep, "/") + basename = os.path.splitext(fullname)[0] + extension = os.path.splitext(fullname)[1] + prefix_path = fullname.replace(os.path.basename(fullname), "") + + abs_path = os.path.join(base_path, fullname) + file_stats = os.stat(abs_path) + + # Resolve metadata + metadata = utils.get_model_metadata(abs_path) + + # Resolve preview + image_name = utils.get_model_preview_name(abs_path) + image_name = os.path.join(prefix_path, image_name) + abs_image_path = os.path.join(base_path, image_name) + if os.path.isfile(abs_image_path): + image_state = os.stat(abs_image_path) + image_timestamp = round(image_state.st_mtime_ns / 1000000) + image_name = f"{image_name}?ts={image_timestamp}" + model_preview = ( + f"/model-manager/preview/{model_type}/{path_index}/{image_name}" + ) + + # Resolve description + description_file = utils.get_model_description_name(abs_path) + description_file = os.path.join(prefix_path, description_file) + abs_desc_path = os.path.join(base_path, description_file) + description = None + if os.path.isfile(abs_desc_path): + with open(abs_desc_path, "r", encoding="utf-8") as f: + description = f.read() + + out.append( + { + "fullname": fullname, + "basename": basename, + "extension": extension, + "type": model_type, + "pathIndex": path_index, + "sizeBytes": file_stats.st_size, + "preview": model_preview, + "description": description, + "createdAt": round(file_stats.st_ctime_ns / 1000000), + "updatedAt": round(file_stats.st_mtime_ns / 1000000), + "metadata": metadata, + } + ) + + return out + + +def update_model(model_path: str, post: MultiDictProxy): + + if "previewFile" in post: + previewFile = post["previewFile"] + utils.save_model_preview_image(model_path, previewFile) + + if "description" in post: + description = post["description"] + utils.save_model_description(model_path, description) + + if "type" in post and "pathIndex" in post and "fullname" in post: + model_type = post.get("type", None) + path_index = int(post.get("pathIndex", None)) + fullname = post.get("fullname", None) + if model_type is None or path_index is None or fullname is None: + raise RuntimeError("Invalid type or pathIndex or fullname") + + # get new path + new_model_path = utils.get_full_path(model_type, path_index, fullname) + + utils.rename_model(model_path, new_model_path) + + +def remove_model(model_path: str): + model_dirname = os.path.dirname(model_path) + os.remove(model_path) + + model_previews = utils.get_model_all_images(model_path) + for preview in model_previews: + os.remove(os.path.join(model_dirname, preview)) + + model_descriptions = utils.get_model_all_descriptions(model_path) + for description in model_descriptions: + os.remove(os.path.join(model_dirname, description)) + + +async def create_model_download_task(post): + dict_post = dict(post) + return await download.create_model_download_task(dict_post) diff --git a/py/socket.py b/py/socket.py new file mode 100644 index 0000000..13a39f9 --- /dev/null +++ b/py/socket.py @@ -0,0 +1,63 @@ +import aiohttp +import logging +import uuid +import json +from aiohttp import web +from typing import Any, Callable, Awaitable +from . import utils + + +__sockets: dict[str, web.WebSocketResponse] = {} + + +async def create_websocket_handler( + request, handler: Callable[[str, Any, str], Awaitable[Any]] +): + ws = web.WebSocketResponse() + await ws.prepare(request) + sid = request.rel_url.query.get("clientId", "") + if sid: + # Reusing existing session, remove old + __sockets.pop(sid, None) + else: + sid = uuid.uuid4().hex + + __sockets[sid] = ws + + try: + async for msg in ws: + if msg.type == aiohttp.WSMsgType.ERROR: + logging.warning( + "ws connection closed with exception %s" % ws.exception() + ) + if msg.type == aiohttp.WSMsgType.TEXT: + data = json.loads(msg.data) + await handler(data.get("type"), data.get("detail"), sid) + finally: + __sockets.pop(sid, None) + return ws + + +async def send_json(event: str, data: Any, sid: str = None): + detail = utils.unpack_dataclass(data) + message = {"type": event, "data": detail} + + if sid is None: + socket_list = list(__sockets.values()) + for ws in socket_list: + await __send_socket_catch_exception(ws.send_json, message) + elif sid in __sockets: + await __send_socket_catch_exception(__sockets[sid].send_json, message) + + +async def __send_socket_catch_exception(function, message): + try: + await function(message) + except ( + aiohttp.ClientError, + aiohttp.ClientPayloadError, + ConnectionResetError, + BrokenPipeError, + ConnectionError, + ) as err: + logging.warning("send error: {}".format(err)) diff --git a/py/thread.py b/py/thread.py new file mode 100644 index 0000000..82689f7 --- /dev/null +++ b/py/thread.py @@ -0,0 +1,64 @@ +import asyncio +import threading +import queue +import logging +from . import utils + + +class DownloadThreadPool: + def __init__(self) -> None: + self.workers_count = 0 + self.task_queue = queue.Queue() + self.running_tasks = set() + self._lock = threading.Lock() + + default_max_workers = 5 + max_workers: int = utils.get_setting_value( + "download.max_task_count", default_max_workers + ) + + if max_workers <= 0: + max_workers = default_max_workers + utils.set_setting_value("download.max_task_count", max_workers) + + self.max_worker = max_workers + + def submit(self, task, task_id): + with self._lock: + if task_id in self.running_tasks: + return "Existing" + self.running_tasks.add(task_id) + self.task_queue.put((task, task_id)) + return self._adjust_worker_count() + + def _adjust_worker_count(self): + if self.workers_count < self.max_worker: + self._start_worker() + return "Running" + else: + return "Waiting" + + def _start_worker(self): + t = threading.Thread(target=self._worker, daemon=True) + t.start() + with self._lock: + self.workers_count += 1 + + def _worker(self): + loop = asyncio.new_event_loop() + + while True: + if self.task_queue.empty(): + break + + task, task_id = self.task_queue.get() + + try: + loop.run_until_complete(task(task_id)) + with self._lock: + self.running_tasks.remove(task_id) + except Exception as e: + logging.error(f"worker run error: {str(e)}") + + with self._lock: + self.workers_count -= 1 diff --git a/py/utils.py b/py/utils.py new file mode 100644 index 0000000..fcdac9f --- /dev/null +++ b/py/utils.py @@ -0,0 +1,282 @@ +import os +import comfy.utils +import json +import logging +import folder_paths +from aiohttp import web +from typing import Any +from . import config + + +def resolve_model_base_paths(): + folders = list(folder_paths.folder_names_and_paths.keys()) + config.model_base_paths = {} + for folder in folders: + if folder == "configs": + continue + if folder == "custom_nodes": + continue + config.model_base_paths[folder] = folder_paths.get_folder_paths(folder) + + +def get_full_path(model_type: str, path_index: int, filename: str): + """ + Get the absolute path in the model type through string concatenation. + """ + folders = config.model_base_paths.get(model_type, []) + if not path_index < len(folders): + raise RuntimeError(f"PathIndex {path_index} is not in {model_type}") + base_path = folders[path_index] + return os.path.join(base_path, filename) + + +def get_valid_full_path(model_type: str, path_index: int, filename: str): + """ + Like get_full_path but it will check whether the file is valid. + """ + folders = config.model_base_paths.get(model_type, []) + if not path_index < len(folders): + raise RuntimeError(f"PathIndex {path_index} is not in {model_type}") + base_path = folders[path_index] + full_path = os.path.join(base_path, filename) + if os.path.isfile(full_path): + return full_path + elif os.path.islink(full_path): + raise RuntimeError( + f"WARNING path {full_path} exists but doesn't link anywhere, skipping." + ) + + +def get_download_path(): + download_path = os.path.join(config.extension_uri, "downloads") + if not os.path.exists(download_path): + os.makedirs(download_path) + return download_path + + +def recursive_search_files(directory: str): + files, folder_all = folder_paths.recursive_search( + directory, excluded_dir_names=[".git"] + ) + files.sort() + return files + + +def search_files(directory: str): + entries = os.listdir(directory) + files = [f for f in entries if os.path.isfile(os.path.join(directory, f))] + files.sort() + return files + + +def get_model_metadata(filename: str): + if not filename.endswith(".safetensors"): + return {} + try: + out = comfy.utils.safetensors_header(filename, max_size=1024 * 1024) + if out is None: + return {} + dt = json.loads(out) + if not "__metadata__" in dt: + return {} + return dt["__metadata__"] + except: + return {} + + +def get_model_all_images(model_path: str): + base_dirname = os.path.dirname(model_path) + files = search_files(base_dirname) + files = folder_paths.filter_files_content_types(files, ["image"]) + + basename = os.path.splitext(os.path.basename(model_path))[0] + output: list[str] = [] + for file in files: + file_basename = os.path.splitext(file)[0] + if file_basename == basename: + output.append(file) + if file_basename == f"{basename}.preview": + output.append(file) + return output + + +def get_model_preview_name(model_path: str): + images = get_model_all_images(model_path) + return images[0] if len(images) > 0 else "no-preview.png" + + +def save_model_preview_image(model_path: str, image_file: Any): + if not isinstance(image_file, web.FileField): + raise RuntimeError("Invalid image file") + + content_type: str = image_file.content_type + if not content_type.startswith("image/"): + raise RuntimeError(f"FileTypeError: expected image, got {content_type}") + + base_dirname = os.path.dirname(model_path) + + # remove old preview images + old_preview_images = get_model_all_images(model_path) + a1111_civitai_helper_image = False + for image in old_preview_images: + if os.path.splitext(image)[1].endswith(".preview"): + a1111_civitai_helper_image = True + image_path = os.path.join(base_dirname, image) + os.remove(image_path) + + # save new preview image + basename = os.path.splitext(os.path.basename(model_path))[0] + extension = f".{content_type.split('/')[1]}" + new_preview_path = os.path.join(base_dirname, f"{basename}{extension}") + + with open(new_preview_path, "wb") as f: + f.write(image_file.file.read()) + + # TODO Is it possible to abandon the current rules and adopt the rules of a1111 civitai_helper? + if a1111_civitai_helper_image: + """ + Keep preview image of a1111_civitai_helper + """ + new_preview_path = os.path.join(base_dirname, f"{basename}.preview{extension}") + with open(new_preview_path, "wb") as f: + f.write(image_file.file.read()) + + +def get_model_all_descriptions(model_path: str): + base_dirname = os.path.dirname(model_path) + files = search_files(base_dirname) + files = folder_paths.filter_files_extensions(files, [".txt", ".md"]) + + basename = os.path.splitext(os.path.basename(model_path))[0] + output: list[str] = [] + for file in files: + file_basename = os.path.splitext(file)[0] + if file_basename == basename: + output.append(file) + return output + + +def get_model_description_name(model_path: str): + descriptions = get_model_all_descriptions(model_path) + basename = os.path.splitext(os.path.basename(model_path))[0] + return descriptions[0] if len(descriptions) > 0 else f"{basename}.md" + + +def save_model_description(model_path: str, content: Any): + if not isinstance(content, str): + raise RuntimeError("Invalid description") + + base_dirname = os.path.dirname(model_path) + + # remove old descriptions + old_descriptions = get_model_all_descriptions(model_path) + for desc in old_descriptions: + description_path = os.path.join(base_dirname, desc) + os.remove(description_path) + + # save new description + basename = os.path.splitext(os.path.basename(model_path))[0] + extension = ".md" + new_desc_path = os.path.join(base_dirname, f"{basename}{extension}") + + with open(new_desc_path, "w", encoding="utf-8") as f: + f.write(content) + + +def rename_model(model_path: str, new_model_path: str): + if model_path == new_model_path: + return + + if os.path.exists(new_model_path): + raise RuntimeError(f"Model {new_model_path} already exists") + + model_name = os.path.splitext(os.path.basename(model_path))[0] + new_model_name = os.path.splitext(os.path.basename(new_model_path))[0] + + model_dirname = os.path.dirname(model_path) + new_model_dirname = os.path.dirname(new_model_path) + + if not os.path.exists(new_model_dirname): + os.makedirs(new_model_dirname) + + # move model + os.rename(model_path, new_model_path) + + # move preview + previews = get_model_all_images(model_path) + for preview in previews: + preview_path = os.path.join(model_dirname, preview) + preview_name = os.path.splitext(preview)[0] + preview_ext = os.path.splitext(preview)[1] + new_preview_path = ( + os.path.join(new_model_dirname, new_model_name + preview_ext) + if preview_name == model_name + else os.path.join( + new_model_dirname, new_model_name + ".preview" + preview_ext + ) + ) + os.rename(preview_path, new_preview_path) + + # move description + description = get_model_description_name(model_path) + description_path = os.path.join(model_dirname, description) + if os.path.isfile(description_path): + new_description_path = os.path.join(new_model_dirname, f"{new_model_name}.md") + os.rename(description_path, new_description_path) + + +import pickle + + +def save_dict_pickle_file(filename: str, data: dict): + with open(filename, "wb") as f: + pickle.dump(data, f) + + +def load_dict_pickle_file(filename: str) -> dict: + with open(filename, "rb") as f: + data = pickle.load(f) + return data + + +def resolve_setting_key(key: str) -> str: + key_paths = key.split(".") + setting_id = config.setting_key + try: + for key_path in key_paths: + setting_id = setting_id[key_path] + except: + pass + if not isinstance(setting_id, str): + raise RuntimeError(f"Invalid key: {key}") + + return setting_id + + +def set_setting_value(key: str, value: Any): + setting_id = resolve_setting_key(key) + fake_request = config.FakeRequest() + settings = config.serverInstance.user_manager.settings.get_settings(fake_request) + settings[setting_id] = value + config.serverInstance.user_manager.settings.save_settings(fake_request, settings) + + +def get_setting_value(key: str, default: Any = None) -> Any: + setting_id = resolve_setting_key(key) + fake_request = config.FakeRequest() + settings = config.serverInstance.user_manager.settings.get_settings(fake_request) + return settings.get(setting_id, default) + + +from dataclasses import asdict, is_dataclass + + +def unpack_dataclass(data: Any): + if isinstance(data, dict): + return {key: unpack_dataclass(value) for key, value in data.items()} + elif isinstance(data, list): + return [unpack_dataclass(x) for x in data] + elif is_dataclass(data): + return asdict(data) + else: + return data diff --git a/src/App.vue b/src/App.vue new file mode 100644 index 0000000..5eb7ae4 --- /dev/null +++ b/src/App.vue @@ -0,0 +1,41 @@ + + + diff --git a/src/components/DialogCreateTask.vue b/src/components/DialogCreateTask.vue new file mode 100644 index 0000000..d186db0 --- /dev/null +++ b/src/components/DialogCreateTask.vue @@ -0,0 +1,160 @@ + + + diff --git a/src/components/DialogDownload.vue b/src/components/DialogDownload.vue new file mode 100644 index 0000000..42f9722 --- /dev/null +++ b/src/components/DialogDownload.vue @@ -0,0 +1,167 @@ + + + diff --git a/src/components/DialogManager.vue b/src/components/DialogManager.vue new file mode 100644 index 0000000..19b06b6 --- /dev/null +++ b/src/components/DialogManager.vue @@ -0,0 +1,223 @@ + + + diff --git a/src/components/DialogModelCard.vue b/src/components/DialogModelCard.vue new file mode 100644 index 0000000..67608fc --- /dev/null +++ b/src/components/DialogModelCard.vue @@ -0,0 +1,97 @@ + + + diff --git a/src/components/DialogModelDetail.vue b/src/components/DialogModelDetail.vue new file mode 100644 index 0000000..ac64eee --- /dev/null +++ b/src/components/DialogModelDetail.vue @@ -0,0 +1,103 @@ + + + diff --git a/src/components/DialogResizer.vue b/src/components/DialogResizer.vue new file mode 100644 index 0000000..1cae6c4 --- /dev/null +++ b/src/components/DialogResizer.vue @@ -0,0 +1,303 @@ + + + diff --git a/src/components/FormWrapper.vue b/src/components/FormWrapper.vue new file mode 100644 index 0000000..049a4f3 --- /dev/null +++ b/src/components/FormWrapper.vue @@ -0,0 +1,17 @@ + + + diff --git a/src/components/GlobalLoading.vue b/src/components/GlobalLoading.vue new file mode 100644 index 0000000..f0be0e5 --- /dev/null +++ b/src/components/GlobalLoading.vue @@ -0,0 +1,15 @@ + + + diff --git a/src/components/GlobalToast.vue b/src/components/GlobalToast.vue new file mode 100644 index 0000000..af090ee --- /dev/null +++ b/src/components/GlobalToast.vue @@ -0,0 +1,22 @@ + + + diff --git a/src/components/ModelBaseInfo.vue b/src/components/ModelBaseInfo.vue new file mode 100644 index 0000000..2ff4f85 --- /dev/null +++ b/src/components/ModelBaseInfo.vue @@ -0,0 +1,90 @@ + + + diff --git a/src/components/ModelContent.vue b/src/components/ModelContent.vue new file mode 100644 index 0000000..e8c9128 --- /dev/null +++ b/src/components/ModelContent.vue @@ -0,0 +1,96 @@ + + + diff --git a/src/components/ModelDescription.vue b/src/components/ModelDescription.vue new file mode 100644 index 0000000..b42204a --- /dev/null +++ b/src/components/ModelDescription.vue @@ -0,0 +1,91 @@ + + + diff --git a/src/components/ModelMetadata.vue b/src/components/ModelMetadata.vue new file mode 100644 index 0000000..69f885d --- /dev/null +++ b/src/components/ModelMetadata.vue @@ -0,0 +1,37 @@ + + + diff --git a/src/components/ModelPreview.vue b/src/components/ModelPreview.vue new file mode 100644 index 0000000..6b45e93 --- /dev/null +++ b/src/components/ModelPreview.vue @@ -0,0 +1,112 @@ + + + diff --git a/src/components/ResponseFileUpload.vue b/src/components/ResponseFileUpload.vue new file mode 100644 index 0000000..24f2754 --- /dev/null +++ b/src/components/ResponseFileUpload.vue @@ -0,0 +1,56 @@ + + + diff --git a/src/components/ResponseImage.vue b/src/components/ResponseImage.vue new file mode 100644 index 0000000..8de839c --- /dev/null +++ b/src/components/ResponseImage.vue @@ -0,0 +1,36 @@ + + + diff --git a/src/components/ResponseInput.vue b/src/components/ResponseInput.vue new file mode 100644 index 0000000..1083998 --- /dev/null +++ b/src/components/ResponseInput.vue @@ -0,0 +1,82 @@ + + + + + diff --git a/src/components/ResponseScrollArea.vue b/src/components/ResponseScrollArea.vue new file mode 100644 index 0000000..d001c49 --- /dev/null +++ b/src/components/ResponseScrollArea.vue @@ -0,0 +1,214 @@ + + + diff --git a/src/components/ResponseSelect.vue b/src/components/ResponseSelect.vue new file mode 100644 index 0000000..0aef152 --- /dev/null +++ b/src/components/ResponseSelect.vue @@ -0,0 +1,234 @@ + + + diff --git a/src/hooks/config.ts b/src/hooks/config.ts new file mode 100644 index 0000000..628e183 --- /dev/null +++ b/src/hooks/config.ts @@ -0,0 +1,69 @@ +import { useRequest } from 'hooks/request' +import { defineStore } from 'hooks/store' +import { app } from 'scripts/comfyAPI' +import { onMounted, onUnmounted, ref } from 'vue' + +export const useConfig = defineStore('config', () => { + const mobileDeviceBreakPoint = 759 + const isMobile = ref(window.innerWidth < mobileDeviceBreakPoint) + + type ModelFolder = Record + const { data: modelFolders, refresh: refreshModelFolders } = + useRequest('/base-folders') + + const checkDeviceType = () => { + isMobile.value = window.innerWidth < mobileDeviceBreakPoint + } + + onMounted(() => { + window.addEventListener('resize', checkDeviceType) + }) + + onUnmounted(() => { + window.removeEventListener('resize', checkDeviceType) + }) + + const refreshSetting = async () => { + return Promise.all([refreshModelFolders()]) + } + + const config = { + isMobile, + gutter: 16, + cardWidth: 240, + aspect: 7 / 9, + modelFolders, + refreshSetting, + } + + useAddConfigSettings(config) + + return config +}) + +type Config = ReturnType + +declare module 'hooks/store' { + interface StoreProvider { + config: Config + } +} + +function useAddConfigSettings(config: Config) { + onMounted(() => { + // API keys + app.ui?.settings.addSetting({ + id: 'ModelManager.APIKey.HuggingFace', + name: 'HuggingFace API Key', + type: 'text', + defaultValue: undefined, + }) + + app.ui?.settings.addSetting({ + id: 'ModelManager.APIKey.Civitai', + name: 'Civitai API Key', + type: 'text', + defaultValue: undefined, + }) + }) +} diff --git a/src/hooks/download.ts b/src/hooks/download.ts new file mode 100644 index 0000000..2da8fab --- /dev/null +++ b/src/hooks/download.ts @@ -0,0 +1,423 @@ +import { useLoading } from 'hooks/loading' +import { MarkdownTool, useMarkdown } from 'hooks/markdown' +import { socket } from 'hooks/socket' +import { defineStore } from 'hooks/store' +import { useToast } from 'hooks/toast' +import { useBoolean } from 'hooks/utils' +import { bytesToSize } from 'utils/common' +import { onBeforeMount, onMounted, ref } from 'vue' +import { useI18n } from 'vue-i18n' + +export const useDownload = defineStore('download', (store) => { + const [visible, toggle] = useBoolean() + const { toast, confirm } = useToast() + const { t } = useI18n() + + const taskList = ref([]) + + const refresh = () => { + socket.send('downloadTaskList', null) + } + + const createTaskItem = (item: DownloadTaskOptions) => { + const { downloadedSize, totalSize, bps, ...rest } = item + + const task: DownloadTask = { + ...rest, + preview: `/model-manager/preview/download/${item.preview}`, + downloadProgress: `${bytesToSize(downloadedSize)} / ${bytesToSize(totalSize)}`, + downloadSpeed: `${bytesToSize(bps)}/s`, + pauseTask() { + socket.send('pauseDownloadTask', item.taskId) + }, + resumeTask: () => { + socket.send('resumeDownloadTask', item.taskId) + }, + deleteTask: () => { + confirm.require({ + message: t('deleteAsk', [t('downloadTask').toLowerCase()]), + header: 'Danger', + icon: 'pi pi-info-circle', + rejectProps: { + label: t('cancel'), + severity: 'secondary', + outlined: true, + }, + acceptProps: { + label: t('delete'), + severity: 'danger', + }, + accept: () => { + socket.send('deleteDownloadTask', item.taskId) + }, + reject: () => {}, + }) + }, + } + + return task + } + + onBeforeMount(() => { + socket.addEventListener('reconnected', () => { + refresh() + }) + + socket.addEventListener('downloadTaskList', (event) => { + const data = event.detail as DownloadTaskOptions[] + + taskList.value = data.map((item) => { + return createTaskItem(item) + }) + }) + + socket.addEventListener('createDownloadTask', (event) => { + const item = event.detail as DownloadTaskOptions + taskList.value.unshift(createTaskItem(item)) + }) + + socket.addEventListener('updateDownloadTask', (event) => { + const item = event.detail as DownloadTaskOptions + + for (const task of taskList.value) { + if (task.taskId === item.taskId) { + if (item.error) { + toast.add({ + severity: 'error', + summary: 'Error', + detail: item.error, + life: 15000, + }) + item.error = undefined + } + Object.assign(task, createTaskItem(item)) + } + } + }) + + socket.addEventListener('deleteDownloadTask', (event) => { + const taskId = event.detail as string + taskList.value = taskList.value.filter((item) => item.taskId !== taskId) + }) + + socket.addEventListener('completeDownloadTask', (event) => { + const taskId = event.detail as string + const task = taskList.value.find((item) => item.taskId === taskId) + taskList.value = taskList.value.filter((item) => item.taskId !== taskId) + toast.add({ + severity: 'success', + summary: 'Success', + detail: `${task?.fullname} Download completed`, + life: 2000, + }) + store.models.refresh() + }) + }) + + onMounted(() => { + refresh() + }) + + return { visible, toggle, data: taskList, refresh } +}) + +declare module 'hooks/store' { + interface StoreProvider { + download: ReturnType + } +} + +abstract class ModelSearch { + constructor(readonly md: MarkdownTool) {} + + abstract search(pathname: string): Promise +} + +class Civitai extends ModelSearch { + async search(searchUrl: string): Promise { + const { pathname, searchParams } = new URL(searchUrl) + + const [, modelId] = pathname.match(/^\/models\/(\d*)/) ?? [] + const versionId = searchParams.get('modelVersionId') + + if (!modelId) { + return Promise.resolve([]) + } + + return fetch(`https://civitai.com/api/v1/models/${modelId}`) + .then((response) => response.json()) + .then((resData) => { + const modelVersions: any[] = resData.modelVersions.filter( + (version: any) => { + if (versionId) { + return version.id == versionId + } + return true + }, + ) + + const models: VersionModel[] = [] + + for (const version of modelVersions) { + const modelFiles: any[] = version.files.filter( + (file: any) => file.type === 'Model', + ) + + const shortname = modelFiles.length > 0 ? version.name : undefined + + for (const file of modelFiles) { + const fullname = file.name + const extension = `.${fullname.split('.').pop()}` + const basename = fullname.replace(extension, '') + + models.push({ + id: file.id, + shortname: shortname ?? basename, + fullname: fullname, + basename: basename, + extension: extension, + preview: version.images.map((i: any) => i.url), + sizeBytes: file.sizeKB * 1024, + type: this.resolveType(resData.type), + pathIndex: 0, + description: [ + '---', + `website: Civitai`, + ``, + `modelPage: https://civitai.com/models/${modelId}?modelVersionId=${version.id}`, + '---', + '', + '# Trigger Words', + `\n${(version.trainedWords ?? ['No trigger words']).join(', ')}\n`, + '# About this version', + this.resolveDescription( + version.description, + '\nNo description about this version\n', + ), + `# ${resData.name}`, + this.resolveDescription( + resData.description, + 'No description about this model', + ), + ].join('\n'), + metadata: file.metadata, + downloadPlatform: 'civitai', + downloadUrl: file.downloadUrl, + hashes: file.hashes, + }) + } + } + + return models + }) + } + + private resolveType(type: string) { + const mapLegacy = { + TextualInversion: 'embeddings', + LoCon: 'loras', + DoRA: 'loras', + Controlnet: 'controlnet', + Upscaler: 'upscale_models', + VAE: 'vae', + } + return mapLegacy[type] ?? `${type.toLowerCase()}s` + } + + private resolveDescription(content: string, defaultContent: string) { + const mdContent = this.md.parse(content ?? '').trim() + return mdContent || defaultContent + } +} + +class Huggingface extends ModelSearch { + async search(searchUrl: string): Promise { + const { pathname } = new URL(searchUrl) + const [, space, name, ...restPaths] = pathname.split('/') + + if (!space || !name) { + return Promise.resolve([]) + } + + const modelId = `${space}/${name}` + const restPathname = restPaths.join('/') + + return fetch(`https://huggingface.co/api/models/${modelId}`) + .then((response) => response.json()) + .then((resData) => { + const siblingFiles: string[] = resData.siblings.map( + (item: any) => item.rfilename, + ) + + const modelFiles: string[] = this.filterTreeFiles( + this.filterModelFiles(siblingFiles), + restPathname, + ) + const images: string[] = this.filterTreeFiles( + this.filterImageFiles(siblingFiles), + restPathname, + ).map((filename) => { + return `https://huggingface.co/${modelId}/resolve/main/${filename}` + }) + + const models: VersionModel[] = [] + + for (const filename of modelFiles) { + const fullname = filename.split('/').pop()! + const extension = `.${fullname.split('.').pop()}` + const basename = fullname.replace(extension, '') + + models.push({ + id: filename, + shortname: filename, + fullname: fullname, + basename: basename, + extension: extension, + preview: images, + sizeBytes: 0, + type: 'unknown', + pathIndex: 0, + description: [ + '---', + `website: HuggingFace`, + `author: ${resData.author}`, + `modelPage: https://huggingface.co/${modelId}`, + '---', + '', + '# Trigger Words', + '\nNo trigger words\n', + '# About this version', + '\nNo description about this version\n', + `# ${resData.modelId}`, + '\nNo description about this model\n', + ].join('\n'), + metadata: {}, + downloadPlatform: 'huggingface', + downloadUrl: `https://huggingface.co/${modelId}/resolve/main/${filename}?download=true`, + }) + } + + return models + }) + } + + private filterTreeFiles(files: string[], pathname: string) { + const [target, , ...paths] = pathname.split('/') + + if (!target) return files + + if (target !== 'tree' && target !== 'blob') return files + + const pathPrefix = paths.join('/') + return files.filter((file) => { + return file.startsWith(pathPrefix) + }) + } + + private filterModelFiles(files: string[]) { + const extension = [ + '.bin', + '.ckpt', + '.gguf', + '.onnx', + '.pt', + '.pth', + '.safetensors', + ] + return files.filter((file) => { + const ext = file.split('.').pop() + return ext ? extension.includes(`.${ext}`) : false + }) + } + + private filterImageFiles(files: string[]) { + const extension = [ + '.png', + '.webp', + '.jpeg', + '.jpg', + '.jfif', + '.gif', + '.apng', + ] + + return files.filter((file) => { + const ext = file.split('.').pop() + return ext ? extension.includes(`.${ext}`) : false + }) + } +} + +class UnknownWebsite extends ModelSearch { + async search(searchUrl: string): Promise { + return Promise.reject( + new Error( + 'Unknown Website, please input a URL from huggingface.co or civitai.com.', + ), + ) + } +} + +export const useModelSearch = () => { + const loading = useLoading() + const md = useMarkdown() + const { toast } = useToast() + const data = ref<(SelectOptions & { item: VersionModel })[]>([]) + const current = ref() + + const handleSearchByUrl = async (url: string) => { + if (!url) { + return Promise.resolve([]) + } + + let instance: ModelSearch = new UnknownWebsite(md) + + const { hostname } = new URL(url ?? '') + + if (hostname === 'civitai.com') { + instance = new Civitai(md) + } + + if (hostname === 'huggingface.co') { + instance = new Huggingface(md) + } + + loading.show() + return instance + .search(url) + .then((resData) => { + data.value = resData.map((item) => ({ + label: item.shortname, + value: item.id, + item, + command() { + current.value = item.id + }, + })) + current.value = data.value[0]?.value + + if (resData.length === 0) { + toast.add({ + severity: 'warn', + summary: 'No Model Found', + detail: `No model found for ${url}`, + life: 3000, + }) + } + + return resData + }) + .catch((err) => { + toast.add({ + severity: 'error', + summary: 'Error', + detail: err.message, + life: 15000, + }) + return [] + }) + .finally(() => loading.hide()) + } + + return { data, current, search: handleSearchByUrl } +} diff --git a/src/hooks/loading.ts b/src/hooks/loading.ts new file mode 100644 index 0000000..4a6b0b8 --- /dev/null +++ b/src/hooks/loading.ts @@ -0,0 +1,55 @@ +import { defineStore } from 'hooks/store' +import { useBoolean } from 'hooks/utils' +import { Ref, ref } from 'vue' + +class GlobalLoading { + loading: Ref + + loadingStack = 0 + + bind(loading: Ref) { + this.loading = loading + } + + show() { + this.loadingStack++ + this.loading.value = true + } + + hide() { + this.loadingStack-- + if (this.loadingStack <= 0) this.loading.value = false + } +} + +export const globalLoading = new GlobalLoading() + +export const useGlobalLoading = defineStore('loading', () => { + const [loading] = useBoolean() + + globalLoading.bind(loading) + + return { loading } +}) + +export const useLoading = () => { + const timer = ref() + + const show = () => { + timer.value = setTimeout(() => { + timer.value = undefined + globalLoading.show() + }, 200) + } + + const hide = () => { + if (timer.value) { + clearTimeout(timer.value) + timer.value = undefined + } else { + globalLoading.hide() + } + } + + return { show, hide } +} diff --git a/src/hooks/manager.ts b/src/hooks/manager.ts new file mode 100644 index 0000000..9e37c98 --- /dev/null +++ b/src/hooks/manager.ts @@ -0,0 +1,27 @@ +import { defineStore } from 'hooks/store' +import { useBoolean } from 'hooks/utils' +import { ref, watch } from 'vue' + +export const useDialogManager = defineStore('dialogManager', () => { + const [visible, toggle] = useBoolean() + + const mounted = ref(false) + const open = ref(false) + + watch(visible, (visible) => { + open.value = visible + mounted.value = true + }) + + const updateVisible = (val: boolean) => { + visible.value = val + } + + return { visible: mounted, open, updateVisible, toggle } +}) + +declare module 'hooks/store' { + interface StoreProvider { + dialogManager: ReturnType + } +} diff --git a/src/hooks/markdown.ts b/src/hooks/markdown.ts new file mode 100644 index 0000000..a358d1e --- /dev/null +++ b/src/hooks/markdown.ts @@ -0,0 +1,49 @@ +import MarkdownIt from 'markdown-it' +import metadata_block from 'markdown-it-metadata-block' +import TurndownService from 'turndown' +import yaml from 'yaml' + +interface MarkdownOptions { + metadata?: Record +} + +export const useMarkdown = (opts?: MarkdownOptions) => { + const md = new MarkdownIt({ + html: true, + linkify: true, + typographer: true, + }) + + md.use(metadata_block, { + parseMetadata: yaml.parse, + meta: opts?.metadata ?? {}, + }) + + md.renderer.rules.link_open = function (tokens, idx, options, env, self) { + const aIndex = tokens[idx].attrIndex('target') + + if (aIndex < 0) { + tokens[idx].attrPush(['target', '_blank']) + } else { + tokens[idx].attrs![aIndex][1] = '_blank' + } + + return self.renderToken(tokens, idx, options) + } + + const turndown = new TurndownService({ + headingStyle: 'atx', + bulletListMarker: '-', + }) + + turndown.addRule('paragraph', { + filter: 'p', + replacement: function (content) { + return `\n\n${content}` + }, + }) + + return { render: md.render.bind(md), parse: turndown.turndown.bind(turndown) } +} + +export type MarkdownTool = ReturnType diff --git a/src/hooks/model.ts b/src/hooks/model.ts new file mode 100644 index 0000000..04c8817 --- /dev/null +++ b/src/hooks/model.ts @@ -0,0 +1,547 @@ +import { useLoading } from 'hooks/loading' +import { useMarkdown } from 'hooks/markdown' +import { request, useRequest } from 'hooks/request' +import { defineStore } from 'hooks/store' +import { useToast } from 'hooks/toast' +import { cloneDeep } from 'lodash' +import { app } from 'scripts/comfyAPI' +import { bytesToSize, formatDate, previewUrlToFile } from 'utils/common' +import { ModelGrid } from 'utils/legacy' +import { resolveModelType } from 'utils/model' +// import {} +import { + computed, + inject, + InjectionKey, + onMounted, + provide, + ref, + toRaw, + unref, +} from 'vue' +import { useI18n } from 'vue-i18n' + +export const useModels = defineStore('models', () => { + const { data, refresh } = useRequest<(Model & { visible?: boolean })[]>( + '/models', + { defaultValue: [] }, + ) + const { toast, confirm } = useToast() + const { t } = useI18n() + const loading = useLoading() + + const updateModel = async (model: BaseModel, data: BaseModel) => { + const formData = new FormData() + + // Check current preview + if (model.preview !== data.preview) { + const previewFile = await previewUrlToFile(data.preview as string) + formData.append('previewFile', previewFile) + } + + // Check current description + if (model.description !== data.description) { + formData.append('description', data.description) + } + + // Check current name and pathIndex + if ( + model.fullname !== data.fullname || + model.pathIndex !== data.pathIndex + ) { + formData.append('type', data.type) + formData.append('pathIndex', data.pathIndex.toString()) + formData.append('fullname', data.fullname) + } + + if (formData.keys().next().done) { + return + } + + loading.show() + await request(`/model/${model.type}/${model.pathIndex}/${model.fullname}`, { + method: 'PUT', + body: formData, + }) + .catch(() => { + toast.add({ + severity: 'error', + summary: 'Error', + detail: 'Failed to update model', + life: 15000, + }) + }) + .finally(() => { + loading.hide() + }) + + await refresh() + } + + const deleteModel = async (model: BaseModel) => { + return new Promise((resolve) => { + confirm.require({ + message: t('deleteAsk', [t('model').toLowerCase()]), + header: 'Danger', + icon: 'pi pi-info-circle', + rejectProps: { + label: t('cancel'), + severity: 'secondary', + outlined: true, + }, + acceptProps: { + label: t('delete'), + severity: 'danger', + }, + accept: () => { + loading.show() + request(`/model/${model.type}/${model.pathIndex}/${model.fullname}`, { + method: 'DELETE', + }) + .then(() => { + toast.add({ + severity: 'success', + summary: 'Success', + detail: `${model.fullname} Deleted`, + life: 2000, + }) + return refresh() + }) + .then(() => { + resolve(void 0) + }) + .catch((e) => { + toast.add({ + severity: 'error', + summary: 'Error', + detail: e.message ?? 'Failed to delete model', + life: 15000, + }) + }) + .finally(() => { + loading.hide() + }) + }, + reject: () => {}, + }) + }) + } + + return { data, refresh, remove: deleteModel, update: updateModel } +}) + +declare module 'hooks/store' { + interface StoreProvider { + models: ReturnType + } +} + +export const useModelFormData = (getFormData: () => BaseModel) => { + const formData = ref(getFormData()) + const modelData = ref(getFormData()) + + type ResetCallback = () => void + const resetCallback = ref([]) + + const registerReset = (callback: ResetCallback) => { + resetCallback.value.push(callback) + } + + const reset = () => { + formData.value = getFormData() + modelData.value = getFormData() + for (const callback of resetCallback.value) { + callback() + } + } + + type SubmitCallback = (data: BaseModel) => void + const submitCallback = ref([]) + + const registerSubmit = (callback: SubmitCallback) => { + submitCallback.value.push(callback) + } + + const submit = () => { + const data = cloneDeep(toRaw(unref(formData))) + for (const callback of submitCallback.value) { + callback(data) + } + return data + } + + const metadata = ref>({}) + + return { + formData, + modelData, + registerReset, + reset, + registerSubmit, + submit, + metadata, + } +} + +type ModelFormInstance = ReturnType + +/** + * Model base info + */ +const baseInfoKey = Symbol('baseInfo') as InjectionKey< + ReturnType +> + +export const useModelBaseInfoEditor = (formInstance: ModelFormInstance) => { + const { formData: model, modelData } = formInstance + + const type = computed({ + get: () => { + return model.value.type + }, + set: (val) => { + model.value.type = val + }, + }) + + const pathIndex = computed({ + get: () => { + return model.value.pathIndex + }, + set: (val) => { + model.value.pathIndex = val + }, + }) + + const extension = computed(() => { + return model.value.extension + }) + + const basename = computed({ + get: () => { + return model.value.fullname.replace(model.value.extension, '') + }, + set: (val) => { + model.value.fullname = `${val ?? ''}${model.value.extension}` + }, + }) + + interface BaseInfoItem { + key: string + display: string + value: any + } + + interface FieldsItem { + key: keyof Model + formatter: (val: any) => string + } + + const baseInfo = computed(() => { + const fields: FieldsItem[] = [ + { + key: 'type', + formatter: () => resolveModelType(modelData.value.type).display, + }, + { + key: 'fullname', + formatter: (val) => val, + }, + { + key: 'sizeBytes', + formatter: (val) => (val == 0 ? 'Unknown' : bytesToSize(val)), + }, + { + key: 'createdAt', + formatter: (val) => val && formatDate(val), + }, + { + key: 'updatedAt', + formatter: (val) => val && formatDate(val), + }, + ] + + const information: Record = {} + for (const item of fields) { + const key = item.key + const value = model.value[key] + const display = item.formatter(value) + + if (display) { + information[key] = { key, value, display } + } + } + + return information + }) + + const result = { + type, + baseInfo, + basename, + extension, + pathIndex, + } + + provide(baseInfoKey, result) + + return result +} + +export const useModelBaseInfo = () => { + return inject(baseInfoKey)! +} + +/** + * Editable preview image. + * + * In edit mode, there are 4 methods for setting a preview picture: + * 1. default value, which is the default image of the model type + * 2. network picture + * 3. local file + * 4. no preview + */ +const previewKey = Symbol('preview') as InjectionKey< + ReturnType +> + +export const useModelPreviewEditor = (formInstance: ModelFormInstance) => { + const { formData: model, registerReset, registerSubmit } = formInstance + + const typeOptions = ref(['default', 'network', 'local', 'none']) + const currentType = ref('default') + + /** + * Default images + */ + const defaultContent = computed(() => { + return Array.isArray(model.value.preview) + ? model.value.preview + : [model.value.preview] + }) + const defaultContentPage = ref(0) + + /** + * Network picture url + */ + const networkContent = ref() + + /** + * Local file url + */ + const localContent = ref() + const updateLocalContent = async (event: SelectEvent) => { + const { files } = event + localContent.value = files[0].objectURL + } + + /** + * No preview + */ + const noPreviewContent = computed(() => { + return `/model-manager/preview/${model.value.type}/0/no-preview.png` + }) + + const preview = computed(() => { + let content: string | undefined + + switch (currentType.value) { + case 'default': + content = defaultContent.value[defaultContentPage.value] + break + case 'network': + content = networkContent.value + break + case 'local': + content = localContent.value + break + default: + content = noPreviewContent.value + break + } + + return content + }) + + onMounted(() => { + registerReset(() => { + currentType.value = 'default' + defaultContentPage.value = 0 + networkContent.value = undefined + localContent.value = undefined + }) + + registerSubmit((data) => { + data.preview = preview.value ?? noPreviewContent.value + }) + }) + + const result = { + preview, + typeOptions, + currentType, + // default value + defaultContent, + defaultContentPage, + // network picture + networkContent, + // local file + localContent, + updateLocalContent, + // no preview + noPreviewContent, + } + + provide(previewKey, result) + + return result +} + +export const useModelPreview = () => { + return inject(previewKey)! +} + +/** + * Model description + */ +const descriptionKey = Symbol('description') as InjectionKey< + ReturnType +> + +export const useModelDescriptionEditor = (formInstance: ModelFormInstance) => { + const { formData: model, metadata } = formInstance + + const md = useMarkdown({ metadata: metadata.value }) + + const description = computed({ + get: () => { + return model.value.description + }, + set: (val) => { + model.value.description = val + }, + }) + + const renderedDescription = computed(() => { + return description.value ? md.render(description.value) : undefined + }) + + const result = { renderedDescription, description } + + provide(descriptionKey, result) + + return result +} + +export const useModelDescription = () => { + return inject(descriptionKey)! +} + +/** + * Model metadata + */ +const metadataKey = Symbol('metadata') as InjectionKey< + ReturnType +> + +export const useModelMetadataEditor = (formInstance: ModelFormInstance) => { + const { formData: model } = formInstance + + const metadata = computed(() => { + return model.value.metadata + }) + + const result = { metadata } + + provide(metadataKey, result) + + return result +} + +export const useModelMetadata = () => { + return inject(metadataKey)! +} + +export const useModelNodeAction = (model: BaseModel) => { + const { t } = useI18n() + const { toast, wrapperToastError } = useToast() + + const createNode = (options: Record = {}) => { + const nodeType = resolveModelType(model.type).loader + if (!nodeType) { + throw new Error(t('unSupportedModelType', [model.type])) + } + + const node = window.LiteGraph.createNode(nodeType, null, options) + const widgetIndex = node.widgets.findIndex((w) => w.type === 'combo') + if (widgetIndex > -1) { + node.widgets[widgetIndex].value = model.fullname + } + return node + } + + const dragToAddModelNode = wrapperToastError((event: DragEvent) => { + // const target = document.elementFromPoint(event.clientX, event.clientY) + // if ( + // target?.tagName.toLocaleLowerCase() === 'canvas' && + // target.id === 'graph-canvas' + // ) { + // const pos = app.clientPosToCanvasPos([event.clientX - 20, event.clientY]) + // const node = createNode({ pos }) + // app.graph.add(node) + // app.canvas.selectNode(node) + // } + // + // Use the legacy method instead + const removeEmbeddingExtension = true + const strictDragToAdd = false + + ModelGrid.dragAddModel( + event, + model.type, + model.fullname, + removeEmbeddingExtension, + strictDragToAdd, + ) + }) + + const addModelNode = wrapperToastError(() => { + const selectedNodes = app.canvas.selected_nodes + const firstSelectedNode = Object.values(selectedNodes)[0] + const offset = 25 + const pos = firstSelectedNode + ? [firstSelectedNode.pos[0] + offset, firstSelectedNode.pos[1] + offset] + : app.canvas.canvas_mouse + const node = createNode({ pos }) + app.graph.add(node) + app.canvas.selectNode(node) + }) + + const copyModelNode = wrapperToastError(() => { + const node = createNode() + app.canvas.copyToClipboard([node]) + toast.add({ + severity: 'success', + summary: 'Success', + detail: t('modelCopied'), + life: 2000, + }) + }) + + const loadPreviewWorkflow = wrapperToastError(async () => { + const previewUrl = model.preview as string + const response = await fetch(previewUrl) + const data = await response.blob() + const type = data.type + const extension = type.split('/').pop() + const file = new File([data], `${model.fullname}.${extension}`, { type }) + app.handleFile(file) + }) + + return { + addModelNode, + dragToAddModelNode, + copyModelNode, + loadPreviewWorkflow, + } +} diff --git a/src/hooks/request.ts b/src/hooks/request.ts new file mode 100644 index 0000000..840df82 --- /dev/null +++ b/src/hooks/request.ts @@ -0,0 +1,85 @@ +import { useLoading } from 'hooks/loading' +import { api } from 'scripts/comfyAPI' +import { onMounted, ref } from 'vue' + +export const request = async (url: string, options?: RequestInit) => { + return api + .fetchApi(`/model-manager${url}`, options) + .then((response) => response.json()) + .then((resData) => { + if (resData.success) { + return resData.data + } + throw new Error(resData.error) + }) +} + +export interface RequestOptions { + method?: RequestInit['method'] + headers?: RequestInit['headers'] + defaultParams?: Record + defaultValue?: any + postData?: (data: T) => T + manual?: boolean +} + +export const useRequest = ( + url: string, + options: RequestOptions = {}, +) => { + const loading = useLoading() + const postData = options.postData ?? ((data) => data) + + const data = ref(options.defaultValue) + const lastParams = ref() + + const fetch = async ( + params: Record = options.defaultParams ?? {}, + ) => { + loading.show() + + lastParams.value = params + + let requestUrl = url + const requestOptions: RequestInit = { + method: options.method, + headers: options.headers, + } + const requestParams = { ...params } + + const templatePattern = /\{(.*?)\}/g + const urlParamKeyMatches = requestUrl.matchAll(templatePattern) + for (const urlParamKey of urlParamKeyMatches) { + const [match, paramKey] = urlParamKey + if (paramKey in requestParams) { + const paramValue = requestParams[paramKey] + delete requestParams[paramKey] + requestUrl = requestUrl.replace(match, paramValue) + } + } + + if (!requestOptions.method) { + requestOptions.method = 'GET' + } + + if (requestOptions.method !== 'GET') { + requestOptions.body = JSON.stringify(requestParams) + } + + return request(requestUrl, requestOptions) + .then((resData) => (data.value = postData(resData))) + .finally(() => loading.hide()) + } + + onMounted(() => { + if (!options.manual) { + fetch() + } + }) + + const refresh = async () => { + return fetch(lastParams.value) + } + + return { data, refresh, fetch } +} diff --git a/src/hooks/resize.ts b/src/hooks/resize.ts new file mode 100644 index 0000000..5ca8f4b --- /dev/null +++ b/src/hooks/resize.ts @@ -0,0 +1,22 @@ +import { throttle } from 'lodash' +import { Directive } from 'vue' + +export const resizeDirective: Directive = { + mounted: (el, binding) => { + const callback = binding.value ?? (() => {}) + const observer = new ResizeObserver(callback) + observer.observe(el) + el['observer'] = observer + }, + unmounted: (el) => { + const observer = el['observer'] + observer.disconnect() + }, +} + +export const defineResizeCallback = ( + callback: ResizeObserverCallback, + wait?: number, +) => { + return throttle(callback, wait ?? 100) +} diff --git a/src/hooks/socket.ts b/src/hooks/socket.ts new file mode 100644 index 0000000..582a43a --- /dev/null +++ b/src/hooks/socket.ts @@ -0,0 +1,82 @@ +import { globalToast } from 'hooks/toast' +import { readonly } from 'vue' + +class WebSocketEvent extends EventTarget { + private socket: WebSocket | null + + constructor() { + super() + this.createSocket() + } + + private createSocket(isReconnect?: boolean) { + const api_host = location.host + const api_base = location.pathname.split('/').slice(0, -1).join('/') + + let opened = false + let existingSession = window.name + if (existingSession) { + existingSession = '?clientId=' + existingSession + } + + this.socket = readonly( + new WebSocket( + `ws${window.location.protocol === 'https:' ? 's' : ''}://${api_host}${api_base}/model-manager/ws${existingSession}`, + ), + ) + + this.socket.addEventListener('open', () => { + opened = true + if (isReconnect) { + this.dispatchEvent(new CustomEvent('reconnected')) + } + }) + + this.socket.addEventListener('error', () => { + if (this.socket) this.socket.close() + }) + + this.socket.addEventListener('close', (event) => { + setTimeout(() => { + this.socket = null + this.createSocket(true) + }, 300) + if (opened) { + this.dispatchEvent(new CustomEvent('status', { detail: null })) + this.dispatchEvent(new CustomEvent('reconnecting')) + } + }) + + this.socket.addEventListener('message', (event) => { + try { + const msg = JSON.parse(event.data) + if (msg.type === 'error') { + globalToast.value?.add({ + severity: 'error', + summary: 'Error', + detail: msg.data, + life: 15000, + }) + } else { + this.dispatchEvent(new CustomEvent(msg.type, { detail: msg.data })) + } + } catch (error) { + console.error(error) + } + }) + } + + addEventListener = ( + type: string, + callback: CustomEventListener | null, + options?: AddEventListenerOptions | boolean, + ) => { + super.addEventListener(type, callback, options) + } + + send(type: string, data: any) { + this.socket?.send(JSON.stringify({ type, detail: data })) + } +} + +export const socket = new WebSocketEvent() diff --git a/src/hooks/store.ts b/src/hooks/store.ts new file mode 100644 index 0000000..96431bd --- /dev/null +++ b/src/hooks/store.ts @@ -0,0 +1,51 @@ +import { inject, InjectionKey, provide } from 'vue' + +const providerHooks = new Map() +const storeEvent = {} as StoreProvider + +export const useStoreProvider = () => { + // const storeEvent = {} + + for (const [key, useHook] of providerHooks) { + storeEvent[key] = useHook() + } + + return storeEvent +} + +const storeKeys = new Map() + +const getStoreKey = (key: string) => { + let storeKey = storeKeys.get(key) + if (!storeKey) { + storeKey = Symbol(key) + storeKeys.set(key, storeKey) + } + return storeKey +} + +/** + * Using vue provide and inject to implement a simple store + */ +export const defineStore = ( + key: string, + useInitial: (event: StoreProvider) => T, +) => { + const storeKey = getStoreKey(key) as InjectionKey + + if (providerHooks.has(key) && !import.meta.hot) { + console.warn(`[defineStore] key: ${key} already exists.`) + } else { + providerHooks.set(key, () => { + const result = useInitial(storeEvent) + provide(storeKey, result ?? storeEvent[key]) + return result + }) + } + + const useStore = () => { + return inject(storeKey)! + } + + return useStore +} diff --git a/src/hooks/toast.ts b/src/hooks/toast.ts new file mode 100644 index 0000000..4e2f683 --- /dev/null +++ b/src/hooks/toast.ts @@ -0,0 +1,45 @@ +import { ToastServiceMethods } from 'primevue/toastservice' +import { useConfirm as usePrimeConfirm } from 'primevue/useconfirm' +import { useToast as usePrimeToast } from 'primevue/usetoast' + +export const globalToast = { value: null } as unknown as { + value: ToastServiceMethods +} + +export const useToast = () => { + const toast = usePrimeToast() + const confirm = usePrimeConfirm() + + globalToast.value = toast + + const wrapperToastError = (callback: T): T => { + const showToast = (error: Error) => { + toast.add({ + severity: 'error', + summary: 'Error', + detail: error.message, + life: 15000, + }) + } + + const isAsync = callback.constructor.name === 'AsyncFunction' + + let wrapperExec: any + + if (isAsync) { + wrapperExec = (...args: any[]) => callback(...args).catch(showToast) + } else { + wrapperExec = (...args: any[]) => { + try { + return callback(...args) + } catch (error) { + showToast(error) + } + } + } + + return wrapperExec + } + + return { toast, wrapperToastError, confirm } +} diff --git a/src/hooks/utils.ts b/src/hooks/utils.ts new file mode 100644 index 0000000..28a8d58 --- /dev/null +++ b/src/hooks/utils.ts @@ -0,0 +1,11 @@ +import { ref } from 'vue' + +export const useBoolean = (defaultValue?: boolean) => { + const target = ref(defaultValue ?? false) + + const toggle = (value?: any) => { + target.value = typeof value === 'boolean' ? value : !target.value + } + + return [target, toggle] as const +} diff --git a/src/i18n.ts b/src/i18n.ts new file mode 100644 index 0000000..f30a2a9 --- /dev/null +++ b/src/i18n.ts @@ -0,0 +1,88 @@ +import { createI18n } from 'vue-i18n' + +const messages = { + en: { + model: 'Model', + modelManager: 'Model Manager', + openModelManager: 'Open Model Manager', + searchModels: 'Search models', + modelCopied: 'Model Copied', + download: 'Download', + downloadList: 'Download List', + downloadTask: 'Download Task', + createDownloadTask: 'Create Download Task', + parseModelUrl: 'Parse Model URL', + pleaseInputModelUrl: 'Input a URL from civitai.com or huggingface.co', + cancel: 'Cancel', + save: 'Save', + delete: 'Delete', + deleteAsk: 'Confirm delete this {0}?', + modelType: 'Model Type', + default: 'Default', + network: 'Network', + local: 'Local', + none: 'None', + uploadFile: 'Upload File', + tapToChange: 'Tap description to change content', + sort: { + name: 'Name', + size: 'Largest', + created: 'Latest created', + modified: 'Latest modified', + }, + info: { + type: 'Model Type', + fullname: 'File Name', + sizeBytes: 'File Size', + createdAt: 'Created At', + updatedAt: 'Updated At', + }, + }, + zh: { + model: '模型', + modelManager: '模型管理器', + openModelManager: '打开模型管理器', + searchModels: '搜索模型', + modelCopied: '模型节点已拷贝', + download: '下载', + downloadList: '下载列表', + downloadTask: '下载任务', + createDownloadTask: '创建下载任务', + parseModelUrl: '解析模型URL', + pleaseInputModelUrl: '输入 civitai.com 或 huggingface.co 的 URL', + cancel: '取消', + save: '保存', + delete: '删除', + deleteAsk: '确定要删除此{0}?', + modelType: '模型类型', + default: '默认', + network: '网络', + local: '本地', + none: '无', + uploadFile: '上传文件', + tapToChange: '点击描述可更改内容', + sort: { + name: '名称', + size: '最大', + created: '最新创建', + modified: '最新修改', + }, + info: { + type: '类型', + fullname: '文件名', + sizeBytes: '文件大小', + createdAt: '创建时间', + updatedAt: '更新时间', + }, + }, +} + +export const i18n = createI18n({ + legacy: false, + locale: + localStorage.getItem('Comfy.Settings.Comfy.Locale') || + navigator.language.split('-')[0] || + 'en', + fallbackLocale: 'en', + messages, +}) diff --git a/src/main.ts b/src/main.ts new file mode 100644 index 0000000..c2eb44a --- /dev/null +++ b/src/main.ts @@ -0,0 +1,55 @@ +import { definePreset } from '@primevue/themes' +import Aura from '@primevue/themes/aura' +import { resizeDirective } from 'hooks/resize' +import PrimeVue from 'primevue/config' +import ConfirmationService from 'primevue/confirmationservice' +import ToastService from 'primevue/toastservice' +import Tooltip from 'primevue/tooltip' +import { app } from 'scripts/comfyAPI' +import { createApp } from 'vue' +import App from './App.vue' +import { i18n } from './i18n' +import './style.css' + +const ComfyUIPreset = definePreset(Aura, { + semantic: { + primary: Aura['primitive'].blue, + }, +}) + +function createVueApp(rootContainer: string | HTMLElement) { + const app = createApp(App) + app.directive('tooltip', Tooltip) + app.directive('resize', resizeDirective) + app + .use(PrimeVue, { + theme: { + preset: ComfyUIPreset, + options: { + prefix: 'p', + cssLayer: { + name: 'primevue', + order: 'tailwind-base, primevue, tailwind-utilities', + }, + // This is a workaround for the issue with the dark mode selector + // https://github.com/primefaces/primevue/issues/5515 + darkModeSelector: '.dark-theme, :root:has(.dark-theme)', + }, + }, + }) + .use(ToastService) + .use(ConfirmationService) + .use(i18n) + .mount(rootContainer) +} + +app.registerExtension({ + name: 'Comfy.ModelManager', + setup() { + const container = document.createElement('div') + container.id = 'comfyui-model-manager' + document.body.appendChild(container) + + createVueApp(container) + }, +}) diff --git a/src/scripts/comfyAPI.ts b/src/scripts/comfyAPI.ts new file mode 100644 index 0000000..61394ff --- /dev/null +++ b/src/scripts/comfyAPI.ts @@ -0,0 +1,7 @@ +export const app = window.comfyAPI.app.app +export const api = window.comfyAPI.api.api + +export const $el = window.comfyAPI.ui.$el + +export const ComfyApp = window.comfyAPI.app.ComfyApp +export const ComfyButton = window.comfyAPI.button.ComfyButton diff --git a/src/style.css b/src/style.css new file mode 100644 index 0000000..6d7905a --- /dev/null +++ b/src/style.css @@ -0,0 +1,157 @@ +@layer primevue, tailwind-utilities; + +@layer tailwind-utilities { + @tailwind components; + @tailwind utilities; + + :root { + --tw-border-spacing-x: 0; + --tw-border-spacing-y: 0; + --tw-translate-x: 0; + --tw-translate-y: 0; + --tw-rotate: 0; + --tw-skew-x: 0; + --tw-skew-y: 0; + --tw-scale-x: 1; + --tw-scale-y: 1; + --tw-pan-x: ; + --tw-pan-y: ; + --tw-pinch-zoom: ; + --tw-scroll-snap-strictness: proximity; + --tw-gradient-from-position: ; + --tw-gradient-via-position: ; + --tw-gradient-to-position: ; + --tw-ordinal: ; + --tw-slashed-zero: ; + --tw-numeric-figure: ; + --tw-numeric-spacing: ; + --tw-numeric-fraction: ; + --tw-ring-inset: ; + --tw-ring-offset-width: 0px; + --tw-ring-offset-color: #fff; + --tw-ring-color: rgb(59 130 246 / 0.5); + --tw-ring-offset-shadow: 0 0 #0000; + --tw-ring-shadow: 0 0 #0000; + --tw-shadow: 0 0 #0000; + --tw-shadow-colored: 0 0 #0000; + --tw-blur: ; + --tw-brightness: ; + --tw-contrast: ; + --tw-grayscale: ; + --tw-hue-rotate: ; + --tw-invert: ; + --tw-saturate: ; + --tw-sepia: ; + --tw-drop-shadow: ; + --tw-backdrop-blur: ; + --tw-backdrop-brightness: ; + --tw-backdrop-contrast: ; + --tw-backdrop-grayscale: ; + --tw-backdrop-hue-rotate: ; + --tw-backdrop-invert: ; + --tw-backdrop-opacity: ; + --tw-backdrop-saturate: ; + --tw-backdrop-sepia: ; + --tw-contain-size: ; + --tw-contain-layout: ; + --tw-contain-paint: ; + --tw-contain-style: ; + } + + *.border, + *.border-x, + *.border-y, + *.border-l, + *.border-t, + *.border-r, + *.border-b { + border-style: solid; + } + + table, + th, + tr, + td { + border-width: 0px; + } +} + +.comfy-modal { + z-index: 3000; +} + +.markdown-it { + font-family: theme('fontFamily.sans'); + line-height: theme('lineHeight.relaxed'); + word-break: break-word; + margin: 0; + + h1 { + font-size: theme('fontSize.2xl'); + font-weight: theme('fontWeight.bold'); + border-bottom: 1px solid #ddd; + margin-top: theme('margin.4'); + margin-bottom: theme('margin.4'); + padding-bottom: theme('padding[2.5]'); + } + + h2 { + font-size: theme('fontSize.xl'); + font-weight: theme('fontWeight.bold'); + } + + h3 { + font-size: theme('fontSize.lg'); + } + + a { + color: #1e8bc3; + text-decoration: none; + word-break: break-all; + } + + a:hover { + text-decoration: underline; + } + + p { + margin: 1em 0; + } + + p img { + width: 100%; + height: 100%; + object-fit: cover; + } + + ul, + ol { + margin: 1em 0; + padding-left: 2em; + } + + li { + margin: 0.5em 0; + } + + blockquote { + border-left: 5px solid #ddd; + padding: 10px 20px; + margin: 1.5em 0; + background: #f9f9f9; + } + + code, + pre { + background: #f9f9f9; + padding: 3px 5px; + border: 1px solid #ddd; + border-radius: 3px; + font-family: 'Courier New', Courier, monospace; + } + + pre { + padding: 10px; + overflow-x: auto; + } +} diff --git a/src/types/global.d.ts b/src/types/global.d.ts new file mode 100644 index 0000000..9a4e612 --- /dev/null +++ b/src/types/global.d.ts @@ -0,0 +1,272 @@ +declare namespace ComfyAPI { + namespace api { + class ComfyApi { + socket: WebSocket + fetchApi: (route: string, options?: RequestInit) => Promise + addEventListener: ( + type: string, + callback: (event: CustomEvent) => void, + options?: AddEventListenerOptions, + ) => void + } + + const api: ComfyApi + } + + namespace app { + interface ComfyExtension { + /** + * The name of the extension + */ + name: string + /** + * Allows any initialisation, e.g. loading resources. Called after the canvas is created but before nodes are added + * @param app The ComfyUI app instance + */ + init?(app: ComfyApp): Promise | void + /** + * Allows any additional setup, called after the application is fully set up and running + * @param app The ComfyUI app instance + */ + setup?(app: ComfyApp): Promise | void + } + + interface BaseSidebarTabExtension { + id: string + title: string + icon?: string + iconBadge?: string | (() => string | null) + order?: number + tooltip?: string + } + + interface VueSidebarTabExtension extends BaseSidebarTabExtension { + type: 'vue' + component: import('vue').Component + } + + interface CustomSidebarTabExtension extends BaseSidebarTabExtension { + type: 'custom' + render: (container: HTMLElement) => void + destroy?: () => void + } + + type SidebarTabExtension = + | VueSidebarTabExtension + | CustomSidebarTabExtension + + interface ExtensionManager { + // Sidebar tabs + registerSidebarTab(tab: SidebarTabExtension): void + unregisterSidebarTab(id: string): void + getSidebarTabs(): SidebarTabExtension[] + + // Toast + toast: ToastManager + } + + class ComfyApp { + ui?: ui.ComfyUI + menu?: index.ComfyAppMenu + graph: lightGraph.LGraph + canvas: lightGraph.LGraphCanvas + extensionManager: ExtensionManager + registerExtension: (extension: ComfyExtension) => void + addNodeOnGraph: ( + nodeDef: lightGraph.ComfyNodeDef, + options?: Record, + ) => lightGraph.LGraphNode + getCanvasCenter: () => lightGraph.Vector2 + clientPosToCanvasPos: (pos: lightGraph.Vector2) => lightGraph.Vector2 + handleFile: (file: File) => void + } + + const app: ComfyApp + } + + namespace ui { + type Props = { + parent?: HTMLElement + $?: (el: HTMLElement) => void + dataset?: DOMStringMap + style?: Partial + for?: string + textContent?: string + [key: string]: any + } + + type Children = Element[] | Element | string | string[] + + type ElementType = K extends keyof HTMLElementTagNameMap + ? HTMLElementTagNameMap[K] + : HTMLElement + + const $el: ( + tag: TTag, + propsOrChildren?: Children | Props, + children?: Children, + ) => ElementType + + class ComfyUI { + app: app.ComfyApp + settings: ComfySettingsDialog + menuHamburger?: HTMLDivElement + menuContainer?: HTMLDivElement + } + + type SettingInputType = + | 'boolean' + | 'number' + | 'slider' + | 'combo' + | 'text' + | 'hidden' + + type SettingCustomRenderer = ( + name: string, + setter: (v: any) => void, + value: any, + attrs: any, + ) => HTMLElement + + interface SettingOption { + text: string + value?: string + } + + interface SettingParams { + id: string + name: string + type: SettingInputType | SettingCustomRenderer + defaultValue: any + onChange?: (newValue: any, oldValue?: any) => void + attrs?: any + tooltip?: string + options?: + | Array + | ((value: any) => SettingOption[]) + // By default category is id.split('.'). However, changing id to assign + // new category has poor backward compatibility. Use this field to overwrite + // default category from id. + // Note: Like id, category value need to be unique. + category?: string[] + experimental?: boolean + deprecated?: boolean + } + + class ComfySettingsDialog { + addSetting: (params: SettingParams) => { value: any } + } + } + + namespace index { + class ComfyAppMenu { + app: app.ComfyApp + logo: HTMLElement + actionsGroup: button.ComfyButtonGroup + settingsGroup: button.ComfyButtonGroup + viewGroup: button.ComfyButtonGroup + mobileMenuButton: ComfyButton + element: HTMLElement + } + } + + namespace button { + type ComfyButtonProps = { + icon?: string + overIcon?: string + iconSize?: number + content?: string | HTMLElement + tooltip?: string + enabled?: boolean + action?: (e: Event, btn: ComfyButton) => void + classList?: ClassList + visibilitySetting?: { id: keyof Settings; showValue: boolean } + app?: app.ComfyApp + } + + class ComfyButton { + constructor(props: ComfyButtonProps): ComfyButton + } + + class ComfyButtonGroup { + insert(button: ComfyButton, index: number): void + append(button: ComfyButton): void + remove(indexOrButton: ComfyButton | number): void + update(): void + constructor(...buttons: (HTMLElement | ComfyButton)[]): ComfyButtonGroup + } + } +} + +declare namespace lightGraph { + class LGraphNode implements ComfyNodeDef { + widgets: any[] + pos: Vector2 + } + + class LGraphGroup {} + + class LGraph { + /** + * Adds a new node instance to this graph + * @param node the instance of the node + */ + add(node: LGraphNode | LGraphGroup, skip_compute_order?: boolean): void + /** + * Returns the top-most node in this position of the canvas + * @param x the x coordinate in canvas space + * @param y the y coordinate in canvas space + * @param nodes_list a list with all the nodes to search from, by default is all the nodes in the graph + * @return the node at this position or null + */ + getNodeOnPos( + x: number, + y: number, + node_list?: LGraphNode[], + margin?: number, + ): T | null + } + + class LGraphCanvas { + selected_nodes: Record + canvas_mouse: Vector2 + selectNode: (node: LGraphNode) => void + copyToClipboard: (nodes: LGraphNode[]) => void + } + + const LiteGraph: { + createNode: ( + type: string, + title: string | null, + options: object, + ) => LGraphNode + } + + type ComfyNodeDef = { + input?: { + required?: Record + optional?: Record + hidden?: Record + } + output?: (string | any[])[] + output_is_list?: boolean[] + output_name?: string[] + output_tooltips?: string[] + name?: string + display_name?: string + description?: string + category?: string + output_node?: boolean + python_module?: string + deprecated?: boolean + experimental?: boolean + } + + type Vector2 = [number, number] +} + +interface Window { + comfyAPI: typeof ComfyAPI + LiteGraph: typeof lightGraph.LiteGraph +} diff --git a/src/types/shims.d.ts b/src/types/shims.d.ts new file mode 100644 index 0000000..38dbf75 --- /dev/null +++ b/src/types/shims.d.ts @@ -0,0 +1,11 @@ +export {} + +declare module 'vue' { + interface ComponentCustomProperties { + vResize: (typeof import('hooks/resize'))['resizeDirective'] + } +} + +declare module 'hooks/store' { + interface StoreProvider {} +} diff --git a/src/types/typings.d.ts b/src/types/typings.d.ts new file mode 100644 index 0000000..b15f900 --- /dev/null +++ b/src/types/typings.d.ts @@ -0,0 +1,69 @@ +interface BaseModel { + id: number | string + fullname: string + basename: string + extension: string + sizeBytes: number + type: string + pathIndex: number + preview: string | string[] + description: string + metadata: Record +} + +interface Model extends BaseModel { + createdAt: number + updatedAt: number +} + +interface VersionModel extends BaseModel { + shortname: string + downloadPlatform: string + downloadUrl: string + hashes?: Record +} + +type PassThrough = T | object | undefined + +interface SelectOptions { + label: string + value: any + icon?: string + command: () => void +} + +interface SelectFile extends File { + objectURL: string +} + +interface SelectEvent { + files: SelectFile[] + originalEvent: Event +} + +interface DownloadTaskOptions { + taskId: string + type: string + fullname: string + preview: string + status: 'pause' | 'waiting' | 'doing' + progress: number + downloadedSize: number + totalSize: number + bps: number + error?: string +} + +interface DownloadTask + extends Omit< + DownloadTaskOptions, + 'downloadedSize' | 'totalSize' | 'bps' | 'error' + > { + downloadProgress: string + downloadSpeed: string + pauseTask: () => void + resumeTask: () => void + deleteTask: () => void +} + +type CustomEventListener = (event: CustomEvent) => void diff --git a/src/utils/common.ts b/src/utils/common.ts new file mode 100644 index 0000000..5d6a259 --- /dev/null +++ b/src/utils/common.ts @@ -0,0 +1,39 @@ +import dayjs from 'dayjs' + +export const bytesToSize = ( + bytes: number | string | undefined | null, + decimals = 2, +) => { + if (typeof bytes === 'undefined' || bytes === null) { + bytes = 0 + } + if (typeof bytes === 'string') { + bytes = Number(bytes) + } + if (Number.isNaN(bytes)) { + return 'Unknown' + } + if (bytes === 0) { + return '0 Bytes' + } + const k = 1024 + const dm = decimals < 0 ? 0 : decimals + const sizes = ['Bytes', 'KB', 'MB', 'GB', 'TB', 'PB', 'EB', 'ZB', 'YB'] + const i = Math.floor(Math.log(bytes) / Math.log(k)) + return parseFloat((bytes / Math.pow(k, i)).toFixed(dm)) + ' ' + sizes[i] +} + +export const formatDate = (date: number | string | Date) => { + return dayjs(date).format('YYYY-MM-DD HH:mm:ss') +} + +export const previewUrlToFile = async (url: string) => { + return fetch(url) + .then((res) => res.blob()) + .then((blob) => { + const type = blob.type + const extension = type.split('/')[1] + const file = new File([blob], `preview.${extension}`, { type }) + return file + }) +} diff --git a/src/utils/legacy.ts b/src/utils/legacy.ts new file mode 100644 index 0000000..bc74917 --- /dev/null +++ b/src/utils/legacy.ts @@ -0,0 +1,620 @@ +// @ts-nocheck +import { app } from 'scripts/comfyAPI' + +const LiteGraph = window.LiteGraph + +const modelNodeType = { + checkpoints: 'CheckpointLoaderSimple', + clip: 'CLIPLoader', + clip_vision: 'CLIPVisionLoader', + controlnet: 'ControlNetLoader', + diffusers: 'DiffusersLoader', + embeddings: 'Embedding', + gligen: 'GLIGENLoader', + hypernetworks: 'HypernetworkLoader', + photomaker: 'PhotoMakerLoader', + loras: 'LoraLoader', + style_models: 'StyleModelLoader', + unet: 'UNETLoader', + upscale_models: 'UpscaleModelLoader', + vae: 'VAELoader', + vae_approx: undefined, +} + +export class ModelGrid { + /** + * @param {string} nodeType + * @returns {int} + */ + static modelWidgetIndex(nodeType) { + return nodeType === undefined ? -1 : 0 + } + + /** + * @param {string} text + * @param {string} file + * @param {boolean} removeExtension + * @returns {string} + */ + static insertEmbeddingIntoText(text, file, removeExtension) { + let name = file + if (removeExtension) { + name = SearchPath.splitExtension(name)[0] + } + const sep = text.length === 0 || text.slice(-1).match(/\s/) ? '' : ' ' + return text + sep + '(embedding:' + name + ':1.0)' + } + + /** + * @param {Array} list + * @param {string} searchString + * @returns {Array} + */ + static #filter(list, searchString) { + /** @type {string[]} */ + const keywords = searchString + //.replace("*", " ") // TODO: this is wrong for wildcards + .split(/(-?".*?"|[^\s"]+)+/g) + .map((item) => + item + .trim() + .replace(/(?:")+/g, '') + .toLowerCase(), + ) + .filter(Boolean) + + const regexSHA256 = /^[a-f0-9]{64}$/gi + const fields = ['name', 'path'] + return list.filter((element) => { + const text = fields + .reduce((memo, field) => memo + ' ' + element[field], '') + .toLowerCase() + return keywords.reduce((memo, target) => { + const excludeTarget = target[0] === '-' + if (excludeTarget && target.length === 1) { + return memo + } + const filteredTarget = excludeTarget ? target.slice(1) : target + if ( + element['SHA256'] !== undefined && + regexSHA256.test(filteredTarget) + ) { + return ( + memo && excludeTarget !== (filteredTarget === element['SHA256']) + ) + } else { + return memo && excludeTarget !== text.includes(filteredTarget) + } + }, true) + }) + } + + /** + * In-place sort. Returns an array alias. + * @param {Array} list + * @param {string} sortBy + * @param {bool} [reverse=false] + * @returns {Array} + */ + static #sort(list, sortBy, reverse = false) { + let compareFn = null + switch (sortBy) { + case MODEL_SORT_DATE_NAME: + compareFn = (a, b) => { + return a[MODEL_SORT_DATE_NAME].localeCompare(b[MODEL_SORT_DATE_NAME]) + } + break + case MODEL_SORT_DATE_MODIFIED: + compareFn = (a, b) => { + return b[MODEL_SORT_DATE_MODIFIED] - a[MODEL_SORT_DATE_MODIFIED] + } + break + case MODEL_SORT_DATE_CREATED: + compareFn = (a, b) => { + return b[MODEL_SORT_DATE_CREATED] - a[MODEL_SORT_DATE_CREATED] + } + break + case MODEL_SORT_SIZE_BYTES: + compareFn = (a, b) => { + return b[MODEL_SORT_SIZE_BYTES] - a[MODEL_SORT_SIZE_BYTES] + } + break + default: + console.warn("Invalid filter sort value: '" + sortBy + "'") + return list + } + const sorted = list.sort(compareFn) + return reverse ? sorted.reverse() : sorted + } + + /** + * @param {Event} event + * @param {string} modelType + * @param {string} path + * @param {boolean} removeEmbeddingExtension + * @param {int} addOffset + */ + static #addModel( + event, + modelType, + path, + removeEmbeddingExtension, + addOffset, + ) { + let success = false + if (modelType !== 'embeddings') { + const nodeType = modelNodeType[modelType] + const widgetIndex = ModelGrid.modelWidgetIndex(nodeType) + let node = LiteGraph.createNode(nodeType, null, []) + if (widgetIndex !== -1 && node) { + node.widgets[widgetIndex].value = path + const selectedNodes = app.canvas.selected_nodes + let isSelectedNode = false + for (var i in selectedNodes) { + const selectedNode = selectedNodes[i] + node.pos[0] = selectedNode.pos[0] + addOffset + node.pos[1] = selectedNode.pos[1] + addOffset + isSelectedNode = true + break + } + if (!isSelectedNode) { + const graphMouse = app.canvas.graph_mouse + node.pos[0] = graphMouse[0] + node.pos[1] = graphMouse[1] + } + app.graph.add(node, { doProcessChange: true }) + app.canvas.selectNode(node) + success = true + } + event.stopPropagation() + } else if (modelType === 'embeddings') { + const [embeddingDirectory, embeddingFile] = SearchPath.split(path) + const selectedNodes = app.canvas.selected_nodes + for (var i in selectedNodes) { + const selectedNode = selectedNodes[i] + const nodeType = modelNodeType[modelType] + const widgetIndex = ModelGrid.modelWidgetIndex(nodeType) + const target = selectedNode?.widgets[widgetIndex]?.element + if (target && target.type === 'textarea') { + target.value = ModelGrid.insertEmbeddingIntoText( + target.value, + embeddingFile, + removeEmbeddingExtension, + ) + success = true + } + } + if (!success) { + console.warn('Try selecting a node before adding the embedding.') + } + event.stopPropagation() + } + comfyButtonAlert(event.target, success, 'mdi-check-bold', 'mdi-close-thick') + } + + static #getWidgetComboIndices(node, value) { + const widgetIndices = [] + node?.widgets?.forEach((widget, index) => { + if (widget.type === 'combo' && widget.options.values?.includes(value)) { + widgetIndices.push(index) + } + }) + return widgetIndices + } + + /** + * @param {DragEvent} event + * @param {string} modelType + * @param {string} path + * @param {boolean} removeEmbeddingExtension + * @param {boolean} strictlyOnWidget + */ + static dragAddModel( + event, + modelType, + path, + removeEmbeddingExtension, + strictlyOnWidget, + ) { + const target = document.elementFromPoint(event.clientX, event.clientY) + if (modelType !== 'embeddings' && target.id === 'graph-canvas') { + const pos = app.canvas.convertEventToCanvasOffset(event) + + const node = app.graph.getNodeOnPos( + pos[0], + pos[1], + app.canvas.visible_nodes, + ) + + let widgetIndex = -1 + if (widgetIndex === -1) { + const widgetIndices = this.#getWidgetComboIndices(node, path) + if (widgetIndices.length === 0) { + widgetIndex = -1 + } else if (widgetIndices.length === 1) { + widgetIndex = widgetIndices[0] + if (strictlyOnWidget) { + const draggedWidget = app.canvas.processNodeWidgets( + node, + pos, + event, + ) + const widget = node.widgets[widgetIndex] + if (draggedWidget != widget) { + // != check NOT same object + widgetIndex = -1 + } + } + } else { + // ambiguous widget (strictlyOnWidget always true) + const draggedWidget = app.canvas.processNodeWidgets(node, pos, event) + widgetIndex = widgetIndices.findIndex((index) => { + return draggedWidget == node.widgets[index] // == check same object + }) + } + } + + if (widgetIndex !== -1) { + node.widgets[widgetIndex].value = path + app.canvas.selectNode(node) + } else { + const expectedNodeType = modelNodeType[modelType] + const newNode = LiteGraph.createNode(expectedNodeType, null, []) + let newWidgetIndex = ModelGrid.modelWidgetIndex(expectedNodeType) + if (newWidgetIndex === -1) { + newWidgetIndex = this.#getWidgetComboIndices(newNode, path)[0] ?? -1 + } + if ( + newNode !== undefined && + newNode !== null && + newWidgetIndex !== -1 + ) { + newNode.pos[0] = pos[0] + newNode.pos[1] = pos[1] + newNode.widgets[newWidgetIndex].value = path + app.graph.add(newNode, { doProcessChange: true }) + app.canvas.selectNode(newNode) + } + } + event.stopPropagation() + } else if (modelType === 'embeddings' && target.type === 'textarea') { + const pos = app.canvas.convertEventToCanvasOffset(event) + const nodeAtPos = app.graph.getNodeOnPos( + pos[0], + pos[1], + app.canvas.visible_nodes, + ) + if (nodeAtPos) { + app.canvas.selectNode(nodeAtPos) + const [embeddingDirectory, embeddingFile] = SearchPath.split(path) + target.value = ModelGrid.insertEmbeddingIntoText( + target.value, + embeddingFile, + removeEmbeddingExtension, + ) + event.stopPropagation() + } + } + } + + /** + * @param {Event} event + * @param {string} modelType + * @param {string} path + * @param {boolean} removeEmbeddingExtension + */ + static #copyModelToClipboard( + event, + modelType, + path, + removeEmbeddingExtension, + ) { + const nodeType = modelNodeType[modelType] + let success = false + if (nodeType === 'Embedding') { + if (navigator.clipboard) { + const [embeddingDirectory, embeddingFile] = SearchPath.split(path) + const embeddingText = ModelGrid.insertEmbeddingIntoText( + '', + embeddingFile, + removeEmbeddingExtension, + ) + navigator.clipboard.writeText(embeddingText) + success = true + } else { + console.warn( + 'Cannot copy the embedding to the system clipboard; Try dragging it instead.', + ) + } + } else if (nodeType) { + const node = LiteGraph.createNode(nodeType, null, []) + const widgetIndex = ModelGrid.modelWidgetIndex(nodeType) + if (widgetIndex !== -1) { + node.widgets[widgetIndex].value = path + app.canvas.copyToClipboard([node]) + success = true + } + } else { + console.warn(`Unable to copy unknown model type '${modelType}.`) + } + comfyButtonAlert(event.target, success, 'mdi-check-bold', 'mdi-close-thick') + } + + /** + * @param {Array} models + * @param {string} modelType + * @param {Object.} settingsElements + * @param {String} searchSeparator + * @param {String} systemSeparator + * @param {(searchPath: string) => Promise} showModelInfo + * @returns {HTMLElement[]} + */ + static #generateInnerHtml( + models, + modelType, + settingsElements, + searchSeparator, + systemSeparator, + showModelInfo, + ) { + // TODO: separate text and model logic; getting too messy + // TODO: fallback on button failure to copy text? + const canShowButtons = modelNodeType[modelType] !== undefined + const showAddButton = + canShowButtons && settingsElements['model-show-add-button'].checked + const showCopyButton = + canShowButtons && settingsElements['model-show-copy-button'].checked + const showLoadWorkflowButton = + canShowButtons && + settingsElements['model-show-load-workflow-button'].checked + const strictDragToAdd = + settingsElements['model-add-drag-strict-on-field'].checked + const addOffset = parseInt(settingsElements['model-add-offset'].value) + const showModelExtension = + settingsElements['model-show-label-extensions'].checked + const modelInfoButtonOnLeft = + !settingsElements['model-info-button-on-left'].checked + const removeEmbeddingExtension = + !settingsElements['model-add-embedding-extension'].checked + const previewThumbnailFormat = + settingsElements['model-preview-thumbnail-type'].value + const previewThumbnailWidth = Math.round( + settingsElements['model-preview-thumbnail-width'].value / 0.75, + ) + const previewThumbnailHeight = Math.round( + settingsElements['model-preview-thumbnail-height'].value / 0.75, + ) + const buttonsOnlyOnHover = + settingsElements['model-buttons-only-on-hover'].checked + if (models.length > 0) { + const $overlay = IS_FIREFOX + ? (modelType, path, removeEmbeddingExtension, strictDragToAdd) => { + return $el('div.model-preview-overlay', { + ondragstart: (e) => { + const data = { + modelType: modelType, + path: path, + removeEmbeddingExtension: removeEmbeddingExtension, + strictDragToAdd: strictDragToAdd, + } + e.dataTransfer.setData('manager-model', JSON.stringify(data)) + e.dataTransfer.setData('text/plain', '') + }, + draggable: true, + }) + } + : (modelType, path, removeEmbeddingExtension, strictDragToAdd) => { + return $el('div.model-preview-overlay', { + ondragend: (e) => + ModelGrid.dragAddModel( + e, + modelType, + path, + removeEmbeddingExtension, + strictDragToAdd, + ), + draggable: true, + }) + } + const forHiddingButtonsClass = buttonsOnlyOnHover + ? 'model-buttons-hidden' + : 'model-buttons-visible' + + return models.map((item) => { + const previewInfo = item.preview + const previewThumbnail = $el('img.model-preview', { + loading: + 'lazy' /* `loading` BEFORE `src`; Known bug in Firefox 124.0.2 and Safari for iOS 17.4.1 (https://stackoverflow.com/a/76252772) */, + src: imageUri( + previewInfo?.path, + previewInfo?.dateModified, + previewThumbnailWidth, + previewThumbnailHeight, + previewThumbnailFormat, + ), + draggable: false, + }) + const searchPath = item.path + const path = SearchPath.systemPath( + searchPath, + searchSeparator, + systemSeparator, + ) + let actionButtons = [] + if (showCopyButton) { + actionButtons.push( + new ComfyButton({ + icon: 'content-copy', + tooltip: 'Copy model to clipboard', + classList: 'comfyui-button icon-button model-button', + action: (e) => + ModelGrid.#copyModelToClipboard( + e, + modelType, + path, + removeEmbeddingExtension, + ), + }).element, + ) + } + if ( + showAddButton && + !(modelType === 'embeddings' && !navigator.clipboard) + ) { + actionButtons.push( + new ComfyButton({ + icon: 'plus-box-outline', + tooltip: 'Add model to node grid', + classList: 'comfyui-button icon-button model-button', + action: (e) => + ModelGrid.#addModel( + e, + modelType, + path, + removeEmbeddingExtension, + addOffset, + ), + }).element, + ) + } + if (showLoadWorkflowButton) { + actionButtons.push( + new ComfyButton({ + icon: 'arrow-bottom-left-bold-box-outline', + tooltip: 'Load preview workflow', + classList: 'comfyui-button icon-button model-button', + action: async (e) => { + const urlString = previewThumbnail.src + const url = new URL(urlString) + const urlSearchParams = url.searchParams + const uri = urlSearchParams.get('uri') + const v = urlSearchParams.get('v') + const urlFull = + urlString.substring(0, urlString.indexOf('?')) + + '?uri=' + + uri + + '&v=' + + v + await loadWorkflow(urlFull) + }, + }).element, + ) + } + const infoButtons = [ + new ComfyButton({ + icon: 'information-outline', + tooltip: 'View model information', + classList: 'comfyui-button icon-button model-button', + action: async () => { + await showModelInfo(searchPath) + }, + }).element, + ] + return $el('div.item', {}, [ + previewThumbnail, + $overlay(modelType, path, removeEmbeddingExtension, strictDragToAdd), + $el( + 'div.model-preview-top-right.' + forHiddingButtonsClass, + { + draggable: false, + }, + modelInfoButtonOnLeft ? infoButtons : actionButtons, + ), + $el( + 'div.model-preview-top-left.' + forHiddingButtonsClass, + { + draggable: false, + }, + modelInfoButtonOnLeft ? actionButtons : infoButtons, + ), + $el( + 'div.model-label', + { + draggable: false, + }, + [ + $el('p', [ + showModelExtension + ? item.name + : SearchPath.splitExtension(item.name)[0], + ]), + ], + ), + ]) + }) + } else { + return [$el('h2', ['No Models'])] + } + } + + /** + * @param {HTMLDivElement} modelGrid + * @param {ModelData} modelData + * @param {HTMLSelectElement} modelSelect + * @param {Object.<{value: string}>} previousModelType + * @param {Object} settings + * @param {string} sortBy + * @param {boolean} reverseSort + * @param {Array} previousModelFilters + * @param {HTMLInputElement} modelFilter + * @param {(searchPath: string) => Promise} showModelInfo + */ + static update( + modelGrid, + modelData, + modelSelect, + previousModelType, + settings, + sortBy, + reverseSort, + previousModelFilters, + modelFilter, + showModelInfo, + ) { + const models = modelData.models + let modelType = modelSelect.value + if (models[modelType] === undefined) { + modelType = settings['model-default-browser-model-type'].value + } + if (models[modelType] === undefined) { + modelType = 'checkpoints' // panic fallback + } + + if (modelType !== previousModelType.value) { + if (settings['model-persistent-search'].checked) { + previousModelFilters.splice(0, previousModelFilters.length) // TODO: make sure this actually worked! + } else { + // cache previous filter text + previousModelFilters[previousModelType.value] = modelFilter.value + // read cached filter text + modelFilter.value = previousModelFilters[modelType] ?? '' + } + previousModelType.value = modelType + } + + let modelTypeOptions = [] + for (const [key, value] of Object.entries(models)) { + const el = $el('option', [key]) + modelTypeOptions.push(el) + } + modelSelect.innerHTML = '' + modelTypeOptions.forEach((option) => modelSelect.add(option)) + modelSelect.value = modelType + + const searchAppend = settings['model-search-always-append'].value + const searchText = modelFilter.value + ' ' + searchAppend + const modelList = ModelGrid.#filter(models[modelType], searchText) + ModelGrid.#sort(modelList, sortBy, reverseSort) + + modelGrid.innerHTML = '' + const modelGridModels = ModelGrid.#generateInnerHtml( + modelList, + modelType, + settings, + modelData.searchSeparator, + modelData.systemSeparator, + showModelInfo, + ) + modelGrid.append.apply(modelGrid, modelGridModels) + } +} diff --git a/src/utils/model.ts b/src/utils/model.ts new file mode 100644 index 0000000..af1b364 --- /dev/null +++ b/src/utils/model.ts @@ -0,0 +1,45 @@ +const loader = { + checkpoints: 'CheckpointLoaderSimple', + clip: 'CLIPLoader', + clip_vision: 'CLIPVisionLoader', + controlnet: 'ControlNetLoader', + diffusers: 'DiffusersLoader', + diffusion_models: 'DiffusersLoader', + embeddings: 'Embedding', + gligen: 'GLIGENLoader', + hypernetworks: 'HypernetworkLoader', + photomaker: 'PhotoMakerLoader', + loras: 'LoraLoader', + style_models: 'StyleModelLoader', + unet: 'UNETLoader', + upscale_models: 'UpscaleModelLoader', + vae: 'VAELoader', + vae_approx: undefined, +} + +const display = { + all: 'ALL', + checkpoints: 'Checkpoint', + clip: 'Clip', + clip_vision: 'Clip Vision', + controlnet: 'Controlnet', + diffusers: 'Diffusers', + diffusion_models: 'Diffusers', + embeddings: 'embedding', + gligen: 'Gligen', + hypernetworks: 'Hypernetwork', + photomaker: 'Photomaker', + loras: 'LoRA', + style_models: 'Style Model', + unet: 'Unet', + upscale_models: 'Upscale Model', + vae: 'VAE', + vae_approx: 'VAE approx', +} + +export const resolveModelType = (type: string) => { + return { + display: display[type], + loader: loader[type], + } +} diff --git a/src/vite-env.d.ts b/src/vite-env.d.ts new file mode 100644 index 0000000..11f02fe --- /dev/null +++ b/src/vite-env.d.ts @@ -0,0 +1 @@ +/// diff --git a/tailwind.config.js b/tailwind.config.js index 5c8bc02..b562990 100644 --- a/tailwind.config.js +++ b/tailwind.config.js @@ -1,8 +1,213 @@ +import container from '@tailwindcss/container-queries' +import plugin from 'tailwindcss/plugin' + /** @type {import('tailwindcss').Config} */ export default { - content: [], - theme: { - extend: {}, + content: ['index.html', './src/**/*.vue'], + + darkMode: ['selector', '.dark-theme'], + + plugins: [ + container, + plugin(({ addUtilities }) => { + addUtilities({ + '.scrollbar-none': { + 'scrollbar-width': 'none', + }, + '.preview-aspect': { + 'aspect-ratio': '7/9', + img: { + width: '100%', + height: '100%', + objectFit: 'cover', + display: 'block', + }, + }, + }) + }), + ], + + corePlugins: { + preflight: false, // This disables Tailwind's base styles + }, + + theme: { + fontSize: { + xs: '0.75rem', + sm: '0.875rem', + base: '1rem', + lg: '1.125rem', + xl: '1.25rem', + '2xl': '1.5rem', + '3xl': '1.875rem', + '4xl': '2.25rem', + '5xl': '3rem', + '6xl': '4rem', + }, + + screens: { + sm: '640px', + md: '768px', + lg: '1024px', + xl: '1280px', + '2xl': '1536px', + '3xl': '1800px', + '4xl': '2500px', + '5xl': '3200px', + }, + + spacing: { + px: '1px', + 0: '0px', + 0.5: '0.125rem', + 1: '0.25rem', + 1.5: '0.375rem', + 2: '0.5rem', + 2.5: '0.625rem', + 3: '0.75rem', + 3.5: '0.875rem', + 4: '1rem', + 4.5: '1.125rem', + 5: '1.25rem', + 6: '1.5rem', + 7: '1.75rem', + 8: '2rem', + 9: '2.25rem', + 10: '2.5rem', + 11: '2.75rem', + 12: '3rem', + 14: '3.5rem', + 16: '4rem', + 18: '4.5rem', + 20: '5rem', + 24: '6rem', + 28: '7rem', + 32: '8rem', + 36: '9rem', + 40: '10rem', + 44: '11rem', + 48: '12rem', + 52: '13rem', + 56: '14rem', + 60: '15rem', + 64: '16rem', + 72: '18rem', + 80: '20rem', + 84: '22rem', + 90: '24rem', + 96: '26rem', + 100: '28rem', + 110: '32rem', + }, + + extend: { + gridTemplateColumns: { + dynamic: 'repeat(var(--tw-grid-cols-count), var(--tw-grid-cols-width))', + }, + + spacing: { + dynamic: 'var(--tw-spacing-size)', + }, + + colors: { + zinc: { + 50: '#fafafa', + 100: '#f4f4f5', + 200: '#e4e4e7', + 300: '#d4d4d8', + 400: '#a1a1aa', + 500: '#71717a', + 600: '#52525b', + 700: '#3f3f46', + 800: '#27272a', + 900: '#18181b', + 950: '#09090b', + }, + + gray: { + 50: '#f8fbfc', + 100: '#f3f6fa', + 200: '#edf2f7', + 300: '#e2e8f0', + 400: '#cbd5e0', + 500: '#a0aec0', + 600: '#718096', + 700: '#4a5568', + 800: '#2d3748', + 900: '#1a202c', + 950: '#0a1016', + }, + + teal: { + 50: '#f0fdfa', + 100: '#e0fcff', + 200: '#bef8fd', + 300: '#87eaf2', + 400: '#54d1db', + 500: '#38bec9', + 600: '#2cb1bc', + 700: '#14919b', + 800: '#0e7c86', + 900: '#005860', + 950: '#022c28', + }, + + blue: { + 50: '#eff6ff', + 100: '#ebf8ff', + 200: '#bee3f8', + 300: '#90cdf4', + 400: '#63b3ed', + 500: '#4299e1', + 600: '#3182ce', + 700: '#2b6cb0', + 800: '#2c5282', + 900: '#2a4365', + 950: '#172554', + }, + + green: { + 50: '#fcfff5', + 100: '#fafff3', + 200: '#eaf9c9', + 300: '#d1efa0', + 400: '#b2e16e', + 500: '#96ce4c', + 600: '#7bb53d', + 700: '#649934', + 800: '#507b2e', + 900: '#456829', + 950: '#355819', + }, + + fuchsia: { + 50: '#fdf4ff', + 100: '#fae8ff', + 200: '#f5d0fe', + 300: '#f0abfc', + 400: '#e879f9', + 500: '#d946ef', + 600: '#c026d3', + 700: '#a21caf', + 800: '#86198f', + 900: '#701a75', + 950: '#4a044e', + }, + + orange: { + 50: '#fff7ed', + 100: '#ffedd5', + 200: '#fedbb8', + 300: '#fbd38d', + 400: '#f6ad55', + 500: '#ed8936', + 600: '#dd6b20', + 700: '#c05621', + 800: '#9c4221', + 900: '#7b341e', + 950: '#431407', + }, + }, + }, }, - plugins: [], } diff --git a/tsconfig.json b/tsconfig.json index cb8982b..e2eb9cc 100644 --- a/tsconfig.json +++ b/tsconfig.json @@ -24,7 +24,14 @@ "allowJs": true, "baseUrl": ".", "outDir": "./web", - "rootDir": "./" + "rootDir": "./", + "paths": { + "components/*": ["src/components/*"], + "hooks/*": ["src/hooks/*"], + "scripts/*": ["src/scripts/*"], + "types/*": ["src/types/*"], + "utils/*": ["src/utils/*"], + } }, "include": [ "src/**/*", diff --git a/vite.config.ts b/vite.config.ts index b392268..6462473 100644 --- a/vite.config.ts +++ b/vite.config.ts @@ -1,8 +1,93 @@ -import { defineConfig } from 'vite' import vue from '@vitejs/plugin-vue' +import fs from 'node:fs' +import path from 'node:path' +import { defineConfig, Plugin } from 'vite' + +function css(): Plugin { + return { + name: 'vite-plugin-css-inject', + apply: 'build', + enforce: 'post', + generateBundle(_, bundle) { + const cssCode: string[] = [] + + for (const key in bundle) { + if (Object.prototype.hasOwnProperty.call(bundle, key)) { + const chunk = bundle[key] + if (chunk.type === 'asset' && chunk.fileName.endsWith('.css')) { + cssCode.push(chunk.source) + delete bundle[key] + } + } + } + + for (const key in bundle) { + if (Object.prototype.hasOwnProperty.call(bundle, key)) { + const chunk = bundle[key] + if (chunk.type === 'chunk' && /index-.*\.js$/.test(chunk.fileName)) { + const originalCode = chunk.code + chunk.code = '(function(){var s=document.createElement("style");' + chunk.code += 's.type="text/css",s.dataset.styleId="model-manager",' + chunk.code += 's.appendChild(document.createTextNode(' + chunk.code += JSON.stringify(cssCode.join('')) + chunk.code += ')),document.head.appendChild(s);})();' + chunk.code += originalCode + } + } + } + }, + } +} + +function output(): Plugin { + return { + name: 'vite-plugin-output-fix', + apply: 'build', + enforce: 'post', + generateBundle(_, bundle) { + for (const key in bundle) { + const chunk = bundle[key] + + if (chunk.type === 'asset') { + if (chunk.fileName === 'index.html') { + delete bundle[key] + } + } + + if (chunk.fileName.startsWith('assets/')) { + chunk.fileName = chunk.fileName.replace('assets/', '') + } + } + }, + } +} + +function dev(): Plugin { + return { + name: 'vite-plugin-dev-fix', + apply: 'serve', + enforce: 'post', + configureServer(server) { + server.httpServer?.on('listening', () => { + const rootDir = server.config.root + const outDir = server.config.build.outDir + + const outDirPath = path.join(rootDir, outDir) + if (fs.existsSync(outDirPath)) { + fs.rmSync(outDirPath, { recursive: true }) + } + fs.mkdirSync(outDirPath) + + const port = server.config.server.port + const content = `import "http://127.0.0.1:${port}/src/main.ts";` + fs.writeFileSync(path.join(outDirPath, 'manager-dev.js'), content) + }) + }, + } +} export default defineConfig({ - plugins: [vue()], + plugins: [vue(), css(), output(), dev()], build: { outDir: 'web', @@ -13,6 +98,25 @@ export default defineConfig({ // Disabling tree-shaking // Prevent vite remove unused exports treeshake: true, + output: { + manualChunks(id) { + if (id.includes('primevue')) { + return 'primevue' + } + }, + }, + }, + chunkSizeWarningLimit: 1024, + }, + + resolve: { + alias: { + src: resolvePath('src'), + components: resolvePath('src/components'), + hooks: resolvePath('src/hooks'), + scripts: resolvePath('src/scripts'), + types: resolvePath('src/types'), + utils: resolvePath('src/utils'), }, }, @@ -23,3 +127,7 @@ export default defineConfig({ minifyWhitespace: true, }, }) + +function resolvePath(str: string) { + return path.resolve(__dirname, str) +}