Server download enhancements & debugging.
This commit is contained in:
143
__init__.py
143
__init__.py
@@ -4,10 +4,12 @@ import sys
|
||||
import copy
|
||||
import hashlib
|
||||
import importlib
|
||||
import re
|
||||
|
||||
from aiohttp import web
|
||||
import server
|
||||
import urllib.parse
|
||||
import urllib.request
|
||||
import struct
|
||||
import json
|
||||
import requests
|
||||
@@ -387,68 +389,147 @@ def_headers = {
|
||||
}
|
||||
|
||||
|
||||
def download_model_file(url, filename):
|
||||
def download_file(url, filename, overwrite):
|
||||
if not overwrite and os.path.isfile(filename):
|
||||
raise Exception("File already exists!")
|
||||
|
||||
# TODO: clear any previous failed partial download file
|
||||
dl_filename = filename + ".download"
|
||||
|
||||
rh = requests.get(
|
||||
url=url, stream=True, verify=False, headers=def_headers, proxies=None
|
||||
)
|
||||
print("temp file is " + dl_filename)
|
||||
total_size = int(rh.headers["Content-Length"])
|
||||
|
||||
basename, ext = os.path.splitext(filename)
|
||||
print("Start download {}, file size: {}".format(basename, total_size))
|
||||
rh = requests.get(url=url, stream=True, verify=False, headers=def_headers, proxies=None, allow_redirects=False)
|
||||
if not rh.ok:
|
||||
raise Exception("Unable to download")
|
||||
|
||||
downloaded_size = 0
|
||||
if os.path.exists(dl_filename):
|
||||
downloaded_size = os.path.getsize(download_file)
|
||||
if rh.status_code == 200 and os.path.exists(dl_filename):
|
||||
downloaded_size = os.path.getsize(dl_filename)
|
||||
|
||||
headers = {"Range": "bytes=%d-" % downloaded_size}
|
||||
headers["User-Agent"] = def_headers["User-Agent"]
|
||||
|
||||
r = requests.get(url=url, stream=True, verify=False, headers=headers, proxies=None)
|
||||
r = requests.get(url=url, stream=True, verify=False, headers=headers, proxies=None, allow_redirects=False)
|
||||
if rh.status_code == 307 and r.status_code == 307:
|
||||
# Civitai redirect
|
||||
redirect_url = r.content.decode("utf-8")
|
||||
if not redirect_url.startswith("http"):
|
||||
# Civitai requires login (NSFW or user-required)
|
||||
# TODO: inform user WHY download failed
|
||||
raise Exception("Unable to download!")
|
||||
download_file(redirect_url, filename, overwrite)
|
||||
return
|
||||
if rh.status_code == 302 and r.status_code == 302:
|
||||
# HuggingFace redirect
|
||||
redirect_url = r.content.decode("utf-8")
|
||||
redirect_url_index = redirect_url.find("http")
|
||||
if redirect_url_index == -1:
|
||||
raise Exception("Unable to download!")
|
||||
download_file(redirect_url[redirect_url_index:], filename, overwrite)
|
||||
return
|
||||
elif rh.status_code == 200 and r.status_code == 206:
|
||||
# Civitai download link
|
||||
pass
|
||||
|
||||
with open(dl_filename, "ab") as f:
|
||||
print("temp file is " + dl_filename)
|
||||
total_size = int(rh.headers.get("Content-Length", 0)) # TODO: pass in total size earlier
|
||||
|
||||
basename, ext = os.path.splitext(filename)
|
||||
print("Start download " + basename)
|
||||
if total_size != 0:
|
||||
print("Download file size: " + str(total_size))
|
||||
|
||||
mode = "wb" if overwrite else "ab"
|
||||
with open(dl_filename, mode) as f:
|
||||
for chunk in r.iter_content(chunk_size=1024):
|
||||
if chunk is not None:
|
||||
downloaded_size += len(chunk)
|
||||
f.write(chunk)
|
||||
f.flush()
|
||||
|
||||
progress = int(50 * downloaded_size / total_size)
|
||||
sys.stdout.reconfigure(encoding="utf-8")
|
||||
sys.stdout.write(
|
||||
"\r[%s%s] %d%%"
|
||||
% (
|
||||
"-" * progress,
|
||||
" " * (50 - progress),
|
||||
100 * downloaded_size / total_size,
|
||||
if total_size != 0:
|
||||
fraction = 1 if downloaded_size == total_size else downloaded_size / total_size
|
||||
progress = int(50 * fraction)
|
||||
sys.stdout.reconfigure(encoding="utf-8")
|
||||
sys.stdout.write(
|
||||
"\r[%s%s] %d%%"
|
||||
% (
|
||||
"-" * progress,
|
||||
" " * (50 - progress),
|
||||
100 * fraction,
|
||||
)
|
||||
)
|
||||
)
|
||||
sys.stdout.flush()
|
||||
sys.stdout.flush()
|
||||
|
||||
print()
|
||||
if overwrite and os.path.isfile(filename):
|
||||
os.remove(filename)
|
||||
os.rename(dl_filename, filename)
|
||||
|
||||
|
||||
@server.PromptServer.instance.routes.post("/model-manager/download")
|
||||
async def download_file(request):
|
||||
async def download_model(request):
|
||||
body = await request.json()
|
||||
json.dump(body, sys.stdout, indent=4)
|
||||
|
||||
overwrite = body.get("overwrite", False)
|
||||
|
||||
model_type = body.get("type")
|
||||
model_type_path = model_type_to_dir_name(model_type)
|
||||
if model_type_path is None:
|
||||
model_path_type = model_type_to_dir_name(model_type)
|
||||
if model_path_type is None or model_path_type == "":
|
||||
return web.json_response({"success": False})
|
||||
model_path = body.get("path", "/0")
|
||||
model_path = model_path.replace("/", os.path.sep)
|
||||
regex_result = re.search(r'\d+', model_path)
|
||||
if regex_result is None:
|
||||
return web.json_response({"success": False})
|
||||
model_path_index = int(regex_result.group())
|
||||
paths = folder_paths_get_folder_paths(model_path_type)
|
||||
if model_path_index < 0 or model_path_index >= len(paths):
|
||||
return web.json_response({"success": False})
|
||||
model_path_span = regex_result.span()
|
||||
directory = os.path.join(
|
||||
comfyui_model_uri,
|
||||
(
|
||||
paths[model_path_index] +
|
||||
model_path[model_path_span[1]:]
|
||||
)
|
||||
)
|
||||
|
||||
download_uri = body.get("download")
|
||||
if download_uri is None:
|
||||
return web.json_response({"success": False})
|
||||
|
||||
model_name = body.get("name")
|
||||
file_name = os.path.join(comfyui_model_uri, model_type_path, model_name)
|
||||
download_model_file(download_uri, file_name)
|
||||
print("File download completed!")
|
||||
return web.json_response({"success": True})
|
||||
name = body.get("name")
|
||||
model_extension = None
|
||||
for ext in folder_paths_get_supported_pt_extensions(model_type):
|
||||
if name.endswith(ext):
|
||||
model_extension = ext
|
||||
break
|
||||
if model_extension is None:
|
||||
return web.json_response({"success": False})
|
||||
file_name = os.path.join(directory, name)
|
||||
try:
|
||||
download_file(download_uri, file_name, overwrite)
|
||||
except:
|
||||
return web.json_response({"success": False})
|
||||
|
||||
image_uri = body.get("image")
|
||||
if image_uri is not None and image_uri != "":
|
||||
image_extension = None
|
||||
for ext in image_extensions:
|
||||
if image_uri.endswith(ext):
|
||||
image_extension = ext
|
||||
break
|
||||
if image_extension is not None:
|
||||
image_name = os.path.join(
|
||||
directory,
|
||||
(name[:len(name) - len(model_extension)]) + image_extension
|
||||
)
|
||||
try:
|
||||
download_file(image_uri, image_name, overwrite)
|
||||
except Exception as e:
|
||||
print(e, file=sys.stderr, flush=True)
|
||||
|
||||
return web.json_response({"success": True})
|
||||
|
||||
WEB_DIRECTORY = "web"
|
||||
NODE_CLASS_MAPPINGS = {}
|
||||
|
||||
Reference in New Issue
Block a user