View models in sub-directories of ComfyUI/models.

This commit is contained in:
Christian Bastian
2024-01-20 21:54:26 -05:00
parent a8fa7c6c15
commit 66a15b5978
3 changed files with 45 additions and 29 deletions

View File

@@ -1,5 +1,6 @@
import os
import sys
import copy
import hashlib
import importlib
@@ -30,23 +31,38 @@ image_extensions = (".apng", ".gif", ".jpeg", ".jpg", ".png", ".webp")
#hash_buffer_size = 4096
_folder_names_and_paths: dict[str, tuple[list[str], list[str]]] = None
def folder_paths_folder_names_and_paths(refresh = False) -> dict[str, tuple[list[str], list[str]]]:
global _folder_names_and_paths
if refresh or _folder_names_and_paths is None:
_folder_names_and_paths = {}
for item_name in os.listdir(comfyui_model_uri):
item_path = os.path.join(comfyui_model_uri, item_name)
if not os.path.isdir(item_path):
continue
if item_name in folder_paths.folder_names_and_paths:
dir_paths, extensions = copy.deepcopy(folder_paths.folder_names_and_paths[item_name])
else:
dir_paths = [item_name]
extensions = [".ckpt", ".pt", ".bin", ".pth", ".safetensors"] # TODO: magic values
_folder_names_and_paths[item_name] = (dir_paths, extensions)
return _folder_names_and_paths
def folder_paths_get_folder_paths(folder_name): # API function crashes querying unknown model folder
paths = folder_paths.folder_names_and_paths
def folder_paths_get_folder_paths(folder_name, refresh = False) -> list[str]: # API function crashes querying unknown model folder
paths = folder_paths_folder_names_and_paths(refresh)
if folder_name in paths:
return paths[folder_name][0][:]
return paths[folder_name][0]
maybe_path = os.path.join(comfyui_model_uri, folder_name)
if os.path.exists(maybe_path):
return [maybe_path]
return []
def folder_paths_get_supported_pt_extensions(folder_name): # Missing API function
paths = folder_paths.folder_names_and_paths
def folder_paths_get_supported_pt_extensions(folder_name, refresh = False) -> list[str]: # Missing API function
paths = folder_paths_folder_names_and_paths(refresh)
if folder_name in paths:
return paths[folder_name][1]
return set([".ckpt", ".pt", ".bin", ".pth", ".safetensors"])
return set([".ckpt", ".pt", ".bin", ".pth", ".safetensors"]) # TODO: magic values
def get_safetensor_header(path):