Feature scan info (#53)

* pref: migrate fetch model info to end back

* fix(download): can't select model type

* feat: add scan model info

* feat: add trigger button in setting

* feat: add printing logs

* chore: add explanation of scan model info
This commit is contained in:
Hayden
2024-11-21 22:04:39 +08:00
committed by GitHub
parent 6ae7e1835f
commit 659637c6e0
20 changed files with 921 additions and 428 deletions

View File

@@ -24,6 +24,34 @@ class TaskStatus:
bps: float = 0
error: Optional[str] = None
def __init__(self, **kwargs):
self.taskId = kwargs.get("taskId", None)
self.type = kwargs.get("type", None)
self.fullname = kwargs.get("fullname", None)
self.preview = kwargs.get("preview", None)
self.status = kwargs.get("status", "pause")
self.platform = kwargs.get("platform", None)
self.downloadedSize = kwargs.get("downloadedSize", 0)
self.totalSize = kwargs.get("totalSize", 0)
self.progress = kwargs.get("progress", 0)
self.bps = kwargs.get("bps", 0)
self.error = kwargs.get("error", None)
def to_dict(self):
return {
"taskId": self.taskId,
"type": self.type,
"fullname": self.fullname,
"preview": self.preview,
"status": self.status,
"platform": self.platform,
"downloadedSize": self.downloadedSize,
"totalSize": self.totalSize,
"progress": self.progress,
"bps": self.bps,
"error": self.error,
}
@dataclass
class TaskContent:
@@ -33,9 +61,31 @@ class TaskContent:
description: str
downloadPlatform: str
downloadUrl: str
sizeBytes: float
sizeBytes: int
hashes: Optional[dict[str, str]] = None
def __init__(self, **kwargs):
self.type = kwargs.get("type", None)
self.pathIndex = int(kwargs.get("pathIndex", 0))
self.fullname = kwargs.get("fullname", None)
self.description = kwargs.get("description", None)
self.downloadPlatform = kwargs.get("downloadPlatform", None)
self.downloadUrl = kwargs.get("downloadUrl", None)
self.sizeBytes = int(kwargs.get("sizeBytes", 0))
self.hashes = kwargs.get("hashes", None)
def to_dict(self):
return {
"type": self.type,
"pathIndex": self.pathIndex,
"fullname": self.fullname,
"description": self.description,
"downloadPlatform": self.downloadPlatform,
"downloadUrl": self.downloadUrl,
"sizeBytes": self.sizeBytes,
"hashes": self.hashes,
}
download_model_task_status: dict[str, TaskStatus] = {}
download_thread_pool = thread.DownloadThreadPool()
@@ -44,7 +94,7 @@ download_thread_pool = thread.DownloadThreadPool()
def set_task_content(task_id: str, task_content: Union[TaskContent, dict]):
download_path = utils.get_download_path()
task_file_path = utils.join_path(download_path, f"{task_id}.task")
utils.save_dict_pickle_file(task_file_path, utils.unpack_dataclass(task_content))
utils.save_dict_pickle_file(task_file_path, task_content)
def get_task_content(task_id: str):
@@ -53,8 +103,6 @@ def get_task_content(task_id: str):
if not os.path.isfile(task_file):
raise RuntimeError(f"Task {task_id} not found")
task_content = utils.load_dict_pickle_file(task_file)
task_content["pathIndex"] = int(task_content.get("pathIndex", 0))
task_content["sizeBytes"] = float(task_content.get("sizeBytes", 0))
return TaskContent(**task_content)
@@ -106,14 +154,14 @@ async def scan_model_download_task_list():
for task_file in task_files:
task_id = task_file.replace(".task", "")
task_status = get_task_status(task_id)
task_list.append(task_status)
task_list.append(task_status.to_dict())
return utils.unpack_dataclass(task_list)
return task_list
async def create_model_download_task(task_data: dict, request):
"""
Creates a download task for the given post.
Creates a download task for the given data.
"""
model_type = task_data.get("type", None)
path_index = int(task_data.get("pathIndex", None))
@@ -132,8 +180,8 @@ async def create_model_download_task(task_data: dict, request):
raise RuntimeError(f"Task {task_id} already exists")
try:
previewFile = task_data.pop("previewFile", None)
utils.save_model_preview_image(task_path, previewFile)
preview_url = task_data.pop("preview", None)
utils.save_model_preview_image(task_path, preview_url)
set_task_content(task_id, task_data)
task_status = TaskStatus(
taskId=task_id,
@@ -144,7 +192,7 @@ async def create_model_download_task(task_data: dict, request):
totalSize=float(task_data.get("sizeBytes", 0)),
)
download_model_task_status[task_id] = task_status
await utils.send_json("create_download_task", task_status)
await utils.send_json("create_download_task", task_status.to_dict())
except Exception as e:
await delete_model_download_task(task_id)
raise RuntimeError(str(e)) from e
@@ -183,7 +231,7 @@ async def delete_model_download_task(task_id: str):
async def download_model(task_id: str, request):
async def download_task(task_id: str):
async def report_progress(task_status: TaskStatus):
await utils.send_json("update_download_task", task_status)
await utils.send_json("update_download_task", task_status.to_dict())
try:
# When starting a task from the queue, the task may not exist
@@ -193,7 +241,7 @@ async def download_model(task_id: str, request):
# Update task status
task_status.status = "doing"
await utils.send_json("update_download_task", task_status)
await utils.send_json("update_download_task", task_status.to_dict())
try:
@@ -221,7 +269,7 @@ async def download_model(task_id: str, request):
except Exception as e:
task_status.status = "pause"
task_status.error = str(e)
await utils.send_json("update_download_task", task_status)
await utils.send_json("update_download_task", task_status.to_dict())
task_status.error = None
utils.print_error(str(e))
@@ -230,11 +278,11 @@ async def download_model(task_id: str, request):
if status == "Waiting":
task_status = get_task_status(task_id)
task_status.status = "waiting"
await utils.send_json("update_download_task", task_status)
await utils.send_json("update_download_task", task_status.to_dict())
except Exception as e:
task_status.status = "pause"
task_status.error = str(e)
await utils.send_json("update_download_task", task_status)
await utils.send_json("update_download_task", task_status.to_dict())
task_status.error = None
utils.print_error(str(e))
@@ -339,7 +387,7 @@ async def download_model_file(
task_content.sizeBytes = total_size
task_status.totalSize = total_size
set_task_content(task_id, task_content)
await utils.send_json("update_download_task", task_content)
await utils.send_json("update_download_task", task_content.to_dict())
with open(download_tmp_file, "ab") as f:
for chunk in response.iter_content(chunk_size=8192):
@@ -358,4 +406,4 @@ async def download_model_file(
await download_complete()
else:
task_status.status = "pause"
await utils.send_json("update_download_task", task_status)
await utils.send_json("update_download_task", task_status.to_dict())

317
py/searcher.py Normal file
View File

@@ -0,0 +1,317 @@
import os
import re
import yaml
import requests
import markdownify
from abc import ABC, abstractmethod
from urllib.parse import urlparse, parse_qs
from . import utils
class ModelSearcher(ABC):
"""
Abstract class for model searcher.
"""
@abstractmethod
def search_by_url(self, url: str) -> list[dict]:
pass
@abstractmethod
def search_by_hash(self, hash: str) -> dict:
pass
class UnknownWebsiteSearcher(ModelSearcher):
def search_by_url(self, url: str):
raise RuntimeError(
f"Unknown Website, please input a URL from huggingface.co or civitai.com."
)
def search_by_hash(self, hash: str):
raise RuntimeError(f"Unknown Website, unable to search with hash value.")
class CivitaiModelSearcher(ModelSearcher):
def search_by_url(self, url: str):
parsed_url = urlparse(url)
pathname = parsed_url.path
match = re.match(r"^/models/(\d*)", pathname)
model_id = match.group(1) if match else None
query_params = parse_qs(parsed_url.query)
version_id = query_params.get("modelVersionId", [None])[0]
if not model_id:
return []
response = requests.get(f"https://civitai.com/api/v1/models/{model_id}")
response.raise_for_status()
res_data: dict = response.json()
model_versions: list[dict] = res_data["modelVersions"]
if version_id:
model_versions = utils.filter_with(model_versions, {"id": int(version_id)})
models: list[dict] = []
for version in model_versions:
model_files: list[dict] = version.get("files", [])
model_files = utils.filter_with(model_files, {"type": "Model"})
shortname = version.get("name", None) if len(model_files) > 0 else None
for file in model_files:
fullname = file.get("name", None)
extension = os.path.splitext(fullname)[1]
basename = os.path.splitext(fullname)[0]
metadata_info = {
"website": "Civitai",
"modelPage": f"https://civitai.com/models/{model_id}?modelVersionId={version.get('id')}",
"author": res_data.get("creator", {}).get("username", None),
"baseModel": version.get("baseModel"),
"hashes": file.get("hashes"),
"metadata": file.get("metadata"),
"preview": [i["url"] for i in version["images"]],
}
description_parts: list[str] = []
description_parts.append("---")
description_parts.append(yaml.dump(metadata_info).strip())
description_parts.append("---")
description_parts.append("")
description_parts.append(f"# Trigger Words")
description_parts.append("")
description_parts.append(
", ".join(version.get("trainedWords", ["No trigger words"]))
)
description_parts.append("")
description_parts.append(f"# About this version")
description_parts.append("")
description_parts.append(
markdownify.markdownify(
version.get(
"description", "<p>No description about this version</p>"
)
).strip()
)
description_parts.append("")
description_parts.append(f"# {res_data.get('name')}")
description_parts.append("")
description_parts.append(
markdownify.markdownify(
res_data.get(
"description", "<p>No description about this model</p>"
)
).strip()
)
description_parts.append("")
model = {
"id": file.get("id"),
"shortname": shortname or basename,
"fullname": fullname,
"basename": basename,
"extension": extension,
"preview": metadata_info.get("preview"),
"sizeBytes": file.get("sizeKB", 0) * 1024,
"type": self._resolve_model_type(res_data.get("type", "unknown")),
"pathIndex": 0,
"description": "\n".join(description_parts),
"metadata": file.get("metadata"),
"downloadPlatform": "civitai",
"downloadUrl": file.get("downloadUrl"),
"hashes": file.get("hashes"),
}
models.append(model)
return models
def search_by_hash(self, hash: str):
if not hash:
raise RuntimeError(f"Hash value is empty.")
response = requests.get(
f"https://civitai.com/api/v1/model-versions/by-hash/{hash}"
)
response.raise_for_status()
version: dict = response.json()
model_id = version.get("modelId")
version_id = version.get("id")
model_page = (
f"https://civitai.com/models/{model_id}?modelVersionId={version_id}"
)
models = self.search_by_url(model_page)
for model in models:
sha256 = model.get("hashes", {}).get("SHA256")
if sha256 == hash:
return model
return models[0]
def _resolve_model_type(self, model_type: str):
map_legacy = {
"TextualInversion": "embeddings",
"LoCon": "loras",
"DoRA": "loras",
"Controlnet": "controlnet",
"Upscaler": "upscale_models",
"VAE": "vae",
"unknown": "unknown",
}
return map_legacy.get(model_type, f"{model_type.lower()}s")
class HuggingfaceModelSearcher(ModelSearcher):
def search_by_url(self, url: str):
parsed_url = urlparse(url)
pathname = parsed_url.path
space, name, *rest_paths = pathname.strip("/").split("/")
model_id = f"{space}/{name}"
rest_pathname = "/".join(rest_paths)
response = requests.get(f"https://huggingface.co/api/models/{model_id}")
response.raise_for_status()
res_data: dict = response.json()
sibling_files: list[str] = [
x.get("rfilename") for x in res_data.get("siblings", [])
]
model_files = utils.filter_with(
utils.filter_with(sibling_files, self._match_model_files()),
self._match_tree_files(rest_pathname),
)
image_files = utils.filter_with(
utils.filter_with(sibling_files, self._match_image_files()),
self._match_tree_files(rest_pathname),
)
image_files = [
f"https://huggingface.co/{model_id}/resolve/main/{filename}"
for filename in image_files
]
models: list[dict] = []
for filename in model_files:
fullname = os.path.basename(filename)
extension = os.path.splitext(fullname)[1]
basename = os.path.splitext(fullname)[0]
description_parts: list[str] = []
metadata_info = {
"website": "HuggingFace",
"modelPage": f"https://huggingface.co/{model_id}",
"author": res_data.get("author", None),
"preview": image_files,
}
description_parts: list[str] = []
description_parts.append("---")
description_parts.append(yaml.dump(metadata_info).strip())
description_parts.append("---")
description_parts.append("")
description_parts.append(f"# Trigger Words")
description_parts.append("")
description_parts.append("No trigger words")
description_parts.append("")
description_parts.append(f"# About this version")
description_parts.append("")
description_parts.append("No description about this version")
description_parts.append("")
description_parts.append(f"# {res_data.get('name')}")
description_parts.append("")
description_parts.append("No description about this model")
description_parts.append("")
model = {
"id": filename,
"shortname": filename,
"fullname": fullname,
"basename": basename,
"extension": extension,
"preview": image_files,
"sizeBytes": 0,
"type": "unknown",
"pathIndex": 0,
"description": "\n".join(description_parts),
"metadata": {},
"downloadPlatform": "",
"downloadUrl": f"https://huggingface.co/{model_id}/resolve/main/{filename}?download=true",
}
models.append(model)
return models
def search_by_hash(self, hash: str):
raise RuntimeError("Hash search is not supported by Huggingface.")
def _match_model_files(self):
extension = [
".bin",
".ckpt",
".gguf",
".onnx",
".pt",
".pth",
".safetensors",
]
def _filter_model_files(file: str):
return any(file.endswith(ext) for ext in extension)
return _filter_model_files
def _match_image_files(self):
extension = [
".png",
".webp",
".jpeg",
".jpg",
".jfif",
".gif",
".apng",
]
def _filter_image_files(file: str):
return any(file.endswith(ext) for ext in extension)
return _filter_image_files
def _match_tree_files(self, pathname: str):
target, *paths = pathname.split("/")
def _filter_tree_files(file: str):
if not target:
return True
if target != "tree" and target != "blob":
return True
prefix_path = "/".join(paths)
return file.startswith(prefix_path)
return _filter_tree_files
def get_model_searcher_by_url(url: str) -> ModelSearcher:
parsed_url = urlparse(url)
host_name = parsed_url.hostname
if host_name == "civitai.com":
return CivitaiModelSearcher()
elif host_name == "huggingface.co":
return HuggingfaceModelSearcher()
return UnknownWebsiteSearcher()

View File

@@ -5,6 +5,7 @@ import folder_paths
from . import config
from . import utils
from . import download
from . import searcher
def scan_models():
@@ -128,3 +129,180 @@ async def resume_model_download_task(task_id, request):
async def delete_model_download_task(task_id):
return await download.delete_model_download_task(task_id)
def fetch_model_info(model_page: str):
if not model_page:
return []
model_searcher = searcher.get_model_searcher_by_url(model_page)
result = model_searcher.search_by_url(model_page)
return result
async def download_model_info(scan_mode: str):
utils.print_info(f"Download model info for {scan_mode}")
model_base_paths = config.model_base_paths
for model_type in model_base_paths:
folders, extensions = folder_paths.folder_names_and_paths[model_type]
for path_index, base_path in enumerate(folders):
files = utils.recursive_search_files(base_path)
models = folder_paths.filter_files_extensions(files, extensions)
images = folder_paths.filter_files_content_types(files, ["image"])
image_dict = utils.file_list_to_name_dict(images)
descriptions = folder_paths.filter_files_extensions(files, [".md"])
description_dict = utils.file_list_to_name_dict(descriptions)
for fullname in models:
fullname = utils.normalize_path(fullname)
basename = os.path.splitext(fullname)[0]
abs_model_path = utils.join_path(base_path, fullname)
image_name = image_dict.get(basename, "no-preview.png")
abs_image_path = utils.join_path(base_path, image_name)
has_preview = os.path.isfile(abs_image_path)
description_name = description_dict.get(basename, None)
abs_description_path = (
utils.join_path(base_path, description_name)
if description_name
else None
)
has_description = (
os.path.isfile(abs_description_path)
if abs_description_path
else False
)
try:
utils.print_info(f"Checking model {abs_model_path}")
utils.print_debug(f"Scan mode: {scan_mode}")
utils.print_debug(f"Has preview: {has_preview}")
utils.print_debug(f"Has description: {has_description}")
if scan_mode != "full" and (has_preview and has_description):
continue
utils.print_debug(f"Calculate sha256 for {abs_model_path}")
hash_value = utils.calculate_sha256(abs_model_path)
utils.print_info(f"Searching model info by hash {hash_value}")
model_info = searcher.CivitaiModelSearcher().search_by_hash(
hash_value
)
preview_url_list = model_info.get("preview", [])
preview_image_url = (
preview_url_list[0] if preview_url_list else None
)
if preview_image_url:
utils.print_debug(f"Save preview image to {abs_image_path}")
utils.save_model_preview_image(
abs_model_path, preview_image_url
)
description = model_info.get("description", None)
if description:
utils.save_model_description(abs_model_path, description)
except Exception as e:
utils.print_error(
f"Failed to download model info for {abs_model_path}: {e}"
)
utils.print_debug("Completed scan model information.")
async def migrate_legacy_information():
import json
import yaml
from PIL import Image
utils.print_info(f"Migrating legacy information...")
model_base_paths = config.model_base_paths
for model_type in model_base_paths:
folders, extensions = folder_paths.folder_names_and_paths[model_type]
for path_index, base_path in enumerate(folders):
files = utils.recursive_search_files(base_path)
models = folder_paths.filter_files_extensions(files, extensions)
for fullname in models:
fullname = utils.normalize_path(fullname)
abs_model_path = utils.join_path(base_path, fullname)
base_file_name = os.path.splitext(abs_model_path)[0]
utils.print_debug(f"Try to migrate legacy info for {abs_model_path}")
preview_path = utils.join_path(
os.path.dirname(abs_model_path),
utils.get_model_preview_name(abs_model_path),
)
new_preview_path = f"{base_file_name}.webp"
if os.path.isfile(preview_path) and preview_path != new_preview_path:
utils.print_info(f"Migrate preview image from {fullname}")
with Image.open(preview_path) as image:
image.save(new_preview_path, format="WEBP")
os.remove(preview_path)
description_path = f"{base_file_name}.md"
metadata_info = {
"website": "Civitai",
}
url_info_path = f"{base_file_name}.url"
if os.path.isfile(url_info_path):
with open(url_info_path, "r", encoding="utf-8") as f:
for line in f:
if line.startswith("URL="):
model_page_url = line[len("URL=") :].strip()
metadata_info.update({"modelPage": model_page_url})
json_info_path = f"{base_file_name}.json"
if os.path.isfile(json_info_path):
with open(json_info_path, "r", encoding="utf-8") as f:
version = json.load(f)
metadata_info.update(
{
"baseModel": version.get("baseModel"),
"preview": [i["url"] for i in version["images"]],
}
)
description_parts: list[str] = [
"---",
yaml.dump(metadata_info).strip(),
"---",
"",
]
text_info_path = f"{base_file_name}.txt"
if os.path.isfile(text_info_path):
with open(text_info_path, "r", encoding="utf-8") as f:
description_parts.append(f.read())
description_path = f"{base_file_name}.md"
if os.path.isfile(text_info_path):
utils.print_info(f"Migrate description from {fullname}")
with open(description_path, "w", encoding="utf-8", newline="") as f:
f.write("\n".join(description_parts))
def try_to_remove_file(file_path):
if os.path.isfile(file_path):
os.remove(file_path)
try_to_remove_file(url_info_path)
try_to_remove_file(text_info_path)
try_to_remove_file(json_info_path)
utils.print_debug("Completed migrate model information.")

View File

@@ -29,6 +29,27 @@ def print_debug(msg, *args, **kwargs):
logging.debug(f"[{config.extension_tag}] {msg}", *args, **kwargs)
def _matches(predicate: dict):
def _filter(obj: dict):
return all(obj.get(key, None) == value for key, value in predicate.items())
return _filter
def filter_with(list: list, predicate):
if isinstance(predicate, dict):
predicate = _matches(predicate)
return [item for item in list if predicate(item)]
async def get_request_body(request) -> dict:
try:
return await request.json()
except:
return {}
def normalize_path(path: str):
normpath = os.path.normpath(path)
return normpath.replace(os.path.sep, "/")
@@ -202,41 +223,22 @@ def get_model_preview_name(model_path: str):
return images[0] if len(images) > 0 else "no-preview.png"
def save_model_preview_image(model_path: str, image_file: Any):
if not isinstance(image_file, web.FileField):
raise RuntimeError("Invalid image file")
from PIL import Image
from io import BytesIO
content_type: str = image_file.content_type
if not content_type.startswith("image/"):
raise RuntimeError(f"FileTypeError: expected image, got {content_type}")
base_dirname = os.path.dirname(model_path)
def save_model_preview_image(model_path: str, image_url: str):
try:
image_response = requests.get(image_url)
image_response.raise_for_status()
# remove old preview images
old_preview_images = get_model_all_images(model_path)
a1111_civitai_helper_image = False
for image in old_preview_images:
if os.path.splitext(image)[1].endswith(".preview"):
a1111_civitai_helper_image = True
image_path = join_path(base_dirname, image)
os.remove(image_path)
basename = os.path.splitext(model_path)[0]
preview_path = f"{basename}.webp"
image = Image.open(BytesIO(image_response.content))
image.save(preview_path, "WEBP")
# save new preview image
basename = os.path.splitext(os.path.basename(model_path))[0]
extension = f".{content_type.split('/')[1]}"
new_preview_path = join_path(base_dirname, f"{basename}{extension}")
with open(new_preview_path, "wb") as f:
f.write(image_file.file.read())
# TODO Is it possible to abandon the current rules and adopt the rules of a1111 civitai_helper?
if a1111_civitai_helper_image:
"""
Keep preview image of a1111_civitai_helper
"""
new_preview_path = join_path(base_dirname, f"{basename}.preview{extension}")
with open(new_preview_path, "wb") as f:
f.write(image_file.file.read())
except Exception as e:
print_error(f"Failed to download image: {e}")
def get_model_all_descriptions(model_path: str):
@@ -361,20 +363,43 @@ def get_setting_value(request: web.Request, key: str, default: Any = None) -> An
return settings.get(setting_id, default)
from dataclasses import asdict, is_dataclass
def unpack_dataclass(data: Any):
if isinstance(data, dict):
return {key: unpack_dataclass(value) for key, value in data.items()}
elif isinstance(data, list):
return [unpack_dataclass(x) for x in data]
elif is_dataclass(data):
return asdict(data)
else:
return data
async def send_json(event: str, data: Any, sid: str = None):
detail = unpack_dataclass(data)
await config.serverInstance.send_json(event, detail, sid)
await config.serverInstance.send_json(event, data, sid)
import sys
import subprocess
import importlib.util
import importlib.metadata
def is_installed(package_name: str):
try:
dist = importlib.metadata.distribution(package_name)
except importlib.metadata.PackageNotFoundError:
try:
spec = importlib.util.find_spec(package_name)
except ModuleNotFoundError:
return False
return spec is not None
return dist is not None
def pip_install(package_name: str):
subprocess.run([sys.executable, "-m", "pip", "install", package_name], check=True)
import hashlib
def calculate_sha256(path, buffer_size=1024 * 1024):
sha256 = hashlib.sha256()
with open(path, "rb") as f:
while True:
data = f.read(buffer_size)
if not data:
break
sha256.update(data)
return sha256.hexdigest()