"""Async streaming downloader with progress tracking and cancellation.""" from __future__ import annotations import asyncio import os import re from pathlib import Path from typing import Optional import httpx from . import db from .config import MODELS_DIR, folder_for_type # download_id -> asyncio.Event set when a cancel is requested. _cancel_flags: dict[int, asyncio.Event] = {} _FILENAME_STAR_RE = re.compile(r"filename\*=(?:UTF-8'')?([^;]+)", re.IGNORECASE) _FILENAME_RE = re.compile(r'filename="?([^";]+)"?', re.IGNORECASE) def sanitize_filename(name: str) -> str: """Strip any directory components and unsafe characters from a filename.""" name = os.path.basename(name.strip().strip('"')) name = name.replace("\x00", "") # Disallow path separators / parent refs that survived basename on other OSes. name = name.replace("/", "_").replace("\\", "_") if name in ("", ".", ".."): name = "model.bin" return name def _filename_from_disposition(value: str) -> Optional[str]: from urllib.parse import unquote m = _FILENAME_STAR_RE.search(value) if m: return unquote(m.group(1)) m = _FILENAME_RE.search(value) if m: return m.group(1) return None def safe_dest(model_type: str, filename: str) -> Path: """Build a destination path under MODELS_DIR/, guarding traversal.""" folder = folder_for_type(model_type) filename = sanitize_filename(filename) base = (MODELS_DIR / folder).resolve() dest = (base / filename).resolve() if not str(dest).startswith(str(base) + os.sep): raise ValueError("Refusing path outside the model folder") return dest def request_cancel(download_id: int) -> bool: """Signal an in-flight download to stop. Returns True if it was active.""" ev = _cancel_flags.get(download_id) if ev is not None: ev.set() return True return False async def run_download(download_id: int, url: str, headers: dict[str, str], model_type: str, filename: Optional[str]) -> None: """Stream `url` to disk, updating the DB row as it progresses.""" cancel = asyncio.Event() _cancel_flags[download_id] = cancel part_path: Optional[Path] = None try: db.update_download(download_id, status="downloading") async with httpx.AsyncClient(follow_redirects=True, timeout=None, headers=headers) as client: async with client.stream("GET", url) as resp: resp.raise_for_status() # Prefer a server-provided filename if we don't have a good one. disp = resp.headers.get("content-disposition") if disp: server_name = _filename_from_disposition(disp) if server_name: filename = server_name if not filename: filename = os.path.basename(str(resp.url).split("?")[0]) or "model.bin" dest = safe_dest(model_type, filename) part_path = dest.with_suffix(dest.suffix + ".part") total = int(resp.headers.get("content-length", 0) or 0) db.update_download(download_id, filename=dest.name, dest_path=str(dest), bytes_total=total) done = 0 last_report = 0.0 with open(part_path, "wb") as fh: async for chunk in resp.aiter_bytes(chunk_size=1024 * 256): if cancel.is_set(): raise asyncio.CancelledError() fh.write(chunk) done += len(chunk) # Throttle DB writes to ~5/sec. now = asyncio.get_event_loop().time() if now - last_report > 0.2: db.update_download(download_id, bytes_done=done) last_report = now db.update_download(download_id, bytes_done=done) os.replace(part_path, dest) part_path = None db.update_download(download_id, status="completed") except asyncio.CancelledError: db.update_download(download_id, status="canceled", error="Canceled by user") _cleanup(part_path) except httpx.HTTPStatusError as exc: db.update_download(download_id, status="failed", error=f"HTTP {exc.response.status_code}") _cleanup(part_path) except Exception as exc: # noqa: BLE001 - surface any failure to the UI db.update_download(download_id, status="failed", error=str(exc)) _cleanup(part_path) finally: _cancel_flags.pop(download_id, None) def _cleanup(part_path: Optional[Path]) -> None: if part_path is not None: try: part_path.unlink(missing_ok=True) except OSError: pass