"""Resolve a user-supplied URL into a concrete download (direct URL + auth + metadata) and provide CivitAI catalog search. Supports direct URLs, CivitAI (model/version pages, civitai.red mirror, or api download links), and HuggingFace `resolve` URLs. API keys are read from settings. """ from __future__ import annotations import re from dataclasses import dataclass, field from typing import Optional from urllib.parse import unquote, urlparse, parse_qs import httpx from . import db CIVITAI_API = "https://civitai.com/api/v1" # CivitAI model `type` -> our model-type key. CIVITAI_TYPE_MAP = { "Checkpoint": "checkpoint", "LORA": "lora", "LoCon": "lora", "DoRA": "lora", "TextualInversion": "embedding", "Hypernetwork": "hypernetwork", "VAE": "vae", "Controlnet": "controlnet", "Upscaler": "upscaler", "MotionModule": "other", "Poses": "other", "Wildcards": "other", "Workflows": "other", "Other": "other", } @dataclass class Resolved: download_url: str source: str headers: dict[str, str] = field(default_factory=dict) filename: Optional[str] = None model_type: Optional[str] = None # suggested type if the registry tells us def _is_civitai_host(host: str) -> bool: # Matches civitai.com, civitai.red, and any civitai.* mirror. return "civitai" in host def detect_source(url: str) -> str: host = (urlparse(url).hostname or "").lower() if _is_civitai_host(host): return "civitai" if "huggingface.co" in host or "hf.co" in host: return "huggingface" return "direct" def _filename_from_url(url: str) -> Optional[str]: path = urlparse(url).path name = unquote(path.rsplit("/", 1)[-1]) if path else "" return name or None def civitai_headers() -> dict[str, str]: headers: dict[str, str] = {} key = db.get_setting("civitai_api_key") if key: headers["Authorization"] = f"Bearer {key}" return headers async def resolve(url: str) -> Resolved: source = detect_source(url) if source == "civitai": return await _resolve_civitai(url) if source == "huggingface": return _resolve_huggingface(url) return _resolve_direct(url) def _resolve_direct(url: str) -> Resolved: return Resolved(download_url=url, source="direct", filename=_filename_from_url(url)) def _resolve_huggingface(url: str) -> Resolved: headers: dict[str, str] = {} token = db.get_setting("huggingface_token") if token: headers["Authorization"] = f"Bearer {token}" return Resolved(download_url=url, source="huggingface", headers=headers, filename=_filename_from_url(url)) _CIVITAI_VERSION_RE = re.compile(r"/api/download/models/(\d+)") _CIVITAI_MODEL_RE = re.compile(r"/models/(\d+)") async def _civitai_get(path: str) -> dict: async with httpx.AsyncClient(timeout=30, headers=civitai_headers()) as client: resp = await client.get(f"{CIVITAI_API}{path}") resp.raise_for_status() return resp.json() async def _civitai_version_id(url: str) -> str: """Extract or look up the model-version id from any CivitAI URL form.""" parsed = urlparse(url) # api/download/models/{versionId} m = _CIVITAI_VERSION_RE.search(parsed.path) if m: return m.group(1) # ...?modelVersionId=... qs = parse_qs(parsed.query) if "modelVersionId" in qs: return qs["modelVersionId"][0] # model page: /models/{id} -> first version m = _CIVITAI_MODEL_RE.search(parsed.path) if m: data = await _civitai_get(f"/models/{m.group(1)}") versions = data.get("modelVersions") or [] if not versions: raise ValueError("CivitAI model has no downloadable versions") return str(versions[0]["id"]) raise ValueError("Unrecognized CivitAI URL") async def _resolve_civitai(url: str) -> Resolved: """Resolve any CivitAI URL to a concrete file via the model-versions API. Always looks the version up so we get the real filename and model type (works for civitai.com and the civitai.red mirror alike). """ headers = civitai_headers() version_id = await _civitai_version_id(url) version = await _civitai_get(f"/model-versions/{version_id}") files = version.get("files") or [] chosen = next((f for f in files if f.get("primary")), files[0] if files else None) if not chosen: raise ValueError("CivitAI version has no files") civ_type = (version.get("model") or {}).get("type") or "Other" model_type = CIVITAI_TYPE_MAP.get(civ_type, "other") download_url = chosen.get("downloadUrl") or \ f"https://civitai.com/api/download/models/{version_id}" return Resolved( download_url=download_url, source="civitai", headers=headers, filename=chosen.get("name") or _filename_from_url(download_url), model_type=model_type, ) # ---- catalog search ------------------------------------------------------- def _pick_thumbnail(version: dict) -> Optional[dict]: """Return {url, type} of a representative preview for a model version.""" images = version.get("images") or [] for im in images: if im.get("type") == "image": return {"url": im.get("url"), "type": "image"} if images: return {"url": images[0].get("url"), "type": images[0].get("type", "image")} return None def _to_card(item: dict) -> dict: versions = item.get("modelVersions") or [] v0 = versions[0] if versions else {} thumb = _pick_thumbnail(v0) if v0 else None stats = item.get("stats") or {} return { "id": item.get("id"), "name": item.get("name"), "type": item.get("type"), "type_key": CIVITAI_TYPE_MAP.get(item.get("type") or "", "other"), "nsfw": item.get("nsfw", False), "creator": (item.get("creator") or {}).get("username"), "downloads": stats.get("downloadCount"), "thumbnail": thumb.get("url") if thumb else None, "thumbnail_type": thumb.get("type") if thumb else None, "primary_version_id": v0.get("id"), "versions": [ {"id": v.get("id"), "name": v.get("name")} for v in versions ], } async def civitai_search(query: Optional[str] = None, types: Optional[list[str]] = None, sort: Optional[str] = None, period: Optional[str] = None, nsfw: Optional[bool] = None, page: Optional[int] = None, cursor: Optional[str] = None, limit: int = 24) -> dict: """Search the CivitAI model catalog, returning normalized cards + paging. CivitAI uses cursor paging whenever a `query` is supplied (page is ignored), and page paging otherwise, so we support and return both. """ params: dict = {"limit": max(1, min(limit, 100))} if query: params["query"] = query if types: params["types"] = types # httpx serializes lists as repeated params if sort: params["sort"] = sort if period: params["period"] = period if nsfw is not None: params["nsfw"] = "true" if nsfw else "false" if cursor: params["cursor"] = cursor elif page: params["page"] = page async with httpx.AsyncClient(timeout=30, headers=civitai_headers()) as client: resp = await client.get(f"{CIVITAI_API}/models", params=params) resp.raise_for_status() data = resp.json() items = [_to_card(it) for it in (data.get("items") or [])] meta = data.get("metadata") or {} return { "items": items, "nextPage": meta.get("nextPage"), "nextCursor": meta.get("nextCursor"), "currentPage": meta.get("currentPage"), "totalPages": meta.get("totalPages"), }