Update 4-bit text-encoder and 4-bit DiT on diffusers script for initial CPU loading (#12)
* Update diffusers script for initial CPU loading * remove transformer
This commit is contained in:
@@ -77,15 +77,22 @@ The text-encoder is offloaded from VRAM for the transformer to run with `pipe.en
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import Flux2Pipeline
|
||||
from diffusers import Flux2Pipeline, AutoModel
|
||||
from transformers import Mistral3ForConditionalGeneration
|
||||
from diffusers.utils import load_image
|
||||
|
||||
repo_id = "diffusers/FLUX.2-dev-bnb-4bit" #quantized text-encoder and DiT. VAE still in bf16
|
||||
device = "cuda:0"
|
||||
torch_dtype = torch.bfloat16
|
||||
|
||||
text_encoder = Mistral3ForConditionalGeneration.from_pretrained(
|
||||
repo_id, subfolder="text_encoder", torch_dtype=torch.bfloat16, device_map="cpu"
|
||||
)
|
||||
dit = AutoModel.from_pretrained(
|
||||
repo_id, subfolder="transformer", torch_dtype=torch.bfloat16, device_map="cpu"
|
||||
)
|
||||
pipe = Flux2Pipeline.from_pretrained(
|
||||
repo_id, torch_dtype=torch_dtype
|
||||
repo_id, text_encoder=text_encoder, transformer=dit, torch_dtype=torch_dtype
|
||||
)
|
||||
pipe.enable_model_cpu_offload()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user