refactor: Migrate the project functionality and optimize the code structure
This commit is contained in:
33
py/config.py
Normal file
33
py/config.py
Normal file
@@ -0,0 +1,33 @@
|
||||
extension_uri: str = None
|
||||
model_base_paths: dict[str, list[str]] = {}
|
||||
|
||||
|
||||
setting_key = {
|
||||
"api_key": {
|
||||
"civitai": "ModelManager.APIKey.Civitai",
|
||||
"huggingface": "ModelManager.APIKey.HuggingFace",
|
||||
},
|
||||
"download": {
|
||||
"max_task_count": "ModelManager.Download.MaxTaskCount",
|
||||
},
|
||||
}
|
||||
|
||||
user_agent = "Mozilla/5.0 (iPad; CPU OS 12_2 like Mac OS X) AppleWebKit/605.1.15 (KHTML, like Gecko) Mobile/15E148"
|
||||
|
||||
|
||||
from server import PromptServer
|
||||
|
||||
serverInstance = PromptServer.instance
|
||||
routes = serverInstance.routes
|
||||
|
||||
|
||||
class FakeRequest:
|
||||
def __init__(self):
|
||||
self.headers = {}
|
||||
|
||||
|
||||
class CustomException(BaseException):
|
||||
def __init__(self, type: str, message: str = None) -> None:
|
||||
self.type = type
|
||||
self.message = message
|
||||
super().__init__(message)
|
||||
362
py/download.py
Normal file
362
py/download.py
Normal file
@@ -0,0 +1,362 @@
|
||||
import os
|
||||
import uuid
|
||||
import time
|
||||
import logging
|
||||
import requests
|
||||
import folder_paths
|
||||
import traceback
|
||||
from typing import Callable, Awaitable, Any, Literal, Union, Optional
|
||||
from dataclasses import dataclass
|
||||
from . import config
|
||||
from . import utils
|
||||
from . import socket
|
||||
from . import thread
|
||||
|
||||
|
||||
@dataclass
|
||||
class TaskStatus:
|
||||
taskId: str
|
||||
type: str
|
||||
fullname: str
|
||||
preview: str
|
||||
status: Literal["pause", "waiting", "doing"] = "pause"
|
||||
platform: Union[str, None] = None
|
||||
downloadedSize: float = 0
|
||||
totalSize: float = 0
|
||||
progress: float = 0
|
||||
bps: float = 0
|
||||
error: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class TaskContent:
|
||||
type: str
|
||||
pathIndex: int
|
||||
fullname: str
|
||||
description: str
|
||||
downloadPlatform: str
|
||||
downloadUrl: str
|
||||
sizeBytes: float
|
||||
hashes: Optional[dict[str, str]] = None
|
||||
|
||||
|
||||
download_model_task_status: dict[str, TaskStatus] = {}
|
||||
download_thread_pool = thread.DownloadThreadPool()
|
||||
|
||||
|
||||
def set_task_content(task_id: str, task_content: Union[TaskContent, dict]):
|
||||
download_path = utils.get_download_path()
|
||||
task_file_path = os.path.join(download_path, f"{task_id}.task")
|
||||
utils.save_dict_pickle_file(task_file_path, utils.unpack_dataclass(task_content))
|
||||
|
||||
|
||||
def get_task_content(task_id: str):
|
||||
download_path = utils.get_download_path()
|
||||
task_file = os.path.join(download_path, f"{task_id}.task")
|
||||
if not os.path.isfile(task_file):
|
||||
raise RuntimeError(f"Task {task_id} not found")
|
||||
task_content = utils.load_dict_pickle_file(task_file)
|
||||
task_content["pathIndex"] = int(task_content.get("pathIndex", 0))
|
||||
task_content["sizeBytes"] = float(task_content.get("sizeBytes", 0))
|
||||
return TaskContent(**task_content)
|
||||
|
||||
|
||||
def get_task_status(task_id: str):
|
||||
task_status = download_model_task_status.get(task_id, None)
|
||||
|
||||
if task_status is None:
|
||||
download_path = utils.get_download_path()
|
||||
task_content = get_task_content(task_id)
|
||||
download_file = os.path.join(download_path, f"{task_id}.download")
|
||||
download_size = 0
|
||||
if os.path.exists(download_file):
|
||||
download_size = os.path.getsize(download_file)
|
||||
|
||||
total_size = task_content.sizeBytes
|
||||
task_status = TaskStatus(
|
||||
taskId=task_id,
|
||||
type=task_content.type,
|
||||
fullname=task_content.fullname,
|
||||
preview=utils.get_model_preview_name(download_file),
|
||||
platform=task_content.downloadPlatform,
|
||||
downloadedSize=download_size,
|
||||
totalSize=task_content.sizeBytes,
|
||||
progress=download_size / total_size * 100 if total_size > 0 else 0,
|
||||
)
|
||||
|
||||
download_model_task_status[task_id] = task_status
|
||||
|
||||
return task_status
|
||||
|
||||
|
||||
def delete_task_status(task_id: str):
|
||||
download_model_task_status.pop(task_id, None)
|
||||
|
||||
|
||||
async def scan_model_download_task_list(sid: str):
|
||||
"""
|
||||
Scan the download directory and send the task list to the client.
|
||||
"""
|
||||
try:
|
||||
download_dir = utils.get_download_path()
|
||||
task_files = utils.search_files(download_dir)
|
||||
task_files = folder_paths.filter_files_extensions(task_files, [".task"])
|
||||
task_files = sorted(
|
||||
task_files,
|
||||
key=lambda x: os.stat(os.path.join(download_dir, x)).st_ctime,
|
||||
reverse=True,
|
||||
)
|
||||
task_list: list[dict] = []
|
||||
for task_file in task_files:
|
||||
task_id = task_file.replace(".task", "")
|
||||
task_status = get_task_status(task_id)
|
||||
task_list.append(task_status)
|
||||
|
||||
await socket.send_json("downloadTaskList", task_list, sid)
|
||||
except Exception as e:
|
||||
error_msg = f"Refresh task list failed: {e}"
|
||||
await socket.send_json("error", error_msg, sid)
|
||||
logging.error(error_msg)
|
||||
|
||||
|
||||
async def create_model_download_task(post: dict):
|
||||
"""
|
||||
Creates a download task for the given post.
|
||||
"""
|
||||
model_type = post.get("type", None)
|
||||
path_index = int(post.get("pathIndex", None))
|
||||
fullname = post.get("fullname", None)
|
||||
|
||||
model_path = utils.get_full_path(model_type, path_index, fullname)
|
||||
# Check if the model path is valid
|
||||
if os.path.exists(model_path):
|
||||
raise RuntimeError(f"File already exists: {model_path}")
|
||||
|
||||
download_path = utils.get_download_path()
|
||||
|
||||
task_id = uuid.uuid4().hex
|
||||
task_path = os.path.join(download_path, f"{task_id}.task")
|
||||
if os.path.exists(task_path):
|
||||
raise RuntimeError(f"Task {task_id} already exists")
|
||||
|
||||
try:
|
||||
previewFile = post.pop("previewFile", None)
|
||||
utils.save_model_preview_image(task_path, previewFile)
|
||||
set_task_content(task_id, post)
|
||||
task_status = TaskStatus(
|
||||
taskId=task_id,
|
||||
type=model_type,
|
||||
fullname=fullname,
|
||||
preview=utils.get_model_preview_name(task_path),
|
||||
platform=post.get("downloadPlatform", None),
|
||||
totalSize=float(post.get("sizeBytes", 0)),
|
||||
)
|
||||
download_model_task_status[task_id] = task_status
|
||||
await socket.send_json("createDownloadTask", task_status)
|
||||
except Exception as e:
|
||||
await delete_model_download_task(task_id)
|
||||
raise RuntimeError(str(e)) from e
|
||||
|
||||
await download_model(task_id)
|
||||
return task_id
|
||||
|
||||
|
||||
async def pause_model_download_task(task_id: str):
|
||||
task_status = get_task_status(task_id=task_id)
|
||||
task_status.status = "pause"
|
||||
|
||||
|
||||
async def delete_model_download_task(task_id: str):
|
||||
task_status = get_task_status(task_id)
|
||||
is_running = task_status.status == "doing"
|
||||
task_status.status = "waiting"
|
||||
await socket.send_json("deleteDownloadTask", task_id)
|
||||
|
||||
# Pause the task
|
||||
if is_running:
|
||||
task_status.status = "pause"
|
||||
time.sleep(1)
|
||||
|
||||
download_dir = utils.get_download_path()
|
||||
task_file_list = os.listdir(download_dir)
|
||||
for task_file in task_file_list:
|
||||
task_file_target = os.path.splitext(task_file)[0]
|
||||
if task_file_target == task_id:
|
||||
delete_task_status(task_id)
|
||||
os.remove(os.path.join(download_dir, task_file))
|
||||
|
||||
await socket.send_json("deleteDownloadTask", task_id)
|
||||
|
||||
|
||||
async def download_model(task_id: str):
|
||||
async def download_task(task_id: str):
|
||||
async def report_progress(task_status: TaskStatus):
|
||||
await socket.send_json("updateDownloadTask", task_status)
|
||||
|
||||
try:
|
||||
# When starting a task from the queue, the task may not exist
|
||||
task_status = get_task_status(task_id)
|
||||
except:
|
||||
return
|
||||
|
||||
# Update task status
|
||||
task_status.status = "doing"
|
||||
await socket.send_json("updateDownloadTask", task_status)
|
||||
|
||||
try:
|
||||
|
||||
# Set download request headers
|
||||
headers = {"User-Agent": config.user_agent}
|
||||
|
||||
download_platform = task_status.platform
|
||||
if download_platform == "civitai":
|
||||
api_key = utils.get_setting_value("api_key.civitai")
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
|
||||
elif download_platform == "huggingface":
|
||||
api_key = utils.get_setting_value("api_key.huggingface")
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
|
||||
progress_interval = 1.0
|
||||
await download_model_file(
|
||||
task_id=task_id,
|
||||
headers=headers,
|
||||
progress_callback=report_progress,
|
||||
interval=progress_interval,
|
||||
)
|
||||
except Exception as e:
|
||||
task_status.status = "pause"
|
||||
task_status.error = str(e)
|
||||
await socket.send_json("updateDownloadTask", task_status)
|
||||
task_status.error = None
|
||||
logging.error(str(e))
|
||||
|
||||
try:
|
||||
status = download_thread_pool.submit(download_task, task_id)
|
||||
if status == "Waiting":
|
||||
task_status = get_task_status(task_id)
|
||||
task_status.status = "waiting"
|
||||
await socket.send_json("updateDownloadTask", task_status)
|
||||
except Exception as e:
|
||||
task_status.status = "pause"
|
||||
task_status.error = str(e)
|
||||
await socket.send_json("updateDownloadTask", task_status)
|
||||
task_status.error = None
|
||||
logging.error(traceback.format_exc())
|
||||
|
||||
|
||||
async def download_model_file(
|
||||
task_id: str,
|
||||
headers: dict,
|
||||
progress_callback: Callable[[TaskStatus], Awaitable[Any]],
|
||||
interval: float = 1.0,
|
||||
):
|
||||
|
||||
async def download_complete():
|
||||
"""
|
||||
Restore the model information from the task file
|
||||
and move the model file to the target directory.
|
||||
"""
|
||||
model_type = task_content.type
|
||||
path_index = task_content.pathIndex
|
||||
fullname = task_content.fullname
|
||||
# Write description file
|
||||
description = task_content.description
|
||||
description_file = os.path.join(download_path, f"{task_id}.md")
|
||||
with open(description_file, "w") as f:
|
||||
f.write(description)
|
||||
|
||||
model_path = utils.get_full_path(model_type, path_index, fullname)
|
||||
|
||||
utils.rename_model(download_tmp_file, model_path)
|
||||
|
||||
time.sleep(1)
|
||||
task_file = os.path.join(download_path, f"{task_id}.task")
|
||||
os.remove(task_file)
|
||||
await socket.send_json("completeDownloadTask", task_id)
|
||||
|
||||
async def update_progress():
|
||||
nonlocal last_update_time
|
||||
nonlocal last_downloaded_size
|
||||
progress = (downloaded_size / total_size) * 100 if total_size > 0 else 0
|
||||
task_status.downloadedSize = downloaded_size
|
||||
task_status.progress = progress
|
||||
task_status.bps = downloaded_size - last_downloaded_size
|
||||
await progress_callback(task_status)
|
||||
last_update_time = time.time()
|
||||
last_downloaded_size = downloaded_size
|
||||
|
||||
task_status = get_task_status(task_id)
|
||||
task_content = get_task_content(task_id)
|
||||
|
||||
# Check download uri
|
||||
model_url = task_content.downloadUrl
|
||||
if not model_url:
|
||||
raise RuntimeError("No downloadUrl found")
|
||||
|
||||
download_path = utils.get_download_path()
|
||||
download_tmp_file = os.path.join(download_path, f"{task_id}.download")
|
||||
|
||||
downloaded_size = 0
|
||||
if os.path.isfile(download_tmp_file):
|
||||
downloaded_size = os.path.getsize(download_tmp_file)
|
||||
headers["Range"] = f"bytes={downloaded_size}-"
|
||||
|
||||
total_size = task_content.sizeBytes
|
||||
|
||||
if total_size > 0 and downloaded_size == total_size:
|
||||
await download_complete()
|
||||
return
|
||||
|
||||
last_update_time = time.time()
|
||||
last_downloaded_size = downloaded_size
|
||||
|
||||
response = requests.get(
|
||||
url=model_url,
|
||||
headers=headers,
|
||||
stream=True,
|
||||
allow_redirects=True,
|
||||
)
|
||||
|
||||
if response.status_code not in (200, 206):
|
||||
raise RuntimeError(
|
||||
f"Failed to download {task_content.fullname}, status code: {response.status_code}"
|
||||
)
|
||||
|
||||
# Some models require logging in before they can be downloaded.
|
||||
# If no token is carried, it will be redirected to the login page.
|
||||
content_type = response.headers.get("content-type")
|
||||
if content_type and content_type.startswith("text/html"):
|
||||
raise RuntimeError(
|
||||
f"{task_content.fullname} needs to be logged in to download. Please set the API-Key first."
|
||||
)
|
||||
|
||||
# When parsing model information from HuggingFace API,
|
||||
# the file size was not found and needs to be obtained from the response header.
|
||||
if total_size == 0:
|
||||
total_size = int(response.headers.get("content-length", 0))
|
||||
task_content.sizeBytes = total_size
|
||||
task_status.totalSize = total_size
|
||||
set_task_content(task_id, task_content)
|
||||
await socket.send_json("updateDownloadTask", task_content)
|
||||
|
||||
with open(download_tmp_file, "ab") as f:
|
||||
for chunk in response.iter_content(chunk_size=8192):
|
||||
if task_status.status == "pause":
|
||||
break
|
||||
|
||||
f.write(chunk)
|
||||
downloaded_size += len(chunk)
|
||||
|
||||
if time.time() - last_update_time >= interval:
|
||||
await update_progress()
|
||||
|
||||
await update_progress()
|
||||
|
||||
if total_size > 0 and downloaded_size == total_size:
|
||||
await download_complete()
|
||||
else:
|
||||
task_status.status = "pause"
|
||||
await socket.send_json("updateDownloadTask", task_status)
|
||||
145
py/services.py
Normal file
145
py/services.py
Normal file
@@ -0,0 +1,145 @@
|
||||
import os
|
||||
import logging
|
||||
import traceback
|
||||
import folder_paths
|
||||
|
||||
from typing import Any
|
||||
from multidict import MultiDictProxy
|
||||
from . import utils
|
||||
from . import socket
|
||||
from . import download
|
||||
|
||||
|
||||
async def connect_websocket(request):
|
||||
async def message_handler(event_type: str, detail: Any, sid: str):
|
||||
try:
|
||||
if event_type == "downloadTaskList":
|
||||
await download.scan_model_download_task_list(sid=sid)
|
||||
|
||||
if event_type == "resumeDownloadTask":
|
||||
await download.download_model(task_id=detail)
|
||||
|
||||
if event_type == "pauseDownloadTask":
|
||||
await download.pause_model_download_task(task_id=detail)
|
||||
|
||||
if event_type == "deleteDownloadTask":
|
||||
await download.delete_model_download_task(task_id=detail)
|
||||
except Exception:
|
||||
logging.error(traceback.format_exc())
|
||||
|
||||
ws = await socket.create_websocket_handler(request, handler=message_handler)
|
||||
return ws
|
||||
|
||||
|
||||
def scan_models_by_model_type(model_type: str):
|
||||
"""
|
||||
Scans all models in the given model type and returns a list of models.
|
||||
"""
|
||||
out = []
|
||||
folders, extensions = folder_paths.folder_names_and_paths[model_type]
|
||||
for path_index, base_path in enumerate(folders):
|
||||
files = utils.recursive_search_files(base_path)
|
||||
|
||||
models = folder_paths.filter_files_extensions(files, extensions)
|
||||
|
||||
for fullname in models:
|
||||
"""
|
||||
fullname is model path relative to base_path
|
||||
eg.
|
||||
abs_path is /path/to/models/stable-diffusion/custom_group/model_name.ckpt
|
||||
base_path is /path/to/models/stable-diffusion
|
||||
fullname is custom_group/model_name.ckpt
|
||||
basename is custom_group/model_name
|
||||
extension is .ckpt
|
||||
"""
|
||||
|
||||
fullname = fullname.replace(os.path.sep, "/")
|
||||
basename = os.path.splitext(fullname)[0]
|
||||
extension = os.path.splitext(fullname)[1]
|
||||
prefix_path = fullname.replace(os.path.basename(fullname), "")
|
||||
|
||||
abs_path = os.path.join(base_path, fullname)
|
||||
file_stats = os.stat(abs_path)
|
||||
|
||||
# Resolve metadata
|
||||
metadata = utils.get_model_metadata(abs_path)
|
||||
|
||||
# Resolve preview
|
||||
image_name = utils.get_model_preview_name(abs_path)
|
||||
image_name = os.path.join(prefix_path, image_name)
|
||||
abs_image_path = os.path.join(base_path, image_name)
|
||||
if os.path.isfile(abs_image_path):
|
||||
image_state = os.stat(abs_image_path)
|
||||
image_timestamp = round(image_state.st_mtime_ns / 1000000)
|
||||
image_name = f"{image_name}?ts={image_timestamp}"
|
||||
model_preview = (
|
||||
f"/model-manager/preview/{model_type}/{path_index}/{image_name}"
|
||||
)
|
||||
|
||||
# Resolve description
|
||||
description_file = utils.get_model_description_name(abs_path)
|
||||
description_file = os.path.join(prefix_path, description_file)
|
||||
abs_desc_path = os.path.join(base_path, description_file)
|
||||
description = None
|
||||
if os.path.isfile(abs_desc_path):
|
||||
with open(abs_desc_path, "r", encoding="utf-8") as f:
|
||||
description = f.read()
|
||||
|
||||
out.append(
|
||||
{
|
||||
"fullname": fullname,
|
||||
"basename": basename,
|
||||
"extension": extension,
|
||||
"type": model_type,
|
||||
"pathIndex": path_index,
|
||||
"sizeBytes": file_stats.st_size,
|
||||
"preview": model_preview,
|
||||
"description": description,
|
||||
"createdAt": round(file_stats.st_ctime_ns / 1000000),
|
||||
"updatedAt": round(file_stats.st_mtime_ns / 1000000),
|
||||
"metadata": metadata,
|
||||
}
|
||||
)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def update_model(model_path: str, post: MultiDictProxy):
|
||||
|
||||
if "previewFile" in post:
|
||||
previewFile = post["previewFile"]
|
||||
utils.save_model_preview_image(model_path, previewFile)
|
||||
|
||||
if "description" in post:
|
||||
description = post["description"]
|
||||
utils.save_model_description(model_path, description)
|
||||
|
||||
if "type" in post and "pathIndex" in post and "fullname" in post:
|
||||
model_type = post.get("type", None)
|
||||
path_index = int(post.get("pathIndex", None))
|
||||
fullname = post.get("fullname", None)
|
||||
if model_type is None or path_index is None or fullname is None:
|
||||
raise RuntimeError("Invalid type or pathIndex or fullname")
|
||||
|
||||
# get new path
|
||||
new_model_path = utils.get_full_path(model_type, path_index, fullname)
|
||||
|
||||
utils.rename_model(model_path, new_model_path)
|
||||
|
||||
|
||||
def remove_model(model_path: str):
|
||||
model_dirname = os.path.dirname(model_path)
|
||||
os.remove(model_path)
|
||||
|
||||
model_previews = utils.get_model_all_images(model_path)
|
||||
for preview in model_previews:
|
||||
os.remove(os.path.join(model_dirname, preview))
|
||||
|
||||
model_descriptions = utils.get_model_all_descriptions(model_path)
|
||||
for description in model_descriptions:
|
||||
os.remove(os.path.join(model_dirname, description))
|
||||
|
||||
|
||||
async def create_model_download_task(post):
|
||||
dict_post = dict(post)
|
||||
return await download.create_model_download_task(dict_post)
|
||||
63
py/socket.py
Normal file
63
py/socket.py
Normal file
@@ -0,0 +1,63 @@
|
||||
import aiohttp
|
||||
import logging
|
||||
import uuid
|
||||
import json
|
||||
from aiohttp import web
|
||||
from typing import Any, Callable, Awaitable
|
||||
from . import utils
|
||||
|
||||
|
||||
__sockets: dict[str, web.WebSocketResponse] = {}
|
||||
|
||||
|
||||
async def create_websocket_handler(
|
||||
request, handler: Callable[[str, Any, str], Awaitable[Any]]
|
||||
):
|
||||
ws = web.WebSocketResponse()
|
||||
await ws.prepare(request)
|
||||
sid = request.rel_url.query.get("clientId", "")
|
||||
if sid:
|
||||
# Reusing existing session, remove old
|
||||
__sockets.pop(sid, None)
|
||||
else:
|
||||
sid = uuid.uuid4().hex
|
||||
|
||||
__sockets[sid] = ws
|
||||
|
||||
try:
|
||||
async for msg in ws:
|
||||
if msg.type == aiohttp.WSMsgType.ERROR:
|
||||
logging.warning(
|
||||
"ws connection closed with exception %s" % ws.exception()
|
||||
)
|
||||
if msg.type == aiohttp.WSMsgType.TEXT:
|
||||
data = json.loads(msg.data)
|
||||
await handler(data.get("type"), data.get("detail"), sid)
|
||||
finally:
|
||||
__sockets.pop(sid, None)
|
||||
return ws
|
||||
|
||||
|
||||
async def send_json(event: str, data: Any, sid: str = None):
|
||||
detail = utils.unpack_dataclass(data)
|
||||
message = {"type": event, "data": detail}
|
||||
|
||||
if sid is None:
|
||||
socket_list = list(__sockets.values())
|
||||
for ws in socket_list:
|
||||
await __send_socket_catch_exception(ws.send_json, message)
|
||||
elif sid in __sockets:
|
||||
await __send_socket_catch_exception(__sockets[sid].send_json, message)
|
||||
|
||||
|
||||
async def __send_socket_catch_exception(function, message):
|
||||
try:
|
||||
await function(message)
|
||||
except (
|
||||
aiohttp.ClientError,
|
||||
aiohttp.ClientPayloadError,
|
||||
ConnectionResetError,
|
||||
BrokenPipeError,
|
||||
ConnectionError,
|
||||
) as err:
|
||||
logging.warning("send error: {}".format(err))
|
||||
64
py/thread.py
Normal file
64
py/thread.py
Normal file
@@ -0,0 +1,64 @@
|
||||
import asyncio
|
||||
import threading
|
||||
import queue
|
||||
import logging
|
||||
from . import utils
|
||||
|
||||
|
||||
class DownloadThreadPool:
|
||||
def __init__(self) -> None:
|
||||
self.workers_count = 0
|
||||
self.task_queue = queue.Queue()
|
||||
self.running_tasks = set()
|
||||
self._lock = threading.Lock()
|
||||
|
||||
default_max_workers = 5
|
||||
max_workers: int = utils.get_setting_value(
|
||||
"download.max_task_count", default_max_workers
|
||||
)
|
||||
|
||||
if max_workers <= 0:
|
||||
max_workers = default_max_workers
|
||||
utils.set_setting_value("download.max_task_count", max_workers)
|
||||
|
||||
self.max_worker = max_workers
|
||||
|
||||
def submit(self, task, task_id):
|
||||
with self._lock:
|
||||
if task_id in self.running_tasks:
|
||||
return "Existing"
|
||||
self.running_tasks.add(task_id)
|
||||
self.task_queue.put((task, task_id))
|
||||
return self._adjust_worker_count()
|
||||
|
||||
def _adjust_worker_count(self):
|
||||
if self.workers_count < self.max_worker:
|
||||
self._start_worker()
|
||||
return "Running"
|
||||
else:
|
||||
return "Waiting"
|
||||
|
||||
def _start_worker(self):
|
||||
t = threading.Thread(target=self._worker, daemon=True)
|
||||
t.start()
|
||||
with self._lock:
|
||||
self.workers_count += 1
|
||||
|
||||
def _worker(self):
|
||||
loop = asyncio.new_event_loop()
|
||||
|
||||
while True:
|
||||
if self.task_queue.empty():
|
||||
break
|
||||
|
||||
task, task_id = self.task_queue.get()
|
||||
|
||||
try:
|
||||
loop.run_until_complete(task(task_id))
|
||||
with self._lock:
|
||||
self.running_tasks.remove(task_id)
|
||||
except Exception as e:
|
||||
logging.error(f"worker run error: {str(e)}")
|
||||
|
||||
with self._lock:
|
||||
self.workers_count -= 1
|
||||
282
py/utils.py
Normal file
282
py/utils.py
Normal file
@@ -0,0 +1,282 @@
|
||||
import os
|
||||
import comfy.utils
|
||||
import json
|
||||
import logging
|
||||
import folder_paths
|
||||
from aiohttp import web
|
||||
from typing import Any
|
||||
from . import config
|
||||
|
||||
|
||||
def resolve_model_base_paths():
|
||||
folders = list(folder_paths.folder_names_and_paths.keys())
|
||||
config.model_base_paths = {}
|
||||
for folder in folders:
|
||||
if folder == "configs":
|
||||
continue
|
||||
if folder == "custom_nodes":
|
||||
continue
|
||||
config.model_base_paths[folder] = folder_paths.get_folder_paths(folder)
|
||||
|
||||
|
||||
def get_full_path(model_type: str, path_index: int, filename: str):
|
||||
"""
|
||||
Get the absolute path in the model type through string concatenation.
|
||||
"""
|
||||
folders = config.model_base_paths.get(model_type, [])
|
||||
if not path_index < len(folders):
|
||||
raise RuntimeError(f"PathIndex {path_index} is not in {model_type}")
|
||||
base_path = folders[path_index]
|
||||
return os.path.join(base_path, filename)
|
||||
|
||||
|
||||
def get_valid_full_path(model_type: str, path_index: int, filename: str):
|
||||
"""
|
||||
Like get_full_path but it will check whether the file is valid.
|
||||
"""
|
||||
folders = config.model_base_paths.get(model_type, [])
|
||||
if not path_index < len(folders):
|
||||
raise RuntimeError(f"PathIndex {path_index} is not in {model_type}")
|
||||
base_path = folders[path_index]
|
||||
full_path = os.path.join(base_path, filename)
|
||||
if os.path.isfile(full_path):
|
||||
return full_path
|
||||
elif os.path.islink(full_path):
|
||||
raise RuntimeError(
|
||||
f"WARNING path {full_path} exists but doesn't link anywhere, skipping."
|
||||
)
|
||||
|
||||
|
||||
def get_download_path():
|
||||
download_path = os.path.join(config.extension_uri, "downloads")
|
||||
if not os.path.exists(download_path):
|
||||
os.makedirs(download_path)
|
||||
return download_path
|
||||
|
||||
|
||||
def recursive_search_files(directory: str):
|
||||
files, folder_all = folder_paths.recursive_search(
|
||||
directory, excluded_dir_names=[".git"]
|
||||
)
|
||||
files.sort()
|
||||
return files
|
||||
|
||||
|
||||
def search_files(directory: str):
|
||||
entries = os.listdir(directory)
|
||||
files = [f for f in entries if os.path.isfile(os.path.join(directory, f))]
|
||||
files.sort()
|
||||
return files
|
||||
|
||||
|
||||
def get_model_metadata(filename: str):
|
||||
if not filename.endswith(".safetensors"):
|
||||
return {}
|
||||
try:
|
||||
out = comfy.utils.safetensors_header(filename, max_size=1024 * 1024)
|
||||
if out is None:
|
||||
return {}
|
||||
dt = json.loads(out)
|
||||
if not "__metadata__" in dt:
|
||||
return {}
|
||||
return dt["__metadata__"]
|
||||
except:
|
||||
return {}
|
||||
|
||||
|
||||
def get_model_all_images(model_path: str):
|
||||
base_dirname = os.path.dirname(model_path)
|
||||
files = search_files(base_dirname)
|
||||
files = folder_paths.filter_files_content_types(files, ["image"])
|
||||
|
||||
basename = os.path.splitext(os.path.basename(model_path))[0]
|
||||
output: list[str] = []
|
||||
for file in files:
|
||||
file_basename = os.path.splitext(file)[0]
|
||||
if file_basename == basename:
|
||||
output.append(file)
|
||||
if file_basename == f"{basename}.preview":
|
||||
output.append(file)
|
||||
return output
|
||||
|
||||
|
||||
def get_model_preview_name(model_path: str):
|
||||
images = get_model_all_images(model_path)
|
||||
return images[0] if len(images) > 0 else "no-preview.png"
|
||||
|
||||
|
||||
def save_model_preview_image(model_path: str, image_file: Any):
|
||||
if not isinstance(image_file, web.FileField):
|
||||
raise RuntimeError("Invalid image file")
|
||||
|
||||
content_type: str = image_file.content_type
|
||||
if not content_type.startswith("image/"):
|
||||
raise RuntimeError(f"FileTypeError: expected image, got {content_type}")
|
||||
|
||||
base_dirname = os.path.dirname(model_path)
|
||||
|
||||
# remove old preview images
|
||||
old_preview_images = get_model_all_images(model_path)
|
||||
a1111_civitai_helper_image = False
|
||||
for image in old_preview_images:
|
||||
if os.path.splitext(image)[1].endswith(".preview"):
|
||||
a1111_civitai_helper_image = True
|
||||
image_path = os.path.join(base_dirname, image)
|
||||
os.remove(image_path)
|
||||
|
||||
# save new preview image
|
||||
basename = os.path.splitext(os.path.basename(model_path))[0]
|
||||
extension = f".{content_type.split('/')[1]}"
|
||||
new_preview_path = os.path.join(base_dirname, f"{basename}{extension}")
|
||||
|
||||
with open(new_preview_path, "wb") as f:
|
||||
f.write(image_file.file.read())
|
||||
|
||||
# TODO Is it possible to abandon the current rules and adopt the rules of a1111 civitai_helper?
|
||||
if a1111_civitai_helper_image:
|
||||
"""
|
||||
Keep preview image of a1111_civitai_helper
|
||||
"""
|
||||
new_preview_path = os.path.join(base_dirname, f"{basename}.preview{extension}")
|
||||
with open(new_preview_path, "wb") as f:
|
||||
f.write(image_file.file.read())
|
||||
|
||||
|
||||
def get_model_all_descriptions(model_path: str):
|
||||
base_dirname = os.path.dirname(model_path)
|
||||
files = search_files(base_dirname)
|
||||
files = folder_paths.filter_files_extensions(files, [".txt", ".md"])
|
||||
|
||||
basename = os.path.splitext(os.path.basename(model_path))[0]
|
||||
output: list[str] = []
|
||||
for file in files:
|
||||
file_basename = os.path.splitext(file)[0]
|
||||
if file_basename == basename:
|
||||
output.append(file)
|
||||
return output
|
||||
|
||||
|
||||
def get_model_description_name(model_path: str):
|
||||
descriptions = get_model_all_descriptions(model_path)
|
||||
basename = os.path.splitext(os.path.basename(model_path))[0]
|
||||
return descriptions[0] if len(descriptions) > 0 else f"{basename}.md"
|
||||
|
||||
|
||||
def save_model_description(model_path: str, content: Any):
|
||||
if not isinstance(content, str):
|
||||
raise RuntimeError("Invalid description")
|
||||
|
||||
base_dirname = os.path.dirname(model_path)
|
||||
|
||||
# remove old descriptions
|
||||
old_descriptions = get_model_all_descriptions(model_path)
|
||||
for desc in old_descriptions:
|
||||
description_path = os.path.join(base_dirname, desc)
|
||||
os.remove(description_path)
|
||||
|
||||
# save new description
|
||||
basename = os.path.splitext(os.path.basename(model_path))[0]
|
||||
extension = ".md"
|
||||
new_desc_path = os.path.join(base_dirname, f"{basename}{extension}")
|
||||
|
||||
with open(new_desc_path, "w", encoding="utf-8") as f:
|
||||
f.write(content)
|
||||
|
||||
|
||||
def rename_model(model_path: str, new_model_path: str):
|
||||
if model_path == new_model_path:
|
||||
return
|
||||
|
||||
if os.path.exists(new_model_path):
|
||||
raise RuntimeError(f"Model {new_model_path} already exists")
|
||||
|
||||
model_name = os.path.splitext(os.path.basename(model_path))[0]
|
||||
new_model_name = os.path.splitext(os.path.basename(new_model_path))[0]
|
||||
|
||||
model_dirname = os.path.dirname(model_path)
|
||||
new_model_dirname = os.path.dirname(new_model_path)
|
||||
|
||||
if not os.path.exists(new_model_dirname):
|
||||
os.makedirs(new_model_dirname)
|
||||
|
||||
# move model
|
||||
os.rename(model_path, new_model_path)
|
||||
|
||||
# move preview
|
||||
previews = get_model_all_images(model_path)
|
||||
for preview in previews:
|
||||
preview_path = os.path.join(model_dirname, preview)
|
||||
preview_name = os.path.splitext(preview)[0]
|
||||
preview_ext = os.path.splitext(preview)[1]
|
||||
new_preview_path = (
|
||||
os.path.join(new_model_dirname, new_model_name + preview_ext)
|
||||
if preview_name == model_name
|
||||
else os.path.join(
|
||||
new_model_dirname, new_model_name + ".preview" + preview_ext
|
||||
)
|
||||
)
|
||||
os.rename(preview_path, new_preview_path)
|
||||
|
||||
# move description
|
||||
description = get_model_description_name(model_path)
|
||||
description_path = os.path.join(model_dirname, description)
|
||||
if os.path.isfile(description_path):
|
||||
new_description_path = os.path.join(new_model_dirname, f"{new_model_name}.md")
|
||||
os.rename(description_path, new_description_path)
|
||||
|
||||
|
||||
import pickle
|
||||
|
||||
|
||||
def save_dict_pickle_file(filename: str, data: dict):
|
||||
with open(filename, "wb") as f:
|
||||
pickle.dump(data, f)
|
||||
|
||||
|
||||
def load_dict_pickle_file(filename: str) -> dict:
|
||||
with open(filename, "rb") as f:
|
||||
data = pickle.load(f)
|
||||
return data
|
||||
|
||||
|
||||
def resolve_setting_key(key: str) -> str:
|
||||
key_paths = key.split(".")
|
||||
setting_id = config.setting_key
|
||||
try:
|
||||
for key_path in key_paths:
|
||||
setting_id = setting_id[key_path]
|
||||
except:
|
||||
pass
|
||||
if not isinstance(setting_id, str):
|
||||
raise RuntimeError(f"Invalid key: {key}")
|
||||
|
||||
return setting_id
|
||||
|
||||
|
||||
def set_setting_value(key: str, value: Any):
|
||||
setting_id = resolve_setting_key(key)
|
||||
fake_request = config.FakeRequest()
|
||||
settings = config.serverInstance.user_manager.settings.get_settings(fake_request)
|
||||
settings[setting_id] = value
|
||||
config.serverInstance.user_manager.settings.save_settings(fake_request, settings)
|
||||
|
||||
|
||||
def get_setting_value(key: str, default: Any = None) -> Any:
|
||||
setting_id = resolve_setting_key(key)
|
||||
fake_request = config.FakeRequest()
|
||||
settings = config.serverInstance.user_manager.settings.get_settings(fake_request)
|
||||
return settings.get(setting_id, default)
|
||||
|
||||
|
||||
from dataclasses import asdict, is_dataclass
|
||||
|
||||
|
||||
def unpack_dataclass(data: Any):
|
||||
if isinstance(data, dict):
|
||||
return {key: unpack_dataclass(value) for key, value in data.items()}
|
||||
elif isinstance(data, list):
|
||||
return [unpack_dataclass(x) for x in data]
|
||||
elif is_dataclass(data):
|
||||
return asdict(data)
|
||||
else:
|
||||
return data
|
||||
Reference in New Issue
Block a user