fix: cross-platform paths

This commit is contained in:
hayden
2024-11-05 16:42:36 +08:00
parent f9b0afcbf5
commit 7a183464ae
4 changed files with 52 additions and 44 deletions

View File

@@ -5,7 +5,7 @@ from .py import utils
# Init config settings
config.extension_uri = os.path.dirname(__file__)
config.extension_uri = utils.normalize_path(os.path.dirname(__file__))
utils.resolve_model_base_paths()
version = utils.get_current_version()
@@ -173,12 +173,12 @@ async def read_model_preview(request):
try:
folders = folder_paths.get_folder_paths(model_type)
base_path = folders[index]
abs_path = os.path.join(base_path, filename)
abs_path = utils.join_path(base_path, filename)
except:
abs_path = extension_uri
if not os.path.isfile(abs_path):
abs_path = os.path.join(extension_uri, "assets", "no-preview.png")
abs_path = utils.join_path(extension_uri, "assets", "no-preview.png")
return web.FileResponse(abs_path)
@@ -188,10 +188,10 @@ async def read_download_preview(request):
extension_uri = config.extension_uri
download_path = utils.get_download_path()
preview_path = os.path.join(download_path, filename)
preview_path = utils.join_path(download_path, filename)
if not os.path.isfile(preview_path):
preview_path = os.path.join(extension_uri, "assets", "no-preview.png")
preview_path = utils.join_path(extension_uri, "assets", "no-preview.png")
return web.FileResponse(preview_path)

View File

