Update 4-bit text-encoder and 4-bit DiT on diffusers script for initial CPU loading (#12)
* Update diffusers script for initial CPU loading * remove transformer
This commit is contained in:
@@ -88,7 +88,7 @@ The below example should run on a RTX 4090. For more examples check the [diffuse
|
|||||||
|
|
||||||
```python
|
```python
|
||||||
import torch
|
import torch
|
||||||
from diffusers import Flux2Pipeline, Flux2Transformer2DModel
|
from diffusers import Flux2Pipeline
|
||||||
from diffusers.utils import load_image
|
from diffusers.utils import load_image
|
||||||
from huggingface_hub import get_token
|
from huggingface_hub import get_token
|
||||||
import requests
|
import requests
|
||||||
@@ -112,7 +112,7 @@ def remote_text_encoder(prompts):
|
|||||||
return prompt_embeds.to(device)
|
return prompt_embeds.to(device)
|
||||||
|
|
||||||
pipe = Flux2Pipeline.from_pretrained(
|
pipe = Flux2Pipeline.from_pretrained(
|
||||||
repo_id, transformer=transformer, text_encoder=None, torch_dtype=torch_dtype
|
repo_id, text_encoder=None, torch_dtype=torch_dtype
|
||||||
).to(device)
|
).to(device)
|
||||||
|
|
||||||
prompt = "Realistic macro photograph of a hermit crab using a soda can as its shell, partially emerging from the can, captured with sharp detail and natural colors, on a sunlit beach with soft shadows and a shallow depth of field, with blurred ocean waves in the background. The can has the text `BFL Diffusers` on it and it has a color gradient that start with #FF5733 at the top and transitions to #33FF57 at the bottom."
|
prompt = "Realistic macro photograph of a hermit crab using a soda can as its shell, partially emerging from the can, captured with sharp detail and natural colors, on a sunlit beach with soft shadows and a shallow depth of field, with blurred ocean waves in the background. The can has the text `BFL Diffusers` on it and it has a color gradient that start with #FF5733 at the top and transitions to #33FF57 at the bottom."
|
||||||
|
|||||||
@@ -77,15 +77,22 @@ The text-encoder is offloaded from VRAM for the transformer to run with `pipe.en
|
|||||||
|
|
||||||
```py
|
```py
|
||||||
import torch
|
import torch
|
||||||
from diffusers import Flux2Pipeline
|
from diffusers import Flux2Pipeline, AutoModel
|
||||||
|
from transformers import Mistral3ForConditionalGeneration
|
||||||
from diffusers.utils import load_image
|
from diffusers.utils import load_image
|
||||||
|
|
||||||
repo_id = "diffusers/FLUX.2-dev-bnb-4bit" #quantized text-encoder and DiT. VAE still in bf16
|
repo_id = "diffusers/FLUX.2-dev-bnb-4bit" #quantized text-encoder and DiT. VAE still in bf16
|
||||||
device = "cuda:0"
|
device = "cuda:0"
|
||||||
torch_dtype = torch.bfloat16
|
torch_dtype = torch.bfloat16
|
||||||
|
|
||||||
|
text_encoder = Mistral3ForConditionalGeneration.from_pretrained(
|
||||||
|
repo_id, subfolder="text_encoder", torch_dtype=torch.bfloat16, device_map="cpu"
|
||||||
|
)
|
||||||
|
dit = AutoModel.from_pretrained(
|
||||||
|
repo_id, subfolder="transformer", torch_dtype=torch.bfloat16, device_map="cpu"
|
||||||
|
)
|
||||||
pipe = Flux2Pipeline.from_pretrained(
|
pipe = Flux2Pipeline.from_pretrained(
|
||||||
repo_id, torch_dtype=torch_dtype
|
repo_id, text_encoder=text_encoder, transformer=dit, torch_dtype=torch_dtype
|
||||||
)
|
)
|
||||||
pipe.enable_model_cpu_offload()
|
pipe.enable_model_cpu_offload()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user