fix: potential bug after adding excluded directories (#94)
* Revert "fix: missing parameter (#93)" This reverts commitc2406a1fd1. * Revert "feat: add exclude scan model types (#92)" This reverts commit40a1a7f43a. * feat: add exclude scan model types * fix: potential bug after adding excluded directories
This commit is contained in:
@@ -12,8 +12,7 @@ setting_key = {
|
||||
"max_task_count": "ModelManager.Download.MaxTaskCount",
|
||||
},
|
||||
"scan": {
|
||||
"include_hidden_files": "ModelManager.Scan.IncludeHiddenFiles",
|
||||
"exclude_scan_types": "ModelManager.Scan.excludeScanTypes",
|
||||
"include_hidden_files": "ModelManager.Scan.IncludeHiddenFiles"
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -167,7 +167,7 @@ async def create_model_download_task(task_data: dict, request):
|
||||
path_index = int(task_data.get("pathIndex", None))
|
||||
fullname = task_data.get("fullname", None)
|
||||
|
||||
model_path = utils.get_full_path(model_type, path_index, fullname, request)
|
||||
model_path = utils.get_full_path(model_type, path_index, fullname)
|
||||
# Check if the model path is valid
|
||||
if os.path.exists(model_path):
|
||||
raise RuntimeError(f"File already exists: {model_path}")
|
||||
@@ -261,7 +261,6 @@ async def download_model(task_id: str, request):
|
||||
|
||||
progress_interval = 1.0
|
||||
await download_model_file(
|
||||
request,
|
||||
task_id=task_id,
|
||||
headers=headers,
|
||||
progress_callback=report_progress,
|
||||
@@ -289,7 +288,6 @@ async def download_model(task_id: str, request):
|
||||
|
||||
|
||||
async def download_model_file(
|
||||
request,
|
||||
task_id: str,
|
||||
headers: dict,
|
||||
progress_callback: Callable[[TaskStatus], Awaitable[Any]],
|
||||
@@ -310,7 +308,7 @@ async def download_model_file(
|
||||
with open(description_file, "w", encoding="utf-8", newline="") as f:
|
||||
f.write(description)
|
||||
|
||||
model_path = utils.get_full_path(model_type, path_index, fullname, request)
|
||||
model_path = utils.get_full_path(model_type, path_index, fullname)
|
||||
|
||||
utils.rename_model(download_tmp_file, model_path)
|
||||
|
||||
@@ -363,7 +361,9 @@ async def download_model_file(
|
||||
)
|
||||
|
||||
if response.status_code not in (200, 206):
|
||||
raise RuntimeError(f"Failed to download {task_content.fullname}, status code: {response.status_code}")
|
||||
raise RuntimeError(
|
||||
f"Failed to download {task_content.fullname}, status code: {response.status_code}"
|
||||
)
|
||||
|
||||
# Some models require logging in before they can be downloaded.
|
||||
# If no token is carried, it will be redirected to the login page.
|
||||
@@ -376,7 +376,9 @@ async def download_model_file(
|
||||
# If it cannot be downloaded, a redirect will definitely occur.
|
||||
# Maybe consider getting the redirect url from response.history to make a judgment.
|
||||
# Here we also need to consider how different websites are processed.
|
||||
raise RuntimeError(f"{task_content.fullname} needs to be logged in to download. Please set the API-Key first.")
|
||||
raise RuntimeError(
|
||||
f"{task_content.fullname} needs to be logged in to download. Please set the API-Key first."
|
||||
)
|
||||
|
||||
# When parsing model information from HuggingFace API,
|
||||
# the file size was not found and needs to be obtained from the response header.
|
||||
|
||||
@@ -69,7 +69,7 @@ def get_model_info(model_path: str):
|
||||
}
|
||||
|
||||
|
||||
def update_model(model_path: str, model_data: dict, request):
|
||||
def update_model(model_path: str, model_data: dict):
|
||||
|
||||
if "previewFile" in model_data:
|
||||
previewFile = model_data["previewFile"]
|
||||
@@ -87,7 +87,7 @@ def update_model(model_path: str, model_data: dict, request):
|
||||
raise RuntimeError("Invalid type or pathIndex or fullname")
|
||||
|
||||
# get new path
|
||||
new_model_path = utils.get_full_path(model_type, path_index, fullname, request)
|
||||
new_model_path = utils.get_full_path(model_type, path_index, fullname)
|
||||
|
||||
utils.rename_model(model_path, new_model_path)
|
||||
|
||||
@@ -136,7 +136,7 @@ def fetch_model_info(model_page: str):
|
||||
|
||||
async def download_model_info(scan_mode: str, request):
|
||||
utils.print_info(f"Download model info for {scan_mode}")
|
||||
model_base_paths = utils.resolve_model_base_paths(request)
|
||||
model_base_paths = utils.resolve_model_base_paths()
|
||||
for model_type in model_base_paths:
|
||||
|
||||
folders, extensions = folder_paths.folder_names_and_paths[model_type]
|
||||
|
||||
13
py/utils.py
13
py/utils.py
@@ -116,13 +116,10 @@ def download_web_distribution(version: str):
|
||||
print_error(f"An unexpected error occurred: {e}")
|
||||
|
||||
|
||||
def resolve_model_base_paths(request):
|
||||
def resolve_model_base_paths():
|
||||
folders = list(folder_paths.folder_names_and_paths.keys())
|
||||
model_base_paths = {}
|
||||
folder_black_list = ["configs", "custom_nodes"]
|
||||
custom_folders = get_setting_value(request, "scan.exclude_scan_types", "")
|
||||
custom_black_list = [f.strip() for f in custom_folders.split(",") if f.strip()]
|
||||
folder_black_list.extend(custom_black_list)
|
||||
for folder in folders:
|
||||
if folder in folder_black_list:
|
||||
continue
|
||||
@@ -131,11 +128,11 @@ def resolve_model_base_paths(request):
|
||||
return model_base_paths
|
||||
|
||||
|
||||
def get_full_path(model_type: str, path_index: int, filename: str, request):
|
||||
def get_full_path(model_type: str, path_index: int, filename: str):
|
||||
"""
|
||||
Get the absolute path in the model type through string concatenation.
|
||||
"""
|
||||
folders = resolve_model_base_paths(request).get(model_type, [])
|
||||
folders = resolve_model_base_paths().get(model_type, [])
|
||||
if not path_index < len(folders):
|
||||
raise RuntimeError(f"PathIndex {path_index} is not in {model_type}")
|
||||
base_path = folders[path_index]
|
||||
@@ -143,11 +140,11 @@ def get_full_path(model_type: str, path_index: int, filename: str, request):
|
||||
return full_path
|
||||
|
||||
|
||||
def get_valid_full_path(model_type: str, path_index: int, filename: str, request):
|
||||
def get_valid_full_path(model_type: str, path_index: int, filename: str):
|
||||
"""
|
||||
Like get_full_path but it will check whether the file is valid.
|
||||
"""
|
||||
folders = resolve_model_base_paths(request).get(model_type, [])
|
||||
folders = resolve_model_base_paths().get(model_type, [])
|
||||
if not path_index < len(folders):
|
||||
raise RuntimeError(f"PathIndex {path_index} is not in {model_type}")
|
||||
base_path = folders[path_index]
|
||||
|
||||
Reference in New Issue
Block a user