Files
SparkyUI/model-manager/app/registries.py
T
TBNilles e8115e7aa6 feat(model-manager): add CivitAI browse view + fix civitai.red / 401
- New "Browse CivitAI" view: thumbnail grid with search, type/sort/period
  filters and NSFW toggle; click a model card to download it (per-card
  version picker for multi-version models). Cursor + page based "Load more".
- Backend: /api/civitai/search and /api/civitai/download endpoints; new
  civitai_search() catalog helper.
- Fix 401 on paste: recognize the civitai.red mirror (and any civitai.*
  host), normalize API calls to civitai.com, and always resolve the
  model-version so type + filename are auto-detected for every CivitAI URL.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-07 14:50:10 -04:00

244 lines
7.7 KiB
Python

"""Resolve a user-supplied URL into a concrete download (direct URL + auth + metadata)
and provide CivitAI catalog search.
Supports direct URLs, CivitAI (model/version pages, civitai.red mirror, or api
download links), and HuggingFace `resolve` URLs. API keys are read from settings.
"""
from __future__ import annotations
import re
from dataclasses import dataclass, field
from typing import Optional
from urllib.parse import unquote, urlparse, parse_qs
import httpx
from . import db
CIVITAI_API = "https://civitai.com/api/v1"
# CivitAI model `type` -> our model-type key.
CIVITAI_TYPE_MAP = {
"Checkpoint": "checkpoint",
"LORA": "lora",
"LoCon": "lora",
"DoRA": "lora",
"TextualInversion": "embedding",
"Hypernetwork": "hypernetwork",
"VAE": "vae",
"Controlnet": "controlnet",
"Upscaler": "upscaler",
"MotionModule": "other",
"Poses": "other",
"Wildcards": "other",
"Workflows": "other",
"Other": "other",
}
@dataclass
class Resolved:
download_url: str
source: str
headers: dict[str, str] = field(default_factory=dict)
filename: Optional[str] = None
model_type: Optional[str] = None # suggested type if the registry tells us
def _is_civitai_host(host: str) -> bool:
# Matches civitai.com, civitai.red, and any civitai.* mirror.
return "civitai" in host
def detect_source(url: str) -> str:
host = (urlparse(url).hostname or "").lower()
if _is_civitai_host(host):
return "civitai"
if "huggingface.co" in host or "hf.co" in host:
return "huggingface"
return "direct"
def _filename_from_url(url: str) -> Optional[str]:
path = urlparse(url).path
name = unquote(path.rsplit("/", 1)[-1]) if path else ""
return name or None
def civitai_headers() -> dict[str, str]:
headers: dict[str, str] = {}
key = db.get_setting("civitai_api_key")
if key:
headers["Authorization"] = f"Bearer {key}"
return headers
async def resolve(url: str) -> Resolved:
source = detect_source(url)
if source == "civitai":
return await _resolve_civitai(url)
if source == "huggingface":
return _resolve_huggingface(url)
return _resolve_direct(url)
def _resolve_direct(url: str) -> Resolved:
return Resolved(download_url=url, source="direct",
filename=_filename_from_url(url))
def _resolve_huggingface(url: str) -> Resolved:
headers: dict[str, str] = {}
token = db.get_setting("huggingface_token")
if token:
headers["Authorization"] = f"Bearer {token}"
return Resolved(download_url=url, source="huggingface", headers=headers,
filename=_filename_from_url(url))
_CIVITAI_VERSION_RE = re.compile(r"/api/download/models/(\d+)")
_CIVITAI_MODEL_RE = re.compile(r"/models/(\d+)")
async def _civitai_get(path: str) -> dict:
async with httpx.AsyncClient(timeout=30, headers=civitai_headers()) as client:
resp = await client.get(f"{CIVITAI_API}{path}")
resp.raise_for_status()
return resp.json()
async def _civitai_version_id(url: str) -> str:
"""Extract or look up the model-version id from any CivitAI URL form."""
parsed = urlparse(url)
# api/download/models/{versionId}
m = _CIVITAI_VERSION_RE.search(parsed.path)
if m:
return m.group(1)
# ...?modelVersionId=...
qs = parse_qs(parsed.query)
if "modelVersionId" in qs:
return qs["modelVersionId"][0]
# model page: /models/{id} -> first version
m = _CIVITAI_MODEL_RE.search(parsed.path)
if m:
data = await _civitai_get(f"/models/{m.group(1)}")
versions = data.get("modelVersions") or []
if not versions:
raise ValueError("CivitAI model has no downloadable versions")
return str(versions[0]["id"])
raise ValueError("Unrecognized CivitAI URL")
async def _resolve_civitai(url: str) -> Resolved:
"""Resolve any CivitAI URL to a concrete file via the model-versions API.
Always looks the version up so we get the real filename and model type
(works for civitai.com and the civitai.red mirror alike).
"""
headers = civitai_headers()
version_id = await _civitai_version_id(url)
version = await _civitai_get(f"/model-versions/{version_id}")
files = version.get("files") or []
chosen = next((f for f in files if f.get("primary")), files[0] if files else None)
if not chosen:
raise ValueError("CivitAI version has no files")
civ_type = (version.get("model") or {}).get("type") or "Other"
model_type = CIVITAI_TYPE_MAP.get(civ_type, "other")
download_url = chosen.get("downloadUrl") or \
f"https://civitai.com/api/download/models/{version_id}"
return Resolved(
download_url=download_url,
source="civitai",
headers=headers,
filename=chosen.get("name") or _filename_from_url(download_url),
model_type=model_type,
)
# ---- catalog search -------------------------------------------------------
def _pick_thumbnail(version: dict) -> Optional[dict]:
"""Return {url, type} of a representative preview for a model version."""
images = version.get("images") or []
for im in images:
if im.get("type") == "image":
return {"url": im.get("url"), "type": "image"}
if images:
return {"url": images[0].get("url"), "type": images[0].get("type", "image")}
return None
def _to_card(item: dict) -> dict:
versions = item.get("modelVersions") or []
v0 = versions[0] if versions else {}
thumb = _pick_thumbnail(v0) if v0 else None
stats = item.get("stats") or {}
return {
"id": item.get("id"),
"name": item.get("name"),
"type": item.get("type"),
"type_key": CIVITAI_TYPE_MAP.get(item.get("type") or "", "other"),
"nsfw": item.get("nsfw", False),
"creator": (item.get("creator") or {}).get("username"),
"downloads": stats.get("downloadCount"),
"thumbnail": thumb.get("url") if thumb else None,
"thumbnail_type": thumb.get("type") if thumb else None,
"primary_version_id": v0.get("id"),
"versions": [
{"id": v.get("id"), "name": v.get("name")} for v in versions
],
}
async def civitai_search(query: Optional[str] = None,
types: Optional[list[str]] = None,
sort: Optional[str] = None,
period: Optional[str] = None,
nsfw: Optional[bool] = None,
page: Optional[int] = None,
cursor: Optional[str] = None,
limit: int = 24) -> dict:
"""Search the CivitAI model catalog, returning normalized cards + paging.
CivitAI uses cursor paging whenever a `query` is supplied (page is ignored),
and page paging otherwise, so we support and return both.
"""
params: dict = {"limit": max(1, min(limit, 100))}
if query:
params["query"] = query
if types:
params["types"] = types # httpx serializes lists as repeated params
if sort:
params["sort"] = sort
if period:
params["period"] = period
if nsfw is not None:
params["nsfw"] = "true" if nsfw else "false"
if cursor:
params["cursor"] = cursor
elif page:
params["page"] = page
async with httpx.AsyncClient(timeout=30, headers=civitai_headers()) as client:
resp = await client.get(f"{CIVITAI_API}/models", params=params)
resp.raise_for_status()
data = resp.json()
items = [_to_card(it) for it in (data.get("items") or [])]
meta = data.get("metadata") or {}
return {
"items": items,
"nextPage": meta.get("nextPage"),
"nextCursor": meta.get("nextCursor"),
"currentPage": meta.get("currentPage"),
"totalPages": meta.get("totalPages"),
}