fix: download module error (#141)
This commit is contained in:
651
py/download.py
651
py/download.py
@@ -3,8 +3,13 @@ import uuid
|
||||
import time
|
||||
import requests
|
||||
import folder_paths
|
||||
|
||||
|
||||
from typing import Callable, Awaitable, Any, Literal, Union, Optional
|
||||
from dataclasses import dataclass
|
||||
from aiohttp import web
|
||||
|
||||
|
||||
from . import config
|
||||
from . import utils
|
||||
from . import thread
|
||||
@@ -87,330 +92,6 @@ class TaskContent:
|
||||
}
|
||||
|
||||
|
||||
download_model_task_status: dict[str, TaskStatus] = {}
|
||||
download_thread_pool = thread.DownloadThreadPool()
|
||||
|
||||
|
||||
def set_task_content(task_id: str, task_content: Union[TaskContent, dict]):
|
||||
download_path = utils.get_download_path()
|
||||
task_file_path = utils.join_path(download_path, f"{task_id}.task")
|
||||
utils.save_dict_pickle_file(task_file_path, task_content)
|
||||
|
||||
|
||||
def get_task_content(task_id: str):
|
||||
download_path = utils.get_download_path()
|
||||
task_file = utils.join_path(download_path, f"{task_id}.task")
|
||||
if not os.path.isfile(task_file):
|
||||
raise RuntimeError(f"Task {task_id} not found")
|
||||
task_content = utils.load_dict_pickle_file(task_file)
|
||||
if isinstance(task_content, TaskContent):
|
||||
return task_content
|
||||
return TaskContent(**task_content)
|
||||
|
||||
|
||||
def get_task_status(task_id: str):
|
||||
task_status = download_model_task_status.get(task_id, None)
|
||||
|
||||
if task_status is None:
|
||||
download_path = utils.get_download_path()
|
||||
task_content = get_task_content(task_id)
|
||||
download_file = utils.join_path(download_path, f"{task_id}.download")
|
||||
download_size = 0
|
||||
if os.path.exists(download_file):
|
||||
download_size = os.path.getsize(download_file)
|
||||
|
||||
total_size = task_content.sizeBytes
|
||||
task_status = TaskStatus(
|
||||
taskId=task_id,
|
||||
type=task_content.type,
|
||||
fullname=task_content.fullname,
|
||||
preview=utils.get_model_preview_name(download_file),
|
||||
platform=task_content.downloadPlatform,
|
||||
downloadedSize=download_size,
|
||||
totalSize=task_content.sizeBytes,
|
||||
progress=download_size / total_size * 100 if total_size > 0 else 0,
|
||||
)
|
||||
|
||||
download_model_task_status[task_id] = task_status
|
||||
|
||||
return task_status
|
||||
|
||||
|
||||
def delete_task_status(task_id: str):
|
||||
download_model_task_status.pop(task_id, None)
|
||||
|
||||
|
||||
async def scan_model_download_task_list():
|
||||
"""
|
||||
Scan the download directory and send the task list to the client.
|
||||
"""
|
||||
download_dir = utils.get_download_path()
|
||||
task_files = utils.search_files(download_dir)
|
||||
task_files = folder_paths.filter_files_extensions(task_files, [".task"])
|
||||
task_files = sorted(
|
||||
task_files,
|
||||
key=lambda x: os.stat(utils.join_path(download_dir, x)).st_ctime,
|
||||
reverse=True,
|
||||
)
|
||||
task_list: list[dict] = []
|
||||
for task_file in task_files:
|
||||
task_id = task_file.replace(".task", "")
|
||||
task_status = get_task_status(task_id)
|
||||
task_list.append(task_status.to_dict())
|
||||
|
||||
return task_list
|
||||
|
||||
|
||||
async def create_model_download_task(task_data: dict, request):
|
||||
"""
|
||||
Creates a download task for the given data.
|
||||
"""
|
||||
model_type = task_data.get("type", None)
|
||||
path_index = int(task_data.get("pathIndex", None))
|
||||
fullname = task_data.get("fullname", None)
|
||||
|
||||
model_path = utils.get_full_path(model_type, path_index, fullname)
|
||||
# Check if the model path is valid
|
||||
if os.path.exists(model_path):
|
||||
raise RuntimeError(f"File already exists: {model_path}")
|
||||
|
||||
download_path = utils.get_download_path()
|
||||
|
||||
task_id = uuid.uuid4().hex
|
||||
task_path = utils.join_path(download_path, f"{task_id}.task")
|
||||
if os.path.exists(task_path):
|
||||
raise RuntimeError(f"Task {task_id} already exists")
|
||||
download_platform = task_data.get("downloadPlatform", None)
|
||||
|
||||
try:
|
||||
preview_file = task_data.pop("previewFile", None)
|
||||
utils.save_model_preview_image(task_path, preview_file, download_platform)
|
||||
set_task_content(task_id, task_data)
|
||||
task_status = TaskStatus(
|
||||
taskId=task_id,
|
||||
type=model_type,
|
||||
fullname=fullname,
|
||||
preview=utils.get_model_preview_name(task_path),
|
||||
platform=download_platform,
|
||||
totalSize=float(task_data.get("sizeBytes", 0)),
|
||||
)
|
||||
download_model_task_status[task_id] = task_status
|
||||
await utils.send_json("create_download_task", task_status.to_dict())
|
||||
except Exception as e:
|
||||
await delete_model_download_task(task_id)
|
||||
raise RuntimeError(str(e)) from e
|
||||
|
||||
await download_model(task_id, request)
|
||||
return task_id
|
||||
|
||||
|
||||
async def pause_model_download_task(task_id: str):
|
||||
task_status = get_task_status(task_id=task_id)
|
||||
task_status.status = "pause"
|
||||
|
||||
|
||||
async def delete_model_download_task(task_id: str):
|
||||
task_status = get_task_status(task_id)
|
||||
is_running = task_status.status == "doing"
|
||||
task_status.status = "waiting"
|
||||
await utils.send_json("delete_download_task", task_id)
|
||||
|
||||
# Pause the task
|
||||
if is_running:
|
||||
task_status.status = "pause"
|
||||
time.sleep(1)
|
||||
|
||||
download_dir = utils.get_download_path()
|
||||
task_file_list = os.listdir(download_dir)
|
||||
for task_file in task_file_list:
|
||||
task_file_target = os.path.splitext(task_file)[0]
|
||||
if task_file_target == task_id:
|
||||
delete_task_status(task_id)
|
||||
os.remove(utils.join_path(download_dir, task_file))
|
||||
|
||||
await utils.send_json("delete_download_task", task_id)
|
||||
|
||||
|
||||
async def download_model(task_id: str, request):
|
||||
async def download_task(task_id: str):
|
||||
async def report_progress(task_status: TaskStatus):
|
||||
await utils.send_json("update_download_task", task_status.to_dict())
|
||||
|
||||
try:
|
||||
# When starting a task from the queue, the task may not exist
|
||||
task_status = get_task_status(task_id)
|
||||
except:
|
||||
return
|
||||
|
||||
# Update task status
|
||||
task_status.status = "doing"
|
||||
await utils.send_json("update_download_task", task_status.to_dict())
|
||||
|
||||
try:
|
||||
|
||||
# Set download request headers
|
||||
headers = {"User-Agent": config.user_agent}
|
||||
|
||||
download_platform = task_status.platform
|
||||
if download_platform == "civitai":
|
||||
api_key = utils.get_setting_value(request, "api_key.civitai")
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
|
||||
elif download_platform == "huggingface":
|
||||
api_key = utils.get_setting_value(request, "api_key.huggingface")
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
|
||||
progress_interval = 1.0
|
||||
await download_model_file(
|
||||
task_id=task_id,
|
||||
headers=headers,
|
||||
progress_callback=report_progress,
|
||||
interval=progress_interval,
|
||||
)
|
||||
except Exception as e:
|
||||
task_status.status = "pause"
|
||||
task_status.error = str(e)
|
||||
await utils.send_json("update_download_task", task_status.to_dict())
|
||||
task_status.error = None
|
||||
utils.print_error(str(e))
|
||||
|
||||
try:
|
||||
status = download_thread_pool.submit(download_task, task_id)
|
||||
if status == "Waiting":
|
||||
task_status = get_task_status(task_id)
|
||||
task_status.status = "waiting"
|
||||
await utils.send_json("update_download_task", task_status.to_dict())
|
||||
except Exception as e:
|
||||
task_status.status = "pause"
|
||||
task_status.error = str(e)
|
||||
await utils.send_json("update_download_task", task_status.to_dict())
|
||||
task_status.error = None
|
||||
utils.print_error(str(e))
|
||||
|
||||
|
||||
async def download_model_file(
|
||||
task_id: str,
|
||||
headers: dict,
|
||||
progress_callback: Callable[[TaskStatus], Awaitable[Any]],
|
||||
interval: float = 1.0,
|
||||
):
|
||||
|
||||
async def download_complete():
|
||||
"""
|
||||
Restore the model information from the task file
|
||||
and move the model file to the target directory.
|
||||
"""
|
||||
model_type = task_content.type
|
||||
path_index = task_content.pathIndex
|
||||
fullname = task_content.fullname
|
||||
# Write description file
|
||||
description = task_content.description
|
||||
description_file = utils.join_path(download_path, f"{task_id}.md")
|
||||
with open(description_file, "w", encoding="utf-8", newline="") as f:
|
||||
f.write(description)
|
||||
|
||||
model_path = utils.get_full_path(model_type, path_index, fullname)
|
||||
|
||||
utils.rename_model(download_tmp_file, model_path)
|
||||
|
||||
time.sleep(1)
|
||||
task_file = utils.join_path(download_path, f"{task_id}.task")
|
||||
os.remove(task_file)
|
||||
await utils.send_json("complete_download_task", task_id)
|
||||
|
||||
async def update_progress():
|
||||
nonlocal last_update_time
|
||||
nonlocal last_downloaded_size
|
||||
progress = (downloaded_size / total_size) * 100 if total_size > 0 else 0
|
||||
task_status.downloadedSize = downloaded_size
|
||||
task_status.progress = progress
|
||||
task_status.bps = downloaded_size - last_downloaded_size
|
||||
await progress_callback(task_status)
|
||||
last_update_time = time.time()
|
||||
last_downloaded_size = downloaded_size
|
||||
|
||||
task_status = get_task_status(task_id)
|
||||
task_content = get_task_content(task_id)
|
||||
|
||||
# Check download uri
|
||||
model_url = task_content.downloadUrl
|
||||
if not model_url:
|
||||
raise RuntimeError("No downloadUrl found")
|
||||
|
||||
download_path = utils.get_download_path()
|
||||
download_tmp_file = utils.join_path(download_path, f"{task_id}.download")
|
||||
|
||||
downloaded_size = 0
|
||||
if os.path.isfile(download_tmp_file):
|
||||
downloaded_size = os.path.getsize(download_tmp_file)
|
||||
headers["Range"] = f"bytes={downloaded_size}-"
|
||||
|
||||
total_size = task_content.sizeBytes
|
||||
|
||||
if total_size > 0 and downloaded_size == total_size:
|
||||
await download_complete()
|
||||
return
|
||||
|
||||
last_update_time = time.time()
|
||||
last_downloaded_size = downloaded_size
|
||||
|
||||
response = requests.get(
|
||||
url=model_url,
|
||||
headers=headers,
|
||||
stream=True,
|
||||
allow_redirects=True,
|
||||
)
|
||||
|
||||
if response.status_code not in (200, 206):
|
||||
raise RuntimeError(f"Failed to download {task_content.fullname}, status code: {response.status_code}")
|
||||
|
||||
# Some models require logging in before they can be downloaded.
|
||||
# If no token is carried, it will be redirected to the login page.
|
||||
content_type = response.headers.get("content-type")
|
||||
if content_type and content_type.startswith("text/html"):
|
||||
# TODO More checks
|
||||
# In addition to requiring login to download, there may be other restrictions.
|
||||
# The currently one situation is early access??? issues#43
|
||||
# Due to the lack of test data, let’s put it aside for now.
|
||||
# If it cannot be downloaded, a redirect will definitely occur.
|
||||
# Maybe consider getting the redirect url from response.history to make a judgment.
|
||||
# Here we also need to consider how different websites are processed.
|
||||
raise RuntimeError(f"{task_content.fullname} needs to be logged in to download. Please set the API-Key first.")
|
||||
|
||||
# When parsing model information from HuggingFace API,
|
||||
# the file size was not found and needs to be obtained from the response header.
|
||||
if total_size == 0:
|
||||
total_size = float(response.headers.get("content-length", 0))
|
||||
task_content.sizeBytes = total_size
|
||||
task_status.totalSize = total_size
|
||||
set_task_content(task_id, task_content)
|
||||
await utils.send_json("update_download_task", task_content.to_dict())
|
||||
|
||||
with open(download_tmp_file, "ab") as f:
|
||||
for chunk in response.iter_content(chunk_size=8192):
|
||||
if task_status.status == "pause":
|
||||
break
|
||||
|
||||
f.write(chunk)
|
||||
downloaded_size += len(chunk)
|
||||
|
||||
if time.time() - last_update_time >= interval:
|
||||
await update_progress()
|
||||
|
||||
await update_progress()
|
||||
|
||||
if total_size > 0 and downloaded_size == total_size:
|
||||
await download_complete()
|
||||
else:
|
||||
task_status.status = "pause"
|
||||
await utils.send_json("update_download_task", task_status.to_dict())
|
||||
|
||||
|
||||
from aiohttp import web
|
||||
|
||||
|
||||
class ModelDownload:
|
||||
def add_routes(self, routes):
|
||||
|
||||
@@ -420,7 +101,7 @@ class ModelDownload:
|
||||
Read download task list.
|
||||
"""
|
||||
try:
|
||||
result = await scan_model_download_task_list()
|
||||
result = await self.scan_model_download_task_list()
|
||||
return web.json_response({"success": True, "data": result})
|
||||
except Exception as e:
|
||||
error_msg = f"Read download task list failed: {e}"
|
||||
@@ -439,9 +120,9 @@ class ModelDownload:
|
||||
json_data = await request.json()
|
||||
status = json_data.get("status", None)
|
||||
if status == "pause":
|
||||
await pause_model_download_task(task_id)
|
||||
await self.pause_model_download_task(task_id)
|
||||
elif status == "resume":
|
||||
await download_model(task_id, request)
|
||||
await self.download_model(task_id, request)
|
||||
else:
|
||||
raise web.HTTPBadRequest(reason="Invalid status")
|
||||
|
||||
@@ -458,7 +139,7 @@ class ModelDownload:
|
||||
"""
|
||||
task_id = request.match_info.get("task_id", None)
|
||||
try:
|
||||
await delete_model_download_task(task_id)
|
||||
await self.delete_model_download_task(task_id)
|
||||
return web.json_response({"success": True})
|
||||
except Exception as e:
|
||||
error_msg = f"Delete download task failed: {str(e)}"
|
||||
@@ -483,9 +164,321 @@ class ModelDownload:
|
||||
task_data = await request.post()
|
||||
task_data = dict(task_data)
|
||||
try:
|
||||
task_id = await create_model_download_task(task_data, request)
|
||||
task_id = await self.create_model_download_task(task_data, request)
|
||||
return web.json_response({"success": True, "data": {"taskId": task_id}})
|
||||
except Exception as e:
|
||||
error_msg = f"Create model download task failed: {str(e)}"
|
||||
utils.print_error(error_msg)
|
||||
return web.json_response({"success": False, "error": error_msg})
|
||||
|
||||
download_model_task_status: dict[str, TaskStatus] = {}
|
||||
|
||||
download_thread_pool = thread.DownloadThreadPool()
|
||||
|
||||
def set_task_content(self, task_id: str, task_content: Union[TaskContent, dict]):
|
||||
download_path = utils.get_download_path()
|
||||
task_file_path = utils.join_path(download_path, f"{task_id}.task")
|
||||
utils.save_dict_pickle_file(task_file_path, task_content)
|
||||
|
||||
def get_task_content(self, task_id: str):
|
||||
download_path = utils.get_download_path()
|
||||
task_file = utils.join_path(download_path, f"{task_id}.task")
|
||||
if not os.path.isfile(task_file):
|
||||
raise RuntimeError(f"Task {task_id} not found")
|
||||
task_content = utils.load_dict_pickle_file(task_file)
|
||||
if isinstance(task_content, TaskContent):
|
||||
return task_content
|
||||
return TaskContent(**task_content)
|
||||
|
||||
def get_task_status(self, task_id: str):
|
||||
task_status = self.download_model_task_status.get(task_id, None)
|
||||
|
||||
if task_status is None:
|
||||
download_path = utils.get_download_path()
|
||||
task_content = self.get_task_content(task_id)
|
||||
download_file = utils.join_path(download_path, f"{task_id}.download")
|
||||
download_size = 0
|
||||
if os.path.exists(download_file):
|
||||
download_size = os.path.getsize(download_file)
|
||||
|
||||
total_size = task_content.sizeBytes
|
||||
task_status = TaskStatus(
|
||||
taskId=task_id,
|
||||
type=task_content.type,
|
||||
fullname=task_content.fullname,
|
||||
preview=utils.get_model_preview_name(download_file),
|
||||
platform=task_content.downloadPlatform,
|
||||
downloadedSize=download_size,
|
||||
totalSize=task_content.sizeBytes,
|
||||
progress=download_size / total_size * 100 if total_size > 0 else 0,
|
||||
)
|
||||
|
||||
self.download_model_task_status[task_id] = task_status
|
||||
|
||||
return task_status
|
||||
|
||||
def delete_task_status(self, task_id: str):
|
||||
self.download_model_task_status.pop(task_id, None)
|
||||
|
||||
async def scan_model_download_task_list(self):
|
||||
"""
|
||||
Scan the download directory and send the task list to the client.
|
||||
"""
|
||||
download_dir = utils.get_download_path()
|
||||
task_files = utils.search_files(download_dir)
|
||||
task_files = folder_paths.filter_files_extensions(task_files, [".task"])
|
||||
task_files = sorted(
|
||||
task_files,
|
||||
key=lambda x: os.stat(utils.join_path(download_dir, x)).st_ctime,
|
||||
reverse=True,
|
||||
)
|
||||
task_list: list[dict] = []
|
||||
for task_file in task_files:
|
||||
task_id = task_file.replace(".task", "")
|
||||
task_status = self.get_task_status(task_id)
|
||||
task_list.append(task_status.to_dict())
|
||||
|
||||
return task_list
|
||||
|
||||
async def create_model_download_task(self, task_data: dict, request):
|
||||
"""
|
||||
Creates a download task for the given data.
|
||||
"""
|
||||
model_type = task_data.get("type", None)
|
||||
path_index = int(task_data.get("pathIndex", None))
|
||||
fullname = task_data.get("fullname", None)
|
||||
|
||||
model_path = utils.get_full_path(model_type, path_index, fullname)
|
||||
# Check if the model path is valid
|
||||
if os.path.exists(model_path):
|
||||
raise RuntimeError(f"File already exists: {model_path}")
|
||||
|
||||
download_path = utils.get_download_path()
|
||||
|
||||
task_id = uuid.uuid4().hex
|
||||
task_path = utils.join_path(download_path, f"{task_id}.task")
|
||||
if os.path.exists(task_path):
|
||||
raise RuntimeError(f"Task {task_id} already exists")
|
||||
download_platform = task_data.get("downloadPlatform", None)
|
||||
|
||||
try:
|
||||
preview_file = task_data.pop("previewFile", None)
|
||||
utils.save_model_preview_image(task_path, preview_file, download_platform)
|
||||
self.set_task_content(task_id, task_data)
|
||||
task_status = TaskStatus(
|
||||
taskId=task_id,
|
||||
type=model_type,
|
||||
fullname=fullname,
|
||||
preview=utils.get_model_preview_name(task_path),
|
||||
platform=download_platform,
|
||||
totalSize=float(task_data.get("sizeBytes", 0)),
|
||||
)
|
||||
self.download_model_task_status[task_id] = task_status
|
||||
await utils.send_json("create_download_task", task_status.to_dict())
|
||||
except Exception as e:
|
||||
await self.delete_model_download_task(task_id)
|
||||
raise RuntimeError(str(e)) from e
|
||||
|
||||
await self.download_model(task_id, request)
|
||||
return task_id
|
||||
|
||||
async def pause_model_download_task(self, task_id: str):
|
||||
task_status = self.get_task_status(task_id=task_id)
|
||||
task_status.status = "pause"
|
||||
|
||||
async def delete_model_download_task(self, task_id: str):
|
||||
task_status = self.get_task_status(task_id)
|
||||
is_running = task_status.status == "doing"
|
||||
task_status.status = "waiting"
|
||||
await utils.send_json("delete_download_task", task_id)
|
||||
|
||||
# Pause the task
|
||||
if is_running:
|
||||
task_status.status = "pause"
|
||||
time.sleep(1)
|
||||
|
||||
download_dir = utils.get_download_path()
|
||||
task_file_list = os.listdir(download_dir)
|
||||
for task_file in task_file_list:
|
||||
task_file_target = os.path.splitext(task_file)[0]
|
||||
if task_file_target == task_id:
|
||||
self.delete_task_status(task_id)
|
||||
os.remove(utils.join_path(download_dir, task_file))
|
||||
|
||||
await utils.send_json("delete_download_task", task_id)
|
||||
|
||||
async def download_model(self, task_id: str, request):
|
||||
async def download_task(task_id: str):
|
||||
async def report_progress(task_status: TaskStatus):
|
||||
await utils.send_json("update_download_task", task_status.to_dict())
|
||||
|
||||
try:
|
||||
# When starting a task from the queue, the task may not exist
|
||||
task_status = self.get_task_status(task_id)
|
||||
except:
|
||||
return
|
||||
|
||||
# Update task status
|
||||
task_status.status = "doing"
|
||||
await utils.send_json("update_download_task", task_status.to_dict())
|
||||
|
||||
try:
|
||||
|
||||
# Set download request headers
|
||||
headers = {"User-Agent": config.user_agent}
|
||||
|
||||
download_platform = task_status.platform
|
||||
if download_platform == "civitai":
|
||||
api_key = utils.get_setting_value(request, "api_key.civitai")
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
|
||||
elif download_platform == "huggingface":
|
||||
api_key = utils.get_setting_value(request, "api_key.huggingface")
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
|
||||
progress_interval = 1.0
|
||||
await self.download_model_file(
|
||||
task_id=task_id,
|
||||
headers=headers,
|
||||
progress_callback=report_progress,
|
||||
interval=progress_interval,
|
||||
)
|
||||
except Exception as e:
|
||||
task_status.status = "pause"
|
||||
task_status.error = str(e)
|
||||
await utils.send_json("update_download_task", task_status.to_dict())
|
||||
task_status.error = None
|
||||
utils.print_error(str(e))
|
||||
|
||||
try:
|
||||
status = self.download_thread_pool.submit(download_task, task_id)
|
||||
if status == "Waiting":
|
||||
task_status = self.get_task_status(task_id)
|
||||
task_status.status = "waiting"
|
||||
await utils.send_json("update_download_task", task_status.to_dict())
|
||||
except Exception as e:
|
||||
task_status.status = "pause"
|
||||
task_status.error = str(e)
|
||||
await utils.send_json("update_download_task", task_status.to_dict())
|
||||
task_status.error = None
|
||||
utils.print_error(str(e))
|
||||
|
||||
async def download_model_file(
|
||||
self,
|
||||
task_id: str,
|
||||
headers: dict,
|
||||
progress_callback: Callable[[TaskStatus], Awaitable[Any]],
|
||||
interval: float = 1.0,
|
||||
):
|
||||
|
||||
async def download_complete():
|
||||
"""
|
||||
Restore the model information from the task file
|
||||
and move the model file to the target directory.
|
||||
"""
|
||||
model_type = task_content.type
|
||||
path_index = task_content.pathIndex
|
||||
fullname = task_content.fullname
|
||||
# Write description file
|
||||
description = task_content.description
|
||||
description_file = utils.join_path(download_path, f"{task_id}.md")
|
||||
with open(description_file, "w", encoding="utf-8", newline="") as f:
|
||||
f.write(description)
|
||||
|
||||
model_path = utils.get_full_path(model_type, path_index, fullname)
|
||||
|
||||
utils.rename_model(download_tmp_file, model_path)
|
||||
|
||||
time.sleep(1)
|
||||
task_file = utils.join_path(download_path, f"{task_id}.task")
|
||||
os.remove(task_file)
|
||||
await utils.send_json("complete_download_task", task_id)
|
||||
|
||||
async def update_progress():
|
||||
nonlocal last_update_time
|
||||
nonlocal last_downloaded_size
|
||||
progress = (downloaded_size / total_size) * 100 if total_size > 0 else 0
|
||||
task_status.downloadedSize = downloaded_size
|
||||
task_status.progress = progress
|
||||
task_status.bps = downloaded_size - last_downloaded_size
|
||||
await progress_callback(task_status)
|
||||
last_update_time = time.time()
|
||||
last_downloaded_size = downloaded_size
|
||||
|
||||
task_status = self.get_task_status(task_id)
|
||||
task_content = self.get_task_content(task_id)
|
||||
|
||||
# Check download uri
|
||||
model_url = task_content.downloadUrl
|
||||
if not model_url:
|
||||
raise RuntimeError("No downloadUrl found")
|
||||
|
||||
download_path = utils.get_download_path()
|
||||
download_tmp_file = utils.join_path(download_path, f"{task_id}.download")
|
||||
|
||||
downloaded_size = 0
|
||||
if os.path.isfile(download_tmp_file):
|
||||
downloaded_size = os.path.getsize(download_tmp_file)
|
||||
headers["Range"] = f"bytes={downloaded_size}-"
|
||||
|
||||
total_size = task_content.sizeBytes
|
||||
|
||||
if total_size > 0 and downloaded_size == total_size:
|
||||
await download_complete()
|
||||
return
|
||||
|
||||
last_update_time = time.time()
|
||||
last_downloaded_size = downloaded_size
|
||||
|
||||
response = requests.get(
|
||||
url=model_url,
|
||||
headers=headers,
|
||||
stream=True,
|
||||
allow_redirects=True,
|
||||
)
|
||||
|
||||
if response.status_code not in (200, 206):
|
||||
raise RuntimeError(f"Failed to download {task_content.fullname}, status code: {response.status_code}")
|
||||
|
||||
# Some models require logging in before they can be downloaded.
|
||||
# If no token is carried, it will be redirected to the login page.
|
||||
content_type = response.headers.get("content-type")
|
||||
if content_type and content_type.startswith("text/html"):
|
||||
# TODO More checks
|
||||
# In addition to requiring login to download, there may be other restrictions.
|
||||
# The currently one situation is early access??? issues#43
|
||||
# Due to the lack of test data, let’s put it aside for now.
|
||||
# If it cannot be downloaded, a redirect will definitely occur.
|
||||
# Maybe consider getting the redirect url from response.history to make a judgment.
|
||||
# Here we also need to consider how different websites are processed.
|
||||
raise RuntimeError(f"{task_content.fullname} needs to be logged in to download. Please set the API-Key first.")
|
||||
|
||||
# When parsing model information from HuggingFace API,
|
||||
# the file size was not found and needs to be obtained from the response header.
|
||||
if total_size == 0:
|
||||
total_size = float(response.headers.get("content-length", 0))
|
||||
task_content.sizeBytes = total_size
|
||||
task_status.totalSize = total_size
|
||||
self.set_task_content(task_id, task_content)
|
||||
await utils.send_json("update_download_task", task_content.to_dict())
|
||||
|
||||
with open(download_tmp_file, "ab") as f:
|
||||
for chunk in response.iter_content(chunk_size=8192):
|
||||
if task_status.status == "pause":
|
||||
break
|
||||
|
||||
f.write(chunk)
|
||||
downloaded_size += len(chunk)
|
||||
|
||||
if time.time() - last_update_time >= interval:
|
||||
await update_progress()
|
||||
|
||||
await update_progress()
|
||||
|
||||
if total_size > 0 and downloaded_size == total_size:
|
||||
await download_complete()
|
||||
else:
|
||||
task_status.status = "pause"
|
||||
await utils.send_json("update_download_task", task_status.to_dict())
|
||||
|
||||
Reference in New Issue
Block a user