diff --git a/__init__.py b/__init__.py index ff4612e..0484eab 100644 --- a/__init__.py +++ b/__init__.py @@ -114,7 +114,8 @@ async def create_model(request): - downloadUrl: download url. - hash: a JSON string containing the hash value of the downloaded model. """ - task_data = await request.json() + task_data = await request.post() + task_data = dict(task_data) try: task_id = await services.create_model_download_task(task_data, request) return web.json_response({"success": True, "data": {"taskId": task_id}}) @@ -186,7 +187,8 @@ async def update_model(request): index = int(request.match_info.get("index", None)) filename = request.match_info.get("filename", None) - model_data: dict = await request.json() + model_data = await request.post() + model_data = dict(model_data) try: model_path = utils.get_valid_full_path(model_type, index, filename) diff --git a/py/download.py b/py/download.py index 84f61d2..49c3ec6 100644 --- a/py/download.py +++ b/py/download.py @@ -180,8 +180,8 @@ async def create_model_download_task(task_data: dict, request): raise RuntimeError(f"Task {task_id} already exists") try: - preview_url = task_data.pop("preview", None) - utils.save_model_preview_image(task_path, preview_url) + previewFile = task_data.pop("previewFile", None) + utils.save_model_preview_image(task_path, previewFile) set_task_content(task_id, task_data) task_status = TaskStatus( taskId=task_id, @@ -361,9 +361,7 @@ async def download_model_file( ) if response.status_code not in (200, 206): - raise RuntimeError( - f"Failed to download {task_content.fullname}, status code: {response.status_code}" - ) + raise RuntimeError(f"Failed to download {task_content.fullname}, status code: {response.status_code}") # Some models require logging in before they can be downloaded. # If no token is carried, it will be redirected to the login page. @@ -376,9 +374,7 @@ async def download_model_file( # If it cannot be downloaded, a redirect will definitely occur. # Maybe consider getting the redirect url from response.history to make a judgment. # Here we also need to consider how different websites are processed. - raise RuntimeError( - f"{task_content.fullname} needs to be logged in to download. Please set the API-Key first." - ) + raise RuntimeError(f"{task_content.fullname} needs to be logged in to download. Please set the API-Key first.") # When parsing model information from HuggingFace API, # the file size was not found and needs to be obtained from the response header. diff --git a/py/services.py b/py/services.py index b623bc4..3768aaa 100644 --- a/py/services.py +++ b/py/services.py @@ -73,7 +73,10 @@ def update_model(model_path: str, model_data: dict): if "previewFile" in model_data: previewFile = model_data["previewFile"] - utils.save_model_preview_image(model_path, previewFile) + if type(previewFile) is str and previewFile == "undefined": + utils.remove_model_preview_image(model_path) + else: + utils.save_model_preview_image(model_path, previewFile) if "description" in model_data: description = model_data["description"] diff --git a/py/utils.py b/py/utils.py index 275e636..3a5163a 100644 --- a/py/utils.py +++ b/py/utils.py @@ -249,19 +249,45 @@ from PIL import Image from io import BytesIO -def save_model_preview_image(model_path: str, image_url: str): - try: - image_response = requests.get(image_url) - image_response.raise_for_status() +def remove_model_preview_image(model_path: str): + basename = os.path.splitext(model_path)[0] + preview_path = f"{basename}.webp" + if os.path.exists(preview_path): + os.remove(preview_path) - basename = os.path.splitext(model_path)[0] - preview_path = f"{basename}.webp" - image = Image.open(BytesIO(image_response.content)) + +def save_model_preview_image(model_path: str, image_file_or_url: Any): + basename = os.path.splitext(model_path)[0] + preview_path = f"{basename}.webp" + + # Download image file if it is url + if type(image_file_or_url) is str: + image_url = image_file_or_url + + try: + image_response = requests.get(image_url) + image_response.raise_for_status() + + image = Image.open(BytesIO(image_response.content)) + image.save(preview_path, "WEBP") + + except Exception as e: + print_error(f"Failed to download image: {e}") + + else: + # Assert image as file + image_file = image_file_or_url + + if not isinstance(image_file, web.FileField): + raise RuntimeError("Invalid image file") + + content_type: str = image_file.content_type + if not content_type.startswith("image/"): + raise RuntimeError(f"FileTypeError: expected image, got {content_type}") + + image = Image.open(image_file.file) image.save(preview_path, "WEBP") - except Exception as e: - print_error(f"Failed to download image: {e}") - def get_model_all_descriptions(model_path: str): base_dirname = os.path.dirname(model_path) diff --git a/src/components/DialogCreateTask.vue b/src/components/DialogCreateTask.vue index 0c1d09d..2ba5e04 100644 --- a/src/components/DialogCreateTask.vue +++ b/src/components/DialogCreateTask.vue @@ -69,7 +69,8 @@ import { useLoading } from 'hooks/loading' import { request } from 'hooks/request' import { useToast } from 'hooks/toast' import Button from 'primevue/button' -import { VersionModel } from 'types/typings' +import { VersionModel, WithResolved } from 'types/typings' +import { previewUrlToFile } from 'utils/common' import { ref } from 'vue' const { isMobile } = useConfig() @@ -87,12 +88,49 @@ const searchModelsByUrl = async () => { } } -const createDownTask = async (data: VersionModel) => { +const createDownTask = async (data: WithResolved) => { loading.show() + const formData = new FormData() + for (const key in data) { + if (Object.prototype.hasOwnProperty.call(data, key)) { + let value = data[key] + + // set preview file + if (key === 'preview') { + if (value) { + const previewFile = await previewUrlToFile(value).catch(() => { + loading.hide() + toast.add({ + severity: 'error', + summary: 'Error', + detail: 'Failed to download preview', + life: 5000, + }) + throw new Error('Failed to download preview') + }) + formData.append('previewFile', previewFile) + } else { + formData.append('previewFile', value) + } + continue + } + + if (typeof value === 'object') { + value = JSON.stringify(value) + } + + if (typeof value === 'number') { + value = value.toString() + } + + formData.append(key, value) + } + } + await request('/model', { method: 'POST', - body: JSON.stringify(data), + body: formData, }) .then(() => { dialog.close() diff --git a/src/components/DialogModelDetail.vue b/src/components/DialogModelDetail.vue index 38c2c5a..d51cf38 100644 --- a/src/components/DialogModelDetail.vue +++ b/src/components/DialogModelDetail.vue @@ -47,7 +47,7 @@ import ResponseScroll from 'components/ResponseScroll.vue' import { useModelNodeAction, useModels } from 'hooks/model' import { useRequest } from 'hooks/request' import Button from 'primevue/button' -import { BaseModel, Model } from 'types/typings' +import { BaseModel, Model, WithResolved } from 'types/typings' import { computed, ref } from 'vue' interface Props { @@ -72,7 +72,7 @@ const handleCancel = () => { editable.value = false } -const handleSave = async (data: BaseModel) => { +const handleSave = async (data: WithResolved) => { await update(modelContent.value, data) editable.value = false } diff --git a/src/components/ModelContent.vue b/src/components/ModelContent.vue index 77f6155..6c1dd49 100644 --- a/src/components/ModelContent.vue +++ b/src/components/ModelContent.vue @@ -62,7 +62,7 @@ import TabList from 'primevue/tablist' import TabPanel from 'primevue/tabpanel' import TabPanels from 'primevue/tabpanels' import Tabs from 'primevue/tabs' -import { BaseModel } from 'types/typings' +import { BaseModel, WithResolved } from 'types/typings' import { toRaw, watch } from 'vue' interface Props { @@ -73,7 +73,7 @@ const props = defineProps() const editable = defineModel('editable') const emits = defineEmits<{ - submit: [formData: BaseModel] + submit: [formData: WithResolved] reset: [] }>() diff --git a/src/hooks/model.ts b/src/hooks/model.ts index e65ac13..fa5199c 100644 --- a/src/hooks/model.ts +++ b/src/hooks/model.ts @@ -3,10 +3,10 @@ import { useMarkdown } from 'hooks/markdown' import { request } from 'hooks/request' import { defineStore } from 'hooks/store' import { useToast } from 'hooks/toast' -import { cloneDeep } from 'lodash' +import { castArray, cloneDeep } from 'lodash' import { app } from 'scripts/comfyAPI' -import { BaseModel, Model, SelectEvent } from 'types/typings' -import { bytesToSize, formatDate } from 'utils/common' +import { BaseModel, Model, SelectEvent, WithResolved } from 'types/typings' +import { bytesToSize, formatDate, previewUrlToFile } from 'utils/common' import { ModelGrid } from 'utils/legacy' import { genModelKey, resolveModelTypeLoader } from 'utils/model' import { @@ -74,18 +74,30 @@ export const useModels = defineStore('models', (store) => { ) } - const updateModel = async (model: BaseModel, data: BaseModel) => { - const updateData = new Map() + const updateModel = async ( + model: BaseModel, + data: WithResolved, + ) => { + const updateData = new FormData() let oldKey: string | null = null + let needUpdate = false // Check current preview if (model.preview !== data.preview) { - updateData.set('previewFile', data.preview) + const preview = data.preview + if (preview) { + const previewFile = await previewUrlToFile(data.preview as string) + updateData.set('previewFile', previewFile) + } else { + updateData.set('previewFile', 'undefined') + } + needUpdate = true } // Check current description if (model.description !== data.description) { updateData.set('description', data.description) + needUpdate = true } // Check current name and pathIndex @@ -97,16 +109,17 @@ export const useModels = defineStore('models', (store) => { updateData.set('type', data.type) updateData.set('pathIndex', data.pathIndex.toString()) updateData.set('fullname', data.fullname) + needUpdate = true } - if (updateData.size === 0) { + if (!needUpdate) { return } loading.show() await request(`/model/${model.type}/${model.pathIndex}/${model.fullname}`, { method: 'PUT', - body: JSON.stringify(Object.fromEntries(updateData.entries())), + body: updateData, }) .catch((err) => { const error_message = err.message ?? err.error @@ -216,15 +229,15 @@ export const useModelFormData = (getFormData: () => BaseModel) => { } } - type SubmitCallback = (data: BaseModel) => void + type SubmitCallback = (data: WithResolved) => void const submitCallback = ref([]) const registerSubmit = (callback: SubmitCallback) => { submitCallback.value.push(callback) } - const submit = () => { - const data = cloneDeep(toRaw(unref(formData))) + const submit = (): WithResolved => { + const data: any = cloneDeep(toRaw(unref(formData))) for (const callback of submitCallback.value) { callback(data) } @@ -394,9 +407,7 @@ export const useModelPreviewEditor = (formInstance: ModelFormInstance) => { * Default images */ const defaultContent = computed(() => { - return Array.isArray(model.value.preview) - ? model.value.preview - : [model.value.preview] + return model.value.preview ? castArray(model.value.preview) : [] }) const defaultContentPage = ref(0) @@ -435,7 +446,7 @@ export const useModelPreviewEditor = (formInstance: ModelFormInstance) => { content = localContent.value break default: - content = noPreviewContent.value + content = undefined break } @@ -451,7 +462,7 @@ export const useModelPreviewEditor = (formInstance: ModelFormInstance) => { }) registerSubmit((data) => { - data.preview = preview.value ?? noPreviewContent.value + data.preview = preview.value }) }) diff --git a/src/types/typings.d.ts b/src/types/typings.d.ts index 3c81ec3..9312331 100644 --- a/src/types/typings.d.ts +++ b/src/types/typings.d.ts @@ -26,6 +26,10 @@ export interface VersionModel extends BaseModel { hashes?: Record } +export type WithResolved = Omit & { + preview: string | undefined +} + export type PassThrough = T | object | undefined export interface SelectOptions { diff --git a/src/utils/common.ts b/src/utils/common.ts index fd1c0f7..5d6a259 100644 --- a/src/utils/common.ts +++ b/src/utils/common.ts @@ -26,3 +26,14 @@ export const bytesToSize = ( export const formatDate = (date: number | string | Date) => { return dayjs(date).format('YYYY-MM-DD HH:mm:ss') } + +export const previewUrlToFile = async (url: string) => { + return fetch(url) + .then((res) => res.blob()) + .then((blob) => { + const type = blob.type + const extension = type.split('/')[1] + const file = new File([blob], `preview.${extension}`, { type }) + return file + }) +}