6871a6d460
- Gallery: DELETE /api/gallery/all removes every image under output/; "Delete all" button with in-app confirm and a deleted/failed count. - Downloads: surface a clear, actionable message when CivitAI/HuggingFace returns 401/403 (model requires login/early-access, or the key/token lacks access) instead of a bare error, both at resolve time and during the download stream. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
133 lines
5.0 KiB
Python
133 lines
5.0 KiB
Python
"""Async streaming downloader with progress tracking and cancellation."""
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import os
|
|
import re
|
|
from pathlib import Path
|
|
from typing import Optional
|
|
|
|
import httpx
|
|
|
|
from . import db
|
|
from .config import MODELS_DIR, folder_for_type
|
|
|
|
# download_id -> asyncio.Event set when a cancel is requested.
|
|
_cancel_flags: dict[int, asyncio.Event] = {}
|
|
|
|
_FILENAME_STAR_RE = re.compile(r"filename\*=(?:UTF-8'')?([^;]+)", re.IGNORECASE)
|
|
_FILENAME_RE = re.compile(r'filename="?([^";]+)"?', re.IGNORECASE)
|
|
|
|
|
|
def sanitize_filename(name: str) -> str:
|
|
"""Strip any directory components and unsafe characters from a filename."""
|
|
name = os.path.basename(name.strip().strip('"'))
|
|
name = name.replace("\x00", "")
|
|
# Disallow path separators / parent refs that survived basename on other OSes.
|
|
name = name.replace("/", "_").replace("\\", "_")
|
|
if name in ("", ".", ".."):
|
|
name = "model.bin"
|
|
return name
|
|
|
|
|
|
def _filename_from_disposition(value: str) -> Optional[str]:
|
|
from urllib.parse import unquote
|
|
m = _FILENAME_STAR_RE.search(value)
|
|
if m:
|
|
return unquote(m.group(1))
|
|
m = _FILENAME_RE.search(value)
|
|
if m:
|
|
return m.group(1)
|
|
return None
|
|
|
|
|
|
def safe_dest(model_type: str, filename: str) -> Path:
|
|
"""Build a destination path under MODELS_DIR/<folder>, guarding traversal."""
|
|
folder = folder_for_type(model_type)
|
|
filename = sanitize_filename(filename)
|
|
base = (MODELS_DIR / folder).resolve()
|
|
dest = (base / filename).resolve()
|
|
if not str(dest).startswith(str(base) + os.sep):
|
|
raise ValueError("Refusing path outside the model folder")
|
|
return dest
|
|
|
|
|
|
def request_cancel(download_id: int) -> bool:
|
|
"""Signal an in-flight download to stop. Returns True if it was active."""
|
|
ev = _cancel_flags.get(download_id)
|
|
if ev is not None:
|
|
ev.set()
|
|
return True
|
|
return False
|
|
|
|
|
|
async def run_download(download_id: int, url: str, headers: dict[str, str],
|
|
model_type: str, filename: Optional[str]) -> None:
|
|
"""Stream `url` to disk, updating the DB row as it progresses."""
|
|
cancel = asyncio.Event()
|
|
_cancel_flags[download_id] = cancel
|
|
part_path: Optional[Path] = None
|
|
try:
|
|
db.update_download(download_id, status="downloading")
|
|
async with httpx.AsyncClient(follow_redirects=True, timeout=None,
|
|
headers=headers) as client:
|
|
async with client.stream("GET", url) as resp:
|
|
resp.raise_for_status()
|
|
|
|
# Prefer a server-provided filename if we don't have a good one.
|
|
disp = resp.headers.get("content-disposition")
|
|
if disp:
|
|
server_name = _filename_from_disposition(disp)
|
|
if server_name:
|
|
filename = server_name
|
|
if not filename:
|
|
filename = os.path.basename(str(resp.url).split("?")[0]) or "model.bin"
|
|
|
|
dest = safe_dest(model_type, filename)
|
|
part_path = dest.with_suffix(dest.suffix + ".part")
|
|
total = int(resp.headers.get("content-length", 0) or 0)
|
|
db.update_download(download_id, filename=dest.name,
|
|
dest_path=str(dest), bytes_total=total)
|
|
|
|
done = 0
|
|
last_report = 0.0
|
|
with open(part_path, "wb") as fh:
|
|
async for chunk in resp.aiter_bytes(chunk_size=1024 * 256):
|
|
if cancel.is_set():
|
|
raise asyncio.CancelledError()
|
|
fh.write(chunk)
|
|
done += len(chunk)
|
|
# Throttle DB writes to ~5/sec.
|
|
now = asyncio.get_event_loop().time()
|
|
if now - last_report > 0.2:
|
|
db.update_download(download_id, bytes_done=done)
|
|
last_report = now
|
|
|
|
db.update_download(download_id, bytes_done=done)
|
|
os.replace(part_path, dest)
|
|
part_path = None
|
|
db.update_download(download_id, status="completed")
|
|
except asyncio.CancelledError:
|
|
db.update_download(download_id, status="canceled", error="Canceled by user")
|
|
_cleanup(part_path)
|
|
except httpx.HTTPStatusError as exc:
|
|
code = exc.response.status_code
|
|
msg = f"HTTP {code}"
|
|
if code in (401, 403):
|
|
msg += " — requires login/early-access or invalid API key"
|
|
db.update_download(download_id, status="failed", error=msg)
|
|
_cleanup(part_path)
|
|
except Exception as exc: # noqa: BLE001 - surface any failure to the UI
|
|
db.update_download(download_id, status="failed", error=str(exc))
|
|
_cleanup(part_path)
|
|
finally:
|
|
_cancel_flags.pop(download_id, None)
|
|
|
|
|
|
def _cleanup(part_path: Optional[Path]) -> None:
|
|
if part_path is not None:
|
|
try:
|
|
part_path.unlink(missing_ok=True)
|
|
except OSError:
|
|
pass
|