Avoid overwriting preexisting preview images unless user selected "overwrite".

- Fixed invalid exceptions.
This commit is contained in:
Christian Bastian
2024-07-24 15:53:55 -04:00
parent 4f5ba8deec
commit 3ca0f500b2

View File

@@ -433,7 +433,7 @@ async def get_image_extensions(request):
def download_model_preview(formdata): def download_model_preview(formdata):
path = formdata.get("path", None) path = formdata.get("path", None)
if type(path) is not str: if type(path) is not str:
raise ("Invalid path!") raise ValueError("Invalid path!")
path, model_type = search_path_to_system_path(path) path, model_type = search_path_to_system_path(path)
model_type_extensions = folder_paths_get_supported_pt_extensions(model_type) model_type_extensions = folder_paths_get_supported_pt_extensions(model_type)
path_without_extension, _ = split_valid_ext(path, model_type_extensions) path_without_extension, _ = split_valid_ext(path, model_type_extensions)
@@ -474,32 +474,36 @@ def download_model_preview(formdata):
else: else:
content_type = image.content_type content_type = image.content_type
if not content_type.startswith("image/"): if not content_type.startswith("image/"):
raise ("Invalid content type!") raise RuntimeError("Invalid content type!")
image_extension = "." + content_type[len("image/"):] image_extension = "." + content_type[len("image/"):]
if image_extension not in image_extensions: if image_extension not in image_extensions:
raise ("Invalid extension!") raise RuntimeError("Invalid extension!")
image_path = path_without_extension + image_extension image_path = path_without_extension + image_extension
if not overwrite and os.path.isfile(image_path): if not overwrite and os.path.isfile(image_path):
raise ("Image already exists!") raise RuntimeError("Image already exists!")
file: io.IOBase = image.file file: io.IOBase = image.file
image_data = file.read() image_data = file.read()
with open(image_path, "wb") as f: with open(image_path, "wb") as f:
f.write(image_data) f.write(image_data)
print("Saved file: " + image_path) print("Saved file: " + image_path)
delete_same_name_files(path_without_extension, preview_extensions, image_extension) if overwrite:
delete_same_name_files(path_without_extension, preview_extensions, image_extension)
# detect and fix wrong file extension # detect (and try to fix) wrong file extension
image_format = None image_format = None
with Image.open(image_path) as image: with Image.open(image_path) as image:
image_format = image.format image_format = image.format
image_dir_and_name, image_ext = os.path.splitext(image_path) image_dir_and_name, image_ext = os.path.splitext(image_path)
if not image_format_is_equal(image_format, image_ext): if not image_format_is_equal(image_format, image_ext):
corrected_image_path = image_dir_and_name + "." + image_format.lower() corrected_image_path = image_dir_and_name + "." + image_format.lower()
os.rename(image_path, corrected_image_path) if os.path.exists(corrected_image_path) and not overwrite:
print("Saved file: " + corrected_image_path) print("WARNING: '" + image_path + "' has wrong extension!")
image_path = corrected_image_path else:
os.rename(image_path, corrected_image_path)
print("Saved file: " + corrected_image_path)
image_path = corrected_image_path
return image_path # return in-case need corrected path return image_path # return in-case need corrected path