"""SparkyUI Model Manager - FastAPI app, API routes, and static UI.""" from __future__ import annotations import asyncio import os from pathlib import Path from typing import Optional from fastapi import FastAPI, HTTPException from fastapi.responses import FileResponse from fastapi.staticfiles import StaticFiles from pydantic import BaseModel from . import db, downloader, registries from .config import ( KEY_BY_FOLDER, MODEL_TYPES, MODELS_DIR, TYPE_BY_KEY, ensure_dirs, ) STATIC_DIR = Path(__file__).parent / "static" app = FastAPI(title="SparkyUI Model Manager") @app.on_event("startup") def _startup() -> None: ensure_dirs() db.init_db() # Any download left mid-flight by a restart is no longer running. for d in db.list_downloads(): if d["status"] in ("queued", "downloading"): db.update_download(d["id"], status="failed", error="Interrupted by restart") # ---- request models ------------------------------------------------------- class SettingsIn(BaseModel): civitai_api_key: Optional[str] = None huggingface_token: Optional[str] = None class DownloadIn(BaseModel): url: str model_type: Optional[str] = None filename: Optional[str] = None # ---- model types & installed models --------------------------------------- @app.get("/api/model-types") def model_types() -> list[dict]: return MODEL_TYPES @app.get("/api/models") def list_models() -> list[dict]: results: list[dict] = [] for folder, type_key in KEY_BY_FOLDER.items(): base = MODELS_DIR / folder if not base.is_dir(): continue for entry in sorted(base.iterdir()): if entry.name.startswith(".") or entry.name.endswith(".part"): continue if not entry.is_file(): continue stat = entry.stat() results.append({ "name": entry.name, "type": type_key, "type_label": TYPE_BY_KEY[type_key]["label"], "folder": folder, "size": stat.st_size, "mtime": stat.st_mtime, }) return results class DeleteModelIn(BaseModel): folder: str name: str @app.delete("/api/models") def delete_model(body: DeleteModelIn) -> dict: if body.folder not in KEY_BY_FOLDER: raise HTTPException(400, "Unknown model folder") name = os.path.basename(body.name) base = (MODELS_DIR / body.folder).resolve() target = (base / name).resolve() if not str(target).startswith(str(base) + os.sep): raise HTTPException(400, "Invalid path") if not target.is_file(): raise HTTPException(404, "File not found") target.unlink() return {"deleted": name} # ---- settings ------------------------------------------------------------- @app.get("/api/settings") def get_settings() -> dict: # Never return the secrets themselves, only whether they are configured. return { "civitai_api_key_set": bool(db.get_setting("civitai_api_key")), "huggingface_token_set": bool(db.get_setting("huggingface_token")), } @app.post("/api/settings") def save_settings(body: SettingsIn) -> dict: # `None` leaves a value untouched; empty string clears it. if body.civitai_api_key is not None: db.set_setting("civitai_api_key", body.civitai_api_key) if body.huggingface_token is not None: db.set_setting("huggingface_token", body.huggingface_token) return get_settings() # ---- downloads ------------------------------------------------------------ @app.get("/api/downloads") def get_downloads() -> list[dict]: return db.list_downloads() @app.post("/api/downloads") async def start_download(body: DownloadIn) -> dict: url = body.url.strip() if not url: raise HTTPException(400, "URL is required") try: resolved = await registries.resolve(url) except Exception as exc: # noqa: BLE001 - report resolution failure to the user raise HTTPException(400, f"Could not resolve URL: {exc}") # Pick the model type: explicit user choice wins, else registry suggestion. model_type = body.model_type or resolved.model_type if not model_type: raise HTTPException( 400, "Please choose a model type for this download.") if model_type not in TYPE_BY_KEY: raise HTTPException(400, "Unknown model type") filename = body.filename or resolved.filename or "" download_id = db.create_download( url=resolved.download_url, source=resolved.source, model_type=model_type, filename=filename, dest_path="", ) asyncio.create_task(downloader.run_download( download_id=download_id, url=resolved.download_url, headers=resolved.headers, model_type=model_type, filename=filename or None, )) row = db.get_download(download_id) return row or {"id": download_id} @app.delete("/api/downloads/{download_id}") def delete_download(download_id: int) -> dict: row = db.get_download(download_id) if not row: raise HTTPException(404, "Download not found") # If it's running, signal cancel; the task will mark it canceled. if downloader.request_cancel(download_id): return {"canceled": download_id} db.delete_download(download_id) return {"removed": download_id} # ---- static UI ------------------------------------------------------------ @app.get("/") def index() -> FileResponse: return FileResponse(STATIC_DIR / "index.html") app.mount("/", StaticFiles(directory=str(STATIC_DIR)), name="static")