diff --git a/__init__.py b/__init__.py index c3f8154..e93d1e9 100644 --- a/__init__.py +++ b/__init__.py @@ -20,21 +20,51 @@ requests.packages.urllib3.disable_warnings() import folder_paths -config_loader_path = os.path.join(os.path.dirname(__file__), 'config_loader.py') +comfyui_model_uri = folder_paths.models_dir + +extension_uri = os.path.dirname(__file__) + +config_loader_path = os.path.join(extension_uri, 'config_loader.py') config_loader_spec = importlib.util.spec_from_file_location('config_loader', config_loader_path) config_loader = importlib.util.module_from_spec(config_loader_spec) config_loader_spec.loader.exec_module(config_loader) -comfyui_model_uri = os.path.join(os.getcwd(), "models") -extension_uri = os.path.join(os.getcwd(), "custom_nodes" + os.path.sep + "ComfyUI-Model-Manager") no_preview_image = os.path.join(extension_uri, "no-preview.png") ui_settings_uri = os.path.join(extension_uri, "ui_settings.yaml") server_settings_uri = os.path.join(extension_uri, "server_settings.yaml") fallback_model_extensions = set([".bin", ".ckpt", ".onnx", ".pt", ".pth", ".safetensors"]) # TODO: magic values -image_extensions = (".apng", ".gif", ".jpeg", ".jpg", ".png", ".webp") # TODO: JavaScript does not know about this (x2 states) +image_extensions = ( + ".png", # order matters + ".webp", + ".jpeg", + ".jpg", + ".gif", + ".apng", +) +stable_diffusion_webui_civitai_helper_image_extensions = ( + ".preview.png", # order matters + ".preview.webp", + ".preview.jpeg", + ".preview.jpg", + ".preview.gif", + ".preview.apng", +) +preview_extensions = ( # TODO: JavaScript does not know about this (x2 states) + image_extensions + # order matters + stable_diffusion_webui_civitai_helper_image_extensions +) +model_info_extension = ".txt" #video_extensions = (".avi", ".mp4", ".webm") # TODO: Requires ffmpeg or cv2. Cache preview frame? +def split_valid_ext(s, *arg_exts): + sl = s.lower() + for exts in arg_exts: + for ext in exts: + if sl.endswith(ext.lower()): + return (s[:-len(ext)], ext) + return (s, "") + _folder_names_and_paths = None # dict[str, tuple[list[str], list[str]]] def folder_paths_folder_names_and_paths(refresh = False): global _folder_names_and_paths @@ -194,18 +224,18 @@ async def get_model_preview(request): uri = request.query.get("uri") image_path = no_preview_image - image_extension = "png" + image_type = "png" image_data = None if uri != "no-preview": sep = os.path.sep uri = uri.replace("/" if sep == "\\" else "/", sep) path, _ = search_path_to_system_path(uri) - head, extension = os.path.splitext(path) + head, extension = split_valid_ext(path, preview_extensions) if os.path.exists(path): - image_extension = extension[1:] + image_type = extension.rsplit(".", 1)[1] image_path = path - elif os.path.exists(head) and os.path.splitext(head)[1] == ".safetensors": - image_extension = extension[1:] + elif os.path.exists(head) and head.endswith(".safetensors"): + image_type = extension.rsplit(".", 1)[1] header = get_safetensor_header(head) metadata = header.get("__metadata__", None) if metadata is not None: @@ -218,23 +248,32 @@ async def get_model_preview(request): with open(image_path, "rb") as file: image_data = file.read() - return web.Response(body=image_data, content_type="image/" + image_extension) + return web.Response(body=image_data, content_type="image/" + image_type) + + +@server.PromptServer.instance.routes.get("/model-manager/image/extensions") +async def get_image_extensions(request): + return web.json_response(image_extensions) def download_model_preview(formdata): path = formdata.get("path", None) if type(path) is not str: raise ("Invalid path!") - path, _ = search_path_to_system_path(path) - path_without_extension, _ = os.path.splitext(path) + path, model_type = search_path_to_system_path(path) + model_type_extensions = folder_paths_get_supported_pt_extensions(model_type) + path_without_extension, _ = split_valid_ext(path, model_type_extensions) overwrite = formdata.get("overwrite", "true").lower() overwrite = True if overwrite == "true" else False image = formdata.get("image", None) if type(image) is str: - image_path = download_image(image, path, overwrite) - _, image_extension = os.path.splitext(image_path) + _, image_extension = split_valid_ext(image, image_extensions) # TODO: doesn't work for https://civitai.com/images/... + if image_extension == "": + raise ValueError("Invalid image type!") + image_path = path_without_extension + image_extension + download_file(image, image_path, overwrite) else: content_type = image.content_type if not content_type.startswith("image/"): @@ -251,7 +290,7 @@ def download_model_preview(formdata): with open(image_path, "wb") as f: f.write(image_data) - delete_same_name_files(path_without_extension, image_extensions, image_extension) + delete_same_name_files(path_without_extension, preview_extensions, image_extension) @server.PromptServer.instance.routes.post("/model-manager/preview/set") @@ -272,10 +311,11 @@ async def delete_model_preview(request): return web.json_response({ "success": False }) model_path = urllib.parse.unquote(model_path) - file, _ = search_path_to_system_path(model_path) - path_and_name, _ = os.path.splitext(file) - delete_same_name_files(path_and_name, image_extensions) - + model_path, model_type = search_path_to_system_path(model_path) + model_extensions = folder_paths_get_supported_pt_extensions(model_type) + path_and_name, _ = split_valid_ext(model_path, model_extensions) + delete_same_name_files(path_and_name, preview_extensions) + return web.json_response({ "success": True }) @@ -297,26 +337,32 @@ async def get_model_list(request): for base_path_index, model_base_path in enumerate(folder_paths_get_folder_paths(model_type)): if not os.path.exists(model_base_path): # TODO: Bug in main code? ("ComfyUI\output\checkpoints", "ComfyUI\output\clip", "ComfyUI\models\t2i_adapter", "ComfyUI\output\vae") continue - for cwd, _subdirs, files in os.walk(model_base_path): + for cwd, subdirs, files in os.walk(model_base_path): dir_models = [] dir_images = [] for file in files: if file.lower().endswith(model_extensions): dir_models.append(file) - elif file.lower().endswith(image_extensions): + elif file.lower().endswith(preview_extensions): dir_images.append(file) for model in dir_models: - model_name, model_ext = os.path.splitext(model) + model_name, model_ext = split_valid_ext(model, model_extensions) image = None image_modified = None - for iImage in range(len(dir_images)-1, -1, -1): - image_name, _ = os.path.splitext(dir_images[iImage]) - if model_name == image_name: - image = end_swap_and_pop(dir_images, iImage) - img_abs_path = os.path.join(cwd, image) - image_modified = pathlib.Path(img_abs_path).stat().st_mtime_ns + for ext in preview_extensions: # order matters + for iImage in range(len(dir_images)-1, -1, -1): + image_name = dir_images[iImage] + if not image_name.lower().endswith(ext.lower()): + continue + image_name = image_name[:-len(ext)] + if model_name == image_name: + image = end_swap_and_pop(dir_images, iImage) + img_abs_path = os.path.join(cwd, image) + image_modified = pathlib.Path(img_abs_path).stat().st_mtime_ns + break + if image is not None: break abs_path = os.path.join(cwd, model) stats = pathlib.Path(abs_path).stat() @@ -378,7 +424,7 @@ def linear_directory_hierarchy(refresh = False): for dir_path_index, dir_path in enumerate(model_dirs): if not os.path.exists(dir_path) or os.path.isfile(dir_path): continue - + #dir_list.append({ "name": str(dir_path_index), "childIndex": None, "childCount": 0 }) dir_stack = [(dir_path, model_dir_child_index + dir_path_index)] while len(dir_stack) > 0: # DEPTH-FIRST @@ -403,8 +449,7 @@ def linear_directory_hierarchy(refresh = False): dir_child_count += 1 else: # file - _, file_extension = os.path.splitext(item_name) - if extension_whitelist is None or file_extension in extension_whitelist: + if extension_whitelist is None or split_valid_ext(item_name, extension_whitelist)[1] != "": dir_list.append({ "name": item_name }) dir_child_count += 1 if dir_child_count > 0: @@ -514,16 +559,6 @@ def download_file(url, filename, overwrite): os.rename(filename_temp, filename) -def download_image(image_uri, model_path, overwrite): - _, extension = os.path.splitext(image_uri) # TODO: doesn't work for https://civitai.com/images/... - if not extension in image_extensions: - raise ValueError("Invalid image type!") - path_without_extension, _ = os.path.splitext(model_path) - file = path_without_extension + extension - download_file(image_uri, file, overwrite) - return file - - @server.PromptServer.instance.routes.get("/model-manager/model/info") async def get_model_info(request): model_path = request.query.get("path", None) @@ -531,35 +566,36 @@ async def get_model_info(request): return web.json_response({ "success": False }) model_path = urllib.parse.unquote(model_path) - file, _ = search_path_to_system_path(model_path) - if file is None: + abs_path, model_type = search_path_to_system_path(model_path) + if abs_path is None: return web.json_response({}) info = {} - path, name = os.path.split(model_path) + comfyui_directory, name = os.path.split(model_path) info["File Name"] = name - info["File Directory"] = path - info["File Size"] = str(os.path.getsize(file)) + " bytes" - stats = pathlib.Path(file).stat() + info["File Directory"] = comfyui_directory + info["File Size"] = str(os.path.getsize(abs_path)) + " bytes" + stats = pathlib.Path(abs_path).stat() date_format = "%Y-%m-%d %H:%M:%S" date_modified = datetime.fromtimestamp(stats.st_mtime).strftime(date_format) info["Date Modified"] = date_modified info["Date Created"] = datetime.fromtimestamp(stats.st_ctime).strftime(date_format) - file_name, _ = os.path.splitext(file) + model_extensions = folder_paths_get_supported_pt_extensions(model_type) + abs_name , _ = split_valid_ext(abs_path, model_extensions) - for extension in image_extensions: - maybe_image = file_name + extension - if os.path.isfile(maybe_image): - image_path, _ = os.path.splitext(model_path) - image_modified = pathlib.Path(maybe_image).stat().st_mtime_ns + for extension in preview_extensions: + maybe_preview = abs_name + extension + if os.path.isfile(maybe_preview): + preview_path, _ = split_valid_ext(model_path, model_extensions) + preview_modified = pathlib.Path(maybe_preview).stat().st_mtime_ns info["Preview"] = { - "path": urllib.parse.quote_plus(image_path + extension), - "dateModified": urllib.parse.quote_plus(str(image_modified)), + "path": urllib.parse.quote_plus(preview_path + extension), + "dateModified": urllib.parse.quote_plus(str(preview_modified)), } break - header = get_safetensor_header(file) + header = get_safetensor_header(abs_path) metadata = header.get("__metadata__", None) #json.dump(metadata, sys.stdout, indent=4) #print() @@ -622,7 +658,7 @@ async def get_model_info(request): training_comment if training_comment != "None" else "" ).strip() - txt_file = file_name + ".txt" + txt_file = abs_name + model_info_extension notes = "" if os.path.isfile(txt_file): with open(txt_file, 'r', encoding="utf-8") as f: @@ -687,8 +723,9 @@ async def download_model(request): return web.json_response(result) name = formdata.get("name") - _, model_extension = os.path.splitext(name) - if not model_extension in folder_paths_get_supported_pt_extensions(model_type): + model_extensions = folder_paths_get_supported_pt_extensions(model_type) + _, model_extension = split_valid_ext(name, model_extensions) + if model_extension == "": result["invalid"] = "name" return web.json_response(result) file_name = os.path.join(directory, name) @@ -725,8 +762,9 @@ async def move_model(request): old_file, old_model_type = search_path_to_system_path(old_file) if not os.path.isfile(old_file): return web.json_response({ "success": False }) - _, model_extension = os.path.splitext(old_file) - if not model_extension in folder_paths_get_supported_pt_extensions(old_model_type): + old_model_extensions = folder_paths_get_supported_pt_extensions(old_model_type) + old_file_without_extension, model_extension = split_valid_ext(old_file, old_model_extensions) + if model_extension == "": # cannot move arbitrary files return web.json_response({ "success": False }) @@ -740,7 +778,10 @@ async def move_model(request): if os.path.isfile(new_file): # cannot overwrite existing file return web.json_response({ "success": False }) - if not model_extension in folder_paths_get_supported_pt_extensions(new_model_type): + new_model_extensions = folder_paths_get_supported_pt_extensions(new_model_type) + new_file_without_extension, new_model_extension = split_valid_ext(new_file, new_model_extensions) + if model_extension != new_model_extension: + # cannot change extension return web.json_response({ "success": False }) new_file_dir, _ = os.path.split(new_file) if not os.path.isdir(new_file_dir): @@ -754,11 +795,8 @@ async def move_model(request): print(e, file=sys.stderr, flush=True) return web.json_response({ "success": False }) - old_file_without_extension, _ = os.path.splitext(old_file) - new_file_without_extension, _ = os.path.splitext(new_file) - - # TODO: this could overwrite existing files... - for extension in image_extensions + (".txt",): + # TODO: this could overwrite existing files in destination... + for extension in preview_extensions + (model_info_extension,): old_file = old_file_without_extension + extension if os.path.isfile(old_file): try: @@ -772,9 +810,9 @@ async def move_model(request): def delete_same_name_files(path_without_extension, extensions, keep_extension=None): for extension in extensions: if extension == keep_extension: continue - image_file = path_without_extension + extension - if os.path.isfile(image_file): - os.remove(image_file) + file = path_without_extension + extension + if os.path.isfile(file): + os.remove(file) @server.PromptServer.instance.routes.post("/model-manager/model/delete") @@ -785,27 +823,22 @@ async def delete_model(request): if model_path is None: return web.json_response(result) model_path = urllib.parse.unquote(model_path) - - file, model_type = search_path_to_system_path(model_path) - if file is None: + model_path, model_type = search_path_to_system_path(model_path) + if model_path is None: return web.json_response(result) - _, extension = os.path.splitext(file) - if not extension in folder_paths_get_supported_pt_extensions(model_type): + model_extensions = folder_paths_get_supported_pt_extensions(model_type) + path_and_name, model_extension = split_valid_ext(model_path, model_extensions) + if model_extension == "": # cannot delete arbitrary files return web.json_response(result) - if os.path.isfile(file): - os.remove(file) + if os.path.isfile(model_path): + os.remove(model_path) result["success"] = True - path_and_name, _ = os.path.splitext(file) - - delete_same_name_files(path_and_name, image_extensions) - - txt_file = path_and_name + ".txt" - if os.path.isfile(txt_file): - os.remove(txt_file) + delete_same_name_files(path_and_name, preview_extensions) + delete_same_name_files(path_and_name, (model_info_extension,)) return web.json_response(result) @@ -821,9 +854,10 @@ async def set_notes(request): model_path = body.get("path", None) if type(model_path) is not str: return web.json_response({ "success": False }) - model_path, _ = search_path_to_system_path(model_path) - file_path_without_extension, _ = os.path.splitext(model_path) - filename = os.path.normpath(file_path_without_extension + ".txt") + model_path, model_type = search_path_to_system_path(model_path) + model_extensions = folder_paths_get_supported_pt_extensions(model_type) + file_path_without_extension, _ = split_valid_ext(model_path, model_extensions) + filename = os.path.normpath(file_path_without_extension + model_info_extension) if text.isspace() or text == "": if os.path.exists(filename): os.remove(filename) diff --git a/web/model-manager.js b/web/model-manager.js index 397a17d..8ddfcff 100644 --- a/web/model-manager.js +++ b/web/model-manager.js @@ -35,7 +35,21 @@ const modelNodeType = { }; const MODEL_EXTENSIONS = [".bin", ".ckpt", ".onnx", ".pt", ".pth", ".safetensors"]; // TODO: ask server for? -const IMAGE_EXTENSIONS = [".apng", ".gif", ".jpeg", ".jpg", ".png", ".webp"]; // TODO: ask server for? +const IMAGE_EXTENSIONS = [ + ".png", + ".webp", + ".jpeg", + ".jpg", + ".gif", + ".apng", + + ".preview.png", + ".preview.webp", + ".preview.jpeg", + ".preview.jpg", + ".preview.gif", + ".preview.apng", +]; // TODO: /model-manager/image/extensions class SearchPath { /** @@ -1413,7 +1427,7 @@ class ModelInfoView { $: (el) => (this.elements.setPreviewButton = el), textContent: "Set as Preview", onclick: async(e) => { - const confirmation = window.confirm("Change preview image PERMANENTLY?"); + const confirmation = window.confirm("Change preview image(s) PERMANENTLY?"); let updatedPreview = false; if (confirmation) { e.target.disabled = true;