diff --git a/__init__.py b/__init__.py index 5746ea0..417ec92 100644 --- a/__init__.py +++ b/__init__.py @@ -244,56 +244,63 @@ def get_def_headers(url=""): return def_headers -def civitai_get_model_version_info_by_hash(sha256_hash): - url_api_hash = r"https://civitai.com/api/v1/model-versions/by-hash/" + sha256_hash - hash_response = requests.get(url_api_hash) - if hash_response.status_code != 200: - return {} - return hash_response.json() - -def civitai_get_model_info_by_model_id(model_id): - url_api_model = r"https://civitai.com/api/v1/models/" + str(model_id) - model_response = requests.get(url_api_model) - if model_response.status_code != 200: - return {} - return model_response.json() +def hash_file(path, buffer_size=1024*1024): + sha256 = hashlib.sha256() + with open(path, 'rb') as f: + while True: + data = f.read(buffer_size) + if not data: break + sha256.update(data) + return sha256.hexdigest() -def search_web_for_model_info(sha256_hash): - model_info = civitai_get_model_version_info_by_hash(sha256_hash) - if len(model_info) > 0: return model_info +class Civitai: + @staticmethod + def search_by_hash(sha256_hash): + url_api_hash = r"https://civitai.com/api/v1/model-versions/by-hash/" + sha256_hash + hash_response = requests.get(url_api_hash) + if hash_response.status_code != 200: + return {} + return hash_response.json() # model version info - # TODO: search other websites + @staticmethod + def search_by_model_id(model_id): + url_api_model = r"https://civitai.com/api/v1/models/" + str(model_id) + model_response = requests.get(url_api_model) + if model_response.status_code != 200: + return {} + return model_response.json() # model group info - return {} + @staticmethod + def get_model_url(model_version_info): + if len(model_version_info) == 0: return "" + model_id = model_version_info.get("modelId") + if model_id is None: + # there can be incomplete model info, so don't throw just in case + return "" + url = f"https://civitai.com/models/{model_id}" + version_id = model_version_info.get("id") + if version_id is not None: + url += f"?modelVersionId={version_id}" + return url + @staticmethod + def search_notes(model_version_info): + model_id = model_version_info.get("modelId") + model_version_id = model_version_info.get("id") -def search_web_for_model_url(sha256_hash): - model_info = civitai_get_model_version_info_by_hash(sha256_hash) - if len(model_info) > 0: - model_id = model_info["modelId"] - version_id = model_info["id"] - return f"https://civitai.com/models/{model_id}?modelVersionId={version_id}" + assert(model_id is not None) + assert(model_version_id is not None) - # TODO: search other websites - - return "" - - -def search_web_for_model_notes(sha256_hash): - model_info = civitai_get_model_version_info_by_hash(sha256_hash) - model_info = civitai_get_model_info_by_model_id(model_info["modelId"]) - if len(model_info) > 0: - model_description = model_info.get("description", "") model_version_description = "" model_trigger_words = [] + model_info = Civitai.search_by_model_id(model_id) + model_description = model_info.get("description", "") for model_version in model_info["modelVersions"]: - for files in model_version["files"]: - if files["hashes"]["SHA256"].lower() == sha256_hash.lower(): - model_version_description = model_version.get("description", "") - model_trigger_words = model_version.get("trainedWords", "") - break - if model_version_description != "": break + if model_version["id"] == model_version_id: + model_version_description = model_version.get("description", "") + model_trigger_words = model_version.get("trainedWords", []) + break notes = "" if len(model_trigger_words) > 0: @@ -313,24 +320,76 @@ def search_web_for_model_notes(sha256_hash): notes += markdownify.markdownify(model_version_description) if model_description != "": if len(notes) > 0: notes += "\n\n" - notes += "# " + model_info.get("name", str(model_info["id"])) + "\n\n" + notes += "# " + model_info.get("name", str(model_id)) + "\n\n" notes += markdownify.markdownify(model_description) - notes = notes.strip() + return notes.strip() + + +class ModelInfo: + @staticmethod + def search_by_hash(sha256_hash): + model_info = Civitai.search_by_hash(sha256_hash) + if len(model_info) > 0: return model_info + # TODO: search other websites + return {} + + @staticmethod + def try_load_cached(model_path): + model_info_path = os.path.splitext(model_path)[0] + model_info_extension + if os.path.isfile(model_info_path): + with open(model_info_path, "r", encoding="utf-8") as f: + model_info = json.load(f) + return model_info + return {} + + @staticmethod + def get_hash(model_info): + model_info = Civitai.get_hash(model_info) + if len(model_info) > 0: return model_info + # TODO: search other websites + return {} + + @staticmethod + def search_info(model_path, cache=True, use_cached=True): + model_info = ModelInfo.try_load_cached(model_path) + if use_cached and len(model_info) > 0: + return model_info + + sha256_hash = hash_file(model_path) + model_info = ModelInfo.search_by_hash(sha256_hash) + if cache and len(model_info) > 0: + model_info_path = os.path.splitext(model_path)[0] + model_info_extension + with open(model_info_path, "w", encoding="utf-8") as f: + json.dump(model_info, f, indent=4) + print("Saved file: " + model_info_path) + + return model_info + + @staticmethod + def get_url(model_info): + if len(model_info) == 0: return "" + model_url = Civitai.get_model_url(model_info) + if model_url != "": return model_url + + # TODO: huggingface has / formats + + # TODO: support other websites + return "" + + @staticmethod + def search_notes(model_path): + notes = "" + + model_info = ModelInfo.search_info(model_path, cache=True, use_cached=True) # assume cached is correct; re-download elsewhere + if len(model_info) > 0: + notes = Civitai.search_notes(model_info) + + # TODO: support other websites return notes - # TODO: search other websites + # TODO: search other websites - return "" - - -def hash_file(path, buffer_size=1024*1024): - sha256 = hashlib.sha256() - with open(path, 'rb') as f: - while True: - data = f.read(buffer_size) - if not data: break - sha256.update(data) - return sha256.hexdigest() + return "" @server.PromptServer.instance.routes.get("/model-manager/timestamp") @@ -979,7 +1038,7 @@ def bytes_to_size(total_bytes): @server.PromptServer.instance.routes.get("/model-manager/model/info/{path}") -async def get_model_info(request): +async def get_model_metadata(request): result = { "success": False } model_path = request.match_info["path"] @@ -993,16 +1052,16 @@ async def get_model_info(request): result["alert"] = "Invalid model path!" return web.json_response(result) - info = {} + data = {} comfyui_directory, name = os.path.split(model_path) - info["File Name"] = name - info["File Directory"] = comfyui_directory - info["File Size"] = bytes_to_size(os.path.getsize(abs_path)) + data["File Name"] = name + data["File Directory"] = comfyui_directory + data["File Size"] = bytes_to_size(os.path.getsize(abs_path)) stats = pathlib.Path(abs_path).stat() date_format = "%Y-%m-%d %H:%M:%S" date_modified = datetime.fromtimestamp(stats.st_mtime).strftime(date_format) - #info["Date Modified"] = date_modified - #info["Date Created"] = datetime.fromtimestamp(stats.st_ctime).strftime(date_format) + #data["Date Modified"] = date_modified + #data["Date Created"] = datetime.fromtimestamp(stats.st_ctime).strftime(date_format) model_extensions = folder_paths_get_supported_pt_extensions(model_type) abs_name , _ = split_valid_ext(abs_path, model_extensions) @@ -1012,7 +1071,7 @@ async def get_model_info(request): if os.path.isfile(maybe_preview): preview_path, _ = split_valid_ext(model_path, model_extensions) preview_modified = pathlib.Path(maybe_preview).stat().st_mtime_ns - info["Preview"] = { + data["Preview"] = { "path": preview_path + extension, "dateModified": str(preview_modified), } @@ -1021,27 +1080,27 @@ async def get_model_info(request): header = get_safetensor_header(abs_path) metadata = header.get("__metadata__", None) - if metadata is not None and info.get("Preview", None) is None: + if metadata is not None and data.get("Preview", None) is None: thumbnail = metadata.get("modelspec.thumbnail") if thumbnail is not None: i0 = thumbnail.find("/") + 1 i1 = thumbnail.find(";", i0) thumbnail_extension = "." + thumbnail[i0:i1] if thumbnail_extension in image_extensions: - info["Preview"] = { + data["Preview"] = { "path": request.query["path"] + thumbnail_extension, "dateModified": date_modified, } if metadata is not None: - info["Base Training Model"] = metadata.get("ss_sd_model_name", "") - info["Base Model Version"] = metadata.get("ss_base_model_version", "") - info["Network Dimension"] = metadata.get("ss_network_dim", "") - info["Network Alpha"] = metadata.get("ss_network_alpha", "") + data["Base Training Model"] = metadata.get("ss_sd_model_name", "") + data["Base Model Version"] = metadata.get("ss_base_model_version", "") + data["Network Dimension"] = metadata.get("ss_network_dim", "") + data["Network Alpha"] = metadata.get("ss_network_alpha", "") if metadata is not None: training_comment = metadata.get("ss_training_comment", "") - info["Description"] = ( + data["Description"] = ( metadata.get("modelspec.description", "") + "\n\n" + metadata.get("modelspec.usage_hint", "") + @@ -1076,7 +1135,7 @@ async def get_model_info(request): resolutions[str(x) + "x" + str(y)] = count resolutions = list(resolutions.items()) resolutions.sort(key=lambda x: x[1], reverse=True) - info["Bucket Resolutions"] = resolutions + data["Bucket Resolutions"] = resolutions tags = None if metadata is not None: @@ -1091,7 +1150,7 @@ async def get_model_info(request): tags.sort(key=lambda x: x[1], reverse=True) result["success"] = True - result["info"] = info + result["info"] = data if metadata is not None: result["metadata"] = metadata if tags is not None: @@ -1099,8 +1158,9 @@ async def get_model_info(request): result["notes"] = notes return web.json_response(result) + @server.PromptServer.instance.routes.get("/model-manager/model/web-url") -async def get_model_info(request): +async def get_model_web_url(request): result = { "success": False } model_path = request.query.get("path", None) @@ -1114,9 +1174,14 @@ async def get_model_info(request): result["alert"] = "Invalid model path!" return web.json_response(result) - sha256_hash = hash_file(abs_path) - web_url = search_web_for_model_url(sha256_hash) + model_info = ModelInfo.search_info(abs_path) + if len(model_info) == 0: + result["alert"] = "Unable to find model info!" + return web.json_response(result) + web_url = ModelInfo.get_url(model_info) + if web_url != "": + result["success"] = True return web.json_response({ "url": web_url }) @@ -1164,18 +1229,7 @@ async def download_model(request): return web.json_response(result) # download model info - sha256_hash = formdata.get("sha256", None) - if sha256_hash is not None: - model_info = search_web_for_model_info(sha256_hash) - if len(model_info) > 0: - info_path = os.path.splitext(file_name)[0] + ".json" - try: - with open(info_path, "w", encoding="utf-8") as f: - json.dump(model_info, f, indent=4) - print("Saved file: " + info_path) - except ValueError as e: - print(e, file=sys.stderr, flush=True) - result["alert"] = "Failed to save model info!\n\n" + str(e) # TODO: >1 alert? concat? + _ = ModelInfo.search_info(file_name, cache=True) # save json # save image as model preview image = formdata.get("image") @@ -1379,8 +1433,7 @@ async def try_download_notes(request): result["alert"] = "Notes already exist!" return web.json_response(result) - sha256_hash = hash_file(abs_path) - notes = search_web_for_model_notes(sha256_hash) + notes = ModelInfo.search_notes(abs_path) if not notes.isspace() and notes != "": try: with open(notes_path, "w", encoding="utf-8") as f: diff --git a/web/model-manager.js b/web/model-manager.js index d58c4ad..a498cc6 100644 --- a/web/model-manager.js +++ b/web/model-manager.js @@ -2370,9 +2370,9 @@ class ModelInfo { [this.elements.tabButtons, this.elements.tabContents] = GenerateTabGroup([ { name: "Overview", icon: "information-box-outline", tabContent: this.element }, - { name: "Metadata", icon: "file-document-outline", tabContent: $el("div", ["Metadata"]) }, - { name: "Tags", icon: "tag-outline", tabContent: $el("div", ["Tags"]) }, { name: "Notes", icon: "pencil-outline", tabContent: $el("div", ["Notes"]) }, + { name: "Tags", icon: "tag-outline", tabContent: $el("div", ["Tags"]) }, + { name: "Metadata", icon: "file-document-outline", tabContent: $el("div", ["Metadata"]) }, ]); } @@ -2457,7 +2457,7 @@ class ModelInfo { result["info"], result["metadata"], result["tags"], - result["notes"] + result["notes"], ]; }) .catch((err) => { @@ -2642,33 +2642,137 @@ class ModelInfo { infoHtml.append.apply(infoHtml, innerHtml); // TODO: set default value of dropdown and value to model type? - /** @type {HTMLDivElement} */ - const metadataElement = this.elements.tabContents[1]; // TODO: remove magic value - const isMetadata = typeof metadata === 'object' && metadata !== null && Object.keys(metadata).length > 0; - metadataElement.innerHTML = ""; - metadataElement.append.apply(metadataElement, [ - $el("h1", ["Metadata"]), - $el("div", (() => { - const tableRows = []; - if (isMetadata) { - for (const [key, value] of Object.entries(metadata)) { - if (value === undefined || value === null) { - continue; - } - if (value !== "") { - tableRows.push($el("tr", [ - $el("th.model-metadata-key", [key]), - $el("th.model-metadata-value", [value]), - ])); - } - } + // + // NOTES + // + + const saveIcon = "content-save"; + const savingIcon = "cloud-upload-outline"; + + const saveNotesButton = new ComfyButton({ + icon: saveIcon, + tooltip: "Save note", + classList: "comfyui-button icon-button", + action: async (e) => { + const [button, icon, span] = comfyButtonDisambiguate(e.target); + button.disabled = true; + const saved = await this.trySave(false); + comfyButtonAlert(e.target, saved); + button.disabled = false; + }, + }).element; + + const downloadNotesButton = new ComfyButton({ + icon: "earth-arrow-down", + tooltip: "Attempt to download model info from the internet.", + classList: "comfyui-button icon-button", + action: async (e) => { + if (this.#savedNotesValue !== "") { + const overwriteNoteConfirmation = window.confirm("Overwrite note?"); + if (!overwriteNoteConfirmation) { + comfyButtonAlert(e.target, false, "mdi-check-bold", "mdi-close-thick"); + return; } - return $el("table.model-metadata", tableRows); - })(), - ), - ]); - const metadataButton = this.elements.tabButtons[1]; // TODO: remove magic value - metadataButton.style.display = isMetadata ? "" : "none"; + } + + const [button, icon, span] = comfyButtonDisambiguate(e.target); + button.disabled = true; + const [success, downloadedNotesValue] = await comfyRequest( + `/model-manager/notes/download?path=${path}&overwrite=True`, + { + method: "POST", + body: {}, + } + ).then((data) => { + const success = data["success"]; + const message = data["alert"]; + if (message !== undefined) { + window.alert(message); + } + return [success, data["notes"]]; + }).catch((err) => { + return [false, ""]; + }); + if (success) { + this.#savedNotesValue = downloadedNotesValue; + this.elements.notes.value = downloadedNotesValue; + } + comfyButtonAlert(e.target, success, "mdi-check-bold", "mdi-close-thick"); + button.disabled = false; + }, + }).element; + + const saveDebounce = debounce(async() => { + const saveIconClass = "mdi-" + saveIcon; + const savingIconClass = "mdi-" + savingIcon; + const iconElement = saveNotesButton.getElementsByTagName("i")[0]; + iconElement.classList.remove(saveIconClass); + iconElement.classList.add(savingIconClass); + const saved = await this.trySave(false); + iconElement.classList.remove(savingIconClass); + iconElement.classList.add(saveIconClass); + }, 1000); + + /** @type {HTMLDivElement} */ + const notesElement = this.elements.tabContents[1]; // TODO: remove magic value + notesElement.innerHTML = ""; + notesElement.append.apply(notesElement, + (() => { + const notes = $el("textarea.comfy-multiline-input", { + name: "model notes", + value: noteText, + oninput: (e) => { + if (this.#settingsElements["model-info-autosave-notes"].checked) { + saveDebounce(); + } + }, + }); + + if (navigator.userAgent.includes("Mac")) { + new KeyComboListener( + ["MetaLeft", "KeyS"], + saveDebounce, + notes, + ); + new KeyComboListener( + ["MetaRight", "KeyS"], + saveDebounce, + notes, + ); + } + else { + new KeyComboListener( + ["ControlLeft", "KeyS"], + saveDebounce, + notes, + ); + new KeyComboListener( + ["ControlRight", "KeyS"], + saveDebounce, + notes, + ); + } + + this.elements.notes = notes; + this.#savedNotesValue = noteText; + return [ + $el("div.row", { + style: { "align-items": "center" }, + }, [ + $el("h1", ["Notes"]), + saveNotesButton, + downloadNotesButton, + ]), + $el("div", { + style: { "display": "flex", "height": "100%", "min-height": "60px" }, + }, notes), + ]; + })() + ); + + // + // Tags + // /** @type {HTMLDivElement} */ const tagsElement = this.elements.tabContents[2]; // TODO: remove magic value @@ -2762,129 +2866,37 @@ class ModelInfo { const tagButton = this.elements.tabButtons[2]; // TODO: remove magic value tagButton.style.display = isTags ? "" : "none"; - const saveIcon = "content-save"; - const savingIcon = "cloud-upload-outline"; - - const saveNotesButton = new ComfyButton({ - icon: saveIcon, - tooltip: "Save note", - classList: "comfyui-button icon-button", - action: async (e) => { - const [button, icon, span] = comfyButtonDisambiguate(e.target); - button.disabled = true; - const saved = await this.trySave(false); - comfyButtonAlert(e.target, saved); - button.disabled = false; - }, - }).element; - - const downloadNotesButton = new ComfyButton({ - icon: "earth-arrow-down", - tooltip: "Attempt to download model info from the internet.", - classList: "comfyui-button icon-button", - action: async (e) => { - if (this.#savedNotesValue !== "") { - const overwriteNoteConfirmation = window.confirm("Overwrite note?"); - if (!overwriteNoteConfirmation) { - comfyButtonAlert(e.target, false, "mdi-check-bold", "mdi-close-thick"); - return; - } - } - - const [button, icon, span] = comfyButtonDisambiguate(e.target); - button.disabled = true; - const [success, downloadedNotesValue] = await comfyRequest( - `/model-manager/notes/download?path=${path}&overwrite=True`, - { - method: "POST", - body: {}, - } - ).then((data) => { - const success = data["success"]; - const message = data["alert"]; - if (message !== undefined) { - window.alert(message); - } - return [success, data["notes"]]; - }).catch((err) => { - return [false, ""]; - }); - if (success) { - this.#savedNotesValue = downloadedNotesValue; - this.elements.notes.value = downloadedNotesValue; - } - comfyButtonAlert(e.target, success, "mdi-check-bold", "mdi-close-thick"); - button.disabled = false; - }, - }).element; - - const saveDebounce = debounce(async() => { - const saveIconClass = "mdi-" + saveIcon; - const savingIconClass = "mdi-" + savingIcon; - const iconElement = saveNotesButton.getElementsByTagName("i")[0]; - iconElement.classList.remove(saveIconClass); - iconElement.classList.add(savingIconClass); - const saved = await this.trySave(false); - iconElement.classList.remove(savingIconClass); - iconElement.classList.add(saveIconClass); - }, 1000); + // + // Metadata + // /** @type {HTMLDivElement} */ - const notesElement = this.elements.tabContents[3]; // TODO: remove magic value - notesElement.innerHTML = ""; - notesElement.append.apply(notesElement, - (() => { - const notes = $el("textarea.comfy-multiline-input", { - name: "model notes", - value: noteText, - oninput: (e) => { - if (this.#settingsElements["model-info-autosave-notes"].checked) { - saveDebounce(); + const metadataElement = this.elements.tabContents[3]; // TODO: remove magic value + const isMetadata = typeof metadata === 'object' && metadata !== null && Object.keys(metadata).length > 0; + metadataElement.innerHTML = ""; + metadataElement.append.apply(metadataElement, [ + $el("h1", ["Metadata"]), + $el("div", (() => { + const tableRows = []; + if (isMetadata) { + for (const [key, value] of Object.entries(metadata)) { + if (value === undefined || value === null) { + continue; + } + if (value !== "") { + tableRows.push($el("tr", [ + $el("th.model-metadata-key", [key]), + $el("th.model-metadata-value", [value]), + ])); + } } - }, - }); - - if (navigator.userAgent.includes("Mac")) { - new KeyComboListener( - ["MetaLeft", "KeyS"], - saveDebounce, - notes, - ); - new KeyComboListener( - ["MetaRight", "KeyS"], - saveDebounce, - notes, - ); - } - else { - new KeyComboListener( - ["ControlLeft", "KeyS"], - saveDebounce, - notes, - ); - new KeyComboListener( - ["ControlRight", "KeyS"], - saveDebounce, - notes, - ); - } - - this.elements.notes = notes; - this.#savedNotesValue = noteText; - return [ - $el("div.row", { - style: { "align-items": "center" }, - }, [ - $el("h1", ["Notes"]), - saveNotesButton, - downloadNotesButton, - ]), - $el("div", { - style: { "display": "flex", "height": "100%", "min-height": "60px" }, - }, notes), - ]; - })() - ); + } + return $el("table.model-metadata", tableRows); + })(), + ), + ]); + const metadataButton = this.elements.tabButtons[3]; // TODO: remove magic value + metadataButton.style.display = isMetadata ? "" : "none"; } static UniformTagSampling(tagsAndCounts, sampleCount, frequencyThreshold = 0) { @@ -3333,7 +3345,6 @@ async function getModelInfos(urlText) { "fp": file["fp"], "quant": file["size"], "fileFormat": file["format"], - "sha256": file["hashes"]["SHA256"], }, }); }); @@ -3354,7 +3365,6 @@ async function getModelInfos(urlText) { const infos = hfInfo["modelFiles"].map((file) => { const indexSep = file.lastIndexOf("/"); const filename = file.substring(indexSep + 1); - // TODO: get sha256 of each HuggingFace model file return { "images": hfInfo["images"], "fileName": filename, @@ -3657,7 +3667,6 @@ class DownloadView { formData.append("download", info["downloadUrl"]); formData.append("path", pathDirectory); formData.append("name", modelName); - formData.append("sha256", info["details"]["sha256"]); const image = await downloadPreviewSelect.getImage(); formData.append("image", image === PREVIEW_NONE_URI ? "" : image); formData.append("overwrite", this.elements.overwrite.checked);