Triage folder_paths crashing on unregistered dir.

This commit is contained in:
Christian Bastian
2024-01-04 07:13:28 -05:00
parent 5517ae68e4
commit 724a9425c4
2 changed files with 55 additions and 39 deletions

View File

@@ -11,10 +11,6 @@ import folder_paths
requests.packages.urllib3.disable_warnings()
def folder_paths_get_supported_pt_extensions(folder_name): # Missing api function.
return folder_paths.folder_names_and_paths[folder_name][1]
comfyui_model_uri = os.path.join(os.getcwd(), "models")
extension_uri = os.path.join(os.getcwd(), "custom_nodes" + os.path.sep + "ComfyUI-Model-Manager")
index_uri = os.path.join(extension_uri, "index.json")
@@ -26,6 +22,26 @@ image_extensions = (".apng", ".gif", ".jpeg", ".jpg", ".png", ".webp")
#hash_buffer_size = 4096
def folder_paths_get_folder_paths(folder_name): # API function crashes querying unknown model folder
paths = folder_paths.folder_names_and_paths
if folder_name in paths:
return paths[folder_name][0][:]
maybe_path = os.path.join(comfyui_model_uri, folder_name)
if os.path.exists(maybe_path):
return [maybe_path]
return []
def folder_paths_get_supported_pt_extensions(folder_name): # Missing API function
paths = folder_paths.folder_names_and_paths
if folder_name in paths:
return paths[folder_name][1]
return set(['.ckpt', '.pt', '.bin', '.pth', '.safetensors'])
def get_safetensor_header(path):
try:
with open(path, "rb") as f:
@@ -43,21 +59,21 @@ def end_swap_and_pop(x, i):
def model_type_to_dir_name(model_type):
# TODO: Figure out how to remove this.
match model_type:
case "checkpoint":
return "checkpoints"
case "diffuser":
return "diffusers"
case "embedding":
return "embeddings"
case "hypernetwork":
return "hypernetworks"
case "lora":
return "loras"
case "upscale_model":
return "upscale_models"
return model_type
if model_type == "checkpoint": return "checkpoints"
#elif model_type == "clip": return "clip"
#elif model_type == "clip_vision": return "clip_vision"
#elif model_type == "controlnet": return "controlnet"
elif model_type == "diffuser": return "diffusers"
elif model_type == "embedding": return "embeddings"
#elif model_type== "gligen": return "gligen"
elif model_type == "hypernetwork": return "hypernetworks"
elif model_type == "lora": return "loras"
#elif model_type == "style_models": return "style_models"
#elif model_type == "unet": return "unet"
elif model_type == "upscale_model": return "upscale_models"
#elif model_type == "vae": return "vae"
#elif model_type == "vae_approx": return "vae_approx"
else: return model_type
@server.PromptServer.instance.routes.get("/model-manager/image-preview")
@@ -77,7 +93,7 @@ async def img_preview(request):
if j == -1:
j = len(rel_image_path)
base_index = int(uri[i + len(os.path.sep):j])
base_path = folder_paths.get_folder_paths(model_type)[base_index]
base_path = folder_paths_get_folder_paths(model_type)[base_index]
abs_image_path = os.path.normpath(base_path + os.path.sep + uri[j:]) # do NOT use os.path.join
if os.path.exists(abs_image_path):
@@ -133,7 +149,7 @@ async def load_source_from(request):
#print(checksum_cache)
for model_type in model_types:
for model_base_path in folder_paths.get_folder_paths(model_type):
for model_base_path in folder_paths_get_folder_paths(model_type):
if not os.path.exists(model_base_path): # Bug in main code?
continue
for cwd, _subdirs, files in os.walk(model_base_path):
@@ -182,7 +198,7 @@ async def load_download_models(request):
for model_type in model_types:
model_extensions = tuple(folder_paths_get_supported_pt_extensions(model_type))
file_names = []
for base_path_index, model_base_path in enumerate(folder_paths.get_folder_paths(model_type)):
for base_path_index, model_base_path in enumerate(folder_paths_get_folder_paths(model_type)):
if not os.path.exists(model_base_path): # Bug in main code?
continue
for cwd, _subdirs, files in os.walk(model_base_path):