import { useLoading } from 'hooks/loading' import { MarkdownTool, useMarkdown } from 'hooks/markdown' import { request } from 'hooks/request' import { defineStore } from 'hooks/store' import { useToast } from 'hooks/toast' import { api } from 'scripts/comfyAPI' import { bytesToSize } from 'utils/common' import { onBeforeMount, onMounted, ref, watch } from 'vue' import { useI18n } from 'vue-i18n' export const useDownload = defineStore('download', (store) => { const { toast, confirm } = useToast() const { t } = useI18n() const taskList = ref([]) const createTaskItem = (item: DownloadTaskOptions) => { const { downloadedSize, totalSize, bps, ...rest } = item const task: DownloadTask = { ...rest, preview: `/model-manager/preview/download/${item.preview}`, downloadProgress: `${bytesToSize(downloadedSize)} / ${bytesToSize(totalSize)}`, downloadSpeed: `${bytesToSize(bps)}/s`, pauseTask() { request(`/download/${item.taskId}`, { method: 'PUT', body: JSON.stringify({ status: 'pause', }), }) }, resumeTask: () => { request(`/download/${item.taskId}`, { method: 'PUT', body: JSON.stringify({ status: 'resume', }), }) }, deleteTask: () => { confirm.require({ message: t('deleteAsk', [t('downloadTask').toLowerCase()]), header: 'Danger', icon: 'pi pi-info-circle', rejectProps: { label: t('cancel'), severity: 'secondary', outlined: true, }, acceptProps: { label: t('delete'), severity: 'danger', }, accept: () => { request(`/download/${item.taskId}`, { method: 'DELETE', }) }, reject: () => {}, }) }, } return task } const refresh = async () => { return request('/download/task') .then((resData: DownloadTaskOptions[]) => { taskList.value = resData.map((item) => createTaskItem(item)) return taskList.value }) .catch((err) => { toast.add({ severity: 'error', summary: 'Error', detail: err.message ?? 'Failed to refresh download task list', life: 15000, }) }) } onBeforeMount(() => { api.addEventListener('reconnected', () => { refresh() }) api.addEventListener('fetch_download_task_list', (event) => { const data = event.detail as DownloadTaskOptions[] taskList.value = data.map((item) => { return createTaskItem(item) }) }) api.addEventListener('create_download_task', (event) => { const item = event.detail as DownloadTaskOptions taskList.value.unshift(createTaskItem(item)) }) api.addEventListener('update_download_task', (event) => { const item = event.detail as DownloadTaskOptions for (const task of taskList.value) { if (task.taskId === item.taskId) { if (item.error) { toast.add({ severity: 'error', summary: 'Error', detail: item.error, life: 15000, }) item.error = undefined } Object.assign(task, createTaskItem(item)) } } }) api.addEventListener('delete_download_task', (event) => { const taskId = event.detail as string taskList.value = taskList.value.filter((item) => item.taskId !== taskId) }) api.addEventListener('complete_download_task', (event) => { const taskId = event.detail as string const task = taskList.value.find((item) => item.taskId === taskId) taskList.value = taskList.value.filter((item) => item.taskId !== taskId) toast.add({ severity: 'success', summary: 'Success', detail: `${task?.fullname} Download completed`, life: 2000, }) store.models.refresh() }) }) onMounted(() => { refresh() }) return { data: taskList, refresh } }) declare module 'hooks/store' { interface StoreProvider { download: ReturnType } } abstract class ModelSearch { constructor(readonly md: MarkdownTool) {} abstract search(pathname: string): Promise } class Civitai extends ModelSearch { async search(searchUrl: string): Promise { const { pathname, searchParams } = new URL(searchUrl) const [, modelId] = pathname.match(/^\/models\/(\d*)/) ?? [] const versionId = searchParams.get('modelVersionId') if (!modelId) { return Promise.resolve([]) } return fetch(`https://civitai.com/api/v1/models/${modelId}`) .then((response) => response.json()) .then((resData) => { const modelVersions: any[] = resData.modelVersions.filter( (version: any) => { if (versionId) { return version.id == versionId } return true }, ) const models: VersionModel[] = [] for (const version of modelVersions) { const modelFiles: any[] = version.files.filter( (file: any) => file.type === 'Model', ) const shortname = modelFiles.length > 0 ? version.name : undefined for (const file of modelFiles) { const fullname = file.name const extension = `.${fullname.split('.').pop()}` const basename = fullname.replace(extension, '') models.push({ id: file.id, shortname: shortname ?? basename, fullname: fullname, basename: basename, extension: extension, preview: version.images.map((i: any) => i.url), sizeBytes: file.sizeKB * 1024, type: this.resolveType(resData.type), pathIndex: 0, description: [ '---', ...[ `website: Civitai`, `modelPage: https://civitai.com/models/${modelId}?modelVersionId=${version.id}`, `author: ${resData.creator?.username}`, version.baseModel && `baseModel: ${version.baseModel}`, file.hashes && `hashes:`, ...Object.entries(file.hashes ?? {}).map( ([key, value]) => ` ${key}: ${value}`, ), file.metadata && `metadata:`, ...Object.entries(file.metadata ?? {}).map( ([key, value]) => ` ${key}: ${value}`, ), ].filter(Boolean), '---', '', '# Trigger Words', `\n${(version.trainedWords ?? ['No trigger words']).join(', ')}\n`, '# About this version', this.resolveDescription( version.description, '\nNo description about this version\n', ), `# ${resData.name}`, this.resolveDescription( resData.description, 'No description about this model', ), ].join('\n'), metadata: file.metadata, downloadPlatform: 'civitai', downloadUrl: file.downloadUrl, hashes: file.hashes, }) } } return models }) } private resolveType(type: string) { const mapLegacy = { TextualInversion: 'embeddings', LoCon: 'loras', DoRA: 'loras', Controlnet: 'controlnet', Upscaler: 'upscale_models', VAE: 'vae', } return mapLegacy[type] ?? `${type.toLowerCase()}s` } private resolveDescription(content: string, defaultContent: string) { const mdContent = this.md.parse(content ?? '').trim() return mdContent || defaultContent } } class Huggingface extends ModelSearch { async search(searchUrl: string): Promise { const { pathname } = new URL(searchUrl) const [, space, name, ...restPaths] = pathname.split('/') if (!space || !name) { return Promise.resolve([]) } const modelId = `${space}/${name}` const restPathname = restPaths.join('/') return fetch(`https://huggingface.co/api/models/${modelId}`) .then((response) => response.json()) .then((resData) => { const siblingFiles: string[] = resData.siblings.map( (item: any) => item.rfilename, ) const modelFiles: string[] = this.filterTreeFiles( this.filterModelFiles(siblingFiles), restPathname, ) const images: string[] = this.filterTreeFiles( this.filterImageFiles(siblingFiles), restPathname, ).map((filename) => { return `https://huggingface.co/${modelId}/resolve/main/${filename}` }) const models: VersionModel[] = [] for (const filename of modelFiles) { const fullname = filename.split('/').pop()! const extension = `.${fullname.split('.').pop()}` const basename = fullname.replace(extension, '') models.push({ id: filename, shortname: filename, fullname: fullname, basename: basename, extension: extension, preview: images, sizeBytes: 0, type: 'unknown', pathIndex: 0, description: [ '---', ...[ `website: HuggingFace`, `modelPage: https://huggingface.co/${modelId}`, `author: ${resData.author}`, ].filter(Boolean), '---', '', '# Trigger Words', '\nNo trigger words\n', '# About this version', '\nNo description about this version\n', `# ${resData.modelId}`, '\nNo description about this model\n', ].join('\n'), metadata: {}, downloadPlatform: 'huggingface', downloadUrl: `https://huggingface.co/${modelId}/resolve/main/${filename}?download=true`, }) } return models }) } private filterTreeFiles(files: string[], pathname: string) { const [target, , ...paths] = pathname.split('/') if (!target) return files if (target !== 'tree' && target !== 'blob') return files const pathPrefix = paths.join('/') return files.filter((file) => { return file.startsWith(pathPrefix) }) } private filterModelFiles(files: string[]) { const extension = [ '.bin', '.ckpt', '.gguf', '.onnx', '.pt', '.pth', '.safetensors', ] return files.filter((file) => { const ext = file.split('.').pop() return ext ? extension.includes(`.${ext}`) : false }) } private filterImageFiles(files: string[]) { const extension = [ '.png', '.webp', '.jpeg', '.jpg', '.jfif', '.gif', '.apng', ] return files.filter((file) => { const ext = file.split('.').pop() return ext ? extension.includes(`.${ext}`) : false }) } } class UnknownWebsite extends ModelSearch { async search(searchUrl: string): Promise { return Promise.reject( new Error( 'Unknown Website, please input a URL from huggingface.co or civitai.com.', ), ) } } export const useModelSearch = () => { const loading = useLoading() const md = useMarkdown() const { toast } = useToast() const data = ref<(SelectOptions & { item: VersionModel })[]>([]) const current = ref() const currentModel = ref() const handleSearchByUrl = async (url: string) => { if (!url) { return Promise.resolve([]) } let instance: ModelSearch = new UnknownWebsite(md) const { hostname } = new URL(url ?? '') if (hostname === 'civitai.com') { instance = new Civitai(md) } if (hostname === 'huggingface.co') { instance = new Huggingface(md) } loading.show() return instance .search(url) .then((resData) => { data.value = resData.map((item) => ({ label: item.shortname, value: item.id, item, command() { current.value = item.id }, })) current.value = data.value[0]?.value currentModel.value = data.value[0]?.item if (resData.length === 0) { toast.add({ severity: 'warn', summary: 'No Model Found', detail: `No model found for ${url}`, life: 3000, }) } return resData }) .catch((err) => { toast.add({ severity: 'error', summary: 'Error', detail: err.message, life: 15000, }) return [] }) .finally(() => loading.hide()) } watch(current, () => { currentModel.value = data.value.find( (option) => option.value === current.value, )?.item }) return { data, current, currentModel, search: handleSearchByUrl } }