FLUX.2 launch

This commit is contained in:
timudk
2025-11-25 07:25:25 -08:00
commit e80b84ed9f
24 changed files with 3238 additions and 0 deletions

20
.github/workflows/ci.yaml vendored Normal file
View File

@@ -0,0 +1,20 @@
name: CI
on: push
jobs:
lint:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- uses: actions/setup-python@v2
with:
python-version: "3.10"
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install ruff==0.6.8
- name: Run Ruff
run: ruff check --output-format=github .
- name: Check imports
run: ruff check --select I --output-format=github .
- name: Check formatting
run: ruff format --check .

232
.gitignore vendored Normal file
View File

@@ -0,0 +1,232 @@
# Created by https://www.toptal.com/developers/gitignore/api/linux,windows,macos,visualstudiocode,python
# Edit at https://www.toptal.com/developers/gitignore?templates=linux,windows,macos,visualstudiocode,python
### Linux ###
*~
# temporary files which can be created if a process still has a handle open of a deleted file
.fuse_hidden*
# KDE directory preferences
.directory
# Linux trash folder which might appear on any partition or disk
.Trash-*
# .nfs files are created when an open file is removed but is still being accessed
.nfs*
### macOS ###
# General
.DS_Store
.AppleDouble
.LSOverride
# Icon must end with two \r
Icon
# Thumbnails
._*
# Files that might appear in the root of a volume
.DocumentRevisions-V100
.fseventsd
.Spotlight-V100
.TemporaryItems
.Trashes
.VolumeIcon.icns
.com.apple.timemachine.donotpresent
# Directories potentially created on remote AFP share
.AppleDB
.AppleDesktop
Network Trash Folder
Temporary Items
.apdisk
### Python ###
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
### VisualStudioCode ###
.vscode/*
!.vscode/settings.json
!.vscode/tasks.json
!.vscode/launch.json
!.vscode/extensions.json
*.code-workspace
# Local History for Visual Studio Code
.history/
### VisualStudioCode Patch ###
# Ignore all local history of files
.history
.ionide
### Windows ###
# Windows thumbnail cache files
Thumbs.db
Thumbs.db:encryptable
ehthumbs.db
ehthumbs_vista.db
# Dump file
*.stackdump
# Folder config file
[Dd]esktop.ini
# Recycle Bin used on file shares
$RECYCLE.BIN/
# Windows Installer files
*.cab
*.msi
*.msix
*.msm
*.msp
# Windows shortcuts
*.lnk
# End of https://www.toptal.com/developers/gitignore/api/linux,windows,macos,visualstudiocode,python
output/

22
.pre-commit-config.yaml Normal file
View File

@@ -0,0 +1,22 @@
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.6.8
hooks:
# Run the linter.
- id: ruff
types_or: [python, pyi]
args: [--fix]
# Run isort
- id: ruff
types_or: [python, pyi]
name: ruff-isort
args: [--select, "I", --fix]
# Run the formatter.
- id: ruff-format
types_or: [python, pyi]
- repo: https://github.com/google/yamlfmt
rev: v0.10.0
hooks:
- id: yamlfmt
args: [--conf, .yamlfmt.yaml]

3
.yamlfmt.yaml Normal file
View File

@@ -0,0 +1,3 @@
formatter:
type: basic
retain_line_breaks: True

142
README.md Normal file
View File

@@ -0,0 +1,142 @@
# FLUX.2
by Black Forest Labs: https://bfl.ai.
Documentation for our API can be found here: [docs.bfl.ai](https://docs.bfl.ai/).
This repo contains minimal inference code to run image generation & editing with our FLUX.2 open-weight models.
## `FLUX.2 [dev]`
`FLUX.2 [dev]` is a 32B parameter flow matching transformer model capable of generating and editing (multiple) images. The model is released under the [FLUX.2-dev Non-Commercial License](model_licenses/LICENSE-FLUX-DEV) and can be found [here](https://huggingface.co/black-forest-labs/FLUX.2-dev).
Note that the below script for `FLUX.2 [dev]` needs considerable amount of VRAM (H100-equivalent GPU). We partnered with Hugging Face to make quantized versions that run on consumer hardware; below you can find instructions on how to run it on a RTX 4090 with a remote text encoder, for other quantization sizes and combinations, check the [diffusers quantization guide here](docs/flux2_dev_hf.md).
### Text-to-image examples
![t2i-grid](assets/teaser_generation.png)
### Editing examples
![edit-grid](assets/teaser_editing.png)
### Prompt upsampling
`FLUX.2 [dev]` benefits significantly from prompt upsampling. The inference script below offers the option to use both local prompt upsampling with the same model we use for text encoding ([`Mistral-Small-3.2-24B-Instruct-2506`](https://huggingface.co/mistralai/Mistral-Small-3.2-24B-Instruct-2506)), or alternatively, use any model on [OpenRouter](https://openrouter.ai/) via an API call.
See the [upsampling guide](docs/flux2_with_prompt_upsampling.md) for additional details and guidance on when to use upsampling.
## `FLUX.2` autoencoder
The FLUX.2 autoencoder has considerably improved over the [FLUX.1 autoencoder](https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/ae.safetensors). The autoencoder is released under [Apache 2.0](https://huggingface.co/datasets/choosealicense/licenses/blob/main/markdown/apache-2.0.md) and can be found [here](https://huggingface.co/black-forest-labs/FLUX.2-dev/blob/main/ae.safetensors). For more information, see our [technical blogpost](https://bfl.ai/blog/flux-2).
## Local installation
The inference code was tested on GB200 and H100 (with CPU offloading).
### GB200
On GB200, we tested `FLUX.2 [dev]` using CUDA 12.9 and Python 3.12.
```bash
python3.12 -m venv .venv
source .venv/bin/activate
pip install -e . --extra-index-url https://download.pytorch.org/whl/cu129 --no-cache-dir
```
### H100
On H100, we tested `FLUX.2 [dev]` using CUDA 12.6 and Python 3.10.
```bash
python3.10 -m venv .venv
source .venv/bin/activate
pip install -e . --extra-index-url https://download.pytorch.org/whl/cu126 --no-cache-dir
```
## Run the CLI
Before running the CLI, you may download the weights from [here](https://huggingface.co/black-forest-labs/FLUX.2-dev) and set the following environment variables.
```bash
export FLUX2_MODEL_PATH="<flux2_path>"
export AE_MODEL_PATH="<ae_path>"
```
If you don't set the environment variables, the weights will be downloaded
automatically.
You can start an interactive session with loaded weights by running the
following command. That will allow you to do both text to image generation as
well as editing one or multiple images.
```bash
export PYTHONPATH=src
python scripts/cli.py
```
On H100, we additionally set the flag `--cpu_offloading True`.
## Watermarking
We've added an option to embed invisible watermarks directly into the generated images
via the [invisible watermark library](https://github.com/ShieldMnt/invisible-watermark).
Additionally, we are recommending implementing a solution to mark the metadata of your outputs, such as [C2PA](https://c2pa.org/)
## 🧨 Lower VRAM diffusers example
The below example should run on a RTX 4090.
```python
import torch
from diffusers import Flux2Pipeline, Flux2Transformer2DModel
from diffusers.utils import load_image
from huggingface_hub import get_token
import requests
import io
repo_id = "diffusers/FLUX.2-dev-bnb-4bit"
device = "cuda:0"
torch_dtype = torch.bfloat16
def remote_text_encoder(prompts):
response = requests.post(
"https://remote-text-encoder-flux-2.huggingface.co/predict",
json={"prompt": prompts},
headers={
"Authorization": f"Bearer {get_token()}",
"Content-Type": "application/json"
}
)
prompt_embeds = torch.load(io.BytesIO(response.content))
return prompt_embeds.to(device)
pipe = Flux2Pipeline.from_pretrained(
repo_id, transformer=transformer, text_encoder=None, torch_dtype=torch_dtype
).to(device)
prompt = "Realistic macro photograph of a hermit crab using a soda can as its shell, partially emerging from the can, captured with sharp detail and natural colors, on a sunlit beach with soft shadows and a shallow depth of field, with blurred ocean waves in the background. The can has the text `BFL Diffusers` on it and it has a color gradient that start with #FF5733 at the top and transitions to #33FF57 at the bottom."
image = pipe(
prompt_embeds=remote_text_encoder(prompt),
#image=load_image("https://huggingface.co/spaces/zerogpu-aoti/FLUX.1-Kontext-Dev-fp8-dynamic/resolve/main/cat.png") #optional image input
generator=torch.Generator(device=device).manual_seed(42),
num_inference_steps=50, #28 steps can be a good trade-off
guidance_scale=4,
).images[0]
image.save("flux2_output.png")
```
## Citation
If you find the provided code or models useful for your research, consider citing them as:
```bib
@misc{flux-2-2025,
author={Black Forest Labs},
title={{FLUX.2: State-of-the-Art Visual Intelligence}},
year={2025},
howpublished={\url{https://bfl.ai/blog/flux-2}},
}
```

View File

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.7 MiB

View File

Binary file not shown.

After

Width:  |  Height:  |  Size: 290 KiB

View File

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.7 MiB

BIN
assets/teaser_editing.png Normal file
View File

Binary file not shown.

After

Width:  |  Height:  |  Size: 11 MiB

View File

Binary file not shown.

After

Width:  |  Height:  |  Size: 22 MiB

191
docs/flux2_dev_hf.md Normal file
View File

@@ -0,0 +1,191 @@
# 🧨 Running the model with diffusers
## Getting started
Install diffusers from `main`
```sh
pip install git+https://github.com/huggingface/diffusers.git
```
After accepting the gating on this repository, login with Hugging Face on your terminal
```sh
hf auth login
```
See below for inference instructions on different GPUs.
---
## 💾 Lower VRAM (~24-32G) - RTX 4090 and 5090
Those with 24-32GB of VRAM can use the model with **4-bit quantization**
### 4-bit transformer and remote text-encoder (~18G of VRAM)
The diffusers team is introducing a remote text-encoder for this release.
The text-embeddings are calculated in bf16 in the cloud and you only load the transformer into VRAM (this setting can get as low as ~18G of VRAM)
```py
import torch
from diffusers import Flux2Pipeline, Flux2Transformer2DModel
from diffusers.utils import load_image
from huggingface_hub import get_token
import requests
import io
repo_id = "diffusers/FLUX.2-dev-bnb-4bit"
device = "cuda:0"
torch_dtype = torch.bfloat16
def remote_text_encoder(prompts):
response = requests.post(
"https://remote-text-encoder-flux-2.huggingface.co/predict",
json={"prompt": prompts},
headers={
"Authorization": f"Bearer {get_token()}",
"Content-Type": "application/json"
}
)
prompt_embeds = torch.load(io.BytesIO(response.content))
return prompt_embeds.to(device)
pipe = Flux2Pipeline.from_pretrained(
repo_id, transformer=transformer, text_encoder=None, torch_dtype=torch_dtype
).to(device)
prompt = "Realistic macro photograph of a hermit crab using a soda can as its shell, partially emerging from the can, captured with sharp detail and natural colors, on a sunlit beach with soft shadows and a shallow depth of field, with blurred ocean waves in the background. The can has the text `BFL Diffusers` on it and it has a color gradient that start with #FF5733 at the top and transitions to #33FF57 at the bottom."
image = pipe(
prompt_embeds=remote_text_encoder(prompt),
#image=load_image("https://huggingface.co/spaces/zerogpu-aoti/FLUX.1-Kontext-Dev-fp8-dynamic/resolve/main/cat.png") #optional image input
generator=torch.Generator(device=device).manual_seed(42),
num_inference_steps=50, #28 steps can be a good trade-off
guidance_scale=4,
).images[0]
image.save("flux2_output.png")
```
### 4-bit transformer and 4-bit text-encoder (~20G of VRAM)
Load both the text-encoder and the transformer in 4-bit.
The text-encoder is offloaded from VRAM for the transformer to run with `pipe.enable_model_cpu_offload()`, making sure both will fit.
```py
import torch
from transformers import Mistral3ForConditionalGeneration
from diffusers import Flux2Pipeline, Flux2Transformer2DModel
repo_id = "diffusers/FLUX.2-dev-bnb-4bit"
device = "cuda:0"
torch_dtype = torch.bfloat16
transformer = Flux2Transformer2DModel.from_pretrained(
repo_id, subfolder="transformer", torch_dtype=torch_dtype, device_map="cpu"
)
text_encoder = Mistral3ForConditionalGeneration.from_pretrained(
repo_id, subfolder="text_encoder", dtype=torch_dtype, device_map="cpu"
)
pipe = Flux2Pipeline.from_pretrained(
repo_id, transformer=transformer, text_encoder=text_encoder, torch_dtype=torch_dtype
)
pipe.enable_model_cpu_offload()
prompt = "Realistic macro photograph of a hermit crab using a soda can as its shell, partially emerging from the can, captured with sharp detail and natural colors, on a sunlit beach with soft shadows and a shallow depth of field, with blurred ocean waves in the background. The can has the text `BFL + Diffusers` on it and it has a color gradient that start with #FF5733 at the top and transitions to #33FF57 at the bottom."
image = pipe(
prompt=prompt,
#image=[load_image("https://huggingface.co/spaces/zerogpu-aoti/FLUX.1-Kontext-Dev-fp8-dynamic/resolve/main/cat.png")] #multi-image input
generator=torch.Generator(device=device).manual_seed(42),
num_inference_steps=50,
guidance_scale=4,
).images[0]
image.save("flux2_output.png")
```
To understand how different quantizations affect the model's abilities and quality, access the [FLUX.2 on diffusers](https://huggingface.co/blog/flux2) blog
---
## 💿 More VRAM (80G+)
Even an H100 can't hold the text-encoder, transormer and VAE at the same time. However, here it is a matter of activating the `pipe.enable_model_cpu_offload()`
And for H200, B200 or larger carts, everything fits.
```py
import torch
from diffusers import Flux2Pipeline
repo_id = "black-forest-labs/FLUX.2-dev"
device = "cuda:0"
torch_dtype = torch.bfloat16
pipe = Flux2Pipeline.from_pretrained(
repo_id, torch_dtype=torch_dtype
)
pipe.enable_model_cpu_offload() #deactivate for >80G VRAM carts like H200, B200, etc. and do a `pipe.to(device)` instead
prompt = "Realistic macro photograph of a hermit crab using a soda can as its shell, partially emerging from the can, captured with sharp detail and natural colors, on a sunlit beach with soft shadows and a shallow depth of field, with blurred ocean waves in the background. The can has the text `BFL Diffusers` on it and it has a color gradient that start with #FF5733 at the top and transitions to #33FF57 at the bottom."
image = pipe(
prompt=prompt,
#image=[load_image("https://huggingface.co/spaces/zerogpu-aoti/FLUX.1-Kontext-Dev-fp8-dynamic/resolve/main/cat.png")] #multi-image input
generator=torch.Generator(device=device).manual_seed(42),
num_inference_steps=50,
guidance_scale=4,
).images[0]
image.save("flux2_output.png")
```
### Remote text-encoder + H100
`pipe.enable_model_cpu_offload()` slows you down a bit. You can move as fast as possible on the H100 with the remote text-encoder
```py
import torch
from diffusers import Flux2Pipeline, Flux2Transformer2DModel
from huggingface_hub import get_token
import requests
import io
repo_id = "black-forest-labs/FLUX.2-dev"
device = "cuda:0"
torch_dtype = torch.bfloat16
def remote_text_encoder(prompts):
response = requests.post(
"https://remote-text-encoder-flux-2.huggingface.co/predict",
json={"prompt": prompts},
headers={
"Authorization": f"Bearer {get_token()}",
"Content-Type": "application/json"
}
)
assert response.status_code == 200, f"{response.status_code=}"
prompt_embeds = torch.load(io.BytesIO(response.content))
return prompt_embeds.to(device)
pipe = Flux2Pipeline.from_pretrained(
repo_id, text_encoder=None, torch_dtype=torch_dtype
).to(device)
prompt = "Realistic macro photograph of a hermit crab using a soda can as its shell, partially emerging from the can, captured with sharp detail and natural colors, on a sunlit beach with soft shadows and a shallow depth of field, with blurred ocean waves in the background. The can has the text `BFL + Diffusers` on it and it has a color gradient that start with #FF5733 at the top and transitions to #33FF57 at the bottom."
image = pipe(
prompt_embeds=remote_text_encoder(prompt),
#image=[load_image("https://huggingface.co/spaces/zerogpu-aoti/FLUX.1-Kontext-Dev-fp8-dynamic/resolve/main/cat.png")] #optional multi-image input
generator=torch.Generator(device=device).manual_seed(42),
num_inference_steps=50,
guidance_scale=4,
).images[0]
image.save("flux2_output.png")
```
## 🧮 Other VRAM sizes
If you have different GPU sizes, you can experiment with different quantizations, for example, for 40-48G VRAM GPUs, (8-bit) quantization instead of 4-bit can be a good trade-off. You can learn more on the [diffusers FLUX.2 release blog](https://huggingface.co/blog/flux2)

View File

@@ -0,0 +1,73 @@
# Prompt upsampling with FLUX.2
Prompt upsampling uses a large vision language model to expand and enrich your prompts before generation, which can significantly improve results for reasoning-heavy and complex generation tasks.
## When to use prompt upsampling
Prompt upsampling is particularly effective for prompts requiring reasoning or complex interpretation:
- **Text generation in images**: Creating memes, posters, or images where the model needs to generate creative or contextually appropriate text
- **Image-based instructions**: Prompts where the input image contains overlaid text, arrows, or annotations that need to be interpreted (e.g., "follow the instructions in the image", "read the diagram and generate the result")
- **Code and math reasoning**: Generating visualizations of algorithms, mathematical concepts, or code flow diagrams where logical structure is important
For simple, direct prompts (e.g., "a red car"), prompt upsampling may not provide significant benefits.
## Methods
We provide two methods for prompt upsampling:
### 1. API-based prompt upsampling (recommended)
API-based prompt upsampling via [OpenRouter](https://openrouter.ai/) generally produces better results by leveraging more capable models.
Set your API key as an environment variable:
```bash
export OPENROUTER_API_KEY="<api_key>"
```
Then run the CLI with upsampling enabled:
```bash
export PYTHONPATH=src
python scripts/cli.py --upsample_prompt_mode=openrouter
```
You can switch between different models using `--openrouter_model=<model_name>`.
Alternatively, you can just start the CLI via
```bash
export PYTHONPATH=src
python scripts/cli.py
```
and choose your prompt upsampling model interactively.
**Example output:**
| Prompt: "Make a meme about generating memes with this model" |
|:---:|
| <img src="../assets/t2i_upsample_example.png" alt="Output" width="512"> |
### 2. Local prompt upsampling
Local prompt upsampling uses [`Mistral-Small-3.2-24B-Instruct-2506`](https://huggingface.co/mistralai/Mistral-Small-3.2-24B-Instruct-2506), which is the model we use for text encoding in `FLUX.2 [dev]`. This option requires no API keys but may produce less detailed expansions.
To enable local prompt upsampling, use `--upsample_prompt_mode=local`.
**Example:**
<table>
<tr>
<th colspan="2" style="text-align: center;">Prompt: "Describe what the red arrow is seeing"</th>
</tr>
<tr>
<th>Input</th>
<th>Output</th>
</tr>
<tr>
<td align="center"><img src="../assets/i2i_upsample_input.png" alt="Input image"></td>
<td align="center"><img src="../assets/i2i_upsample_example.png" alt="Output image"></td>
</tr>
</table>

85
model_cards/FLUX.2-dev.md Normal file
View File

@@ -0,0 +1,85 @@
![Teaser](../assets/teaser_generation.png)
![Teaser](../assets/teaser_editing.png)
`FLUX.2 [dev]` is a 32 billion parameter rectified flow transformer capable of generating, editing and combining images based on text instructions.
For more information, please read our [blog post](https://bfl.ai/blog/flux-2).
# Key Features
1. State of the art in open text-to-image generation, single-reference editing and multi-reference editing.
2. No need for finetuning: character, object and style reference without additional training in one model.
4. Trained using guidance distillation, making `FLUX.2 [dev]` more efficient.
5. Open weights to drive new scientific research, and empower artists to develop innovative workflows.
6. Generated outputs can be used for personal, scientific, and commercial purposes, as described in the [FLUX \[dev\] Non-Commercial License](https://github.com/black-forest-labs/flux/blob/main/model_licenses/LICENSE-FLUX1-dev).
# Usage
We provide a reference implementation of `FLUX.2 [dev]`, as well as sampling code, in a dedicated [github repository](https://github.com/black-forest-labs/flux2).
Developers and creatives looking to build on top of `FLUX.2 [dev]` are encouraged to use this as a starting point.
`FLUX.2 [dev]` is also available in both [ComfyUI](https://github.com/comfyanonymous/ComfyUI) and [Diffusers](https://github.com/huggingface/diffusers).
### Using with diffusers 🧨
For local deployment on a consumer type graphics card, like an RTX 4090 or an RTX 5090, please see the [diffusers docs](https://github.com/black-forest-labs/flux2/blob/main/docs/flux2_dev_hf.md) on our GitHub page.
As an example, here's a way to load a 4-bit quantized model with a remote text-encoder on an RTX 4090:
```python
import torch
from diffusers import Flux2Pipeline, Flux2Transformer2DModel
from diffusers.utils import load_image
from huggingface_hub import get_token
import requests
import io
repo_id = "diffusers/FLUX.2-dev-bnb-4bit"
device = "cuda:0"
torch_dtype = torch.bfloat16
def remote_text_encoder(prompts):
response = requests.post(
"https://remote-text-encoder-flux-2.huggingface.co/predict",
json={"prompt": prompts},
headers={
"Authorization": f"Bearer {get_token()}",
"Content-Type": "application/json"
}
)
prompt_embeds = torch.load(io.BytesIO(response.content))
return prompt_embeds.to(device)
pipe = Flux2Pipeline.from_pretrained(
repo_id, transformer=transformer, text_encoder=None, torch_dtype=torch_dtype
).to(device)
prompt = "Realistic macro photograph of a hermit crab using a soda can as its shell, partially emerging from the can, captured with sharp detail and natural colors, on a sunlit beach with soft shadows and a shallow depth of field, with blurred ocean waves in the background. The can has the text `BFL Diffusers` on it and it has a color gradient that start with #FF5733 at the top and transitions to #33FF57 at the bottom."
image = pipe(
prompt_embeds=remote_text_encoder(prompt),
#image=load_image("https://huggingface.co/spaces/zerogpu-aoti/FLUX.1-Kontext-Dev-fp8-dynamic/resolve/main/cat.png") #optional image input
generator=torch.Generator(device=device).manual_seed(42),
num_inference_steps=50, #28 steps can be a good trade-off
guidance_scale=4,
).images[0]
image.save("flux2_output.png")
```
---
# Risks
Black Forest Labs is committed to the responsible development and deployment of our models. Prior to releasing the FLUX.2 family of models, we evaluated and mitigated a number of risks in our model checkpoints and hosted services, including the generation of unlawful content such as child sexual abuse material (CSAM) and nonconsensual intimate imagery (NCII). We implemented a series of pre-release mitigations to help prevent misuse by third parties, with additional post-release mitigations to help address residual risks:
1. Pre-training mitigation. We filtered pre-training data for multiple categories of “not safe for work” (NSFW) and known child sexual abuse material (CSAM) to help prevent a user generating unlawful content in response to text prompts or uploaded images. We have partnered with the Internet Watch Foundation, an independent nonprofit organization dedicated to preventing online abuse, to filter known CSAM from the training data.
2. Post-training mitigation. Subsequently, we undertook multiple rounds of targeted fine-tuning to provide additional mitigation against potential abuse, including both text-to-image (T2I) and image-to-image (I2I) attacks. By inhibiting certain behaviors and suppressing certain concepts in the trained model, these techniques can help to prevent a user generating synthetic CSAM or NCII from a text prompt, or transforming an uploaded image into synthetic CSAM or NCII.
3. Ongoing evaluation. Throughout this process, we conducted multiple internal and external third-party evaluations of model checkpoints to identify further opportunities for mitigation. External third-party evaluations focused on eliciting CSAM and NCII through adversarial testing with (i) text-only prompts, (ii) a single uploaded reference image with text prompts, and (iii) multiple uploaded reference images with text prompts. Based on this feedback, we conducted further safety fine-tuning to produce our open-weight model (FLUX.2 [dev]).
4. Release decision. After safety fine-tuning and prior to release, we conducted a final third-party evaluation of the proposed release checkpoint, focused on T2I and I2I generation of synthetic CSAM and NCII, including a comparison with other open-weight T2I and I2I models (total prompts n≈2,800). The final FLUX.2 [dev] checkpoint demonstrated high resilience against violative inputs in complex generation and editing tasks, and demonstrated higher resilience than leading open-weight models across these risk categories. Based on these findings, we approved the release of the FLUX.2 Pro model via API and the release of the open-weight FLUX.2 [dev] model under a non-commercial license to support third-party research and development.
5. Inference filters. The repository for the FLUX.2 [dev] model includes filters for NSFW and IP-infringing content at input and output. Filters or manual review must be used with the model under the terms of the FLUX.2 [dev] Non-Commercial License. We may approach known deployers of the FLUX.2 [dev] model at random to verify that filters or manual review processes are in place. Additionally, we apply multiple filters to intercept text prompts, uploaded images, and output images on the API for FLUX.2 [pro]. We utilize both in-house and third-party supplied filters to prevent CSAM and NCII outputs, including filters provided by Hive and Microsoft. We provide filters for other categories of potentially harmful content, including gore, which can be adjusted by developers based on their specific risk profile and legitimate use cases.
6. Content provenance. Content provenance features can help users and platforms better identify, label, and interpret AI-generated content online. The inference code for FLUX.2 [dev] implements an example of pixel-layer watermarking, and this repository includes links to the Coalition for Content Provenance and Authenticity (C2PA) standard for metadata. The API for FLUX.2 Pro applies cryptographically-signed C2PA metadata to output content to indicate that images were produced with our model.
7. Policies. Use of our models and access to our API are governed by our FLUX [dev] Non-Commercial License (for our non-commercial open-weight users); Developer Terms of Service, Self-Hosted Commercial License Terms, and Usage Policy (for our commercial open-weight model users); and Developer Terms of Service, FLUX API Service Terms, and Usage Policy (for our API users). These prohibit the generation of unlawful content or the use of generated content for unlawful, defamatory, or abusive purposes. Developers and users must consent to these conditions to access the FLUX.2 [dev] model on Hugging Face.
8. Monitoring. We are monitoring for patterns of violative use after release. We continue to issue and escalate takedown requests to websites, services, or businesses that misuse our models. Additionally, we may ban users or developers who we detect intentionally and repeatedly violate our policies via the FLUX API. Additionally, we provide a dedicated email address (safety@blackforestlabs.ai) to solicit feedback from the community. We maintain a reporting relationship with organizations such as the Internet Watch Foundation and the National Center for Missing and Exploited Children, and welcome ongoing engagement with authorities, developers, and researchers to share intelligence about emerging risks and develop effective mitigations.
# License
This model falls under the [FLUX \[dev\] Non-Commercial License](https://huggingface.co/black-forest-labs/FLUX.2-dev/blob/main/LICENSE.txt).

View File

@@ -0,0 +1,56 @@
FLUX [dev] Non-Commercial License v2.0
Black Forest Labs Inc. (“we” or “our” or “Company”) is pleased to make the weights, parameters, and inference code for the FLUX [dev] Models (as defined below) freely available for your non-commercial and non-production use as set forth in this FLUX [dev] Non-Commercial License (“License”). “Models” includes the models denoted as FLUX.x [dev], where “.x” denotes the FLUX model version number, including but not limited to FLUX.1 [dev], FLUX.1 Fill [dev], FLUX.1 Depth [dev], FLUX.1 Canny [dev], FLUX.1 Redux [dev], FLUX.1 Canny [dev] LoRA, FLUX.1 Depth [dev] LoRA, FLUX.1 Kontext [dev], FLUX.1 Krea [dev], and FLUX.2 [dev], and their elements which includes algorithms, software, checkpoints, parameters, source code (inference code, evaluation code, and if applicable, fine-tuning code) and any other materials associated with the FLUX [dev] AI models made available by Company under this License, including if any, the technical documentation, manuals, and instructions for the use and operation thereof (individually and collectively, the “FLUX [dev] Models”). Note that we may also make available certain elements of what is included in the definition of “FLUX [dev] Model” under a separate license, such as the inference code, and nothing in this License will be deemed to restrict or limit any other licenses granted by us in such elements.
By downloading, accessing, using, Distributing (as defined below), or creating a Derivative (as defined below) of the FLUX [dev] Model, you agree to the terms of this License. If you do not agree to this License, then you do not have any rights to access, use, Distribute or create a Derivative of the FLUX [dev] Model and you must immediately cease using the FLUX [dev] Model. If you are agreeing to be bound by the terms of this License on behalf of your employer or other entity, you represent and warrant to us that you have full legal authority to bind your employer or such entity to this License. If you do not have the requisite authority, you may not accept the License or access the FLUX [dev] Model on behalf of your employer or other entity.
1. Definitions.
a. “Derivative” means any (i) modified version of the FLUX [dev] Model (including but not limited to any customized or fine-tuned version thereof), (ii) work based on the FLUX [dev] Model, or (iii) any other derivative work thereof. For the avoidance of doubt, Outputs are not considered Derivatives under this License.
b. “Distribution” or “Distribute” or “Distributing” means providing or making available, by any means, a copy of the FLUX [dev] Models and/or the Derivatives as the case may be.
c. “Non-Commercial Purpose” means any of the following uses, but only so far as you do not receive any direct or indirect payment arising from the use of the FLUX [dev] Model, Derivatives, or Content Filters (as defined below): (i) personal use for research, experiment, and testing for the benefit of public knowledge, personal study, private entertainment, hobby projects, or otherwise not directly or indirectly connected to any commercial activities, business operations, or employment responsibilities; (ii) use by commercial or for-profit entities for testing, evaluation, or non-commercial research and development in a non-production environment; and (iii) use by any charitable organization for charitable purposes, or for testing or evaluation. For clarity, use (a) for revenue-generating activity, (b) in direct interactions with or that has impact on end users, or (c) to train, fine tune, or distill other models for commercial use, in each case, is not a Non-Commercial Purpose.
d. “Outputs” means any content generated by the operation of the FLUX [dev] Models or Derivatives from an input (such as an image input) or prompt (i.e., text instructions) provided by users. For the avoidance of doubt, Outputs do not include any components of the FLUX [dev] Models, such as any fine-tuned versions of the FLUX [dev] Models, the weights, or parameters.
e. “you” or “your” means the individual or entity entering into this License with Company.
2. License Grant.
a. License. Subject to your compliance with this License, Company grants you a non-exclusive, worldwide, non-transferable, non-sublicensable, revocable, royalty free, and limited license to access, use, create Derivatives of, and Distribute the FLUX [dev] Models and Derivatives solely for your Non-Commercial Purposes. The foregoing license is personal to you, and you may not assign or sublicense this License or any other rights or obligations under this License without Companys prior written consent; any such assignment or sublicense will be void and will automatically and immediately terminate this License. Any restrictions set forth herein regarding the FLUX [dev] Model also apply to any Derivative you create or that are created on your behalf.
b. Non-Commercial Use Only. You may only access, use, Distribute, or create Derivatives of the FLUX [dev] Model or Derivatives for Non-Commercial Purposes. If you want to use a FLUX [dev] Model or a Derivative for any purpose that is not expressly authorized under this License, such as for a commercial activity, you must request a license from Company, which Company may grant to you in Companys sole discretion and which additional use may be subject to a fee, royalty or other revenue share. Please see www.bfl.ai if you would like a commercial license.
c. Reserved Rights. The grant of rights expressly set forth in this License are the complete grant of rights to you in the FLUX [dev] Model, and no other licenses are granted, whether by waiver, estoppel, implication, equity, or otherwise. Company and its licensors reserve all rights not expressly granted by this License.
d. Outputs. We claim no ownership rights in and to the Outputs. You are solely responsible for the Outputs you generate and their subsequent uses in accordance with this License. You may use Output for any purpose (including for commercial purposes), except as expressly prohibited herein. You may not use the Output to train, fine-tune, or distill a model that is competitive with a FLUX [dev] Model.
e. You may access, use, Distribute, or create Output of the FLUX [dev] Model or Derivatives if you: (i) (A) implement and maintain content filtering measures (“Content Filters”) for your use of the FLUX [dev] Model or Derivatives to prevent the creation, display, transmission, generation, or dissemination of unlawful or infringing content, which may include Content Filters that we may make available for use with the FLUX [dev] Model (“Provided Content Filters”), or (B) ensure Output undergoes review for unlawful or infringing content before public or non-public distribution, display, transmission or dissemination; and (ii) ensure Output includes disclosure (or other indication) that the Output was generated or modified using artificial intelligence technologies to the extent required under applicable law.
3. Distribution. Subject to this License, you may Distribute copies of the FLUX [dev] Model and/or Derivatives made by you, under the following conditions:
a. you must make available a copy of this License to third-party recipients of the FLUX [dev] Models and/or Derivatives you Distribute, and specify that any rights to use the FLUX [dev] Models and/or Derivatives shall be directly granted by Company to said third-party recipients pursuant to this License;
b. you must prominently display the following notice alongside the Distribution of the FLUX [dev] Model or Derivative (such as via a “Notice” text file distributed as part of such FLUX [dev] Model or Derivative) (the “Attribution Notice”):
“The FLUX [dev] Model is licensed by Black Forest Labs Inc. under the FLUX [dev] Non-Commercial License. Copyright Black Forest Labs Inc.
IN NO EVENT SHALL BLACK FOREST LABS INC. BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH USE OF THIS MODEL.”
c. in the case of Distribution of Derivatives made by you: (i) you must also include in the Attribution Notice a statement that you have modified the applicable FLUX [dev] Model; (ii) any terms and conditions you impose on any third-party recipients relating to Derivatives made by or for you shall neither limit such third-party recipients use of the FLUX [dev] Model or any Derivatives made by or for Company in accordance with this License nor conflict with any of its terms and conditions and must include disclaimer of warranties and limitation of liability provisions that are at least as protective of Company as those set forth herein; and (iii) you must not misrepresent or imply, through any means, that the Derivatives made by or for you and/or any modified version of the FLUX [dev] Model you Distribute under your name and responsibility is an official product of the Company or has been endorsed, approved or validated by the Company, unless you are authorized by Company to do so in writing.
4. Restrictions. You will not, and will not permit, assist or cause any third party to
a. use, modify, copy, reproduce, create Derivatives of, or Distribute the FLUX [dev] Model (or any Derivative thereof, or any data produced by the FLUX [dev] Model), in whole or in part, (i) for any commercial or production purposes, (ii) military purposes, (iii) purposes of surveillance, including any research or development relating to surveillance, (iv) biometric processing, (v) in any manner that infringes, misappropriates, or otherwise violates (or is likely to infringe, misappropriate, or otherwise violate) any third partys legal rights, including rights of publicity or “digital replica” rights, (vi) in any unlawful, fraudulent, defamatory, or abusive activity, (vii) to generate unlawful content, including child sexual abuse material, or non-consensual intimate images; or (viii) in any manner that violates any applicable law and any privacy or security laws, rules, regulations, directives, or governmental requirements (including the General Data Privacy Regulation (Regulation (EU) 2016/679), the California Consumer Privacy Act, any and all laws governing the processing of biometric information, and the EU Artificial Intelligence Act (Regulation (EU) 2024/1689), as well as all amendments and successor laws to any of the foregoing);
b. alter or remove copyright and other proprietary notices which appear on or in any portion of the FLUX [dev] Model;
c. utilize any equipment, device, software, or other means to circumvent or remove any security or protection used by Company in connection with the FLUX [dev] Model, or to circumvent or remove any usage restrictions, or to enable functionality disabled by FLUX [dev] Model;
d. offer or impose any terms on the FLUX [dev] Model that alter, restrict, or are inconsistent with the terms of this License;
e. violate any applicable U.S. and non-U.S. export control and trade sanctions laws (“Export Laws”) in connection with your use or Distribution of any FLUX [dev] Model;
f. directly or indirectly Distribute, export, or otherwise transfer FLUX [dev] Model (i) to any individual, entity, or country prohibited by Export Laws; (ii) to anyone on U.S. or non-U.S. government restricted parties lists; (iii) for any purpose prohibited by Export Laws, including nuclear, chemical or biological weapons, or missile technology applications; (iv) use or download FLUX [dev] Model if you or they are (a) located in a comprehensively sanctioned jurisdiction, (b) currently listed on any U.S. or non-U.S. restricted parties list, or (c) for any purpose prohibited by Export Laws; and (v) will not disguise your location through IP proxying or other methods.
5. DISCLAIMERS. THE FLUX [DEV] MODEL AND PROVIDED CONTENT FILTERS ARE PROVIDED “AS IS” AND “WITH ALL FAULTS” WITH NO WARRANTY OF ANY KIND, EXPRESS OR IMPLIED. COMPANY EXPRESSLY DISCLAIMS ALL REPRESENTATIONS AND WARRANTIES, EXPRESS OR IMPLIED, WHETHER BY STATUTE, CUSTOM, USAGE OR OTHERWISE AS TO ANY MATTERS RELATED TO THE FLUX [DEV] MODEL AND PROVIDED CONTENT FILTERS, INCLUDING BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE, SATISFACTORY QUALITY, OR NON-INFRINGEMENT. COMPANY MAKES NO WARRANTIES OR REPRESENTATIONS THAT THE FLUX [DEV] MODEL AND PROVIDED CONTENT FILTERS WILL BE ERROR FREE OR FREE OF VIRUSES OR OTHER HARMFUL COMPONENTS, OR PRODUCE ANY PARTICULAR RESULTS.
6. LIMITATION OF LIABILITY. TO THE FULLEST EXTENT PERMITTED BY LAW, IN NO EVENT WILL COMPANY BE LIABLE TO YOU OR YOUR EMPLOYEES, AFFILIATES, USERS, OFFICERS OR DIRECTORS (A) UNDER ANY THEORY OF LIABILITY, WHETHER BASED IN CONTRACT, TORT, NEGLIGENCE, STRICT LIABILITY, WARRANTY, OR OTHERWISE UNDER THIS LICENSE, OR (B) FOR ANY INDIRECT, CONSEQUENTIAL, EXEMPLARY, INCIDENTAL, PUNITIVE OR SPECIAL DAMAGES OR LOST PROFITS, EVEN IF COMPANY HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES. THE FLUX [DEV] MODEL, ITS CONSTITUENT COMPONENTS, PROVIDED CONTENT FILTERS, AND ANY OUTPUT (COLLECTIVELY, “MODEL MATERIALS”) ARE NOT DESIGNED OR INTENDED FOR USE IN ANY APPLICATION OR SITUATION WHERE FAILURE OR FAULT OF THE MODEL MATERIALS COULD REASONABLY BE ANTICIPATED TO LEAD TO SERIOUS INJURY OF ANY PERSON, INCLUDING POTENTIAL DISCRIMINATION OR VIOLATION OF AN INDIVIDUALS PRIVACY RIGHTS, OR TO SEVERE PHYSICAL, PROPERTY, OR ENVIRONMENTAL DAMAGE (EACH, A “HIGH-RISK USE”). IF YOU ELECT TO USE ANY OF THE MODEL MATERIALS FOR A HIGH-RISK USE, YOU DO SO AT YOUR OWN RISK. YOU AGREE TO DESIGN AND IMPLEMENT APPROPRIATE DECISION-MAKING AND RISK-MITIGATION PROCEDURES AND POLICIES IN CONNECTION WITH A HIGH-RISK USE SUCH THAT EVEN IF THERE IS A FAILURE OR FAULT IN ANY OF THE MODEL MATERIALS, THE SAFETY OF PERSONS OR PROPERTY AFFECTED BY THE ACTIVITY STAYS AT A LEVEL THAT IS REASONABLE, APPROPRIATE, AND LAWFUL FOR THE FIELD OF THE HIGH-RISK USE.
7. INDEMNIFICATION. You will indemnify, defend and hold harmless Company and our subsidiaries and affiliates, and each of our respective shareholders, directors, officers, employees, agents, successors, and assigns (collectively, the “Company Parties”) from and against any losses, liabilities, damages, fines, penalties, and expenses (including reasonable attorneys fees) incurred by any Company Party in connection with any claim, demand, allegation, lawsuit, proceeding, or investigation (collectively, “Claims”) arising out of or related to (a) your access to or use of the FLUX [dev] Model (including in connection with any Output, results or data generated from such access or use, or from your access or use of any Content Filters), including any High-Risk Use; (b) your Content Filters, including your failure to implement any Content Filters where required by this License such as in Section 2(e); (c) your violation of this License; or (d) your violation, misappropriation or infringement of any rights of another (including intellectual property or other proprietary rights and privacy rights). You will promptly notify the Company Parties of any such Claims, and cooperate with Company Parties in defending such Claims. You will also grant the Company Parties sole control of the defense or settlement, at Companys sole option, of any Claims. This indemnity is in addition to, and not in lieu of, any other indemnities or remedies set forth in a written agreement between you and Company or the other Company Parties.
8. Termination; Survival.
a. This License will automatically terminate upon any breach by you of the terms of this License.
b. We may terminate this License, in whole or in part, at any time upon notice (including electronic) to you.
c. If you initiate any legal action or proceedings against Company or any other entity (including a cross-claim or counterclaim in a lawsuit), alleging that the FLUX [dev] Model, any Derivative, or Provided Content Filters, or any part thereof, infringe upon intellectual property or other rights owned or licensable by you, then any licenses granted to you under this License will immediately terminate as of the date such legal action or claim is filed or initiated.
d. Upon termination of this License, you must cease all use, access or Distribution of the FLUX [dev] Model, any Derivatives, and any Provided Content Filters. The following sections survive termination of this License: 2(c), 2(d), 4-11.
9. Third Party Materials. The FLUX [dev] Model and Provided Content Filters may contain third-party software or other components (including free and open source software) (all of the foregoing, “Third Party Materials”), which are subject to the license terms of the respective third-party licensors. Your dealings or correspondence with third parties and your use of or interaction with any Third Party Materials are solely between you and the third party. Company does not control or endorse, and makes no representations or warranties regarding, any Third Party Materials, and your access to and use of such Third Party Materials are at your own risk.
10. Trademarks. You have not been granted any trademark license as part of this License and may not use any name, logo or trademark associated with Company without the prior written permission of Company, except to the extent necessary to make the reference required in the Attribution Notice as specified above or as is reasonably necessary in describing the FLUX [dev] Model and its creators.
11. General. This License will be governed and construed under the laws of the State of Delaware without regard to conflicts of law provisions. If any provision or part of a provision of this License is unlawful, void or unenforceable, that provision or part of the provision is deemed severed from this License, and will not affect the validity and enforceability of any remaining provisions. The failure of Company to exercise or enforce any right or provision of this License will not operate as a waiver of such right or provision. This License does not confer any third-party beneficiary rights upon any other person or entity. This License, together with the documentation, contains the entire understanding between you and Company regarding the subject matter of this License, and supersedes all other written or oral agreements and understandings between you and Company regarding such subject matter.

44
pyproject.toml Normal file
View File

@@ -0,0 +1,44 @@
[project]
name = "flux"
version = "0.1.0"
description = "Inference codebase for FLUX.2"
readme = "README.md"
requires-python = ">=3.10,<3.13"
authors = [
{ name = "Black Forest Labs", email = "support@blackforestlabs.ai" }
]
license = { file = "LICENSE.md" }
dependencies = [
"torch==2.8.0",
"torchvision==0.23.0",
"einops==0.8.1",
"transformers==4.56.1",
"safetensors==0.4.5",
"fire==0.7.1",
"openai==2.8.1",
]
[project.optional-dependencies]
dev = [
"ruff==0.6.8",
]
[build-system]
requires = ["setuptools>=64", "wheel"]
build-backend = "setuptools.build_meta"
[tool.ruff]
line-length = 110
target-version = "py312"
[tool.ruff.lint]
ignore = [
"E501", # line too long (handled by formatter)
]
[tool.ruff.format]
quote-style = "double"
indent-style = "space"
line-ending = "auto"
docstring-code-format = true

525
scripts/cli.py Normal file
View File

@@ -0,0 +1,525 @@
import json
import os
import random
import shlex
import sys
from dataclasses import asdict, dataclass, field
from pathlib import Path
from typing import Any, Dict, List, Literal, Optional
import torch
from einops import rearrange
from PIL import ExifTags, Image
from flux2.openrouter_api_client import DEFAULT_SAMPLING_PARAMS, OpenRouterAPIClient
from flux2.sampling import (
batched_prc_img,
batched_prc_txt,
denoise,
encode_image_refs,
get_schedule,
scatter_ids,
)
from flux2.util import FLUX2_MODEL_INFO, load_ae, load_flow_model, load_mistral_small_embedder
# from flux2.watermark import embed_watermark
@dataclass
class Config:
prompt: str = "a photo of a forest with mist swirling around the tree trunks. The word 'FLUX.2' is painted over it in big, red brush strokes with visible texture"
seed: Optional[int] = None
width: int = 1360
height: int = 768
num_steps: int = 50
guidance: float = 4.0
input_images: List[Path] = field(default_factory=list)
match_image_size: Optional[int] = None # Index of input_images to match size from
upsample_prompt_mode: Literal["none", "local", "openrouter"] = "none"
openrouter_model: str = "mistralai/pixtral-large-2411" # OpenRouter model name
def copy(self) -> "Config":
return Config(
prompt=self.prompt,
seed=self.seed,
width=self.width,
height=self.height,
num_steps=self.num_steps,
guidance=self.guidance,
input_images=list(self.input_images),
match_image_size=self.match_image_size,
upsample_prompt_mode=self.upsample_prompt_mode,
openrouter_model=self.openrouter_model,
)
DEFAULTS = Config()
INT_FIELDS = {"width", "height", "seed", "num_steps", "match_image_size"}
FLOAT_FIELDS = {"guidance"}
LIST_FIELDS = {"input_images"}
UPSAMPLING_MODE_FIELDS = ("none", "local", "openrouter")
STR_FIELDS = {"openrouter_model"}
def coerce_value(key: str, raw: str):
"""Convert a raw string to the correct field type."""
if key in INT_FIELDS:
if raw.lower() == "none" or raw == "":
return None
return int(raw)
if key in FLOAT_FIELDS:
return float(raw)
if key in STR_FIELDS:
return raw.strip().strip('"').strip("'")
if key in LIST_FIELDS:
# Handle empty list cases
if raw == "" or raw == "[]":
return []
# Accept comma-separated or space-separated; strip quotes.
items = []
# If user passed a single token that contains commas, split on commas.
tokens = [raw] if ("," in raw and " " not in raw) else shlex.split(raw)
for tok in tokens:
for part in tok.split(","):
part = part.strip()
if part:
if os.path.exists(part):
items.append(Path(part))
else:
print(f"File {part} not found. Skipping for now. Please check your path")
return items
if key == "upsample_prompt_mode":
v = str(raw).strip().strip('"').strip("'").lower()
if v in UPSAMPLING_MODE_FIELDS:
return v
raise ValueError(
f"invalid upsample_prompt_mode: {v}. Must be one of: {', '.join(UPSAMPLING_MODE_FIELDS)}"
)
# plain strings
return raw
def apply_updates(cfg: Config, updates: Dict[str, Any]) -> None:
for k, v in updates.items():
if not hasattr(cfg, k):
print(f" ! unknown key: {k}", file=sys.stderr)
continue
# Validate upsample_prompt_mode
if k == "upsample_prompt_mode":
valid_modes = {"none", "local", "openrouter"}
if v not in valid_modes:
print(
f" ! Invalid upsample_prompt_mode: {v}. Must be one of: {', '.join(valid_modes)}",
file=sys.stderr,
)
continue
setattr(cfg, k, v)
def parse_key_values(line: str) -> Dict[str, Any]:
"""
Parse shell-like 'key=value' pairs. Values can be quoted.
Example: prompt="a dog" width=768 input_images="in1.png,in2.jpg"
"""
updates: Dict[str, Any] = {}
for token in shlex.split(line):
if "=" not in token:
# Allow bare commands like: run, show, reset, quit
updates[token] = True
continue
key, val = token.split("=", 1)
key = key.strip()
val = val.strip()
try:
updates[key] = coerce_value(key, val)
except Exception as e:
print(f" ! could not parse {key}={val!r}: {e}", file=sys.stderr)
return updates
def print_config(cfg: Config):
d = asdict(cfg)
d["input_images"] = [str(p) for p in cfg.input_images]
print("Current config:")
for k in [
"prompt",
"seed",
"width",
"height",
"num_steps",
"guidance",
"input_images",
"match_image_size",
"upsample_prompt_mode",
"openrouter_model",
]:
print(f" {k}: {d[k]}")
print()
def print_help():
print("""
Available commands:
[Enter] - Run generation with current config
run - Run generation with current config
show - Show current configuration
reset - Reset configuration to defaults
help, h, ? - Show this help message
quit, q, exit - Exit the program
Setting parameters:
key=value - Update a config parameter (shows updated config, doesn't run)
Examples:
prompt="a cat in a hat"
width=768 height=768
seed=42
num_steps=30
guidance=3.5
input_images="img1.jpg,img2.jpg"
match_image_size=0 (use dimensions from first input image)
upsample_prompt_mode="none" (prompt upsampling mode: "none", "local", or "openrouter")
openrouter_model="mistralai/pixtral-large-2411" (OpenRouter model name)
You can combine parameter updates:
prompt="sunset" width=1920 height=1080
Parameters:
prompt - Text prompt for generation (string)
seed - Random seed (integer or 'none' for random)
width - Output width in pixels (integer)
height - Output height in pixels (integer)
num_steps - Number of denoising steps (integer)
guidance - Guidance scale (float)
input_images - Comma-separated list of input image paths (list)
match_image_size - Index of input image to match dimensions from (integer, 0-based)
upsample_prompt_mode - Prompt upsampling mode: "none" (default), "local", or "openrouter" (string)
openrouter_model - OpenRouter model name (string, default: "mistralai/pixtral-large-2411")
Examples: "mistralai/pixtral-large-2411", "qwen/qwen3-vl-235b-a22b-instruct", etc.
Note: For "openrouter" mode, set OPENROUTER_API_KEY environment variable
""")
# ---------- Main Loop ----------
def main(
model_name: str = "flux.2-dev",
single_eval: bool = False,
prompt: str | None = None,
debug_mode: bool = False,
cpu_offloading: bool = False,
**overwrite,
):
assert (
model_name.lower() in FLUX2_MODEL_INFO
), f"{model_name} is not available, choose from {FLUX2_MODEL_INFO.keys()}"
torch_device = torch.device("cuda")
mistral = load_mistral_small_embedder()
model = load_flow_model(
model_name, debug_mode=debug_mode, device="cpu" if cpu_offloading else torch_device
)
ae = load_ae(model_name)
ae.eval()
mistral.eval()
# API client will be initialized lazily when needed
openrouter_api_client: Optional[OpenRouterAPIClient] = None
cfg = DEFAULTS.copy()
changes = [f"{key}={value}" for key, value in overwrite.items()]
updates = parse_key_values(" ".join(changes))
apply_updates(cfg, updates)
if prompt is not None:
cfg.prompt = prompt
print_config(cfg)
while True:
if not single_eval:
try:
line = input("> ").strip()
except (EOFError, KeyboardInterrupt):
print("\nbye!")
break
if not line:
# Empty -> run with current config
cmd = "run"
updates = {}
else:
try:
updates = parse_key_values(line)
except Exception as e: # noqa: BLE001
print(f" ! Failed to parse command: {type(e).__name__}: {e}", file=sys.stderr)
print(
" ! Please check your syntax (e.g., matching quotes) and try again.\n",
file=sys.stderr,
)
continue
if "prompt" in updates and mistral.test_txt(updates["prompt"]):
print(
"Your prompt has been flagged for potential copyright or public personas concerns. Please choose another."
)
updates.pop("prompt")
if "input_images" in updates:
flagged = False
for image in updates["input_images"]:
if mistral.test_image(image):
print(f"The image {image} has been flagged as unsuitable. Please choose another.")
flagged = True
if flagged:
updates.pop("input_images")
# If the line was only 'run' / 'show' / ... it will appear as {cmd: True}
# If it had key=val pairs, there may be no bare command -> just update config
bare_cmds = [k for k, v in updates.items() if v is True and k.isalpha()]
cmd = bare_cmds[0] if bare_cmds else None
# Remove bare commands from updates so they don't get applied as fields
for c in bare_cmds:
updates.pop(c, None)
if cmd in ("quit", "q", "exit"):
print("bye!")
break
elif cmd == "reset":
cfg = DEFAULTS.copy()
print_config(cfg)
continue
elif cmd == "show":
print_config(cfg)
continue
elif cmd in ("help", "h", "?"):
print_help()
continue
# Apply key=value changes
if updates:
apply_updates(cfg, updates)
print_config(cfg)
continue
# Only run if explicitly requested (empty line or 'run' command)
if cmd != "run":
if cmd is not None:
print(f" ! Unknown command: '{cmd}'", file=sys.stderr)
print(" ! Type 'help' to see available commands.\n", file=sys.stderr)
continue
try:
# Load input images first to potentially match dimensions
img_ctx = [Image.open(input_image) for input_image in cfg.input_images]
# Apply match_image_size if specified
width = cfg.width
height = cfg.height
if cfg.match_image_size is not None:
if cfg.match_image_size < 0 or cfg.match_image_size >= len(img_ctx):
print(
f" ! match_image_size={cfg.match_image_size} is out of range (0-{len(img_ctx)-1})",
file=sys.stderr,
)
print(f" ! Using default dimensions: {width}x{height}", file=sys.stderr)
else:
ref_img = img_ctx[cfg.match_image_size]
width, height = ref_img.size
print(f" Matched dimensions from image {cfg.match_image_size}: {width}x{height}")
seed = cfg.seed if cfg.seed is not None else random.randrange(2**31)
dir = Path("output")
dir.mkdir(exist_ok=True)
output_name = dir / f"sample_{len(list(dir.glob('*')))}.png"
with torch.no_grad():
ref_tokens, ref_ids = encode_image_refs(ae, img_ctx)
if cfg.upsample_prompt_mode == "openrouter":
try:
# Ensure API key is available, otherwise prompt the user
api_key = os.environ.get("OPENROUTER_API_KEY", "").strip()
if not api_key:
try:
entered = input(
"OPENROUTER_API_KEY not set. Enter it now (leave blank to skip OpenRouter upsampling): "
).strip()
except (EOFError, KeyboardInterrupt):
entered = ""
if entered:
os.environ["OPENROUTER_API_KEY"] = entered
else:
print(
" ! No API key provided; disabling OpenRouter upsampling",
file=sys.stderr,
)
cfg.upsample_prompt_mode = "none"
prompt = cfg.prompt
# Skip OpenRouter flow
# Only proceed if still in openrouter mode (not disabled above)
if cfg.upsample_prompt_mode == "openrouter":
# Let user specify sampling params, or use model defaults if available
sampling_params_input = ""
try:
sampling_params_input = input(
"Enter OpenRouter sampling params as JSON or key=value (blank to use defaults): "
).strip()
except (EOFError, KeyboardInterrupt):
sampling_params_input = ""
sampling_params: Dict[str, Any] = {}
if sampling_params_input:
# Try JSON first
parsed_ok = False
try:
parsed = json.loads(sampling_params_input)
if isinstance(parsed, dict):
sampling_params = parsed
parsed_ok = True
except Exception:
parsed_ok = False
if not parsed_ok:
# Fallback: parse key=value pairs separated by spaces or commas
tokens = [
tok
for tok in sampling_params_input.replace(",", " ").split(" ")
if tok
]
for tok in tokens:
if "=" not in tok:
continue
k, v = tok.split("=", 1)
v_str = v.strip()
v_low = v_str.lower()
if v_low in {"true", "false"}:
val: Any = v_low == "true"
else:
try:
if "." in v_str:
num = float(v_str)
val = int(num) if num.is_integer() else num
else:
val = int(v_str)
except Exception:
val = v_str
sampling_params[k.strip()] = val
print(f" Using custom OpenRouter sampling params: {sampling_params}")
else:
model_key = cfg.openrouter_model
default_params = DEFAULT_SAMPLING_PARAMS.get(model_key)
if default_params:
sampling_params = default_params
print(
f" Using default OpenRouter sampling params for {model_key}: {sampling_params}"
)
else:
print(
f" Setting no OpenRouter sampling params: not set for this model ({model_key})"
)
# Initialize or reinitialize client if model changed
if (
openrouter_api_client is None
or openrouter_api_client.model != cfg.openrouter_model
or getattr(openrouter_api_client, "sampling_params", None) != sampling_params
):
openrouter_api_client = OpenRouterAPIClient(
model=cfg.openrouter_model,
sampling_params=sampling_params,
)
else:
# Ensure client uses latest sampling params
openrouter_api_client.sampling_params = sampling_params
upsampled_prompts = openrouter_api_client.upsample_prompt(
[cfg.prompt], img=[img_ctx] if img_ctx else None
)
prompt = upsampled_prompts[0] if upsampled_prompts else cfg.prompt
except Exception as e:
print(f" ! Failed to upsample prompt via OpenRouter API: {e}", file=sys.stderr)
print(
" ! Disabling OpenRouter upsampling and falling back to original prompt",
file=sys.stderr,
)
cfg.upsample_prompt_mode = "none"
prompt = cfg.prompt
elif cfg.upsample_prompt_mode == "local":
# Use local model for upsampling
upsampled_prompts = mistral.upsample_prompt(
[cfg.prompt], img=[img_ctx] if img_ctx else None
)
prompt = upsampled_prompts[0] if upsampled_prompts else cfg.prompt
else:
# upsample_prompt_mode == "none" or invalid value
prompt = cfg.prompt
print("Generating with prompt: ", prompt)
ctx = mistral([prompt]).to(torch.bfloat16)
ctx, ctx_ids = batched_prc_txt(ctx)
if cpu_offloading:
mistral = mistral.cpu()
torch.cuda.empty_cache()
model = model.to(torch_device)
# Create noise
shape = (1, 128, height // 16, width // 16)
generator = torch.Generator(device="cuda").manual_seed(seed)
randn = torch.randn(shape, generator=generator, dtype=torch.bfloat16, device="cuda")
x, x_ids = batched_prc_img(randn)
timesteps = get_schedule(cfg.num_steps, x.shape[1])
x = denoise(
model,
x,
x_ids,
ctx,
ctx_ids,
timesteps=timesteps,
guidance=cfg.guidance,
img_cond_seq=ref_tokens,
img_cond_seq_ids=ref_ids,
)
x = torch.cat(scatter_ids(x, x_ids)).squeeze(2)
x = ae.decode(x).float()
# x = embed_watermark(x)
if cpu_offloading:
model = model.cpu()
torch.cuda.empty_cache()
mistral = mistral.to(torch_device)
x = x.clamp(-1, 1)
x = rearrange(x[0], "c h w -> h w c")
img = Image.fromarray((127.5 * (x + 1.0)).cpu().byte().numpy())
if mistral.test_image(img):
print("Your output has been flagged. Please choose another prompt / input image combination")
else:
exif_data = Image.Exif()
exif_data[ExifTags.Base.Software] = "AI generated;flux2"
exif_data[ExifTags.Base.Make] = "Black Forest Labs"
img.save(output_name, exif=exif_data, quality=95, subsampling=0)
print(f"Saved {output_name}")
except Exception as e: # noqa: BLE001
print(f"\n ERROR: {type(e).__name__}: {e}", file=sys.stderr)
print(" The model is still loaded. Please fix the error and try again.\n", file=sys.stderr)
if single_eval:
break
if __name__ == "__main__":
from fire import Fire
Fire(main)

336
src/flux2/autoencoder.py Normal file
View File

@@ -0,0 +1,336 @@
import math
from dataclasses import dataclass, field
import torch
from einops import rearrange
from torch import Tensor, nn
@dataclass
class AutoEncoderParams:
resolution: int = 256
in_channels: int = 3
ch: int = 128
out_ch: int = 3
ch_mult: list[int] = field(default_factory=lambda: [1, 2, 4, 4])
num_res_blocks: int = 2
z_channels: int = 32
def swish(x: Tensor) -> Tensor:
return x * torch.sigmoid(x)
class AttnBlock(nn.Module):
def __init__(self, in_channels: int):
super().__init__()
self.in_channels = in_channels
self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1)
self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1)
self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1)
self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1)
def attention(self, h_: Tensor) -> Tensor:
h_ = self.norm(h_)
q = self.q(h_)
k = self.k(h_)
v = self.v(h_)
b, c, h, w = q.shape
q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous()
k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous()
v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous()
h_ = nn.functional.scaled_dot_product_attention(q, k, v)
return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b)
def forward(self, x: Tensor) -> Tensor:
return x + self.proj_out(self.attention(x))
class ResnetBlock(nn.Module):
def __init__(self, in_channels: int, out_channels: int):
super().__init__()
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True)
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
if self.in_channels != self.out_channels:
self.nin_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
def forward(self, x):
h = x
h = self.norm1(h)
h = swish(h)
h = self.conv1(h)
h = self.norm2(h)
h = swish(h)
h = self.conv2(h)
if self.in_channels != self.out_channels:
x = self.nin_shortcut(x)
return x + h
class Downsample(nn.Module):
def __init__(self, in_channels: int):
super().__init__()
# no asymmetric padding in torch conv, must do it ourselves
self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
def forward(self, x: Tensor):
pad = (0, 1, 0, 1)
x = nn.functional.pad(x, pad, mode="constant", value=0)
x = self.conv(x)
return x
class Upsample(nn.Module):
def __init__(self, in_channels: int):
super().__init__()
self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
def forward(self, x: Tensor):
x = nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
x = self.conv(x)
return x
class Encoder(nn.Module):
def __init__(
self,
resolution: int,
in_channels: int,
ch: int,
ch_mult: list[int],
num_res_blocks: int,
z_channels: int,
):
super().__init__()
self.quant_conv = torch.nn.Conv2d(2 * z_channels, 2 * z_channels, 1)
self.ch = ch
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.in_channels = in_channels
# downsampling
self.conv_in = nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1)
curr_res = resolution
in_ch_mult = (1,) + tuple(ch_mult)
self.in_ch_mult = in_ch_mult
self.down = nn.ModuleList()
block_in = self.ch
for i_level in range(self.num_resolutions):
block = nn.ModuleList()
attn = nn.ModuleList()
block_in = ch * in_ch_mult[i_level]
block_out = ch * ch_mult[i_level]
for _ in range(self.num_res_blocks):
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
block_in = block_out
down = nn.Module()
down.block = block
down.attn = attn
if i_level != self.num_resolutions - 1:
down.downsample = Downsample(block_in)
curr_res = curr_res // 2
self.down.append(down)
# middle
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
self.mid.attn_1 = AttnBlock(block_in)
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
# end
self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
self.conv_out = nn.Conv2d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1)
def forward(self, x: Tensor) -> Tensor:
# downsampling
hs = [self.conv_in(x)]
for i_level in range(self.num_resolutions):
for i_block in range(self.num_res_blocks):
h = self.down[i_level].block[i_block](hs[-1])
if len(self.down[i_level].attn) > 0:
h = self.down[i_level].attn[i_block](h)
hs.append(h)
if i_level != self.num_resolutions - 1:
hs.append(self.down[i_level].downsample(hs[-1]))
# middle
h = hs[-1]
h = self.mid.block_1(h)
h = self.mid.attn_1(h)
h = self.mid.block_2(h)
# end
h = self.norm_out(h)
h = swish(h)
h = self.conv_out(h)
h = self.quant_conv(h)
return h
class Decoder(nn.Module):
def __init__(
self,
ch: int,
out_ch: int,
ch_mult: list[int],
num_res_blocks: int,
in_channels: int,
resolution: int,
z_channels: int,
):
super().__init__()
self.post_quant_conv = torch.nn.Conv2d(z_channels, z_channels, 1)
self.ch = ch
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.in_channels = in_channels
self.ffactor = 2 ** (self.num_resolutions - 1)
# compute in_ch_mult, block_in and curr_res at lowest res
block_in = ch * ch_mult[self.num_resolutions - 1]
curr_res = resolution // 2 ** (self.num_resolutions - 1)
self.z_shape = (1, z_channels, curr_res, curr_res)
# z to block_in
self.conv_in = nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
# middle
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
self.mid.attn_1 = AttnBlock(block_in)
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
# upsampling
self.up = nn.ModuleList()
for i_level in reversed(range(self.num_resolutions)):
block = nn.ModuleList()
attn = nn.ModuleList()
block_out = ch * ch_mult[i_level]
for _ in range(self.num_res_blocks + 1):
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
block_in = block_out
up = nn.Module()
up.block = block
up.attn = attn
if i_level != 0:
up.upsample = Upsample(block_in)
curr_res = curr_res * 2
self.up.insert(0, up) # prepend to get consistent order
# end
self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
def forward(self, z: Tensor) -> Tensor:
z = self.post_quant_conv(z)
# get dtype for proper tracing
upscale_dtype = next(self.up.parameters()).dtype
# z to block_in
h = self.conv_in(z)
# middle
h = self.mid.block_1(h)
h = self.mid.attn_1(h)
h = self.mid.block_2(h)
# cast to proper dtype
h = h.to(upscale_dtype)
# upsampling
for i_level in reversed(range(self.num_resolutions)):
for i_block in range(self.num_res_blocks + 1):
h = self.up[i_level].block[i_block](h)
if len(self.up[i_level].attn) > 0:
h = self.up[i_level].attn[i_block](h)
if i_level != 0:
h = self.up[i_level].upsample(h)
# end
h = self.norm_out(h)
h = swish(h)
h = self.conv_out(h)
return h
class AutoEncoder(nn.Module):
def __init__(self, params: AutoEncoderParams):
super().__init__()
self.params = params
self.encoder = Encoder(
resolution=params.resolution,
in_channels=params.in_channels,
ch=params.ch,
ch_mult=params.ch_mult,
num_res_blocks=params.num_res_blocks,
z_channels=params.z_channels,
)
self.decoder = Decoder(
resolution=params.resolution,
in_channels=params.in_channels,
ch=params.ch,
out_ch=params.out_ch,
ch_mult=params.ch_mult,
num_res_blocks=params.num_res_blocks,
z_channels=params.z_channels,
)
self.bn_eps = 1e-4
self.bn_momentum = 0.1
self.ps = [2, 2]
self.bn = torch.nn.BatchNorm2d(
math.prod(self.ps) * params.z_channels,
eps=self.bn_eps,
momentum=self.bn_momentum,
affine=False,
track_running_stats=True,
)
def normalize(self, z):
self.bn.eval()
return self.bn(z)
def inv_normalize(self, z):
self.bn.eval()
s = torch.sqrt(self.bn.running_var.view(1, -1, 1, 1) + self.bn_eps)
m = self.bn.running_mean.view(1, -1, 1, 1)
return z * s + m
def encode(self, x: Tensor) -> Tensor:
moments = self.encoder(x)
mean = torch.chunk(moments, 2, dim=1)[0]
z = rearrange(
mean,
"... c (i pi) (j pj) -> ... (c pi pj) i j",
pi=self.ps[0],
pj=self.ps[1],
)
z = self.normalize(z)
return z
def decode(self, z: Tensor) -> Tensor:
z = self.inv_normalize(z)
z = rearrange(
z,
"... (c pi pj) i j -> ... c (i pi) (j pj)",
pi=self.ps[0],
pj=self.ps[1],
)
dec = self.decoder(z)
return dec

451
src/flux2/model.py Normal file
View File

@@ -0,0 +1,451 @@
import math
from dataclasses import dataclass, field
import torch
from einops import rearrange
from torch import Tensor, nn
@dataclass
class Flux2Params:
in_channels: int = 128
context_in_dim: int = 15360
hidden_size: int = 6144
num_heads: int = 48
depth: int = 8
depth_single_blocks: int = 48
axes_dim: list[int] = field(default_factory=lambda: [32, 32, 32, 32])
theta: int = 2000
mlp_ratio: float = 3.0
class Flux2(nn.Module):
def __init__(self, params: Flux2Params):
super().__init__()
self.in_channels = params.in_channels
self.out_channels = params.in_channels
if params.hidden_size % params.num_heads != 0:
raise ValueError(
f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
)
pe_dim = params.hidden_size // params.num_heads
if sum(params.axes_dim) != pe_dim:
raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
self.hidden_size = params.hidden_size
self.num_heads = params.num_heads
self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=False)
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, disable_bias=True)
self.guidance_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, disable_bias=True)
self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size, bias=False)
self.double_blocks = nn.ModuleList(
[
DoubleStreamBlock(
self.hidden_size,
self.num_heads,
mlp_ratio=params.mlp_ratio,
)
for _ in range(params.depth)
]
)
self.single_blocks = nn.ModuleList(
[
SingleStreamBlock(
self.hidden_size,
self.num_heads,
mlp_ratio=params.mlp_ratio,
)
for _ in range(params.depth_single_blocks)
]
)
self.double_stream_modulation_img = Modulation(
self.hidden_size,
double=True,
disable_bias=True,
)
self.double_stream_modulation_txt = Modulation(
self.hidden_size,
double=True,
disable_bias=True,
)
self.single_stream_modulation = Modulation(self.hidden_size, double=False, disable_bias=True)
self.final_layer = LastLayer(
self.hidden_size,
self.out_channels,
)
def forward(
self,
x: Tensor,
x_ids: Tensor,
timesteps: Tensor,
ctx: Tensor,
ctx_ids: Tensor,
guidance: Tensor,
):
num_txt_tokens = ctx.shape[1]
timestep_emb = timestep_embedding(timesteps, 256)
vec = self.time_in(timestep_emb)
guidance_emb = timestep_embedding(guidance, 256)
vec = vec + self.guidance_in(guidance_emb)
double_block_mod_img = self.double_stream_modulation_img(vec)
double_block_mod_txt = self.double_stream_modulation_txt(vec)
single_block_mod, _ = self.single_stream_modulation(vec)
img = self.img_in(x)
txt = self.txt_in(ctx)
pe_x = self.pe_embedder(x_ids)
pe_ctx = self.pe_embedder(ctx_ids)
for block in self.double_blocks:
img, txt = block(
img,
txt,
pe_x,
pe_ctx,
double_block_mod_img,
double_block_mod_txt,
)
img = torch.cat((txt, img), dim=1)
pe = torch.cat((pe_ctx, pe_x), dim=2)
for i, block in enumerate(self.single_blocks):
img = block(
img,
pe,
single_block_mod,
)
img = img[:, num_txt_tokens:, ...]
img = self.final_layer(img, vec)
return img
class SelfAttention(nn.Module):
def __init__(
self,
dim: int,
num_heads: int = 8,
):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.qkv = nn.Linear(dim, dim * 3, bias=False)
self.norm = QKNorm(head_dim)
self.proj = nn.Linear(dim, dim, bias=False)
class SiLUActivation(nn.Module):
def __init__(self):
super().__init__()
self.gate_fn = nn.SiLU()
def forward(self, x: Tensor) -> Tensor:
x1, x2 = x.chunk(2, dim=-1)
return self.gate_fn(x1) * x2
class Modulation(nn.Module):
def __init__(self, dim: int, double: bool, disable_bias: bool = False):
super().__init__()
self.is_double = double
self.multiplier = 6 if double else 3
self.lin = nn.Linear(dim, self.multiplier * dim, bias=not disable_bias)
def forward(self, vec: torch.Tensor):
out = self.lin(nn.functional.silu(vec))
if out.ndim == 2:
out = out[:, None, :]
out = out.chunk(self.multiplier, dim=-1)
return out[:3], out[3:] if self.is_double else None
class LastLayer(nn.Module):
def __init__(
self,
hidden_size: int,
out_channels: int,
):
super().__init__()
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.linear = nn.Linear(hidden_size, out_channels, bias=False)
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=False))
def forward(self, x: torch.Tensor, vec: torch.Tensor) -> torch.Tensor:
mod = self.adaLN_modulation(vec)
shift, scale = mod.chunk(2, dim=-1)
if shift.ndim == 2:
shift = shift[:, None, :]
scale = scale[:, None, :]
x = (1 + scale) * self.norm_final(x) + shift
x = self.linear(x)
return x
class SingleStreamBlock(nn.Module):
def __init__(
self,
hidden_size: int,
num_heads: int,
mlp_ratio: float = 4.0,
):
super().__init__()
self.hidden_dim = hidden_size
self.num_heads = num_heads
head_dim = hidden_size // num_heads
self.scale = head_dim**-0.5
self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
self.mlp_mult_factor = 2
self.linear1 = nn.Linear(
hidden_size,
hidden_size * 3 + self.mlp_hidden_dim * self.mlp_mult_factor,
bias=False,
)
self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size, bias=False)
self.norm = QKNorm(head_dim)
self.hidden_size = hidden_size
self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.mlp_act = SiLUActivation()
def forward(
self,
x: Tensor,
pe: Tensor,
mod: tuple[Tensor, Tensor],
) -> Tensor:
mod_shift, mod_scale, mod_gate = mod
x_mod = (1 + mod_scale) * self.pre_norm(x) + mod_shift
qkv, mlp = torch.split(
self.linear1(x_mod),
[3 * self.hidden_size, self.mlp_hidden_dim * self.mlp_mult_factor],
dim=-1,
)
q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
q, k = self.norm(q, k, v)
attn = attention(q, k, v, pe)
# compute activation in mlp stream, cat again and run second linear layer
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
return x + mod_gate * output
class DoubleStreamBlock(nn.Module):
def __init__(
self,
hidden_size: int,
num_heads: int,
mlp_ratio: float,
):
super().__init__()
mlp_hidden_dim = int(hidden_size * mlp_ratio)
self.num_heads = num_heads
assert hidden_size % num_heads == 0, f"{hidden_size=} must be divisible by {num_heads=}"
self.hidden_size = hidden_size
self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.mlp_mult_factor = 2
self.img_attn = SelfAttention(
dim=hidden_size,
num_heads=num_heads,
)
self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.img_mlp = nn.Sequential(
nn.Linear(hidden_size, mlp_hidden_dim * self.mlp_mult_factor, bias=False),
SiLUActivation(),
nn.Linear(mlp_hidden_dim, hidden_size, bias=False),
)
self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.txt_attn = SelfAttention(
dim=hidden_size,
num_heads=num_heads,
)
self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.txt_mlp = nn.Sequential(
nn.Linear(
hidden_size,
mlp_hidden_dim * self.mlp_mult_factor,
bias=False,
),
SiLUActivation(),
nn.Linear(mlp_hidden_dim, hidden_size, bias=False),
)
def forward(
self,
img: Tensor,
txt: Tensor,
pe: Tensor,
pe_ctx: Tensor,
mod_img: tuple[Tensor, Tensor],
mod_txt: tuple[Tensor, Tensor],
) -> tuple[Tensor, Tensor]:
img_mod1, img_mod2 = mod_img
txt_mod1, txt_mod2 = mod_txt
img_mod1_shift, img_mod1_scale, img_mod1_gate = img_mod1
img_mod2_shift, img_mod2_scale, img_mod2_gate = img_mod2
txt_mod1_shift, txt_mod1_scale, txt_mod1_gate = txt_mod1
txt_mod2_shift, txt_mod2_scale, txt_mod2_gate = txt_mod2
# prepare image for attention
img_modulated = self.img_norm1(img)
img_modulated = (1 + img_mod1_scale) * img_modulated + img_mod1_shift
img_qkv = self.img_attn.qkv(img_modulated)
img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
# prepare txt for attention
txt_modulated = self.txt_norm1(txt)
txt_modulated = (1 + txt_mod1_scale) * txt_modulated + txt_mod1_shift
txt_qkv = self.txt_attn.qkv(txt_modulated)
txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
q = torch.cat((txt_q, img_q), dim=2)
k = torch.cat((txt_k, img_k), dim=2)
v = torch.cat((txt_v, img_v), dim=2)
pe = torch.cat((pe_ctx, pe), dim=2)
attn = attention(q, k, v, pe)
txt_attn, img_attn = attn[:, : txt_q.shape[2]], attn[:, txt_q.shape[2] :]
# calculate the img blocks
img = img + img_mod1_gate * self.img_attn.proj(img_attn)
img = img + img_mod2_gate * self.img_mlp(
(1 + img_mod2_scale) * (self.img_norm2(img)) + img_mod2_shift
)
# calculate the txt blocks
txt = txt + txt_mod1_gate * self.txt_attn.proj(txt_attn)
txt = txt + txt_mod2_gate * self.txt_mlp(
(1 + txt_mod2_scale) * (self.txt_norm2(txt)) + txt_mod2_shift
)
return img, txt
class MLPEmbedder(nn.Module):
def __init__(self, in_dim: int, hidden_dim: int, disable_bias: bool = False):
super().__init__()
self.in_layer = nn.Linear(in_dim, hidden_dim, bias=not disable_bias)
self.silu = nn.SiLU()
self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=not disable_bias)
def forward(self, x: Tensor) -> Tensor:
return self.out_layer(self.silu(self.in_layer(x)))
class EmbedND(nn.Module):
def __init__(self, dim: int, theta: int, axes_dim: list[int]):
super().__init__()
self.dim = dim
self.theta = theta
self.axes_dim = axes_dim
def forward(self, ids: Tensor) -> Tensor:
emb = torch.cat(
[rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(len(self.axes_dim))],
dim=-3,
)
return emb.unsqueeze(1)
def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0):
"""
Create sinusoidal timestep embeddings.
:param t: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an (N, D) Tensor of positional embeddings.
"""
t = time_factor * t
half = dim // 2
freqs = torch.exp(
-math.log(max_period) * torch.arange(start=0, end=half, device=t.device, dtype=torch.float32) / half
)
args = t[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
if torch.is_floating_point(t):
embedding = embedding.to(t)
return embedding
class RMSNorm(torch.nn.Module):
def __init__(self, dim: int):
super().__init__()
self.scale = nn.Parameter(torch.ones(dim))
def forward(self, x: Tensor):
x_dtype = x.dtype
x = x.float()
rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6)
return (x * rrms).to(dtype=x_dtype) * self.scale
class QKNorm(torch.nn.Module):
def __init__(self, dim: int):
super().__init__()
self.query_norm = RMSNorm(dim)
self.key_norm = RMSNorm(dim)
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]:
q = self.query_norm(q)
k = self.key_norm(k)
return q.to(v), k.to(v)
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor:
q, k = apply_rope(q, k, pe)
x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
x = rearrange(x, "B H L D -> B L (H D)")
return x
def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
assert dim % 2 == 0
scale = torch.arange(0, dim, 2, dtype=pos.dtype, device=pos.device) / dim
omega = 1.0 / (theta**scale)
out = torch.einsum("...n,d->...nd", pos, omega)
out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
return out.float()
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]:
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)

View File

@@ -0,0 +1,129 @@
"""OpenRouter API client for prompt upsampling."""
import os
from typing import Any
from openai import OpenAI
from PIL import Image
from .system_messages import SYSTEM_MESSAGE_UPSAMPLING_I2I, SYSTEM_MESSAGE_UPSAMPLING_T2I
from .util import image_to_base64
DEFAULT_SAMPLING_PARAMS = {"mistralai/pixtral-large-2411": dict(temperature=0.15)}
class OpenRouterAPIClient:
"""Client for OpenRouter API-based prompt upsampling."""
def __init__(
self,
sampling_params: dict[str, Any],
model: str = "mistralai/pixtral-large-2411",
max_tokens: int = 512,
):
"""
Initialize the OpenRouter API client.
Args:
model: Model name to use for upsampling. Defaults to "mistralai/pixtral-large-2411".
Can be any OpenRouter model (e.g., "mistralai/pixtral-large-2411",
"qwen/qwen3-vl-235b-a22b-instruct", etc.)
"""
self.api_key = os.environ["OPENROUTER_API_KEY"]
self.client = OpenAI(api_key=self.api_key, base_url="https://openrouter.ai/api/v1")
self.model = model
self.sampling_params = sampling_params
self.max_tokens = max_tokens
def _format_messages(
self,
prompt: str,
system_message: str,
images: list[Image.Image] | None = None,
) -> list[dict[str, str]]:
messages: list[dict[str, str]] = [
{"role": "system", "content": system_message},
]
if images:
content = []
for img in images:
img_base64 = image_to_base64(img)
content.append(
{
"type": "image_url",
"image_url": {
"url": f"data:image/png;base64,{img_base64}",
},
}
)
content.append({"type": "text", "text": prompt})
messages.append({"role": "user", "content": content})
else:
messages.append({"role": "user", "content": prompt})
return messages
def upsample_prompt(
self,
txt: list[str],
img: list[Image.Image] | list[list[Image.Image]] | None = None,
) -> list[str]:
"""
Upsample prompts using OpenRouter API.
Args:
txt: List of input prompts to upsample
img: Optional list of images or list of lists of images.
If None or empty, uses t2i mode, otherwise i2i mode.
Returns:
List of upsampled prompts
"""
# Determine system message based on whether images are provided
has_images = img is not None and len(img) > 0
if has_images and isinstance(img[0], list):
has_images = len(img[0]) > 0
if has_images:
system_message = SYSTEM_MESSAGE_UPSAMPLING_I2I
else:
system_message = SYSTEM_MESSAGE_UPSAMPLING_T2I
upsampled_prompts = []
# Process each prompt (potentially with images)
for i, prompt in enumerate(txt):
# Get images for this prompt
prompt_images: list[Image.Image] | None = None
if img is not None and len(img) > i:
if isinstance(img[i], list):
prompt_images = img[i] if len(img[i]) > 0 else None
elif isinstance(img[i], Image.Image):
prompt_images = [img[i]]
# Format messages
messages = self._format_messages(
prompt=prompt,
system_message=system_message,
images=prompt_images,
)
# Call API
try:
response = self.client.chat.completions.create(
model=self.model,
messages=messages,
max_tokens=self.max_tokens,
**self.sampling_params,
)
upsampled = response.choices[0].message.content.strip()
upsampled_prompts.append(upsampled)
except Exception as e:
print(f"Error upsampling prompt via OpenRouter API: {e}, returning original prompt")
upsampled_prompts.append(prompt)
return upsampled_prompts

339
src/flux2/sampling.py Normal file
View File

@@ -0,0 +1,339 @@
import math
import torch
import torchvision
from einops import rearrange
from PIL import Image
from torch import Tensor
from .model import Flux2
def compress_time(t_ids: Tensor) -> Tensor:
assert t_ids.ndim == 1
t_ids_max = torch.max(t_ids)
t_remap = torch.zeros((t_ids_max + 1,), device=t_ids.device, dtype=t_ids.dtype)
t_unique_sorted_ids = torch.unique(t_ids, sorted=True)
t_remap[t_unique_sorted_ids] = torch.arange(
len(t_unique_sorted_ids), device=t_ids.device, dtype=t_ids.dtype
)
t_ids_compressed = t_remap[t_ids]
return t_ids_compressed
def scatter_ids(x: Tensor, x_ids: Tensor) -> list[Tensor]:
"""
using position ids to scatter tokens into place
"""
x_list = []
t_coords = []
for data, pos in zip(x, x_ids):
_, ch = data.shape # noqa: F841
t_ids = pos[:, 0].to(torch.int64)
h_ids = pos[:, 1].to(torch.int64)
w_ids = pos[:, 2].to(torch.int64)
t_ids_cmpr = compress_time(t_ids)
t = torch.max(t_ids_cmpr) + 1
h = torch.max(h_ids) + 1
w = torch.max(w_ids) + 1
flat_ids = t_ids_cmpr * w * h + h_ids * w + w_ids
out = torch.zeros((t * h * w, ch), device=data.device, dtype=data.dtype)
out.scatter_(0, flat_ids.unsqueeze(1).expand(-1, ch), data)
x_list.append(rearrange(out, "(t h w) c -> 1 c t h w", t=t, h=h, w=w))
t_coords.append(torch.unique(t_ids, sorted=True))
return x_list
def encode_image_refs(ae, img_ctx: list[Image.Image]):
scale = 10
if len(img_ctx) > 1:
limit_pixels = 1024**2
elif len(img_ctx) == 1:
limit_pixels = 2024**2
else:
limit_pixels = None
if not img_ctx:
return None, None
img_ctx_prep = default_prep(img=img_ctx, limit_pixels=limit_pixels)
if not isinstance(img_ctx_prep, list):
img_ctx_prep = [img_ctx_prep]
# Encode each reference image
encoded_refs = []
for img in img_ctx_prep:
encoded = ae.encode(img[None].cuda())[0]
encoded_refs.append(encoded)
# Create time offsets for each reference
t_off = [scale + scale * t for t in torch.arange(0, len(encoded_refs))]
t_off = [t.view(-1) for t in t_off]
# Process with position IDs
ref_tokens, ref_ids = listed_prc_img(encoded_refs, t_coord=t_off)
# Concatenate all references along sequence dimension
ref_tokens = torch.cat(ref_tokens, dim=0) # (total_ref_tokens, C)
ref_ids = torch.cat(ref_ids, dim=0) # (total_ref_tokens, 4)
# Add batch dimension
ref_tokens = ref_tokens.unsqueeze(0) # (1, total_ref_tokens, C)
ref_ids = ref_ids.unsqueeze(0) # (1, total_ref_tokens, 4)
return ref_tokens.to(torch.bfloat16), ref_ids
def prc_txt(x: Tensor, t_coord: Tensor | None = None) -> tuple[Tensor, Tensor]:
_l, _ = x.shape # noqa: F841
coords = {
"t": torch.arange(1) if t_coord is None else t_coord,
"h": torch.arange(1), # dummy dimension
"w": torch.arange(1), # dummy dimension
"l": torch.arange(_l),
}
x_ids = torch.cartesian_prod(coords["t"], coords["h"], coords["w"], coords["l"])
return x, x_ids.to(x.device)
def batched_wrapper(fn):
def batched_prc(x: Tensor, t_coord: Tensor | None = None) -> tuple[Tensor, Tensor]:
results = []
for i in range(len(x)):
results.append(
fn(
x[i],
t_coord[i] if t_coord is not None else None,
)
)
x, x_ids = zip(*results)
return torch.stack(x), torch.stack(x_ids)
return batched_prc
def listed_wrapper(fn):
def listed_prc(
x: list[Tensor],
t_coord: list[Tensor] | None = None,
) -> tuple[list[Tensor], list[Tensor]]:
results = []
for i in range(len(x)):
results.append(
fn(
x[i],
t_coord[i] if t_coord is not None else None,
)
)
x, x_ids = zip(*results)
return list(x), list(x_ids)
return listed_prc
def prc_img(x: Tensor, t_coord: Tensor | None = None) -> tuple[Tensor, Tensor]:
_, h, w = x.shape # noqa: F841
x_coords = {
"t": torch.arange(1) if t_coord is None else t_coord,
"h": torch.arange(h),
"w": torch.arange(w),
"l": torch.arange(1),
}
x_ids = torch.cartesian_prod(x_coords["t"], x_coords["h"], x_coords["w"], x_coords["l"])
x = rearrange(x, "c h w -> (h w) c")
return x, x_ids.to(x.device)
listed_prc_img = listed_wrapper(prc_img)
batched_prc_img = batched_wrapper(prc_img)
batched_prc_txt = batched_wrapper(prc_txt)
def center_crop_to_multiple_of_x(
img: Image.Image | list[Image.Image], x: int
) -> Image.Image | list[Image.Image]:
if isinstance(img, list):
return [center_crop_to_multiple_of_x(_img, x) for _img in img] # type: ignore
w, h = img.size
new_w = (w // x) * x
new_h = (h // x) * x
left = (w - new_w) // 2
top = (h - new_h) // 2
right = left + new_w
bottom = top + new_h
resized = img.crop((left, top, right, bottom))
return resized
def cap_pixels(img: Image.Image | list[Image.Image], k):
if isinstance(img, list):
return [cap_pixels(_img, k) for _img in img]
w, h = img.size
pixel_count = w * h
if pixel_count <= k:
return img
# Scaling factor to reduce total pixels below K
scale = math.sqrt(k / pixel_count)
new_w = int(w * scale)
new_h = int(h * scale)
return img.resize((new_w, new_h), Image.Resampling.LANCZOS)
def cap_min_pixels(img: Image.Image | list[Image.Image], max_ar=8, min_sidelength=64):
if isinstance(img, list):
return [cap_min_pixels(_img, max_ar=max_ar, min_sidelength=min_sidelength) for _img in img]
w, h = img.size
if w < min_sidelength or h < min_sidelength:
raise ValueError(f"Skipping due to minimal sidelength underschritten h {h} w {w}")
if w / h > max_ar or h / w > max_ar:
raise ValueError(f"Skipping due to maximal ar overschritten h {h} w {w}")
return img
def to_rgb(img: Image.Image | list[Image.Image]):
if isinstance(img, list):
return [
to_rgb(
_img,
)
for _img in img
]
return img.convert("RGB")
def default_images_prep(
x: Image.Image | list[Image.Image],
) -> torch.Tensor | list[torch.Tensor]:
if isinstance(x, list):
return [default_images_prep(e) for e in x] # type: ignore
x_tensor = torchvision.transforms.ToTensor()(x)
return 2 * x_tensor - 1
def default_prep(
img: Image.Image | list[Image.Image], limit_pixels: int | None, ensure_multiple: int = 16
) -> torch.Tensor | list[torch.Tensor]:
img_rgb = to_rgb(img)
img_min = cap_min_pixels(img_rgb) # type: ignore
if limit_pixels is not None:
img_cap = cap_pixels(img_min, limit_pixels) # type: ignore
else:
img_cap = img_min
img_crop = center_crop_to_multiple_of_x(img_cap, ensure_multiple) # type: ignore
img_tensor = default_images_prep(img_crop)
return img_tensor
def generalized_time_snr_shift(t: Tensor, mu: float, sigma: float) -> Tensor:
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
def get_schedule(num_steps: int, image_seq_len: int) -> list[float]:
mu = compute_empirical_mu(image_seq_len, num_steps)
timesteps = torch.linspace(1, 0, num_steps + 1)
timesteps = generalized_time_snr_shift(timesteps, mu, 1.0)
return timesteps.tolist()
def compute_empirical_mu(image_seq_len: int, num_steps: int) -> float:
a1, b1 = 8.73809524e-05, 1.89833333
a2, b2 = 0.00016927, 0.45666666
if image_seq_len > 4300:
mu = a2 * image_seq_len + b2
return float(mu)
m_200 = a2 * image_seq_len + b2
m_10 = a1 * image_seq_len + b1
a = (m_200 - m_10) / 190.0
b = m_200 - 200.0 * a
mu = a * num_steps + b
return float(mu)
def denoise(
model: Flux2,
# model input
img: Tensor,
img_ids: Tensor,
txt: Tensor,
txt_ids: Tensor,
# sampling parameters
timesteps: list[float],
guidance: float,
# extra img tokens (sequence-wise)
img_cond_seq: Tensor | None = None,
img_cond_seq_ids: Tensor | None = None,
):
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
for t_curr, t_prev in zip(timesteps[:-1], timesteps[1:]):
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
img_input = img
img_input_ids = img_ids
if img_cond_seq is not None:
assert (
img_cond_seq_ids is not None
), "You need to provide either both or neither of the sequence conditioning"
img_input = torch.cat((img_input, img_cond_seq), dim=1)
img_input_ids = torch.cat((img_input_ids, img_cond_seq_ids), dim=1)
pred = model(
x=img_input,
x_ids=img_input_ids,
timesteps=t_vec,
ctx=txt,
ctx_ids=txt_ids,
guidance=guidance_vec,
)
if img_input_ids is not None:
pred = pred[:, : img.shape[1]]
img = img + (t_prev - t_curr) * pred
return img
def concatenate_images(
images: list[Image.Image],
) -> Image.Image:
"""
Concatenate a list of PIL images horizontally with center alignment and white background.
"""
# If only one image, return a copy of it
if len(images) == 1:
return images[0].copy()
# Convert all images to RGB if not already
images = [img.convert("RGB") if img.mode != "RGB" else img for img in images]
# Calculate dimensions for horizontal concatenation
total_width = sum(img.width for img in images)
max_height = max(img.height for img in images)
# Create new image with white background
background_color = (255, 255, 255)
new_img = Image.new("RGB", (total_width, max_height), background_color)
# Paste images with center alignment
x_offset = 0
for img in images:
y_offset = (max_height - img.height) // 2
new_img.paste(img, (x_offset, y_offset))
x_offset += img.width
return new_img

View File

@@ -0,0 +1,82 @@
SYSTEM_MESSAGE = """You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object
attribution and actions without speculation."""
SYSTEM_MESSAGE_UPSAMPLING_T2I = """You are an expert prompt engineer for FLUX.2 by Black Forest Labs. Rewrite user prompts to be more descriptive while strictly preserving their core subject and intent.
Guidelines:
1. Structure: Keep structured inputs structured (enhance within fields). Convert natural language to detailed paragraphs.
2. Details: Add concrete visual specifics - form, scale, textures, materials, lighting (quality, direction, color), shadows, spatial relationships, and environmental context.
3. Text in Images: Put ALL text in quotation marks, matching the prompt's language. Always provide explicit quoted text for objects that would contain text in reality (signs, labels, screens, etc.) - without it, the model generates gibberish.
Output only the revised prompt and nothing else."""
SYSTEM_MESSAGE_UPSAMPLING_I2I = """You are FLUX.2 by Black Forest Labs, an image-editing expert. You convert editing requests into one concise instruction (50-80 words, ~30 for brief requests).
Rules:
- Single instruction only, no commentary
- Use clear, analytical language (avoid "whimsical," "cascading," etc.)
- Specify what changes AND what stays the same (face, lighting, composition)
- Reference actual image elements
- Turn negatives into positives ("don't change X""keep X")
- Make abstractions concrete ("futuristic""glowing cyan neon, metallic panels")
- Keep content PG-13
Output only the final instruction in plain text and nothing else."""
SYSTEM_PROMPT_CONTENT_FILTER = """
You are Mistral Small 3, a Large Language Model (LLM) created by Mistral AI, a French startup headquartered in Paris.
Your task is identifying images and text for copyright concerns and depictions of public personas.
"""
PROMPT_IMAGE_INTEGRITY = """
Task: Analyze an image to identify potential copyright concerns or depictions of public figures.
Output: Respond with only "yes" or "no"
Criteria for "yes":
- The image contains a recognizable character from copyrighted media (movies, TV, comics, games, etc.)
- The image displays a trademarked logo or brand
- The image depicts a recognizable public figure (celebrities, politicians, athletes, influencers, historical figures, etc.)
Criteria for "no":
- All other cases
- When you cannot identify the specific copyrighted work or named individual
Critical Requirements:
1. You must be able to name the exact copyrighted work or specific person depicted
2. General references to demographics or characteristics are not sufficient
3. Base your decision solely on visual content, not interpretation
4. Provide only the one-word answer: "yes" or "no"
""".strip()
PROMPT_IMAGE_INTEGRITY_FOLLOW_UP = "Does this image have copyright concerns or includes public figures?"
PROMPT_TEXT_INTEGRITY = """
Task: Analyze a text prompt to identify potential copyright concerns or requests to depict living public figures.
Output: Respond with only "yes" or "no"
Criteria for "Yes":
- The prompt explicitly names a character from copyrighted media (movies, TV, comics, games, etc.)
- The prompt explicitly mentions a trademarked logo or brand
- The prompt names or describes a specific living public figure (celebrities, politicians, athletes, influencers, etc.)
Criteria for "No":
- All other cases
- When you cannot identify the specific copyrighted work or named individual
Critical Requirements:
1. You must be able to name the exact copyrighted work or specific person referenced
2. General demographic descriptions or characteristics are not sufficient
3. Analyze only the prompt text, not potential image outcomes
4. Provide only the one-word answer: "yes" or "no"
The prompt to check is:
-----
{prompt}
-----
Does this prompt have copyright concerns or includes public figures?
""".strip()

356
src/flux2/text_encoder.py Normal file
View File

@@ -0,0 +1,356 @@
from pathlib import Path
import torch
import torch.nn as nn
from einops import rearrange
from PIL import Image
from transformers import AutoProcessor, Mistral3ForConditionalGeneration, pipeline
from .sampling import cap_pixels, concatenate_images
from .system_messages import (
PROMPT_IMAGE_INTEGRITY,
PROMPT_IMAGE_INTEGRITY_FOLLOW_UP,
PROMPT_TEXT_INTEGRITY,
SYSTEM_MESSAGE,
SYSTEM_MESSAGE_UPSAMPLING_I2I,
SYSTEM_MESSAGE_UPSAMPLING_T2I,
SYSTEM_PROMPT_CONTENT_FILTER,
)
OUTPUT_LAYERS = [10, 20, 30]
MAX_LENGTH = 512
NSFW_THRESHOLD = 0.85
UPSAMPLING_MAX_IMAGE_SIZE = 768**2
class Mistral3SmallEmbedder(nn.Module):
def __init__(
self,
model_spec: str = "mistralai/Mistral-Small-3.2-24B-Instruct-2506",
model_spec_processor: str = "mistralai/Mistral-Small-3.1-24B-Instruct-2503",
torch_dtype: str = "bfloat16",
):
super().__init__()
self.model: Mistral3ForConditionalGeneration = Mistral3ForConditionalGeneration.from_pretrained(
model_spec,
torch_dtype=getattr(torch, torch_dtype),
)
self.processor = AutoProcessor.from_pretrained(model_spec_processor, use_fast=False)
self.yes_token, self.no_token = self.processor.tokenizer.encode(
["yes", "no"], add_special_tokens=False
)
self.max_length = MAX_LENGTH
self.upsampling_max_image_size = UPSAMPLING_MAX_IMAGE_SIZE
self.nsfw_classifier = pipeline("image-classification", model="Falconsai/nsfw_image_detection")
def _validate_and_process_images(
self, img: list[list[Image.Image]] | list[Image.Image]
) -> list[list[Image.Image]]:
# Simple validation: ensure it's a list of PIL images or list of lists of PIL images
if not img:
return []
# Check if it's a list of lists or a list of images
if isinstance(img[0], Image.Image):
# It's a list of images, convert to list of lists
img = [[im] for im in img]
# potentially concatenate multiple images to reduce the size
img = [[concatenate_images(img_i)] if len(img_i) > 1 else img_i for img_i in img]
# cap the pixels
img = [[cap_pixels(img_i, self.upsampling_max_image_size) for img_i in img_i] for img_i in img]
return img
def format_input(
self,
txt: list[str],
system_message: str = SYSTEM_MESSAGE,
img: list[Image.Image] | list[list[Image.Image]] | None = None,
) -> list[list[dict]]:
"""
Format a batch of text prompts into the conversation format expected by apply_chat_template.
Optionally, add images to the input.
Args:
txt: List of text prompts
system_message: System message to use (default: CREATIVE_SYSTEM_MESSAGE)
img: List of images to add to the input.
Returns:
List of conversations, where each conversation is a list of message dicts
"""
# Remove [IMG] tokens from prompts to avoid Pixtral validation issues
# when truncation is enabled. The processor counts [IMG] tokens and fails
# if the count changes after truncation.
cleaned_txt = [prompt.replace("[IMG]", "") for prompt in txt]
if img is None or len(img) == 0:
return [
[
{
"role": "system",
"content": [{"type": "text", "text": system_message}],
},
{"role": "user", "content": [{"type": "text", "text": prompt}]},
]
for prompt in cleaned_txt
]
else:
assert len(img) == len(txt), "Number of images must match number of prompts"
img = self._validate_and_process_images(img)
messages = [
[
{
"role": "system",
"content": [{"type": "text", "text": system_message}],
},
]
for _ in cleaned_txt
]
for i, (el, images) in enumerate(zip(messages, img)):
# optionally add the images per batch element.
if images is not None:
el.append(
{
"role": "user",
"content": [{"type": "image", "image": image_obj} for image_obj in images],
}
)
# add the text.
el.append(
{
"role": "user",
"content": [{"type": "text", "text": cleaned_txt[i]}],
}
)
return messages
@torch.no_grad()
def upsample_prompt(
self,
txt: list[str],
img: list[Image.Image] | list[list[Image.Image]] | None = None,
temperature: float = 0.15,
) -> list[str]:
"""
Upsample prompts using the model's generate method.
Args:
txt: List of input prompts to upsample
img: Optional list of images or list of lists of images. If None or all None, uses t2i mode, otherwise i2i mode.
Returns:
List of upsampled prompts
"""
# Set system message based on whether images are provided
if img is None or len(img) == 0 or img[0] is None:
system_message = SYSTEM_MESSAGE_UPSAMPLING_T2I
else:
system_message = SYSTEM_MESSAGE_UPSAMPLING_I2I
# Format input messages
messages_batch = self.format_input(txt=txt, system_message=system_message, img=img)
# Process all messages at once
# with image processing a too short max length can throw an error in here.
try:
inputs = self.processor.apply_chat_template(
messages_batch,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=2048,
)
except ValueError as e:
print(
f"Error processing input: {e}, your max length is probably too short, when you have images in the input."
)
raise e
# Move to device
inputs["input_ids"] = inputs["input_ids"].to(self.model.device)
inputs["attention_mask"] = inputs["attention_mask"].to(self.model.device)
if "pixel_values" in inputs:
inputs["pixel_values"] = inputs["pixel_values"].to(self.model.device, self.model.dtype)
# Generate text using the model's generate method
try:
generated_ids = self.model.generate(
**inputs,
max_new_tokens=512,
do_sample=True,
temperature=temperature,
use_cache=True,
)
# Decode only the newly generated tokens (skip input tokens)
# Extract only the generated portion
input_length = inputs["input_ids"].shape[1]
generated_tokens = generated_ids[:, input_length:]
raw_txt = self.processor.tokenizer.batch_decode(
generated_tokens, skip_special_tokens=True, clean_up_tokenization_spaces=True
)
return raw_txt
except Exception as e:
print(f"Error generating upsampled prompt: {e}, returning original prompt")
return txt
@torch.no_grad()
def forward(self, txt: list[str]):
# Format input messages
messages_batch = self.format_input(txt=txt)
# Process all messages at once
# with image processing a too short max length can throw an error in here.
inputs = self.processor.apply_chat_template(
messages_batch,
add_generation_prompt=False,
tokenize=True,
return_dict=True,
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=self.max_length,
)
# Move to device
input_ids = inputs["input_ids"].to(self.model.device)
attention_mask = inputs["attention_mask"].to(self.model.device)
# Forward pass through the model
output = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
output_hidden_states=True,
use_cache=False,
)
out = torch.stack([output.hidden_states[k] for k in OUTPUT_LAYERS], dim=1)
return rearrange(out, "b c l d -> b l (c d)")
def yes_no_logit_processor(
self, input_ids: torch.LongTensor, scores: torch.FloatTensor
) -> torch.FloatTensor:
"""
Sets all tokens but yes/no to the minimum.
"""
scores_yes_token = scores[:, self.yes_token].clone()
scores_no_token = scores[:, self.no_token].clone()
scores_min = scores.min()
scores[:, :] = scores_min - 1
scores[:, self.yes_token] = scores_yes_token
scores[:, self.no_token] = scores_no_token
return scores
def test_image(self, image: Image.Image | str | Path | torch.Tensor) -> bool:
if isinstance(image, torch.Tensor):
image = rearrange(image[0].clamp(-1.0, 1.0), "c h w -> h w c")
image = Image.fromarray((127.5 * (image + 1.0)).cpu().byte().numpy())
elif isinstance(image, (str, Path)):
image = Image.open(image)
classification = next(c for c in self.nsfw_classifier(image) if c["label"] == "nsfw")
if classification["score"] > NSFW_THRESHOLD:
return True
# 512^2 pixels are enough for checking
w, h = image.size
f = (512**2 / (w * h)) ** 0.5
image = image.resize((int(f * w), int(f * h)))
chat = [
{
"role": "system",
"content": [
{
"type": "text",
"text": SYSTEM_PROMPT_CONTENT_FILTER,
},
],
},
{
"role": "user",
"content": [
{
"type": "text",
"text": PROMPT_IMAGE_INTEGRITY,
},
{
"type": "image",
"image": image,
},
{
"type": "text",
"text": PROMPT_IMAGE_INTEGRITY_FOLLOW_UP,
},
],
},
]
inputs = self.processor.apply_chat_template(
chat,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt",
).to(self.model.device)
inputs["pixel_values"] = inputs["pixel_values"].to(dtype=self.model.dtype)
generate_ids = self.model.generate(
**inputs,
max_new_tokens=1,
logits_processor=[self.yes_no_logit_processor],
do_sample=False,
)
return generate_ids[0, -1].item() == self.yes_token
def test_txt(self, txt: str) -> bool:
chat = [
{
"role": "system",
"content": [
{
"type": "text",
"text": SYSTEM_PROMPT_CONTENT_FILTER,
},
],
},
{
"role": "user",
"content": [
{
"type": "text",
"text": PROMPT_TEXT_INTEGRITY.format(prompt=txt),
},
],
},
]
inputs = self.processor.apply_chat_template(
chat,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt",
).to(self.model.device)
generate_ids = self.model.generate(
**inputs,
max_new_tokens=1,
logits_processor=[self.yes_no_logit_processor],
do_sample=False,
)
return generate_ids[0, -1].item() == self.yes_token

105
src/flux2/util.py Normal file
View File

@@ -0,0 +1,105 @@
import base64
import io
import os
import sys
import huggingface_hub
import torch
from PIL import Image
from safetensors.torch import load_file as load_sft
from .autoencoder import AutoEncoder, AutoEncoderParams
from .model import Flux2, Flux2Params
from .text_encoder import Mistral3SmallEmbedder
FLUX2_MODEL_INFO = {
"flux.2-dev": {
"repo_id": "black-forest-labs/FLUX.2-dev",
"filename": "flux2-dev.safetensors",
"filename_ae": "ae.safetensors",
"params": Flux2Params(),
}
}
def load_flow_model(model_name: str, debug_mode: bool = False, device: str | torch.device = "cuda") -> Flux2:
config = FLUX2_MODEL_INFO[model_name.lower()]
if debug_mode:
config["params"].depth = 1
config["params"].depth_single_blocks = 1
else:
if "FLUX2_MODEL_PATH" in os.environ:
weight_path = os.environ["FLUX2_MODEL_PATH"]
assert os.path.exists(weight_path), f"Provided weight path {weight_path} does not exist"
else:
# download from huggingface
try:
weight_path = huggingface_hub.hf_hub_download(
repo_id=config["repo_id"],
filename=config["filename"],
repo_type="model",
)
except huggingface_hub.errors.RepositoryNotFoundError:
print(
f"Failed to access the model repository. Please check your internet "
f"connection and make sure you've access to {config['repo_id']}."
"Stopping."
)
sys.exit(1)
if not debug_mode:
with torch.device("meta"):
model = Flux2(FLUX2_MODEL_INFO[model_name.lower()]["params"]).to(torch.bfloat16)
print(f"Loading {weight_path} for the FLUX.2 weights")
sd = load_sft(weight_path, device=str(device))
model.load_state_dict(sd, strict=False, assign=True)
return model.to(device)
else:
with torch.device(device):
return Flux2(FLUX2_MODEL_INFO[model_name.lower()]["params"]).to(torch.bfloat16)
def load_mistral_small_embedder(device: str | torch.device = "cuda") -> Mistral3SmallEmbedder:
return Mistral3SmallEmbedder().to(device)
def load_ae(model_name: str, device: str | torch.device = "cuda") -> AutoEncoder:
config = FLUX2_MODEL_INFO[model_name.lower()]
if "AE_MODEL_PATH" in os.environ:
weight_path = os.environ["AE_MODEL_PATH"]
assert os.path.exists(weight_path), f"Provided weight path {weight_path} does not exist"
else:
# download from huggingface
try:
weight_path = huggingface_hub.hf_hub_download(
repo_id=config["repo_id"],
filename=config["filename_ae"],
repo_type="model",
)
except huggingface_hub.errors.RepositoryNotFoundError:
print(
f"Failed to access the model repository. Please check your internet "
f"connection and make sure you've access to {config['repo_id']}."
"Stopping."
)
sys.exit(1)
if isinstance(device, str):
device = torch.device(device)
with torch.device("meta"):
ae = AutoEncoder(AutoEncoderParams())
print(f"Loading {weight_path} for the AutoEncoder weights")
sd = load_sft(weight_path, device=str(device))
ae.load_state_dict(sd, strict=True, assign=True)
return ae.to(device)
def image_to_base64(image: Image.Image) -> str:
"""Convert PIL Image to base64 string."""
buffered = io.BytesIO()
image.save(buffered, format="PNG")
img_str = base64.b64encode(buffered.getvalue()).decode()
return img_str

47
src/flux2/watermark.py Normal file
View File

@@ -0,0 +1,47 @@
import torch
from einops import rearrange
from imwatermark import WatermarkEncoder
class WatermarkEmbedder:
def __init__(self, watermark):
self.watermark = watermark
self.num_bits = len(WATERMARK_BITS)
self.encoder = WatermarkEncoder()
self.encoder.set_watermark("bits", self.watermark)
def __call__(self, image: torch.Tensor) -> torch.Tensor:
"""
Adds a predefined watermark to the input image
Args:
image: ([N,] B, RGB, H, W) in range [-1, 1]
Returns:
same as input but watermarked
"""
image = 0.5 * image + 0.5
squeeze = len(image.shape) == 4
if squeeze:
image = image[None, ...]
n = image.shape[0]
image_np = rearrange((255 * image).detach().cpu(), "n b c h w -> (n b) h w c").numpy()[:, :, :, ::-1]
# torch (b, c, h, w) in [0, 1] -> numpy (b, h, w, c) [0, 255]
# watermarking libary expects input as cv2 BGR format
for k in range(image_np.shape[0]):
image_np[k] = self.encoder.encode(image_np[k], "dwtDct")
image = torch.from_numpy(rearrange(image_np[:, :, :, ::-1], "(n b) h w c -> n b c h w", n=n)).to(
image.device
)
image = torch.clamp(image / 255, min=0.0, max=1.0)
if squeeze:
image = image[0]
image = 2 * image - 1
return image
# A fixed 48-bit message that was chosen at random
WATERMARK_MESSAGE = 0b001010101111111010000111100111001111010100101110
# bin(x)[2:] gives bits of x as str, use int to convert them to 0/1
WATERMARK_BITS = [int(bit) for bit in bin(WATERMARK_MESSAGE)[2:]]
embed_watermark = WatermarkEmbedder(WATERMARK_BITS)