Files
TBNilles 399acabd58 feat(model-manager): "Free GPU memory" button to unload ComfyUI models
ComfyUI caches the last model when RAM is plentiful (unified memory), so
memory doesn't drop after switching models even though models are being
swapped, not accumulated. Add a sidebar "Free GPU memory" button that
proxies ComfyUI's POST /free (unload_models + free_memory) via a new
/api/comfyui/free endpoint (COMFYUI_URL env). Verified it releases ~7GB.
README documents this plus the --disable-smart-memory auto-unload option.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-07 17:14:37 -04:00

345 lines
11 KiB
Python

"""SparkyUI Model Manager - FastAPI app, API routes, and static UI."""
from __future__ import annotations
import asyncio
import os
from pathlib import Path
from typing import Optional
import httpx
from fastapi import FastAPI, HTTPException
from fastapi.responses import FileResponse
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel
from . import db, downloader, registries
from .config import (
COMFYUI_PORT,
COMFYUI_URL,
COMFYUIMINI_PORT,
IMAGE_EXTS,
KEY_BY_FOLDER,
MODEL_TYPES,
MODELS_DIR,
OUTPUT_DIR,
TYPE_BY_KEY,
ensure_dirs,
safe_output_path,
)
STATIC_DIR = Path(__file__).parent / "static"
app = FastAPI(title="SparkyUI Model Manager")
@app.on_event("startup")
def _startup() -> None:
ensure_dirs()
db.init_db()
# Any download left mid-flight by a restart is no longer running.
for d in db.list_downloads():
if d["status"] in ("queued", "downloading"):
db.update_download(d["id"], status="failed",
error="Interrupted by restart")
# ---- request models -------------------------------------------------------
class SettingsIn(BaseModel):
civitai_api_key: Optional[str] = None
huggingface_token: Optional[str] = None
class DownloadIn(BaseModel):
url: str
model_type: Optional[str] = None
filename: Optional[str] = None
# ---- model types & installed models ---------------------------------------
@app.get("/api/model-types")
def model_types() -> list[dict]:
return MODEL_TYPES
@app.get("/api/models")
def list_models() -> list[dict]:
results: list[dict] = []
for folder, type_key in KEY_BY_FOLDER.items():
base = MODELS_DIR / folder
if not base.is_dir():
continue
for entry in sorted(base.iterdir()):
if entry.name.startswith(".") or entry.name.endswith(".part"):
continue
if not entry.is_file():
continue
stat = entry.stat()
results.append({
"name": entry.name,
"type": type_key,
"type_label": TYPE_BY_KEY[type_key]["label"],
"folder": folder,
"size": stat.st_size,
"mtime": stat.st_mtime,
})
return results
class DeleteModelIn(BaseModel):
folder: str
name: str
@app.delete("/api/models")
def delete_model(body: DeleteModelIn) -> dict:
if body.folder not in KEY_BY_FOLDER:
raise HTTPException(400, "Unknown model folder")
name = os.path.basename(body.name)
base = (MODELS_DIR / body.folder).resolve()
target = (base / name).resolve()
if not str(target).startswith(str(base) + os.sep):
raise HTTPException(400, "Invalid path")
if not target.is_file():
raise HTTPException(404, "File not found")
target.unlink()
return {"deleted": name}
# ---- settings -------------------------------------------------------------
@app.get("/api/settings")
def get_settings() -> dict:
# Never return the secrets themselves, only whether they are configured.
return {
"civitai_api_key_set": bool(db.get_setting("civitai_api_key")),
"huggingface_token_set": bool(db.get_setting("huggingface_token")),
}
@app.post("/api/settings")
def save_settings(body: SettingsIn) -> dict:
# `None` leaves a value untouched; empty string clears it.
if body.civitai_api_key is not None:
db.set_setting("civitai_api_key", body.civitai_api_key)
if body.huggingface_token is not None:
db.set_setting("huggingface_token", body.huggingface_token)
return get_settings()
# ---- downloads ------------------------------------------------------------
@app.get("/api/downloads")
def get_downloads() -> list[dict]:
return db.list_downloads()
@app.post("/api/downloads")
async def start_download(body: DownloadIn) -> dict:
url = body.url.strip()
if not url:
raise HTTPException(400, "URL is required")
try:
resolved = await registries.resolve(url)
except httpx.HTTPStatusError as exc:
code = exc.response.status_code
if code in (401, 403):
raise HTTPException(
400,
f"The source returned {code} for this model. It likely requires "
"being logged in / early access on that site, or your API key/token "
"doesn't have access to it. (Public models download fine.)")
raise HTTPException(400, f"Could not resolve URL (HTTP {code})")
except Exception as exc: # noqa: BLE001 - report resolution failure to the user
raise HTTPException(400, f"Could not resolve URL: {exc}")
# Pick the model type: explicit user choice wins, else registry suggestion.
model_type = body.model_type or resolved.model_type
if not model_type:
raise HTTPException(
400, "Please choose a model type for this download.")
if model_type not in TYPE_BY_KEY:
raise HTTPException(400, "Unknown model type")
filename = body.filename or resolved.filename or ""
download_id = db.create_download(
url=resolved.download_url,
source=resolved.source,
model_type=model_type,
filename=filename,
dest_path="",
)
asyncio.create_task(downloader.run_download(
download_id=download_id,
url=resolved.download_url,
headers=resolved.headers,
model_type=model_type,
filename=filename or None,
))
row = db.get_download(download_id)
return row or {"id": download_id}
# ---- CivitAI catalog browse ----------------------------------------------
@app.get("/api/civitai/search")
async def civitai_search(
query: Optional[str] = None,
types: Optional[str] = None, # comma-separated CivitAI types
base_models: Optional[str] = None, # comma-separated CivitAI base models
sort: Optional[str] = "Most Downloaded",
period: Optional[str] = "AllTime",
nsfw: bool = False,
page: Optional[int] = None,
cursor: Optional[str] = None,
) -> dict:
if not db.get_setting("civitai_api_key"):
raise HTTPException(400, "Set your CivitAI API key in Settings first.")
type_list = [t for t in (types or "").split(",") if t] or None
base_model_list = [b for b in (base_models or "").split(",") if b] or None
try:
return await registries.civitai_search(
query=query, types=type_list, base_models=base_model_list,
sort=sort, period=period, nsfw=nsfw, page=page, cursor=cursor)
except Exception as exc: # noqa: BLE001 - surface API errors to the UI
raise HTTPException(502, f"CivitAI search failed: {exc}")
class CivitaiDownloadIn(BaseModel):
version_id: int
@app.post("/api/civitai/download")
async def civitai_download(body: CivitaiDownloadIn) -> dict:
url = f"https://civitai.com/api/download/models/{body.version_id}"
return await start_download(DownloadIn(url=url))
@app.delete("/api/downloads/{download_id}")
def delete_download(download_id: int) -> dict:
row = db.get_download(download_id)
if not row:
raise HTTPException(404, "Download not found")
# If it's running, signal cancel; the task will mark it canceled.
if downloader.request_cancel(download_id):
return {"canceled": download_id}
db.delete_download(download_id)
return {"removed": download_id}
# ---- gallery (generated photos) -------------------------------------------
@app.get("/api/gallery")
def list_gallery(limit: int = 60, offset: int = 0) -> dict:
"""List generated images under OUTPUT_DIR, newest first, paginated."""
files: list[dict] = []
if OUTPUT_DIR.is_dir():
for path in OUTPUT_DIR.rglob("*"):
if not path.is_file() or path.suffix.lower() not in IMAGE_EXTS:
continue
rel = path.relative_to(OUTPUT_DIR).as_posix()
stat = path.stat()
files.append({
"path": rel,
"name": path.name,
"subfolder": path.parent.relative_to(OUTPUT_DIR).as_posix()
if path.parent != OUTPUT_DIR else "",
"size": stat.st_size,
"mtime": stat.st_mtime,
})
files.sort(key=lambda f: f["mtime"], reverse=True)
total = len(files)
page = files[offset:offset + max(1, min(limit, 500))]
return {"items": page, "total": total,
"offset": offset, "returned": len(page)}
@app.delete("/api/gallery/all")
def delete_all_photos() -> dict:
"""Permanently delete every image under OUTPUT_DIR."""
deleted = 0
failed = 0
if OUTPUT_DIR.is_dir():
for path in OUTPUT_DIR.rglob("*"):
if path.is_file() and path.suffix.lower() in IMAGE_EXTS:
try:
path.unlink()
deleted += 1
except OSError:
failed += 1
return {"deleted": deleted, "failed": failed}
@app.get("/gallery/file")
def gallery_file(path: str):
try:
target = safe_output_path(path)
except ValueError:
raise HTTPException(400, "Invalid path")
if not target.is_file() or target.suffix.lower() not in IMAGE_EXTS:
raise HTTPException(404, "Image not found")
return FileResponse(target)
class DeletePhotoIn(BaseModel):
path: str
@app.delete("/api/gallery")
def delete_photo(body: DeletePhotoIn) -> dict:
try:
target = safe_output_path(body.path)
except ValueError:
raise HTTPException(400, "Invalid path")
if not target.is_file():
raise HTTPException(404, "Image not found")
try:
target.unlink()
except PermissionError:
raise HTTPException(
403,
"Permission denied (file is in a folder owned by another user). "
"Run ComfyUI as the same UID, or remove it from a host shell.")
return {"deleted": body.path}
# ---- device routing / ui config -------------------------------------------
@app.get("/api/ui-config")
def ui_config() -> dict:
return {"comfyui_port": COMFYUI_PORT, "comfyuimini_port": COMFYUIMINI_PORT}
@app.post("/api/comfyui/free")
async def comfyui_free() -> dict:
"""Ask ComfyUI to unload all models and free GPU/RAM (proxies POST /free)."""
try:
async with httpx.AsyncClient(timeout=30) as client:
resp = await client.post(
f"{COMFYUI_URL}/free",
json={"unload_models": True, "free_memory": True})
resp.raise_for_status()
except Exception as exc: # noqa: BLE001 - report to UI
raise HTTPException(502, f"Could not reach ComfyUI: {exc}")
return {"ok": True}
@app.get("/start")
def start() -> FileResponse:
return FileResponse(STATIC_DIR / "start.html")
# ---- static UI ------------------------------------------------------------
@app.get("/")
def index() -> FileResponse:
return FileResponse(STATIC_DIR / "index.html")
app.mount("/", StaticFiles(directory=str(STATIC_DIR)), name="static")