6871a6d460
- Gallery: DELETE /api/gallery/all removes every image under output/; "Delete all" button with in-app confirm and a deleted/failed count. - Downloads: surface a clear, actionable message when CivitAI/HuggingFace returns 401/403 (model requires login/early-access, or the key/token lacks access) instead of a bare error, both at resolve time and during the download stream. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
328 lines
10 KiB
Python
328 lines
10 KiB
Python
"""SparkyUI Model Manager - FastAPI app, API routes, and static UI."""
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import os
|
|
from pathlib import Path
|
|
from typing import Optional
|
|
|
|
import httpx
|
|
from fastapi import FastAPI, HTTPException
|
|
from fastapi.responses import FileResponse
|
|
from fastapi.staticfiles import StaticFiles
|
|
from pydantic import BaseModel
|
|
|
|
from . import db, downloader, registries
|
|
from .config import (
|
|
COMFYUI_PORT,
|
|
COMFYUIMINI_PORT,
|
|
IMAGE_EXTS,
|
|
KEY_BY_FOLDER,
|
|
MODEL_TYPES,
|
|
MODELS_DIR,
|
|
OUTPUT_DIR,
|
|
TYPE_BY_KEY,
|
|
ensure_dirs,
|
|
safe_output_path,
|
|
)
|
|
|
|
STATIC_DIR = Path(__file__).parent / "static"
|
|
|
|
app = FastAPI(title="SparkyUI Model Manager")
|
|
|
|
|
|
@app.on_event("startup")
|
|
def _startup() -> None:
|
|
ensure_dirs()
|
|
db.init_db()
|
|
# Any download left mid-flight by a restart is no longer running.
|
|
for d in db.list_downloads():
|
|
if d["status"] in ("queued", "downloading"):
|
|
db.update_download(d["id"], status="failed",
|
|
error="Interrupted by restart")
|
|
|
|
|
|
# ---- request models -------------------------------------------------------
|
|
|
|
class SettingsIn(BaseModel):
|
|
civitai_api_key: Optional[str] = None
|
|
huggingface_token: Optional[str] = None
|
|
|
|
|
|
class DownloadIn(BaseModel):
|
|
url: str
|
|
model_type: Optional[str] = None
|
|
filename: Optional[str] = None
|
|
|
|
|
|
# ---- model types & installed models ---------------------------------------
|
|
|
|
@app.get("/api/model-types")
|
|
def model_types() -> list[dict]:
|
|
return MODEL_TYPES
|
|
|
|
|
|
@app.get("/api/models")
|
|
def list_models() -> list[dict]:
|
|
results: list[dict] = []
|
|
for folder, type_key in KEY_BY_FOLDER.items():
|
|
base = MODELS_DIR / folder
|
|
if not base.is_dir():
|
|
continue
|
|
for entry in sorted(base.iterdir()):
|
|
if entry.name.startswith(".") or entry.name.endswith(".part"):
|
|
continue
|
|
if not entry.is_file():
|
|
continue
|
|
stat = entry.stat()
|
|
results.append({
|
|
"name": entry.name,
|
|
"type": type_key,
|
|
"type_label": TYPE_BY_KEY[type_key]["label"],
|
|
"folder": folder,
|
|
"size": stat.st_size,
|
|
"mtime": stat.st_mtime,
|
|
})
|
|
return results
|
|
|
|
|
|
class DeleteModelIn(BaseModel):
|
|
folder: str
|
|
name: str
|
|
|
|
|
|
@app.delete("/api/models")
|
|
def delete_model(body: DeleteModelIn) -> dict:
|
|
if body.folder not in KEY_BY_FOLDER:
|
|
raise HTTPException(400, "Unknown model folder")
|
|
name = os.path.basename(body.name)
|
|
base = (MODELS_DIR / body.folder).resolve()
|
|
target = (base / name).resolve()
|
|
if not str(target).startswith(str(base) + os.sep):
|
|
raise HTTPException(400, "Invalid path")
|
|
if not target.is_file():
|
|
raise HTTPException(404, "File not found")
|
|
target.unlink()
|
|
return {"deleted": name}
|
|
|
|
|
|
# ---- settings -------------------------------------------------------------
|
|
|
|
@app.get("/api/settings")
|
|
def get_settings() -> dict:
|
|
# Never return the secrets themselves, only whether they are configured.
|
|
return {
|
|
"civitai_api_key_set": bool(db.get_setting("civitai_api_key")),
|
|
"huggingface_token_set": bool(db.get_setting("huggingface_token")),
|
|
}
|
|
|
|
|
|
@app.post("/api/settings")
|
|
def save_settings(body: SettingsIn) -> dict:
|
|
# `None` leaves a value untouched; empty string clears it.
|
|
if body.civitai_api_key is not None:
|
|
db.set_setting("civitai_api_key", body.civitai_api_key)
|
|
if body.huggingface_token is not None:
|
|
db.set_setting("huggingface_token", body.huggingface_token)
|
|
return get_settings()
|
|
|
|
|
|
# ---- downloads ------------------------------------------------------------
|
|
|
|
@app.get("/api/downloads")
|
|
def get_downloads() -> list[dict]:
|
|
return db.list_downloads()
|
|
|
|
|
|
@app.post("/api/downloads")
|
|
async def start_download(body: DownloadIn) -> dict:
|
|
url = body.url.strip()
|
|
if not url:
|
|
raise HTTPException(400, "URL is required")
|
|
|
|
try:
|
|
resolved = await registries.resolve(url)
|
|
except httpx.HTTPStatusError as exc:
|
|
code = exc.response.status_code
|
|
if code in (401, 403):
|
|
raise HTTPException(
|
|
400,
|
|
f"The source returned {code} for this model. It likely requires "
|
|
"being logged in / early access on that site, or your API key/token "
|
|
"doesn't have access to it. (Public models download fine.)")
|
|
raise HTTPException(400, f"Could not resolve URL (HTTP {code})")
|
|
except Exception as exc: # noqa: BLE001 - report resolution failure to the user
|
|
raise HTTPException(400, f"Could not resolve URL: {exc}")
|
|
|
|
# Pick the model type: explicit user choice wins, else registry suggestion.
|
|
model_type = body.model_type or resolved.model_type
|
|
if not model_type:
|
|
raise HTTPException(
|
|
400, "Please choose a model type for this download.")
|
|
if model_type not in TYPE_BY_KEY:
|
|
raise HTTPException(400, "Unknown model type")
|
|
|
|
filename = body.filename or resolved.filename or ""
|
|
|
|
download_id = db.create_download(
|
|
url=resolved.download_url,
|
|
source=resolved.source,
|
|
model_type=model_type,
|
|
filename=filename,
|
|
dest_path="",
|
|
)
|
|
|
|
asyncio.create_task(downloader.run_download(
|
|
download_id=download_id,
|
|
url=resolved.download_url,
|
|
headers=resolved.headers,
|
|
model_type=model_type,
|
|
filename=filename or None,
|
|
))
|
|
|
|
row = db.get_download(download_id)
|
|
return row or {"id": download_id}
|
|
|
|
|
|
# ---- CivitAI catalog browse ----------------------------------------------
|
|
|
|
@app.get("/api/civitai/search")
|
|
async def civitai_search(
|
|
query: Optional[str] = None,
|
|
types: Optional[str] = None, # comma-separated CivitAI types
|
|
sort: Optional[str] = "Most Downloaded",
|
|
period: Optional[str] = "AllTime",
|
|
nsfw: bool = False,
|
|
page: Optional[int] = None,
|
|
cursor: Optional[str] = None,
|
|
) -> dict:
|
|
if not db.get_setting("civitai_api_key"):
|
|
raise HTTPException(400, "Set your CivitAI API key in Settings first.")
|
|
type_list = [t for t in (types or "").split(",") if t] or None
|
|
try:
|
|
return await registries.civitai_search(
|
|
query=query, types=type_list, sort=sort, period=period,
|
|
nsfw=nsfw, page=page, cursor=cursor)
|
|
except Exception as exc: # noqa: BLE001 - surface API errors to the UI
|
|
raise HTTPException(502, f"CivitAI search failed: {exc}")
|
|
|
|
|
|
class CivitaiDownloadIn(BaseModel):
|
|
version_id: int
|
|
|
|
|
|
@app.post("/api/civitai/download")
|
|
async def civitai_download(body: CivitaiDownloadIn) -> dict:
|
|
url = f"https://civitai.com/api/download/models/{body.version_id}"
|
|
return await start_download(DownloadIn(url=url))
|
|
|
|
|
|
@app.delete("/api/downloads/{download_id}")
|
|
def delete_download(download_id: int) -> dict:
|
|
row = db.get_download(download_id)
|
|
if not row:
|
|
raise HTTPException(404, "Download not found")
|
|
# If it's running, signal cancel; the task will mark it canceled.
|
|
if downloader.request_cancel(download_id):
|
|
return {"canceled": download_id}
|
|
db.delete_download(download_id)
|
|
return {"removed": download_id}
|
|
|
|
|
|
# ---- gallery (generated photos) -------------------------------------------
|
|
|
|
@app.get("/api/gallery")
|
|
def list_gallery(limit: int = 60, offset: int = 0) -> dict:
|
|
"""List generated images under OUTPUT_DIR, newest first, paginated."""
|
|
files: list[dict] = []
|
|
if OUTPUT_DIR.is_dir():
|
|
for path in OUTPUT_DIR.rglob("*"):
|
|
if not path.is_file() or path.suffix.lower() not in IMAGE_EXTS:
|
|
continue
|
|
rel = path.relative_to(OUTPUT_DIR).as_posix()
|
|
stat = path.stat()
|
|
files.append({
|
|
"path": rel,
|
|
"name": path.name,
|
|
"subfolder": path.parent.relative_to(OUTPUT_DIR).as_posix()
|
|
if path.parent != OUTPUT_DIR else "",
|
|
"size": stat.st_size,
|
|
"mtime": stat.st_mtime,
|
|
})
|
|
files.sort(key=lambda f: f["mtime"], reverse=True)
|
|
total = len(files)
|
|
page = files[offset:offset + max(1, min(limit, 500))]
|
|
return {"items": page, "total": total,
|
|
"offset": offset, "returned": len(page)}
|
|
|
|
|
|
@app.delete("/api/gallery/all")
|
|
def delete_all_photos() -> dict:
|
|
"""Permanently delete every image under OUTPUT_DIR."""
|
|
deleted = 0
|
|
failed = 0
|
|
if OUTPUT_DIR.is_dir():
|
|
for path in OUTPUT_DIR.rglob("*"):
|
|
if path.is_file() and path.suffix.lower() in IMAGE_EXTS:
|
|
try:
|
|
path.unlink()
|
|
deleted += 1
|
|
except OSError:
|
|
failed += 1
|
|
return {"deleted": deleted, "failed": failed}
|
|
|
|
|
|
@app.get("/gallery/file")
|
|
def gallery_file(path: str):
|
|
try:
|
|
target = safe_output_path(path)
|
|
except ValueError:
|
|
raise HTTPException(400, "Invalid path")
|
|
if not target.is_file() or target.suffix.lower() not in IMAGE_EXTS:
|
|
raise HTTPException(404, "Image not found")
|
|
return FileResponse(target)
|
|
|
|
|
|
class DeletePhotoIn(BaseModel):
|
|
path: str
|
|
|
|
|
|
@app.delete("/api/gallery")
|
|
def delete_photo(body: DeletePhotoIn) -> dict:
|
|
try:
|
|
target = safe_output_path(body.path)
|
|
except ValueError:
|
|
raise HTTPException(400, "Invalid path")
|
|
if not target.is_file():
|
|
raise HTTPException(404, "Image not found")
|
|
try:
|
|
target.unlink()
|
|
except PermissionError:
|
|
raise HTTPException(
|
|
403,
|
|
"Permission denied (file is in a folder owned by another user). "
|
|
"Run ComfyUI as the same UID, or remove it from a host shell.")
|
|
return {"deleted": body.path}
|
|
|
|
|
|
# ---- device routing / ui config -------------------------------------------
|
|
|
|
@app.get("/api/ui-config")
|
|
def ui_config() -> dict:
|
|
return {"comfyui_port": COMFYUI_PORT, "comfyuimini_port": COMFYUIMINI_PORT}
|
|
|
|
|
|
@app.get("/start")
|
|
def start() -> FileResponse:
|
|
return FileResponse(STATIC_DIR / "start.html")
|
|
|
|
|
|
# ---- static UI ------------------------------------------------------------
|
|
|
|
@app.get("/")
|
|
def index() -> FileResponse:
|
|
return FileResponse(STATIC_DIR / "index.html")
|
|
|
|
|
|
app.mount("/", StaticFiles(directory=str(STATIC_DIR)), name="static")
|