diff --git a/__init__.py b/__init__.py index 53b1dbd..0a26952 100644 --- a/__init__.py +++ b/__init__.py @@ -272,11 +272,37 @@ def get_safetensors_image_bytes(path): return base64.b64decode(image_data) +def get_image_info(image): + metadata = None + if len(image.info) > 0: + metadata = PngInfo() + for (key, value) in image.info.items(): + value_str = str(PIL_cast_serializable(value)) # not sure if this is correct (sometimes includes exif) + metadata.add_text(key, value_str) + return metadata + + +def image_format_is_equal(f1, f2): + if not isinstance(f1, str) or not isinstance(f2, str): + return False + f1 = f1.upper() + f2 = f2.upper() + return f1 == f2 or (f1 == "JPG" and f2 == "JPEG") or (f1 == "JPEG" and f2 == "JPG") + + @server.PromptServer.instance.routes.get("/model-manager/preview/get") async def get_model_preview(request): uri = request.query.get("uri") + quality = request.query.get("quality", "75") + try: + quality = int(quality) + except: + quality = 75 + response_image_format = request.query.get("image-format", None) + if isinstance(response_image_format, str): + response_image_format = response_image_format.upper() + image_path = no_preview_image - image_type = "png" file_name = os.path.split(no_preview_image)[1] if uri != "no-preview": sep = os.path.sep @@ -285,12 +311,10 @@ async def get_model_preview(request): head, extension = split_valid_ext(path, preview_extensions) if os.path.exists(path): image_path = path - image_type = extension.rsplit(".", 1)[1] - file_name = os.path.split(head)[1] + "." + image_type + file_name = os.path.split(head)[1] + extension elif os.path.exists(head) and head.endswith(".safetensors"): image_path = head - image_type = extension.rsplit(".", 1)[1] - file_name = os.path.splitext(os.path.split(head)[1])[0] + "." + image_type + file_name = os.path.splitext(os.path.split(head)[1])[0] + extension w = request.query.get("width") h = request.query.get("height") @@ -314,6 +338,19 @@ async def get_model_preview(request): else: with open(image_path, "rb") as image: image_data = image.read() + fp = io.BytesIO(image_data) + with Image.open(fp) as image: + image_format = image.format + if response_image_format is None: + response_image_format = image_format + elif not image_format_is_equal(response_image_format, image_format): + exif = image.getexif() + metadata = get_image_info(image) + if response_image_format == 'JPEG' or response_image_format == 'JPG': + image = image.convert('RGB') + image_bytes = io.BytesIO() + image.save(image_bytes, format=response_image_format, exif=exif, pnginfo=metadata, quality=quality) + image_data = image_bytes.getvalue() else: if image_path.endswith(".safetensors"): image_data = get_safetensors_image_bytes(image_path) @@ -322,7 +359,9 @@ async def get_model_preview(request): fp = image_path with Image.open(fp) as image: - format = image.format + image_format = image.format + if response_image_format is None: + response_image_format = image_format w0, h0 = image.size if w is None: w = (h * w0) // h0 @@ -330,13 +369,7 @@ async def get_model_preview(request): h = (w * h0) // w0 exif = image.getexif() - - metadata = None - if len(image.info) > 0: - metadata = PngInfo() - for (key, value) in image.info.items(): - value_str = str(PIL_cast_serializable(value)) # not sure if this is correct (sometimes includes exif) - metadata.add_text(key, value_str) + metadata = get_image_info(image) ratio_original = w0 / h0 ratio_thumbnail = w / h @@ -354,8 +387,10 @@ async def get_model_preview(request): image.thumbnail((w, h)) + if not image_format_is_equal(image_format, response_image_format) and (response_image_format == 'JPEG' or response_image_format == 'JPG'): + image = image.convert('RGB') image_bytes = io.BytesIO() - image.save(image_bytes, format=format, exif=exif, pnginfo=metadata) + image.save(image_bytes, format=response_image_format, exif=exif, pnginfo=metadata, quality=quality) image_data = image_bytes.getvalue() return web.Response( @@ -363,7 +398,7 @@ async def get_model_preview(request): "Content-Disposition": f"inline; filename={file_name}", }, body=image_data, - content_type="image/" + image_type, + content_type="image/" + response_image_format.lower(), ) diff --git a/web/model-manager.js b/web/model-manager.js index b6386c3..f74f514 100644 --- a/web/model-manager.js +++ b/web/model-manager.js @@ -114,9 +114,11 @@ class SearchPath { * @param {string | undefined} [dateImageModified=undefined] * @param {string | undefined} [width=undefined] * @param {string | undefined} [height=undefined] + * @param {string | undefined} [imageFormat=undefined] + * @param {string | undefined} [quality=undefined] * @returns {string} */ -function imageUri(imageSearchPath = undefined, dateImageModified = undefined, width = undefined, height = undefined) { +function imageUri(imageSearchPath = undefined, dateImageModified = undefined, width = undefined, height = undefined, imageFormat = undefined, quality = undefined) { const path = imageSearchPath ?? "no-preview"; const date = dateImageModified; let uri = `/model-manager/preview/get?uri=${path}`; @@ -129,11 +131,19 @@ function imageUri(imageSearchPath = undefined, dateImageModified = undefined, wi if (date !== undefined && date !== null) { uri += `&v=${date}`; } + if (imageFormat !== undefined && imageFormat !== null) { + uri += `&image-format=${imageFormat}`; + } + if (quality !== undefined && quality !== null) { + uri += `&quality=${quality}`; + } return uri; } const PREVIEW_NONE_URI = imageUri(); const PREVIEW_THUMBNAIL_WIDTH = 320; const PREVIEW_THUMBNAIL_HEIGHT = 480; +const PREVIEW_THUMBNAIL_FORMAT = "JPEG"; +const PREVIEW_THUMBNAIL_QUALITY = undefined; /** * @param {(...args) => void} callback @@ -1701,6 +1711,8 @@ class ModelGrid { previewInfo?.dateModified, PREVIEW_THUMBNAIL_WIDTH, PREVIEW_THUMBNAIL_HEIGHT, + PREVIEW_THUMBNAIL_FORMAT, + PREVIEW_THUMBNAIL_QUALITY, ), draggable: false, }),