Files
SparkyUI/model-manager/app/downloader.py
T
TBNilles 6871a6d460 feat(model-manager): add "Delete all" gallery button + clearer 401 errors
- Gallery: DELETE /api/gallery/all removes every image under output/;
  "Delete all" button with in-app confirm and a deleted/failed count.
- Downloads: surface a clear, actionable message when CivitAI/HuggingFace
  returns 401/403 (model requires login/early-access, or the key/token
  lacks access) instead of a bare error, both at resolve time and during
  the download stream.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-07 16:19:48 -04:00

133 lines
5.0 KiB
Python

"""Async streaming downloader with progress tracking and cancellation."""
from __future__ import annotations
import asyncio
import os
import re
from pathlib import Path
from typing import Optional
import httpx
from . import db
from .config import MODELS_DIR, folder_for_type
# download_id -> asyncio.Event set when a cancel is requested.
_cancel_flags: dict[int, asyncio.Event] = {}
_FILENAME_STAR_RE = re.compile(r"filename\*=(?:UTF-8'')?([^;]+)", re.IGNORECASE)
_FILENAME_RE = re.compile(r'filename="?([^";]+)"?', re.IGNORECASE)
def sanitize_filename(name: str) -> str:
"""Strip any directory components and unsafe characters from a filename."""
name = os.path.basename(name.strip().strip('"'))
name = name.replace("\x00", "")
# Disallow path separators / parent refs that survived basename on other OSes.
name = name.replace("/", "_").replace("\\", "_")
if name in ("", ".", ".."):
name = "model.bin"
return name
def _filename_from_disposition(value: str) -> Optional[str]:
from urllib.parse import unquote
m = _FILENAME_STAR_RE.search(value)
if m:
return unquote(m.group(1))
m = _FILENAME_RE.search(value)
if m:
return m.group(1)
return None
def safe_dest(model_type: str, filename: str) -> Path:
"""Build a destination path under MODELS_DIR/<folder>, guarding traversal."""
folder = folder_for_type(model_type)
filename = sanitize_filename(filename)
base = (MODELS_DIR / folder).resolve()
dest = (base / filename).resolve()
if not str(dest).startswith(str(base) + os.sep):
raise ValueError("Refusing path outside the model folder")
return dest
def request_cancel(download_id: int) -> bool:
"""Signal an in-flight download to stop. Returns True if it was active."""
ev = _cancel_flags.get(download_id)
if ev is not None:
ev.set()
return True
return False
async def run_download(download_id: int, url: str, headers: dict[str, str],
model_type: str, filename: Optional[str]) -> None:
"""Stream `url` to disk, updating the DB row as it progresses."""
cancel = asyncio.Event()
_cancel_flags[download_id] = cancel
part_path: Optional[Path] = None
try:
db.update_download(download_id, status="downloading")
async with httpx.AsyncClient(follow_redirects=True, timeout=None,
headers=headers) as client:
async with client.stream("GET", url) as resp:
resp.raise_for_status()
# Prefer a server-provided filename if we don't have a good one.
disp = resp.headers.get("content-disposition")
if disp:
server_name = _filename_from_disposition(disp)
if server_name:
filename = server_name
if not filename:
filename = os.path.basename(str(resp.url).split("?")[0]) or "model.bin"
dest = safe_dest(model_type, filename)
part_path = dest.with_suffix(dest.suffix + ".part")
total = int(resp.headers.get("content-length", 0) or 0)
db.update_download(download_id, filename=dest.name,
dest_path=str(dest), bytes_total=total)
done = 0
last_report = 0.0
with open(part_path, "wb") as fh:
async for chunk in resp.aiter_bytes(chunk_size=1024 * 256):
if cancel.is_set():
raise asyncio.CancelledError()
fh.write(chunk)
done += len(chunk)
# Throttle DB writes to ~5/sec.
now = asyncio.get_event_loop().time()
if now - last_report > 0.2:
db.update_download(download_id, bytes_done=done)
last_report = now
db.update_download(download_id, bytes_done=done)
os.replace(part_path, dest)
part_path = None
db.update_download(download_id, status="completed")
except asyncio.CancelledError:
db.update_download(download_id, status="canceled", error="Canceled by user")
_cleanup(part_path)
except httpx.HTTPStatusError as exc:
code = exc.response.status_code
msg = f"HTTP {code}"
if code in (401, 403):
msg += " — requires login/early-access or invalid API key"
db.update_download(download_id, status="failed", error=msg)
_cleanup(part_path)
except Exception as exc: # noqa: BLE001 - surface any failure to the UI
db.update_download(download_id, status="failed", error=str(exc))
_cleanup(part_path)
finally:
_cancel_flags.pop(download_id, None)
def _cleanup(part_path: Optional[Path]) -> None:
if part_path is not None:
try:
part_path.unlink(missing_ok=True)
except OSError:
pass