From f8624698c4494727889cadb97b1a5c0c62fda0a7 Mon Sep 17 00:00:00 2001 From: Christian Bastian Date: Mon, 12 Feb 2024 21:06:53 -0500 Subject: [PATCH] Server download enhancements & debugging. --- __init__.py | 143 +++++++++++++++++++++++++++++++++--------- web/model-manager.css | 1 + web/model-manager.js | 98 ++++++++++++++++++++--------- 3 files changed, 183 insertions(+), 59 deletions(-) diff --git a/__init__.py b/__init__.py index c20bb85..1ab9c17 100644 --- a/__init__.py +++ b/__init__.py @@ -4,10 +4,12 @@ import sys import copy import hashlib import importlib +import re from aiohttp import web import server import urllib.parse +import urllib.request import struct import json import requests @@ -387,68 +389,147 @@ def_headers = { } -def download_model_file(url, filename): +def download_file(url, filename, overwrite): + if not overwrite and os.path.isfile(filename): + raise Exception("File already exists!") + + # TODO: clear any previous failed partial download file dl_filename = filename + ".download" - rh = requests.get( - url=url, stream=True, verify=False, headers=def_headers, proxies=None - ) - print("temp file is " + dl_filename) - total_size = int(rh.headers["Content-Length"]) - - basename, ext = os.path.splitext(filename) - print("Start download {}, file size: {}".format(basename, total_size)) + rh = requests.get(url=url, stream=True, verify=False, headers=def_headers, proxies=None, allow_redirects=False) + if not rh.ok: + raise Exception("Unable to download") downloaded_size = 0 - if os.path.exists(dl_filename): - downloaded_size = os.path.getsize(download_file) + if rh.status_code == 200 and os.path.exists(dl_filename): + downloaded_size = os.path.getsize(dl_filename) headers = {"Range": "bytes=%d-" % downloaded_size} headers["User-Agent"] = def_headers["User-Agent"] - r = requests.get(url=url, stream=True, verify=False, headers=headers, proxies=None) + r = requests.get(url=url, stream=True, verify=False, headers=headers, proxies=None, allow_redirects=False) + if rh.status_code == 307 and r.status_code == 307: + # Civitai redirect + redirect_url = r.content.decode("utf-8") + if not redirect_url.startswith("http"): + # Civitai requires login (NSFW or user-required) + # TODO: inform user WHY download failed + raise Exception("Unable to download!") + download_file(redirect_url, filename, overwrite) + return + if rh.status_code == 302 and r.status_code == 302: + # HuggingFace redirect + redirect_url = r.content.decode("utf-8") + redirect_url_index = redirect_url.find("http") + if redirect_url_index == -1: + raise Exception("Unable to download!") + download_file(redirect_url[redirect_url_index:], filename, overwrite) + return + elif rh.status_code == 200 and r.status_code == 206: + # Civitai download link + pass - with open(dl_filename, "ab") as f: + print("temp file is " + dl_filename) + total_size = int(rh.headers.get("Content-Length", 0)) # TODO: pass in total size earlier + + basename, ext = os.path.splitext(filename) + print("Start download " + basename) + if total_size != 0: + print("Download file size: " + str(total_size)) + + mode = "wb" if overwrite else "ab" + with open(dl_filename, mode) as f: for chunk in r.iter_content(chunk_size=1024): if chunk is not None: downloaded_size += len(chunk) f.write(chunk) f.flush() - progress = int(50 * downloaded_size / total_size) - sys.stdout.reconfigure(encoding="utf-8") - sys.stdout.write( - "\r[%s%s] %d%%" - % ( - "-" * progress, - " " * (50 - progress), - 100 * downloaded_size / total_size, + if total_size != 0: + fraction = 1 if downloaded_size == total_size else downloaded_size / total_size + progress = int(50 * fraction) + sys.stdout.reconfigure(encoding="utf-8") + sys.stdout.write( + "\r[%s%s] %d%%" + % ( + "-" * progress, + " " * (50 - progress), + 100 * fraction, + ) ) - ) - sys.stdout.flush() + sys.stdout.flush() print() + if overwrite and os.path.isfile(filename): + os.remove(filename) os.rename(dl_filename, filename) @server.PromptServer.instance.routes.post("/model-manager/download") -async def download_file(request): +async def download_model(request): body = await request.json() + json.dump(body, sys.stdout, indent=4) + + overwrite = body.get("overwrite", False) + model_type = body.get("type") - model_type_path = model_type_to_dir_name(model_type) - if model_type_path is None: + model_path_type = model_type_to_dir_name(model_type) + if model_path_type is None or model_path_type == "": return web.json_response({"success": False}) + model_path = body.get("path", "/0") + model_path = model_path.replace("/", os.path.sep) + regex_result = re.search(r'\d+', model_path) + if regex_result is None: + return web.json_response({"success": False}) + model_path_index = int(regex_result.group()) + paths = folder_paths_get_folder_paths(model_path_type) + if model_path_index < 0 or model_path_index >= len(paths): + return web.json_response({"success": False}) + model_path_span = regex_result.span() + directory = os.path.join( + comfyui_model_uri, + ( + paths[model_path_index] + + model_path[model_path_span[1]:] + ) + ) download_uri = body.get("download") if download_uri is None: return web.json_response({"success": False}) - model_name = body.get("name") - file_name = os.path.join(comfyui_model_uri, model_type_path, model_name) - download_model_file(download_uri, file_name) - print("File download completed!") - return web.json_response({"success": True}) + name = body.get("name") + model_extension = None + for ext in folder_paths_get_supported_pt_extensions(model_type): + if name.endswith(ext): + model_extension = ext + break + if model_extension is None: + return web.json_response({"success": False}) + file_name = os.path.join(directory, name) + try: + download_file(download_uri, file_name, overwrite) + except: + return web.json_response({"success": False}) + image_uri = body.get("image") + if image_uri is not None and image_uri != "": + image_extension = None + for ext in image_extensions: + if image_uri.endswith(ext): + image_extension = ext + break + if image_extension is not None: + image_name = os.path.join( + directory, + (name[:len(name) - len(model_extension)]) + image_extension + ) + try: + download_file(image_uri, image_name, overwrite) + except Exception as e: + print(e, file=sys.stderr, flush=True) + + return web.json_response({"success": True}) WEB_DIRECTORY = "web" NODE_CLASS_MAPPINGS = {} diff --git a/web/model-manager.css b/web/model-manager.css index 180630c..a01a1dd 100644 --- a/web/model-manager.css +++ b/web/model-manager.css @@ -359,6 +359,7 @@ max-height: 30vh; overflow: auto; border-radius: 10px; + z-index: 1; } .search-dropdown:empty { diff --git a/web/model-manager.js b/web/model-manager.js index cafadc3..34b8003 100644 --- a/web/model-manager.js +++ b/web/model-manager.js @@ -55,6 +55,9 @@ const MODEL_SORT_DATE_CREATED = "dateCreated"; const MODEL_SORT_DATE_MODIFIED = "dateModified"; const MODEL_SORT_DATE_NAME = "name"; +const MODEL_EXTENSIONS = [".ckpt", ".pt", ".bin", ".pth", ".safetensors"]; // TODO: ask server for? +const IMAGE_EXTENSIONS = [".png", ".webp", ".gif"]; // TODO: ask server for? + /** * Tries to return the related ComfyUI model directory if unambigious. * @@ -194,8 +197,11 @@ async function civitai_getFilteredInfo(stringUrl) { } } else if (urlPath.startsWith('/models')) { - const idStart = urlPath.indexOf("/", 1) + 1; - const idEnd = urlPath.indexOf("/", idStart); + const idStart = urlPath.indexOf("models/") + "models/".length; + const idEnd = (() => { + const idEnd = urlPath.indexOf("/", idStart); + return idEnd === -1 ? urlPath.length : idEnd; + })(); const modelId = urlPath.substring(idStart, idEnd); if (parseInt(modelId, 10) == NaN) { return {}; @@ -209,7 +215,9 @@ async function civitai_getFilteredInfo(stringUrl) { const modelVersionInfos = modelInfo["modelVersions"]; for (let i = 0; i < modelVersionInfos.length; i++) { const versionInfo = modelVersionInfos[i]; - if (modelVersionId instanceof String && modelVersionId != versionInfo["id"]) { continue; } + if (!Number.isNaN(modelVersionId)) { + if (modelVersionId != versionInfo["id"]) {continue; } + } const filesInfo = civitai_getModelFilesInfo(versionInfo); modelVersions.push(filesInfo); } @@ -303,12 +311,11 @@ async function huggingFace_getFilteredInfo(stringUrl) { //const modelInfo = await requestInfo(modelId + branch); // this only gives you the files at the given branch path... // oid: SHA-1?, lfs.oid: SHA-256 - const validModelExtensions = [".ckpt", ".pt", ".bin", ".pth", ".safetensors"]; // TODO: ask server for? const clippedFilePath = filePath.substring(filePath[0] === "/" ? 1 : 0); const modelFiles = modelInfo["siblings"].filter((sib) => { const filename = sib["rfilename"]; - for (let i = 0; i < validModelExtensions.length; i++) { - if (filename.endsWith(validModelExtensions[i])) { + for (let i = 0; i < MODEL_EXTENSIONS.length; i++) { + if (filename.endsWith(MODEL_EXTENSIONS[i])) { return filename.startsWith(clippedFilePath); } } @@ -321,11 +328,10 @@ async function huggingFace_getFilteredInfo(stringUrl) { return {}; } - const validImageExtensions = [".png", ".webp", ".gif"]; // TODO: ask server for? const imageFiles = modelInfo["siblings"].filter((sib) => { const filename = sib["rfilename"]; - for (let i = 0; i < validImageExtensions.length; i++) { - if (filename.endsWith(validImageExtensions[i])) { + for (let i = 0; i < IMAGE_EXTENSIONS.length; i++) { + if (filename.endsWith(IMAGE_EXTENSIONS[i])) { return filename.startsWith(filePath); } } @@ -335,7 +341,7 @@ async function huggingFace_getFilteredInfo(stringUrl) { return filename; }); - const baseDownloadUrl = url.origin + urlPath.substring(0, i2) + "/resolve" + branch; + const baseDownloadUrl = url.origin + urlPath.substring(0, i2) + "/resolve" + branch.replace("/tree", ""); return { "baseDownloadUrl": baseDownloadUrl, "modelFiles": modelFiles, @@ -353,6 +359,9 @@ async function huggingFace_getFilteredInfo(stringUrl) { class DirectoryDropdown { /** @type {HTMLDivElement} */ element = undefined; + + /** @type {Boolean} */ + showDirectoriesOnly = false; /** @type {HTMLInputElement} */ #input = undefined; @@ -373,8 +382,9 @@ class DirectoryDropdown { * @param {Function} [updateCallback= () => {}] * @param {Function} [submitCallback= () => {}] * @param {String} [sep="/"] + * @param {Boolean} [showDirectoriesOnly=false] */ - constructor(input, updateDropdown, updateCallback = () => {}, submitCallback = () => {}, sep = "/") { + constructor(input, updateDropdown, updateCallback = () => {}, submitCallback = () => {}, sep = "/", showDirectoriesOnly = false) { /** @type {HTMLDivElement} */ const dropdown = $el("div.search-dropdown", { // TODO: change to `search-directory-dropdown` style: { @@ -386,6 +396,7 @@ class DirectoryDropdown { this.#updateDropdown = updateDropdown; this.#updateCallback = updateCallback; this.#submitCallback = submitCallback; + this.showDirectoriesOnly = showDirectoriesOnly; input.addEventListener("input", () => updateDropdown()); input.addEventListener("focus", () => updateDropdown()); @@ -554,6 +565,7 @@ class DirectoryDropdown { const updateDropdown = this.#updateDropdown; const updateCallback = this.#updateCallback; const submitCallback = this.#submitCallback; + const showDirectoriesOnly = this.showDirectoriesOnly; const filter = input.value; if (filter[0] !== sep) { @@ -631,12 +643,12 @@ class DirectoryDropdown { const grandChildCount = child["childCount"]; const isDir = grandChildCount !== undefined && grandChildCount !== null && grandChildCount > 0; const itemName = child["name"]; - if (itemName.startsWith(lastWord)) { + if (itemName.startsWith(lastWord) && (!showDirectoriesOnly || (showDirectoriesOnly && isDir))) { options.push(itemName + (isDir ? "/" : "")); } } } - else { + else if (!showDirectoriesOnly) { const filename = item["name"]; if (filename.startsWith(lastWord)) { options.push(filename); @@ -1597,6 +1609,7 @@ class ModelManager extends ComfyDialog { this.#modelTab_updatePreviousModelFilter, this.#modelTab_updateModelGrid, this.#sep, + false, ); this.#modelContentFilterDirectoryDropdown = searchDropdown; @@ -1949,10 +1962,6 @@ class ModelManager extends ComfyDialog { filename: null, }; - const datas = { - cachedUrl: "", - }; - $el("input", { $: (el) => (els.saveDirectoryPath = el), type: "text", @@ -1985,6 +1994,7 @@ class ModelManager extends ComfyDialog { () => {}, () => {}, sep, + true, ); const filepath = info["downloadFilePath"]; @@ -1993,12 +2003,24 @@ class ModelManager extends ComfyDialog { $el("div", [ $el("div", [ $el("button", { - onclick: (e) => { - const url = datas.cachedUrl; - const modelType = els.modelTypeSelect.value; // TODO: cannot be empty string or invalid selection - const path = els.saveDirectoryPath.value; // TODO: server: root must be valid - const filename = els.filename.value; // note: does not include file extension - const imgUrl = (() => { + onclick: async (e) => { + const record = {}; + record["download"] = info["downloadUrl"]; + record["type"] = els.modelTypeSelect.value; + if (record["type"] === "") { return; } // TODO: notify user in app + record["path"] = els.saveDirectoryPath.value; + record["name"] = (() => { + const filename = info["fileName"]; + const name = els.filename.value; + if (name === "") { + return filename; + } + const ext = MODEL_EXTENSIONS.find((ext) => { + return filename.endsWith(ext); + }) ?? ""; + return name + ext; + })(); + record["image"] = (() => { const value = document.querySelector(`input[name="${RADIO_MODEL_PREVIEW_GROUP_NAME}"]:checked`).value; switch (value) { case RADIO_MODEL_PREVIEW_DEFAULT: @@ -2015,9 +2037,24 @@ class ModelManager extends ComfyDialog { } return ""; })(); - // TODO: lock downloading - // TODO: send download info to server - // TODO: unlock downloading + record["overwrite"] = true; // TODO: add to UI + e.disabled = true; + await request( + "/model-manager/download", + { + method: "POST", + body: JSON.stringify(record), + } + ).then(data => { + if (data["success"] !== true) { + // TODO: notify user in app + console.error('Failed to download model:', data); + } + }).catch(err => { + // TODO: notify user in app + console.error('Failed to download model:', err); + }); + e.disabled = false; }, }, ["Download"]), els.modelTypeSelect, @@ -2198,7 +2235,7 @@ class ModelManager extends ComfyDialog { "images": [], // TODO: ambiguous? "fileName": filename, "modelType": "", - "downloadUrl": baseDownloadUrl + "/" + file, + "downloadUrl": baseDownloadUrl + "/" + file + "?download=true", "downloadFilePath": file.substring(0, indexSep + 1), "details": { "fileSizeKB": undefined, // TODO: too hard? @@ -2214,7 +2251,12 @@ class ModelManager extends ComfyDialog { })(); const modelTypes = Object.keys(this.#data.models); - const modelInfosHtml = modelInfos.map((modelInfo) => { + const modelInfosHtml = modelInfos.filter((modelInfo) => { + const filename = modelInfo["fileName"]; + return MODEL_EXTENSIONS.find((ext) => { + return filename.endsWith(ext); + }) ?? false; + }).map((modelInfo) => { return this.#downloadTab_modelInfo( modelInfo, modelTypes,