From 75f922bea25dbd813e80136c2a76b06c8335b4fb Mon Sep 17 00:00:00 2001 From: Christian Bastian <80225746+cdb-boop@users.noreply.github.com> Date: Mon, 23 Sep 2024 17:55:27 -0400 Subject: [PATCH] Added scan to download previews from model info files. - Fixed bug where scan button was not getting reset. - Attempt to get full size image preview. (May not get original image.) --- __init__.py | 185 ++++++++++++++++++++++++++++++------------- web/model-manager.js | 35 +++++++- 2 files changed, 164 insertions(+), 56 deletions(-) diff --git a/__init__.py b/__init__.py index c97a92c..5263ba8 100644 --- a/__init__.py +++ b/__init__.py @@ -271,6 +271,45 @@ def hash_file(path, buffer_size=1024*1024): class Civitai: + IMAGE_URL_SUBDIRECTORY_PREFIX = "https://civitai.com/images/" + IMAGE_URL_DOMAIN_PREFIX = "'https://image.civitai.com/" + + @staticmethod + def image_subdirectory_url_to_image_url(image_url): + url_suffix = image_url[len(Civitai.IMAGE_URL_SUBDIRECTORY_PREFIX):] + image_id = re.search(r"^\d+", url_suffix).group(0) + image_id = str(int(image_id)) + image_info_url = f"https://civitai.com/api/v1/images?imageId={image_id}" + def_headers = get_def_headers(image_info_url) + response = requests.get( + url=image_info_url, + stream=False, + verify=False, + headers=def_headers, + proxies=None, + allow_redirects=False, + ) + if response.ok: + #content_type = response.headers.get("Content-Type") + info = response.json() + items = info["items"] + if len(items) == 0: + raise RuntimeError("Civitai /api/v1/images returned 0 items!") + return items[0]["url"] + else: + raise RuntimeError("Bad response from api/v1/images!") + + @staticmethod + def image_domain_url_full_size(url, width = None): + result = re.search("/width=(\d+)", url) + if width is None: + i0 = result.span()[0] + i1 = result.span()[1] + return url[0:i0] + url[i1:] + else: + w = int(result.group(1)) + return url.replace(str(w), str(width)) + @staticmethod def search_by_hash(sha256_hash): url_api_hash = r"https://civitai.com/api/v1/model-versions/by-hash/" + sha256_hash @@ -301,11 +340,17 @@ class Civitai: return url @staticmethod - def get_preview_urls(model_version_info): + def get_preview_urls(model_version_info, full_size=False): images = model_version_info.get("images", None) if images is None: return [] - return [image_info["url"] for image_info in images] + preview_urls = [] + for image_info in images: + url = image_info["url"] + if full_size: + url = Civitai.image_domain_url_full_size(url, image_info.get("width", None)) + preview_urls.append(url) + return preview_urls @staticmethod def search_notes(model_version_info): @@ -432,10 +477,10 @@ class ModelInfo: return "" @staticmethod - def get_web_preview_urls(model_info): + def get_web_preview_urls(model_info, full_size=False): if len(model_info) == 0: return [] - preview_urls = Civitai.get_preview_urls(model_info) + preview_urls = Civitai.get_preview_urls(model_info, full_size) if len(preview_urls) > 0: return preview_urls # TODO: support other websites @@ -652,42 +697,16 @@ async def get_image_extensions(request): return web.json_response(image_extensions) -def download_model_preview(formdata): - path = formdata.get("path", None) - if type(path) is not str: +def download_model_preview(path, image, overwrite): + if not os.path.isfile(path): raise ValueError("Invalid path!") - path, model_type = search_path_to_system_path(path) - model_type_extensions = folder_paths_get_supported_pt_extensions(model_type) - path_without_extension, _ = split_valid_ext(path, model_type_extensions) + path_without_extension = os.path.splitext(path)[0] - overwrite = formdata.get("overwrite", "true").lower() - overwrite = True if overwrite == "true" else False - - image = formdata.get("image", None) if type(image) is str: - civitai_image_url = "https://civitai.com/images/" - if image.startswith(civitai_image_url): - image_id = re.search(r"^\d+", image[len(civitai_image_url):]).group(0) - image_id = str(int(image_id)) - image_info_url = f"https://civitai.com/api/v1/images?imageId={image_id}" - def_headers = get_def_headers(image_info_url) - response = requests.get( - url=image_info_url, - stream=False, - verify=False, - headers=def_headers, - proxies=None, - allow_redirects=False, - ) - if response.ok: - content_type = response.headers.get("Content-Type") - info = response.json() - items = info["items"] - if len(items) == 0: - raise RuntimeError("Civitai /api/v1/images returned 0 items!") - image = items[0]["url"] - else: - raise RuntimeError("Bad response from api/v1/images!") + if image.startswith(Civitai.IMAGE_URL_SUBDIRECTORY_PREFIX): + image = Civitai.image_subdirectory_url_to_image_url(image) + if image.startswith(Civitai.IMAGE_URL_DOMAIN_PREFIX): + image = Civitai.image_domain_url_full_size(image) _, image_extension = split_valid_ext(image, image_extensions) if image_extension == "": raise ValueError("Invalid image type!") @@ -715,17 +734,23 @@ def download_model_preview(formdata): # detect (and try to fix) wrong file extension image_format = None - with Image.open(image_path) as image: - image_format = image.format - image_dir_and_name, image_ext = os.path.splitext(image_path) - if not image_format_is_equal(image_format, image_ext): - corrected_image_path = image_dir_and_name + "." + image_format.lower() - if os.path.exists(corrected_image_path) and not overwrite: - print("WARNING: '" + image_path + "' has wrong extension!") - else: - os.rename(image_path, corrected_image_path) - print("Saved file: " + corrected_image_path) - image_path = corrected_image_path + try: + with Image.open(image_path) as image: + image_format = image.format + image_dir_and_name, image_ext = os.path.splitext(image_path) + if not image_format_is_equal(image_format, image_ext): + corrected_image_path = image_dir_and_name + "." + image_format.lower() + if os.path.exists(corrected_image_path) and not overwrite: + print("WARNING: '" + image_path + "' has wrong extension!") + else: + os.rename(image_path, corrected_image_path) + print("Saved file: " + corrected_image_path) + image_path = corrected_image_path + except Image.UnidentifiedImageError as e: #TODO: handle case where "image" is actually video + print("WARNING: '" + image_path + "' image format was unknown!") + os.remove(image_path) + print("Deleted file: " + image_path) + image_path = "" return image_path # return in-case need corrected path @@ -733,7 +758,15 @@ def download_model_preview(formdata): async def set_model_preview(request): formdata = await request.post() try: - download_model_preview(formdata) + search_path = formdata.get("path", None) + model_path, model_type = search_path_to_system_path(search_path) + + image = formdata.get("image", None) + + overwrite = formdata.get("overwrite", "true").lower() + overwrite = True if overwrite == "true" else False + + download_model_preview(model_path, image, overwrite) return web.json_response({ "success": True }) except ValueError as e: print(e, file=sys.stderr, flush=True) @@ -1047,6 +1080,48 @@ async def try_scan_download(request): response["success"] = True return web.json_response(response) +@server.PromptServer.instance.routes.post("/model-manager/preview/scan") +async def try_scan_download_previews(request): + refresh = request.query.get("refresh", None) is not None + response = { + "success": False, + "count": 0, + } + model_paths = folder_paths_folder_names_and_paths(refresh) + for _, (model_dirs, model_extension_whitelist) in model_paths.items(): + for root_dir in model_dirs: + for root, dirs, files in os.walk(root_dir): + for file in files: + file_name, file_extension = os.path.splitext(file) + if file_extension not in model_extension_whitelist: + continue + model_file_path = root + os.path.sep + file + model_file_head = os.path.splitext(model_file_path)[0] + + preview_exists = False + for preview_extension in preview_extensions: + preview_path = model_file_head + preview_extension + if os.path.isfile(preview_path): + preview_exists = True + break + if preview_exists: + continue + + model_info = ModelInfo.try_load_cached(model_file_path) # NOTE: model info must already be downloaded + web_previews = ModelInfo.get_web_preview_urls(model_info, True) + if len(web_previews) == 0: + continue + saved_image_path = download_model_preview( + model_file_path, + image=web_previews[0], + overwrite=False, + ) + if os.path.isfile(saved_image_path): + response["count"] += 1 + + response["success"] = True + return web.json_response(response) + def download_file(url, filename, overwrite): if not overwrite and os.path.isfile(filename): @@ -1272,7 +1347,7 @@ async def get_model_metadata(request): tags.sort(key=lambda x: x[1], reverse=True) model_info = ModelInfo.try_load_cached(abs_path) - web_previews = ModelInfo.get_web_preview_urls(model_info) + web_previews = ModelInfo.get_web_preview_urls(model_info, True) result["success"] = True result["info"] = data @@ -1398,11 +1473,11 @@ async def download_model(request): image = formdata.get("image") if image is not None and image != "": try: - download_model_preview({ - "path": model_path + os.sep + name, - "image": image, - "overwrite": formdata.get("overwrite"), - }) + download_model_preview( + file_name, + image, + formdata.get("overwrite"), + ) except Exception as e: print(e, file=sys.stderr, flush=True) result["alert"] = "Failed to download preview!\n\n" + str(e) diff --git a/web/model-manager.js b/web/model-manager.js index ed07c3b..5659dcc 100644 --- a/web/model-manager.js +++ b/web/model-manager.js @@ -4836,7 +4836,8 @@ class SettingsView { }).catch((err) => { return { success: false }; }); - const successMessage = data['success'] ? "Scan Finished!" : "Scan Failed!"; + const success = data['success']; + const successMessage = success ? "Scan Finished!" : "Scan Failed!"; const infoCount = data['infoCount']; const notesCount = data['notesCount']; const urlCount = data['urlCount']; @@ -4849,6 +4850,37 @@ class SettingsView { }, }).element; + const scanDownloadPreviewsButton = new ComfyButton({ + content: 'Download Missing Previews', + tooltip: 'Downloads missing model previews from model info.\nRun model info scan first!', + action: async (e) => { + const confirmation = window.confirm( + 'WARNING: This may take a while and generate MANY server requests!\nUSE AT YOUR OWN RISK!', + ); + if (!confirmation) { + return; + } + + const [button, icon, span] = comfyButtonDisambiguate(e.target); + button.disabled = true; + const data = await comfyRequest('/model-manager/preview/scan', { + method: 'POST', + body: JSON.stringify({}), + }).catch((err) => { + return { success: false }; + }); + const success = data['success']; + const successMessage = success ? "Scan Finished!" : "Scan Failed!"; + const count = data['count']; + window.alert(`${successMessage}\nPreviews Downloaded: ${count}`); + comfyButtonAlert(e.target, success); + if (count > 0) { + await this.reload(true); + } + button.disabled = false; + }, + }).element; + $el( 'div.model-manager-settings', { @@ -5011,6 +5043,7 @@ class SettingsView { $el('h2', ['Scan Files']), $el('div', [correctPreviewsButton]), $el('div', [scanDownloadModelInfosButton]), + $el('div', [scanDownloadPreviewsButton]), $el('h2', ['Random Tag Generator']), $select({ $: (el) => (settings['tag-generator-sampler-method'] = el),