Triage folder_paths crashing on unregistered dir.

This commit is contained in:
Christian Bastian
2024-01-04 07:13:28 -05:00
parent 5517ae68e4
commit 724a9425c4
2 changed files with 55 additions and 39 deletions

View File

@@ -11,10 +11,6 @@ import folder_paths
requests.packages.urllib3.disable_warnings() requests.packages.urllib3.disable_warnings()
def folder_paths_get_supported_pt_extensions(folder_name): # Missing api function.
return folder_paths.folder_names_and_paths[folder_name][1]
comfyui_model_uri = os.path.join(os.getcwd(), "models") comfyui_model_uri = os.path.join(os.getcwd(), "models")
extension_uri = os.path.join(os.getcwd(), "custom_nodes" + os.path.sep + "ComfyUI-Model-Manager") extension_uri = os.path.join(os.getcwd(), "custom_nodes" + os.path.sep + "ComfyUI-Model-Manager")
index_uri = os.path.join(extension_uri, "index.json") index_uri = os.path.join(extension_uri, "index.json")
@@ -26,6 +22,26 @@ image_extensions = (".apng", ".gif", ".jpeg", ".jpg", ".png", ".webp")
#hash_buffer_size = 4096 #hash_buffer_size = 4096
def folder_paths_get_folder_paths(folder_name): # API function crashes querying unknown model folder
paths = folder_paths.folder_names_and_paths
if folder_name in paths:
return paths[folder_name][0][:]
maybe_path = os.path.join(comfyui_model_uri, folder_name)
if os.path.exists(maybe_path):
return [maybe_path]
return []
def folder_paths_get_supported_pt_extensions(folder_name): # Missing API function
paths = folder_paths.folder_names_and_paths
if folder_name in paths:
return paths[folder_name][1]
return set(['.ckpt', '.pt', '.bin', '.pth', '.safetensors'])
def get_safetensor_header(path): def get_safetensor_header(path):
try: try:
with open(path, "rb") as f: with open(path, "rb") as f:
@@ -43,21 +59,21 @@ def end_swap_and_pop(x, i):
def model_type_to_dir_name(model_type): def model_type_to_dir_name(model_type):
# TODO: Figure out how to remove this. if model_type == "checkpoint": return "checkpoints"
match model_type: #elif model_type == "clip": return "clip"
case "checkpoint": #elif model_type == "clip_vision": return "clip_vision"
return "checkpoints" #elif model_type == "controlnet": return "controlnet"
case "diffuser": elif model_type == "diffuser": return "diffusers"
return "diffusers" elif model_type == "embedding": return "embeddings"
case "embedding": #elif model_type== "gligen": return "gligen"
return "embeddings" elif model_type == "hypernetwork": return "hypernetworks"
case "hypernetwork": elif model_type == "lora": return "loras"
return "hypernetworks" #elif model_type == "style_models": return "style_models"
case "lora": #elif model_type == "unet": return "unet"
return "loras" elif model_type == "upscale_model": return "upscale_models"
case "upscale_model": #elif model_type == "vae": return "vae"
return "upscale_models" #elif model_type == "vae_approx": return "vae_approx"
return model_type else: return model_type
@server.PromptServer.instance.routes.get("/model-manager/image-preview") @server.PromptServer.instance.routes.get("/model-manager/image-preview")
@@ -77,7 +93,7 @@ async def img_preview(request):
if j == -1: if j == -1:
j = len(rel_image_path) j = len(rel_image_path)
base_index = int(uri[i + len(os.path.sep):j]) base_index = int(uri[i + len(os.path.sep):j])
base_path = folder_paths.get_folder_paths(model_type)[base_index] base_path = folder_paths_get_folder_paths(model_type)[base_index]
abs_image_path = os.path.normpath(base_path + os.path.sep + uri[j:]) # do NOT use os.path.join abs_image_path = os.path.normpath(base_path + os.path.sep + uri[j:]) # do NOT use os.path.join
if os.path.exists(abs_image_path): if os.path.exists(abs_image_path):
@@ -133,7 +149,7 @@ async def load_source_from(request):
#print(checksum_cache) #print(checksum_cache)
for model_type in model_types: for model_type in model_types:
for model_base_path in folder_paths.get_folder_paths(model_type): for model_base_path in folder_paths_get_folder_paths(model_type):
if not os.path.exists(model_base_path): # Bug in main code? if not os.path.exists(model_base_path): # Bug in main code?
continue continue
for cwd, _subdirs, files in os.walk(model_base_path): for cwd, _subdirs, files in os.walk(model_base_path):
@@ -182,7 +198,7 @@ async def load_download_models(request):
for model_type in model_types: for model_type in model_types:
model_extensions = tuple(folder_paths_get_supported_pt_extensions(model_type)) model_extensions = tuple(folder_paths_get_supported_pt_extensions(model_type))
file_names = [] file_names = []
for base_path_index, model_base_path in enumerate(folder_paths.get_folder_paths(model_type)): for base_path_index, model_base_path in enumerate(folder_paths_get_folder_paths(model_type)):
if not os.path.exists(model_base_path): # Bug in main code? if not os.path.exists(model_base_path): # Bug in main code?
continue continue
for cwd, _subdirs, files in os.walk(model_base_path): for cwd, _subdirs, files in os.walk(model_base_path):

