support download via aria2 (#797)

This commit is contained in:
dishuostec
2024-06-22 21:11:57 +08:00
committed by GitHub
parent 951e5ecd67
commit 7b9292fbd4
4 changed files with 110 additions and 1 deletions

View File

@@ -0,0 +1,67 @@
import os
aria2 = os.getenv('COMFYUI_MANAGER_ARIA2_SERVER')
HF_ENDPOINT = os.getenv('HF_ENDPOINT')
if aria2 is not None:
secret = os.getenv('COMFYUI_MANAGER_ARIA2_SECRET')
host, port = aria2.split(':')
import aria2p
aria2 = aria2p.API(aria2p.Client(host=host, port=port, secret=secret))
def download_url(model_url: str, model_dir: str, filename: str):
if aria2:
return aria2_download_url(model_url, model_dir, filename)
else:
from torchvision.datasets.utils import download_url as torchvision_download_url
return torchvision_download_url(model_url, model_dir, filename)
def aria2_find_task(dir: str, filename: str):
target = os.path.join(dir, filename)
downloads = aria2.get_downloads()
for download in downloads:
for file in download.files:
if file.is_metadata:
continue
if str(file.path) == target:
return download
def aria2_download_url(model_url: str, model_dir: str, filename: str):
import manager_core as core
import tqdm
import time
if model_dir.startswith(core.comfy_path):
model_dir = model_dir[len(core.comfy_path) :]
if HF_ENDPOINT:
model_url = model_url.replace('https://huggingface.co', HF_ENDPOINT)
download_dir = model_dir if model_dir.startswith('/') else os.path.join('/models', model_dir)
download = aria2_find_task(download_dir, filename)
if download is None:
options = {'dir': download_dir, 'out': filename}
download = aria2.add(model_url, options)[0]
if download.is_active:
with tqdm.tqdm(
total=download.total_length,
bar_format='{l_bar}{bar}{r_bar}',
desc=filename,
unit='B',
unit_scale=True,
) as progress_bar:
while download.is_active:
if progress_bar.total == 0 and download.total_length != 0:
progress_bar.reset(download.total_length)
progress_bar.update(download.completed_length - progress_bar.n)
time.sleep(1)
download.update()

View File

@@ -106,7 +106,7 @@ core.manager_funcs = ManagerFuncsInComfyUI()
sys.path.append('../..')
from torchvision.datasets.utils import download_url
from manager_downloader import download_url
core.comfy_path = os.path.dirname(folder_paths.__file__)
core.js_path = os.path.join(core.comfy_path, "web", "extensions")