fix: cross-platform paths
This commit is contained in:
@@ -46,13 +46,13 @@ download_thread_pool = thread.DownloadThreadPool()
|
||||
|
||||
def set_task_content(task_id: str, task_content: Union[TaskContent, dict]):
|
||||
download_path = utils.get_download_path()
|
||||
task_file_path = os.path.join(download_path, f"{task_id}.task")
|
||||
task_file_path = utils.join_path(download_path, f"{task_id}.task")
|
||||
utils.save_dict_pickle_file(task_file_path, utils.unpack_dataclass(task_content))
|
||||
|
||||
|
||||
def get_task_content(task_id: str):
|
||||
download_path = utils.get_download_path()
|
||||
task_file = os.path.join(download_path, f"{task_id}.task")
|
||||
task_file = utils.join_path(download_path, f"{task_id}.task")
|
||||
if not os.path.isfile(task_file):
|
||||
raise RuntimeError(f"Task {task_id} not found")
|
||||
task_content = utils.load_dict_pickle_file(task_file)
|
||||
@@ -67,7 +67,7 @@ def get_task_status(task_id: str):
|
||||
if task_status is None:
|
||||
download_path = utils.get_download_path()
|
||||
task_content = get_task_content(task_id)
|
||||
download_file = os.path.join(download_path, f"{task_id}.download")
|
||||
download_file = utils.join_path(download_path, f"{task_id}.download")
|
||||
download_size = 0
|
||||
if os.path.exists(download_file):
|
||||
download_size = os.path.getsize(download_file)
|
||||
@@ -103,7 +103,7 @@ async def scan_model_download_task_list(sid: str):
|
||||
task_files = folder_paths.filter_files_extensions(task_files, [".task"])
|
||||
task_files = sorted(
|
||||
task_files,
|
||||
key=lambda x: os.stat(os.path.join(download_dir, x)).st_ctime,
|
||||
key=lambda x: os.stat(utils.join_path(download_dir, x)).st_ctime,
|
||||
reverse=True,
|
||||
)
|
||||
task_list: list[dict] = []
|
||||
@@ -135,7 +135,7 @@ async def create_model_download_task(post: dict):
|
||||
download_path = utils.get_download_path()
|
||||
|
||||
task_id = uuid.uuid4().hex
|
||||
task_path = os.path.join(download_path, f"{task_id}.task")
|
||||
task_path = utils.join_path(download_path, f"{task_id}.task")
|
||||
if os.path.exists(task_path):
|
||||
raise RuntimeError(f"Task {task_id} already exists")
|
||||
|
||||
@@ -183,7 +183,7 @@ async def delete_model_download_task(task_id: str):
|
||||
task_file_target = os.path.splitext(task_file)[0]
|
||||
if task_file_target == task_id:
|
||||
delete_task_status(task_id)
|
||||
os.remove(os.path.join(download_dir, task_file))
|
||||
os.remove(utils.join_path(download_dir, task_file))
|
||||
|
||||
await socket.send_json("deleteDownloadTask", task_id)
|
||||
|
||||
@@ -264,7 +264,7 @@ async def download_model_file(
|
||||
fullname = task_content.fullname
|
||||
# Write description file
|
||||
description = task_content.description
|
||||
description_file = os.path.join(download_path, f"{task_id}.md")
|
||||
description_file = utils.join_path(download_path, f"{task_id}.md")
|
||||
with open(description_file, "w") as f:
|
||||
f.write(description)
|
||||
|
||||
@@ -273,7 +273,7 @@ async def download_model_file(
|
||||
utils.rename_model(download_tmp_file, model_path)
|
||||
|
||||
time.sleep(1)
|
||||
task_file = os.path.join(download_path, f"{task_id}.task")
|
||||
task_file = utils.join_path(download_path, f"{task_id}.task")
|
||||
os.remove(task_file)
|
||||
await socket.send_json("completeDownloadTask", task_id)
|
||||
|
||||
@@ -297,7 +297,7 @@ async def download_model_file(
|
||||
raise RuntimeError("No downloadUrl found")
|
||||
|
||||
download_path = utils.get_download_path()
|
||||
download_tmp_file = os.path.join(download_path, f"{task_id}.download")
|
||||
download_tmp_file = utils.join_path(download_path, f"{task_id}.download")
|
||||
|
||||
downloaded_size = 0
|
||||
if os.path.isfile(download_tmp_file):
|
||||
|
||||
@@ -46,16 +46,16 @@ def scan_models():
|
||||
image_dict = utils.file_list_to_name_dict(images)
|
||||
|
||||
for fullname in models:
|
||||
fullname = fullname.replace(os.path.sep, "/")
|
||||
fullname = utils.normalize_path(fullname)
|
||||
basename = os.path.splitext(fullname)[0]
|
||||
extension = os.path.splitext(fullname)[1]
|
||||
|
||||
abs_path = os.path.join(base_path, fullname)
|
||||
abs_path = utils.join_path(base_path, fullname)
|
||||
file_stats = os.stat(abs_path)
|
||||
|
||||
# Resolve preview
|
||||
image_name = image_dict.get(basename, "no-preview.png")
|
||||
abs_image_path = os.path.join(base_path, image_name)
|
||||
abs_image_path = utils.join_path(base_path, image_name)
|
||||
if os.path.isfile(abs_image_path):
|
||||
image_state = os.stat(abs_image_path)
|
||||
image_timestamp = round(image_state.st_mtime_ns / 1000000)
|
||||
@@ -87,7 +87,7 @@ def get_model_info(model_path: str):
|
||||
metadata = utils.get_model_metadata(model_path)
|
||||
|
||||
description_file = utils.get_model_description_name(model_path)
|
||||
description_file = os.path.join(directory, description_file)
|
||||
description_file = utils.join_path(directory, description_file)
|
||||
description = None
|
||||
if os.path.isfile(description_file):
|
||||
with open(description_file, "r", encoding="utf-8") as f:
|
||||
@@ -128,11 +128,11 @@ def remove_model(model_path: str):
|
||||
|
||||
model_previews = utils.get_model_all_images(model_path)
|
||||
for preview in model_previews:
|
||||
os.remove(os.path.join(model_dirname, preview))
|
||||
os.remove(utils.join_path(model_dirname, preview))
|
||||
|
||||
model_descriptions = utils.get_model_all_descriptions(model_path)
|
||||
for description in model_descriptions:
|
||||
os.remove(os.path.join(model_dirname, description))
|
||||
os.remove(utils.join_path(model_dirname, description))
|
||||
|
||||
|
||||
async def create_model_download_task(post):
|
||||
|
||||
56
py/utils.py
56
py/utils.py
@@ -15,9 +15,18 @@ from typing import Any
|
||||
from . import config
|
||||
|
||||
|
||||
def normalize_path(path: str):
|
||||
normpath = os.path.normpath(path)
|
||||
return normpath.replace(os.path.sep, "/")
|
||||
|
||||
|
||||
def join_path(path: str, *paths: list[str]):
|
||||
return normalize_path(os.path.join(path, *paths))
|
||||
|
||||
|
||||
def get_current_version():
|
||||
try:
|
||||
pyproject_path = os.path.join(config.extension_uri, "pyproject.toml")
|
||||
pyproject_path = join_path(config.extension_uri, "pyproject.toml")
|
||||
config_parser = configparser.ConfigParser()
|
||||
config_parser.read(pyproject_path)
|
||||
version = config_parser.get("project", "version")
|
||||
@@ -27,13 +36,13 @@ def get_current_version():
|
||||
|
||||
|
||||
def download_web_distribution(version: str):
|
||||
web_path = os.path.join(config.extension_uri, "web")
|
||||
dev_web_file = os.path.join(web_path, "manager-dev.js")
|
||||
web_path = join_path(config.extension_uri, "web")
|
||||
dev_web_file = join_path(web_path, "manager-dev.js")
|
||||
if os.path.exists(dev_web_file):
|
||||
return
|
||||
|
||||
web_version = "0.0.0"
|
||||
version_file = os.path.join(web_path, "version.yaml")
|
||||
version_file = join_path(web_path, "version.yaml")
|
||||
if os.path.exists(version_file):
|
||||
with open(version_file, "r") as f:
|
||||
version_content = yaml.safe_load(f)
|
||||
@@ -49,7 +58,7 @@ def download_web_distribution(version: str):
|
||||
response = requests.get(download_url, stream=True)
|
||||
response.raise_for_status()
|
||||
|
||||
temp_file = os.path.join(config.extension_uri, "temp.tar.gz")
|
||||
temp_file = join_path(config.extension_uri, "temp.tar.gz")
|
||||
with open(temp_file, "wb") as f:
|
||||
for chunk in response.iter_content(chunk_size=8192):
|
||||
f.write(chunk)
|
||||
@@ -82,7 +91,8 @@ def resolve_model_base_paths():
|
||||
continue
|
||||
if folder == "custom_nodes":
|
||||
continue
|
||||
config.model_base_paths[folder] = folder_paths.get_folder_paths(folder)
|
||||
folders = folder_paths.get_folder_paths(folder)
|
||||
config.model_base_paths[folder] = [normalize_path(f) for f in folders]
|
||||
|
||||
|
||||
def get_full_path(model_type: str, path_index: int, filename: str):
|
||||
@@ -93,7 +103,8 @@ def get_full_path(model_type: str, path_index: int, filename: str):
|
||||
if not path_index < len(folders):
|
||||
raise RuntimeError(f"PathIndex {path_index} is not in {model_type}")
|
||||
base_path = folders[path_index]
|
||||
return os.path.join(base_path, filename)
|
||||
full_path = join_path(base_path, filename)
|
||||
return full_path
|
||||
|
||||
|
||||
def get_valid_full_path(model_type: str, path_index: int, filename: str):
|
||||
@@ -104,7 +115,7 @@ def get_valid_full_path(model_type: str, path_index: int, filename: str):
|
||||
if not path_index < len(folders):
|
||||
raise RuntimeError(f"PathIndex {path_index} is not in {model_type}")
|
||||
base_path = folders[path_index]
|
||||
full_path = os.path.join(base_path, filename)
|
||||
full_path = join_path(base_path, filename)
|
||||
if os.path.isfile(full_path):
|
||||
return full_path
|
||||
elif os.path.islink(full_path):
|
||||
@@ -114,7 +125,7 @@ def get_valid_full_path(model_type: str, path_index: int, filename: str):
|
||||
|
||||
|
||||
def get_download_path():
|
||||
download_path = os.path.join(config.extension_uri, "downloads")
|
||||
download_path = join_path(config.extension_uri, "downloads")
|
||||
if not os.path.exists(download_path):
|
||||
os.makedirs(download_path)
|
||||
return download_path
|
||||
@@ -124,12 +135,12 @@ def recursive_search_files(directory: str):
|
||||
files, folder_all = folder_paths.recursive_search(
|
||||
directory, excluded_dir_names=[".git"]
|
||||
)
|
||||
return files
|
||||
return [normalize_path(f) for f in files]
|
||||
|
||||
|
||||
def search_files(directory: str):
|
||||
entries = os.listdir(directory)
|
||||
files = [f for f in entries if os.path.isfile(os.path.join(directory, f))]
|
||||
files = [f for f in entries if os.path.isfile(join_path(directory, f))]
|
||||
return files
|
||||
|
||||
|
||||
@@ -137,7 +148,6 @@ def file_list_to_name_dict(files: list[str]):
|
||||
file_dict: dict[str, str] = {}
|
||||
for file in files:
|
||||
filename = os.path.splitext(file)[0]
|
||||
filename = filename.replace(os.path.sep, "/")
|
||||
file_dict[filename] = file
|
||||
return file_dict
|
||||
|
||||
@@ -194,13 +204,13 @@ def save_model_preview_image(model_path: str, image_file: Any):
|
||||
for image in old_preview_images:
|
||||
if os.path.splitext(image)[1].endswith(".preview"):
|
||||
a1111_civitai_helper_image = True
|
||||
image_path = os.path.join(base_dirname, image)
|
||||
image_path = join_path(base_dirname, image)
|
||||
os.remove(image_path)
|
||||
|
||||
# save new preview image
|
||||
basename = os.path.splitext(os.path.basename(model_path))[0]
|
||||
extension = f".{content_type.split('/')[1]}"
|
||||
new_preview_path = os.path.join(base_dirname, f"{basename}{extension}")
|
||||
new_preview_path = join_path(base_dirname, f"{basename}{extension}")
|
||||
|
||||
with open(new_preview_path, "wb") as f:
|
||||
f.write(image_file.file.read())
|
||||
@@ -210,7 +220,7 @@ def save_model_preview_image(model_path: str, image_file: Any):
|
||||
"""
|
||||
Keep preview image of a1111_civitai_helper
|
||||
"""
|
||||
new_preview_path = os.path.join(base_dirname, f"{basename}.preview{extension}")
|
||||
new_preview_path = join_path(base_dirname, f"{basename}.preview{extension}")
|
||||
with open(new_preview_path, "wb") as f:
|
||||
f.write(image_file.file.read())
|
||||
|
||||
@@ -244,13 +254,13 @@ def save_model_description(model_path: str, content: Any):
|
||||
# remove old descriptions
|
||||
old_descriptions = get_model_all_descriptions(model_path)
|
||||
for desc in old_descriptions:
|
||||
description_path = os.path.join(base_dirname, desc)
|
||||
description_path = join_path(base_dirname, desc)
|
||||
os.remove(description_path)
|
||||
|
||||
# save new description
|
||||
basename = os.path.splitext(os.path.basename(model_path))[0]
|
||||
extension = ".md"
|
||||
new_desc_path = os.path.join(base_dirname, f"{basename}{extension}")
|
||||
new_desc_path = join_path(base_dirname, f"{basename}{extension}")
|
||||
|
||||
with open(new_desc_path, "w", encoding="utf-8") as f:
|
||||
f.write(content)
|
||||
@@ -278,23 +288,21 @@ def rename_model(model_path: str, new_model_path: str):
|
||||
# move preview
|
||||
previews = get_model_all_images(model_path)
|
||||
for preview in previews:
|
||||
preview_path = os.path.join(model_dirname, preview)
|
||||
preview_path = join_path(model_dirname, preview)
|
||||
preview_name = os.path.splitext(preview)[0]
|
||||
preview_ext = os.path.splitext(preview)[1]
|
||||
new_preview_path = (
|
||||
os.path.join(new_model_dirname, new_model_name + preview_ext)
|
||||
join_path(new_model_dirname, new_model_name + preview_ext)
|
||||
if preview_name == model_name
|
||||
else os.path.join(
|
||||
new_model_dirname, new_model_name + ".preview" + preview_ext
|
||||
)
|
||||
else join_path(new_model_dirname, new_model_name + ".preview" + preview_ext)
|
||||
)
|
||||
shutil.move(preview_path, new_preview_path)
|
||||
|
||||
# move description
|
||||
description = get_model_description_name(model_path)
|
||||
description_path = os.path.join(model_dirname, description)
|
||||
description_path = join_path(model_dirname, description)
|
||||
if os.path.isfile(description_path):
|
||||
new_description_path = os.path.join(new_model_dirname, f"{new_model_name}.md")
|
||||
new_description_path = join_path(new_model_dirname, f"{new_model_name}.md")
|
||||
shutil.move(description_path, new_description_path)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user