@@ -46,13 +46,13 @@ download_thread_pool = thread.DownloadThreadPool()
def set_task_content(task_id: str, task_content: Union[TaskContent, dict]):
download_path = utils.get_download_path()
task_file_path = os.path.join(download_path, f"{task_id}.task")
task_file_path = utils.join_path(download_path, f"{task_id}.task")
utils.save_dict_pickle_file(task_file_path, utils.unpack_dataclass(task_content))
def get_task_content(task_id: str):
download_path = utils.get_download_path()
task_file = os.path.join(download_path, f"{task_id}.task")
task_file = utils.join_path(download_path, f"{task_id}.task")
if not os.path.isfile(task_file):
raise RuntimeError(f"Task {task_id} not found")
task_content = utils.load_dict_pickle_file(task_file)
@@ -67,7 +67,7 @@ def get_task_status(task_id: str):
if task_status is None:
download_path = utils.get_download_path()
task_content = get_task_content(task_id)
download_file = os.path.join(download_path, f"{task_id}.download")
download_file = utils.join_path(download_path, f"{task_id}.download")
download_size = 0
if os.path.exists(download_file):
download_size = os.path.getsize(download_file)
@@ -103,7 +103,7 @@ async def scan_model_download_task_list(sid: str):
task_files = folder_paths.filter_files_extensions(task_files, [".task"])
task_files = sorted(
task_files,
key=lambda x: os.stat(os.path.join(download_dir, x)).st_ctime,
key=lambda x: os.stat(utils.join_path(download_dir, x)).st_ctime,
reverse=True,
)
task_list: list[dict] = []
@@ -135,7 +135,7 @@ async def create_model_download_task(post: dict):
download_path = utils.get_download_path()
task_id = uuid.uuid4().hex
task_path = os.path.join(download_path, f"{task_id}.task")
task_path = utils.join_path(download_path, f"{task_id}.task")
if os.path.exists(task_path):
raise RuntimeError(f"Task {task_id} already exists")
@@ -183,7 +183,7 @@ async def delete_model_download_task(task_id: str):
task_file_target = os.path.splitext(task_file)[0]
if task_file_target == task_id:
delete_task_status(task_id)
os.remove(os.path.join(download_dir, task_file))
os.remove(utils.join_path(download_dir, task_file))
await socket.send_json("deleteDownloadTask", task_id)
@@ -264,7 +264,7 @@ async def download_model_file(
fullname = task_content.fullname
# Write description file
description = task_content.description
description_file = os.path.join(download_path, f"{task_id}.md")
description_file = utils.join_path(download_path, f"{task_id}.md")
with open(description_file, "w") as f:
f.write(description)
@@ -273,7 +273,7 @@ async def download_model_file(
utils.rename_model(download_tmp_file, model_path)
time.sleep(1)
task_file = os.path.join(download_path, f"{task_id}.task")
task_file = utils.join_path(download_path, f"{task_id}.task")
os.remove(task_file)
await socket.send_json("completeDownloadTask", task_id)
@@ -297,7 +297,7 @@ async def download_model_file(
raise RuntimeError("No downloadUrl found")
download_path = utils.get_download_path()
download_tmp_file = os.path.join(download_path, f"{task_id}.download")
download_tmp_file = utils.join_path(download_path, f"{task_id}.download")
downloaded_size = 0
if os.path.isfile(download_tmp_file):

View File

@@ -46,16 +46,16 @@ def scan_models():
image_dict = utils.file_list_to_name_dict(images)
for fullname in models:
fullname = fullname.replace(os.path.sep, "/")
fullname = utils.normalize_path(fullname)
basename = os.path.splitext(fullname)[0]
extension = os.path.splitext(fullname)[1]
abs_path = os.path.join(base_path, fullname)
abs_path = utils.join_path(base_path, fullname)
file_stats = os.stat(abs_path)
# Resolve preview
image_name = image_dict.get(basename, "no-preview.png")
abs_image_path = os.path.join(base_path, image_name)
abs_image_path = utils.join_path(base_path, image_name)
if os.path.isfile(abs_image_path):
image_state = os.stat(abs_image_path)
image_timestamp = round(image_state.st_mtime_ns / 1000000)
@@ -87,7 +87,7 @@ def get_model_info(model_path: str):
metadata = utils.get_model_metadata(model_path)
description_file = utils.get_model_description_name(model_path)
description_file = os.path.join(directory, description_file)
description_file = utils.join_path(directory, description_file)
description = None
if os.path.isfile(description_file):
with open(description_file, "r", encoding="utf-8") as f:
@@ -128,11 +128,11 @@ def remove_model(model_path: str):
model_previews = utils.get_model_all_images(model_path)
for preview in model_previews:
os.remove(os.path.join(model_dirname, preview))
os.remove(utils.join_path(model_dirname, preview))
model_descriptions = utils.get_model_all_descriptions(model_path)
for description in model_descriptions:
os.remove(os.path.join(model_dirname, description))
os.remove(utils.join_path(model_dirname, description))
async def create_model_download_task(post):

View File

@@ -15,9 +15,18 @@ from typing import Any
from . import config
def normalize_path(path: str):
normpath = os.path.normpath(path)
return normpath.replace(os.path.sep, "/")
def join_path(path: str, *paths: list[str]):
return normalize_path(os.path.join(path, *paths))
def get_current_version():
try:
pyproject_path = os.path.join(config.extension_uri, "pyproject.toml")
pyproject_path = join_path(config.extension_uri, "pyproject.toml")
config_parser = configparser.ConfigParser()
config_parser.read(pyproject_path)
version = config_parser.get("project", "version")
@@ -27,13 +36,13 @@ def get_current_version():
def download_web_distribution(version: str):
web_path = os.path.join(config.extension_uri, "web")
dev_web_file = os.path.join(web_path, "manager-dev.js")
web_path = join_path(config.extension_uri, "web")
dev_web_file = join_path(web_path, "manager-dev.js")
if os.path.exists(dev_web_file):
return
web_version = "0.0.0"
version_file = os.path.join(web_path, "version.yaml")
version_file = join_path(web_path, "version.yaml")
if os.path.exists(version_file):
with open(version_file, "r") as f:
version_content = yaml.safe_load(f)
@@ -49,7 +58,7 @@ def download_web_distribution(version: str):
response = requests.get(download_url, stream=True)
response.raise_for_status()
temp_file = os.path.join(config.extension_uri, "temp.tar.gz")
temp_file = join_path(config.extension_uri, "temp.tar.gz")
with open(temp_file, "wb") as f:
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)
@@ -82,7 +91,8 @@ def resolve_model_base_paths():
continue
if folder == "custom_nodes":
continue
config.model_base_paths[folder] = folder_paths.get_folder_paths(folder)
folders = folder_paths.get_folder_paths(folder)
config.model_base_paths[folder] = [normalize_path(f) for f in folders]
def get_full_path(model_type: str, path_index: int, filename: str):
@@ -93,7 +103,8 @@ def get_full_path(model_type: str, path_index: int, filename: str):
if not path_index < len(folders):
raise RuntimeError(f"PathIndex {path_index} is not in {model_type}")
base_path = folders[path_index]
return os.path.join(base_path, filename)
full_path = join_path(base_path, filename)
return full_path
def get_valid_full_path(model_type: str, path_index: int, filename: str):
@@ -104,7 +115,7 @@ def get_valid_full_path(model_type: str, path_index: int, filename: str):
if not path_index < len(folders):
raise RuntimeError(f"PathIndex {path_index} is not in {model_type}")
base_path = folders[path_index]
full_path = os.path.join(base_path, filename)
full_path = join_path(base_path, filename)
if os.path.isfile(full_path):
return full_path
elif os.path.islink(full_path):
@@ -114,7 +125,7 @@ def get_valid_full_path(model_type: str, path_index: int, filename: str):
def get_download_path():
download_path = os.path.join(config.extension_uri, "downloads")
download_path = join_path(config.extension_uri, "downloads")
if not os.path.exists(download_path):
os.makedirs(download_path)
return download_path
@@ -124,12 +135,12 @@ def recursive_search_files(directory: str):
files, folder_all = folder_paths.recursive_search(
directory, excluded_dir_names=[".git"]
)
return files
return [normalize_path(f) for f in files]
def search_files(directory: str):
entries = os.listdir(directory)
files = [f for f in entries if os.path.isfile(os.path.join(directory, f))]
files = [f for f in entries if os.path.isfile(join_path(directory, f))]
return files
@@ -137,7 +148,6 @@ def file_list_to_name_dict(files: list[str]):
file_dict: dict[str, str] = {}
for file in files:
filename = os.path.splitext(file)[0]
filename = filename.replace(os.path.sep, "/")
file_dict[filename] = file
return file_dict
@@ -194,13 +204,13 @@ def save_model_preview_image(model_path: str, image_file: Any):
for image in old_preview_images:
if os.path.splitext(image)[1].endswith(".preview"):
a1111_civitai_helper_image = True
image_path = os.path.join(base_dirname, image)
image_path = join_path(base_dirname, image)
os.remove(image_path)
# save new preview image
basename = os.path.splitext(os.path.basename(model_path))[0]
extension = f".{content_type.split('/')[1]}"
new_preview_path = os.path.join(base_dirname, f"{basename}{extension}")
new_preview_path = join_path(base_dirname, f"{basename}{extension}")
with open(new_preview_path, "wb") as f:
f.write(image_file.file.read())
@@ -210,7 +220,7 @@ def save_model_preview_image(model_path: str, image_file: Any):
"""
Keep preview image of a1111_civitai_helper
"""
new_preview_path = os.path.join(base_dirname, f"{basename}.preview{extension}")
new_preview_path = join_path(base_dirname, f"{basename}.preview{extension}")
with open(new_preview_path, "wb") as f:
f.write(image_file.file.read())
@@ -244,13 +254,13 @@ def save_model_description(model_path: str, content: Any):
# remove old descriptions
old_descriptions = get_model_all_descriptions(model_path)
for desc in old_descriptions:
description_path = os.path.join(base_dirname, desc)
description_path = join_path(base_dirname, desc)
os.remove(description_path)
# save new description
basename = os.path.splitext(os.path.basename(model_path))[0]
extension = ".md"
new_desc_path = os.path.join(base_dirname, f"{basename}{extension}")
new_desc_path = join_path(base_dirname, f"{basename}{extension}")
with open(new_desc_path, "w", encoding="utf-8") as f:
f.write(content)
@@ -278,23 +288,21 @@ def rename_model(model_path: str, new_model_path: str):
# move preview
previews = get_model_all_images(model_path)
for preview in previews:
preview_path = os.path.join(model_dirname, preview)
preview_path = join_path(model_dirname, preview)
preview_name = os.path.splitext(preview)[0]
preview_ext = os.path.splitext(preview)[1]
new_preview_path = (
os.path.join(new_model_dirname, new_model_name + preview_ext)
join_path(new_model_dirname, new_model_name + preview_ext)
if preview_name == model_name
else os.path.join(
new_model_dirname, new_model_name + ".preview" + preview_ext
)
else join_path(new_model_dirname, new_model_name + ".preview" + preview_ext)
)
shutil.move(preview_path, new_preview_path)
# move description
description = get_model_description_name(model_path)
description_path = os.path.join(model_dirname, description)
description_path = join_path(model_dirname, description)
if os.path.isfile(description_path):
new_description_path = os.path.join(new_model_dirname, f"{new_model_name}.md")
new_description_path = join_path(new_model_dirname, f"{new_model_name}.md")
shutil.move(description_path, new_description_path)