Server download enhancements & debugging.

This commit is contained in:
Christian Bastian
2024-02-12 21:06:53 -05:00
parent 1e844982c3
commit f8624698c4
3 changed files with 183 additions and 59 deletions

View File

@@ -4,10 +4,12 @@ import sys
import copy import copy
import hashlib import hashlib
import importlib import importlib
import re
from aiohttp import web from aiohttp import web
import server import server
import urllib.parse import urllib.parse
import urllib.request
import struct import struct
import json import json
import requests import requests
@@ -387,68 +389,147 @@ def_headers = {
} }
def download_model_file(url, filename): def download_file(url, filename, overwrite):
if not overwrite and os.path.isfile(filename):
raise Exception("File already exists!")
# TODO: clear any previous failed partial download file
dl_filename = filename + ".download" dl_filename = filename + ".download"
rh = requests.get( rh = requests.get(url=url, stream=True, verify=False, headers=def_headers, proxies=None, allow_redirects=False)
url=url, stream=True, verify=False, headers=def_headers, proxies=None if not rh.ok:
) raise Exception("Unable to download")
print("temp file is " + dl_filename)
total_size = int(rh.headers["Content-Length"])
basename, ext = os.path.splitext(filename)
print("Start download {}, file size: {}".format(basename, total_size))
downloaded_size = 0 downloaded_size = 0
if os.path.exists(dl_filename): if rh.status_code == 200 and os.path.exists(dl_filename):
downloaded_size = os.path.getsize(download_file) downloaded_size = os.path.getsize(dl_filename)
headers = {"Range": "bytes=%d-" % downloaded_size} headers = {"Range": "bytes=%d-" % downloaded_size}
headers["User-Agent"] = def_headers["User-Agent"] headers["User-Agent"] = def_headers["User-Agent"]
r = requests.get(url=url, stream=True, verify=False, headers=headers, proxies=None) r = requests.get(url=url, stream=True, verify=False, headers=headers, proxies=None, allow_redirects=False)
if rh.status_code == 307 and r.status_code == 307:
# Civitai redirect
redirect_url = r.content.decode("utf-8")
if not redirect_url.startswith("http"):
# Civitai requires login (NSFW or user-required)
# TODO: inform user WHY download failed
raise Exception("Unable to download!")
download_file(redirect_url, filename, overwrite)
return
if rh.status_code == 302 and r.status_code == 302:
# HuggingFace redirect
redirect_url = r.content.decode("utf-8")
redirect_url_index = redirect_url.find("http")
if redirect_url_index == -1:
raise Exception("Unable to download!")
download_file(redirect_url[redirect_url_index:], filename, overwrite)
return
elif rh.status_code == 200 and r.status_code == 206:
# Civitai download link
pass
with open(dl_filename, "ab") as f: print("temp file is " + dl_filename)
total_size = int(rh.headers.get("Content-Length", 0)) # TODO: pass in total size earlier
basename, ext = os.path.splitext(filename)
print("Start download " + basename)
if total_size != 0:
print("Download file size: " + str(total_size))
mode = "wb" if overwrite else "ab"
with open(dl_filename, mode) as f:
for chunk in r.iter_content(chunk_size=1024): for chunk in r.iter_content(chunk_size=1024):
if chunk is not None: if chunk is not None:
downloaded_size += len(chunk) downloaded_size += len(chunk)
f.write(chunk) f.write(chunk)
f.flush() f.flush()
progress = int(50 * downloaded_size / total_size) if total_size != 0:
sys.stdout.reconfigure(encoding="utf-8") fraction = 1 if downloaded_size == total_size else downloaded_size / total_size
sys.stdout.write( progress = int(50 * fraction)
"\r[%s%s] %d%%" sys.stdout.reconfigure(encoding="utf-8")
% ( sys.stdout.write(
"-" * progress, "\r[%s%s] %d%%"
" " * (50 - progress), % (
100 * downloaded_size / total_size, "-" * progress,
" " * (50 - progress),
100 * fraction,
)
) )
) sys.stdout.flush()
sys.stdout.flush()
print() print()
if overwrite and os.path.isfile(filename):
os.remove(filename)
os.rename(dl_filename, filename) os.rename(dl_filename, filename)
@server.PromptServer.instance.routes.post("/model-manager/download") @server.PromptServer.instance.routes.post("/model-manager/download")
async def download_file(request): async def download_model(request):
body = await request.json() body = await request.json()
json.dump(body, sys.stdout, indent=4)
overwrite = body.get("overwrite", False)
model_type = body.get("type") model_type = body.get("type")
model_type_path = model_type_to_dir_name(model_type) model_path_type = model_type_to_dir_name(model_type)
if model_type_path is None: if model_path_type is None or model_path_type == "":
return web.json_response({"success": False}) return web.json_response({"success": False})
model_path = body.get("path", "/0")
model_path = model_path.replace("/", os.path.sep)
regex_result = re.search(r'\d+', model_path)
if regex_result is None:
return web.json_response({"success": False})
model_path_index = int(regex_result.group())
paths = folder_paths_get_folder_paths(model_path_type)
if model_path_index < 0 or model_path_index >= len(paths):
return web.json_response({"success": False})
model_path_span = regex_result.span()
directory = os.path.join(
comfyui_model_uri,
(
paths[model_path_index] +
model_path[model_path_span[1]:]
)
)
download_uri = body.get("download") download_uri = body.get("download")
if download_uri is None: if download_uri is None:
return web.json_response({"success": False}) return web.json_response({"success": False})
model_name = body.get("name") name = body.get("name")
file_name = os.path.join(comfyui_model_uri, model_type_path, model_name) model_extension = None
download_model_file(download_uri, file_name) for ext in folder_paths_get_supported_pt_extensions(model_type):
print("File download completed!") if name.endswith(ext):
return web.json_response({"success": True}) model_extension = ext
break
if model_extension is None:
return web.json_response({"success": False})
file_name = os.path.join(directory, name)
try:
download_file(download_uri, file_name, overwrite)
except:
return web.json_response({"success": False})
image_uri = body.get("image")
if image_uri is not None and image_uri != "":
image_extension = None
for ext in image_extensions:
if image_uri.endswith(ext):
image_extension = ext
break
if image_extension is not None:
image_name = os.path.join(
directory,
(name[:len(name) - len(model_extension)]) + image_extension
)
try:
download_file(image_uri, image_name, overwrite)
except Exception as e:
print(e, file=sys.stderr, flush=True)
return web.json_response({"success": True})
WEB_DIRECTORY = "web" WEB_DIRECTORY = "web"
NODE_CLASS_MAPPINGS = {} NODE_CLASS_MAPPINGS = {}

