fix: cross-platform paths

This commit is contained in:
hayden
2024-11-05 16:42:36 +08:00
parent f9b0afcbf5
commit 7a183464ae
4 changed files with 52 additions and 44 deletions

View File

@@ -46,13 +46,13 @@ download_thread_pool = thread.DownloadThreadPool()
def set_task_content(task_id: str, task_content: Union[TaskContent, dict]):
download_path = utils.get_download_path()
task_file_path = os.path.join(download_path, f"{task_id}.task")
task_file_path = utils.join_path(download_path, f"{task_id}.task")
utils.save_dict_pickle_file(task_file_path, utils.unpack_dataclass(task_content))
def get_task_content(task_id: str):
download_path = utils.get_download_path()
task_file = os.path.join(download_path, f"{task_id}.task")
task_file = utils.join_path(download_path, f"{task_id}.task")
if not os.path.isfile(task_file):
raise RuntimeError(f"Task {task_id} not found")
task_content = utils.load_dict_pickle_file(task_file)
@@ -67,7 +67,7 @@ def get_task_status(task_id: str):
if task_status is None:
download_path = utils.get_download_path()
task_content = get_task_content(task_id)
download_file = os.path.join(download_path, f"{task_id}.download")
download_file = utils.join_path(download_path, f"{task_id}.download")
download_size = 0
if os.path.exists(download_file):
download_size = os.path.getsize(download_file)
@@ -103,7 +103,7 @@ async def scan_model_download_task_list(sid: str):
task_files = folder_paths.filter_files_extensions(task_files, [".task"])
task_files = sorted(
task_files,
key=lambda x: os.stat(os.path.join(download_dir, x)).st_ctime,
key=lambda x: os.stat(utils.join_path(download_dir, x)).st_ctime,
reverse=True,
)
task_list: list[dict] = []
@@ -135,7 +135,7 @@ async def create_model_download_task(post: dict):
download_path = utils.get_download_path()
task_id = uuid.uuid4().hex
task_path = os.path.join(download_path, f"{task_id}.task")
task_path = utils.join_path(download_path, f"{task_id}.task")
if os.path.exists(task_path):
raise RuntimeError(f"Task {task_id} already exists")
@@ -183,7 +183,7 @@ async def delete_model_download_task(task_id: str):
task_file_target = os.path.splitext(task_file)[0]
if task_file_target == task_id:
delete_task_status(task_id)
os.remove(os.path.join(download_dir, task_file))
os.remove(utils.join_path(download_dir, task_file))
await socket.send_json("deleteDownloadTask", task_id)
@@ -264,7 +264,7 @@ async def download_model_file(
fullname = task_content.fullname
# Write description file
description = task_content.description
description_file = os.path.join(download_path, f"{task_id}.md")
description_file = utils.join_path(download_path, f"{task_id}.md")
with open(description_file, "w") as f:
f.write(description)
@@ -273,7 +273,7 @@ async def download_model_file(
utils.rename_model(download_tmp_file, model_path)
time.sleep(1)
task_file = os.path.join(download_path, f"{task_id}.task")
task_file = utils.join_path(download_path, f"{task_id}.task")
os.remove(task_file)
await socket.send_json("completeDownloadTask", task_id)
@@ -297,7 +297,7 @@ async def download_model_file(
raise RuntimeError("No downloadUrl found")
download_path = utils.get_download_path()
download_tmp_file = os.path.join(download_path, f"{task_id}.download")
download_tmp_file = utils.join_path(download_path, f"{task_id}.download")
downloaded_size = 0
if os.path.isfile(download_tmp_file):