From 724a9425c4f1ce3a487ee9336adb54a1b0834838 Mon Sep 17 00:00:00 2001 From: Christian Bastian Date: Thu, 4 Jan 2024 07:13:28 -0500 Subject: [PATCH] Triage `folder_paths` crashing on unregistered dir. --- __init__.py | 60 ++++++++++++++++++++++++++++---------------- web/model-manager.js | 34 ++++++++++++------------- 2 files changed, 55 insertions(+), 39 deletions(-) diff --git a/__init__.py b/__init__.py index 6715c57..35bcd6a 100644 --- a/__init__.py +++ b/__init__.py @@ -11,10 +11,6 @@ import folder_paths requests.packages.urllib3.disable_warnings() -def folder_paths_get_supported_pt_extensions(folder_name): # Missing api function. - return folder_paths.folder_names_and_paths[folder_name][1] - - comfyui_model_uri = os.path.join(os.getcwd(), "models") extension_uri = os.path.join(os.getcwd(), "custom_nodes" + os.path.sep + "ComfyUI-Model-Manager") index_uri = os.path.join(extension_uri, "index.json") @@ -26,6 +22,26 @@ image_extensions = (".apng", ".gif", ".jpeg", ".jpg", ".png", ".webp") #hash_buffer_size = 4096 + +def folder_paths_get_folder_paths(folder_name): # API function crashes querying unknown model folder + paths = folder_paths.folder_names_and_paths + if folder_name in paths: + return paths[folder_name][0][:] + + maybe_path = os.path.join(comfyui_model_uri, folder_name) + if os.path.exists(maybe_path): + return [maybe_path] + return [] + + +def folder_paths_get_supported_pt_extensions(folder_name): # Missing API function + paths = folder_paths.folder_names_and_paths + if folder_name in paths: + return paths[folder_name][1] + + return set(['.ckpt', '.pt', '.bin', '.pth', '.safetensors']) + + def get_safetensor_header(path): try: with open(path, "rb") as f: @@ -43,21 +59,21 @@ def end_swap_and_pop(x, i): def model_type_to_dir_name(model_type): - # TODO: Figure out how to remove this. - match model_type: - case "checkpoint": - return "checkpoints" - case "diffuser": - return "diffusers" - case "embedding": - return "embeddings" - case "hypernetwork": - return "hypernetworks" - case "lora": - return "loras" - case "upscale_model": - return "upscale_models" - return model_type + if model_type == "checkpoint": return "checkpoints" + #elif model_type == "clip": return "clip" + #elif model_type == "clip_vision": return "clip_vision" + #elif model_type == "controlnet": return "controlnet" + elif model_type == "diffuser": return "diffusers" + elif model_type == "embedding": return "embeddings" + #elif model_type== "gligen": return "gligen" + elif model_type == "hypernetwork": return "hypernetworks" + elif model_type == "lora": return "loras" + #elif model_type == "style_models": return "style_models" + #elif model_type == "unet": return "unet" + elif model_type == "upscale_model": return "upscale_models" + #elif model_type == "vae": return "vae" + #elif model_type == "vae_approx": return "vae_approx" + else: return model_type @server.PromptServer.instance.routes.get("/model-manager/image-preview") @@ -77,7 +93,7 @@ async def img_preview(request): if j == -1: j = len(rel_image_path) base_index = int(uri[i + len(os.path.sep):j]) - base_path = folder_paths.get_folder_paths(model_type)[base_index] + base_path = folder_paths_get_folder_paths(model_type)[base_index] abs_image_path = os.path.normpath(base_path + os.path.sep + uri[j:]) # do NOT use os.path.join if os.path.exists(abs_image_path): @@ -133,7 +149,7 @@ async def load_source_from(request): #print(checksum_cache) for model_type in model_types: - for model_base_path in folder_paths.get_folder_paths(model_type): + for model_base_path in folder_paths_get_folder_paths(model_type): if not os.path.exists(model_base_path): # Bug in main code? continue for cwd, _subdirs, files in os.walk(model_base_path): @@ -182,7 +198,7 @@ async def load_download_models(request): for model_type in model_types: model_extensions = tuple(folder_paths_get_supported_pt_extensions(model_type)) file_names = [] - for base_path_index, model_base_path in enumerate(folder_paths.get_folder_paths(model_type)): + for base_path_index, model_base_path in enumerate(folder_paths_get_folder_paths(model_type)): if not os.path.exists(model_base_path): # Bug in main code? continue for cwd, _subdirs, files in os.walk(model_base_path): diff --git a/web/model-manager.js b/web/model-manager.js index 9f5bc2f..acc965c 100644 --- a/web/model-manager.js +++ b/web/model-manager.js @@ -22,20 +22,20 @@ function request(url, options) { } function modelNodeType(modelType) { - if (modelType === "checkpoints") return "CheckpointLoaderSimple"; - else if (modelType === "clip") return "CLIPLoader"; - else if (modelType === "clip_vision") return "CLIPVisionLoader"; - else if (modelType === "controlnet") return "ControlNetLoader"; - else if (modelType === "diffusers") return "DiffusersLoader"; - else if (modelType === "embeddings") return "Embedding"; - else if (modelType === "gligen") return "GLIGENLoader"; - else if (modelType === "hypernetworks") return "HypernetworkLoader"; - else if (modelType === "loras") return "LoraLoader"; - else if (modelType === "style_models") return "StyleModelLoader"; - else if (modelType === "unet") return "UNETLoader"; - else if (modelType === "upscale_models") return "UpscaleModelLoader"; - else if (modelType === "vae") return "VAELoader"; - else if (modelType === "vae_approx") return undefined; + if (modelType === "checkpoints") { return "CheckpointLoaderSimple"; } + else if (modelType === "clip") { return "CLIPLoader"; } + else if (modelType === "clip_vision") { return "CLIPVisionLoader"; } + else if (modelType === "controlnet") { return "ControlNetLoader"; } + else if (modelType === "diffusers") { return "DiffusersLoader"; } + else if (modelType === "embeddings") { return "Embedding"; } + else if (modelType === "gligen") { return "GLIGENLoader"; } + else if (modelType === "hypernetworks") { return "HypernetworkLoader"; } + else if (modelType === "loras") { return "LoraLoader"; } + else if (modelType === "style_models") { return "StyleModelLoader"; } + else if (modelType === "unet") { return "UNETLoader"; } + else if (modelType === "upscale_models") { return "UpscaleModelLoader"; } + else if (modelType === "vae") { return "VAELoader"; } + else if (modelType === "vae_approx") { return undefined; } else { console.warn(`ModelType ${modelType} unrecognized.`); return undefined; } } @@ -48,8 +48,8 @@ function pathToFileString(path) { return path.slice(i); } -function insertEmbeddingIntoText(currentText, embeddingFile, removeExtension = false) { - if (removeExtension) { +function insertEmbeddingIntoText(currentText, embeddingFile, extensionRegex = null) { + if (extensionRegex) { // TODO: setting.remove_extension_embedding } // TODO: don't add if it is already in the text? @@ -712,7 +712,7 @@ class ModelManager extends ComfyDialog { name: "model-type", onchange: () => this.#modelGridUpdate(), }, - [ + [ // TODO: generate based on existing model folders $el("option", ["checkpoints"]), $el("option", ["clip"]), $el("option", ["clip_vision"]),