diff --git a/__init__.py b/__init__.py index 3145e07..0602c69 100644 --- a/__init__.py +++ b/__init__.py @@ -191,6 +191,9 @@ def ui_rules(): Rule("model-add-drag-strict-on-field", False, bool), Rule("model-add-offset", 25, int), Rule("download-save-description-as-text-file", True, bool), + Rule("tag-generator-sampler-method", "Frequency", str), + Rule("tag-generator-count", 10, int), + Rule("tag-generator-threshold", 2, int), ] diff --git a/web/model-manager.css b/web/model-manager.css index dd394d3..e10a567 100644 --- a/web/model-manager.css +++ b/web/model-manager.css @@ -28,6 +28,21 @@ gap: 16px; } +.model-manager .no-highlight { + user-select: none; + -moz-user-select: none; + -webkit-text-select: none; + -webkit-user-select: none; +} + +.model-manager label:has(> *){ + pointer-events: none; +} + +.model-manager label > * { + pointer-events: auto; +} + /* sidebar buttons */ .model-manager .sidebar-buttons { overflow: hidden; @@ -605,7 +620,9 @@ /* model manager settings */ .model-manager .model-manager-settings > div, -.model-manager .model-manager-settings > label { +.model-manager .model-manager-settings > label, +.model-manager .tag-generator-settings > label, +.model-manager .tag-generator-settings > div { display: flex; flex-direction: row; align-items: center; @@ -613,20 +630,13 @@ margin: 16px 0; } -.model-manager label { - pointer-events: none; -} - -.model-manager label > * { - pointer-events: auto; -} - .model-manager .model-manager-settings button { height: 40px; width: 120px; } -.model-manager .model-manager-settings input[type="number"] { +.model-manager .model-manager-settings input[type="number"], +.model-manager .tag-generator-settings input[type="number"]{ width: 50px; } diff --git a/web/model-manager.js b/web/model-manager.js index 5b56098..dded4e0 100644 --- a/web/model-manager.js +++ b/web/model-manager.js @@ -227,7 +227,7 @@ function $checkbox(x = { $: (el) => {}, textContent: "", checked: false }) { const text = x.textContent; const input = $el("input", { type: "checkbox", - name: text ?? "checkbox", + name: text ?? "checkbox", checked: x.checked ?? false, }); const label = $el("label", [ @@ -254,7 +254,7 @@ function $select(x = { $: (el) => {}, textContent: "", options: [""] }) { })); const label = $el("label", [ text === "" || text === undefined || text === null ? "" : " " + text, - select, + select, ]); if (x.$ !== undefined){ x.$(select); @@ -284,7 +284,7 @@ function $radioGroup(attr) { checked: index === 0, $: (el) => (inputRef.value = el), }), - $el("label", [item.label ?? item.value]), + $el("label.no-highlight", item.label ?? item.value), ] ); }); @@ -379,7 +379,6 @@ function GenerateDynamicTabTextCallback(element, tabButtons, minWidth) { } /** - * * @param {[String, int][]} map * @returns {String} */ @@ -398,6 +397,20 @@ function TagCountMapToParagraph(map) { return text; } +/** + * @param {String} p + * @returns {[String, int][]} + */ +function ParseTagParagraph(p) { + return p.split(",").map(x => { + const text = x.endsWith(", ") ? x.substring(0, x.length - 2) : x; + const i = text.lastIndexOf("("); + const tag = text.substring(0, i).trim(); + const frequency = parseInt(text.substring(i + 1, text.length - 1)); + return [tag, frequency]; + }); +} + class ImageSelect { /** @constant {string} */ #PREVIEW_DEFAULT = "Default"; /** @constant {string} */ #PREVIEW_UPLOAD = "Upload"; @@ -1883,11 +1896,16 @@ class ModelInfo { /** @type {string} */ #savedNotesValue = null; + /** @type {[HTMLElement][]} */ + #settingsElements = null; + /** * @param {ModelData} modelData * @param {() => Promise} updateModels + * @param {any} settingsElements */ - constructor(modelData, updateModels) { + constructor(modelData, updateModels, settingsElements) { + this.#settingsElements = settingsElements; const moveDestinationInput = $el("input.search-text-area", { name: "move directory", autocomplete: "off", @@ -2365,20 +2383,91 @@ class ModelInfo { /** @type {HTMLDivElement} */ const tagsElement = this.elements.tabContents[2]; // TODO: remove magic value const isTags = Array.isArray(tags) && tags.length > 0; + const tagsParagraph = $el("div", (() => { + const elements = []; + if (isTags) { + let text = TagCountMapToParagraph(tags); + const div = $el("div"); + div.innerHTML = text; + elements.push(div); + } + return elements; + })(), + ); + const tagGeneratorRandomizedOutput = $el("textarea.comfy-multiline-input", { + name: "random tag generator output", + rows: 4, + }); + const TAG_GENERATOR_SAMPLER_NAME = "model manager tag generator sampler"; + const tagGenerationCount = $el("input", { + type: "number", + name: "tag generator count", + step: 1, + min: 1, + value: this.#settingsElements["tag-generator-count"].value, + }); + const tagGenerationThreshold = $el("input", { + type: "number", + name: "tag generator threshold", + step: 1, + min: 1, + value: this.#settingsElements["tag-generator-threshold"].value, + }); + const selectedSamplerOption = this.#settingsElements["tag-generator-sampler-method"].value; + const samplerOptions = ["Frequency", "Uniform"]; + const samplerRadioGroup = $radioGroup({ + name: TAG_GENERATOR_SAMPLER_NAME, + onchange: (value) => {}, + options: samplerOptions.map(option => { return { value: option }; }), + }); + const samplerOptionInputs = samplerRadioGroup.getElementsByTagName("input"); + for (let i = 0; i < samplerOptionInputs.length; i++) { + const samplerOptionInput = samplerOptionInputs[i]; + if (samplerOptionInput.value === selectedSamplerOption) { + samplerOptionInput.click(); + break; + } + } tagsElement.innerHTML = ""; tagsElement.append.apply(tagsElement, [ $el("h1", { style: { "margin-top": "0px", "margin-bottom": "0px" } }, ["Tags"]), - $el("div", (() => { - const elements = []; - if (isTags) { - let text = TagCountMapToParagraph(tags); - const div = $el("div"); - div.innerHTML = text; - elements.push(div); - } - return elements; - })(), - ), + $el("h2", ["Random Tag Generator"]), + $el("div", [ + $el("details.tag-generator-settings", { + style: { margin: "10px 0", display: "none" }, + open: false, + }, [ + $el("summary", ["Settings"]), + $el("div", [ + "Sampling Method", + samplerRadioGroup, + ]), + $el("label", [ + "Count", + tagGenerationCount, + ]), + $el("label", [ + "Threshold", + tagGenerationThreshold, + ]), + ]), + tagGeneratorRandomizedOutput, + $el("button", { + textContent: "Randomize", + style: { width: "100%" }, + onclick: (e) => { + const samplerName = document.querySelector(`input[name="${TAG_GENERATOR_SAMPLER_NAME}"]:checked`).value; + const sampler = samplerName === "Frequency" ? ModelInfo.ProbabilisticTagSampling : ModelInfo.UniformTagSampling; + const sampleCount = tagGenerationCount.value; + const frequencyThreshold = tagGenerationThreshold.value; + const tags = ParseTagParagraph(tagsParagraph.innerText); + const sampledTags = sampler(tags, sampleCount, frequencyThreshold); + tagGeneratorRandomizedOutput.value = sampledTags.join(", "); + }, + }), + ]), + $el("h2", ["Training Tags"]), + tagsParagraph, ]); const tagButton = this.elements.tabButtons[2]; // TODO: remove magic value tagButton.style.display = isTags ? "" : "none"; @@ -2414,6 +2503,48 @@ class ModelInfo { })() ); } + + static UniformTagSampling(tagsAndCounts, sampleCount, frequencyThreshold = 0) { + const data = tagsAndCounts.filter(x => x[1] >= frequencyThreshold); + let count = data.length; + const samples = []; + for (let i = 0; i < sampleCount; i++) { + if (count === 0) { break; } + const index = Math.floor(Math.random() * count); + const pair = data.splice(index, 1)[0]; + samples.push(pair); + count -= 1; + } + const sortedSamples = samples.sort((x1, x2) => { return parseInt(x2[1]) - parseInt(x1[1]) }); + return sortedSamples.map(x => x[0]); + } + + static ProbabilisticTagSampling(tagsAndCounts, sampleCount, frequencyThreshold = 0) { + const data = tagsAndCounts.filter(x => x[1] >= frequencyThreshold); + let tagFrequenciesSum = data.reduce((accumulator, x) => accumulator + x[1], 0); + let count = data.length; + const samples = []; + for (let i = 0; i < sampleCount; i++) { + if (count === 0) { break; } + const index = (() => { + let frequencyIndex = Math.floor(Math.random() * tagFrequenciesSum); + return data.findIndex(x => { + const frequency = x[1]; + if (frequency > frequencyIndex) { + return true; + } + frequencyIndex = frequencyIndex - frequency; + return false; + }); + })(); + const pair = data.splice(index, 1)[0]; + samples.push(pair); + tagFrequenciesSum -= pair[1]; + count -= 1; + } + const sortedSamples = samples.sort((x1, x2) => { return parseInt(x2[1]) - parseInt(x1[1]) }); + return sortedSamples.map(x => x[0]); + } } class Civitai { @@ -3284,6 +3415,10 @@ class SettingsView { /** @type {HTMLInputElement} */ "model-add-offset": null, /** @type {HTMLInputElement} */ "download-save-description-as-text-file": null, + + /** @type {HTMLInputElement} */ "tag-generator-sampler-method": null, + /** @type {HTMLInputElement} */ "tag-generator-count": null, + /** @type {HTMLInputElement} */ "tag-generator-threshold": null, }, }; @@ -3452,8 +3587,8 @@ class SettingsView { $: (el) => (settings["model-add-drag-strict-on-field"] = el), textContent: "Strict dragging model onto a node's model field to add", }), - $el("div", [ - $el("p", ["Add model offset"]), + $el("label", [ + "Add model offset", $el("input", { $: (el) => (settings["model-add-offset"] = el), type: "number", @@ -3498,6 +3633,32 @@ class SettingsView { }), ]), */ + $el("h2", ["Random Tag Generator"]), + $select({ + $: (el) => (settings["tag-generator-sampler-method"] = el), + textContent: "Default sampling method", + options: ["Frequency", "Uniform"], + }), + $el("label", [ + "Default count", + $el("input", { + $: (el) => (settings["tag-generator-count"] = el), + type: "number", + name: "tag generator count", + step: 1, + min: 1, + }), + ]), + $el("label", [ + "Default minimum threshold", + $el("input", { + $: (el) => (settings["tag-generator-threshold"] = el), + type: "number", + name: "tag generator threshold", + step: 1, + min: 1, + }), + ]), ]); } } @@ -3711,6 +3872,7 @@ class ModelManager extends ComfyDialog { this.#modelInfo = new ModelInfo( this.#modelData, this.#refreshModels, + this.#settingsView.elements.settings, ); this.#browseView = new BrowseView(