diff --git a/README.md b/README.md
index ee81c9c..2d4e3c1 100644
--- a/README.md
+++ b/README.md
@@ -61,3 +61,10 @@ npm run build
- Read, edit and save notes. (Saved as a `.md` file beside the model).
- Change or remove a model's preview image.
- View training tags and use the random tag generator to generate prompt ideas. (Inspired by the one in A1111.)
+
+### Scan Model Information
+
+
+
+- Scan models and try to download information & preview.
+- Support migration from `cdb-boop/ComfyUI-Model-Manager/main`
diff --git a/__init__.py b/__init__.py
index 31e6d87..c5a0770 100644
--- a/__init__.py
+++ b/__init__.py
@@ -3,9 +3,26 @@ import folder_paths
from .py import config
from .py import utils
+extension_uri = utils.normalize_path(os.path.dirname(__file__))
+
+requirements_path = utils.join_path(extension_uri, "requirements.txt")
+
+with open(requirements_path, "r", encoding="utf-8") as f:
+ requirements = f.readlines()
+
+requirements = [x.strip() for x in requirements]
+requirements = [x for x in requirements if not x.startswith("#")]
+
+uninstalled_package = [p for p in requirements if not utils.is_installed(p)]
+
+if len(uninstalled_package) > 0:
+ utils.print_info(f"Install dependencies...")
+ for p in uninstalled_package:
+ utils.pip_install(p)
+
# Init config settings
-config.extension_uri = utils.normalize_path(os.path.dirname(__file__))
+config.extension_uri = extension_uri
utils.resolve_model_base_paths()
version = utils.get_current_version()
@@ -97,9 +114,8 @@ async def create_model(request):
- downloadUrl: download url.
- hash: a JSON string containing the hash value of the downloaded model.
"""
- post = await request.post()
+ task_data = await request.json()
try:
- task_data = dict(post)
task_id = await services.create_model_download_task(task_data, request)
return web.json_response({"success": True, "data": {"taskId": task_id}})
except Exception as e:
@@ -158,13 +174,12 @@ async def update_model(request):
index = int(request.match_info.get("index", None))
filename = request.match_info.get("filename", None)
- post: dict = await request.post()
+ model_data: dict = await request.json()
try:
model_path = utils.get_valid_full_path(model_type, index, filename)
if model_path is None:
raise RuntimeError(f"File {filename} not found")
- model_data = dict(post)
services.update_model(model_path, model_data)
return web.json_response({"success": True})
except Exception as e:
@@ -194,6 +209,37 @@ async def delete_model(request):
return web.json_response({"success": False, "error": error_msg})
+@routes.get("/model-manager/model-info")
+async def fetch_model_info(request):
+ """
+ Fetch model information from network with model page.
+ """
+ try:
+ model_page = request.query.get("model-page", None)
+ result = services.fetch_model_info(model_page)
+ return web.json_response({"success": True, "data": result})
+ except Exception as e:
+ error_msg = f"Fetch model info failed: {str(e)}"
+ utils.print_error(error_msg)
+ return web.json_response({"success": False, "error": error_msg})
+
+
+@routes.post("/model-manager/model-info/scan")
+async def download_model_info(request):
+ """
+ Create a task to download model information.
+ """
+ post = await utils.get_request_body(request)
+ try:
+ scan_mode = post.get("scanMode", "diff")
+ await services.download_model_info(scan_mode)
+ return web.json_response({"success": True})
+ except Exception as e:
+ error_msg = f"Download model info failed: {str(e)}"
+ utils.print_error(error_msg)
+ return web.json_response({"success": False, "error": error_msg})
+
+
@routes.get("/model-manager/preview/{type}/{index}/{filename:.*}")
async def read_model_preview(request):
"""
@@ -236,6 +282,20 @@ async def read_download_preview(request):
return web.FileResponse(preview_path)
+@routes.post("/model-manager/migrate")
+async def migrate_legacy_information(request):
+ """
+ Migrate legacy information.
+ """
+ try:
+ await services.migrate_legacy_information()
+ return web.json_response({"success": True})
+ except Exception as e:
+ error_msg = f"Download model info failed: {str(e)}"
+ utils.print_error(error_msg)
+ return web.json_response({"success": False, "error": error_msg})
+
+
WEB_DIRECTORY = "web"
NODE_CLASS_MAPPINGS = {}
__all__ = ["WEB_DIRECTORY", "NODE_CLASS_MAPPINGS"]
diff --git a/demo/scan-model-info.png b/demo/scan-model-info.png
new file mode 100755
index 0000000..7a26a3e
Binary files /dev/null and b/demo/scan-model-info.png differ
diff --git a/package.json b/package.json
index 3706b2e..50e0865 100644
--- a/package.json
+++ b/package.json
@@ -14,7 +14,6 @@
"@types/lodash": "^4.17.9",
"@types/markdown-it": "^14.1.2",
"@types/node": "^22.5.5",
- "@types/turndown": "^5.0.5",
"@vitejs/plugin-vue": "^5.1.4",
"autoprefixer": "^10.4.20",
"eslint": "^9.10.0",
@@ -40,15 +39,13 @@
"markdown-it": "^14.1.0",
"markdown-it-metadata-block": "^1.0.6",
"primevue": "^4.0.7",
- "turndown": "^7.2.0",
"vue": "^3.4.31",
"vue-i18n": "^9.13.1",
"yaml": "^2.6.0"
},
"lint-staged": {
"./**/*.{js,ts,tsx,vue}": [
- "prettier --write",
- "git add"
+ "prettier --write"
]
}
}
diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml
index 3fc0dc2..2b59fef 100644
--- a/pnpm-lock.yaml
+++ b/pnpm-lock.yaml
@@ -26,9 +26,6 @@ importers:
primevue:
specifier: ^4.0.7
version: 4.0.7(vue@3.5.6(typescript@5.6.2))
- turndown:
- specifier: ^7.2.0
- version: 7.2.0
vue:
specifier: ^3.4.31
version: 3.5.6(typescript@5.6.2)
@@ -51,9 +48,6 @@ importers:
'@types/node':
specifier: ^22.5.5
version: 22.5.5
- '@types/turndown':
- specifier: ^5.0.5
- version: 5.0.5
'@vitejs/plugin-vue':
specifier: ^5.1.4
version: 5.1.4(vite@5.4.6(@types/node@22.5.5)(less@4.2.0))(vue@3.5.6(typescript@5.6.2))
@@ -349,9 +343,6 @@ packages:
'@jridgewell/trace-mapping@0.3.25':
resolution: {integrity: sha512-vNk6aEwybGtawWmy/PzwnGDOjCkLWSD2wqvjGGAgOAwCGWySYXfYoxt00IJkTF+8Lb57DwOb3Aa0o9CApepiYQ==}
- '@mixmark-io/domino@2.2.0':
- resolution: {integrity: sha512-Y28PR25bHXUg88kCV7nivXrP2Nj2RueZ3/l/jdx6J9f8J4nsEGcgX0Qe6lt7Pa+J79+kPiJU3LguR6O/6zrLOw==}
-
'@nodelib/fs.scandir@2.1.5':
resolution: {integrity: sha512-vq24Bq3ym5HEQm2NKCr3yXDwjc7vTsEThRDnkp2DK9p1uqLR+DHurm/NOTo0KG7HYHU7eppKZj3MyqYuMBf62g==}
engines: {node: '>= 8'}
@@ -502,9 +493,6 @@ packages:
'@types/node@22.5.5':
resolution: {integrity: sha512-Xjs4y5UPO/CLdzpgR6GirZJx36yScjh73+2NlLlkFRSoQN8B0DpfXPdZGnvVmLRLOsqDpOfTNv7D9trgGhmOIA==}
- '@types/turndown@5.0.5':
- resolution: {integrity: sha512-TL2IgGgc7B5j78rIccBtlYAnkuv8nUQqhQc+DSYV5j9Be9XOcm/SKOVRuA47xAVI3680Tk9B1d8flK2GWT2+4w==}
-
'@typescript-eslint/eslint-plugin@8.13.0':
resolution: {integrity: sha512-nQtBLiZYMUPkclSeC3id+x4uVd1SGtHuElTxL++SfP47jR0zfkZBJHc+gL4qPsgTuypz0k8Y2GheaDYn6Gy3rg==}
engines: {node: ^18.18.0 || ^20.9.0 || >=21.1.0}
@@ -1614,9 +1602,6 @@ packages:
tslib@2.7.0:
resolution: {integrity: sha512-gLXCKdN1/j47AiHiOkJN69hJmcbGTHI0ImLmbYLHykhgeN0jVGola9yVjFgzCUklsZQMW55o+dW7IXv3RCXDzA==}
- turndown@7.2.0:
- resolution: {integrity: sha512-eCZGBN4nNNqM9Owkv9HAtWRYfLA4h909E/WGAWWBpmB275ehNhZyk87/Tpvjbp0jjNl9XwCsbe6bm6CqFsgD+A==}
-
type-check@0.4.0:
resolution: {integrity: sha512-XleUoc9uwGXqjWwXaUTZAmzMcFZ5858QA2vvx1Ur5xIcixXIP+8LnFDgRplU30us6teqdlskFfu+ae4K79Ooew==}
engines: {node: '>= 0.8.0'}
@@ -1929,8 +1914,6 @@ snapshots:
'@jridgewell/resolve-uri': 3.1.2
'@jridgewell/sourcemap-codec': 1.5.0
- '@mixmark-io/domino@2.2.0': {}
-
'@nodelib/fs.scandir@2.1.5':
dependencies:
'@nodelib/fs.stat': 2.0.5
@@ -2038,8 +2021,6 @@ snapshots:
dependencies:
undici-types: 6.19.8
- '@types/turndown@5.0.5': {}
-
'@typescript-eslint/eslint-plugin@8.13.0(@typescript-eslint/parser@8.13.0(eslint@9.10.0(jiti@1.21.6))(typescript@5.6.2))(eslint@9.10.0(jiti@1.21.6))(typescript@5.6.2)':
dependencies:
'@eslint-community/regexpp': 4.12.1
@@ -3164,10 +3145,6 @@ snapshots:
tslib@2.7.0: {}
- turndown@7.2.0:
- dependencies:
- '@mixmark-io/domino': 2.2.0
-
type-check@0.4.0:
dependencies:
prelude-ls: 1.2.1
diff --git a/py/download.py b/py/download.py
index 0fee52b..84f61d2 100644
--- a/py/download.py
+++ b/py/download.py
@@ -24,6 +24,34 @@ class TaskStatus:
bps: float = 0
error: Optional[str] = None
+ def __init__(self, **kwargs):
+ self.taskId = kwargs.get("taskId", None)
+ self.type = kwargs.get("type", None)
+ self.fullname = kwargs.get("fullname", None)
+ self.preview = kwargs.get("preview", None)
+ self.status = kwargs.get("status", "pause")
+ self.platform = kwargs.get("platform", None)
+ self.downloadedSize = kwargs.get("downloadedSize", 0)
+ self.totalSize = kwargs.get("totalSize", 0)
+ self.progress = kwargs.get("progress", 0)
+ self.bps = kwargs.get("bps", 0)
+ self.error = kwargs.get("error", None)
+
+ def to_dict(self):
+ return {
+ "taskId": self.taskId,
+ "type": self.type,
+ "fullname": self.fullname,
+ "preview": self.preview,
+ "status": self.status,
+ "platform": self.platform,
+ "downloadedSize": self.downloadedSize,
+ "totalSize": self.totalSize,
+ "progress": self.progress,
+ "bps": self.bps,
+ "error": self.error,
+ }
+
@dataclass
class TaskContent:
@@ -33,9 +61,31 @@ class TaskContent:
description: str
downloadPlatform: str
downloadUrl: str
- sizeBytes: float
+ sizeBytes: int
hashes: Optional[dict[str, str]] = None
+ def __init__(self, **kwargs):
+ self.type = kwargs.get("type", None)
+ self.pathIndex = int(kwargs.get("pathIndex", 0))
+ self.fullname = kwargs.get("fullname", None)
+ self.description = kwargs.get("description", None)
+ self.downloadPlatform = kwargs.get("downloadPlatform", None)
+ self.downloadUrl = kwargs.get("downloadUrl", None)
+ self.sizeBytes = int(kwargs.get("sizeBytes", 0))
+ self.hashes = kwargs.get("hashes", None)
+
+ def to_dict(self):
+ return {
+ "type": self.type,
+ "pathIndex": self.pathIndex,
+ "fullname": self.fullname,
+ "description": self.description,
+ "downloadPlatform": self.downloadPlatform,
+ "downloadUrl": self.downloadUrl,
+ "sizeBytes": self.sizeBytes,
+ "hashes": self.hashes,
+ }
+
download_model_task_status: dict[str, TaskStatus] = {}
download_thread_pool = thread.DownloadThreadPool()
@@ -44,7 +94,7 @@ download_thread_pool = thread.DownloadThreadPool()
def set_task_content(task_id: str, task_content: Union[TaskContent, dict]):
download_path = utils.get_download_path()
task_file_path = utils.join_path(download_path, f"{task_id}.task")
- utils.save_dict_pickle_file(task_file_path, utils.unpack_dataclass(task_content))
+ utils.save_dict_pickle_file(task_file_path, task_content)
def get_task_content(task_id: str):
@@ -53,8 +103,6 @@ def get_task_content(task_id: str):
if not os.path.isfile(task_file):
raise RuntimeError(f"Task {task_id} not found")
task_content = utils.load_dict_pickle_file(task_file)
- task_content["pathIndex"] = int(task_content.get("pathIndex", 0))
- task_content["sizeBytes"] = float(task_content.get("sizeBytes", 0))
return TaskContent(**task_content)
@@ -106,14 +154,14 @@ async def scan_model_download_task_list():
for task_file in task_files:
task_id = task_file.replace(".task", "")
task_status = get_task_status(task_id)
- task_list.append(task_status)
+ task_list.append(task_status.to_dict())
- return utils.unpack_dataclass(task_list)
+ return task_list
async def create_model_download_task(task_data: dict, request):
"""
- Creates a download task for the given post.
+ Creates a download task for the given data.
"""
model_type = task_data.get("type", None)
path_index = int(task_data.get("pathIndex", None))
@@ -132,8 +180,8 @@ async def create_model_download_task(task_data: dict, request):
raise RuntimeError(f"Task {task_id} already exists")
try:
- previewFile = task_data.pop("previewFile", None)
- utils.save_model_preview_image(task_path, previewFile)
+ preview_url = task_data.pop("preview", None)
+ utils.save_model_preview_image(task_path, preview_url)
set_task_content(task_id, task_data)
task_status = TaskStatus(
taskId=task_id,
@@ -144,7 +192,7 @@ async def create_model_download_task(task_data: dict, request):
totalSize=float(task_data.get("sizeBytes", 0)),
)
download_model_task_status[task_id] = task_status
- await utils.send_json("create_download_task", task_status)
+ await utils.send_json("create_download_task", task_status.to_dict())
except Exception as e:
await delete_model_download_task(task_id)
raise RuntimeError(str(e)) from e
@@ -183,7 +231,7 @@ async def delete_model_download_task(task_id: str):
async def download_model(task_id: str, request):
async def download_task(task_id: str):
async def report_progress(task_status: TaskStatus):
- await utils.send_json("update_download_task", task_status)
+ await utils.send_json("update_download_task", task_status.to_dict())
try:
# When starting a task from the queue, the task may not exist
@@ -193,7 +241,7 @@ async def download_model(task_id: str, request):
# Update task status
task_status.status = "doing"
- await utils.send_json("update_download_task", task_status)
+ await utils.send_json("update_download_task", task_status.to_dict())
try:
@@ -221,7 +269,7 @@ async def download_model(task_id: str, request):
except Exception as e:
task_status.status = "pause"
task_status.error = str(e)
- await utils.send_json("update_download_task", task_status)
+ await utils.send_json("update_download_task", task_status.to_dict())
task_status.error = None
utils.print_error(str(e))
@@ -230,11 +278,11 @@ async def download_model(task_id: str, request):
if status == "Waiting":
task_status = get_task_status(task_id)
task_status.status = "waiting"
- await utils.send_json("update_download_task", task_status)
+ await utils.send_json("update_download_task", task_status.to_dict())
except Exception as e:
task_status.status = "pause"
task_status.error = str(e)
- await utils.send_json("update_download_task", task_status)
+ await utils.send_json("update_download_task", task_status.to_dict())
task_status.error = None
utils.print_error(str(e))
@@ -339,7 +387,7 @@ async def download_model_file(
task_content.sizeBytes = total_size
task_status.totalSize = total_size
set_task_content(task_id, task_content)
- await utils.send_json("update_download_task", task_content)
+ await utils.send_json("update_download_task", task_content.to_dict())
with open(download_tmp_file, "ab") as f:
for chunk in response.iter_content(chunk_size=8192):
@@ -358,4 +406,4 @@ async def download_model_file(
await download_complete()
else:
task_status.status = "pause"
- await utils.send_json("update_download_task", task_status)
+ await utils.send_json("update_download_task", task_status.to_dict())
diff --git a/py/searcher.py b/py/searcher.py
new file mode 100644
index 0000000..a2a1926
--- /dev/null
+++ b/py/searcher.py
@@ -0,0 +1,317 @@
+import os
+import re
+import yaml
+import requests
+import markdownify
+
+
+from abc import ABC, abstractmethod
+from urllib.parse import urlparse, parse_qs
+
+from . import utils
+
+
+class ModelSearcher(ABC):
+ """
+ Abstract class for model searcher.
+ """
+
+ @abstractmethod
+ def search_by_url(self, url: str) -> list[dict]:
+ pass
+
+ @abstractmethod
+ def search_by_hash(self, hash: str) -> dict:
+ pass
+
+
+class UnknownWebsiteSearcher(ModelSearcher):
+ def search_by_url(self, url: str):
+ raise RuntimeError(
+ f"Unknown Website, please input a URL from huggingface.co or civitai.com."
+ )
+
+ def search_by_hash(self, hash: str):
+ raise RuntimeError(f"Unknown Website, unable to search with hash value.")
+
+
+class CivitaiModelSearcher(ModelSearcher):
+ def search_by_url(self, url: str):
+ parsed_url = urlparse(url)
+
+ pathname = parsed_url.path
+ match = re.match(r"^/models/(\d*)", pathname)
+ model_id = match.group(1) if match else None
+
+ query_params = parse_qs(parsed_url.query)
+ version_id = query_params.get("modelVersionId", [None])[0]
+
+ if not model_id:
+ return []
+
+ response = requests.get(f"https://civitai.com/api/v1/models/{model_id}")
+ response.raise_for_status()
+ res_data: dict = response.json()
+
+ model_versions: list[dict] = res_data["modelVersions"]
+ if version_id:
+ model_versions = utils.filter_with(model_versions, {"id": int(version_id)})
+
+ models: list[dict] = []
+
+ for version in model_versions:
+ model_files: list[dict] = version.get("files", [])
+ model_files = utils.filter_with(model_files, {"type": "Model"})
+
+ shortname = version.get("name", None) if len(model_files) > 0 else None
+
+ for file in model_files:
+ fullname = file.get("name", None)
+ extension = os.path.splitext(fullname)[1]
+ basename = os.path.splitext(fullname)[0]
+
+ metadata_info = {
+ "website": "Civitai",
+ "modelPage": f"https://civitai.com/models/{model_id}?modelVersionId={version.get('id')}",
+ "author": res_data.get("creator", {}).get("username", None),
+ "baseModel": version.get("baseModel"),
+ "hashes": file.get("hashes"),
+ "metadata": file.get("metadata"),
+ "preview": [i["url"] for i in version["images"]],
+ }
+
+ description_parts: list[str] = []
+ description_parts.append("---")
+ description_parts.append(yaml.dump(metadata_info).strip())
+ description_parts.append("---")
+ description_parts.append("")
+ description_parts.append(f"# Trigger Words")
+ description_parts.append("")
+ description_parts.append(
+ ", ".join(version.get("trainedWords", ["No trigger words"]))
+ )
+ description_parts.append("")
+ description_parts.append(f"# About this version")
+ description_parts.append("")
+ description_parts.append(
+ markdownify.markdownify(
+ version.get(
+ "description", "
No description about this version
" + ) + ).strip() + ) + description_parts.append("") + description_parts.append(f"# {res_data.get('name')}") + description_parts.append("") + description_parts.append( + markdownify.markdownify( + res_data.get( + "description", "No description about this model
" + ) + ).strip() + ) + description_parts.append("") + + model = { + "id": file.get("id"), + "shortname": shortname or basename, + "fullname": fullname, + "basename": basename, + "extension": extension, + "preview": metadata_info.get("preview"), + "sizeBytes": file.get("sizeKB", 0) * 1024, + "type": self._resolve_model_type(res_data.get("type", "unknown")), + "pathIndex": 0, + "description": "\n".join(description_parts), + "metadata": file.get("metadata"), + "downloadPlatform": "civitai", + "downloadUrl": file.get("downloadUrl"), + "hashes": file.get("hashes"), + } + models.append(model) + + return models + + def search_by_hash(self, hash: str): + if not hash: + raise RuntimeError(f"Hash value is empty.") + + response = requests.get( + f"https://civitai.com/api/v1/model-versions/by-hash/{hash}" + ) + response.raise_for_status() + version: dict = response.json() + + model_id = version.get("modelId") + version_id = version.get("id") + + model_page = ( + f"https://civitai.com/models/{model_id}?modelVersionId={version_id}" + ) + + models = self.search_by_url(model_page) + + for model in models: + sha256 = model.get("hashes", {}).get("SHA256") + if sha256 == hash: + return model + + return models[0] + + def _resolve_model_type(self, model_type: str): + map_legacy = { + "TextualInversion": "embeddings", + "LoCon": "loras", + "DoRA": "loras", + "Controlnet": "controlnet", + "Upscaler": "upscale_models", + "VAE": "vae", + "unknown": "unknown", + } + return map_legacy.get(model_type, f"{model_type.lower()}s") + + +class HuggingfaceModelSearcher(ModelSearcher): + def search_by_url(self, url: str): + parsed_url = urlparse(url) + + pathname = parsed_url.path + + space, name, *rest_paths = pathname.strip("/").split("/") + + model_id = f"{space}/{name}" + rest_pathname = "/".join(rest_paths) + + response = requests.get(f"https://huggingface.co/api/models/{model_id}") + response.raise_for_status() + res_data: dict = response.json() + + sibling_files: list[str] = [ + x.get("rfilename") for x in res_data.get("siblings", []) + ] + + model_files = utils.filter_with( + utils.filter_with(sibling_files, self._match_model_files()), + self._match_tree_files(rest_pathname), + ) + + image_files = utils.filter_with( + utils.filter_with(sibling_files, self._match_image_files()), + self._match_tree_files(rest_pathname), + ) + image_files = [ + f"https://huggingface.co/{model_id}/resolve/main/{filename}" + for filename in image_files + ] + + models: list[dict] = [] + + for filename in model_files: + fullname = os.path.basename(filename) + extension = os.path.splitext(fullname)[1] + basename = os.path.splitext(fullname)[0] + + description_parts: list[str] = [] + + metadata_info = { + "website": "HuggingFace", + "modelPage": f"https://huggingface.co/{model_id}", + "author": res_data.get("author", None), + "preview": image_files, + } + + description_parts: list[str] = [] + description_parts.append("---") + description_parts.append(yaml.dump(metadata_info).strip()) + description_parts.append("---") + description_parts.append("") + description_parts.append(f"# Trigger Words") + description_parts.append("") + description_parts.append("No trigger words") + description_parts.append("") + description_parts.append(f"# About this version") + description_parts.append("") + description_parts.append("No description about this version") + description_parts.append("") + description_parts.append(f"# {res_data.get('name')}") + description_parts.append("") + description_parts.append("No description about this model") + description_parts.append("") + + model = { + "id": filename, + "shortname": filename, + "fullname": fullname, + "basename": basename, + "extension": extension, + "preview": image_files, + "sizeBytes": 0, + "type": "unknown", + "pathIndex": 0, + "description": "\n".join(description_parts), + "metadata": {}, + "downloadPlatform": "", + "downloadUrl": f"https://huggingface.co/{model_id}/resolve/main/{filename}?download=true", + } + models.append(model) + + return models + + def search_by_hash(self, hash: str): + raise RuntimeError("Hash search is not supported by Huggingface.") + + def _match_model_files(self): + extension = [ + ".bin", + ".ckpt", + ".gguf", + ".onnx", + ".pt", + ".pth", + ".safetensors", + ] + + def _filter_model_files(file: str): + return any(file.endswith(ext) for ext in extension) + + return _filter_model_files + + def _match_image_files(self): + extension = [ + ".png", + ".webp", + ".jpeg", + ".jpg", + ".jfif", + ".gif", + ".apng", + ] + + def _filter_image_files(file: str): + return any(file.endswith(ext) for ext in extension) + + return _filter_image_files + + def _match_tree_files(self, pathname: str): + target, *paths = pathname.split("/") + + def _filter_tree_files(file: str): + if not target: + return True + if target != "tree" and target != "blob": + return True + + prefix_path = "/".join(paths) + return file.startswith(prefix_path) + + return _filter_tree_files + + +def get_model_searcher_by_url(url: str) -> ModelSearcher: + parsed_url = urlparse(url) + host_name = parsed_url.hostname + if host_name == "civitai.com": + return CivitaiModelSearcher() + elif host_name == "huggingface.co": + return HuggingfaceModelSearcher() + return UnknownWebsiteSearcher() diff --git a/py/services.py b/py/services.py index 9a6c7e5..6729987 100644 --- a/py/services.py +++ b/py/services.py @@ -5,6 +5,7 @@ import folder_paths from . import config from . import utils from . import download +from . import searcher def scan_models(): @@ -128,3 +129,180 @@ async def resume_model_download_task(task_id, request): async def delete_model_download_task(task_id): return await download.delete_model_download_task(task_id) + + +def fetch_model_info(model_page: str): + if not model_page: + return [] + + model_searcher = searcher.get_model_searcher_by_url(model_page) + result = model_searcher.search_by_url(model_page) + return result + + +async def download_model_info(scan_mode: str): + utils.print_info(f"Download model info for {scan_mode}") + model_base_paths = config.model_base_paths + for model_type in model_base_paths: + + folders, extensions = folder_paths.folder_names_and_paths[model_type] + for path_index, base_path in enumerate(folders): + files = utils.recursive_search_files(base_path) + + models = folder_paths.filter_files_extensions(files, extensions) + images = folder_paths.filter_files_content_types(files, ["image"]) + image_dict = utils.file_list_to_name_dict(images) + descriptions = folder_paths.filter_files_extensions(files, [".md"]) + description_dict = utils.file_list_to_name_dict(descriptions) + + for fullname in models: + fullname = utils.normalize_path(fullname) + basename = os.path.splitext(fullname)[0] + + abs_model_path = utils.join_path(base_path, fullname) + + image_name = image_dict.get(basename, "no-preview.png") + abs_image_path = utils.join_path(base_path, image_name) + + has_preview = os.path.isfile(abs_image_path) + + description_name = description_dict.get(basename, None) + abs_description_path = ( + utils.join_path(base_path, description_name) + if description_name + else None + ) + has_description = ( + os.path.isfile(abs_description_path) + if abs_description_path + else False + ) + + try: + + utils.print_info(f"Checking model {abs_model_path}") + utils.print_debug(f"Scan mode: {scan_mode}") + utils.print_debug(f"Has preview: {has_preview}") + utils.print_debug(f"Has description: {has_description}") + + if scan_mode != "full" and (has_preview and has_description): + continue + + utils.print_debug(f"Calculate sha256 for {abs_model_path}") + hash_value = utils.calculate_sha256(abs_model_path) + utils.print_info(f"Searching model info by hash {hash_value}") + model_info = searcher.CivitaiModelSearcher().search_by_hash( + hash_value + ) + + preview_url_list = model_info.get("preview", []) + preview_image_url = ( + preview_url_list[0] if preview_url_list else None + ) + if preview_image_url: + utils.print_debug(f"Save preview image to {abs_image_path}") + utils.save_model_preview_image( + abs_model_path, preview_image_url + ) + + description = model_info.get("description", None) + if description: + utils.save_model_description(abs_model_path, description) + except Exception as e: + utils.print_error( + f"Failed to download model info for {abs_model_path}: {e}" + ) + + utils.print_debug("Completed scan model information.") + + +async def migrate_legacy_information(): + import json + import yaml + from PIL import Image + + utils.print_info(f"Migrating legacy information...") + + model_base_paths = config.model_base_paths + for model_type in model_base_paths: + + folders, extensions = folder_paths.folder_names_and_paths[model_type] + for path_index, base_path in enumerate(folders): + files = utils.recursive_search_files(base_path) + + models = folder_paths.filter_files_extensions(files, extensions) + + for fullname in models: + fullname = utils.normalize_path(fullname) + + abs_model_path = utils.join_path(base_path, fullname) + + base_file_name = os.path.splitext(abs_model_path)[0] + + utils.print_debug(f"Try to migrate legacy info for {abs_model_path}") + + preview_path = utils.join_path( + os.path.dirname(abs_model_path), + utils.get_model_preview_name(abs_model_path), + ) + new_preview_path = f"{base_file_name}.webp" + + if os.path.isfile(preview_path) and preview_path != new_preview_path: + utils.print_info(f"Migrate preview image from {fullname}") + with Image.open(preview_path) as image: + image.save(new_preview_path, format="WEBP") + os.remove(preview_path) + + description_path = f"{base_file_name}.md" + + metadata_info = { + "website": "Civitai", + } + + url_info_path = f"{base_file_name}.url" + if os.path.isfile(url_info_path): + with open(url_info_path, "r", encoding="utf-8") as f: + for line in f: + if line.startswith("URL="): + model_page_url = line[len("URL=") :].strip() + metadata_info.update({"modelPage": model_page_url}) + + json_info_path = f"{base_file_name}.json" + if os.path.isfile(json_info_path): + with open(json_info_path, "r", encoding="utf-8") as f: + version = json.load(f) + metadata_info.update( + { + "baseModel": version.get("baseModel"), + "preview": [i["url"] for i in version["images"]], + } + ) + + description_parts: list[str] = [ + "---", + yaml.dump(metadata_info).strip(), + "---", + "", + ] + + text_info_path = f"{base_file_name}.txt" + if os.path.isfile(text_info_path): + with open(text_info_path, "r", encoding="utf-8") as f: + description_parts.append(f.read()) + + description_path = f"{base_file_name}.md" + + if os.path.isfile(text_info_path): + utils.print_info(f"Migrate description from {fullname}") + with open(description_path, "w", encoding="utf-8", newline="") as f: + f.write("\n".join(description_parts)) + + def try_to_remove_file(file_path): + if os.path.isfile(file_path): + os.remove(file_path) + + try_to_remove_file(url_info_path) + try_to_remove_file(text_info_path) + try_to_remove_file(json_info_path) + + utils.print_debug("Completed migrate model information.") diff --git a/py/utils.py b/py/utils.py index a3efa54..5e53db0 100644 --- a/py/utils.py +++ b/py/utils.py @@ -29,6 +29,27 @@ def print_debug(msg, *args, **kwargs): logging.debug(f"[{config.extension_tag}] {msg}", *args, **kwargs) +def _matches(predicate: dict): + def _filter(obj: dict): + return all(obj.get(key, None) == value for key, value in predicate.items()) + + return _filter + + +def filter_with(list: list, predicate): + if isinstance(predicate, dict): + predicate = _matches(predicate) + + return [item for item in list if predicate(item)] + + +async def get_request_body(request) -> dict: + try: + return await request.json() + except: + return {} + + def normalize_path(path: str): normpath = os.path.normpath(path) return normpath.replace(os.path.sep, "/") @@ -202,41 +223,22 @@ def get_model_preview_name(model_path: str): return images[0] if len(images) > 0 else "no-preview.png" -def save_model_preview_image(model_path: str, image_file: Any): - if not isinstance(image_file, web.FileField): - raise RuntimeError("Invalid image file") +from PIL import Image +from io import BytesIO - content_type: str = image_file.content_type - if not content_type.startswith("image/"): - raise RuntimeError(f"FileTypeError: expected image, got {content_type}") - base_dirname = os.path.dirname(model_path) +def save_model_preview_image(model_path: str, image_url: str): + try: + image_response = requests.get(image_url) + image_response.raise_for_status() - # remove old preview images - old_preview_images = get_model_all_images(model_path) - a1111_civitai_helper_image = False - for image in old_preview_images: - if os.path.splitext(image)[1].endswith(".preview"): - a1111_civitai_helper_image = True - image_path = join_path(base_dirname, image) - os.remove(image_path) + basename = os.path.splitext(model_path)[0] + preview_path = f"{basename}.webp" + image = Image.open(BytesIO(image_response.content)) + image.save(preview_path, "WEBP") - # save new preview image - basename = os.path.splitext(os.path.basename(model_path))[0] - extension = f".{content_type.split('/')[1]}" - new_preview_path = join_path(base_dirname, f"{basename}{extension}") - - with open(new_preview_path, "wb") as f: - f.write(image_file.file.read()) - - # TODO Is it possible to abandon the current rules and adopt the rules of a1111 civitai_helper? - if a1111_civitai_helper_image: - """ - Keep preview image of a1111_civitai_helper - """ - new_preview_path = join_path(base_dirname, f"{basename}.preview{extension}") - with open(new_preview_path, "wb") as f: - f.write(image_file.file.read()) + except Exception as e: + print_error(f"Failed to download image: {e}") def get_model_all_descriptions(model_path: str): @@ -361,20 +363,43 @@ def get_setting_value(request: web.Request, key: str, default: Any = None) -> An return settings.get(setting_id, default) -from dataclasses import asdict, is_dataclass - - -def unpack_dataclass(data: Any): - if isinstance(data, dict): - return {key: unpack_dataclass(value) for key, value in data.items()} - elif isinstance(data, list): - return [unpack_dataclass(x) for x in data] - elif is_dataclass(data): - return asdict(data) - else: - return data - - async def send_json(event: str, data: Any, sid: str = None): - detail = unpack_dataclass(data) - await config.serverInstance.send_json(event, detail, sid) + await config.serverInstance.send_json(event, data, sid) + + +import sys +import subprocess +import importlib.util +import importlib.metadata + + +def is_installed(package_name: str): + try: + dist = importlib.metadata.distribution(package_name) + except importlib.metadata.PackageNotFoundError: + try: + spec = importlib.util.find_spec(package_name) + except ModuleNotFoundError: + return False + + return spec is not None + + return dist is not None + + +def pip_install(package_name: str): + subprocess.run([sys.executable, "-m", "pip", "install", package_name], check=True) + + +import hashlib + + +def calculate_sha256(path, buffer_size=1024 * 1024): + sha256 = hashlib.sha256() + with open(path, "rb") as f: + while True: + data = f.read(buffer_size) + if not data: + break + sha256.update(data) + return sha256.hexdigest() diff --git a/pyproject.toml b/pyproject.toml index 95fc4b0..3a398ef 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,6 +3,7 @@ name = "comfyui-model-manager" description = "Manage models: browsing, download and delete." version = "2.0.3" license = "LICENSE" +dependencies = ["markdownify"] [project.urls] Repository = "https://github.com/hayden-fr/ComfyUI-Model-Manager" diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..06a83f1 --- /dev/null +++ b/requirements.txt @@ -0,0 +1 @@ +markdownify \ No newline at end of file diff --git a/src/components/DialogCreateTask.vue b/src/components/DialogCreateTask.vue index e7dcef7..089449d 100644 --- a/src/components/DialogCreateTask.vue +++ b/src/components/DialogCreateTask.vue @@ -70,7 +70,6 @@ import { request } from 'hooks/request' import { useToast } from 'hooks/toast' import Button from 'primevue/button' import { VersionModel } from 'types/typings' -import { previewUrlToFile } from 'utils/common' import { ref } from 'vue' const { isMobile } = useConfig() @@ -89,38 +88,11 @@ const searchModelsByUrl = async () => { } const createDownTask = async (data: VersionModel) => { - const formData = new FormData() - loading.show() - // set base info - formData.append('type', data.type) - formData.append('pathIndex', data.pathIndex.toString()) - formData.append('fullname', data.fullname) - // set preview - const previewFile = await previewUrlToFile(data.preview as string).catch( - () => { - loading.hide() - toast.add({ - severity: 'error', - summary: 'Error', - detail: 'Failed to download preview', - life: 15000, - }) - throw new Error('Failed to download preview') - }, - ) - formData.append('previewFile', previewFile) - // set description - formData.append('description', data.description) - // set model download info - formData.append('downloadPlatform', data.downloadPlatform) - formData.append('downloadUrl', data.downloadUrl) - formData.append('sizeBytes', data.sizeBytes.toString()) - formData.append('hashes', JSON.stringify(data.hashes)) await request('/model', { method: 'POST', - body: formData, + body: JSON.stringify(data), }) .then(() => { dialog.close({ key: 'model-manager-create-task' }) diff --git a/src/hooks/config.ts b/src/hooks/config.ts index 358b148..0674d3d 100644 --- a/src/hooks/config.ts +++ b/src/hooks/config.ts @@ -1,9 +1,10 @@ -import { useRequest } from 'hooks/request' +import { request, useRequest } from 'hooks/request' import { defineStore } from 'hooks/store' -import { app } from 'scripts/comfyAPI' +import { $el, app, ComfyDialog } from 'scripts/comfyAPI' import { onMounted, onUnmounted, ref } from 'vue' +import { useToast } from './toast' -export const useConfig = defineStore('config', () => { +export const useConfig = defineStore('config', (store) => { const mobileDeviceBreakPoint = 759 const isMobile = ref(window.innerWidth < mobileDeviceBreakPoint) @@ -36,7 +37,7 @@ export const useConfig = defineStore('config', () => { refresh, } - useAddConfigSettings() + useAddConfigSettings(store) return config }) @@ -49,7 +50,41 @@ declare module 'hooks/store' { } } -function useAddConfigSettings() { +function useAddConfigSettings(store: import('hooks/store').StoreProvider) { + const { toast } = useToast() + + const confirm = (opts: { + message?: string + accept?: () => void + reject?: () => void + }) => { + const dialog = new ComfyDialog('div', []) + + dialog.show( + $el('div', [ + $el('p', { textContent: opts.message }), + $el('div.flex.gap-4', [ + $el('button.flex-1', { + textContent: 'Cancel', + onclick: () => { + opts.reject?.() + dialog.close() + document.body.removeChild(dialog.element) + }, + }), + $el('button.flex-1', { + textContent: 'Continue', + onclick: () => { + opts.accept?.() + dialog.close() + document.body.removeChild(dialog.element) + }, + }), + ]), + ]), + ) + } + onMounted(() => { // API keys app.ui?.settings.addSetting({ @@ -65,5 +100,144 @@ function useAddConfigSettings() { type: 'text', defaultValue: undefined, }) + + // Migrate + app.ui?.settings.addSetting({ + id: 'ModelManager.Migrate.Migrate', + name: 'Migrate information from cdb-boop/main', + defaultValue: '', + type: () => { + return $el('button.p-button.p-component.p-button-secondary', { + textContent: 'Migrate', + onclick: () => { + confirm({ + message: [ + 'This operation will delete old files and override current files if it exists.', + // 'This may take a while and generate MANY server requests!', + 'Continue?', + ].join('\n'), + accept: () => { + store.loading.loading.value = true + request('/migrate', { + method: 'POST', + }) + .then(() => { + toast.add({ + severity: 'success', + summary: 'Complete migration', + life: 2000, + }) + store.models.refresh() + }) + .catch((err) => { + toast.add({ + severity: 'error', + summary: 'Error', + detail: err.message ?? 'Failed to migrate information', + life: 15000, + }) + }) + .finally(() => { + store.loading.loading.value = false + }) + }, + }) + }, + }) + }, + }) + + // Scan information + app.ui?.settings.addSetting({ + id: 'ModelManager.ScanFiles.Full', + name: "Override all models' information and preview", + defaultValue: '', + type: () => { + return $el('button.p-button.p-component.p-button-secondary', { + textContent: 'Full Scan', + onclick: () => { + confirm({ + message: [ + 'This operation will override current files.', + 'This may take a while and generate MANY server requests!', + 'USE AT YOUR OWN RISK! Continue?', + ].join('\n'), + accept: () => { + store.loading.loading.value = true + request('/model-info/scan', { + method: 'POST', + body: JSON.stringify({ scanMode: 'full' }), + }) + .then(() => { + toast.add({ + severity: 'success', + summary: 'Complete download information', + life: 2000, + }) + store.models.refresh() + }) + .catch((err) => { + toast.add({ + severity: 'error', + summary: 'Error', + detail: err.message ?? 'Failed to download information', + life: 15000, + }) + }) + .finally(() => { + store.loading.loading.value = false + }) + }, + }) + }, + }) + }, + }) + + app.ui?.settings.addSetting({ + id: 'ModelManager.ScanFiles.Incremental', + name: 'Download missing information or preview', + defaultValue: '', + type: () => { + return $el('button.p-button.p-component.p-button-secondary', { + textContent: 'Diff Scan', + onclick: () => { + confirm({ + message: [ + 'Download missing information or preview.', + 'This may take a while and generate MANY server requests!', + 'USE AT YOUR OWN RISK! Continue?', + ].join('\n'), + accept: () => { + store.loading.loading.value = true + request('/model-info/scan', { + method: 'POST', + body: JSON.stringify({ scanMode: 'diff' }), + }) + .then(() => { + toast.add({ + severity: 'success', + summary: 'Complete download information', + life: 2000, + }) + store.models.refresh() + }) + .catch((err) => { + toast.add({ + severity: 'error', + summary: 'Error', + detail: err.message ?? 'Failed to download information', + life: 15000, + }) + }) + .finally(() => { + store.loading.loading.value = false + }) + }, + }) + }, + }) + }, + }) }) } diff --git a/src/hooks/download.ts b/src/hooks/download.ts index b4d3c1a..9945540 100644 --- a/src/hooks/download.ts +++ b/src/hooks/download.ts @@ -1,5 +1,4 @@ import { useLoading } from 'hooks/loading' -import { MarkdownTool, useMarkdown } from 'hooks/markdown' import { request } from 'hooks/request' import { defineStore } from 'hooks/store' import { useToast } from 'hooks/toast' @@ -157,253 +156,8 @@ declare module 'hooks/store' { } } -abstract class ModelSearch { - constructor(readonly md: MarkdownTool) {} - - abstract search(pathname: string): Promise