Added scan to download previews from model info files.

- Fixed bug where scan button was not getting reset.
- Attempt to get full size image preview. (May not get original image.)
This commit is contained in:
Christian Bastian
2024-09-23 17:55:27 -04:00
parent 182c515a6e
commit 75f922bea2
2 changed files with 164 additions and 56 deletions

View File

@@ -271,6 +271,45 @@ def hash_file(path, buffer_size=1024*1024):
class Civitai:
IMAGE_URL_SUBDIRECTORY_PREFIX = "https://civitai.com/images/"
IMAGE_URL_DOMAIN_PREFIX = "'https://image.civitai.com/"
@staticmethod
def image_subdirectory_url_to_image_url(image_url):
url_suffix = image_url[len(Civitai.IMAGE_URL_SUBDIRECTORY_PREFIX):]
image_id = re.search(r"^\d+", url_suffix).group(0)
image_id = str(int(image_id))
image_info_url = f"https://civitai.com/api/v1/images?imageId={image_id}"
def_headers = get_def_headers(image_info_url)
response = requests.get(
url=image_info_url,
stream=False,
verify=False,
headers=def_headers,
proxies=None,
allow_redirects=False,
)
if response.ok:
#content_type = response.headers.get("Content-Type")
info = response.json()
items = info["items"]
if len(items) == 0:
raise RuntimeError("Civitai /api/v1/images returned 0 items!")
return items[0]["url"]
else:
raise RuntimeError("Bad response from api/v1/images!")
@staticmethod
def image_domain_url_full_size(url, width = None):
result = re.search("/width=(\d+)", url)
if width is None:
i0 = result.span()[0]
i1 = result.span()[1]
return url[0:i0] + url[i1:]
else:
w = int(result.group(1))
return url.replace(str(w), str(width))
@staticmethod
def search_by_hash(sha256_hash):
url_api_hash = r"https://civitai.com/api/v1/model-versions/by-hash/" + sha256_hash
@@ -301,11 +340,17 @@ class Civitai:
return url
@staticmethod
def get_preview_urls(model_version_info):
def get_preview_urls(model_version_info, full_size=False):
images = model_version_info.get("images", None)
if images is None:
return []
return [image_info["url"] for image_info in images]
preview_urls = []
for image_info in images:
url = image_info["url"]
if full_size:
url = Civitai.image_domain_url_full_size(url, image_info.get("width", None))
preview_urls.append(url)
return preview_urls
@staticmethod
def search_notes(model_version_info):
@@ -432,10 +477,10 @@ class ModelInfo:
return ""
@staticmethod
def get_web_preview_urls(model_info):
def get_web_preview_urls(model_info, full_size=False):
if len(model_info) == 0:
return []
preview_urls = Civitai.get_preview_urls(model_info)
preview_urls = Civitai.get_preview_urls(model_info, full_size)
if len(preview_urls) > 0:
return preview_urls
# TODO: support other websites
@@ -652,42 +697,16 @@ async def get_image_extensions(request):
return web.json_response(image_extensions)
def download_model_preview(formdata):
path = formdata.get("path", None)
if type(path) is not str:
def download_model_preview(path, image, overwrite):
if not os.path.isfile(path):
raise ValueError("Invalid path!")
path, model_type = search_path_to_system_path(path)
model_type_extensions = folder_paths_get_supported_pt_extensions(model_type)
path_without_extension, _ = split_valid_ext(path, model_type_extensions)
path_without_extension = os.path.splitext(path)[0]
overwrite = formdata.get("overwrite", "true").lower()
overwrite = True if overwrite == "true" else False
image = formdata.get("image", None)
if type(image) is str:
civitai_image_url = "https://civitai.com/images/"
if image.startswith(civitai_image_url):
image_id = re.search(r"^\d+", image[len(civitai_image_url):]).group(0)
image_id = str(int(image_id))
image_info_url = f"https://civitai.com/api/v1/images?imageId={image_id}"
def_headers = get_def_headers(image_info_url)
response = requests.get(
url=image_info_url,
stream=False,
verify=False,
headers=def_headers,
proxies=None,
allow_redirects=False,
)
if response.ok:
content_type = response.headers.get("Content-Type")
info = response.json()
items = info["items"]
if len(items) == 0:
raise RuntimeError("Civitai /api/v1/images returned 0 items!")
image = items[0]["url"]
else:
raise RuntimeError("Bad response from api/v1/images!")
if image.startswith(Civitai.IMAGE_URL_SUBDIRECTORY_PREFIX):
image = Civitai.image_subdirectory_url_to_image_url(image)
if image.startswith(Civitai.IMAGE_URL_DOMAIN_PREFIX):
image = Civitai.image_domain_url_full_size(image)
_, image_extension = split_valid_ext(image, image_extensions)
if image_extension == "":
raise ValueError("Invalid image type!")
@@ -715,17 +734,23 @@ def download_model_preview(formdata):
# detect (and try to fix) wrong file extension
image_format = None
with Image.open(image_path) as image:
image_format = image.format
image_dir_and_name, image_ext = os.path.splitext(image_path)
if not image_format_is_equal(image_format, image_ext):
corrected_image_path = image_dir_and_name + "." + image_format.lower()
if os.path.exists(corrected_image_path) and not overwrite:
print("WARNING: '" + image_path + "' has wrong extension!")
else:
os.rename(image_path, corrected_image_path)
print("Saved file: " + corrected_image_path)
image_path = corrected_image_path
try:
with Image.open(image_path) as image:
image_format = image.format
image_dir_and_name, image_ext = os.path.splitext(image_path)
if not image_format_is_equal(image_format, image_ext):
corrected_image_path = image_dir_and_name + "." + image_format.lower()
if os.path.exists(corrected_image_path) and not overwrite:
print("WARNING: '" + image_path + "' has wrong extension!")
else:
os.rename(image_path, corrected_image_path)
print("Saved file: " + corrected_image_path)
image_path = corrected_image_path
except Image.UnidentifiedImageError as e: #TODO: handle case where "image" is actually video
print("WARNING: '" + image_path + "' image format was unknown!")
os.remove(image_path)
print("Deleted file: " + image_path)
image_path = ""
return image_path # return in-case need corrected path
@@ -733,7 +758,15 @@ def download_model_preview(formdata):
async def set_model_preview(request):
formdata = await request.post()
try:
download_model_preview(formdata)
search_path = formdata.get("path", None)
model_path, model_type = search_path_to_system_path(search_path)
image = formdata.get("image", None)
overwrite = formdata.get("overwrite", "true").lower()
overwrite = True if overwrite == "true" else False
download_model_preview(model_path, image, overwrite)
return web.json_response({ "success": True })
except ValueError as e:
print(e, file=sys.stderr, flush=True)
@@ -1047,6 +1080,48 @@ async def try_scan_download(request):
response["success"] = True
return web.json_response(response)
@server.PromptServer.instance.routes.post("/model-manager/preview/scan")
async def try_scan_download_previews(request):
refresh = request.query.get("refresh", None) is not None
response = {
"success": False,
"count": 0,
}
model_paths = folder_paths_folder_names_and_paths(refresh)
for _, (model_dirs, model_extension_whitelist) in model_paths.items():
for root_dir in model_dirs:
for root, dirs, files in os.walk(root_dir):
for file in files:
file_name, file_extension = os.path.splitext(file)
if file_extension not in model_extension_whitelist:
continue
model_file_path = root + os.path.sep + file
model_file_head = os.path.splitext(model_file_path)[0]
preview_exists = False
for preview_extension in preview_extensions:
preview_path = model_file_head + preview_extension
if os.path.isfile(preview_path):
preview_exists = True
break
if preview_exists:
continue
model_info = ModelInfo.try_load_cached(model_file_path) # NOTE: model info must already be downloaded
web_previews = ModelInfo.get_web_preview_urls(model_info, True)
if len(web_previews) == 0:
continue
saved_image_path = download_model_preview(
model_file_path,
image=web_previews[0],
overwrite=False,
)
if os.path.isfile(saved_image_path):
response["count"] += 1
response["success"] = True
return web.json_response(response)
def download_file(url, filename, overwrite):
if not overwrite and os.path.isfile(filename):
@@ -1272,7 +1347,7 @@ async def get_model_metadata(request):
tags.sort(key=lambda x: x[1], reverse=True)
model_info = ModelInfo.try_load_cached(abs_path)
web_previews = ModelInfo.get_web_preview_urls(model_info)
web_previews = ModelInfo.get_web_preview_urls(model_info, True)
result["success"] = True
result["info"] = data
@@ -1398,11 +1473,11 @@ async def download_model(request):
image = formdata.get("image")
if image is not None and image != "":
try:
download_model_preview({
"path": model_path + os.sep + name,
"image": image,
"overwrite": formdata.get("overwrite"),
})
download_model_preview(
file_name,
image,
formdata.get("overwrite"),
)
except Exception as e:
print(e, file=sys.stderr, flush=True)
result["alert"] = "Failed to download preview!\n\n" + str(e)

