Files
ComfyUI-Model-Manager/py/download.py
Hayden be383ac6e1 fix: potential bug after adding excluded directories (#94)
* Revert "fix: missing parameter (#93)"

This reverts commit c2406a1fd1.

* Revert "feat: add exclude scan model types (#92)"

This reverts commit 40a1a7f43a.

* feat: add exclude scan model types

* fix: potential bug after adding excluded directories
2025-01-14 11:04:41 +08:00

410 lines
14 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import os
import uuid
import time
import requests
import folder_paths
from typing import Callable, Awaitable, Any, Literal, Union, Optional
from dataclasses import dataclass
from . import config
from . import utils
from . import thread
@dataclass
class TaskStatus:
taskId: str
type: str
fullname: str
preview: str
status: Literal["pause", "waiting", "doing"] = "pause"
platform: Union[str, None] = None
downloadedSize: float = 0
totalSize: float = 0
progress: float = 0
bps: float = 0
error: Optional[str] = None
def __init__(self, **kwargs):
self.taskId = kwargs.get("taskId", None)
self.type = kwargs.get("type", None)
self.fullname = kwargs.get("fullname", None)
self.preview = kwargs.get("preview", None)
self.status = kwargs.get("status", "pause")
self.platform = kwargs.get("platform", None)
self.downloadedSize = kwargs.get("downloadedSize", 0)
self.totalSize = kwargs.get("totalSize", 0)
self.progress = kwargs.get("progress", 0)
self.bps = kwargs.get("bps", 0)
self.error = kwargs.get("error", None)
def to_dict(self):
return {
"taskId": self.taskId,
"type": self.type,
"fullname": self.fullname,
"preview": self.preview,
"status": self.status,
"platform": self.platform,
"downloadedSize": self.downloadedSize,
"totalSize": self.totalSize,
"progress": self.progress,
"bps": self.bps,
"error": self.error,
}
@dataclass
class TaskContent:
type: str
pathIndex: int
fullname: str
description: str
downloadPlatform: str
downloadUrl: str
sizeBytes: int
hashes: Optional[dict[str, str]] = None
def __init__(self, **kwargs):
self.type = kwargs.get("type", None)
self.pathIndex = int(kwargs.get("pathIndex", 0))
self.fullname = kwargs.get("fullname", None)
self.description = kwargs.get("description", None)
self.downloadPlatform = kwargs.get("downloadPlatform", None)
self.downloadUrl = kwargs.get("downloadUrl", None)
self.sizeBytes = int(kwargs.get("sizeBytes", 0))
self.hashes = kwargs.get("hashes", None)
def to_dict(self):
return {
"type": self.type,
"pathIndex": self.pathIndex,
"fullname": self.fullname,
"description": self.description,
"downloadPlatform": self.downloadPlatform,
"downloadUrl": self.downloadUrl,
"sizeBytes": self.sizeBytes,
"hashes": self.hashes,
}
download_model_task_status: dict[str, TaskStatus] = {}
download_thread_pool = thread.DownloadThreadPool()
def set_task_content(task_id: str, task_content: Union[TaskContent, dict]):
download_path = utils.get_download_path()
task_file_path = utils.join_path(download_path, f"{task_id}.task")
utils.save_dict_pickle_file(task_file_path, task_content)
def get_task_content(task_id: str):
download_path = utils.get_download_path()
task_file = utils.join_path(download_path, f"{task_id}.task")
if not os.path.isfile(task_file):
raise RuntimeError(f"Task {task_id} not found")
task_content = utils.load_dict_pickle_file(task_file)
return TaskContent(**task_content)
def get_task_status(task_id: str):
task_status = download_model_task_status.get(task_id, None)
if task_status is None:
download_path = utils.get_download_path()
task_content = get_task_content(task_id)
download_file = utils.join_path(download_path, f"{task_id}.download")
download_size = 0
if os.path.exists(download_file):
download_size = os.path.getsize(download_file)
total_size = task_content.sizeBytes
task_status = TaskStatus(
taskId=task_id,
type=task_content.type,
fullname=task_content.fullname,
preview=utils.get_model_preview_name(download_file),
platform=task_content.downloadPlatform,
downloadedSize=download_size,
totalSize=task_content.sizeBytes,
progress=download_size / total_size * 100 if total_size > 0 else 0,
)
download_model_task_status[task_id] = task_status
return task_status
def delete_task_status(task_id: str):
download_model_task_status.pop(task_id, None)
async def scan_model_download_task_list():
"""
Scan the download directory and send the task list to the client.
"""
download_dir = utils.get_download_path()
task_files = utils.search_files(download_dir)
task_files = folder_paths.filter_files_extensions(task_files, [".task"])
task_files = sorted(
task_files,
key=lambda x: os.stat(utils.join_path(download_dir, x)).st_ctime,
reverse=True,
)
task_list: list[dict] = []
for task_file in task_files:
task_id = task_file.replace(".task", "")
task_status = get_task_status(task_id)
task_list.append(task_status.to_dict())
return task_list
async def create_model_download_task(task_data: dict, request):
"""
Creates a download task for the given data.
"""
model_type = task_data.get("type", None)
path_index = int(task_data.get("pathIndex", None))
fullname = task_data.get("fullname", None)
model_path = utils.get_full_path(model_type, path_index, fullname)
# Check if the model path is valid
if os.path.exists(model_path):
raise RuntimeError(f"File already exists: {model_path}")
download_path = utils.get_download_path()
task_id = uuid.uuid4().hex
task_path = utils.join_path(download_path, f"{task_id}.task")
if os.path.exists(task_path):
raise RuntimeError(f"Task {task_id} already exists")
try:
preview_url = task_data.pop("preview", None)
utils.save_model_preview_image(task_path, preview_url)
set_task_content(task_id, task_data)
task_status = TaskStatus(
taskId=task_id,
type=model_type,
fullname=fullname,
preview=utils.get_model_preview_name(task_path),
platform=task_data.get("downloadPlatform", None),
totalSize=float(task_data.get("sizeBytes", 0)),
)
download_model_task_status[task_id] = task_status
await utils.send_json("create_download_task", task_status.to_dict())
except Exception as e:
await delete_model_download_task(task_id)
raise RuntimeError(str(e)) from e
await download_model(task_id, request)
return task_id
async def pause_model_download_task(task_id: str):
task_status = get_task_status(task_id=task_id)
task_status.status = "pause"
async def delete_model_download_task(task_id: str):
task_status = get_task_status(task_id)
is_running = task_status.status == "doing"
task_status.status = "waiting"
await utils.send_json("delete_download_task", task_id)
# Pause the task
if is_running:
task_status.status = "pause"
time.sleep(1)
download_dir = utils.get_download_path()
task_file_list = os.listdir(download_dir)
for task_file in task_file_list:
task_file_target = os.path.splitext(task_file)[0]
if task_file_target == task_id:
delete_task_status(task_id)
os.remove(utils.join_path(download_dir, task_file))
await utils.send_json("delete_download_task", task_id)
async def download_model(task_id: str, request):
async def download_task(task_id: str):
async def report_progress(task_status: TaskStatus):
await utils.send_json("update_download_task", task_status.to_dict())
try:
# When starting a task from the queue, the task may not exist
task_status = get_task_status(task_id)
except:
return
# Update task status
task_status.status = "doing"
await utils.send_json("update_download_task", task_status.to_dict())
try:
# Set download request headers
headers = {"User-Agent": config.user_agent}
download_platform = task_status.platform
if download_platform == "civitai":
api_key = utils.get_setting_value(request, "api_key.civitai")
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
elif download_platform == "huggingface":
api_key = utils.get_setting_value(request, "api_key.huggingface")
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
progress_interval = 1.0
await download_model_file(
task_id=task_id,
headers=headers,
progress_callback=report_progress,
interval=progress_interval,
)
except Exception as e:
task_status.status = "pause"
task_status.error = str(e)
await utils.send_json("update_download_task", task_status.to_dict())
task_status.error = None
utils.print_error(str(e))
try:
status = download_thread_pool.submit(download_task, task_id)
if status == "Waiting":
task_status = get_task_status(task_id)
task_status.status = "waiting"
await utils.send_json("update_download_task", task_status.to_dict())
except Exception as e:
task_status.status = "pause"
task_status.error = str(e)
await utils.send_json("update_download_task", task_status.to_dict())
task_status.error = None
utils.print_error(str(e))
async def download_model_file(
task_id: str,
headers: dict,
progress_callback: Callable[[TaskStatus], Awaitable[Any]],
interval: float = 1.0,
):
async def download_complete():
"""
Restore the model information from the task file
and move the model file to the target directory.
"""
model_type = task_content.type
path_index = task_content.pathIndex
fullname = task_content.fullname
# Write description file
description = task_content.description
description_file = utils.join_path(download_path, f"{task_id}.md")
with open(description_file, "w", encoding="utf-8", newline="") as f:
f.write(description)
model_path = utils.get_full_path(model_type, path_index, fullname)
utils.rename_model(download_tmp_file, model_path)
time.sleep(1)
task_file = utils.join_path(download_path, f"{task_id}.task")
os.remove(task_file)
await utils.send_json("complete_download_task", task_id)
async def update_progress():
nonlocal last_update_time
nonlocal last_downloaded_size
progress = (downloaded_size / total_size) * 100 if total_size > 0 else 0
task_status.downloadedSize = downloaded_size
task_status.progress = progress
task_status.bps = downloaded_size - last_downloaded_size
await progress_callback(task_status)
last_update_time = time.time()
last_downloaded_size = downloaded_size
task_status = get_task_status(task_id)
task_content = get_task_content(task_id)
# Check download uri
model_url = task_content.downloadUrl
if not model_url:
raise RuntimeError("No downloadUrl found")
download_path = utils.get_download_path()
download_tmp_file = utils.join_path(download_path, f"{task_id}.download")
downloaded_size = 0
if os.path.isfile(download_tmp_file):
downloaded_size = os.path.getsize(download_tmp_file)
headers["Range"] = f"bytes={downloaded_size}-"
total_size = task_content.sizeBytes
if total_size > 0 and downloaded_size == total_size:
await download_complete()
return
last_update_time = time.time()
last_downloaded_size = downloaded_size
response = requests.get(
url=model_url,
headers=headers,
stream=True,
allow_redirects=True,
)
if response.status_code not in (200, 206):
raise RuntimeError(
f"Failed to download {task_content.fullname}, status code: {response.status_code}"
)
# Some models require logging in before they can be downloaded.
# If no token is carried, it will be redirected to the login page.
content_type = response.headers.get("content-type")
if content_type and content_type.startswith("text/html"):
# TODO More checks
# In addition to requiring login to download, there may be other restrictions.
# The currently one situation is early access??? issues#43
# Due to the lack of test data, lets put it aside for now.
# If it cannot be downloaded, a redirect will definitely occur.
# Maybe consider getting the redirect url from response.history to make a judgment.
# Here we also need to consider how different websites are processed.
raise RuntimeError(
f"{task_content.fullname} needs to be logged in to download. Please set the API-Key first."
)
# When parsing model information from HuggingFace API,
# the file size was not found and needs to be obtained from the response header.
if total_size == 0:
total_size = int(response.headers.get("content-length", 0))
task_content.sizeBytes = total_size
task_status.totalSize = total_size
set_task_content(task_id, task_content)
await utils.send_json("update_download_task", task_content.to_dict())
with open(download_tmp_file, "ab") as f:
for chunk in response.iter_content(chunk_size=8192):
if task_status.status == "pause":
break
f.write(chunk)
downloaded_size += len(chunk)
if time.time() - last_update_time >= interval:
await update_progress()
await update_progress()
if total_size > 0 and downloaded_size == total_size:
await download_complete()
else:
task_status.status = "pause"
await utils.send_json("update_download_task", task_status.to_dict())