e8115e7aa6
- New "Browse CivitAI" view: thumbnail grid with search, type/sort/period filters and NSFW toggle; click a model card to download it (per-card version picker for multi-version models). Cursor + page based "Load more". - Backend: /api/civitai/search and /api/civitai/download endpoints; new civitai_search() catalog helper. - Fix 401 on paste: recognize the civitai.red mirror (and any civitai.* host), normalize API calls to civitai.com, and always resolve the model-version so type + filename are auto-detected for every CivitAI URL. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
225 lines
6.7 KiB
Python
225 lines
6.7 KiB
Python
"""SparkyUI Model Manager - FastAPI app, API routes, and static UI."""
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import os
|
|
from pathlib import Path
|
|
from typing import Optional
|
|
|
|
from fastapi import FastAPI, HTTPException
|
|
from fastapi.responses import FileResponse
|
|
from fastapi.staticfiles import StaticFiles
|
|
from pydantic import BaseModel
|
|
|
|
from . import db, downloader, registries
|
|
from .config import (
|
|
KEY_BY_FOLDER,
|
|
MODEL_TYPES,
|
|
MODELS_DIR,
|
|
TYPE_BY_KEY,
|
|
ensure_dirs,
|
|
)
|
|
|
|
STATIC_DIR = Path(__file__).parent / "static"
|
|
|
|
app = FastAPI(title="SparkyUI Model Manager")
|
|
|
|
|
|
@app.on_event("startup")
|
|
def _startup() -> None:
|
|
ensure_dirs()
|
|
db.init_db()
|
|
# Any download left mid-flight by a restart is no longer running.
|
|
for d in db.list_downloads():
|
|
if d["status"] in ("queued", "downloading"):
|
|
db.update_download(d["id"], status="failed",
|
|
error="Interrupted by restart")
|
|
|
|
|
|
# ---- request models -------------------------------------------------------
|
|
|
|
class SettingsIn(BaseModel):
|
|
civitai_api_key: Optional[str] = None
|
|
huggingface_token: Optional[str] = None
|
|
|
|
|
|
class DownloadIn(BaseModel):
|
|
url: str
|
|
model_type: Optional[str] = None
|
|
filename: Optional[str] = None
|
|
|
|
|
|
# ---- model types & installed models ---------------------------------------
|
|
|
|
@app.get("/api/model-types")
|
|
def model_types() -> list[dict]:
|
|
return MODEL_TYPES
|
|
|
|
|
|
@app.get("/api/models")
|
|
def list_models() -> list[dict]:
|
|
results: list[dict] = []
|
|
for folder, type_key in KEY_BY_FOLDER.items():
|
|
base = MODELS_DIR / folder
|
|
if not base.is_dir():
|
|
continue
|
|
for entry in sorted(base.iterdir()):
|
|
if entry.name.startswith(".") or entry.name.endswith(".part"):
|
|
continue
|
|
if not entry.is_file():
|
|
continue
|
|
stat = entry.stat()
|
|
results.append({
|
|
"name": entry.name,
|
|
"type": type_key,
|
|
"type_label": TYPE_BY_KEY[type_key]["label"],
|
|
"folder": folder,
|
|
"size": stat.st_size,
|
|
"mtime": stat.st_mtime,
|
|
})
|
|
return results
|
|
|
|
|
|
class DeleteModelIn(BaseModel):
|
|
folder: str
|
|
name: str
|
|
|
|
|
|
@app.delete("/api/models")
|
|
def delete_model(body: DeleteModelIn) -> dict:
|
|
if body.folder not in KEY_BY_FOLDER:
|
|
raise HTTPException(400, "Unknown model folder")
|
|
name = os.path.basename(body.name)
|
|
base = (MODELS_DIR / body.folder).resolve()
|
|
target = (base / name).resolve()
|
|
if not str(target).startswith(str(base) + os.sep):
|
|
raise HTTPException(400, "Invalid path")
|
|
if not target.is_file():
|
|
raise HTTPException(404, "File not found")
|
|
target.unlink()
|
|
return {"deleted": name}
|
|
|
|
|
|
# ---- settings -------------------------------------------------------------
|
|
|
|
@app.get("/api/settings")
|
|
def get_settings() -> dict:
|
|
# Never return the secrets themselves, only whether they are configured.
|
|
return {
|
|
"civitai_api_key_set": bool(db.get_setting("civitai_api_key")),
|
|
"huggingface_token_set": bool(db.get_setting("huggingface_token")),
|
|
}
|
|
|
|
|
|
@app.post("/api/settings")
|
|
def save_settings(body: SettingsIn) -> dict:
|
|
# `None` leaves a value untouched; empty string clears it.
|
|
if body.civitai_api_key is not None:
|
|
db.set_setting("civitai_api_key", body.civitai_api_key)
|
|
if body.huggingface_token is not None:
|
|
db.set_setting("huggingface_token", body.huggingface_token)
|
|
return get_settings()
|
|
|
|
|
|
# ---- downloads ------------------------------------------------------------
|
|
|
|
@app.get("/api/downloads")
|
|
def get_downloads() -> list[dict]:
|
|
return db.list_downloads()
|
|
|
|
|
|
@app.post("/api/downloads")
|
|
async def start_download(body: DownloadIn) -> dict:
|
|
url = body.url.strip()
|
|
if not url:
|
|
raise HTTPException(400, "URL is required")
|
|
|
|
try:
|
|
resolved = await registries.resolve(url)
|
|
except Exception as exc: # noqa: BLE001 - report resolution failure to the user
|
|
raise HTTPException(400, f"Could not resolve URL: {exc}")
|
|
|
|
# Pick the model type: explicit user choice wins, else registry suggestion.
|
|
model_type = body.model_type or resolved.model_type
|
|
if not model_type:
|
|
raise HTTPException(
|
|
400, "Please choose a model type for this download.")
|
|
if model_type not in TYPE_BY_KEY:
|
|
raise HTTPException(400, "Unknown model type")
|
|
|
|
filename = body.filename or resolved.filename or ""
|
|
|
|
download_id = db.create_download(
|
|
url=resolved.download_url,
|
|
source=resolved.source,
|
|
model_type=model_type,
|
|
filename=filename,
|
|
dest_path="",
|
|
)
|
|
|
|
asyncio.create_task(downloader.run_download(
|
|
download_id=download_id,
|
|
url=resolved.download_url,
|
|
headers=resolved.headers,
|
|
model_type=model_type,
|
|
filename=filename or None,
|
|
))
|
|
|
|
row = db.get_download(download_id)
|
|
return row or {"id": download_id}
|
|
|
|
|
|
# ---- CivitAI catalog browse ----------------------------------------------
|
|
|
|
@app.get("/api/civitai/search")
|
|
async def civitai_search(
|
|
query: Optional[str] = None,
|
|
types: Optional[str] = None, # comma-separated CivitAI types
|
|
sort: Optional[str] = "Most Downloaded",
|
|
period: Optional[str] = "AllTime",
|
|
nsfw: bool = False,
|
|
page: Optional[int] = None,
|
|
cursor: Optional[str] = None,
|
|
) -> dict:
|
|
if not db.get_setting("civitai_api_key"):
|
|
raise HTTPException(400, "Set your CivitAI API key in Settings first.")
|
|
type_list = [t for t in (types or "").split(",") if t] or None
|
|
try:
|
|
return await registries.civitai_search(
|
|
query=query, types=type_list, sort=sort, period=period,
|
|
nsfw=nsfw, page=page, cursor=cursor)
|
|
except Exception as exc: # noqa: BLE001 - surface API errors to the UI
|
|
raise HTTPException(502, f"CivitAI search failed: {exc}")
|
|
|
|
|
|
class CivitaiDownloadIn(BaseModel):
|
|
version_id: int
|
|
|
|
|
|
@app.post("/api/civitai/download")
|
|
async def civitai_download(body: CivitaiDownloadIn) -> dict:
|
|
url = f"https://civitai.com/api/download/models/{body.version_id}"
|
|
return await start_download(DownloadIn(url=url))
|
|
|
|
|
|
@app.delete("/api/downloads/{download_id}")
|
|
def delete_download(download_id: int) -> dict:
|
|
row = db.get_download(download_id)
|
|
if not row:
|
|
raise HTTPException(404, "Download not found")
|
|
# If it's running, signal cancel; the task will mark it canceled.
|
|
if downloader.request_cancel(download_id):
|
|
return {"canceled": download_id}
|
|
db.delete_download(download_id)
|
|
return {"removed": download_id}
|
|
|
|
|
|
# ---- static UI ------------------------------------------------------------
|
|
|
|
@app.get("/")
|
|
def index() -> FileResponse:
|
|
return FileResponse(STATIC_DIR / "index.html")
|
|
|
|
|
|
app.mount("/", StaticFiles(directory=str(STATIC_DIR)), name="static")
|