diff --git a/__init__.py b/__init__.py index 2935af2..9d504ea 100644 --- a/__init__.py +++ b/__init__.py @@ -95,7 +95,7 @@ async def get_model_paths(request): """ Returns the base folders for models. """ - model_base_paths = utils.resolve_model_base_paths() + model_base_paths = utils.resolve_model_base_paths(request) return web.json_response({"success": True, "data": model_base_paths}) @@ -160,7 +160,7 @@ async def read_model_info(request): filename = request.match_info.get("filename", None) try: - model_path = utils.get_valid_full_path(model_type, index, filename) + model_path = utils.get_valid_full_path(model_type, index, filename, request) result = services.get_model_info(model_path) return web.json_response({"success": True, "data": result}) except Exception as e: @@ -189,10 +189,10 @@ async def update_model(request): model_data: dict = await request.json() try: - model_path = utils.get_valid_full_path(model_type, index, filename) + model_path = utils.get_valid_full_path(model_type, index, filename, request) if model_path is None: raise RuntimeError(f"File {filename} not found") - services.update_model(model_path, model_data) + services.update_model(model_path, model_data, request) return web.json_response({"success": True}) except Exception as e: error_msg = f"Update model failed: {str(e)}" @@ -210,7 +210,7 @@ async def delete_model(request): filename = request.match_info.get("filename", None) try: - model_path = utils.get_valid_full_path(model_type, index, filename) + model_path = utils.get_valid_full_path(model_type, index, filename, request) if model_path is None: raise RuntimeError(f"File {filename} not found") services.remove_model(model_path) diff --git a/py/download.py b/py/download.py index 84f61d2..44ecf6d 100644 --- a/py/download.py +++ b/py/download.py @@ -167,7 +167,7 @@ async def create_model_download_task(task_data: dict, request): path_index = int(task_data.get("pathIndex", None)) fullname = task_data.get("fullname", None) - model_path = utils.get_full_path(model_type, path_index, fullname) + model_path = utils.get_full_path(model_type, path_index, fullname, request) # Check if the model path is valid if os.path.exists(model_path): raise RuntimeError(f"File already exists: {model_path}") @@ -261,6 +261,7 @@ async def download_model(task_id: str, request): progress_interval = 1.0 await download_model_file( + request, task_id=task_id, headers=headers, progress_callback=report_progress, @@ -288,6 +289,7 @@ async def download_model(task_id: str, request): async def download_model_file( + request, task_id: str, headers: dict, progress_callback: Callable[[TaskStatus], Awaitable[Any]], @@ -308,7 +310,7 @@ async def download_model_file( with open(description_file, "w", encoding="utf-8", newline="") as f: f.write(description) - model_path = utils.get_full_path(model_type, path_index, fullname) + model_path = utils.get_full_path(model_type, path_index, fullname, request) utils.rename_model(download_tmp_file, model_path) @@ -361,9 +363,7 @@ async def download_model_file( ) if response.status_code not in (200, 206): - raise RuntimeError( - f"Failed to download {task_content.fullname}, status code: {response.status_code}" - ) + raise RuntimeError(f"Failed to download {task_content.fullname}, status code: {response.status_code}") # Some models require logging in before they can be downloaded. # If no token is carried, it will be redirected to the login page. @@ -376,9 +376,7 @@ async def download_model_file( # If it cannot be downloaded, a redirect will definitely occur. # Maybe consider getting the redirect url from response.history to make a judgment. # Here we also need to consider how different websites are processed. - raise RuntimeError( - f"{task_content.fullname} needs to be logged in to download. Please set the API-Key first." - ) + raise RuntimeError(f"{task_content.fullname} needs to be logged in to download. Please set the API-Key first.") # When parsing model information from HuggingFace API, # the file size was not found and needs to be obtained from the response header. diff --git a/py/services.py b/py/services.py index b623bc4..129d3db 100644 --- a/py/services.py +++ b/py/services.py @@ -69,7 +69,7 @@ def get_model_info(model_path: str): } -def update_model(model_path: str, model_data: dict): +def update_model(model_path: str, model_data: dict, request): if "previewFile" in model_data: previewFile = model_data["previewFile"] @@ -87,7 +87,7 @@ def update_model(model_path: str, model_data: dict): raise RuntimeError("Invalid type or pathIndex or fullname") # get new path - new_model_path = utils.get_full_path(model_type, path_index, fullname) + new_model_path = utils.get_full_path(model_type, path_index, fullname, request) utils.rename_model(model_path, new_model_path) @@ -136,7 +136,7 @@ def fetch_model_info(model_page: str): async def download_model_info(scan_mode: str, request): utils.print_info(f"Download model info for {scan_mode}") - model_base_paths = utils.resolve_model_base_paths() + model_base_paths = utils.resolve_model_base_paths(request) for model_type in model_base_paths: folders, extensions = folder_paths.folder_names_and_paths[model_type] diff --git a/py/utils.py b/py/utils.py index 680a588..7878029 100644 --- a/py/utils.py +++ b/py/utils.py @@ -131,11 +131,11 @@ def resolve_model_base_paths(request): return model_base_paths -def get_full_path(model_type: str, path_index: int, filename: str): +def get_full_path(model_type: str, path_index: int, filename: str, request): """ Get the absolute path in the model type through string concatenation. """ - folders = resolve_model_base_paths().get(model_type, []) + folders = resolve_model_base_paths(request).get(model_type, []) if not path_index < len(folders): raise RuntimeError(f"PathIndex {path_index} is not in {model_type}") base_path = folders[path_index] @@ -143,11 +143,11 @@ def get_full_path(model_type: str, path_index: int, filename: str): return full_path -def get_valid_full_path(model_type: str, path_index: int, filename: str): +def get_valid_full_path(model_type: str, path_index: int, filename: str, request): """ Like get_full_path but it will check whether the file is valid. """ - folders = resolve_model_base_paths().get(model_type, []) + folders = resolve_model_base_paths(request).get(model_type, []) if not path_index < len(folders): raise RuntimeError(f"PathIndex {path_index} is not in {model_type}") base_path = folders[path_index] diff --git a/pyproject.toml b/pyproject.toml index ac91c26..80c84d5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,7 @@ [project] name = "comfyui-model-manager" description = "Manage models: browsing, download and delete." -version = "2.2.0" +version = "2.2.1" license = { file = "LICENSE" } dependencies = ["markdownify"]