View File

@@ -4836,7 +4836,8 @@ class SettingsView {
}).catch((err) => {
return { success: false };
});
const successMessage = data['success'] ? "Scan Finished!" : "Scan Failed!";
const success = data['success'];
const successMessage = success ? "Scan Finished!" : "Scan Failed!";
const infoCount = data['infoCount'];
const notesCount = data['notesCount'];
const urlCount = data['urlCount'];
@@ -4849,6 +4850,37 @@ class SettingsView {
},
}).element;
const scanDownloadPreviewsButton = new ComfyButton({
content: 'Download Missing Previews',
tooltip: 'Downloads missing model previews from model info.\nRun model info scan first!',
action: async (e) => {
const confirmation = window.confirm(
'WARNING: This may take a while and generate MANY server requests!\nUSE AT YOUR OWN RISK!',
);
if (!confirmation) {
return;
}
const [button, icon, span] = comfyButtonDisambiguate(e.target);
button.disabled = true;
const data = await comfyRequest('/model-manager/preview/scan', {
method: 'POST',
body: JSON.stringify({}),
}).catch((err) => {
return { success: false };
});
const success = data['success'];
const successMessage = success ? "Scan Finished!" : "Scan Failed!";
const count = data['count'];
window.alert(`${successMessage}\nPreviews Downloaded: ${count}`);
comfyButtonAlert(e.target, success);
if (count > 0) {
await this.reload(true);
}
button.disabled = false;
},
}).element;
$el(
'div.model-manager-settings',
{
@@ -5011,6 +5043,7 @@ class SettingsView {
$el('h2', ['Scan Files']),
$el('div', [correctPreviewsButton]),
$el('div', [scanDownloadModelInfosButton]),
$el('div', [scanDownloadPreviewsButton]),
$el('h2', ['Random Tag Generator']),
$select({
$: (el) => (settings['tag-generator-sampler-method'] = el),