diff --git a/__init__.py b/__init__.py index 47513aa..b09db8d 100644 --- a/__init__.py +++ b/__init__.py @@ -433,7 +433,7 @@ async def get_image_extensions(request): def download_model_preview(formdata): path = formdata.get("path", None) if type(path) is not str: - raise ("Invalid path!") + raise ValueError("Invalid path!") path, model_type = search_path_to_system_path(path) model_type_extensions = folder_paths_get_supported_pt_extensions(model_type) path_without_extension, _ = split_valid_ext(path, model_type_extensions) @@ -474,32 +474,36 @@ def download_model_preview(formdata): else: content_type = image.content_type if not content_type.startswith("image/"): - raise ("Invalid content type!") + raise RuntimeError("Invalid content type!") image_extension = "." + content_type[len("image/"):] if image_extension not in image_extensions: - raise ("Invalid extension!") + raise RuntimeError("Invalid extension!") image_path = path_without_extension + image_extension if not overwrite and os.path.isfile(image_path): - raise ("Image already exists!") + raise RuntimeError("Image already exists!") file: io.IOBase = image.file image_data = file.read() with open(image_path, "wb") as f: f.write(image_data) print("Saved file: " + image_path) - delete_same_name_files(path_without_extension, preview_extensions, image_extension) + if overwrite: + delete_same_name_files(path_without_extension, preview_extensions, image_extension) - # detect and fix wrong file extension + # detect (and try to fix) wrong file extension image_format = None with Image.open(image_path) as image: image_format = image.format image_dir_and_name, image_ext = os.path.splitext(image_path) if not image_format_is_equal(image_format, image_ext): corrected_image_path = image_dir_and_name + "." + image_format.lower() - os.rename(image_path, corrected_image_path) - print("Saved file: " + corrected_image_path) - image_path = corrected_image_path + if os.path.exists(corrected_image_path) and not overwrite: + print("WARNING: '" + image_path + "' has wrong extension!") + else: + os.rename(image_path, corrected_image_path) + print("Saved file: " + corrected_image_path) + image_path = corrected_image_path return image_path # return in-case need corrected path