Model Previews can now be set in Model View.

- Abstracted out Radio Buttons for Preview selection (mostly clean).
- Added REST API for preview/set and preview/delete.
- Added dateModified to query string so the browser can detect out of date preview images.
- Added image path and dateModified to Model Info payload.
This commit is contained in:
Christian Bastian
2024-02-22 05:11:38 -05:00
parent c4b6ddd5ca
commit d951a508ed
4 changed files with 665 additions and 337 deletions

View File

@@ -1,4 +1,5 @@
import os
import io
import pathlib
import shutil
from datetime import datetime
@@ -30,7 +31,7 @@ ui_settings_uri = os.path.join(extension_uri, "ui_settings.yaml")
server_settings_uri = os.path.join(extension_uri, "server_settings.yaml")
fallback_model_extensions = set([".bin", ".ckpt", ".onnx", ".pt", ".pth", ".safetensors"]) # TODO: magic values
image_extensions = (".apng", ".gif", ".jpeg", ".jpg", ".png", ".webp")
image_extensions = (".apng", ".gif", ".jpeg", ".jpg", ".png", ".webp") # TODO: JavaScript does not know about this (x2 states)
#video_extensions = (".avi", ".mp4", ".webm") # TODO: Requires ffmpeg or cv2. Cache preview frame?
_folder_names_and_paths = None # dict[str, tuple[list[str], list[str]]]
@@ -195,21 +196,11 @@ async def get_model_preview(request):
image_path = no_preview_image
image_extension = "png"
if uri != "no-post":
rel_image_path = os.path.dirname(uri)
i = uri.find(os.path.sep)
model_type = uri[0:i]
j = uri.find(os.path.sep, i + len(os.path.sep))
if j == -1:
j = len(rel_image_path)
base_index = int(uri[i + len(os.path.sep):j])
base_path = folder_paths_get_folder_paths(model_type)[base_index]
abs_image_path = os.path.normpath(base_path + os.path.sep + uri[j:]) # do NOT use os.path.join
if os.path.exists(abs_image_path):
image_path = abs_image_path
if uri != "no-preview":
sep = os.path.sep
uri = uri.replace("/" if sep == "\\" else "/", os.path.sep)
image_path, _ = search_path_to_system_path(uri)
if os.path.exists(image_path):
_, image_extension = os.path.splitext(uri)
image_extension = image_extension[1:]
@@ -219,6 +210,64 @@ async def get_model_preview(request):
return web.Response(body=image_data, content_type="image/" + image_extension)
def download_model_preview(formdata):
path = formdata.get("path", None)
if type(path) is not str:
raise ("Invalid path!")
path, _ = search_path_to_system_path(path)
path_without_extension, _ = os.path.splitext(path)
overwrite = formdata.get("overwrite", "true").lower()
overwrite = True if overwrite == "true" else False
image = formdata.get("image", None)
if type(image) is str:
image_path = download_image(image, path, overwrite)
_, image_extension = os.path.splitext(image_path)
else:
content_type = image.content_type
if not content_type.startswith("image/"):
raise ("Invalid content type!")
image_extension = "." + content_type[len("image/"):]
if image_extension not in image_extensions:
raise ("Invalid extension!")
image_path = path_without_extension + image_extension
if not overwrite and os.path.isfile(image_path):
raise ("Image already exists!")
file: io.IOBase = image.file
image_data = file.read()
with open(image_path, "wb") as f:
f.write(image_data)
delete_same_name_files(path_without_extension, image_extensions, image_extension)
@server.PromptServer.instance.routes.post("/model-manager/preview/set")
async def set_model_preview(request):
formdata = await request.post()
try:
download_model_preview(formdata)
return web.json_response({ "success": True })
except ValueError as e:
print(e, file=sys.stderr, flush=True)
return web.json_response({ "success": False })
@server.PromptServer.instance.routes.post("/model-manager/preview/delete")
async def delete_model_preview(request):
model_path = request.query.get("path", None)
if model_path is None:
return web.json_response({ "success": False })
model_path = urllib.parse.unquote(model_path)
file, _ = search_path_to_system_path(model_path)
path_and_name, _ = os.path.splitext(file)
delete_same_name_files(path_and_name, image_extensions)
return web.json_response({ "success": True })
@server.PromptServer.instance.routes.get("/model-manager/models/list")
async def load_download_models(request):
model_types = os.listdir(comfyui_model_uri)
@@ -245,34 +294,40 @@ async def load_download_models(request):
for model in dir_models:
model_name, _ = os.path.splitext(model)
image = None
image_modified = None
for iImage in range(len(dir_images)-1, -1, -1):
image_name, _ = os.path.splitext(dir_images[iImage])
if model_name == image_name:
image = end_swap_and_pop(dir_images, iImage)
img_abs_path = os.path.join(cwd, image)
image_modified = pathlib.Path(img_abs_path).stat().st_mtime_ns
break
abs_path = os.path.join(cwd, model)
stats = pathlib.Path(abs_path).stat()
date_modified = stats.st_mtime_ns
date_created = stats.st_ctime_ns
model_modified = stats.st_mtime_ns
model_created = stats.st_ctime_ns
rel_path = "" if cwd == model_base_path else os.path.relpath(cwd, model_base_path)
info = (model, image, base_path_index, rel_path, date_modified, date_created)
info = (model, image, base_path_index, rel_path, model_modified, model_created, image_modified)
file_infos.append(info)
file_infos.sort(key=lambda tup: tup[4], reverse=True) # TODO: remove sort; sorted on client
model_items = []
for model, image, base_path_index, rel_path, date_modified, date_created in file_infos:
for model, image, base_path_index, rel_path, model_modified, model_created, image_modified in file_infos:
item = {
"name": model,
"path": "/" + os.path.join(model_type, str(base_path_index), rel_path, model).replace(os.path.sep, "/"), # relative logical path
#"systemPath": os.path.join(rel_path, model), # relative system path (less information than "search path")
"dateModified": date_modified,
"dateCreated": date_created,
"dateModified": model_modified,
"dateCreated": model_created,
#"dateLastUsed": "", # TODO: track server-side, send increment client-side
#"countUsed": 0, # TODO: track server-side, send increment client-side
}
if image is not None:
raw_post = os.path.join(model_type, str(base_path_index), rel_path, image)
item["post"] = urllib.parse.quote_plus(raw_post)
item["preview"] = {
"path": urllib.parse.quote_plus(raw_post),
"dateModified": urllib.parse.quote_plus(str(image_modified)),
}
model_items.append(item)
models[model_type] = model_items
@@ -342,13 +397,14 @@ async def directory_list(request):
def download_file(url, filename, overwrite):
if not overwrite and os.path.isfile(filename):
raise Exception("File already exists!")
raise ValueError("File already exists!")
filename_temp = filename + ".download"
def_headers = {
"User-Agent": "Mozilla/5.0 (iPad; CPU OS 12_2 like Mac OS X) AppleWebKit/605.1.15 (KHTML, like Gecko) Mobile/15E148",
}
if url.startswith("https://civitai.com/"):
api_key = server_settings["civitai_api_key"]
if (api_key != ""):
@@ -358,10 +414,9 @@ def download_file(url, filename, overwrite):
api_key = server_settings["huggingface_api_key"]
if api_key != "":
def_headers["Authorization"] = f"Bearer {api_key}"
rh = requests.get(url=url, stream=True, verify=False, headers=def_headers, proxies=None, allow_redirects=False)
if not rh.ok:
raise Exception("Unable to download")
raise ValueError("Unable to download")
downloaded_size = 0
if rh.status_code == 200 and os.path.exists(filename_temp):
@@ -369,7 +424,7 @@ def download_file(url, filename, overwrite):
headers = {"Range": "bytes=%d-" % downloaded_size}
headers["User-Agent"] = def_headers["User-Agent"]
r = requests.get(url=url, stream=True, verify=False, headers=headers, proxies=None, allow_redirects=False)
if rh.status_code == 307 and r.status_code == 307:
# Civitai redirect
@@ -377,7 +432,7 @@ def download_file(url, filename, overwrite):
if not redirect_url.startswith("http"):
# Civitai requires login (NSFW or user-required)
# TODO: inform user WHY download failed
raise Exception("Unable to download!")
raise ValueError("Unable to download!")
download_file(redirect_url, filename, overwrite)
return
if rh.status_code == 302 and r.status_code == 302:
@@ -385,7 +440,7 @@ def download_file(url, filename, overwrite):
redirect_url = r.content.decode("utf-8")
redirect_url_index = redirect_url.find("http")
if redirect_url_index == -1:
raise Exception("Unable to download!")
raise ValueError("Unable to download!")
download_file(redirect_url[redirect_url_index:], filename, overwrite)
return
elif rh.status_code == 200 and r.status_code == 206:
@@ -419,18 +474,33 @@ def download_file(url, filename, overwrite):
)
)
sys.stdout.flush()
print()
if overwrite and os.path.isfile(filename):
os.remove(filename)
os.rename(filename_temp, filename)
def download_image(image_uri, model_path, overwrite):
extension = None # TODO: doesn't work for https://civitai.com/images/...
for image_extension in image_extensions:
if image_uri.endswith(image_extension):
extension = image_extension
break
if extension is None:
raise ValueError("Invalid image type!")
path_without_extension, _ = os.path.splitext(model_path)
file = path_without_extension + extension
download_file(image_uri, file, overwrite)
return file
@server.PromptServer.instance.routes.get("/model-manager/model/info")
async def get_model_info(request):
model_path = request.query.get("path", None)
if model_path is None:
return web.json_response({})
return web.json_response({ "success": False })
model_path = urllib.parse.unquote(model_path)
file, _ = search_path_to_system_path(model_path)
@@ -441,12 +511,25 @@ async def get_model_info(request):
path, name = os.path.split(model_path)
info["File Name"] = name
info["File Directory"] = path
info["File Size"] = os.path.getsize(file)
info["File Size"] = str(os.path.getsize(file)) + " bytes"
stats = pathlib.Path(file).stat()
date_format = "%Y/%m/%d %H:%M:%S"
info["Date Created"] = datetime.fromtimestamp(stats.st_ctime).strftime(date_format)
info["Date Modified"] = datetime.fromtimestamp(stats.st_mtime).strftime(date_format)
file_name, _ = os.path.splitext(file)
for extension in image_extensions:
maybe_image = file_name + extension
if os.path.isfile(maybe_image):
image_path, _ = os.path.splitext(model_path)
image_modified = pathlib.Path(maybe_image).stat().st_mtime_ns
info["Preview"] = {
"path": urllib.parse.quote_plus(image_path + extension),
"dateModified": urllib.parse.quote_plus(str(image_modified)),
}
break
header = get_safetensor_header(file)
metadata = header.get("__metadata__", None)
if metadata is not None:
@@ -455,7 +538,6 @@ async def get_model_info(request):
info["Hash"] = metadata.get("sshs_model_hash", "")
info["Output Name"] = metadata.get("ss_output_name", "")
file_name, _ = os.path.splitext(file)
txt_file = file_name + ".txt"
notes = ""
if os.path.isfile(txt_file):
@@ -500,26 +582,27 @@ async def get_system_separator(request):
@server.PromptServer.instance.routes.post("/model-manager/model/download")
async def download_model(request):
body = await request.json()
formdata = await request.post()
result = {
"success": False,
"invalid": None,
}
overwrite = body.get("overwrite", False)
overwrite = formdata.get("overwrite", "false").lower()
overwrite = True if overwrite == "true" else False
model_path = body.get("path", "/0")
model_path = formdata.get("path", "/0")
directory, model_type = search_path_to_system_path(model_path)
if directory is None:
result["invalid"] = "path"
return web.json_response(result)
download_uri = body.get("download")
download_uri = formdata.get("download")
if download_uri is None:
result["invalid"] = "download"
return web.json_response(result)
name = body.get("name")
name = formdata.get("name")
model_extension = None
for ext in folder_paths_get_supported_pt_extensions(model_type):
if name.endswith(ext):
@@ -531,27 +614,22 @@ async def download_model(request):
file_name = os.path.join(directory, name)
try:
download_file(download_uri, file_name, overwrite)
except:
result["invalid"] = "download"
except Exception as e:
print(e, file=sys.stderr, flush=True)
result["invalid"] = "model"
return web.json_response(result)
image_uri = body.get("image")
if image_uri is not None and image_uri != "":
image_extension = None # TODO: doesn't work for https://civitai.com/images/...
for ext in image_extensions:
if image_uri.endswith(ext):
image_extension = ext
break
if image_extension is not None:
file_path_without_extension = name[:len(name) - len(model_extension)]
image_name = os.path.join(
directory,
file_path_without_extension + image_extension
)
try:
download_file(image_uri, image_name, overwrite)
except Exception as e:
print(e, file=sys.stderr, flush=True)
image = formdata.get("image")
if image is not None and image != "":
try:
download_model_preview({
"path": model_path + os.sep + name,
"image": image,
"overwrite": formdata.get("overwrite"),
})
except Exception as e:
print(e, file=sys.stderr, flush=True)
result["invalid"] = "preview"
result["success"] = True
return web.json_response(result)
@@ -579,7 +657,8 @@ async def move_model(request):
new_file = os.path.join(new_path, filename)
try:
shutil.move(old_file, new_file)
except:
except ValueError as e:
print(e, file=sys.stderr, flush=True)
return web.json_response({ "success": False })
old_file_without_extension, _ = os.path.splitext(old_file)
@@ -590,12 +669,20 @@ async def move_model(request):
if os.path.isfile(old_file):
try:
shutil.move(old_file, new_file_without_extension + extension)
except Exception as e:
except ValueError as e:
print(e, file=sys.stderr, flush=True)
return web.json_response({ "success": True })
def delete_same_name_files(path_without_extension, extensions, keep_extension=None):
for extension in extensions:
if extension == keep_extension: continue
image_file = path_without_extension + extension
if os.path.isfile(image_file):
os.remove(image_file)
@server.PromptServer.instance.routes.post("/model-manager/model/delete")
async def delete_model(request):
result = { "success": False }
@@ -623,10 +710,7 @@ async def delete_model(request):
path_and_name, _ = os.path.splitext(file)
for img_ext in image_extensions:
image_file = path_and_name + img_ext
if os.path.isfile(image_file):
os.remove(image_file)
delete_same_name_files(path_and_name, image_extensions)
txt_file = path_and_name + ".txt"
if os.path.isfile(txt_file):
@@ -656,7 +740,8 @@ async def set_notes(request):
try:
with open(filename, "w", encoding="utf-8") as f:
f.write(text)
except:
except ValueError as e:
print(e, file=sys.stderr, flush=True)
web.json_response({ "success": False })
return web.json_response({ "success": True })