Improved models dirs, search & ui.
This commit is contained in:
253
__init__.py
253
__init__.py
@@ -1,106 +1,229 @@
|
||||
import os
|
||||
import sys
|
||||
import hashlib
|
||||
from aiohttp import web
|
||||
import server
|
||||
import os
|
||||
import urllib.parse
|
||||
import struct
|
||||
import json
|
||||
import requests
|
||||
import folder_paths
|
||||
|
||||
requests.packages.urllib3.disable_warnings()
|
||||
|
||||
def folder_paths_get_supported_pt_extensions(folder_name): # Missing api function.
|
||||
return folder_paths.folder_names_and_paths[folder_name][1]
|
||||
|
||||
|
||||
model_uri = os.path.join(os.getcwd(), "models")
|
||||
extension_uri = os.path.join(os.getcwd(), "custom_nodes/ComfyUI-Model-Manager")
|
||||
comfyui_model_uri = os.path.join(os.getcwd(), "models")
|
||||
extension_uri = os.path.join(os.getcwd(), "custom_nodes" + os.path.sep + "ComfyUI-Model-Manager")
|
||||
index_uri = os.path.join(extension_uri, "index.json")
|
||||
#checksum_cache_uri = os.path.join(extension_uri, "checksum_cache.txt")
|
||||
no_preview_image = os.path.join(extension_uri, "no-preview.png")
|
||||
|
||||
model_type_dir_dict = {
|
||||
"checkpoint": "checkpoints",
|
||||
"clip": "clip",
|
||||
"clip_vision": "clip_vision",
|
||||
"controlnet": "controlnet",
|
||||
"diffuser": "diffusers",
|
||||
"embedding": "embeddings",
|
||||
"gligen": "gligen",
|
||||
"hypernetwork": "hypernetworks",
|
||||
"lora": "loras",
|
||||
"style_models": "style_models",
|
||||
"unet": "unet",
|
||||
"upscale_model": "upscale_models",
|
||||
"vae": "vae",
|
||||
"vae_approx": "vae_approx",
|
||||
}
|
||||
image_extensions = (".apng", ".gif", ".jpeg", ".jpg", ".png", ".webp")
|
||||
#video_extensions = (".avi", ".mp4", ".webm") # TODO: Requires ffmpeg or cv2. Cache preview frame?
|
||||
|
||||
#hash_buffer_size = 4096
|
||||
|
||||
def get_safetensor_header(path):
|
||||
try:
|
||||
with open(path, "rb") as f:
|
||||
length_of_header = struct.unpack("<Q", f.read(8))[0]
|
||||
header_bytes = f.read(length_of_header)
|
||||
header_json = json.loads(header_bytes)
|
||||
return header_json
|
||||
except:
|
||||
return {}
|
||||
|
||||
|
||||
@server.PromptServer.instance.routes.get("/model-manager/imgPreview")
|
||||
def end_swap_and_pop(x, i):
|
||||
x[i], x[-1] = x[-1], x[i]
|
||||
return x.pop(-1)
|
||||
|
||||
|
||||
def model_type_to_dir_name(model_type):
|
||||
# TODO: Figure out how to remove this.
|
||||
match model_type:
|
||||
case "checkpoint":
|
||||
return "checkpoints"
|
||||
case "diffuser":
|
||||
return "diffusers"
|
||||
case "embedding":
|
||||
return "embeddings"
|
||||
case "hypernetwork":
|
||||
return "hypernetworks"
|
||||
case "lora":
|
||||
return "loras"
|
||||
case "upscale_model":
|
||||
return "upscale_models"
|
||||
return model_type
|
||||
|
||||
|
||||
@server.PromptServer.instance.routes.get("/model-manager/image-preview")
|
||||
async def img_preview(request):
|
||||
uri = request.query.get("uri")
|
||||
filepath = os.path.join(model_uri, uri)
|
||||
|
||||
if os.path.exists(filepath):
|
||||
with open(filepath, "rb") as img_file:
|
||||
image_data = img_file.read()
|
||||
else:
|
||||
with open(os.path.join(extension_uri, "no-preview.png"), "rb") as img_file:
|
||||
image_data = img_file.read()
|
||||
image_path = no_preview_image
|
||||
image_extension = "png"
|
||||
|
||||
return web.Response(body=image_data, content_type="image/png")
|
||||
if (uri != "no-post"):
|
||||
rel_image_path = os.path.dirname(uri)
|
||||
|
||||
i = uri.find(os.path.sep)
|
||||
model_type = uri[0:i]
|
||||
|
||||
import json
|
||||
j = uri.find(os.path.sep, i + len(os.path.sep))
|
||||
if j == -1:
|
||||
j = len(rel_image_path)
|
||||
base_index = int(uri[i + len(os.path.sep):j])
|
||||
base_path = folder_paths.get_folder_paths(model_type)[base_index]
|
||||
|
||||
abs_image_path = os.path.normpath(base_path + os.path.sep + uri[j:]) # do NOT use os.path.join
|
||||
if os.path.exists(abs_image_path):
|
||||
image_path = abs_image_path
|
||||
_, image_extension = os.path.splitext(uri)
|
||||
image_extension = image_extension[1:]
|
||||
|
||||
with open(image_path, "rb") as img_file:
|
||||
image_data = img_file.read()
|
||||
|
||||
return web.Response(body=image_data, content_type="image/" + image_extension)
|
||||
|
||||
#def calculate_sha256(file_path):
|
||||
# try:
|
||||
# with open(file_path, "rb") as f:
|
||||
# sha256 = hashlib.sha256()
|
||||
# while True:
|
||||
# data = f.read(hash_buffer_size)
|
||||
# if not data:
|
||||
# break
|
||||
# sha256.update(data)
|
||||
# return sha256.hexdigest()
|
||||
# except:
|
||||
# return ""
|
||||
|
||||
@server.PromptServer.instance.routes.get("/model-manager/source")
|
||||
async def load_source_from(request):
|
||||
uri = request.query.get("uri", "local")
|
||||
if uri == "local":
|
||||
with open(os.path.join(extension_uri, "index.json")) as file:
|
||||
with open(index_uri) as file:
|
||||
dataSource = json.load(file)
|
||||
else:
|
||||
response = requests.get(uri)
|
||||
dataSource = response.json()
|
||||
|
||||
# check if it installed
|
||||
model_types = os.listdir(comfyui_model_uri)
|
||||
model_types.remove("configs")
|
||||
sourceSorted = {}
|
||||
for model_type in model_types:
|
||||
sourceSorted[model_type] = []
|
||||
for item in dataSource:
|
||||
model_type = item.get("type")
|
||||
model_name = item.get("name")
|
||||
model_type_path = model_type_dir_dict.get(model_type)
|
||||
if model_type_path is None:
|
||||
continue
|
||||
if os.path.exists(os.path.join(model_uri, model_type_path, model_name)):
|
||||
item["installed"] = True
|
||||
item_model_type = model_type_to_dir_name(item.get("type"))
|
||||
sourceSorted[item_model_type].append(item)
|
||||
item["installed"] = False
|
||||
|
||||
#checksum_cache = []
|
||||
#if os.path.exists(checksum_cache_uri):
|
||||
# with open(checksum_cache_uri, "r") as file:
|
||||
# checksum_cache = file.read().splitlines()
|
||||
#else:
|
||||
# with open(checksum_cache_uri, "w") as file:
|
||||
# pass
|
||||
#print(checksum_cache)
|
||||
|
||||
for model_type in model_types:
|
||||
for model_base_path in folder_paths.get_folder_paths(model_type):
|
||||
if not os.path.exists(model_base_path): # Bug in main code?
|
||||
continue
|
||||
for cwd, _subdirs, files in os.walk(model_base_path):
|
||||
for file in files:
|
||||
source_type = sourceSorted[model_type]
|
||||
for iItem in range(len(source_type)-1,-1,-1):
|
||||
item = source_type[iItem]
|
||||
|
||||
# TODO: Make hashing optional (because it is slow to compute).
|
||||
if file != item.get("name"):
|
||||
continue
|
||||
|
||||
#file_path = os.path.join(cwd, file)
|
||||
#file_size = int(item.get("size") or 0)
|
||||
#if os.path.getsize(file_path) != file_size:
|
||||
# continue
|
||||
#
|
||||
#checksum = item.get("SHA256")
|
||||
#if checksum == "" or checksum == None:
|
||||
# continue
|
||||
# BUG: Model always hashed if same size but different hash.
|
||||
# TODO: Change code to save list (NOT dict) with absolute model path and checksum on each line
|
||||
#if checksum not in checksum_cache:
|
||||
# sha256 = calculate_sha256(file_path) # TODO: Make checksum optional!
|
||||
# checksum_cache.append(sha256)
|
||||
# print(f"{file}: calc:{sha256}, real:{checksum}")
|
||||
# if sha256 != checksum:
|
||||
# continue
|
||||
|
||||
item["installed"] = True
|
||||
end_swap_and_pop(source_type, iItem)
|
||||
|
||||
#with open(checksum_cache_uri, "w") as file:
|
||||
# file.writelines(checksum + '\n' for checksum in checksum_cache) # because python is a mess
|
||||
|
||||
return web.json_response(dataSource)
|
||||
|
||||
|
||||
@server.PromptServer.instance.routes.get("/model-manager/models")
|
||||
async def load_download_models(request):
|
||||
model_types = os.listdir(model_uri)
|
||||
model_types = sorted(model_types)
|
||||
model_types = [content for content in model_types if content != "configs"]
|
||||
model_types = os.listdir(comfyui_model_uri)
|
||||
model_types.remove("configs")
|
||||
model_types.sort()
|
||||
|
||||
model_suffix = (".safetensors", ".pt", ".pth", ".bin", ".ckpt")
|
||||
models = {}
|
||||
|
||||
for model_type in model_types:
|
||||
model_type_uri = os.path.join(model_uri, model_type)
|
||||
filenames = os.listdir(model_type_uri)
|
||||
filenames = sorted(filenames)
|
||||
model_files = [f for f in filenames if f.endswith(model_suffix)]
|
||||
model_extensions = tuple(folder_paths_get_supported_pt_extensions(model_type))
|
||||
file_names = []
|
||||
for base_path_index, model_base_path in enumerate(folder_paths.get_folder_paths(model_type)):
|
||||
if not os.path.exists(model_base_path): # Bug in main code?
|
||||
continue
|
||||
for cwd, _subdirs, files in os.walk(model_base_path):
|
||||
dir_models = []
|
||||
dir_images = []
|
||||
|
||||
def name2item(name):
|
||||
item = {"name": name}
|
||||
file_name, ext = os.path.splitext(name)
|
||||
post_name = file_name + ".png"
|
||||
if post_name in filenames:
|
||||
post_path = os.path.join(model_type, post_name)
|
||||
item["post"] = post_path
|
||||
return item
|
||||
for file in files:
|
||||
if file.lower().endswith(model_extensions):
|
||||
dir_models.append(file)
|
||||
elif file.lower().endswith(image_extensions):
|
||||
dir_images.append(file)
|
||||
|
||||
for model in dir_models:
|
||||
model_name, _ = os.path.splitext(model)
|
||||
image = None
|
||||
for iImage in range(len(dir_images)-1, -1, -1):
|
||||
image_name, _ = os.path.splitext(dir_images[iImage])
|
||||
if model_name == image_name:
|
||||
image = end_swap_and_pop(dir_images, iImage)
|
||||
break
|
||||
rel_path = "" if cwd == model_base_path else os.path.relpath(cwd, model_base_path)
|
||||
file_names.append((model, image, base_path_index, rel_path))
|
||||
file_names.sort(key=lambda tup: tup[0].lower())
|
||||
|
||||
model_items = []
|
||||
for model, image, base_path_index, rel_path in file_names:
|
||||
name, _ = os.path.splitext(model)
|
||||
item = {
|
||||
"name": name,
|
||||
"path": os.path.join(model_type, rel_path, model).replace(os.path.sep, "/"),
|
||||
}
|
||||
if image is not None:
|
||||
raw_post = os.path.join(model_type, str(base_path_index), rel_path, image)
|
||||
item["post"] = urllib.parse.quote_plus(raw_post)
|
||||
model_items.append(item)
|
||||
|
||||
model_items = list(map(name2item, model_files))
|
||||
models[model_type] = model_items
|
||||
|
||||
return web.json_response(models)
|
||||
|
||||
|
||||
import sys
|
||||
import requests
|
||||
|
||||
|
||||
requests.packages.urllib3.disable_warnings()
|
||||
|
||||
def_headers = {
|
||||
"User-Agent": "Mozilla/5.0 (iPad; CPU OS 12_2 like Mac OS X) AppleWebKit/605.1.15 (KHTML, like Gecko) Mobile/15E148"
|
||||
}
|
||||
@@ -154,7 +277,7 @@ def download_model_file(url, filename):
|
||||
async def download_file(request):
|
||||
body = await request.json()
|
||||
model_type = body.get("type")
|
||||
model_type_path = model_type_dir_dict.get(model_type)
|
||||
model_type_path = model_type_to_dir_name(model_type)
|
||||
if model_type_path is None:
|
||||
return web.json_response({"success": False})
|
||||
|
||||
@@ -163,9 +286,9 @@ async def download_file(request):
|
||||
return web.json_response({"success": False})
|
||||
|
||||
model_name = body.get("name")
|
||||
file_name = os.path.join(model_uri, model_type_path, model_name)
|
||||
file_name = os.path.join(comfyui_model_uri, model_type_path, model_name)
|
||||
download_model_file(download_uri, file_name)
|
||||
print("文件下载完成!")
|
||||
print("File download completed!")
|
||||
return web.json_response({"success": True})
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user