505f212c4d
CivitAI Early Access versions require purchased access and otherwise fail with 401. Surface version `availability` from the API as an `early_access` flag (per card and per version), show an amber "EARLY ACCESS" tag on the card, label such entries in the version dropdown, and warn before attempting to download one. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
256 lines
8.1 KiB
Python
256 lines
8.1 KiB
Python
"""Resolve a user-supplied URL into a concrete download (direct URL + auth + metadata)
|
|
and provide CivitAI catalog search.
|
|
|
|
Supports direct URLs, CivitAI (model/version pages, civitai.red mirror, or api
|
|
download links), and HuggingFace `resolve` URLs. API keys are read from settings.
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import re
|
|
from dataclasses import dataclass, field
|
|
from typing import Optional
|
|
from urllib.parse import unquote, urlparse, parse_qs
|
|
|
|
import httpx
|
|
|
|
from . import db
|
|
|
|
CIVITAI_API = "https://civitai.com/api/v1"
|
|
|
|
# CivitAI model `type` -> our model-type key.
|
|
CIVITAI_TYPE_MAP = {
|
|
"Checkpoint": "checkpoint",
|
|
"LORA": "lora",
|
|
"LoCon": "lora",
|
|
"DoRA": "lora",
|
|
"TextualInversion": "embedding",
|
|
"Hypernetwork": "hypernetwork",
|
|
"VAE": "vae",
|
|
"Controlnet": "controlnet",
|
|
"Upscaler": "upscaler",
|
|
"MotionModule": "other",
|
|
"Poses": "other",
|
|
"Wildcards": "other",
|
|
"Workflows": "other",
|
|
"Other": "other",
|
|
}
|
|
|
|
|
|
@dataclass
|
|
class Resolved:
|
|
download_url: str
|
|
source: str
|
|
headers: dict[str, str] = field(default_factory=dict)
|
|
filename: Optional[str] = None
|
|
model_type: Optional[str] = None # suggested type if the registry tells us
|
|
|
|
|
|
def _is_civitai_host(host: str) -> bool:
|
|
# Matches civitai.com, civitai.red, and any civitai.* mirror.
|
|
return "civitai" in host
|
|
|
|
|
|
def detect_source(url: str) -> str:
|
|
host = (urlparse(url).hostname or "").lower()
|
|
if _is_civitai_host(host):
|
|
return "civitai"
|
|
if "huggingface.co" in host or "hf.co" in host:
|
|
return "huggingface"
|
|
return "direct"
|
|
|
|
|
|
def _filename_from_url(url: str) -> Optional[str]:
|
|
path = urlparse(url).path
|
|
name = unquote(path.rsplit("/", 1)[-1]) if path else ""
|
|
return name or None
|
|
|
|
|
|
def civitai_headers() -> dict[str, str]:
|
|
headers: dict[str, str] = {}
|
|
key = db.get_setting("civitai_api_key")
|
|
if key:
|
|
headers["Authorization"] = f"Bearer {key}"
|
|
return headers
|
|
|
|
|
|
async def resolve(url: str) -> Resolved:
|
|
source = detect_source(url)
|
|
if source == "civitai":
|
|
return await _resolve_civitai(url)
|
|
if source == "huggingface":
|
|
return _resolve_huggingface(url)
|
|
return _resolve_direct(url)
|
|
|
|
|
|
def _resolve_direct(url: str) -> Resolved:
|
|
return Resolved(download_url=url, source="direct",
|
|
filename=_filename_from_url(url))
|
|
|
|
|
|
def _resolve_huggingface(url: str) -> Resolved:
|
|
headers: dict[str, str] = {}
|
|
token = db.get_setting("huggingface_token")
|
|
if token:
|
|
headers["Authorization"] = f"Bearer {token}"
|
|
return Resolved(download_url=url, source="huggingface", headers=headers,
|
|
filename=_filename_from_url(url))
|
|
|
|
|
|
_CIVITAI_VERSION_RE = re.compile(r"/api/download/models/(\d+)")
|
|
_CIVITAI_MODEL_RE = re.compile(r"/models/(\d+)")
|
|
|
|
|
|
async def _civitai_get(path: str) -> dict:
|
|
async with httpx.AsyncClient(timeout=30, headers=civitai_headers()) as client:
|
|
resp = await client.get(f"{CIVITAI_API}{path}")
|
|
resp.raise_for_status()
|
|
return resp.json()
|
|
|
|
|
|
async def _civitai_version_id(url: str) -> str:
|
|
"""Extract or look up the model-version id from any CivitAI URL form."""
|
|
parsed = urlparse(url)
|
|
|
|
# api/download/models/{versionId}
|
|
m = _CIVITAI_VERSION_RE.search(parsed.path)
|
|
if m:
|
|
return m.group(1)
|
|
|
|
# ...?modelVersionId=...
|
|
qs = parse_qs(parsed.query)
|
|
if "modelVersionId" in qs:
|
|
return qs["modelVersionId"][0]
|
|
|
|
# model page: /models/{id} -> first version
|
|
m = _CIVITAI_MODEL_RE.search(parsed.path)
|
|
if m:
|
|
data = await _civitai_get(f"/models/{m.group(1)}")
|
|
versions = data.get("modelVersions") or []
|
|
if not versions:
|
|
raise ValueError("CivitAI model has no downloadable versions")
|
|
return str(versions[0]["id"])
|
|
|
|
raise ValueError("Unrecognized CivitAI URL")
|
|
|
|
|
|
async def _resolve_civitai(url: str) -> Resolved:
|
|
"""Resolve any CivitAI URL to a concrete file via the model-versions API.
|
|
|
|
Always looks the version up so we get the real filename and model type
|
|
(works for civitai.com and the civitai.red mirror alike).
|
|
"""
|
|
headers = civitai_headers()
|
|
version_id = await _civitai_version_id(url)
|
|
version = await _civitai_get(f"/model-versions/{version_id}")
|
|
|
|
files = version.get("files") or []
|
|
chosen = next((f for f in files if f.get("primary")), files[0] if files else None)
|
|
if not chosen:
|
|
raise ValueError("CivitAI version has no files")
|
|
|
|
civ_type = (version.get("model") or {}).get("type") or "Other"
|
|
model_type = CIVITAI_TYPE_MAP.get(civ_type, "other")
|
|
|
|
download_url = chosen.get("downloadUrl") or \
|
|
f"https://civitai.com/api/download/models/{version_id}"
|
|
|
|
return Resolved(
|
|
download_url=download_url,
|
|
source="civitai",
|
|
headers=headers,
|
|
filename=chosen.get("name") or _filename_from_url(download_url),
|
|
model_type=model_type,
|
|
)
|
|
|
|
|
|
# ---- catalog search -------------------------------------------------------
|
|
|
|
def _pick_thumbnail(version: dict) -> Optional[dict]:
|
|
"""Return {url, type} of a representative preview for a model version."""
|
|
images = version.get("images") or []
|
|
for im in images:
|
|
if im.get("type") == "image":
|
|
return {"url": im.get("url"), "type": "image"}
|
|
if images:
|
|
return {"url": images[0].get("url"), "type": images[0].get("type", "image")}
|
|
return None
|
|
|
|
|
|
def _is_early_access(version: dict) -> bool:
|
|
"""A version is early-access (download requires purchase/login) when its
|
|
availability is anything other than Public."""
|
|
return (version.get("availability") or "Public") != "Public"
|
|
|
|
|
|
def _to_card(item: dict) -> dict:
|
|
versions = item.get("modelVersions") or []
|
|
v0 = versions[0] if versions else {}
|
|
thumb = _pick_thumbnail(v0) if v0 else None
|
|
stats = item.get("stats") or {}
|
|
return {
|
|
"id": item.get("id"),
|
|
"name": item.get("name"),
|
|
"type": item.get("type"),
|
|
"type_key": CIVITAI_TYPE_MAP.get(item.get("type") or "", "other"),
|
|
"nsfw": item.get("nsfw", False),
|
|
"creator": (item.get("creator") or {}).get("username"),
|
|
"downloads": stats.get("downloadCount"),
|
|
"thumbnail": thumb.get("url") if thumb else None,
|
|
"thumbnail_type": thumb.get("type") if thumb else None,
|
|
"primary_version_id": v0.get("id"),
|
|
"early_access": _is_early_access(v0),
|
|
"versions": [
|
|
{
|
|
"id": v.get("id"),
|
|
"name": v.get("name"),
|
|
"early_access": _is_early_access(v),
|
|
}
|
|
for v in versions
|
|
],
|
|
}
|
|
|
|
|
|
async def civitai_search(query: Optional[str] = None,
|
|
types: Optional[list[str]] = None,
|
|
sort: Optional[str] = None,
|
|
period: Optional[str] = None,
|
|
nsfw: Optional[bool] = None,
|
|
page: Optional[int] = None,
|
|
cursor: Optional[str] = None,
|
|
limit: int = 24) -> dict:
|
|
"""Search the CivitAI model catalog, returning normalized cards + paging.
|
|
|
|
CivitAI uses cursor paging whenever a `query` is supplied (page is ignored),
|
|
and page paging otherwise, so we support and return both.
|
|
"""
|
|
params: dict = {"limit": max(1, min(limit, 100))}
|
|
if query:
|
|
params["query"] = query
|
|
if types:
|
|
params["types"] = types # httpx serializes lists as repeated params
|
|
if sort:
|
|
params["sort"] = sort
|
|
if period:
|
|
params["period"] = period
|
|
if nsfw is not None:
|
|
params["nsfw"] = "true" if nsfw else "false"
|
|
if cursor:
|
|
params["cursor"] = cursor
|
|
elif page:
|
|
params["page"] = page
|
|
|
|
async with httpx.AsyncClient(timeout=30, headers=civitai_headers()) as client:
|
|
resp = await client.get(f"{CIVITAI_API}/models", params=params)
|
|
resp.raise_for_status()
|
|
data = resp.json()
|
|
|
|
items = [_to_card(it) for it in (data.get("items") or [])]
|
|
meta = data.get("metadata") or {}
|
|
return {
|
|
"items": items,
|
|
"nextPage": meta.get("nextPage"),
|
|
"nextCursor": meta.get("nextCursor"),
|
|
"currentPage": meta.get("currentPage"),
|
|
"totalPages": meta.get("totalPages"),
|
|
}
|