diff --git a/README.md b/README.md index 03850f4..4350968 100644 --- a/README.md +++ b/README.md @@ -88,7 +88,7 @@ The below example should run on a RTX 4090. For more examples check the [diffuse ```python import torch -from diffusers import Flux2Pipeline, Flux2Transformer2DModel +from diffusers import Flux2Pipeline from diffusers.utils import load_image from huggingface_hub import get_token import requests @@ -112,7 +112,7 @@ def remote_text_encoder(prompts): return prompt_embeds.to(device) pipe = Flux2Pipeline.from_pretrained( - repo_id, transformer=transformer, text_encoder=None, torch_dtype=torch_dtype + repo_id, text_encoder=None, torch_dtype=torch_dtype ).to(device) prompt = "Realistic macro photograph of a hermit crab using a soda can as its shell, partially emerging from the can, captured with sharp detail and natural colors, on a sunlit beach with soft shadows and a shallow depth of field, with blurred ocean waves in the background. The can has the text `BFL Diffusers` on it and it has a color gradient that start with #FF5733 at the top and transitions to #33FF57 at the bottom." diff --git a/docs/flux2_dev_hf.md b/docs/flux2_dev_hf.md index acee3da..9c057c9 100644 --- a/docs/flux2_dev_hf.md +++ b/docs/flux2_dev_hf.md @@ -77,15 +77,22 @@ The text-encoder is offloaded from VRAM for the transformer to run with `pipe.en ```py import torch -from diffusers import Flux2Pipeline +from diffusers import Flux2Pipeline, AutoModel +from transformers import Mistral3ForConditionalGeneration from diffusers.utils import load_image repo_id = "diffusers/FLUX.2-dev-bnb-4bit" #quantized text-encoder and DiT. VAE still in bf16 device = "cuda:0" torch_dtype = torch.bfloat16 +text_encoder = Mistral3ForConditionalGeneration.from_pretrained( + repo_id, subfolder="text_encoder", torch_dtype=torch.bfloat16, device_map="cpu" +) +dit = AutoModel.from_pretrained( + repo_id, subfolder="transformer", torch_dtype=torch.bfloat16, device_map="cpu" +) pipe = Flux2Pipeline.from_pretrained( - repo_id, torch_dtype=torch_dtype + repo_id, text_encoder=text_encoder, transformer=dit, torch_dtype=torch_dtype ) pipe.enable_model_cpu_offload()