Triage folder_paths crashing on unregistered dir.

This commit is contained in:
Christian Bastian
2024-01-04 07:13:28 -05:00
parent 5517ae68e4
commit 724a9425c4
2 changed files with 55 additions and 39 deletions

View File

@@ -11,10 +11,6 @@ import folder_paths
requests.packages.urllib3.disable_warnings()
def folder_paths_get_supported_pt_extensions(folder_name): # Missing api function.
return folder_paths.folder_names_and_paths[folder_name][1]
comfyui_model_uri = os.path.join(os.getcwd(), "models")
extension_uri = os.path.join(os.getcwd(), "custom_nodes" + os.path.sep + "ComfyUI-Model-Manager")
index_uri = os.path.join(extension_uri, "index.json")
@@ -26,6 +22,26 @@ image_extensions = (".apng", ".gif", ".jpeg", ".jpg", ".png", ".webp")
#hash_buffer_size = 4096
def folder_paths_get_folder_paths(folder_name): # API function crashes querying unknown model folder
paths = folder_paths.folder_names_and_paths
if folder_name in paths:
return paths[folder_name][0][:]
maybe_path = os.path.join(comfyui_model_uri, folder_name)
if os.path.exists(maybe_path):
return [maybe_path]
return []
def folder_paths_get_supported_pt_extensions(folder_name): # Missing API function
paths = folder_paths.folder_names_and_paths
if folder_name in paths:
return paths[folder_name][1]
return set(['.ckpt', '.pt', '.bin', '.pth', '.safetensors'])
def get_safetensor_header(path):
try:
with open(path, "rb") as f:
@@ -43,21 +59,21 @@ def end_swap_and_pop(x, i):
def model_type_to_dir_name(model_type):
# TODO: Figure out how to remove this.
match model_type:
case "checkpoint":
return "checkpoints"
case "diffuser":
return "diffusers"
case "embedding":
return "embeddings"
case "hypernetwork":
return "hypernetworks"
case "lora":
return "loras"
case "upscale_model":
return "upscale_models"
return model_type
if model_type == "checkpoint": return "checkpoints"
#elif model_type == "clip": return "clip"
#elif model_type == "clip_vision": return "clip_vision"
#elif model_type == "controlnet": return "controlnet"
elif model_type == "diffuser": return "diffusers"
elif model_type == "embedding": return "embeddings"
#elif model_type== "gligen": return "gligen"
elif model_type == "hypernetwork": return "hypernetworks"
elif model_type == "lora": return "loras"
#elif model_type == "style_models": return "style_models"
#elif model_type == "unet": return "unet"
elif model_type == "upscale_model": return "upscale_models"
#elif model_type == "vae": return "vae"
#elif model_type == "vae_approx": return "vae_approx"
else: return model_type
@server.PromptServer.instance.routes.get("/model-manager/image-preview")
@@ -77,7 +93,7 @@ async def img_preview(request):
if j == -1:
j = len(rel_image_path)
base_index = int(uri[i + len(os.path.sep):j])
base_path = folder_paths.get_folder_paths(model_type)[base_index]
base_path = folder_paths_get_folder_paths(model_type)[base_index]
abs_image_path = os.path.normpath(base_path + os.path.sep + uri[j:]) # do NOT use os.path.join
if os.path.exists(abs_image_path):
@@ -133,7 +149,7 @@ async def load_source_from(request):
#print(checksum_cache)
for model_type in model_types:
for model_base_path in folder_paths.get_folder_paths(model_type):
for model_base_path in folder_paths_get_folder_paths(model_type):
if not os.path.exists(model_base_path): # Bug in main code?
continue
for cwd, _subdirs, files in os.walk(model_base_path):
@@ -182,7 +198,7 @@ async def load_download_models(request):
for model_type in model_types:
model_extensions = tuple(folder_paths_get_supported_pt_extensions(model_type))
file_names = []
for base_path_index, model_base_path in enumerate(folder_paths.get_folder_paths(model_type)):
for base_path_index, model_base_path in enumerate(folder_paths_get_folder_paths(model_type)):
if not os.path.exists(model_base_path): # Bug in main code?
continue
for cwd, _subdirs, files in os.walk(model_base_path):

View File

@@ -22,20 +22,20 @@ function request(url, options) {
}
function modelNodeType(modelType) {
if (modelType === "checkpoints") return "CheckpointLoaderSimple";
else if (modelType === "clip") return "CLIPLoader";
else if (modelType === "clip_vision") return "CLIPVisionLoader";
else if (modelType === "controlnet") return "ControlNetLoader";
else if (modelType === "diffusers") return "DiffusersLoader";
else if (modelType === "embeddings") return "Embedding";
else if (modelType === "gligen") return "GLIGENLoader";
else if (modelType === "hypernetworks") return "HypernetworkLoader";
else if (modelType === "loras") return "LoraLoader";
else if (modelType === "style_models") return "StyleModelLoader";
else if (modelType === "unet") return "UNETLoader";
else if (modelType === "upscale_models") return "UpscaleModelLoader";
else if (modelType === "vae") return "VAELoader";
else if (modelType === "vae_approx") return undefined;
if (modelType === "checkpoints") { return "CheckpointLoaderSimple"; }
else if (modelType === "clip") { return "CLIPLoader"; }
else if (modelType === "clip_vision") { return "CLIPVisionLoader"; }
else if (modelType === "controlnet") { return "ControlNetLoader"; }
else if (modelType === "diffusers") { return "DiffusersLoader"; }
else if (modelType === "embeddings") { return "Embedding"; }
else if (modelType === "gligen") { return "GLIGENLoader"; }
else if (modelType === "hypernetworks") { return "HypernetworkLoader"; }
else if (modelType === "loras") { return "LoraLoader"; }
else if (modelType === "style_models") { return "StyleModelLoader"; }
else if (modelType === "unet") { return "UNETLoader"; }
else if (modelType === "upscale_models") { return "UpscaleModelLoader"; }
else if (modelType === "vae") { return "VAELoader"; }
else if (modelType === "vae_approx") { return undefined; }
else { console.warn(`ModelType ${modelType} unrecognized.`); return undefined; }
}
@@ -48,8 +48,8 @@ function pathToFileString(path) {
return path.slice(i);
}
function insertEmbeddingIntoText(currentText, embeddingFile, removeExtension = false) {
if (removeExtension) {
function insertEmbeddingIntoText(currentText, embeddingFile, extensionRegex = null) {
if (extensionRegex) {
// TODO: setting.remove_extension_embedding
}
// TODO: don't add if it is already in the text?
@@ -712,7 +712,7 @@ class ModelManager extends ComfyDialog {
name: "model-type",
onchange: () => this.#modelGridUpdate(),
},
[
[ // TODO: generate based on existing model folders
$el("option", ["checkpoints"]),
$el("option", ["clip"]),
$el("option", ["clip_vision"]),