Preview image improvements.
- Model Tab grid receives smaller previews from server. - Attempted to make PIL image `info` serializable for previews. - Get full size previews from Civitai. - Note, the Civitai server may return nothing for the image id. (External bug?) - Support downloading previews from https://civitai.com/images/ - Lazy Loading in Model Tab.
This commit is contained in:
180
__init__.py
180
__init__.py
@@ -198,6 +198,26 @@ def server_rules():
|
||||
server_settings = config_loader.yaml_load(server_settings_uri, server_rules())
|
||||
config_loader.yaml_save(server_settings_uri, server_rules(), server_settings)
|
||||
|
||||
|
||||
def get_def_headers(url=""):
|
||||
def_headers = {
|
||||
"User-Agent": "Mozilla/5.0 (iPad; CPU OS 12_2 like Mac OS X) AppleWebKit/605.1.15 (KHTML, like Gecko) Mobile/15E148",
|
||||
}
|
||||
|
||||
if url.startswith("https://civitai.com/"):
|
||||
api_key = server_settings["civitai_api_key"]
|
||||
if (api_key != ""):
|
||||
def_headers["Authorization"] = f"Bearer {api_key}"
|
||||
url += "&" if "?" in url else "?" # not the most robust solution
|
||||
url += f"token={api_key}" # TODO: Authorization didn't work in the header
|
||||
elif url.startswith("https://huggingface.co/"):
|
||||
api_key = server_settings["huggingface_api_key"]
|
||||
if api_key != "":
|
||||
def_headers["Authorization"] = f"Bearer {api_key}"
|
||||
|
||||
return def_headers
|
||||
|
||||
|
||||
@server.PromptServer.instance.routes.get("/model-manager/settings/load")
|
||||
async def load_ui_settings(request):
|
||||
rules = ui_rules()
|
||||
@@ -218,34 +238,105 @@ async def save_ui_settings(request):
|
||||
})
|
||||
|
||||
|
||||
from PIL import Image, TiffImagePlugin
|
||||
from PIL.PngImagePlugin import PngInfo
|
||||
def PIL_cast_serializable(v):
|
||||
# source: https://github.com/python-pillow/Pillow/issues/6199#issuecomment-1214854558
|
||||
if isinstance(v, TiffImagePlugin.IFDRational):
|
||||
return float(v)
|
||||
elif isinstance(v, tuple):
|
||||
return tuple(PIL_cast_serializable(t) for t in v)
|
||||
elif isinstance(v, bytes):
|
||||
return v.decode(errors="replace")
|
||||
elif isinstance(v, dict):
|
||||
for kk, vv in v.items():
|
||||
v[kk] = PIL_cast_serializable(vv)
|
||||
return v
|
||||
else:
|
||||
return v
|
||||
|
||||
|
||||
def get_safetensors_image_bytes(path):
|
||||
if not os.path.isfile(path):
|
||||
raise RuntimeError("Path was invalid!")
|
||||
header = get_safetensor_header(path)
|
||||
metadata = header.get("__metadata__", None)
|
||||
if metadata is None:
|
||||
return None
|
||||
thumbnail = metadata.get("modelspec.thumbnail", None)
|
||||
if thumbnail is None:
|
||||
return None
|
||||
image_data = thumbnail.split(',')[1]
|
||||
return base64.b64decode(image_data)
|
||||
|
||||
|
||||
@server.PromptServer.instance.routes.get("/model-manager/preview/get")
|
||||
async def get_model_preview(request):
|
||||
uri = request.query.get("uri")
|
||||
|
||||
image_path = no_preview_image
|
||||
image_type = "png"
|
||||
image_data = None
|
||||
if uri != "no-preview":
|
||||
sep = os.path.sep
|
||||
uri = uri.replace("/" if sep == "\\" else "/", sep)
|
||||
path, _ = search_path_to_system_path(uri)
|
||||
head, extension = split_valid_ext(path, preview_extensions)
|
||||
if os.path.exists(path):
|
||||
image_type = extension.rsplit(".", 1)[1]
|
||||
image_path = path
|
||||
elif os.path.exists(head) and head.endswith(".safetensors"):
|
||||
image_type = extension.rsplit(".", 1)[1]
|
||||
header = get_safetensor_header(head)
|
||||
metadata = header.get("__metadata__", None)
|
||||
if metadata is not None:
|
||||
thumbnail = metadata.get("modelspec.thumbnail", None)
|
||||
if thumbnail is not None:
|
||||
image_data = thumbnail.split(',')[1]
|
||||
image_data = base64.b64decode(image_data)
|
||||
elif os.path.exists(head) and head.endswith(".safetensors"):
|
||||
image_path = head
|
||||
image_type = extension.rsplit(".", 1)[1]
|
||||
|
||||
if image_data == None:
|
||||
with open(image_path, "rb") as file:
|
||||
image_data = file.read()
|
||||
w = request.query.get("width")
|
||||
h = request.query.get("height")
|
||||
try:
|
||||
w = int(w)
|
||||
if w < 1:
|
||||
w = None
|
||||
except:
|
||||
w = None
|
||||
try:
|
||||
h = int(h)
|
||||
if w < 1:
|
||||
h = None
|
||||
except:
|
||||
h = None
|
||||
|
||||
image_data = None
|
||||
if w is None and h is None: # full size
|
||||
if image_path.endswith(".safetensors"):
|
||||
image_data = get_safetensors_image_bytes(image_path)
|
||||
else:
|
||||
with open(image_path, "rb") as image:
|
||||
image_data = image.read()
|
||||
else:
|
||||
if image_path.endswith(".safetensors"):
|
||||
image_data = get_safetensors_image_bytes(image_path)
|
||||
fp = io.BytesIO(image_data)
|
||||
else:
|
||||
fp = image_path
|
||||
|
||||
with Image.open(fp) as image:
|
||||
w0, h0 = image.size
|
||||
if w is None:
|
||||
w = (h * w0) // h0
|
||||
elif h is None:
|
||||
h = (w * h0) // w0
|
||||
|
||||
exif = image.getexif()
|
||||
|
||||
metadata = None
|
||||
if len(image.info) > 0:
|
||||
metadata = PngInfo()
|
||||
for (key, value) in image.info.items():
|
||||
value_str = str(PIL_cast_serializable(value)) # not sure if this is correct (sometimes includes exif)
|
||||
metadata.add_text(key, value_str)
|
||||
|
||||
image.thumbnail((w, h))
|
||||
|
||||
image_bytes = io.BytesIO()
|
||||
image.save(image_bytes, format=image.format, exif=exif, pnginfo=metadata)
|
||||
image_data = image_bytes.getvalue()
|
||||
|
||||
return web.Response(body=image_data, content_type="image/" + image_type)
|
||||
|
||||
@@ -268,7 +359,30 @@ def download_model_preview(formdata):
|
||||
|
||||
image = formdata.get("image", None)
|
||||
if type(image) is str:
|
||||
_, image_extension = split_valid_ext(image, image_extensions) # TODO: doesn't work for https://civitai.com/images/...
|
||||
civitai_image_url = "https://civitai.com/images/"
|
||||
if image.startswith(civitai_image_url):
|
||||
image_id = re.search(r"^\d+", image[len(civitai_image_url):]).group(0)
|
||||
image_id = str(int(image_id))
|
||||
image_info_url = f"https://civitai.com/api/v1/images?imageId={image_id}"
|
||||
def_headers = get_def_headers(image_info_url)
|
||||
response = requests.get(
|
||||
url=image_info_url,
|
||||
stream=False,
|
||||
verify=False,
|
||||
headers=def_headers,
|
||||
proxies=None,
|
||||
allow_redirects=False,
|
||||
)
|
||||
if response.ok:
|
||||
content_type = response.headers.get("Content-Type")
|
||||
info = response.json()
|
||||
items = info["items"]
|
||||
if len(items) == 0:
|
||||
raise RuntimeError("Civitai /api/v1/images returned 0 items!")
|
||||
image = items[0]["url"]
|
||||
else:
|
||||
raise RuntimeError("Bad response from api/v1/images!")
|
||||
_, image_extension = split_valid_ext(image, image_extensions)
|
||||
if image_extension == "":
|
||||
raise ValueError("Invalid image type!")
|
||||
image_path = path_without_extension + image_extension
|
||||
@@ -474,21 +588,15 @@ def download_file(url, filename, overwrite):
|
||||
|
||||
filename_temp = filename + ".download"
|
||||
|
||||
def_headers = {
|
||||
"User-Agent": "Mozilla/5.0 (iPad; CPU OS 12_2 like Mac OS X) AppleWebKit/605.1.15 (KHTML, like Gecko) Mobile/15E148",
|
||||
}
|
||||
|
||||
if url.startswith("https://civitai.com/"):
|
||||
api_key = server_settings["civitai_api_key"]
|
||||
if (api_key != ""):
|
||||
def_headers["Authorization"] = f"Bearer {api_key}"
|
||||
url += "&" if "?" in url else "?" # not the most robust solution
|
||||
url += f"token={api_key}" # TODO: Authorization didn't work in the header
|
||||
elif url.startswith("https://huggingface.co/"):
|
||||
api_key = server_settings["huggingface_api_key"]
|
||||
if api_key != "":
|
||||
def_headers["Authorization"] = f"Bearer {api_key}"
|
||||
rh = requests.get(url=url, stream=True, verify=False, headers=def_headers, proxies=None, allow_redirects=False)
|
||||
def_headers = get_def_headers(url)
|
||||
rh = requests.get(
|
||||
url=url,
|
||||
stream=True,
|
||||
verify=False,
|
||||
headers=def_headers,
|
||||
proxies=None,
|
||||
allow_redirects=False,
|
||||
)
|
||||
if not rh.ok:
|
||||
raise ValueError(
|
||||
"Unable to download! Request header status code: " +
|
||||
@@ -501,8 +609,16 @@ def download_file(url, filename, overwrite):
|
||||
|
||||
headers = {"Range": "bytes=%d-" % downloaded_size}
|
||||
headers["User-Agent"] = def_headers["User-Agent"]
|
||||
|
||||
r = requests.get(url=url, stream=True, verify=False, headers=headers, proxies=None, allow_redirects=False)
|
||||
headers["Authorization"] = def_headers.get("Authorization", None)
|
||||
|
||||
r = requests.get(
|
||||
url=url,
|
||||
stream=True,
|
||||
verify=False,
|
||||
headers=headers,
|
||||
proxies=None,
|
||||
allow_redirects=False,
|
||||
)
|
||||
if rh.status_code == 307 and r.status_code == 307:
|
||||
# Civitai redirect
|
||||
redirect_url = r.content.decode("utf-8")
|
||||
|
||||
Reference in New Issue
Block a user