View File

@@ -359,6 +359,7 @@
max-height: 30vh; max-height: 30vh;
overflow: auto; overflow: auto;
border-radius: 10px; border-radius: 10px;
z-index: 1;
} }
.search-dropdown:empty { .search-dropdown:empty {

View File

@@ -55,6 +55,9 @@ const MODEL_SORT_DATE_CREATED = "dateCreated";
const MODEL_SORT_DATE_MODIFIED = "dateModified"; const MODEL_SORT_DATE_MODIFIED = "dateModified";
const MODEL_SORT_DATE_NAME = "name"; const MODEL_SORT_DATE_NAME = "name";
const MODEL_EXTENSIONS = [".ckpt", ".pt", ".bin", ".pth", ".safetensors"]; // TODO: ask server for?
const IMAGE_EXTENSIONS = [".png", ".webp", ".gif"]; // TODO: ask server for?
/** /**
* Tries to return the related ComfyUI model directory if unambigious. * Tries to return the related ComfyUI model directory if unambigious.
* *
@@ -194,8 +197,11 @@ async function civitai_getFilteredInfo(stringUrl) {
} }
} }
else if (urlPath.startsWith('/models')) { else if (urlPath.startsWith('/models')) {
const idStart = urlPath.indexOf("/", 1) + 1; const idStart = urlPath.indexOf("models/") + "models/".length;
const idEnd = urlPath.indexOf("/", idStart); const idEnd = (() => {
const idEnd = urlPath.indexOf("/", idStart);
return idEnd === -1 ? urlPath.length : idEnd;
})();
const modelId = urlPath.substring(idStart, idEnd); const modelId = urlPath.substring(idStart, idEnd);
if (parseInt(modelId, 10) == NaN) { if (parseInt(modelId, 10) == NaN) {
return {}; return {};
@@ -209,7 +215,9 @@ async function civitai_getFilteredInfo(stringUrl) {
const modelVersionInfos = modelInfo["modelVersions"]; const modelVersionInfos = modelInfo["modelVersions"];
for (let i = 0; i < modelVersionInfos.length; i++) { for (let i = 0; i < modelVersionInfos.length; i++) {
const versionInfo = modelVersionInfos[i]; const versionInfo = modelVersionInfos[i];
if (modelVersionId instanceof String && modelVersionId != versionInfo["id"]) { continue; } if (!Number.isNaN(modelVersionId)) {
if (modelVersionId != versionInfo["id"]) {continue; }
}
const filesInfo = civitai_getModelFilesInfo(versionInfo); const filesInfo = civitai_getModelFilesInfo(versionInfo);
modelVersions.push(filesInfo); modelVersions.push(filesInfo);
} }
@@ -303,12 +311,11 @@ async function huggingFace_getFilteredInfo(stringUrl) {
//const modelInfo = await requestInfo(modelId + branch); // this only gives you the files at the given branch path... //const modelInfo = await requestInfo(modelId + branch); // this only gives you the files at the given branch path...
// oid: SHA-1?, lfs.oid: SHA-256 // oid: SHA-1?, lfs.oid: SHA-256
const validModelExtensions = [".ckpt", ".pt", ".bin", ".pth", ".safetensors"]; // TODO: ask server for?
const clippedFilePath = filePath.substring(filePath[0] === "/" ? 1 : 0); const clippedFilePath = filePath.substring(filePath[0] === "/" ? 1 : 0);
const modelFiles = modelInfo["siblings"].filter((sib) => { const modelFiles = modelInfo["siblings"].filter((sib) => {
const filename = sib["rfilename"]; const filename = sib["rfilename"];
for (let i = 0; i < validModelExtensions.length; i++) { for (let i = 0; i < MODEL_EXTENSIONS.length; i++) {
if (filename.endsWith(validModelExtensions[i])) { if (filename.endsWith(MODEL_EXTENSIONS[i])) {
return filename.startsWith(clippedFilePath); return filename.startsWith(clippedFilePath);
} }
} }
@@ -321,11 +328,10 @@ async function huggingFace_getFilteredInfo(stringUrl) {
return {}; return {};
} }
const validImageExtensions = [".png", ".webp", ".gif"]; // TODO: ask server for?
const imageFiles = modelInfo["siblings"].filter((sib) => { const imageFiles = modelInfo["siblings"].filter((sib) => {
const filename = sib["rfilename"]; const filename = sib["rfilename"];
for (let i = 0; i < validImageExtensions.length; i++) { for (let i = 0; i < IMAGE_EXTENSIONS.length; i++) {
if (filename.endsWith(validImageExtensions[i])) { if (filename.endsWith(IMAGE_EXTENSIONS[i])) {
return filename.startsWith(filePath); return filename.startsWith(filePath);
} }
} }
@@ -335,7 +341,7 @@ async function huggingFace_getFilteredInfo(stringUrl) {
return filename; return filename;
}); });
const baseDownloadUrl = url.origin + urlPath.substring(0, i2) + "/resolve" + branch; const baseDownloadUrl = url.origin + urlPath.substring(0, i2) + "/resolve" + branch.replace("/tree", "");
return { return {
"baseDownloadUrl": baseDownloadUrl, "baseDownloadUrl": baseDownloadUrl,
"modelFiles": modelFiles, "modelFiles": modelFiles,
@@ -353,6 +359,9 @@ async function huggingFace_getFilteredInfo(stringUrl) {
class DirectoryDropdown { class DirectoryDropdown {
/** @type {HTMLDivElement} */ /** @type {HTMLDivElement} */
element = undefined; element = undefined;
/** @type {Boolean} */
showDirectoriesOnly = false;
/** @type {HTMLInputElement} */ /** @type {HTMLInputElement} */
#input = undefined; #input = undefined;
@@ -373,8 +382,9 @@ class DirectoryDropdown {
* @param {Function} [updateCallback= () => {}] * @param {Function} [updateCallback= () => {}]
* @param {Function} [submitCallback= () => {}] * @param {Function} [submitCallback= () => {}]
* @param {String} [sep="/"] * @param {String} [sep="/"]
* @param {Boolean} [showDirectoriesOnly=false]
*/ */
constructor(input, updateDropdown, updateCallback = () => {}, submitCallback = () => {}, sep = "/") { constructor(input, updateDropdown, updateCallback = () => {}, submitCallback = () => {}, sep = "/", showDirectoriesOnly = false) {
/** @type {HTMLDivElement} */ /** @type {HTMLDivElement} */
const dropdown = $el("div.search-dropdown", { // TODO: change to `search-directory-dropdown` const dropdown = $el("div.search-dropdown", { // TODO: change to `search-directory-dropdown`
style: { style: {
@@ -386,6 +396,7 @@ class DirectoryDropdown {
this.#updateDropdown = updateDropdown; this.#updateDropdown = updateDropdown;
this.#updateCallback = updateCallback; this.#updateCallback = updateCallback;
this.#submitCallback = submitCallback; this.#submitCallback = submitCallback;
this.showDirectoriesOnly = showDirectoriesOnly;
input.addEventListener("input", () => updateDropdown()); input.addEventListener("input", () => updateDropdown());
input.addEventListener("focus", () => updateDropdown()); input.addEventListener("focus", () => updateDropdown());
@@ -554,6 +565,7 @@ class DirectoryDropdown {
const updateDropdown = this.#updateDropdown; const updateDropdown = this.#updateDropdown;
const updateCallback = this.#updateCallback; const updateCallback = this.#updateCallback;
const submitCallback = this.#submitCallback; const submitCallback = this.#submitCallback;
const showDirectoriesOnly = this.showDirectoriesOnly;
const filter = input.value; const filter = input.value;
if (filter[0] !== sep) { if (filter[0] !== sep) {
@@ -631,12 +643,12 @@ class DirectoryDropdown {
const grandChildCount = child["childCount"]; const grandChildCount = child["childCount"];
const isDir = grandChildCount !== undefined && grandChildCount !== null && grandChildCount > 0; const isDir = grandChildCount !== undefined && grandChildCount !== null && grandChildCount > 0;
const itemName = child["name"]; const itemName = child["name"];
if (itemName.startsWith(lastWord)) { if (itemName.startsWith(lastWord) && (!showDirectoriesOnly || (showDirectoriesOnly && isDir))) {
options.push(itemName + (isDir ? "/" : "")); options.push(itemName + (isDir ? "/" : ""));
} }
} }
} }
else { else if (!showDirectoriesOnly) {
const filename = item["name"]; const filename = item["name"];
if (filename.startsWith(lastWord)) { if (filename.startsWith(lastWord)) {
options.push(filename); options.push(filename);
@@ -1597,6 +1609,7 @@ class ModelManager extends ComfyDialog {
this.#modelTab_updatePreviousModelFilter, this.#modelTab_updatePreviousModelFilter,
this.#modelTab_updateModelGrid, this.#modelTab_updateModelGrid,
this.#sep, this.#sep,
false,
); );
this.#modelContentFilterDirectoryDropdown = searchDropdown; this.#modelContentFilterDirectoryDropdown = searchDropdown;
@@ -1949,10 +1962,6 @@ class ModelManager extends ComfyDialog {
filename: null, filename: null,
}; };
const datas = {
cachedUrl: "",
};
$el("input", { $el("input", {
$: (el) => (els.saveDirectoryPath = el), $: (el) => (els.saveDirectoryPath = el),
type: "text", type: "text",
@@ -1985,6 +1994,7 @@ class ModelManager extends ComfyDialog {
() => {}, () => {},
() => {}, () => {},
sep, sep,
true,
); );
const filepath = info["downloadFilePath"]; const filepath = info["downloadFilePath"];
@@ -1993,12 +2003,24 @@ class ModelManager extends ComfyDialog {
$el("div", [ $el("div", [
$el("div", [ $el("div", [
$el("button", { $el("button", {
onclick: (e) => { onclick: async (e) => {
const url = datas.cachedUrl; const record = {};
const modelType = els.modelTypeSelect.value; // TODO: cannot be empty string or invalid selection record["download"] = info["downloadUrl"];
const path = els.saveDirectoryPath.value; // TODO: server: root must be valid record["type"] = els.modelTypeSelect.value;
const filename = els.filename.value; // note: does not include file extension if (record["type"] === "") { return; } // TODO: notify user in app
const imgUrl = (() => { record["path"] = els.saveDirectoryPath.value;
record["name"] = (() => {
const filename = info["fileName"];
const name = els.filename.value;
if (name === "") {
return filename;
}
const ext = MODEL_EXTENSIONS.find((ext) => {
return filename.endsWith(ext);
}) ?? "";
return name + ext;
})();
record["image"] = (() => {
const value = document.querySelector(`input[name="${RADIO_MODEL_PREVIEW_GROUP_NAME}"]:checked`).value; const value = document.querySelector(`input[name="${RADIO_MODEL_PREVIEW_GROUP_NAME}"]:checked`).value;
switch (value) { switch (value) {
case RADIO_MODEL_PREVIEW_DEFAULT: case RADIO_MODEL_PREVIEW_DEFAULT:
@@ -2015,9 +2037,24 @@ class ModelManager extends ComfyDialog {
} }
return ""; return "";
})(); })();
// TODO: lock downloading record["overwrite"] = true; // TODO: add to UI
// TODO: send download info to server e.disabled = true;
// TODO: unlock downloading await request(
"/model-manager/download",
{
method: "POST",
body: JSON.stringify(record),
}
).then(data => {
if (data["success"] !== true) {
// TODO: notify user in app
console.error('Failed to download model:', data);
}
}).catch(err => {
// TODO: notify user in app
console.error('Failed to download model:', err);
});
e.disabled = false;
}, },
}, ["Download"]), }, ["Download"]),
els.modelTypeSelect, els.modelTypeSelect,
@@ -2198,7 +2235,7 @@ class ModelManager extends ComfyDialog {
"images": [], // TODO: ambiguous? "images": [], // TODO: ambiguous?
"fileName": filename, "fileName": filename,
"modelType": "", "modelType": "",
"downloadUrl": baseDownloadUrl + "/" + file, "downloadUrl": baseDownloadUrl + "/" + file + "?download=true",
"downloadFilePath": file.substring(0, indexSep + 1), "downloadFilePath": file.substring(0, indexSep + 1),
"details": { "details": {
"fileSizeKB": undefined, // TODO: too hard? "fileSizeKB": undefined, // TODO: too hard?
@@ -2214,7 +2251,12 @@ class ModelManager extends ComfyDialog {
})(); })();
const modelTypes = Object.keys(this.#data.models); const modelTypes = Object.keys(this.#data.models);
const modelInfosHtml = modelInfos.map((modelInfo) => { const modelInfosHtml = modelInfos.filter((modelInfo) => {
const filename = modelInfo["fileName"];
return MODEL_EXTENSIONS.find((ext) => {
return filename.endsWith(ext);
}) ?? false;
}).map((modelInfo) => {
return this.#downloadTab_modelInfo( return this.#downloadTab_modelInfo(
modelInfo, modelInfo,
modelTypes, modelTypes,