feat(model-manager): add CivitAI browse view + fix civitai.red / 401
- New "Browse CivitAI" view: thumbnail grid with search, type/sort/period filters and NSFW toggle; click a model card to download it (per-card version picker for multi-version models). Cursor + page based "Load more". - Backend: /api/civitai/search and /api/civitai/download endpoints; new civitai_search() catalog helper. - Fix 401 on paste: recognize the civitai.red mirror (and any civitai.* host), normalize API calls to civitai.com, and always resolve the model-version so type + filename are auto-detected for every CivitAI URL. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
+140
-39
@@ -1,7 +1,8 @@
|
||||
"""Resolve a user-supplied URL into a concrete download (direct URL + auth + metadata).
|
||||
"""Resolve a user-supplied URL into a concrete download (direct URL + auth + metadata)
|
||||
and provide CivitAI catalog search.
|
||||
|
||||
Supports direct URLs, CivitAI (model/version pages or api download links), and
|
||||
HuggingFace `resolve` URLs. API keys are read from the settings table.
|
||||
Supports direct URLs, CivitAI (model/version pages, civitai.red mirror, or api
|
||||
download links), and HuggingFace `resolve` URLs. API keys are read from settings.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
@@ -14,6 +15,8 @@ import httpx
|
||||
|
||||
from . import db
|
||||
|
||||
CIVITAI_API = "https://civitai.com/api/v1"
|
||||
|
||||
# CivitAI model `type` -> our model-type key.
|
||||
CIVITAI_TYPE_MAP = {
|
||||
"Checkpoint": "checkpoint",
|
||||
@@ -28,6 +31,7 @@ CIVITAI_TYPE_MAP = {
|
||||
"MotionModule": "other",
|
||||
"Poses": "other",
|
||||
"Wildcards": "other",
|
||||
"Workflows": "other",
|
||||
"Other": "other",
|
||||
}
|
||||
|
||||
@@ -41,9 +45,14 @@ class Resolved:
|
||||
model_type: Optional[str] = None # suggested type if the registry tells us
|
||||
|
||||
|
||||
def _is_civitai_host(host: str) -> bool:
|
||||
# Matches civitai.com, civitai.red, and any civitai.* mirror.
|
||||
return "civitai" in host
|
||||
|
||||
|
||||
def detect_source(url: str) -> str:
|
||||
host = (urlparse(url).hostname or "").lower()
|
||||
if "civitai.com" in host:
|
||||
if _is_civitai_host(host):
|
||||
return "civitai"
|
||||
if "huggingface.co" in host or "hf.co" in host:
|
||||
return "huggingface"
|
||||
@@ -56,6 +65,14 @@ def _filename_from_url(url: str) -> Optional[str]:
|
||||
return name or None
|
||||
|
||||
|
||||
def civitai_headers() -> dict[str, str]:
|
||||
headers: dict[str, str] = {}
|
||||
key = db.get_setting("civitai_api_key")
|
||||
if key:
|
||||
headers["Authorization"] = f"Bearer {key}"
|
||||
return headers
|
||||
|
||||
|
||||
async def resolve(url: str) -> Resolved:
|
||||
source = detect_source(url)
|
||||
if source == "civitai":
|
||||
@@ -83,49 +100,50 @@ _CIVITAI_VERSION_RE = re.compile(r"/api/download/models/(\d+)")
|
||||
_CIVITAI_MODEL_RE = re.compile(r"/models/(\d+)")
|
||||
|
||||
|
||||
async def _resolve_civitai(url: str) -> Resolved:
|
||||
headers: dict[str, str] = {}
|
||||
api_key = db.get_setting("civitai_api_key")
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
async def _civitai_get(path: str) -> dict:
|
||||
async with httpx.AsyncClient(timeout=30, headers=civitai_headers()) as client:
|
||||
resp = await client.get(f"{CIVITAI_API}{path}")
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
|
||||
# Case 1: already a direct api/download URL -> use as-is.
|
||||
if _CIVITAI_VERSION_RE.search(urlparse(url).path):
|
||||
return Resolved(download_url=url, source="civitai", headers=headers,
|
||||
filename=_filename_from_url(url))
|
||||
|
||||
# Case 2: a model page URL. Find the version id (explicit query or first version).
|
||||
async def _civitai_version_id(url: str) -> str:
|
||||
"""Extract or look up the model-version id from any CivitAI URL form."""
|
||||
parsed = urlparse(url)
|
||||
qs = parse_qs(parsed.query)
|
||||
version_id: Optional[str] = None
|
||||
if "modelVersionId" in qs:
|
||||
version_id = qs["modelVersionId"][0]
|
||||
|
||||
if version_id is None:
|
||||
m = _CIVITAI_MODEL_RE.search(parsed.path)
|
||||
if not m:
|
||||
# Can't understand it; fall back to treating it as a direct link.
|
||||
return Resolved(download_url=url, source="civitai", headers=headers,
|
||||
filename=_filename_from_url(url))
|
||||
model_id = m.group(1)
|
||||
async with httpx.AsyncClient(timeout=30, headers=headers) as client:
|
||||
resp = await client.get(f"https://civitai.com/api/v1/models/{model_id}")
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
# api/download/models/{versionId}
|
||||
m = _CIVITAI_VERSION_RE.search(parsed.path)
|
||||
if m:
|
||||
return m.group(1)
|
||||
|
||||
# ...?modelVersionId=...
|
||||
qs = parse_qs(parsed.query)
|
||||
if "modelVersionId" in qs:
|
||||
return qs["modelVersionId"][0]
|
||||
|
||||
# model page: /models/{id} -> first version
|
||||
m = _CIVITAI_MODEL_RE.search(parsed.path)
|
||||
if m:
|
||||
data = await _civitai_get(f"/models/{m.group(1)}")
|
||||
versions = data.get("modelVersions") or []
|
||||
if not versions:
|
||||
raise ValueError("CivitAI model has no downloadable versions")
|
||||
version_id = str(versions[0]["id"])
|
||||
return str(versions[0]["id"])
|
||||
|
||||
# Resolve the version to a concrete file + type.
|
||||
async with httpx.AsyncClient(timeout=30, headers=headers) as client:
|
||||
resp = await client.get(
|
||||
f"https://civitai.com/api/v1/model-versions/{version_id}")
|
||||
resp.raise_for_status()
|
||||
version = resp.json()
|
||||
raise ValueError("Unrecognized CivitAI URL")
|
||||
|
||||
|
||||
async def _resolve_civitai(url: str) -> Resolved:
|
||||
"""Resolve any CivitAI URL to a concrete file via the model-versions API.
|
||||
|
||||
Always looks the version up so we get the real filename and model type
|
||||
(works for civitai.com and the civitai.red mirror alike).
|
||||
"""
|
||||
headers = civitai_headers()
|
||||
version_id = await _civitai_version_id(url)
|
||||
version = await _civitai_get(f"/model-versions/{version_id}")
|
||||
|
||||
files = version.get("files") or []
|
||||
# Prefer the primary file, else the first.
|
||||
chosen = next((f for f in files if f.get("primary")), files[0] if files else None)
|
||||
if not chosen:
|
||||
raise ValueError("CivitAI version has no files")
|
||||
@@ -133,10 +151,93 @@ async def _resolve_civitai(url: str) -> Resolved:
|
||||
civ_type = (version.get("model") or {}).get("type") or "Other"
|
||||
model_type = CIVITAI_TYPE_MAP.get(civ_type, "other")
|
||||
|
||||
download_url = chosen.get("downloadUrl") or \
|
||||
f"https://civitai.com/api/download/models/{version_id}"
|
||||
|
||||
return Resolved(
|
||||
download_url=chosen["downloadUrl"],
|
||||
download_url=download_url,
|
||||
source="civitai",
|
||||
headers=headers,
|
||||
filename=chosen.get("name") or _filename_from_url(chosen["downloadUrl"]),
|
||||
filename=chosen.get("name") or _filename_from_url(download_url),
|
||||
model_type=model_type,
|
||||
)
|
||||
|
||||
|
||||
# ---- catalog search -------------------------------------------------------
|
||||
|
||||
def _pick_thumbnail(version: dict) -> Optional[dict]:
|
||||
"""Return {url, type} of a representative preview for a model version."""
|
||||
images = version.get("images") or []
|
||||
for im in images:
|
||||
if im.get("type") == "image":
|
||||
return {"url": im.get("url"), "type": "image"}
|
||||
if images:
|
||||
return {"url": images[0].get("url"), "type": images[0].get("type", "image")}
|
||||
return None
|
||||
|
||||
|
||||
def _to_card(item: dict) -> dict:
|
||||
versions = item.get("modelVersions") or []
|
||||
v0 = versions[0] if versions else {}
|
||||
thumb = _pick_thumbnail(v0) if v0 else None
|
||||
stats = item.get("stats") or {}
|
||||
return {
|
||||
"id": item.get("id"),
|
||||
"name": item.get("name"),
|
||||
"type": item.get("type"),
|
||||
"type_key": CIVITAI_TYPE_MAP.get(item.get("type") or "", "other"),
|
||||
"nsfw": item.get("nsfw", False),
|
||||
"creator": (item.get("creator") or {}).get("username"),
|
||||
"downloads": stats.get("downloadCount"),
|
||||
"thumbnail": thumb.get("url") if thumb else None,
|
||||
"thumbnail_type": thumb.get("type") if thumb else None,
|
||||
"primary_version_id": v0.get("id"),
|
||||
"versions": [
|
||||
{"id": v.get("id"), "name": v.get("name")} for v in versions
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
async def civitai_search(query: Optional[str] = None,
|
||||
types: Optional[list[str]] = None,
|
||||
sort: Optional[str] = None,
|
||||
period: Optional[str] = None,
|
||||
nsfw: Optional[bool] = None,
|
||||
page: Optional[int] = None,
|
||||
cursor: Optional[str] = None,
|
||||
limit: int = 24) -> dict:
|
||||
"""Search the CivitAI model catalog, returning normalized cards + paging.
|
||||
|
||||
CivitAI uses cursor paging whenever a `query` is supplied (page is ignored),
|
||||
and page paging otherwise, so we support and return both.
|
||||
"""
|
||||
params: dict = {"limit": max(1, min(limit, 100))}
|
||||
if query:
|
||||
params["query"] = query
|
||||
if types:
|
||||
params["types"] = types # httpx serializes lists as repeated params
|
||||
if sort:
|
||||
params["sort"] = sort
|
||||
if period:
|
||||
params["period"] = period
|
||||
if nsfw is not None:
|
||||
params["nsfw"] = "true" if nsfw else "false"
|
||||
if cursor:
|
||||
params["cursor"] = cursor
|
||||
elif page:
|
||||
params["page"] = page
|
||||
|
||||
async with httpx.AsyncClient(timeout=30, headers=civitai_headers()) as client:
|
||||
resp = await client.get(f"{CIVITAI_API}/models", params=params)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
|
||||
items = [_to_card(it) for it in (data.get("items") or [])]
|
||||
meta = data.get("metadata") or {}
|
||||
return {
|
||||
"items": items,
|
||||
"nextPage": meta.get("nextPage"),
|
||||
"nextCursor": meta.get("nextCursor"),
|
||||
"currentPage": meta.get("currentPage"),
|
||||
"totalPages": meta.get("totalPages"),
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user