+# Usage
-
+```bash
+cd /path/to/ComfyUI/custom_nodes
+git clone https://github.com/hayden-fr/ComfyUI-Model-Manager.git
+cd /path/to/ComfyUI/custom_nodes/ComfyUI-Model-Manager
+npm install
+npm run build
+```
## Features
-### Node Graph
+## Freely adjust size and position
-
+
+
+### Support Node Graph
+
+
- Drag a model thumbnail onto the graph to add a new node.
- Drag a model thumbnail onto an existing node to set the input field.
- If there are multiple valid possible fields, then the drag must be exact.
- Drag an embedding thumbnail onto a text area, or highlight any number of nodes, to append it onto the end of the text.
- Drag the preview image in a model's info view onto the graph to load the embedded workflow (if it exists).
-
-
+
- View multiple models associated with a url.
- Select a save directory and input a filename.
- Optionally set a model's preview image.
-- Optionally edit and save descriptions as a .txt note. (Default behavior can be set in the settings tab.)
-- Add Civitai and HuggingFace API tokens in `server_settings.yaml`.
+- Optionally edit and save descriptions as a .md note.
+- Add Civitai and HuggingFace API tokens in ComfyUI's settings.
+
+
### Models Tab
-
+
- Search in real-time for models using the search bar.
-- Use advance keyword search by typing `"multiple words in quotes"` or a minus sign before to `-exclude` a word or phrase.
-- Add `/` at the start of a search to view a dropdown list of subdirectories (for example, `/0/1.5/styles/clothing`).
- - Any directory paths in ComfyUI's `extra_model_paths.yaml` or directories added in `ComfyUI/models/` will automatically be detected.
-- Sort models by "Date Created", "Date Modified", "Name" and "File Size".
+- Sort models by "Name", "File Size", "Date Created" and "Date Modified".
### Model Info View
-
+
- View file info and metadata.
- Rename, move or **permanently** remove a model and all of it's related files.
-- Read, edit and save notes. (Saved as a `.txt` file beside the model).
- - `Ctrl+s` or `⌘+S` to save a note when the textarea is in focus.
- - Autosave can be enabled in settings. (Note: Once the model info view is closed, the undo history is lost.)
+- Read, edit and save notes. (Saved as a `.md` file beside the model).
- Change or remove a model's preview image.
- View training tags and use the random tag generator to generate prompt ideas. (Inspired by the one in A1111.)
-
-### Settings Tab
-
-
-
-- Settings are saved to `ui_settings.yaml`.
-- Most settings should update immediately, but a few may require a page reload to take effect.
-- Press the "Fix Extensions" button to correct all image file extensions in the model directories. (Note: This may take a minute or so to complete.)
diff --git a/__init__.py b/__init__.py
index 6aadb9b..1525a25 100644
--- a/__init__.py
+++ b/__init__.py
@@ -1,1228 +1,183 @@
import os
-import io
-import pathlib
-import shutil
-from datetime import datetime
-import sys
-import copy
-import importlib
-import re
-import base64
-
-from aiohttp import web
-import server
-import urllib.parse
-import urllib.request
-import struct
-import json
-import requests
-requests.packages.urllib3.disable_warnings()
-
-import comfy.utils
import folder_paths
-
-comfyui_model_uri = folder_paths.models_dir
-
-extension_uri = os.path.dirname(__file__)
-
-config_loader_path = os.path.join(extension_uri, 'config_loader.py')
-config_loader_spec = importlib.util.spec_from_file_location('config_loader', config_loader_path)
-config_loader = importlib.util.module_from_spec(config_loader_spec)
-config_loader_spec.loader.exec_module(config_loader)
-
-no_preview_image = os.path.join(extension_uri, "no-preview.png")
-ui_settings_uri = os.path.join(extension_uri, "ui_settings.yaml")
-server_settings_uri = os.path.join(extension_uri, "server_settings.yaml")
-
-fallback_model_extensions = set([".bin", ".ckpt", ".gguf", ".onnx", ".pt", ".pth", ".safetensors"]) # TODO: magic values
-jpeg_format_names = ["JPG", "JPEG", "JFIF"]
-image_extensions = (
- ".png", # order matters
- ".webp",
- ".jpeg",
- ".jpg",
- ".jfif",
- ".gif",
- ".apng",
-)
-stable_diffusion_webui_civitai_helper_image_extensions = (
- ".preview.png", # order matters
- ".preview.webp",
- ".preview.jpeg",
- ".preview.jpg",
- ".preview.jfif",
- ".preview.gif",
- ".preview.apng",
-)
-preview_extensions = ( # TODO: JavaScript does not know about this (x2 states)
- image_extensions + # order matters
- stable_diffusion_webui_civitai_helper_image_extensions
-)
-model_info_extension = ".txt"
-#video_extensions = (".avi", ".mp4", ".webm") # TODO: Requires ffmpeg or cv2. Cache preview frame?
-
-def split_valid_ext(s, *arg_exts):
- sl = s.lower()
- for exts in arg_exts:
- for ext in exts:
- if sl.endswith(ext.lower()):
- return (s[:-len(ext)], ext)
- return (s, "")
-
-_folder_names_and_paths = None # dict[str, tuple[list[str], list[str]]]
-def folder_paths_folder_names_and_paths(refresh = False):
- global _folder_names_and_paths
- if refresh or _folder_names_and_paths is None:
- _folder_names_and_paths = {}
- for item_name in os.listdir(comfyui_model_uri):
- item_path = os.path.join(comfyui_model_uri, item_name)
- if not os.path.isdir(item_path):
- continue
- if item_name == "configs":
- continue
- if item_name in folder_paths.folder_names_and_paths:
- dir_paths, extensions = copy.deepcopy(folder_paths.folder_names_and_paths[item_name])
- else:
- dir_paths = [item_path]
- extensions = copy.deepcopy(fallback_model_extensions)
- _folder_names_and_paths[item_name] = (dir_paths, extensions)
- return _folder_names_and_paths
-
-def folder_paths_get_folder_paths(folder_name, refresh = False): # API function crashes querying unknown model folder
- paths = folder_paths_folder_names_and_paths(refresh)
- if folder_name in paths:
- return paths[folder_name][0]
-
- maybe_path = os.path.join(comfyui_model_uri, folder_name)
- if os.path.exists(maybe_path):
- return [maybe_path]
- return []
-
-def folder_paths_get_supported_pt_extensions(folder_name, refresh = False): # Missing API function
- paths = folder_paths_folder_names_and_paths(refresh)
- if folder_name in paths:
- return paths[folder_name][1]
- model_extensions = copy.deepcopy(fallback_model_extensions)
- return model_extensions
+from .py import config
+from .py import utils
-def search_path_to_system_path(model_path):
- sep = os.path.sep
- model_path = os.path.normpath(model_path.replace("/", sep))
- model_path = model_path.lstrip(sep)
+# Init config settings
+config.extension_uri = os.path.dirname(__file__)
+utils.resolve_model_base_paths()
- isep1 = model_path.find(sep, 0)
- if isep1 == -1 or isep1 == len(model_path):
- return (None, None)
- isep2 = model_path.find(sep, isep1 + 1)
- if isep2 == -1 or isep2 - isep1 == 1:
- isep2 = len(model_path)
+import logging
+from aiohttp import web
+import traceback
+from .py import services
- model_path_type = model_path[0:isep1]
- paths = folder_paths_get_folder_paths(model_path_type)
- if len(paths) == 0:
- return (None, None)
- model_path_index = model_path[isep1 + 1:isep2]
+routes = config.routes
+
+
+@routes.get("/model-manager/ws")
+async def socket_handler(request):
+ """
+ Handle websocket connection.
+ """
+ ws = await services.connect_websocket(request)
+ return ws
+
+
+@routes.get("/model-manager/base-folders")
+async def get_model_paths(request):
+ """
+ Returns the base folders for models.
+ """
+ model_base_paths = config.model_base_paths
+ return web.json_response({"success": True, "data": model_base_paths})
+
+
+@routes.post("/model-manager/model")
+async def create_model(request):
+ """
+ Create a new model.
+
+ request body: x-www-form-urlencoded
+ - type: model type.
+ - pathIndex: index of the model folders.
+ - fullname: filename that relative to the model folder.
+ - previewFile: preview file.
+ - description: description.
+ - downloadPlatform: download platform.
+ - downloadUrl: download url.
+ - hash: a JSON string containing the hash value of the downloaded model.
+ """
+ post = await request.post()
try:
- model_path_index = int(model_path_index)
- except:
- return (None, None)
- if model_path_index < 0 or model_path_index >= len(paths):
- return (None, None)
-
- system_path = os.path.normpath(
- paths[model_path_index] +
- sep +
- model_path[isep2:]
- )
-
- return (system_path, model_path_type)
-
-
-def get_safetensor_header(path):
- try:
- header_bytes = comfy.utils.safetensors_header(path)
- header_json = json.loads(header_bytes)
- return header_json if header_json is not None else {}
- except:
- return {}
-
-
-def end_swap_and_pop(x, i):
- x[i], x[-1] = x[-1], x[i]
- return x.pop(-1)
-
-
-def model_type_to_dir_name(model_type):
- if model_type == "checkpoint": return "checkpoints"
- #elif model_type == "clip": return "clip"
- #elif model_type == "clip_vision": return "clip_vision"
- #elif model_type == "controlnet": return "controlnet"
- elif model_type == "diffuser": return "diffusers"
- elif model_type == "embedding": return "embeddings"
- #elif model_type== "gligen": return "gligen"
- elif model_type == "hypernetwork": return "hypernetworks"
- elif model_type == "lora": return "loras"
- #elif model_type == "style_models": return "style_models"
- #elif model_type == "unet": return "unet"
- elif model_type == "upscale_model": return "upscale_models"
- #elif model_type == "vae": return "vae"
- #elif model_type == "vae_approx": return "vae_approx"
- else: return model_type
-
-
-def ui_rules():
- Rule = config_loader.Rule
- return [
- Rule("model-search-always-append", "", str),
- Rule("model-default-browser-model-type", "checkpoints", str),
- Rule("model-real-time-search", True, bool),
- Rule("model-persistent-search", True, bool),
-
- Rule("model-preview-thumbnail-type", "AUTO", str),
- Rule("model-preview-fallback-search-safetensors-thumbnail", False, bool),
- Rule("model-show-label-extensions", False, bool),
- Rule("model-show-add-button", True, bool),
- Rule("model-show-copy-button", True, bool),
- Rule("model-show-load-workflow-button", True, bool),
- Rule("model-info-button-on-left", False, bool),
-
- Rule("model-add-embedding-extension", False, bool),
- Rule("model-add-drag-strict-on-field", False, bool),
- Rule("model-add-offset", 25, int),
-
- Rule("model-info-autosave-notes", False, bool),
-
- Rule("download-save-description-as-text-file", True, bool),
-
- Rule("sidebar-control-always-compact", False, bool),
- Rule("sidebar-default-width", 0.5, float, 0.0, 1.0),
- Rule("sidebar-default-height", 0.5, float, 0.0, 1.0),
- Rule("text-input-always-hide-search-button", False, bool),
- Rule("text-input-always-hide-clear-button", False, bool),
-
- Rule("tag-generator-sampler-method", "Frequency", str),
- Rule("tag-generator-count", 10, int),
- Rule("tag-generator-threshold", 2, int),
- ]
-
-
-def server_rules():
- Rule = config_loader.Rule
- return [
- #Rule("model_extension_download_whitelist", [".safetensors"], list),
- Rule("civitai_api_key", "", str),
- Rule("huggingface_api_key", "", str),
- ]
-server_settings = config_loader.yaml_load(server_settings_uri, server_rules())
-config_loader.yaml_save(server_settings_uri, server_rules(), server_settings)
-
-
-def get_def_headers(url=""):
- def_headers = {
- "User-Agent": "Mozilla/5.0 (iPad; CPU OS 12_2 like Mac OS X) AppleWebKit/605.1.15 (KHTML, like Gecko) Mobile/15E148",
- }
-
- if url.startswith("https://civitai.com/"):
- api_key = server_settings["civitai_api_key"]
- if (api_key != ""):
- def_headers["Authorization"] = f"Bearer {api_key}"
- url += "&" if "?" in url else "?" # not the most robust solution
- url += f"token={api_key}" # TODO: Authorization didn't work in the header
- elif url.startswith("https://huggingface.co/"):
- api_key = server_settings["huggingface_api_key"]
- if api_key != "":
- def_headers["Authorization"] = f"Bearer {api_key}"
-
- return def_headers
-
-
-@server.PromptServer.instance.routes.get("/model-manager/timestamp")
-async def get_timestamp(request):
- return web.json_response({ "timestamp": datetime.now().timestamp() })
-
-
-@server.PromptServer.instance.routes.get("/model-manager/settings/load")
-async def load_ui_settings(request):
- rules = ui_rules()
- settings = config_loader.yaml_load(ui_settings_uri, rules)
- return web.json_response({ "settings": settings })
-
-
-@server.PromptServer.instance.routes.post("/model-manager/settings/save")
-async def save_ui_settings(request):
- body = await request.json()
- settings = body.get("settings")
- rules = ui_rules()
- validated_settings = config_loader.validated(rules, settings)
- success = config_loader.yaml_save(ui_settings_uri, rules, validated_settings)
- print("Saved file: " + ui_settings_uri)
- return web.json_response({
- "success": success,
- "settings": validated_settings if success else "",
- })
-
-
-from PIL import Image, TiffImagePlugin
-from PIL.PngImagePlugin import PngInfo
-def PIL_cast_serializable(v):
- # source: https://github.com/python-pillow/Pillow/issues/6199#issuecomment-1214854558
- if isinstance(v, TiffImagePlugin.IFDRational):
- return float(v)
- elif isinstance(v, tuple):
- return tuple(PIL_cast_serializable(t) for t in v)
- elif isinstance(v, bytes):
- return v.decode(errors="replace")
- elif isinstance(v, dict):
- for kk, vv in v.items():
- v[kk] = PIL_cast_serializable(vv)
- return v
- else:
- return v
-
-
-def get_safetensors_image_bytes(path):
- if not os.path.isfile(path):
- raise RuntimeError("Path was invalid!")
- header = get_safetensor_header(path)
- metadata = header.get("__metadata__", None)
- if metadata is None:
- return None
- thumbnail = metadata.get("modelspec.thumbnail", None)
- if thumbnail is None:
- return None
- image_data = thumbnail.split(',')[1]
- return base64.b64decode(image_data)
-
-
-def get_image_info(image):
- metadata = None
- if len(image.info) > 0:
- metadata = PngInfo()
- for (key, value) in image.info.items():
- value_str = str(PIL_cast_serializable(value)) # not sure if this is correct (sometimes includes exif)
- metadata.add_text(key, value_str)
- return metadata
-
-
-def image_format_is_equal(f1, f2):
- if not isinstance(f1, str) or not isinstance(f2, str):
- return False
- if f1[0] == ".": f1 = f1[1:]
- if f2[0] == ".": f2 = f2[1:]
- f1 = f1.upper()
- f2 = f2.upper()
- return f1 == f2 or (f1 in jpeg_format_names and f2 in jpeg_format_names)
-
-
-def get_auto_thumbnail_format(original_format):
- if original_format in ["JPEG", "WEBP", "JPG"]: # JFIF?
- return original_format
- return "JPEG" # default fallback
-
-
-@server.PromptServer.instance.routes.get("/model-manager/preview/get")
-async def get_model_preview(request):
- uri = request.query.get("uri")
- quality = 75
- response_image_format = request.query.get("image-format", None)
- if isinstance(response_image_format, str):
- response_image_format = response_image_format.upper()
-
- image_path = no_preview_image
- file_name = os.path.split(no_preview_image)[1]
- if uri != "no-preview":
- sep = os.path.sep
- uri = uri.replace("/" if sep == "\\" else "/", sep)
- path, _ = search_path_to_system_path(uri)
- head, extension = split_valid_ext(path, preview_extensions)
- if os.path.exists(path):
- image_path = path
- file_name = os.path.split(head)[1] + extension
- elif os.path.exists(head) and head.endswith(".safetensors"):
- image_path = head
- file_name = os.path.splitext(os.path.split(head)[1])[0] + extension
-
- w = request.query.get("width")
- h = request.query.get("height")
- try:
- w = int(w)
- if w < 1:
- w = None
- except:
- w = None
- try:
- h = int(h)
- if w < 1:
- h = None
- except:
- h = None
-
- image_data = None
- if w is None and h is None: # full size
- if image_path.endswith(".safetensors"):
- image_data = get_safetensors_image_bytes(image_path)
- else:
- with open(image_path, "rb") as image:
- image_data = image.read()
- fp = io.BytesIO(image_data)
- with Image.open(fp) as image:
- image_format = image.format
- if response_image_format is None:
- response_image_format = image_format
- elif response_image_format == "AUTO":
- response_image_format = get_auto_thumbnail_format(image_format)
-
- if not image_format_is_equal(response_image_format, image_format):
- exif = image.getexif()
- metadata = get_image_info(image)
- if response_image_format in jpeg_format_names:
- image = image.convert('RGB')
- image_bytes = io.BytesIO()
- image.save(image_bytes, format=response_image_format, exif=exif, pnginfo=metadata, quality=quality)
- image_data = image_bytes.getvalue()
- else:
- if image_path.endswith(".safetensors"):
- image_data = get_safetensors_image_bytes(image_path)
- fp = io.BytesIO(image_data)
- else:
- fp = image_path
-
- with Image.open(fp) as image:
- image_format = image.format
- if response_image_format is None:
- response_image_format = image_format
- elif response_image_format == "AUTO":
- response_image_format = get_auto_thumbnail_format(image_format)
-
- w0, h0 = image.size
- if w is None:
- w = (h * w0) // h0
- elif h is None:
- h = (w * h0) // w0
-
- exif = image.getexif()
- metadata = get_image_info(image)
-
- ratio_original = w0 / h0
- ratio_thumbnail = w / h
- if abs(ratio_original - ratio_thumbnail) < 0.01:
- crop_box = (0, 0, w0, h0)
- elif ratio_original > ratio_thumbnail:
- crop_width_fp = h0 * w / h
- x0 = int((w0 - crop_width_fp) / 2)
- crop_box = (x0, 0, x0 + int(crop_width_fp), h0)
- else:
- crop_height_fp = w0 * h / w
- y0 = int((h0 - crop_height_fp) / 2)
- crop_box = (0, y0, w0, y0 + int(crop_height_fp))
- image = image.crop(crop_box)
-
- if w < w0 and h < h0:
- resampling_method = Image.Resampling.BOX
- else:
- resampling_method = Image.Resampling.BICUBIC
- image.thumbnail((w, h), resample=resampling_method)
-
- if not image_format_is_equal(image_format, response_image_format) and response_image_format in jpeg_format_names:
- image = image.convert('RGB')
- image_bytes = io.BytesIO()
- image.save(image_bytes, format=response_image_format, exif=exif, pnginfo=metadata, quality=quality)
- image_data = image_bytes.getvalue()
-
- response_file_name = os.path.splitext(file_name)[0] + '.' + response_image_format.lower()
- return web.Response(
- headers={
- "Content-Disposition": f"inline; filename={response_file_name}",
- },
- body=image_data,
- content_type="image/" + response_image_format.lower(),
- )
-
-
-@server.PromptServer.instance.routes.get("/model-manager/image/extensions")
-async def get_image_extensions(request):
- return web.json_response(image_extensions)
-
-
-def download_model_preview(formdata):
- path = formdata.get("path", None)
- if type(path) is not str:
- raise ValueError("Invalid path!")
- path, model_type = search_path_to_system_path(path)
- model_type_extensions = folder_paths_get_supported_pt_extensions(model_type)
- path_without_extension, _ = split_valid_ext(path, model_type_extensions)
-
- overwrite = formdata.get("overwrite", "true").lower()
- overwrite = True if overwrite == "true" else False
-
- image = formdata.get("image", None)
- if type(image) is str:
- civitai_image_url = "https://civitai.com/images/"
- if image.startswith(civitai_image_url):
- image_id = re.search(r"^\d+", image[len(civitai_image_url):]).group(0)
- image_id = str(int(image_id))
- image_info_url = f"https://civitai.com/api/v1/images?imageId={image_id}"
- def_headers = get_def_headers(image_info_url)
- response = requests.get(
- url=image_info_url,
- stream=False,
- verify=False,
- headers=def_headers,
- proxies=None,
- allow_redirects=False,
- )
- if response.ok:
- content_type = response.headers.get("Content-Type")
- info = response.json()
- items = info["items"]
- if len(items) == 0:
- raise RuntimeError("Civitai /api/v1/images returned 0 items!")
- image = items[0]["url"]
- else:
- raise RuntimeError("Bad response from api/v1/images!")
- _, image_extension = split_valid_ext(image, image_extensions)
- if image_extension == "":
- raise ValueError("Invalid image type!")
- image_path = path_without_extension + image_extension
- download_file(image, image_path, overwrite)
- else:
- content_type = image.content_type
- if not content_type.startswith("image/"):
- raise RuntimeError("Invalid content type!")
- image_extension = "." + content_type[len("image/"):]
- if image_extension not in image_extensions:
- raise RuntimeError("Invalid extension!")
-
- image_path = path_without_extension + image_extension
- if not overwrite and os.path.isfile(image_path):
- raise RuntimeError("Image already exists!")
- file: io.IOBase = image.file
- image_data = file.read()
- with open(image_path, "wb") as f:
- f.write(image_data)
- print("Saved file: " + image_path)
-
- if overwrite:
- delete_same_name_files(path_without_extension, preview_extensions, image_extension)
-
- # detect (and try to fix) wrong file extension
- image_format = None
- with Image.open(image_path) as image:
- image_format = image.format
- image_dir_and_name, image_ext = os.path.splitext(image_path)
- if not image_format_is_equal(image_format, image_ext):
- corrected_image_path = image_dir_and_name + "." + image_format.lower()
- if os.path.exists(corrected_image_path) and not overwrite:
- print("WARNING: '" + image_path + "' has wrong extension!")
- else:
- os.rename(image_path, corrected_image_path)
- print("Saved file: " + corrected_image_path)
- image_path = corrected_image_path
- return image_path # return in-case need corrected path
-
-
-@server.PromptServer.instance.routes.post("/model-manager/preview/set")
-async def set_model_preview(request):
- formdata = await request.post()
- try:
- download_model_preview(formdata)
- return web.json_response({ "success": True })
- except ValueError as e:
- print(e, file=sys.stderr, flush=True)
- return web.json_response({
- "success": False,
- "alert": "Failed to set preview!\n\n" + str(e),
- })
-
-
-@server.PromptServer.instance.routes.post("/model-manager/preview/delete")
-async def delete_model_preview(request):
- result = { "success": False }
-
- model_path = request.query.get("path", None)
- if model_path is None:
- result["alert"] = "Missing model path!"
- return web.json_response(result)
- model_path = urllib.parse.unquote(model_path)
-
- model_path, model_type = search_path_to_system_path(model_path)
- model_extensions = folder_paths_get_supported_pt_extensions(model_type)
- path_and_name, _ = split_valid_ext(model_path, model_extensions)
- delete_same_name_files(path_and_name, preview_extensions)
-
- result["success"] = True
- return web.json_response(result)
-
-
-def correct_image_extensions(root_dir):
- detected_image_count = 0
- corrected_image_count = 0
- for root, dirs, files in os.walk(root_dir):
- for file_name in files:
- file_path = root + os.path.sep + file_name
- image_format = None
- try:
- with Image.open(file_path) as image:
- image_format = image.format
- except:
- continue
- image_path = file_path
- image_dir_and_name, image_ext = os.path.splitext(image_path)
- if not image_format_is_equal(image_format, image_ext):
- detected_image_count += 1
- corrected_image_path = image_dir_and_name + "." + image_format.lower()
- if os.path.exists(corrected_image_path):
- print("WARNING: '" + image_path + "' has wrong extension!")
- else:
- try:
- os.rename(image_path, corrected_image_path)
- except:
- print("WARNING: Unable to rename '" + image_path + "'!")
- continue
- ext0 = os.path.splitext(image_path)[1]
- ext1 = os.path.splitext(corrected_image_path)[1]
- print(f"({ext0} -> {ext1}): {corrected_image_path}")
- corrected_image_count += 1
- return (detected_image_count, corrected_image_count)
-
-
-@server.PromptServer.instance.routes.get("/model-manager/preview/correct-extensions")
-async def correct_preview_extensions(request):
- result = { "success": False }
-
- detected = 0
- corrected = 0
-
- model_types = os.listdir(comfyui_model_uri)
- model_types.remove("configs")
- model_types.sort()
-
- for model_type in model_types:
- for base_path_index, model_base_path in enumerate(folder_paths_get_folder_paths(model_type)):
- if not os.path.exists(model_base_path): # TODO: Bug in main code? ("ComfyUI\output\checkpoints", "ComfyUI\output\clip", "ComfyUI\models\t2i_adapter", "ComfyUI\output\vae")
- continue
- d, c = correct_image_extensions(model_base_path)
- detected += d
- corrected += c
-
- result["success"] = True
- result["detected"] = detected
- result["corrected"] = corrected
- return web.json_response(result)
-
-
-@server.PromptServer.instance.routes.get("/model-manager/models/list")
-async def get_model_list(request):
- use_safetensor_thumbnail = (
- config_loader.yaml_load(ui_settings_uri, ui_rules())
- .get("model-preview-fallback-search-safetensors-thumbnail", False)
- )
-
- model_types = os.listdir(comfyui_model_uri)
- model_types.remove("configs")
- model_types.sort()
-
- models = {}
- for model_type in model_types:
- model_extensions = tuple(folder_paths_get_supported_pt_extensions(model_type))
- file_infos = []
- for base_path_index, model_base_path in enumerate(folder_paths_get_folder_paths(model_type)):
- if not os.path.exists(model_base_path): # TODO: Bug in main code? ("ComfyUI\output\checkpoints", "ComfyUI\output\clip", "ComfyUI\models\t2i_adapter", "ComfyUI\output\vae")
- continue
- for cwd, subdirs, files in os.walk(model_base_path):
- dir_models = []
- dir_images = []
-
- for file in files:
- if file.lower().endswith(model_extensions):
- dir_models.append(file)
- elif file.lower().endswith(preview_extensions):
- dir_images.append(file)
-
- for model in dir_models:
- model_name, model_ext = split_valid_ext(model, model_extensions)
- image = None
- image_modified = None
- for ext in preview_extensions: # order matters
- for iImage in range(len(dir_images)-1, -1, -1):
- image_name = dir_images[iImage]
- if not image_name.lower().endswith(ext.lower()):
- continue
- image_name = image_name[:-len(ext)]
- if model_name == image_name:
- image = end_swap_and_pop(dir_images, iImage)
- img_abs_path = os.path.join(cwd, image)
- image_modified = pathlib.Path(img_abs_path).stat().st_mtime_ns
- break
- if image is not None:
- break
- abs_path = os.path.join(cwd, model)
- stats = pathlib.Path(abs_path).stat()
- sizeBytes = stats.st_size
- model_modified = stats.st_mtime_ns
- model_created = stats.st_ctime_ns
- if use_safetensor_thumbnail and image is None and model_ext == ".safetensors":
- # try to fallback on safetensor embedded thumbnail
- header = get_safetensor_header(abs_path)
- metadata = header.get("__metadata__", None)
- if metadata is not None:
- thumbnail = metadata.get("modelspec.thumbnail", None)
- if thumbnail is not None:
- i0 = thumbnail.find("/") + 1
- i1 = thumbnail.find(";")
- image_ext = "." + thumbnail[i0:i1]
- if image_ext in image_extensions:
- image = model + image_ext
- image_modified = model_modified
- rel_path = "" if cwd == model_base_path else os.path.relpath(cwd, model_base_path)
- info = (
- model,
- image,
- base_path_index,
- rel_path,
- model_modified,
- model_created,
- image_modified,
- sizeBytes,
- )
- file_infos.append(info)
- #file_infos.sort(key=lambda tup: tup[4], reverse=True) # TODO: remove sort; sorted on client
-
- model_items = []
- for model, image, base_path_index, rel_path, model_modified, model_created, image_modified, sizeBytes in file_infos:
- item = {
- "name": model,
- "path": "/" + os.path.join(model_type, str(base_path_index), rel_path, model).replace(os.path.sep, "/"), # relative logical path
- #"systemPath": os.path.join(rel_path, model), # relative system path (less information than "search path")
- "dateModified": model_modified,
- "dateCreated": model_created,
- #"dateLastUsed": "", # TODO: track server-side, send increment client-side
- #"countUsed": 0, # TODO: track server-side, send increment client-side
- "sizeBytes": sizeBytes,
- }
- if image is not None:
- raw_post = os.path.join(model_type, str(base_path_index), rel_path, image)
- item["preview"] = {
- "path": urllib.parse.quote_plus(raw_post),
- "dateModified": urllib.parse.quote_plus(str(image_modified)),
- }
- model_items.append(item)
-
- models[model_type] = model_items
-
- return web.json_response(models)
-
-
-def linear_directory_hierarchy(refresh = False):
- model_paths = folder_paths_folder_names_and_paths(refresh)
- dir_list = []
- dir_list.append({ "name": "", "childIndex": 1, "childCount": len(model_paths) })
- for model_dir_name, (model_dirs, _) in model_paths.items():
- dir_list.append({ "name": model_dir_name, "childIndex": None, "childCount": len(model_dirs) })
- for model_dir_index, (_, (model_dirs, extension_whitelist)) in enumerate(model_paths.items()):
- model_dir_child_index = len(dir_list)
- dir_list[model_dir_index + 1]["childIndex"] = model_dir_child_index
- for dir_path_index, dir_path in enumerate(model_dirs):
- dir_list.append({ "name": str(dir_path_index), "childIndex": None, "childCount": None })
- for dir_path_index, dir_path in enumerate(model_dirs):
- if not os.path.exists(dir_path) or os.path.isfile(dir_path):
- continue
-
- #dir_list.append({ "name": str(dir_path_index), "childIndex": None, "childCount": 0 })
- dir_stack = [(dir_path, model_dir_child_index + dir_path_index)]
- while len(dir_stack) > 0: # DEPTH-FIRST
- dir_path, dir_index = dir_stack.pop()
-
- dir_items = os.listdir(dir_path)
- dir_items = sorted(dir_items, key=str.casefold)
-
- dir_child_count = 0
-
- # TODO: sort content of directory: alphabetically
- # TODO: sort content of directory: files first
-
- subdirs = []
- for item_name in dir_items: # BREADTH-FIRST
- item_path = os.path.join(dir_path, item_name)
- if os.path.isdir(item_path):
- # dir
- subdir_index = len(dir_list) # this must be done BEFORE `dir_list.append`
- subdirs.append((item_path, subdir_index))
- dir_list.append({ "name": item_name, "childIndex": None, "childCount": 0 })
- dir_child_count += 1
- else:
- # file
- if extension_whitelist is None or split_valid_ext(item_name, extension_whitelist)[1] != "":
- dir_list.append({ "name": item_name })
- dir_child_count += 1
- if dir_child_count > 0:
- dir_list[dir_index]["childIndex"] = len(dir_list) - dir_child_count
- dir_list[dir_index]["childCount"] = dir_child_count
- subdirs.reverse()
- for dir_path, subdir_index in subdirs:
- dir_stack.append((dir_path, subdir_index))
- return dir_list
-
-
-@server.PromptServer.instance.routes.get("/model-manager/models/directory-list")
-async def get_directory_list(request):
- #body = await request.json()
- dir_list = linear_directory_hierarchy(True)
- #json.dump(dir_list, sys.stdout, indent=4)
- return web.json_response(dir_list)
-
-
-def download_file(url, filename, overwrite):
- if not overwrite and os.path.isfile(filename):
- raise ValueError("File already exists!")
-
- filename_temp = filename + ".download"
-
- def_headers = get_def_headers(url)
- rh = requests.get(
- url=url,
- stream=True,
- verify=False,
- headers=def_headers,
- proxies=None,
- allow_redirects=False,
- )
- if not rh.ok:
- raise ValueError(
- "Unable to download! Request header status code: " +
- str(rh.status_code)
- )
-
- downloaded_size = 0
- if rh.status_code == 200 and os.path.exists(filename_temp):
- downloaded_size = os.path.getsize(filename_temp)
-
- headers = {"Range": "bytes=%d-" % downloaded_size}
- headers["User-Agent"] = def_headers["User-Agent"]
- headers["Authorization"] = def_headers.get("Authorization", None)
-
- r = requests.get(
- url=url,
- stream=True,
- verify=False,
- headers=headers,
- proxies=None,
- allow_redirects=False,
- )
- if rh.status_code == 307 and r.status_code == 307:
- # Civitai redirect
- redirect_url = r.content.decode("utf-8")
- if not redirect_url.startswith("http"):
- # Civitai requires login (NSFW or user-required)
- # TODO: inform user WHY download failed
- raise ValueError("Unable to download from Civitai! Redirect url: " + str(redirect_url))
- download_file(redirect_url, filename, overwrite)
- return
- if rh.status_code == 302 and r.status_code == 302:
- # HuggingFace redirect
- redirect_url = r.content.decode("utf-8")
- redirect_url_index = redirect_url.find("http")
- if redirect_url_index == -1:
- raise ValueError("Unable to download from HuggingFace! Redirect url: " + str(redirect_url))
- download_file(redirect_url[redirect_url_index:], filename, overwrite)
- return
- elif rh.status_code == 200 and r.status_code == 206:
- # Civitai download link
- pass
-
- total_size = int(rh.headers.get("Content-Length", 0)) # TODO: pass in total size earlier
-
- print("Downloading file: " + url)
- if total_size != 0:
- print("Download file size: " + str(total_size))
-
- mode = "wb" if overwrite else "ab"
- with open(filename_temp, mode) as f:
- for chunk in r.iter_content(chunk_size=1024):
- if chunk is not None:
- downloaded_size += len(chunk)
- f.write(chunk)
- f.flush()
-
- if total_size != 0:
- fraction = 1 if downloaded_size == total_size else downloaded_size / total_size
- progress = int(50 * fraction)
- sys.stdout.reconfigure(encoding="utf-8")
- sys.stdout.write(
- "\r[%s%s] %d%%"
- % (
- "-" * progress,
- " " * (50 - progress),
- 100 * fraction,
- )
- )
- sys.stdout.flush()
- print()
-
- if overwrite and os.path.isfile(filename):
- os.remove(filename)
- os.rename(filename_temp, filename)
- print("Saved file: " + filename)
-
-
-def bytes_to_size(total_bytes):
- units = ["B", "KiB", "MiB", "GiB", "TiB", "PiB"]
- b = total_bytes
- i = 0
- while True:
- b = b >> 10
- if (b == 0): break
- i = i + 1
- if i >= len(units) or i == 0:
- return str(total_bytes) + " " + units[0]
- return "{:.2f}".format(total_bytes / (1 << (i * 10))) + " " + units[i]
-
-
-@server.PromptServer.instance.routes.get("/model-manager/model/info")
-async def get_model_info(request):
- result = { "success": False }
-
- model_path = request.query.get("path", None)
- if model_path is None:
- result["alert"] = "Missing model path!"
- return web.json_response(result)
- model_path = urllib.parse.unquote(model_path)
-
- abs_path, model_type = search_path_to_system_path(model_path)
- if abs_path is None:
- result["alert"] = "Invalid model path!"
- return web.json_response(result)
-
- info = {}
- comfyui_directory, name = os.path.split(model_path)
- info["File Name"] = name
- info["File Directory"] = comfyui_directory
- info["File Size"] = bytes_to_size(os.path.getsize(abs_path))
- stats = pathlib.Path(abs_path).stat()
- date_format = "%Y-%m-%d %H:%M:%S"
- date_modified = datetime.fromtimestamp(stats.st_mtime).strftime(date_format)
- #info["Date Modified"] = date_modified
- #info["Date Created"] = datetime.fromtimestamp(stats.st_ctime).strftime(date_format)
-
- model_extensions = folder_paths_get_supported_pt_extensions(model_type)
- abs_name , _ = split_valid_ext(abs_path, model_extensions)
-
- for extension in preview_extensions:
- maybe_preview = abs_name + extension
- if os.path.isfile(maybe_preview):
- preview_path, _ = split_valid_ext(model_path, model_extensions)
- preview_modified = pathlib.Path(maybe_preview).stat().st_mtime_ns
- info["Preview"] = {
- "path": urllib.parse.quote_plus(preview_path + extension),
- "dateModified": urllib.parse.quote_plus(str(preview_modified)),
- }
- break
-
- header = get_safetensor_header(abs_path)
- metadata = header.get("__metadata__", None)
-
- if metadata is not None and info.get("Preview", None) is None:
- thumbnail = metadata.get("modelspec.thumbnail")
- if thumbnail is not None:
- i0 = thumbnail.find("/") + 1
- i1 = thumbnail.find(";", i0)
- thumbnail_extension = "." + thumbnail[i0:i1]
- if thumbnail_extension in image_extensions:
- info["Preview"] = {
- "path": request.query["path"] + thumbnail_extension,
- "dateModified": date_modified,
- }
-
- if metadata is not None:
- info["Base Training Model"] = metadata.get("ss_sd_model_name", "")
- info["Base Model Version"] = metadata.get("ss_base_model_version", "")
- info["Network Dimension"] = metadata.get("ss_network_dim", "")
- info["Network Alpha"] = metadata.get("ss_network_alpha", "")
-
- if metadata is not None:
- training_comment = metadata.get("ss_training_comment", "")
- info["Description"] = (
- metadata.get("modelspec.description", "") +
- "\n\n" +
- metadata.get("modelspec.usage_hint", "") +
- "\n\n" +
- training_comment if training_comment != "None" else ""
- ).strip()
-
- info_text_file = abs_name + model_info_extension
- notes = ""
- if os.path.isfile(info_text_file):
- with open(info_text_file, 'r', encoding="utf-8") as f:
- notes = f.read()
-
- if metadata is not None:
- img_buckets = metadata.get("ss_bucket_info", None)
- datasets = metadata.get("ss_datasets", None)
-
- if type(img_buckets) is str:
- img_buckets = json.loads(img_buckets)
- elif type(datasets) is str:
- datasets = json.loads(datasets)
- if isinstance(datasets, list):
- datasets = datasets[0]
- img_buckets = datasets.get("bucket_info", None)
- resolutions = {}
- if img_buckets is not None:
- buckets = img_buckets.get("buckets", {})
- for resolution in buckets.values():
- dim = resolution["resolution"]
- x, y = dim[0], dim[1]
- count = resolution["count"]
- resolutions[str(x) + "x" + str(y)] = count
- resolutions = list(resolutions.items())
- resolutions.sort(key=lambda x: x[1], reverse=True)
- info["Bucket Resolutions"] = resolutions
-
- tags = None
- if metadata is not None:
- dir_tags = metadata.get("ss_tag_frequency", "{}")
- if type(dir_tags) is str:
- dir_tags = json.loads(dir_tags)
- tags = {}
- for train_tags in dir_tags.values():
- for tag, count in train_tags.items():
- tags[tag] = tags.get(tag, 0) + count
- tags = list(tags.items())
- tags.sort(key=lambda x: x[1], reverse=True)
-
- result["success"] = True
- result["info"] = info
- if metadata is not None:
- result["metadata"] = metadata
- if tags is not None:
- result["tags"] = tags
- result["notes"] = notes
- return web.json_response(result)
-
-
-@server.PromptServer.instance.routes.get("/model-manager/system-separator")
-async def get_system_separator(request):
- return web.json_response(os.path.sep)
-
-
-@server.PromptServer.instance.routes.post("/model-manager/model/download")
-async def download_model(request):
- formdata = await request.post()
- result = { "success": False }
-
- overwrite = formdata.get("overwrite", "false").lower()
- overwrite = True if overwrite == "true" else False
-
- model_path = formdata.get("path", "/0")
- directory, model_type = search_path_to_system_path(model_path)
- if directory is None:
- result["alert"] = "Invalid save path!"
- return web.json_response(result)
-
- download_uri = formdata.get("download")
- if download_uri is None:
- result["alert"] = "Invalid download url!"
- return web.json_response(result)
-
- name = formdata.get("name")
- model_extensions = folder_paths_get_supported_pt_extensions(model_type)
- name_head, model_extension = split_valid_ext(name, model_extensions)
- name_without_extension = os.path.split(name_head)[1]
- if name_without_extension == "":
- result["alert"] = "Cannot have empty model name!"
- return web.json_response(result)
- if model_extension == "":
- result["alert"] = "Unrecognized model extension!"
- return web.json_response(result)
- file_name = os.path.join(directory, name)
- try:
- download_file(download_uri, file_name, overwrite)
+ task_id = await services.create_model_download_task(post)
+ return web.json_response({"success": True, "data": {"taskId": task_id}})
except Exception as e:
- print(e, file=sys.stderr, flush=True)
- result["alert"] = "Failed to download model!\n\n" + str(e)
- return web.json_response(result)
-
- image = formdata.get("image")
- if image is not None and image != "":
- try:
- download_model_preview({
- "path": model_path + os.sep + name,
- "image": image,
- "overwrite": formdata.get("overwrite"),
- })
- except Exception as e:
- print(e, file=sys.stderr, flush=True)
- result["alert"] = "Failed to download preview!\n\n" + str(e)
-
- result["success"] = True
- return web.json_response(result)
+ error_msg = f"Create model download task failed: {str(e)}"
+ logging.error(error_msg)
+ logging.debug(traceback.format_exc())
+ return web.json_response({"success": False, "error": error_msg})
-@server.PromptServer.instance.routes.post("/model-manager/model/move")
-async def move_model(request):
- body = await request.json()
- result = { "success": False }
-
- old_file = body.get("oldFile", None)
- if old_file is None:
- result["alert"] = "No model was given!"
- return web.json_response(result)
- old_file, old_model_type = search_path_to_system_path(old_file)
- if not os.path.isfile(old_file):
- result["alert"] = "Model does not exist!"
- return web.json_response(result)
- old_model_extensions = folder_paths_get_supported_pt_extensions(old_model_type)
- old_file_without_extension, model_extension = split_valid_ext(old_file, old_model_extensions)
- if model_extension == "":
- result["alert"] = "Invalid model extension!"
- return web.json_response(result)
-
- new_file = body.get("newFile", None)
- if new_file is None or new_file == "":
- result["alert"] = "New model name was invalid!"
- return web.json_response(result)
- new_file, new_model_type = search_path_to_system_path(new_file)
- if not new_file.endswith(model_extension):
- result["alert"] = "Cannot change model extension!"
- return web.json_response(result)
- if os.path.isfile(new_file):
- result["alert"] = "Cannot overwrite existing model!"
- return web.json_response(result)
- new_model_extensions = folder_paths_get_supported_pt_extensions(new_model_type)
- new_file_without_extension, new_model_extension = split_valid_ext(new_file, new_model_extensions)
- if model_extension != new_model_extension:
- result["alert"] = "Cannot change model extension!"
- return web.json_response(result)
- new_file_dir, new_file_name = os.path.split(new_file)
- if not os.path.isdir(new_file_dir):
- result["alert"] = "Destination directory does not exist!"
- return web.json_response(result)
- new_name_without_extension = os.path.splitext(new_file_name)[0]
- if new_file_name == new_name_without_extension or new_name_without_extension == "":
- result["alert"] = "New model name was empty!"
- return web.json_response(result)
-
- if old_file == new_file:
- # no-op
- result["success"] = True
- return web.json_response(result)
+@routes.get("/model-manager/models")
+async def read_models(request):
+ """
+ Scan all models and read their information.
+ """
try:
- shutil.move(old_file, new_file)
- print("Moved file: " + new_file)
- except ValueError as e:
- print(e, file=sys.stderr, flush=True)
- result["alert"] = "Failed to move model!\n\n" + str(e)
- return web.json_response(result)
- # TODO: this could overwrite existing files in destination; do a check beforehand?
- for extension in preview_extensions + (model_info_extension,):
- old_file = old_file_without_extension + extension
- if os.path.isfile(old_file):
- new_file = new_file_without_extension + extension
- try:
- shutil.move(old_file, new_file)
- print("Moved file: " + new_file)
- except ValueError as e:
- print(e, file=sys.stderr, flush=True)
- msg = result.get("alert","")
- if msg == "":
- result["alert"] = "Failed to move model resource file!\n\n" + str(e)
- else:
- result["alert"] = msg + "\n" + str(e)
-
- result["success"] = True
- return web.json_response(result)
+ result = []
+ model_base_paths = config.model_base_paths
+ for model_type in model_base_paths:
+ result.extend(services.scan_models_by_model_type(model_type))
+ result = [{"id": i, **x} for i, x in enumerate(result)]
+ return web.json_response({"success": True, "data": result})
+ except Exception as e:
+ error_msg = f"Read models failed: {str(e)}"
+ logging.error(error_msg)
+ logging.debug(traceback.format_exc())
+ return web.json_response({"success": False, "error": error_msg})
-def delete_same_name_files(path_without_extension, extensions, keep_extension=None):
- for extension in extensions:
- if extension == keep_extension: continue
- file = path_without_extension + extension
- if os.path.isfile(file):
- os.remove(file)
- print("Deleted file: " + file)
+@routes.put("/model-manager/model/{type}/{index}/{filename:.*}")
+async def update_model(request):
+ """
+ Update model information.
+
+ request body: x-www-form-urlencoded
+ - previewFile: preview file.
+ - description: description.
+ - type: model type.
+ - pathIndex: index of the model folders.
+ - fullname: filename that relative to the model folder.
+ All fields are optional, but type, pathIndex and fullname must appear together.
+ """
+ model_type = request.match_info.get("type", None)
+ index = int(request.match_info.get("index", None))
+ filename = request.match_info.get("filename", None)
+
+ post: dict = await request.post()
+
+ try:
+ model_path = utils.get_valid_full_path(model_type, index, filename)
+ if model_path is None:
+ raise RuntimeError(f"File {filename} not found")
+ services.update_model(model_path, post)
+ return web.json_response({"success": True})
+ except Exception as e:
+ error_msg = f"Update model failed: {str(e)}"
+ logging.error(error_msg)
+ logging.debug(traceback.format_exc())
+ return web.json_response({"success": False, "error": error_msg})
-@server.PromptServer.instance.routes.post("/model-manager/model/delete")
+@routes.delete("/model-manager/model/{type}/{index}/{filename:.*}")
async def delete_model(request):
- result = { "success": False }
+ """
+ Delete model.
+ """
+ model_type = request.match_info.get("type", None)
+ index = int(request.match_info.get("index", None))
+ filename = request.match_info.get("filename", None)
- model_path = request.query.get("path", None)
- if model_path is None:
- result["alert"] = "Missing model path!"
- return web.json_response(result)
- model_path = urllib.parse.unquote(model_path)
- model_path, model_type = search_path_to_system_path(model_path)
- if model_path is None:
- result["alert"] = "Invalid model path!"
- return web.json_response(result)
-
- model_extensions = folder_paths_get_supported_pt_extensions(model_type)
- path_and_name, model_extension = split_valid_ext(model_path, model_extensions)
- if model_extension == "":
- result["alert"] = "Cannot delete file!"
- return web.json_response(result)
-
- if os.path.isfile(model_path):
- os.remove(model_path)
- result["success"] = True
- print("Deleted file: " + model_path)
-
- delete_same_name_files(path_and_name, preview_extensions)
- delete_same_name_files(path_and_name, (model_info_extension,))
-
- return web.json_response(result)
+ try:
+ model_path = utils.get_valid_full_path(model_type, index, filename)
+ if model_path is None:
+ raise RuntimeError(f"File {filename} not found")
+ services.remove_model(model_path)
+ return web.json_response({"success": True})
+ except Exception as e:
+ error_msg = f"Delete model failed: {str(e)}"
+ logging.error(error_msg)
+ logging.debug(traceback.format_exc())
+ return web.json_response({"success": False, "error": error_msg})
-@server.PromptServer.instance.routes.post("/model-manager/notes/save")
-async def set_notes(request):
- body = await request.json()
- result = { "success": False }
+@routes.get("/model-manager/preview/{type}/{index}/{filename:.*}")
+async def read_model_preview(request):
+ """
+ Get the file stream of the specified image.
+ If the file does not exist, no-preview.png is returned.
- dt_epoch = body.get("timestamp", None)
+ :param type: The type of the model. eg.checkpoints, loras, vae, etc.
+ :param index: The index of the model folders.
+ :param filename: The filename of the image.
+ """
+ model_type = request.match_info.get("type", None)
+ index = int(request.match_info.get("index", None))
+ filename = request.match_info.get("filename", None)
- text = body.get("notes", None)
- if type(text) is not str:
- result["alert"] = "Invalid note!"
- return web.json_response(result)
+ extension_uri = config.extension_uri
- model_path = body.get("path", None)
- if type(model_path) is not str:
- result["alert"] = "Missing model path!"
- return web.json_response(result)
- model_path, model_type = search_path_to_system_path(model_path)
- model_extensions = folder_paths_get_supported_pt_extensions(model_type)
- file_path_without_extension, _ = split_valid_ext(model_path, model_extensions)
- filename = os.path.normpath(file_path_without_extension + model_info_extension)
-
- if dt_epoch is not None and os.path.exists(filename) and os.path.getmtime(filename) > dt_epoch:
- # discard late save
- result["success"] = True
- return web.json_response(result)
-
- if text.isspace() or text == "":
- if os.path.exists(filename):
- os.remove(filename)
- #print("Deleted file: " + filename) # autosave -> too verbose
- else:
- try:
- with open(filename, "w", encoding="utf-8") as f:
- f.write(text)
- if dt_epoch is not None:
- os.utime(filename, (dt_epoch, dt_epoch))
- #print("Saved file: " + filename) # autosave -> too verbose
- except ValueError as e:
- print(e, file=sys.stderr, flush=True)
- result["alert"] = "Failed to save notes!\n\n" + str(e)
- web.json_response(result)
+ try:
+ folders = folder_paths.get_folder_paths(model_type)
+ base_path = folders[index]
+ abs_path = os.path.join(base_path, filename)
+ except:
+ abs_path = extension_uri
- result["success"] = True
- return web.json_response(result)
+ if not os.path.isfile(abs_path):
+ abs_path = os.path.join(extension_uri, "assets", "no-preview.png")
+ return web.FileResponse(abs_path)
+
+
+@routes.get("/model-manager/preview/download/{filename}")
+async def read_download_preview(request):
+ filename = request.match_info.get("filename", None)
+ extension_uri = config.extension_uri
+
+ download_path = utils.get_download_path()
+ preview_path = os.path.join(download_path, filename)
+
+ if not os.path.isfile(preview_path):
+ preview_path = os.path.join(extension_uri, "assets", "no-preview.png")
+
+ return web.FileResponse(preview_path)
WEB_DIRECTORY = "web"
NODE_CLASS_MAPPINGS = {}
-__all__ = ["NODE_CLASS_MAPPINGS"]
+__all__ = ["WEB_DIRECTORY", "NODE_CLASS_MAPPINGS"]
diff --git a/no-preview.png b/assets/no-preview.png
similarity index 100%
rename from no-preview.png
rename to assets/no-preview.png
diff --git a/config_loader.py b/config_loader.py
deleted file mode 100644
index 2f9e895..0000000
--- a/config_loader.py
+++ /dev/null
@@ -1,72 +0,0 @@
-import yaml
-from dataclasses import dataclass
-
-@dataclass
-class Rule:
- key: any
- value_default: any
- value_type: type
- value_min: any # int | float | None
- value_max: any # int | float | None
-
- def __init__(
- self,
- key,
- value_default,
- value_type: type,
- value_min: any = None, # int | float | None
- value_max: any = None, # int | float | None
- ):
- self.key = key
- self.value_default = value_default
- self.value_type = value_type
- self.value_min = value_min
- self.value_max = value_max
-
-def _get_valid_value(data: dict, r: Rule):
- if r.value_type != type(r.value_default):
- raise Exception(f"'value_type' does not match type of 'value_default'!")
- value = data.get(r.key)
- if value is None:
- value = r.value_default
- else:
- try:
- value = r.value_type(value)
- except:
- value = r.value_default
-
- value_is_numeric = r.value_type == int or r.value_type == float
- if value_is_numeric and r.value_min:
- if r.value_type != type(r.value_min):
- raise Exception(f"Type of 'value_type' does not match the type of 'value_min'!")
- value = max(r.value_min, value)
- if value_is_numeric and r.value_max:
- if r.value_type != type(r.value_max):
- raise Exception(f"Type of 'value_type' does not match the type of 'value_max'!")
- value = min(r.value_max, value)
-
- return value
-
-def validated(rules: list[Rule], data: dict = {}):
- valid = {}
- for r in rules:
- valid[r.key] = _get_valid_value(data, r)
- return valid
-
-def yaml_load(path, rules: list[Rule]):
- data = {}
- try:
- with open(path, 'r') as file:
- data = yaml.safe_load(file)
- except:
- pass
- return validated(rules, data)
-
-def yaml_save(path, rules: list[Rule], data: dict) -> bool:
- data = validated(rules, data)
- try:
- with open(path, 'w') as file:
- yaml.dump(data, file)
- return True
- except:
- return False
diff --git a/demo/beta-menu-model-manager-button-settings-group.png b/demo/beta-menu-model-manager-button-settings-group.png
deleted file mode 100644
index e5f0d7e..0000000
Binary files a/demo/beta-menu-model-manager-button-settings-group.png and /dev/null differ
diff --git a/demo/tab-download.png b/demo/tab-download.png
index 4e5f7a7..50ffa5f 100644
Binary files a/demo/tab-download.png and b/demo/tab-download.png differ
diff --git a/demo/tab-model-drag-add.gif b/demo/tab-model-drag-add.gif
deleted file mode 100644
index 897b474..0000000
Binary files a/demo/tab-model-drag-add.gif and /dev/null differ
diff --git a/demo/tab-model-info-overview.png b/demo/tab-model-info-overview.png
old mode 100644
new mode 100755
index 0637e75..bef89c4
Binary files a/demo/tab-model-info-overview.png and b/demo/tab-model-info-overview.png differ
diff --git a/demo/tab-model-node-graph.gif b/demo/tab-model-node-graph.gif
new file mode 100755
index 0000000..638bcdc
Binary files /dev/null and b/demo/tab-model-node-graph.gif differ
diff --git a/demo/tab-model-preview-thumbnail-buttons-example.png b/demo/tab-model-preview-thumbnail-buttons-example.png
deleted file mode 100644
index 8f96ba6..0000000
Binary files a/demo/tab-model-preview-thumbnail-buttons-example.png and /dev/null differ
diff --git a/demo/tab-models-dropdown.png b/demo/tab-models-dropdown.png
deleted file mode 100644
index e23f763..0000000
Binary files a/demo/tab-models-dropdown.png and /dev/null differ
diff --git a/demo/tab-models.gif b/demo/tab-models.gif
new file mode 100755
index 0000000..b3e28f0
Binary files /dev/null and b/demo/tab-models.gif differ
diff --git a/demo/tab-models.png b/demo/tab-models.png
old mode 100644
new mode 100755
index 672ecea..f2a8825
Binary files a/demo/tab-models.png and b/demo/tab-models.png differ
diff --git a/demo/tab-settings.png b/demo/tab-settings.png
old mode 100644
new mode 100755
index 13ee3ef..04f3b49
Binary files a/demo/tab-settings.png and b/demo/tab-settings.png differ
diff --git a/package.json b/package.json
index 4b3fab8..0c1f49e 100644
--- a/package.json
+++ b/package.json
@@ -4,12 +4,16 @@
"version": "1.0.0",
"type": "module",
"scripts": {
- "dev": "vite build --watch",
+ "dev": "vite",
"build": "vite build",
"prepare": "husky"
},
"devDependencies": {
+ "@tailwindcss/container-queries": "^0.1.1",
+ "@types/lodash": "^4.17.9",
+ "@types/markdown-it": "^14.1.2",
"@types/node": "^22.5.5",
+ "@types/turndown": "^5.0.5",
"@vitejs/plugin-vue": "^5.1.4",
"autoprefixer": "^10.4.20",
"eslint": "^9.10.0",
@@ -20,15 +24,23 @@
"postcss": "^8.4.47",
"prettier": "^3.3.3",
"prettier-plugin-organize-imports": "^4.1.0",
+ "prettier-plugin-tailwindcss": "^0.6.8",
"tailwindcss": "^3.4.12",
"typescript": "^5.6.2",
"typescript-eslint": "^8.6.0",
"vite": "^5.4.6"
},
"dependencies": {
+ "@primevue/themes": "^4.0.7",
+ "dayjs": "^1.11.13",
+ "lodash": "^4.17.21",
+ "markdown-it": "^14.1.0",
+ "markdown-it-metadata-block": "^1.0.6",
"primevue": "^4.0.7",
+ "turndown": "^7.2.0",
"vue": "^3.4.31",
- "vue-i18n": "^9.13.1"
+ "vue-i18n": "^9.13.1",
+ "yaml": "^2.6.0"
},
"lint-staged": {
"./**/*.{js,ts,tsx,vue}": [
diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml
index f75ee96..f155679 100644
--- a/pnpm-lock.yaml
+++ b/pnpm-lock.yaml
@@ -8,19 +8,52 @@ importers:
.:
dependencies:
+ '@primevue/themes':
+ specifier: ^4.0.7
+ version: 4.0.7
+ dayjs:
+ specifier: ^1.11.13
+ version: 1.11.13
+ lodash:
+ specifier: ^4.17.21
+ version: 4.17.21
+ markdown-it:
+ specifier: ^14.1.0
+ version: 14.1.0
+ markdown-it-metadata-block:
+ specifier: ^1.0.6
+ version: 1.0.6
primevue:
specifier: ^4.0.7
version: 4.0.7(vue@3.5.6(typescript@5.6.2))
+ turndown:
+ specifier: ^7.2.0
+ version: 7.2.0
vue:
specifier: ^3.4.31
version: 3.5.6(typescript@5.6.2)
vue-i18n:
specifier: ^9.13.1
version: 9.14.0(vue@3.5.6(typescript@5.6.2))
+ yaml:
+ specifier: ^2.6.0
+ version: 2.6.0
devDependencies:
+ '@tailwindcss/container-queries':
+ specifier: ^0.1.1
+ version: 0.1.1(tailwindcss@3.4.12)
+ '@types/lodash':
+ specifier: ^4.17.9
+ version: 4.17.9
+ '@types/markdown-it':
+ specifier: ^14.1.2
+ version: 14.1.2
'@types/node':
specifier: ^22.5.5
version: 22.5.5
+ '@types/turndown':
+ specifier: ^5.0.5
+ version: 5.0.5
'@vitejs/plugin-vue':
specifier: ^5.1.4
version: 5.1.4(vite@5.4.6(@types/node@22.5.5)(less@4.2.0))(vue@3.5.6(typescript@5.6.2))
@@ -51,6 +84,9 @@ importers:
prettier-plugin-organize-imports:
specifier: ^4.1.0
version: 4.1.0(prettier@3.3.3)(typescript@5.6.2)
+ prettier-plugin-tailwindcss:
+ specifier: ^0.6.8
+ version: 0.6.8(prettier-plugin-organize-imports@4.1.0(prettier@3.3.3)(typescript@5.6.2))(prettier@3.3.3)
tailwindcss:
specifier: ^3.4.12
version: 3.4.12
@@ -297,6 +333,9 @@ packages:
'@jridgewell/trace-mapping@0.3.25':
resolution: {integrity: sha512-vNk6aEwybGtawWmy/PzwnGDOjCkLWSD2wqvjGGAgOAwCGWySYXfYoxt00IJkTF+8Lb57DwOb3Aa0o9CApepiYQ==}
+ '@mixmark-io/domino@2.2.0':
+ resolution: {integrity: sha512-Y28PR25bHXUg88kCV7nivXrP2Nj2RueZ3/l/jdx6J9f8J4nsEGcgX0Qe6lt7Pa+J79+kPiJU3LguR6O/6zrLOw==}
+
'@nodelib/fs.scandir@2.1.5':
resolution: {integrity: sha512-vq24Bq3ym5HEQm2NKCr3yXDwjc7vTsEThRDnkp2DK9p1uqLR+DHurm/NOTo0KG7HYHU7eppKZj3MyqYuMBf62g==}
engines: {node: '>= 8'}
@@ -331,6 +370,10 @@ packages:
resolution: {integrity: sha512-tj4dfRdV5iN6O0mbkpjhMsGlT3wZTqOPL779ndY5gKuCwN5zcFmKmABWVQmr/ClRivnMkw6Yr1x6gRTV/N0ydg==}
engines: {node: '>=12.11.0'}
+ '@primevue/themes@4.0.7':
+ resolution: {integrity: sha512-ZbDUrpBmtuqdeegNwUaJTubaLDBBJWOc4Z6UoQM3DG2c7EAE19wQbuh+cG9zqA7sT/Xsp+ACC/Z9e4FnfqB55g==}
+ engines: {node: '>=12.11.0'}
+
'@rollup/rollup-android-arm-eabi@4.22.0':
resolution: {integrity: sha512-/IZQvg6ZR0tAkEi4tdXOraQoWeJy9gbQ/cx4I7k9dJaCk9qrXEcdouxRVz5kZXt5C2bQ9pILoAA+KB4C/d3pfw==}
cpu: [arm]
@@ -420,12 +463,32 @@ packages:
cpu: [x64]
os: [win32]
+ '@tailwindcss/container-queries@0.1.1':
+ resolution: {integrity: sha512-p18dswChx6WnTSaJCSGx6lTmrGzNNvm2FtXmiO6AuA1V4U5REyoqwmT6kgAsIMdjo07QdAfYXHJ4hnMtfHzWgA==}
+ peerDependencies:
+ tailwindcss: '>=3.2.0'
+
'@types/estree@1.0.5':
resolution: {integrity: sha512-/kYRxGDLWzHOB7q+wtSUQlFrtcdUccpfy+X+9iMBpHK8QLLhx2wIPYuS5DYtR9Wa/YlZAbIovy7qVdB1Aq6Lyw==}
+ '@types/linkify-it@5.0.0':
+ resolution: {integrity: sha512-sVDA58zAw4eWAffKOaQH5/5j3XeayukzDk+ewSsnv3p4yJEZHCCzMDiZM8e0OUrRvmpGZ85jf4yDHkHsgBNr9Q==}
+
+ '@types/lodash@4.17.9':
+ resolution: {integrity: sha512-w9iWudx1XWOHW5lQRS9iKpK/XuRhnN+0T7HvdCCd802FYkT1AMTnxndJHGrNJwRoRHkslGr4S29tjm1cT7x/7w==}
+
+ '@types/markdown-it@14.1.2':
+ resolution: {integrity: sha512-promo4eFwuiW+TfGxhi+0x3czqTYJkG8qB17ZUJiVF10Xm7NLVRSLUsfRTU/6h1e24VvRnXCx+hG7li58lkzog==}
+
+ '@types/mdurl@2.0.0':
+ resolution: {integrity: sha512-RGdgjQUZba5p6QEFAVx2OGb8rQDL/cPRG7GiedRzMcJ1tYnUANBncjbSB1NRGwbvjcPeikRABz2nshyPk1bhWg==}
+
'@types/node@22.5.5':
resolution: {integrity: sha512-Xjs4y5UPO/CLdzpgR6GirZJx36yScjh73+2NlLlkFRSoQN8B0DpfXPdZGnvVmLRLOsqDpOfTNv7D9trgGhmOIA==}
+ '@types/turndown@5.0.5':
+ resolution: {integrity: sha512-TL2IgGgc7B5j78rIccBtlYAnkuv8nUQqhQc+DSYV5j9Be9XOcm/SKOVRuA47xAVI3680Tk9B1d8flK2GWT2+4w==}
+
'@typescript-eslint/eslint-plugin@8.6.0':
resolution: {integrity: sha512-UOaz/wFowmoh2G6Mr9gw60B1mm0MzUtm6Ic8G2yM1Le6gyj5Loi/N+O5mocugRGY+8OeeKmkMmbxNqUCq3B4Sg==}
engines: {node: ^18.18.0 || ^20.9.0 || >=21.1.0}
@@ -667,6 +730,9 @@ packages:
csstype@3.1.3:
resolution: {integrity: sha512-M1uQkMl8rQK/szD0LNhtqxIPLpimGm8sOBwU7lLnCpSbTyY3yeU1Vc7l4KT5zT4s/yOxHH5O7tIuuLOCnLADRw==}
+ dayjs@1.11.13:
+ resolution: {integrity: sha512-oaMBel6gjolK862uaPQOVTA7q3TZhuSvuMQAAglQDOWYO9A91IrAOUJEyKVlqJlHE0vq5p5UXxzdPfMH/x6xNg==}
+
debug@4.3.7:
resolution: {integrity: sha512-Er2nc/H7RrMXZBFCEim6TCmMk02Z8vLC2Rbi1KEBggpo0fS6l0S1nnapwmIi3yW/+GOJap1Krg4w0Hg80oCqgQ==}
engines: {node: '>=6.0'}
@@ -1002,6 +1068,9 @@ packages:
lines-and-columns@1.2.4:
resolution: {integrity: sha512-7ylylesZQ/PV29jhEDl3Ufjo6ZX7gCqJr5F7PKrqc93v7fzSymt1BpwEU8nAUXs8qzzvqhbjhK5QZg6Mt/HkBg==}
+ linkify-it@5.0.0:
+ resolution: {integrity: sha512-5aHCbzQRADcdP+ATqnDuhhJ/MRIqDkZX5pyjFHRRysS8vZ5AbqGEoFIb6pYHPZ+L/OC2Lc+xT8uHVVR5CAK/wQ==}
+
lint-staged@15.2.10:
resolution: {integrity: sha512-5dY5t743e1byO19P9I4b3x8HJwalIznL5E1FWYnU6OWw33KxNBSLAc6Cy7F2PsFEO8FKnLwjwm5hx7aMF0jzZg==}
engines: {node: '>=18.12.0'}
@@ -1035,6 +1104,16 @@ packages:
resolution: {integrity: sha512-LS9X+dc8KLxXCb8dni79fLIIUA5VyZoyjSMCwTluaXA0o27cCK0bhXkpgw+sTXVpPy/lSO57ilRixqk0vDmtRA==}
engines: {node: '>=6'}
+ markdown-it-metadata-block@1.0.6:
+ resolution: {integrity: sha512-0nMBdV/CLy/bFfcw3wFdiZ6sgEv/yWAoNxgb3qY+5lLEP804r/JT9yLmLH3Z3YrqGDHb5xIi7gqhj7gwbPHycQ==}
+
+ markdown-it@14.1.0:
+ resolution: {integrity: sha512-a54IwgWPaeBCAAsv13YgmALOF1elABB08FxO9i+r4VFk5Vl4pKokRPeX8u5TCgSsPi6ec1otfLjdOpVcgbpshg==}
+ hasBin: true
+
+ mdurl@2.0.0:
+ resolution: {integrity: sha512-Lf+9+2r+Tdp5wXDXC4PcIBjTDtq4UKjCPMQhKIuzpJNW0b96kVqSwW0bT7FhRSfmAiFYgP+SCRvdrDozfh0U5w==}
+
merge-stream@2.0.0:
resolution: {integrity: sha512-abv/qOcuPfk3URPfDzmZU1LKmuw8kT+0nIHvKrKgFrwifol/doWcdA4ZqsWQ8ENrFKkd67Mfpo/LovbIUsbt3w==}
@@ -1244,6 +1323,61 @@ packages:
vue-tsc:
optional: true
+ prettier-plugin-tailwindcss@0.6.8:
+ resolution: {integrity: sha512-dGu3kdm7SXPkiW4nzeWKCl3uoImdd5CTZEJGxyypEPL37Wj0HT2pLqjrvSei1nTeuQfO4PUfjeW5cTUNRLZ4sA==}
+ engines: {node: '>=14.21.3'}
+ peerDependencies:
+ '@ianvs/prettier-plugin-sort-imports': '*'
+ '@prettier/plugin-pug': '*'
+ '@shopify/prettier-plugin-liquid': '*'
+ '@trivago/prettier-plugin-sort-imports': '*'
+ '@zackad/prettier-plugin-twig-melody': '*'
+ prettier: ^3.0
+ prettier-plugin-astro: '*'
+ prettier-plugin-css-order: '*'
+ prettier-plugin-import-sort: '*'
+ prettier-plugin-jsdoc: '*'
+ prettier-plugin-marko: '*'
+ prettier-plugin-multiline-arrays: '*'
+ prettier-plugin-organize-attributes: '*'
+ prettier-plugin-organize-imports: '*'
+ prettier-plugin-sort-imports: '*'
+ prettier-plugin-style-order: '*'
+ prettier-plugin-svelte: '*'
+ peerDependenciesMeta:
+ '@ianvs/prettier-plugin-sort-imports':
+ optional: true
+ '@prettier/plugin-pug':
+ optional: true
+ '@shopify/prettier-plugin-liquid':
+ optional: true
+ '@trivago/prettier-plugin-sort-imports':
+ optional: true
+ '@zackad/prettier-plugin-twig-melody':
+ optional: true
+ prettier-plugin-astro:
+ optional: true
+ prettier-plugin-css-order:
+ optional: true
+ prettier-plugin-import-sort:
+ optional: true
+ prettier-plugin-jsdoc:
+ optional: true
+ prettier-plugin-marko:
+ optional: true
+ prettier-plugin-multiline-arrays:
+ optional: true
+ prettier-plugin-organize-attributes:
+ optional: true
+ prettier-plugin-organize-imports:
+ optional: true
+ prettier-plugin-sort-imports:
+ optional: true
+ prettier-plugin-style-order:
+ optional: true
+ prettier-plugin-svelte:
+ optional: true
+
prettier@3.3.3:
resolution: {integrity: sha512-i2tDNA0O5IrMO757lfrdQZCc2jPNDVntV0m/+4whiDfWaTKfMNgR7Qz0NAeGz/nRqF4m5/6CLzbP4/liHt12Ew==}
engines: {node: '>=14'}
@@ -1256,6 +1390,10 @@ packages:
prr@1.0.1:
resolution: {integrity: sha512-yPw4Sng1gWghHQWj0B3ZggWUm4qVbPwPFcRG8KyxiU7J2OHFSoEHKS+EZ3fv5l1t9CyCiop6l/ZYeWbrgoQejw==}
+ punycode.js@2.3.1:
+ resolution: {integrity: sha512-uxFIHU0YlHYhDQtV4R9J6a52SLx28BCjT+4ieh7IGbgwVJWO+km431c4yRlREUAsAmt/uMjQUyQHNEPf0M39CA==}
+ engines: {node: '>=6'}
+
punycode@2.3.1:
resolution: {integrity: sha512-vYt7UD1U9Wg6138shLtLOvdAu+8DsC/ilFtEVHcH+wydcSpNE20AfSOduf6MkRFahL5FY7X1oU7nKVZFtfq8Fg==}
engines: {node: '>=6'}
@@ -1420,6 +1558,9 @@ packages:
tslib@2.7.0:
resolution: {integrity: sha512-gLXCKdN1/j47AiHiOkJN69hJmcbGTHI0ImLmbYLHykhgeN0jVGola9yVjFgzCUklsZQMW55o+dW7IXv3RCXDzA==}
+ turndown@7.2.0:
+ resolution: {integrity: sha512-eCZGBN4nNNqM9Owkv9HAtWRYfLA4h909E/WGAWWBpmB275ehNhZyk87/Tpvjbp0jjNl9XwCsbe6bm6CqFsgD+A==}
+
type-check@0.4.0:
resolution: {integrity: sha512-XleUoc9uwGXqjWwXaUTZAmzMcFZ5858QA2vvx1Ur5xIcixXIP+8LnFDgRplU30us6teqdlskFfu+ae4K79Ooew==}
engines: {node: '>= 0.8.0'}
@@ -1442,6 +1583,9 @@ packages:
engines: {node: '>=14.17'}
hasBin: true
+ uc.micro@2.1.0:
+ resolution: {integrity: sha512-ARDJmphmdvUk6Glw7y9DQ2bFkKBHwQHLi2lsaH6PPmz/Ka9sFOBsBluozhDltWmnv9u/cF6Rt87znRTPV+yp/A==}
+
undici-types@6.19.8:
resolution: {integrity: sha512-ve2KP6f/JnbPBFyobGHuerC9g1FYGn/F8n1LWTwNxCEzd6IfqTwUQcNXgEtmmQ6DlRrC1hrSrBnCZPokRrDHjw==}
@@ -1538,6 +1682,11 @@ packages:
engines: {node: '>= 14'}
hasBin: true
+ yaml@2.6.0:
+ resolution: {integrity: sha512-a6ae//JvKDEra2kdi1qzCyrJW/WZCgFi8ydDV+eXExl95t+5R+ijnqHJbz9tmMh8FUjx3iv2fCQ4dclAQlO2UQ==}
+ engines: {node: '>= 14'}
+ hasBin: true
+
yocto-queue@0.1.0:
resolution: {integrity: sha512-rVksvsnNCdJ/ohGc6xgPwyN8eheCxsiLM8mxuE/t/mOVqJewPuO1miLpTHQiRgTKCLexL4MeAFVagts7HmNZ2Q==}
engines: {node: '>=10'}
@@ -1708,6 +1857,8 @@ snapshots:
'@jridgewell/resolve-uri': 3.1.2
'@jridgewell/sourcemap-codec': 1.5.0
+ '@mixmark-io/domino@2.2.0': {}
+
'@nodelib/fs.scandir@2.1.5':
dependencies:
'@nodelib/fs.stat': 2.0.5
@@ -1742,6 +1893,10 @@ snapshots:
transitivePeerDependencies:
- vue
+ '@primevue/themes@4.0.7':
+ dependencies:
+ '@primeuix/styled': 0.0.5
+
'@rollup/rollup-android-arm-eabi@4.22.0':
optional: true
@@ -1790,12 +1945,29 @@ snapshots:
'@rollup/rollup-win32-x64-msvc@4.22.0':
optional: true
+ '@tailwindcss/container-queries@0.1.1(tailwindcss@3.4.12)':
+ dependencies:
+ tailwindcss: 3.4.12
+
'@types/estree@1.0.5': {}
+ '@types/linkify-it@5.0.0': {}
+
+ '@types/lodash@4.17.9': {}
+
+ '@types/markdown-it@14.1.2':
+ dependencies:
+ '@types/linkify-it': 5.0.0
+ '@types/mdurl': 2.0.0
+
+ '@types/mdurl@2.0.0': {}
+
'@types/node@22.5.5':
dependencies:
undici-types: 6.19.8
+ '@types/turndown@5.0.5': {}
+
'@typescript-eslint/eslint-plugin@8.6.0(@typescript-eslint/parser@8.6.0(eslint@9.10.0(jiti@1.21.6))(typescript@5.6.2))(eslint@9.10.0(jiti@1.21.6))(typescript@5.6.2)':
dependencies:
'@eslint-community/regexpp': 4.11.1
@@ -2074,6 +2246,8 @@ snapshots:
csstype@3.1.3: {}
+ dayjs@1.11.13: {}
+
debug@4.3.7:
dependencies:
ms: 2.1.3
@@ -2430,6 +2604,10 @@ snapshots:
lines-and-columns@1.2.4: {}
+ linkify-it@5.0.0:
+ dependencies:
+ uc.micro: 2.1.0
+
lint-staged@15.2.10:
dependencies:
chalk: 5.3.0
@@ -2482,6 +2660,19 @@ snapshots:
semver: 5.7.2
optional: true
+ markdown-it-metadata-block@1.0.6: {}
+
+ markdown-it@14.1.0:
+ dependencies:
+ argparse: 2.0.1
+ entities: 4.5.0
+ linkify-it: 5.0.0
+ mdurl: 2.0.0
+ punycode.js: 2.3.1
+ uc.micro: 2.1.0
+
+ mdurl@2.0.0: {}
+
merge-stream@2.0.0: {}
merge2@1.4.1: {}
@@ -2618,7 +2809,7 @@ snapshots:
postcss-load-config@4.0.2(postcss@8.4.47):
dependencies:
lilconfig: 3.1.2
- yaml: 2.5.1
+ yaml: 2.6.0
optionalDependencies:
postcss: 8.4.47
@@ -2647,6 +2838,12 @@ snapshots:
prettier: 3.3.3
typescript: 5.6.2
+ prettier-plugin-tailwindcss@0.6.8(prettier-plugin-organize-imports@4.1.0(prettier@3.3.3)(typescript@5.6.2))(prettier@3.3.3):
+ dependencies:
+ prettier: 3.3.3
+ optionalDependencies:
+ prettier-plugin-organize-imports: 4.1.0(prettier@3.3.3)(typescript@5.6.2)
+
prettier@3.3.3: {}
primevue@4.0.7(vue@3.5.6(typescript@5.6.2)):
@@ -2661,6 +2858,8 @@ snapshots:
prr@1.0.1:
optional: true
+ punycode.js@2.3.1: {}
+
punycode@2.3.1: {}
queue-microtask@1.2.3: {}
@@ -2849,6 +3048,10 @@ snapshots:
tslib@2.7.0: {}
+ turndown@7.2.0:
+ dependencies:
+ '@mixmark-io/domino': 2.2.0
+
type-check@0.4.0:
dependencies:
prelude-ls: 1.2.1
@@ -2868,6 +3071,8 @@ snapshots:
typescript@5.6.2: {}
+ uc.micro@2.1.0: {}
+
undici-types@6.19.8: {}
update-browserslist-db@1.1.0(browserslist@4.23.3):
@@ -2950,4 +3155,6 @@ snapshots:
yaml@2.5.1: {}
+ yaml@2.6.0: {}
+
yocto-queue@0.1.0: {}
diff --git a/py/config.py b/py/config.py
new file mode 100644
index 0000000..8efbdee
--- /dev/null
+++ b/py/config.py
@@ -0,0 +1,33 @@
+extension_uri: str = None
+model_base_paths: dict[str, list[str]] = {}
+
+
+setting_key = {
+ "api_key": {
+ "civitai": "ModelManager.APIKey.Civitai",
+ "huggingface": "ModelManager.APIKey.HuggingFace",
+ },
+ "download": {
+ "max_task_count": "ModelManager.Download.MaxTaskCount",
+ },
+}
+
+user_agent = "Mozilla/5.0 (iPad; CPU OS 12_2 like Mac OS X) AppleWebKit/605.1.15 (KHTML, like Gecko) Mobile/15E148"
+
+
+from server import PromptServer
+
+serverInstance = PromptServer.instance
+routes = serverInstance.routes
+
+
+class FakeRequest:
+ def __init__(self):
+ self.headers = {}
+
+
+class CustomException(BaseException):
+ def __init__(self, type: str, message: str = None) -> None:
+ self.type = type
+ self.message = message
+ super().__init__(message)
diff --git a/py/download.py b/py/download.py
new file mode 100644
index 0000000..ca8f42d
--- /dev/null
+++ b/py/download.py
@@ -0,0 +1,362 @@
+import os
+import uuid
+import time
+import logging
+import requests
+import folder_paths
+import traceback
+from typing import Callable, Awaitable, Any, Literal, Union, Optional
+from dataclasses import dataclass
+from . import config
+from . import utils
+from . import socket
+from . import thread
+
+
+@dataclass
+class TaskStatus:
+ taskId: str
+ type: str
+ fullname: str
+ preview: str
+ status: Literal["pause", "waiting", "doing"] = "pause"
+ platform: Union[str, None] = None
+ downloadedSize: float = 0
+ totalSize: float = 0
+ progress: float = 0
+ bps: float = 0
+ error: Optional[str] = None
+
+
+@dataclass
+class TaskContent:
+ type: str
+ pathIndex: int
+ fullname: str
+ description: str
+ downloadPlatform: str
+ downloadUrl: str
+ sizeBytes: float
+ hashes: Optional[dict[str, str]] = None
+
+
+download_model_task_status: dict[str, TaskStatus] = {}
+download_thread_pool = thread.DownloadThreadPool()
+
+
+def set_task_content(task_id: str, task_content: Union[TaskContent, dict]):
+ download_path = utils.get_download_path()
+ task_file_path = os.path.join(download_path, f"{task_id}.task")
+ utils.save_dict_pickle_file(task_file_path, utils.unpack_dataclass(task_content))
+
+
+def get_task_content(task_id: str):
+ download_path = utils.get_download_path()
+ task_file = os.path.join(download_path, f"{task_id}.task")
+ if not os.path.isfile(task_file):
+ raise RuntimeError(f"Task {task_id} not found")
+ task_content = utils.load_dict_pickle_file(task_file)
+ task_content["pathIndex"] = int(task_content.get("pathIndex", 0))
+ task_content["sizeBytes"] = float(task_content.get("sizeBytes", 0))
+ return TaskContent(**task_content)
+
+
+def get_task_status(task_id: str):
+ task_status = download_model_task_status.get(task_id, None)
+
+ if task_status is None:
+ download_path = utils.get_download_path()
+ task_content = get_task_content(task_id)
+ download_file = os.path.join(download_path, f"{task_id}.download")
+ download_size = 0
+ if os.path.exists(download_file):
+ download_size = os.path.getsize(download_file)
+
+ total_size = task_content.sizeBytes
+ task_status = TaskStatus(
+ taskId=task_id,
+ type=task_content.type,
+ fullname=task_content.fullname,
+ preview=utils.get_model_preview_name(download_file),
+ platform=task_content.downloadPlatform,
+ downloadedSize=download_size,
+ totalSize=task_content.sizeBytes,
+ progress=download_size / total_size * 100 if total_size > 0 else 0,
+ )
+
+ download_model_task_status[task_id] = task_status
+
+ return task_status
+
+
+def delete_task_status(task_id: str):
+ download_model_task_status.pop(task_id, None)
+
+
+async def scan_model_download_task_list(sid: str):
+ """
+ Scan the download directory and send the task list to the client.
+ """
+ try:
+ download_dir = utils.get_download_path()
+ task_files = utils.search_files(download_dir)
+ task_files = folder_paths.filter_files_extensions(task_files, [".task"])
+ task_files = sorted(
+ task_files,
+ key=lambda x: os.stat(os.path.join(download_dir, x)).st_ctime,
+ reverse=True,
+ )
+ task_list: list[dict] = []
+ for task_file in task_files:
+ task_id = task_file.replace(".task", "")
+ task_status = get_task_status(task_id)
+ task_list.append(task_status)
+
+ await socket.send_json("downloadTaskList", task_list, sid)
+ except Exception as e:
+ error_msg = f"Refresh task list failed: {e}"
+ await socket.send_json("error", error_msg, sid)
+ logging.error(error_msg)
+
+
+async def create_model_download_task(post: dict):
+ """
+ Creates a download task for the given post.
+ """
+ model_type = post.get("type", None)
+ path_index = int(post.get("pathIndex", None))
+ fullname = post.get("fullname", None)
+
+ model_path = utils.get_full_path(model_type, path_index, fullname)
+ # Check if the model path is valid
+ if os.path.exists(model_path):
+ raise RuntimeError(f"File already exists: {model_path}")
+
+ download_path = utils.get_download_path()
+
+ task_id = uuid.uuid4().hex
+ task_path = os.path.join(download_path, f"{task_id}.task")
+ if os.path.exists(task_path):
+ raise RuntimeError(f"Task {task_id} already exists")
+
+ try:
+ previewFile = post.pop("previewFile", None)
+ utils.save_model_preview_image(task_path, previewFile)
+ set_task_content(task_id, post)
+ task_status = TaskStatus(
+ taskId=task_id,
+ type=model_type,
+ fullname=fullname,
+ preview=utils.get_model_preview_name(task_path),
+ platform=post.get("downloadPlatform", None),
+ totalSize=float(post.get("sizeBytes", 0)),
+ )
+ download_model_task_status[task_id] = task_status
+ await socket.send_json("createDownloadTask", task_status)
+ except Exception as e:
+ await delete_model_download_task(task_id)
+ raise RuntimeError(str(e)) from e
+
+ await download_model(task_id)
+ return task_id
+
+
+async def pause_model_download_task(task_id: str):
+ task_status = get_task_status(task_id=task_id)
+ task_status.status = "pause"
+
+
+async def delete_model_download_task(task_id: str):
+ task_status = get_task_status(task_id)
+ is_running = task_status.status == "doing"
+ task_status.status = "waiting"
+ await socket.send_json("deleteDownloadTask", task_id)
+
+ # Pause the task
+ if is_running:
+ task_status.status = "pause"
+ time.sleep(1)
+
+ download_dir = utils.get_download_path()
+ task_file_list = os.listdir(download_dir)
+ for task_file in task_file_list:
+ task_file_target = os.path.splitext(task_file)[0]
+ if task_file_target == task_id:
+ delete_task_status(task_id)
+ os.remove(os.path.join(download_dir, task_file))
+
+ await socket.send_json("deleteDownloadTask", task_id)
+
+
+async def download_model(task_id: str):
+ async def download_task(task_id: str):
+ async def report_progress(task_status: TaskStatus):
+ await socket.send_json("updateDownloadTask", task_status)
+
+ try:
+ # When starting a task from the queue, the task may not exist
+ task_status = get_task_status(task_id)
+ except:
+ return
+
+ # Update task status
+ task_status.status = "doing"
+ await socket.send_json("updateDownloadTask", task_status)
+
+ try:
+
+ # Set download request headers
+ headers = {"User-Agent": config.user_agent}
+
+ download_platform = task_status.platform
+ if download_platform == "civitai":
+ api_key = utils.get_setting_value("api_key.civitai")
+ if api_key:
+ headers["Authorization"] = f"Bearer {api_key}"
+
+ elif download_platform == "huggingface":
+ api_key = utils.get_setting_value("api_key.huggingface")
+ if api_key:
+ headers["Authorization"] = f"Bearer {api_key}"
+
+ progress_interval = 1.0
+ await download_model_file(
+ task_id=task_id,
+ headers=headers,
+ progress_callback=report_progress,
+ interval=progress_interval,
+ )
+ except Exception as e:
+ task_status.status = "pause"
+ task_status.error = str(e)
+ await socket.send_json("updateDownloadTask", task_status)
+ task_status.error = None
+ logging.error(str(e))
+
+ try:
+ status = download_thread_pool.submit(download_task, task_id)
+ if status == "Waiting":
+ task_status = get_task_status(task_id)
+ task_status.status = "waiting"
+ await socket.send_json("updateDownloadTask", task_status)
+ except Exception as e:
+ task_status.status = "pause"
+ task_status.error = str(e)
+ await socket.send_json("updateDownloadTask", task_status)
+ task_status.error = None
+ logging.error(traceback.format_exc())
+
+
+async def download_model_file(
+ task_id: str,
+ headers: dict,
+ progress_callback: Callable[[TaskStatus], Awaitable[Any]],
+ interval: float = 1.0,
+):
+
+ async def download_complete():
+ """
+ Restore the model information from the task file
+ and move the model file to the target directory.
+ """
+ model_type = task_content.type
+ path_index = task_content.pathIndex
+ fullname = task_content.fullname
+ # Write description file
+ description = task_content.description
+ description_file = os.path.join(download_path, f"{task_id}.md")
+ with open(description_file, "w") as f:
+ f.write(description)
+
+ model_path = utils.get_full_path(model_type, path_index, fullname)
+
+ utils.rename_model(download_tmp_file, model_path)
+
+ time.sleep(1)
+ task_file = os.path.join(download_path, f"{task_id}.task")
+ os.remove(task_file)
+ await socket.send_json("completeDownloadTask", task_id)
+
+ async def update_progress():
+ nonlocal last_update_time
+ nonlocal last_downloaded_size
+ progress = (downloaded_size / total_size) * 100 if total_size > 0 else 0
+ task_status.downloadedSize = downloaded_size
+ task_status.progress = progress
+ task_status.bps = downloaded_size - last_downloaded_size
+ await progress_callback(task_status)
+ last_update_time = time.time()
+ last_downloaded_size = downloaded_size
+
+ task_status = get_task_status(task_id)
+ task_content = get_task_content(task_id)
+
+ # Check download uri
+ model_url = task_content.downloadUrl
+ if not model_url:
+ raise RuntimeError("No downloadUrl found")
+
+ download_path = utils.get_download_path()
+ download_tmp_file = os.path.join(download_path, f"{task_id}.download")
+
+ downloaded_size = 0
+ if os.path.isfile(download_tmp_file):
+ downloaded_size = os.path.getsize(download_tmp_file)
+ headers["Range"] = f"bytes={downloaded_size}-"
+
+ total_size = task_content.sizeBytes
+
+ if total_size > 0 and downloaded_size == total_size:
+ await download_complete()
+ return
+
+ last_update_time = time.time()
+ last_downloaded_size = downloaded_size
+
+ response = requests.get(
+ url=model_url,
+ headers=headers,
+ stream=True,
+ allow_redirects=True,
+ )
+
+ if response.status_code not in (200, 206):
+ raise RuntimeError(
+ f"Failed to download {task_content.fullname}, status code: {response.status_code}"
+ )
+
+ # Some models require logging in before they can be downloaded.
+ # If no token is carried, it will be redirected to the login page.
+ content_type = response.headers.get("content-type")
+ if content_type and content_type.startswith("text/html"):
+ raise RuntimeError(
+ f"{task_content.fullname} needs to be logged in to download. Please set the API-Key first."
+ )
+
+ # When parsing model information from HuggingFace API,
+ # the file size was not found and needs to be obtained from the response header.
+ if total_size == 0:
+ total_size = int(response.headers.get("content-length", 0))
+ task_content.sizeBytes = total_size
+ task_status.totalSize = total_size
+ set_task_content(task_id, task_content)
+ await socket.send_json("updateDownloadTask", task_content)
+
+ with open(download_tmp_file, "ab") as f:
+ for chunk in response.iter_content(chunk_size=8192):
+ if task_status.status == "pause":
+ break
+
+ f.write(chunk)
+ downloaded_size += len(chunk)
+
+ if time.time() - last_update_time >= interval:
+ await update_progress()
+
+ await update_progress()
+
+ if total_size > 0 and downloaded_size == total_size:
+ await download_complete()
+ else:
+ task_status.status = "pause"
+ await socket.send_json("updateDownloadTask", task_status)
diff --git a/py/services.py b/py/services.py
new file mode 100644
index 0000000..835b095
--- /dev/null
+++ b/py/services.py
@@ -0,0 +1,145 @@
+import os
+import logging
+import traceback
+import folder_paths
+
+from typing import Any
+from multidict import MultiDictProxy
+from . import utils
+from . import socket
+from . import download
+
+
+async def connect_websocket(request):
+ async def message_handler(event_type: str, detail: Any, sid: str):
+ try:
+ if event_type == "downloadTaskList":
+ await download.scan_model_download_task_list(sid=sid)
+
+ if event_type == "resumeDownloadTask":
+ await download.download_model(task_id=detail)
+
+ if event_type == "pauseDownloadTask":
+ await download.pause_model_download_task(task_id=detail)
+
+ if event_type == "deleteDownloadTask":
+ await download.delete_model_download_task(task_id=detail)
+ except Exception:
+ logging.error(traceback.format_exc())
+
+ ws = await socket.create_websocket_handler(request, handler=message_handler)
+ return ws
+
+
+def scan_models_by_model_type(model_type: str):
+ """
+ Scans all models in the given model type and returns a list of models.
+ """
+ out = []
+ folders, extensions = folder_paths.folder_names_and_paths[model_type]
+ for path_index, base_path in enumerate(folders):
+ files = utils.recursive_search_files(base_path)
+
+ models = folder_paths.filter_files_extensions(files, extensions)
+
+ for fullname in models:
+ """
+ fullname is model path relative to base_path
+ eg.
+ abs_path is /path/to/models/stable-diffusion/custom_group/model_name.ckpt
+ base_path is /path/to/models/stable-diffusion
+ fullname is custom_group/model_name.ckpt
+ basename is custom_group/model_name
+ extension is .ckpt
+ """
+
+ fullname = fullname.replace(os.path.sep, "/")
+ basename = os.path.splitext(fullname)[0]
+ extension = os.path.splitext(fullname)[1]
+ prefix_path = fullname.replace(os.path.basename(fullname), "")
+
+ abs_path = os.path.join(base_path, fullname)
+ file_stats = os.stat(abs_path)
+
+ # Resolve metadata
+ metadata = utils.get_model_metadata(abs_path)
+
+ # Resolve preview
+ image_name = utils.get_model_preview_name(abs_path)
+ image_name = os.path.join(prefix_path, image_name)
+ abs_image_path = os.path.join(base_path, image_name)
+ if os.path.isfile(abs_image_path):
+ image_state = os.stat(abs_image_path)
+ image_timestamp = round(image_state.st_mtime_ns / 1000000)
+ image_name = f"{image_name}?ts={image_timestamp}"
+ model_preview = (
+ f"/model-manager/preview/{model_type}/{path_index}/{image_name}"
+ )
+
+ # Resolve description
+ description_file = utils.get_model_description_name(abs_path)
+ description_file = os.path.join(prefix_path, description_file)
+ abs_desc_path = os.path.join(base_path, description_file)
+ description = None
+ if os.path.isfile(abs_desc_path):
+ with open(abs_desc_path, "r", encoding="utf-8") as f:
+ description = f.read()
+
+ out.append(
+ {
+ "fullname": fullname,
+ "basename": basename,
+ "extension": extension,
+ "type": model_type,
+ "pathIndex": path_index,
+ "sizeBytes": file_stats.st_size,
+ "preview": model_preview,
+ "description": description,
+ "createdAt": round(file_stats.st_ctime_ns / 1000000),
+ "updatedAt": round(file_stats.st_mtime_ns / 1000000),
+ "metadata": metadata,
+ }
+ )
+
+ return out
+
+
+def update_model(model_path: str, post: MultiDictProxy):
+
+ if "previewFile" in post:
+ previewFile = post["previewFile"]
+ utils.save_model_preview_image(model_path, previewFile)
+
+ if "description" in post:
+ description = post["description"]
+ utils.save_model_description(model_path, description)
+
+ if "type" in post and "pathIndex" in post and "fullname" in post:
+ model_type = post.get("type", None)
+ path_index = int(post.get("pathIndex", None))
+ fullname = post.get("fullname", None)
+ if model_type is None or path_index is None or fullname is None:
+ raise RuntimeError("Invalid type or pathIndex or fullname")
+
+ # get new path
+ new_model_path = utils.get_full_path(model_type, path_index, fullname)
+
+ utils.rename_model(model_path, new_model_path)
+
+
+def remove_model(model_path: str):
+ model_dirname = os.path.dirname(model_path)
+ os.remove(model_path)
+
+ model_previews = utils.get_model_all_images(model_path)
+ for preview in model_previews:
+ os.remove(os.path.join(model_dirname, preview))
+
+ model_descriptions = utils.get_model_all_descriptions(model_path)
+ for description in model_descriptions:
+ os.remove(os.path.join(model_dirname, description))
+
+
+async def create_model_download_task(post):
+ dict_post = dict(post)
+ return await download.create_model_download_task(dict_post)
diff --git a/py/socket.py b/py/socket.py
new file mode 100644
index 0000000..13a39f9
--- /dev/null
+++ b/py/socket.py
@@ -0,0 +1,63 @@
+import aiohttp
+import logging
+import uuid
+import json
+from aiohttp import web
+from typing import Any, Callable, Awaitable
+from . import utils
+
+
+__sockets: dict[str, web.WebSocketResponse] = {}
+
+
+async def create_websocket_handler(
+ request, handler: Callable[[str, Any, str], Awaitable[Any]]
+):
+ ws = web.WebSocketResponse()
+ await ws.prepare(request)
+ sid = request.rel_url.query.get("clientId", "")
+ if sid:
+ # Reusing existing session, remove old
+ __sockets.pop(sid, None)
+ else:
+ sid = uuid.uuid4().hex
+
+ __sockets[sid] = ws
+
+ try:
+ async for msg in ws:
+ if msg.type == aiohttp.WSMsgType.ERROR:
+ logging.warning(
+ "ws connection closed with exception %s" % ws.exception()
+ )
+ if msg.type == aiohttp.WSMsgType.TEXT:
+ data = json.loads(msg.data)
+ await handler(data.get("type"), data.get("detail"), sid)
+ finally:
+ __sockets.pop(sid, None)
+ return ws
+
+
+async def send_json(event: str, data: Any, sid: str = None):
+ detail = utils.unpack_dataclass(data)
+ message = {"type": event, "data": detail}
+
+ if sid is None:
+ socket_list = list(__sockets.values())
+ for ws in socket_list:
+ await __send_socket_catch_exception(ws.send_json, message)
+ elif sid in __sockets:
+ await __send_socket_catch_exception(__sockets[sid].send_json, message)
+
+
+async def __send_socket_catch_exception(function, message):
+ try:
+ await function(message)
+ except (
+ aiohttp.ClientError,
+ aiohttp.ClientPayloadError,
+ ConnectionResetError,
+ BrokenPipeError,
+ ConnectionError,
+ ) as err:
+ logging.warning("send error: {}".format(err))
diff --git a/py/thread.py b/py/thread.py
new file mode 100644
index 0000000..82689f7
--- /dev/null
+++ b/py/thread.py
@@ -0,0 +1,64 @@
+import asyncio
+import threading
+import queue
+import logging
+from . import utils
+
+
+class DownloadThreadPool:
+ def __init__(self) -> None:
+ self.workers_count = 0
+ self.task_queue = queue.Queue()
+ self.running_tasks = set()
+ self._lock = threading.Lock()
+
+ default_max_workers = 5
+ max_workers: int = utils.get_setting_value(
+ "download.max_task_count", default_max_workers
+ )
+
+ if max_workers <= 0:
+ max_workers = default_max_workers
+ utils.set_setting_value("download.max_task_count", max_workers)
+
+ self.max_worker = max_workers
+
+ def submit(self, task, task_id):
+ with self._lock:
+ if task_id in self.running_tasks:
+ return "Existing"
+ self.running_tasks.add(task_id)
+ self.task_queue.put((task, task_id))
+ return self._adjust_worker_count()
+
+ def _adjust_worker_count(self):
+ if self.workers_count < self.max_worker:
+ self._start_worker()
+ return "Running"
+ else:
+ return "Waiting"
+
+ def _start_worker(self):
+ t = threading.Thread(target=self._worker, daemon=True)
+ t.start()
+ with self._lock:
+ self.workers_count += 1
+
+ def _worker(self):
+ loop = asyncio.new_event_loop()
+
+ while True:
+ if self.task_queue.empty():
+ break
+
+ task, task_id = self.task_queue.get()
+
+ try:
+ loop.run_until_complete(task(task_id))
+ with self._lock:
+ self.running_tasks.remove(task_id)
+ except Exception as e:
+ logging.error(f"worker run error: {str(e)}")
+
+ with self._lock:
+ self.workers_count -= 1
diff --git a/py/utils.py b/py/utils.py
new file mode 100644
index 0000000..fcdac9f
--- /dev/null
+++ b/py/utils.py
@@ -0,0 +1,282 @@
+import os
+import comfy.utils
+import json
+import logging
+import folder_paths
+from aiohttp import web
+from typing import Any
+from . import config
+
+
+def resolve_model_base_paths():
+ folders = list(folder_paths.folder_names_and_paths.keys())
+ config.model_base_paths = {}
+ for folder in folders:
+ if folder == "configs":
+ continue
+ if folder == "custom_nodes":
+ continue
+ config.model_base_paths[folder] = folder_paths.get_folder_paths(folder)
+
+
+def get_full_path(model_type: str, path_index: int, filename: str):
+ """
+ Get the absolute path in the model type through string concatenation.
+ """
+ folders = config.model_base_paths.get(model_type, [])
+ if not path_index < len(folders):
+ raise RuntimeError(f"PathIndex {path_index} is not in {model_type}")
+ base_path = folders[path_index]
+ return os.path.join(base_path, filename)
+
+
+def get_valid_full_path(model_type: str, path_index: int, filename: str):
+ """
+ Like get_full_path but it will check whether the file is valid.
+ """
+ folders = config.model_base_paths.get(model_type, [])
+ if not path_index < len(folders):
+ raise RuntimeError(f"PathIndex {path_index} is not in {model_type}")
+ base_path = folders[path_index]
+ full_path = os.path.join(base_path, filename)
+ if os.path.isfile(full_path):
+ return full_path
+ elif os.path.islink(full_path):
+ raise RuntimeError(
+ f"WARNING path {full_path} exists but doesn't link anywhere, skipping."
+ )
+
+
+def get_download_path():
+ download_path = os.path.join(config.extension_uri, "downloads")
+ if not os.path.exists(download_path):
+ os.makedirs(download_path)
+ return download_path
+
+
+def recursive_search_files(directory: str):
+ files, folder_all = folder_paths.recursive_search(
+ directory, excluded_dir_names=[".git"]
+ )
+ files.sort()
+ return files
+
+
+def search_files(directory: str):
+ entries = os.listdir(directory)
+ files = [f for f in entries if os.path.isfile(os.path.join(directory, f))]
+ files.sort()
+ return files
+
+
+def get_model_metadata(filename: str):
+ if not filename.endswith(".safetensors"):
+ return {}
+ try:
+ out = comfy.utils.safetensors_header(filename, max_size=1024 * 1024)
+ if out is None:
+ return {}
+ dt = json.loads(out)
+ if not "__metadata__" in dt:
+ return {}
+ return dt["__metadata__"]
+ except:
+ return {}
+
+
+def get_model_all_images(model_path: str):
+ base_dirname = os.path.dirname(model_path)
+ files = search_files(base_dirname)
+ files = folder_paths.filter_files_content_types(files, ["image"])
+
+ basename = os.path.splitext(os.path.basename(model_path))[0]
+ output: list[str] = []
+ for file in files:
+ file_basename = os.path.splitext(file)[0]
+ if file_basename == basename:
+ output.append(file)
+ if file_basename == f"{basename}.preview":
+ output.append(file)
+ return output
+
+
+def get_model_preview_name(model_path: str):
+ images = get_model_all_images(model_path)
+ return images[0] if len(images) > 0 else "no-preview.png"
+
+
+def save_model_preview_image(model_path: str, image_file: Any):
+ if not isinstance(image_file, web.FileField):
+ raise RuntimeError("Invalid image file")
+
+ content_type: str = image_file.content_type
+ if not content_type.startswith("image/"):
+ raise RuntimeError(f"FileTypeError: expected image, got {content_type}")
+
+ base_dirname = os.path.dirname(model_path)
+
+ # remove old preview images
+ old_preview_images = get_model_all_images(model_path)
+ a1111_civitai_helper_image = False
+ for image in old_preview_images:
+ if os.path.splitext(image)[1].endswith(".preview"):
+ a1111_civitai_helper_image = True
+ image_path = os.path.join(base_dirname, image)
+ os.remove(image_path)
+
+ # save new preview image
+ basename = os.path.splitext(os.path.basename(model_path))[0]
+ extension = f".{content_type.split('/')[1]}"
+ new_preview_path = os.path.join(base_dirname, f"{basename}{extension}")
+
+ with open(new_preview_path, "wb") as f:
+ f.write(image_file.file.read())
+
+ # TODO Is it possible to abandon the current rules and adopt the rules of a1111 civitai_helper?
+ if a1111_civitai_helper_image:
+ """
+ Keep preview image of a1111_civitai_helper
+ """
+ new_preview_path = os.path.join(base_dirname, f"{basename}.preview{extension}")
+ with open(new_preview_path, "wb") as f:
+ f.write(image_file.file.read())
+
+
+def get_model_all_descriptions(model_path: str):
+ base_dirname = os.path.dirname(model_path)
+ files = search_files(base_dirname)
+ files = folder_paths.filter_files_extensions(files, [".txt", ".md"])
+
+ basename = os.path.splitext(os.path.basename(model_path))[0]
+ output: list[str] = []
+ for file in files:
+ file_basename = os.path.splitext(file)[0]
+ if file_basename == basename:
+ output.append(file)
+ return output
+
+
+def get_model_description_name(model_path: str):
+ descriptions = get_model_all_descriptions(model_path)
+ basename = os.path.splitext(os.path.basename(model_path))[0]
+ return descriptions[0] if len(descriptions) > 0 else f"{basename}.md"
+
+
+def save_model_description(model_path: str, content: Any):
+ if not isinstance(content, str):
+ raise RuntimeError("Invalid description")
+
+ base_dirname = os.path.dirname(model_path)
+
+ # remove old descriptions
+ old_descriptions = get_model_all_descriptions(model_path)
+ for desc in old_descriptions:
+ description_path = os.path.join(base_dirname, desc)
+ os.remove(description_path)
+
+ # save new description
+ basename = os.path.splitext(os.path.basename(model_path))[0]
+ extension = ".md"
+ new_desc_path = os.path.join(base_dirname, f"{basename}{extension}")
+
+ with open(new_desc_path, "w", encoding="utf-8") as f:
+ f.write(content)
+
+
+def rename_model(model_path: str, new_model_path: str):
+ if model_path == new_model_path:
+ return
+
+ if os.path.exists(new_model_path):
+ raise RuntimeError(f"Model {new_model_path} already exists")
+
+ model_name = os.path.splitext(os.path.basename(model_path))[0]
+ new_model_name = os.path.splitext(os.path.basename(new_model_path))[0]
+
+ model_dirname = os.path.dirname(model_path)
+ new_model_dirname = os.path.dirname(new_model_path)
+
+ if not os.path.exists(new_model_dirname):
+ os.makedirs(new_model_dirname)
+
+ # move model
+ os.rename(model_path, new_model_path)
+
+ # move preview
+ previews = get_model_all_images(model_path)
+ for preview in previews:
+ preview_path = os.path.join(model_dirname, preview)
+ preview_name = os.path.splitext(preview)[0]
+ preview_ext = os.path.splitext(preview)[1]
+ new_preview_path = (
+ os.path.join(new_model_dirname, new_model_name + preview_ext)
+ if preview_name == model_name
+ else os.path.join(
+ new_model_dirname, new_model_name + ".preview" + preview_ext
+ )
+ )
+ os.rename(preview_path, new_preview_path)
+
+ # move description
+ description = get_model_description_name(model_path)
+ description_path = os.path.join(model_dirname, description)
+ if os.path.isfile(description_path):
+ new_description_path = os.path.join(new_model_dirname, f"{new_model_name}.md")
+ os.rename(description_path, new_description_path)
+
+
+import pickle
+
+
+def save_dict_pickle_file(filename: str, data: dict):
+ with open(filename, "wb") as f:
+ pickle.dump(data, f)
+
+
+def load_dict_pickle_file(filename: str) -> dict:
+ with open(filename, "rb") as f:
+ data = pickle.load(f)
+ return data
+
+
+def resolve_setting_key(key: str) -> str:
+ key_paths = key.split(".")
+ setting_id = config.setting_key
+ try:
+ for key_path in key_paths:
+ setting_id = setting_id[key_path]
+ except:
+ pass
+ if not isinstance(setting_id, str):
+ raise RuntimeError(f"Invalid key: {key}")
+
+ return setting_id
+
+
+def set_setting_value(key: str, value: Any):
+ setting_id = resolve_setting_key(key)
+ fake_request = config.FakeRequest()
+ settings = config.serverInstance.user_manager.settings.get_settings(fake_request)
+ settings[setting_id] = value
+ config.serverInstance.user_manager.settings.save_settings(fake_request, settings)
+
+
+def get_setting_value(key: str, default: Any = None) -> Any:
+ setting_id = resolve_setting_key(key)
+ fake_request = config.FakeRequest()
+ settings = config.serverInstance.user_manager.settings.get_settings(fake_request)
+ return settings.get(setting_id, default)
+
+
+from dataclasses import asdict, is_dataclass
+
+
+def unpack_dataclass(data: Any):
+ if isinstance(data, dict):
+ return {key: unpack_dataclass(value) for key, value in data.items()}
+ elif isinstance(data, list):
+ return [unpack_dataclass(x) for x in data]
+ elif is_dataclass(data):
+ return asdict(data)
+ else:
+ return data
diff --git a/src/App.vue b/src/App.vue
new file mode 100644
index 0000000..5eb7ae4
--- /dev/null
+++ b/src/App.vue
@@ -0,0 +1,41 @@
+
+ | + {{ $t(`info.${item.key}`) }} + | +{{ item.display }} | +
| + {{ item.key }} + | +{{ item.value }} | +
+ {{ $t('uploadFile') }} +
+