1 Commits

Author SHA1 Message Date
Hayden
c2406a1fd1 fix: missing parameter (#93) 2025-01-13 15:58:11 +08:00
5 changed files with 19 additions and 21 deletions

View File

@@ -95,7 +95,7 @@ async def get_model_paths(request):
"""
Returns the base folders for models.
"""
model_base_paths = utils.resolve_model_base_paths()
model_base_paths = utils.resolve_model_base_paths(request)
return web.json_response({"success": True, "data": model_base_paths})
@@ -160,7 +160,7 @@ async def read_model_info(request):
filename = request.match_info.get("filename", None)
try:
model_path = utils.get_valid_full_path(model_type, index, filename)
model_path = utils.get_valid_full_path(model_type, index, filename, request)
result = services.get_model_info(model_path)
return web.json_response({"success": True, "data": result})
except Exception as e:
@@ -189,10 +189,10 @@ async def update_model(request):
model_data: dict = await request.json()
try:
model_path = utils.get_valid_full_path(model_type, index, filename)
model_path = utils.get_valid_full_path(model_type, index, filename, request)
if model_path is None:
raise RuntimeError(f"File {filename} not found")
services.update_model(model_path, model_data)
services.update_model(model_path, model_data, request)
return web.json_response({"success": True})
except Exception as e:
error_msg = f"Update model failed: {str(e)}"
@@ -210,7 +210,7 @@ async def delete_model(request):
filename = request.match_info.get("filename", None)
try:
model_path = utils.get_valid_full_path(model_type, index, filename)
model_path = utils.get_valid_full_path(model_type, index, filename, request)
if model_path is None:
raise RuntimeError(f"File {filename} not found")
services.remove_model(model_path)

View File

@@ -167,7 +167,7 @@ async def create_model_download_task(task_data: dict, request):
path_index = int(task_data.get("pathIndex", None))
fullname = task_data.get("fullname", None)
model_path = utils.get_full_path(model_type, path_index, fullname)
model_path = utils.get_full_path(model_type, path_index, fullname, request)
# Check if the model path is valid
if os.path.exists(model_path):
raise RuntimeError(f"File already exists: {model_path}")
@@ -261,6 +261,7 @@ async def download_model(task_id: str, request):
progress_interval = 1.0
await download_model_file(
request,
task_id=task_id,
headers=headers,
progress_callback=report_progress,
@@ -288,6 +289,7 @@ async def download_model(task_id: str, request):
async def download_model_file(
request,
task_id: str,
headers: dict,
progress_callback: Callable[[TaskStatus], Awaitable[Any]],
@@ -308,7 +310,7 @@ async def download_model_file(
with open(description_file, "w", encoding="utf-8", newline="") as f:
f.write(description)
model_path = utils.get_full_path(model_type, path_index, fullname)
model_path = utils.get_full_path(model_type, path_index, fullname, request)
utils.rename_model(download_tmp_file, model_path)
@@ -361,9 +363,7 @@ async def download_model_file(
)
if response.status_code not in (200, 206):
raise RuntimeError(
f"Failed to download {task_content.fullname}, status code: {response.status_code}"
)
raise RuntimeError(f"Failed to download {task_content.fullname}, status code: {response.status_code}")
# Some models require logging in before they can be downloaded.
# If no token is carried, it will be redirected to the login page.
@@ -376,9 +376,7 @@ async def download_model_file(
# If it cannot be downloaded, a redirect will definitely occur.
# Maybe consider getting the redirect url from response.history to make a judgment.
# Here we also need to consider how different websites are processed.
raise RuntimeError(
f"{task_content.fullname} needs to be logged in to download. Please set the API-Key first."
)
raise RuntimeError(f"{task_content.fullname} needs to be logged in to download. Please set the API-Key first.")
# When parsing model information from HuggingFace API,
# the file size was not found and needs to be obtained from the response header.

View File

@@ -69,7 +69,7 @@ def get_model_info(model_path: str):
}
def update_model(model_path: str, model_data: dict):
def update_model(model_path: str, model_data: dict, request):
if "previewFile" in model_data:
previewFile = model_data["previewFile"]
@@ -87,7 +87,7 @@ def update_model(model_path: str, model_data: dict):
raise RuntimeError("Invalid type or pathIndex or fullname")
# get new path
new_model_path = utils.get_full_path(model_type, path_index, fullname)
new_model_path = utils.get_full_path(model_type, path_index, fullname, request)
utils.rename_model(model_path, new_model_path)
@@ -136,7 +136,7 @@ def fetch_model_info(model_page: str):
async def download_model_info(scan_mode: str, request):
utils.print_info(f"Download model info for {scan_mode}")
model_base_paths = utils.resolve_model_base_paths()
model_base_paths = utils.resolve_model_base_paths(request)
for model_type in model_base_paths:
folders, extensions = folder_paths.folder_names_and_paths[model_type]

View File

@@ -131,11 +131,11 @@ def resolve_model_base_paths(request):
return model_base_paths
def get_full_path(model_type: str, path_index: int, filename: str):
def get_full_path(model_type: str, path_index: int, filename: str, request):
"""
Get the absolute path in the model type through string concatenation.
"""
folders = resolve_model_base_paths().get(model_type, [])
folders = resolve_model_base_paths(request).get(model_type, [])
if not path_index < len(folders):
raise RuntimeError(f"PathIndex {path_index} is not in {model_type}")
base_path = folders[path_index]
@@ -143,11 +143,11 @@ def get_full_path(model_type: str, path_index: int, filename: str):
return full_path
def get_valid_full_path(model_type: str, path_index: int, filename: str):
def get_valid_full_path(model_type: str, path_index: int, filename: str, request):
"""
Like get_full_path but it will check whether the file is valid.
"""
folders = resolve_model_base_paths().get(model_type, [])
folders = resolve_model_base_paths(request).get(model_type, [])
if not path_index < len(folders):
raise RuntimeError(f"PathIndex {path_index} is not in {model_type}")
base_path = folders[path_index]

View File

@@ -1,7 +1,7 @@
[project]
name = "comfyui-model-manager"
description = "Manage models: browsing, download and delete."
version = "2.2.0"
version = "2.2.1"
license = { file = "LICENSE" }
dependencies = ["markdownify"]