View File

@@ -22,20 +22,20 @@ function request(url, options) {
} }
function modelNodeType(modelType) { function modelNodeType(modelType) {
if (modelType === "checkpoints") return "CheckpointLoaderSimple"; if (modelType === "checkpoints") { return "CheckpointLoaderSimple"; }
else if (modelType === "clip") return "CLIPLoader"; else if (modelType === "clip") { return "CLIPLoader"; }
else if (modelType === "clip_vision") return "CLIPVisionLoader"; else if (modelType === "clip_vision") { return "CLIPVisionLoader"; }
else if (modelType === "controlnet") return "ControlNetLoader"; else if (modelType === "controlnet") { return "ControlNetLoader"; }
else if (modelType === "diffusers") return "DiffusersLoader"; else if (modelType === "diffusers") { return "DiffusersLoader"; }
else if (modelType === "embeddings") return "Embedding"; else if (modelType === "embeddings") { return "Embedding"; }
else if (modelType === "gligen") return "GLIGENLoader"; else if (modelType === "gligen") { return "GLIGENLoader"; }
else if (modelType === "hypernetworks") return "HypernetworkLoader"; else if (modelType === "hypernetworks") { return "HypernetworkLoader"; }
else if (modelType === "loras") return "LoraLoader"; else if (modelType === "loras") { return "LoraLoader"; }
else if (modelType === "style_models") return "StyleModelLoader"; else if (modelType === "style_models") { return "StyleModelLoader"; }
else if (modelType === "unet") return "UNETLoader"; else if (modelType === "unet") { return "UNETLoader"; }
else if (modelType === "upscale_models") return "UpscaleModelLoader"; else if (modelType === "upscale_models") { return "UpscaleModelLoader"; }
else if (modelType === "vae") return "VAELoader"; else if (modelType === "vae") { return "VAELoader"; }
else if (modelType === "vae_approx") return undefined; else if (modelType === "vae_approx") { return undefined; }
else { console.warn(`ModelType ${modelType} unrecognized.`); return undefined; } else { console.warn(`ModelType ${modelType} unrecognized.`); return undefined; }
} }
@@ -48,8 +48,8 @@ function pathToFileString(path) {
return path.slice(i); return path.slice(i);
} }
function insertEmbeddingIntoText(currentText, embeddingFile, removeExtension = false) { function insertEmbeddingIntoText(currentText, embeddingFile, extensionRegex = null) {
if (removeExtension) { if (extensionRegex) {
// TODO: setting.remove_extension_embedding // TODO: setting.remove_extension_embedding
} }
// TODO: don't add if it is already in the text? // TODO: don't add if it is already in the text?
@@ -712,7 +712,7 @@ class ModelManager extends ComfyDialog {
name: "model-type", name: "model-type",
onchange: () => this.#modelGridUpdate(), onchange: () => this.#modelGridUpdate(),
}, },
[ [ // TODO: generate based on existing model folders
$el("option", ["checkpoints"]), $el("option", ["checkpoints"]),
$el("option", ["clip"]), $el("option", ["clip"]),
$el("option", ["clip_vision"]), $el("option", ["clip_vision"]),