"""Resolve a user-supplied URL into a concrete download (direct URL + auth + metadata). Supports direct URLs, CivitAI (model/version pages or api download links), and HuggingFace `resolve` URLs. API keys are read from the settings table. """ from __future__ import annotations import re from dataclasses import dataclass, field from typing import Optional from urllib.parse import unquote, urlparse, parse_qs import httpx from . import db # CivitAI model `type` -> our model-type key. CIVITAI_TYPE_MAP = { "Checkpoint": "checkpoint", "LORA": "lora", "LoCon": "lora", "DoRA": "lora", "TextualInversion": "embedding", "Hypernetwork": "hypernetwork", "VAE": "vae", "Controlnet": "controlnet", "Upscaler": "upscaler", "MotionModule": "other", "Poses": "other", "Wildcards": "other", "Other": "other", } @dataclass class Resolved: download_url: str source: str headers: dict[str, str] = field(default_factory=dict) filename: Optional[str] = None model_type: Optional[str] = None # suggested type if the registry tells us def detect_source(url: str) -> str: host = (urlparse(url).hostname or "").lower() if "civitai.com" in host: return "civitai" if "huggingface.co" in host or "hf.co" in host: return "huggingface" return "direct" def _filename_from_url(url: str) -> Optional[str]: path = urlparse(url).path name = unquote(path.rsplit("/", 1)[-1]) if path else "" return name or None async def resolve(url: str) -> Resolved: source = detect_source(url) if source == "civitai": return await _resolve_civitai(url) if source == "huggingface": return _resolve_huggingface(url) return _resolve_direct(url) def _resolve_direct(url: str) -> Resolved: return Resolved(download_url=url, source="direct", filename=_filename_from_url(url)) def _resolve_huggingface(url: str) -> Resolved: headers: dict[str, str] = {} token = db.get_setting("huggingface_token") if token: headers["Authorization"] = f"Bearer {token}" return Resolved(download_url=url, source="huggingface", headers=headers, filename=_filename_from_url(url)) _CIVITAI_VERSION_RE = re.compile(r"/api/download/models/(\d+)") _CIVITAI_MODEL_RE = re.compile(r"/models/(\d+)") async def _resolve_civitai(url: str) -> Resolved: headers: dict[str, str] = {} api_key = db.get_setting("civitai_api_key") if api_key: headers["Authorization"] = f"Bearer {api_key}" # Case 1: already a direct api/download URL -> use as-is. if _CIVITAI_VERSION_RE.search(urlparse(url).path): return Resolved(download_url=url, source="civitai", headers=headers, filename=_filename_from_url(url)) # Case 2: a model page URL. Find the version id (explicit query or first version). parsed = urlparse(url) qs = parse_qs(parsed.query) version_id: Optional[str] = None if "modelVersionId" in qs: version_id = qs["modelVersionId"][0] if version_id is None: m = _CIVITAI_MODEL_RE.search(parsed.path) if not m: # Can't understand it; fall back to treating it as a direct link. return Resolved(download_url=url, source="civitai", headers=headers, filename=_filename_from_url(url)) model_id = m.group(1) async with httpx.AsyncClient(timeout=30, headers=headers) as client: resp = await client.get(f"https://civitai.com/api/v1/models/{model_id}") resp.raise_for_status() data = resp.json() versions = data.get("modelVersions") or [] if not versions: raise ValueError("CivitAI model has no downloadable versions") version_id = str(versions[0]["id"]) # Resolve the version to a concrete file + type. async with httpx.AsyncClient(timeout=30, headers=headers) as client: resp = await client.get( f"https://civitai.com/api/v1/model-versions/{version_id}") resp.raise_for_status() version = resp.json() files = version.get("files") or [] # Prefer the primary file, else the first. chosen = next((f for f in files if f.get("primary")), files[0] if files else None) if not chosen: raise ValueError("CivitAI version has no files") civ_type = (version.get("model") or {}).get("type") or "Other" model_type = CIVITAI_TYPE_MAP.get(civ_type, "other") return Resolved( download_url=chosen["downloadUrl"], source="civitai", headers=headers, filename=chosen.get("name") or _filename_from_url(chosen["downloadUrl"]), model_type=model_type, )