Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
be383ac6e1 |
12
__init__.py
12
__init__.py
@@ -95,7 +95,7 @@ async def get_model_paths(request):
|
|||||||
"""
|
"""
|
||||||
Returns the base folders for models.
|
Returns the base folders for models.
|
||||||
"""
|
"""
|
||||||
model_base_paths = utils.resolve_model_base_paths(request)
|
model_base_paths = utils.resolve_model_base_paths()
|
||||||
return web.json_response({"success": True, "data": model_base_paths})
|
return web.json_response({"success": True, "data": model_base_paths})
|
||||||
|
|
||||||
|
|
||||||
@@ -130,7 +130,7 @@ async def list_model_types(request):
|
|||||||
Scan all models and read their information.
|
Scan all models and read their information.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
result = utils.resolve_model_base_paths(request)
|
result = utils.resolve_model_base_paths()
|
||||||
return web.json_response({"success": True, "data": result})
|
return web.json_response({"success": True, "data": result})
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error_msg = f"Read models failed: {str(e)}"
|
error_msg = f"Read models failed: {str(e)}"
|
||||||
@@ -160,7 +160,7 @@ async def read_model_info(request):
|
|||||||
filename = request.match_info.get("filename", None)
|
filename = request.match_info.get("filename", None)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
model_path = utils.get_valid_full_path(model_type, index, filename, request)
|
model_path = utils.get_valid_full_path(model_type, index, filename)
|
||||||
result = services.get_model_info(model_path)
|
result = services.get_model_info(model_path)
|
||||||
return web.json_response({"success": True, "data": result})
|
return web.json_response({"success": True, "data": result})
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -189,10 +189,10 @@ async def update_model(request):
|
|||||||
model_data: dict = await request.json()
|
model_data: dict = await request.json()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
model_path = utils.get_valid_full_path(model_type, index, filename, request)
|
model_path = utils.get_valid_full_path(model_type, index, filename)
|
||||||
if model_path is None:
|
if model_path is None:
|
||||||
raise RuntimeError(f"File {filename} not found")
|
raise RuntimeError(f"File {filename} not found")
|
||||||
services.update_model(model_path, model_data, request)
|
services.update_model(model_path, model_data)
|
||||||
return web.json_response({"success": True})
|
return web.json_response({"success": True})
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error_msg = f"Update model failed: {str(e)}"
|
error_msg = f"Update model failed: {str(e)}"
|
||||||
@@ -210,7 +210,7 @@ async def delete_model(request):
|
|||||||
filename = request.match_info.get("filename", None)
|
filename = request.match_info.get("filename", None)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
model_path = utils.get_valid_full_path(model_type, index, filename, request)
|
model_path = utils.get_valid_full_path(model_type, index, filename)
|
||||||
if model_path is None:
|
if model_path is None:
|
||||||
raise RuntimeError(f"File {filename} not found")
|
raise RuntimeError(f"File {filename} not found")
|
||||||
services.remove_model(model_path)
|
services.remove_model(model_path)
|
||||||
|
|||||||
@@ -12,8 +12,7 @@ setting_key = {
|
|||||||
"max_task_count": "ModelManager.Download.MaxTaskCount",
|
"max_task_count": "ModelManager.Download.MaxTaskCount",
|
||||||
},
|
},
|
||||||
"scan": {
|
"scan": {
|
||||||
"include_hidden_files": "ModelManager.Scan.IncludeHiddenFiles",
|
"include_hidden_files": "ModelManager.Scan.IncludeHiddenFiles"
|
||||||
"exclude_scan_types": "ModelManager.Scan.excludeScanTypes",
|
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -167,7 +167,7 @@ async def create_model_download_task(task_data: dict, request):
|
|||||||
path_index = int(task_data.get("pathIndex", None))
|
path_index = int(task_data.get("pathIndex", None))
|
||||||
fullname = task_data.get("fullname", None)
|
fullname = task_data.get("fullname", None)
|
||||||
|
|
||||||
model_path = utils.get_full_path(model_type, path_index, fullname, request)
|
model_path = utils.get_full_path(model_type, path_index, fullname)
|
||||||
# Check if the model path is valid
|
# Check if the model path is valid
|
||||||
if os.path.exists(model_path):
|
if os.path.exists(model_path):
|
||||||
raise RuntimeError(f"File already exists: {model_path}")
|
raise RuntimeError(f"File already exists: {model_path}")
|
||||||
@@ -261,7 +261,6 @@ async def download_model(task_id: str, request):
|
|||||||
|
|
||||||
progress_interval = 1.0
|
progress_interval = 1.0
|
||||||
await download_model_file(
|
await download_model_file(
|
||||||
request,
|
|
||||||
task_id=task_id,
|
task_id=task_id,
|
||||||
headers=headers,
|
headers=headers,
|
||||||
progress_callback=report_progress,
|
progress_callback=report_progress,
|
||||||
@@ -289,7 +288,6 @@ async def download_model(task_id: str, request):
|
|||||||
|
|
||||||
|
|
||||||
async def download_model_file(
|
async def download_model_file(
|
||||||
request,
|
|
||||||
task_id: str,
|
task_id: str,
|
||||||
headers: dict,
|
headers: dict,
|
||||||
progress_callback: Callable[[TaskStatus], Awaitable[Any]],
|
progress_callback: Callable[[TaskStatus], Awaitable[Any]],
|
||||||
@@ -310,7 +308,7 @@ async def download_model_file(
|
|||||||
with open(description_file, "w", encoding="utf-8", newline="") as f:
|
with open(description_file, "w", encoding="utf-8", newline="") as f:
|
||||||
f.write(description)
|
f.write(description)
|
||||||
|
|
||||||
model_path = utils.get_full_path(model_type, path_index, fullname, request)
|
model_path = utils.get_full_path(model_type, path_index, fullname)
|
||||||
|
|
||||||
utils.rename_model(download_tmp_file, model_path)
|
utils.rename_model(download_tmp_file, model_path)
|
||||||
|
|
||||||
@@ -363,7 +361,9 @@ async def download_model_file(
|
|||||||
)
|
)
|
||||||
|
|
||||||
if response.status_code not in (200, 206):
|
if response.status_code not in (200, 206):
|
||||||
raise RuntimeError(f"Failed to download {task_content.fullname}, status code: {response.status_code}")
|
raise RuntimeError(
|
||||||
|
f"Failed to download {task_content.fullname}, status code: {response.status_code}"
|
||||||
|
)
|
||||||
|
|
||||||
# Some models require logging in before they can be downloaded.
|
# Some models require logging in before they can be downloaded.
|
||||||
# If no token is carried, it will be redirected to the login page.
|
# If no token is carried, it will be redirected to the login page.
|
||||||
@@ -376,7 +376,9 @@ async def download_model_file(
|
|||||||
# If it cannot be downloaded, a redirect will definitely occur.
|
# If it cannot be downloaded, a redirect will definitely occur.
|
||||||
# Maybe consider getting the redirect url from response.history to make a judgment.
|
# Maybe consider getting the redirect url from response.history to make a judgment.
|
||||||
# Here we also need to consider how different websites are processed.
|
# Here we also need to consider how different websites are processed.
|
||||||
raise RuntimeError(f"{task_content.fullname} needs to be logged in to download. Please set the API-Key first.")
|
raise RuntimeError(
|
||||||
|
f"{task_content.fullname} needs to be logged in to download. Please set the API-Key first."
|
||||||
|
)
|
||||||
|
|
||||||
# When parsing model information from HuggingFace API,
|
# When parsing model information from HuggingFace API,
|
||||||
# the file size was not found and needs to be obtained from the response header.
|
# the file size was not found and needs to be obtained from the response header.
|
||||||
|
|||||||
@@ -69,7 +69,7 @@ def get_model_info(model_path: str):
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def update_model(model_path: str, model_data: dict, request):
|
def update_model(model_path: str, model_data: dict):
|
||||||
|
|
||||||
if "previewFile" in model_data:
|
if "previewFile" in model_data:
|
||||||
previewFile = model_data["previewFile"]
|
previewFile = model_data["previewFile"]
|
||||||
@@ -87,7 +87,7 @@ def update_model(model_path: str, model_data: dict, request):
|
|||||||
raise RuntimeError("Invalid type or pathIndex or fullname")
|
raise RuntimeError("Invalid type or pathIndex or fullname")
|
||||||
|
|
||||||
# get new path
|
# get new path
|
||||||
new_model_path = utils.get_full_path(model_type, path_index, fullname, request)
|
new_model_path = utils.get_full_path(model_type, path_index, fullname)
|
||||||
|
|
||||||
utils.rename_model(model_path, new_model_path)
|
utils.rename_model(model_path, new_model_path)
|
||||||
|
|
||||||
@@ -136,7 +136,7 @@ def fetch_model_info(model_page: str):
|
|||||||
|
|
||||||
async def download_model_info(scan_mode: str, request):
|
async def download_model_info(scan_mode: str, request):
|
||||||
utils.print_info(f"Download model info for {scan_mode}")
|
utils.print_info(f"Download model info for {scan_mode}")
|
||||||
model_base_paths = utils.resolve_model_base_paths(request)
|
model_base_paths = utils.resolve_model_base_paths()
|
||||||
for model_type in model_base_paths:
|
for model_type in model_base_paths:
|
||||||
|
|
||||||
folders, extensions = folder_paths.folder_names_and_paths[model_type]
|
folders, extensions = folder_paths.folder_names_and_paths[model_type]
|
||||||
|
|||||||
13
py/utils.py
13
py/utils.py
@@ -116,13 +116,10 @@ def download_web_distribution(version: str):
|
|||||||
print_error(f"An unexpected error occurred: {e}")
|
print_error(f"An unexpected error occurred: {e}")
|
||||||
|
|
||||||
|
|
||||||
def resolve_model_base_paths(request):
|
def resolve_model_base_paths():
|
||||||
folders = list(folder_paths.folder_names_and_paths.keys())
|
folders = list(folder_paths.folder_names_and_paths.keys())
|
||||||
model_base_paths = {}
|
model_base_paths = {}
|
||||||
folder_black_list = ["configs", "custom_nodes"]
|
folder_black_list = ["configs", "custom_nodes"]
|
||||||
custom_folders = get_setting_value(request, "scan.exclude_scan_types", "")
|
|
||||||
custom_black_list = [f.strip() for f in custom_folders.split(",") if f.strip()]
|
|
||||||
folder_black_list.extend(custom_black_list)
|
|
||||||
for folder in folders:
|
for folder in folders:
|
||||||
if folder in folder_black_list:
|
if folder in folder_black_list:
|
||||||
continue
|
continue
|
||||||
@@ -131,11 +128,11 @@ def resolve_model_base_paths(request):
|
|||||||
return model_base_paths
|
return model_base_paths
|
||||||
|
|
||||||
|
|
||||||
def get_full_path(model_type: str, path_index: int, filename: str, request):
|
def get_full_path(model_type: str, path_index: int, filename: str):
|
||||||
"""
|
"""
|
||||||
Get the absolute path in the model type through string concatenation.
|
Get the absolute path in the model type through string concatenation.
|
||||||
"""
|
"""
|
||||||
folders = resolve_model_base_paths(request).get(model_type, [])
|
folders = resolve_model_base_paths().get(model_type, [])
|
||||||
if not path_index < len(folders):
|
if not path_index < len(folders):
|
||||||
raise RuntimeError(f"PathIndex {path_index} is not in {model_type}")
|
raise RuntimeError(f"PathIndex {path_index} is not in {model_type}")
|
||||||
base_path = folders[path_index]
|
base_path = folders[path_index]
|
||||||
@@ -143,11 +140,11 @@ def get_full_path(model_type: str, path_index: int, filename: str, request):
|
|||||||
return full_path
|
return full_path
|
||||||
|
|
||||||
|
|
||||||
def get_valid_full_path(model_type: str, path_index: int, filename: str, request):
|
def get_valid_full_path(model_type: str, path_index: int, filename: str):
|
||||||
"""
|
"""
|
||||||
Like get_full_path but it will check whether the file is valid.
|
Like get_full_path but it will check whether the file is valid.
|
||||||
"""
|
"""
|
||||||
folders = resolve_model_base_paths(request).get(model_type, [])
|
folders = resolve_model_base_paths().get(model_type, [])
|
||||||
if not path_index < len(folders):
|
if not path_index < len(folders):
|
||||||
raise RuntimeError(f"PathIndex {path_index} is not in {model_type}")
|
raise RuntimeError(f"PathIndex {path_index} is not in {model_type}")
|
||||||
base_path = folders[path_index]
|
base_path = folders[path_index]
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "comfyui-model-manager"
|
name = "comfyui-model-manager"
|
||||||
description = "Manage models: browsing, download and delete."
|
description = "Manage models: browsing, download and delete."
|
||||||
version = "2.2.1"
|
version = "2.2.2"
|
||||||
license = { file = "LICENSE" }
|
license = { file = "LICENSE" }
|
||||||
dependencies = ["markdownify"]
|
dependencies = ["markdownify"]
|
||||||
|
|
||||||
|
|||||||
@@ -68,11 +68,12 @@ import ModelCard from 'components/ModelCard.vue'
|
|||||||
import ResponseInput from 'components/ResponseInput.vue'
|
import ResponseInput from 'components/ResponseInput.vue'
|
||||||
import ResponseScroll from 'components/ResponseScroll.vue'
|
import ResponseScroll from 'components/ResponseScroll.vue'
|
||||||
import ResponseSelect from 'components/ResponseSelect.vue'
|
import ResponseSelect from 'components/ResponseSelect.vue'
|
||||||
import { useConfig } from 'hooks/config'
|
import { configSetting, useConfig } from 'hooks/config'
|
||||||
import { useContainerQueries } from 'hooks/container'
|
import { useContainerQueries } from 'hooks/container'
|
||||||
import { useModels } from 'hooks/model'
|
import { useModels } from 'hooks/model'
|
||||||
import { defineResizeCallback } from 'hooks/resize'
|
import { defineResizeCallback } from 'hooks/resize'
|
||||||
import { chunk } from 'lodash'
|
import { chunk } from 'lodash'
|
||||||
|
import { app } from 'scripts/comfyAPI'
|
||||||
import { Model } from 'types/typings'
|
import { Model } from 'types/typings'
|
||||||
import { genModelKey } from 'utils/model'
|
import { genModelKey } from 'utils/model'
|
||||||
import { computed, ref, watch } from 'vue'
|
import { computed, ref, watch } from 'vue'
|
||||||
@@ -89,7 +90,20 @@ const searchContent = ref<string>()
|
|||||||
|
|
||||||
const currentType = ref('all')
|
const currentType = ref('all')
|
||||||
const typeOptions = computed(() => {
|
const typeOptions = computed(() => {
|
||||||
return ['all', ...Object.keys(folders.value)].map((type) => {
|
const excludeScanTypes = app.ui?.settings.getSettingValue<string>(
|
||||||
|
configSetting.excludeScanTypes,
|
||||||
|
)
|
||||||
|
const customBlackList =
|
||||||
|
excludeScanTypes
|
||||||
|
?.split(',')
|
||||||
|
.map((type) => type.trim())
|
||||||
|
.filter(Boolean) ?? []
|
||||||
|
return [
|
||||||
|
'all',
|
||||||
|
...Object.keys(folders.value).filter(
|
||||||
|
(folder) => !customBlackList.includes(folder),
|
||||||
|
),
|
||||||
|
].map((type) => {
|
||||||
return {
|
return {
|
||||||
label: type,
|
label: type,
|
||||||
value: type,
|
value: type,
|
||||||
|
|||||||
@@ -41,6 +41,10 @@ declare module 'hooks/store' {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export const configSetting = {
|
||||||
|
excludeScanTypes: 'ModelManager.Scan.excludeScanTypes',
|
||||||
|
}
|
||||||
|
|
||||||
function useAddConfigSettings(store: import('hooks/store').StoreProvider) {
|
function useAddConfigSettings(store: import('hooks/store').StoreProvider) {
|
||||||
const { toast } = useToast()
|
const { toast } = useToast()
|
||||||
const { t } = useI18n()
|
const { t } = useI18n()
|
||||||
@@ -191,7 +195,7 @@ function useAddConfigSettings(store: import('hooks/store').StoreProvider) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
app.ui?.settings.addSetting({
|
app.ui?.settings.addSetting({
|
||||||
id: 'ModelManager.Scan.excludeScanTypes',
|
id: configSetting.excludeScanTypes,
|
||||||
category: [t('modelManager'), t('setting.scan'), 'ExcludeScanTypes'],
|
category: [t('modelManager'), t('setting.scan'), 'ExcludeScanTypes'],
|
||||||
name: t('setting.excludeScanTypes'),
|
name: t('setting.excludeScanTypes'),
|
||||||
defaultValue: undefined,
|
defaultValue: undefined,
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ import {
|
|||||||
unref,
|
unref,
|
||||||
} from 'vue'
|
} from 'vue'
|
||||||
import { useI18n } from 'vue-i18n'
|
import { useI18n } from 'vue-i18n'
|
||||||
|
import { configSetting } from './config'
|
||||||
|
|
||||||
type ModelFolder = Record<string, string[]>
|
type ModelFolder = Record<string, string[]>
|
||||||
|
|
||||||
@@ -56,8 +57,20 @@ export const useModels = defineStore('models', (store) => {
|
|||||||
const refreshAllModels = async (force = false) => {
|
const refreshAllModels = async (force = false) => {
|
||||||
const forceRefresh = force ? refreshFolders() : Promise.resolve()
|
const forceRefresh = force ? refreshFolders() : Promise.resolve()
|
||||||
models.value = {}
|
models.value = {}
|
||||||
|
const excludeScanTypes = app.ui?.settings.getSettingValue<string>(
|
||||||
|
configSetting.excludeScanTypes,
|
||||||
|
)
|
||||||
|
const customBlackList =
|
||||||
|
excludeScanTypes
|
||||||
|
?.split(',')
|
||||||
|
.map((type) => type.trim())
|
||||||
|
.filter(Boolean) ?? []
|
||||||
return forceRefresh.then(() =>
|
return forceRefresh.then(() =>
|
||||||
Promise.allSettled(Object.keys(folders.value).map(refreshModels)),
|
Promise.allSettled(
|
||||||
|
Object.keys(folders.value)
|
||||||
|
.filter((folder) => !customBlackList.includes(folder))
|
||||||
|
.map(refreshModels),
|
||||||
|
),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user