FLUX.2 launch
This commit is contained in:
525
scripts/cli.py
Normal file
525
scripts/cli.py
Normal file
@@ -0,0 +1,525 @@
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import shlex
|
||||
import sys
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Literal, Optional
|
||||
|
||||
import torch
|
||||
from einops import rearrange
|
||||
from PIL import ExifTags, Image
|
||||
|
||||
from flux2.openrouter_api_client import DEFAULT_SAMPLING_PARAMS, OpenRouterAPIClient
|
||||
from flux2.sampling import (
|
||||
batched_prc_img,
|
||||
batched_prc_txt,
|
||||
denoise,
|
||||
encode_image_refs,
|
||||
get_schedule,
|
||||
scatter_ids,
|
||||
)
|
||||
from flux2.util import FLUX2_MODEL_INFO, load_ae, load_flow_model, load_mistral_small_embedder
|
||||
|
||||
# from flux2.watermark import embed_watermark
|
||||
|
||||
|
||||
@dataclass
|
||||
class Config:
|
||||
prompt: str = "a photo of a forest with mist swirling around the tree trunks. The word 'FLUX.2' is painted over it in big, red brush strokes with visible texture"
|
||||
seed: Optional[int] = None
|
||||
width: int = 1360
|
||||
height: int = 768
|
||||
num_steps: int = 50
|
||||
guidance: float = 4.0
|
||||
input_images: List[Path] = field(default_factory=list)
|
||||
match_image_size: Optional[int] = None # Index of input_images to match size from
|
||||
upsample_prompt_mode: Literal["none", "local", "openrouter"] = "none"
|
||||
openrouter_model: str = "mistralai/pixtral-large-2411" # OpenRouter model name
|
||||
|
||||
def copy(self) -> "Config":
|
||||
return Config(
|
||||
prompt=self.prompt,
|
||||
seed=self.seed,
|
||||
width=self.width,
|
||||
height=self.height,
|
||||
num_steps=self.num_steps,
|
||||
guidance=self.guidance,
|
||||
input_images=list(self.input_images),
|
||||
match_image_size=self.match_image_size,
|
||||
upsample_prompt_mode=self.upsample_prompt_mode,
|
||||
openrouter_model=self.openrouter_model,
|
||||
)
|
||||
|
||||
|
||||
DEFAULTS = Config()
|
||||
|
||||
INT_FIELDS = {"width", "height", "seed", "num_steps", "match_image_size"}
|
||||
FLOAT_FIELDS = {"guidance"}
|
||||
LIST_FIELDS = {"input_images"}
|
||||
UPSAMPLING_MODE_FIELDS = ("none", "local", "openrouter")
|
||||
STR_FIELDS = {"openrouter_model"}
|
||||
|
||||
|
||||
def coerce_value(key: str, raw: str):
|
||||
"""Convert a raw string to the correct field type."""
|
||||
if key in INT_FIELDS:
|
||||
if raw.lower() == "none" or raw == "":
|
||||
return None
|
||||
return int(raw)
|
||||
|
||||
if key in FLOAT_FIELDS:
|
||||
return float(raw)
|
||||
|
||||
if key in STR_FIELDS:
|
||||
return raw.strip().strip('"').strip("'")
|
||||
|
||||
if key in LIST_FIELDS:
|
||||
# Handle empty list cases
|
||||
if raw == "" or raw == "[]":
|
||||
return []
|
||||
# Accept comma-separated or space-separated; strip quotes.
|
||||
items = []
|
||||
# If user passed a single token that contains commas, split on commas.
|
||||
tokens = [raw] if ("," in raw and " " not in raw) else shlex.split(raw)
|
||||
for tok in tokens:
|
||||
for part in tok.split(","):
|
||||
part = part.strip()
|
||||
if part:
|
||||
if os.path.exists(part):
|
||||
items.append(Path(part))
|
||||
else:
|
||||
print(f"File {part} not found. Skipping for now. Please check your path")
|
||||
return items
|
||||
|
||||
if key == "upsample_prompt_mode":
|
||||
v = str(raw).strip().strip('"').strip("'").lower()
|
||||
if v in UPSAMPLING_MODE_FIELDS:
|
||||
return v
|
||||
raise ValueError(
|
||||
f"invalid upsample_prompt_mode: {v}. Must be one of: {', '.join(UPSAMPLING_MODE_FIELDS)}"
|
||||
)
|
||||
|
||||
# plain strings
|
||||
return raw
|
||||
|
||||
|
||||
def apply_updates(cfg: Config, updates: Dict[str, Any]) -> None:
|
||||
for k, v in updates.items():
|
||||
if not hasattr(cfg, k):
|
||||
print(f" ! unknown key: {k}", file=sys.stderr)
|
||||
continue
|
||||
# Validate upsample_prompt_mode
|
||||
if k == "upsample_prompt_mode":
|
||||
valid_modes = {"none", "local", "openrouter"}
|
||||
if v not in valid_modes:
|
||||
print(
|
||||
f" ! Invalid upsample_prompt_mode: {v}. Must be one of: {', '.join(valid_modes)}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
continue
|
||||
setattr(cfg, k, v)
|
||||
|
||||
|
||||
def parse_key_values(line: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Parse shell-like 'key=value' pairs. Values can be quoted.
|
||||
Example: prompt="a dog" width=768 input_images="in1.png,in2.jpg"
|
||||
"""
|
||||
updates: Dict[str, Any] = {}
|
||||
for token in shlex.split(line):
|
||||
if "=" not in token:
|
||||
# Allow bare commands like: run, show, reset, quit
|
||||
updates[token] = True
|
||||
continue
|
||||
key, val = token.split("=", 1)
|
||||
key = key.strip()
|
||||
val = val.strip()
|
||||
try:
|
||||
updates[key] = coerce_value(key, val)
|
||||
except Exception as e:
|
||||
print(f" ! could not parse {key}={val!r}: {e}", file=sys.stderr)
|
||||
return updates
|
||||
|
||||
|
||||
def print_config(cfg: Config):
|
||||
d = asdict(cfg)
|
||||
d["input_images"] = [str(p) for p in cfg.input_images]
|
||||
print("Current config:")
|
||||
for k in [
|
||||
"prompt",
|
||||
"seed",
|
||||
"width",
|
||||
"height",
|
||||
"num_steps",
|
||||
"guidance",
|
||||
"input_images",
|
||||
"match_image_size",
|
||||
"upsample_prompt_mode",
|
||||
"openrouter_model",
|
||||
]:
|
||||
print(f" {k}: {d[k]}")
|
||||
print()
|
||||
|
||||
|
||||
def print_help():
|
||||
print("""
|
||||
Available commands:
|
||||
[Enter] - Run generation with current config
|
||||
run - Run generation with current config
|
||||
show - Show current configuration
|
||||
reset - Reset configuration to defaults
|
||||
help, h, ? - Show this help message
|
||||
quit, q, exit - Exit the program
|
||||
|
||||
Setting parameters:
|
||||
key=value - Update a config parameter (shows updated config, doesn't run)
|
||||
|
||||
Examples:
|
||||
prompt="a cat in a hat"
|
||||
width=768 height=768
|
||||
seed=42
|
||||
num_steps=30
|
||||
guidance=3.5
|
||||
input_images="img1.jpg,img2.jpg"
|
||||
match_image_size=0 (use dimensions from first input image)
|
||||
upsample_prompt_mode="none" (prompt upsampling mode: "none", "local", or "openrouter")
|
||||
openrouter_model="mistralai/pixtral-large-2411" (OpenRouter model name)
|
||||
|
||||
You can combine parameter updates:
|
||||
prompt="sunset" width=1920 height=1080
|
||||
|
||||
Parameters:
|
||||
prompt - Text prompt for generation (string)
|
||||
seed - Random seed (integer or 'none' for random)
|
||||
width - Output width in pixels (integer)
|
||||
height - Output height in pixels (integer)
|
||||
num_steps - Number of denoising steps (integer)
|
||||
guidance - Guidance scale (float)
|
||||
input_images - Comma-separated list of input image paths (list)
|
||||
match_image_size - Index of input image to match dimensions from (integer, 0-based)
|
||||
upsample_prompt_mode - Prompt upsampling mode: "none" (default), "local", or "openrouter" (string)
|
||||
openrouter_model - OpenRouter model name (string, default: "mistralai/pixtral-large-2411")
|
||||
Examples: "mistralai/pixtral-large-2411", "qwen/qwen3-vl-235b-a22b-instruct", etc.
|
||||
Note: For "openrouter" mode, set OPENROUTER_API_KEY environment variable
|
||||
""")
|
||||
|
||||
|
||||
# ---------- Main Loop ----------
|
||||
|
||||
|
||||
def main(
|
||||
model_name: str = "flux.2-dev",
|
||||
single_eval: bool = False,
|
||||
prompt: str | None = None,
|
||||
debug_mode: bool = False,
|
||||
cpu_offloading: bool = False,
|
||||
**overwrite,
|
||||
):
|
||||
assert (
|
||||
model_name.lower() in FLUX2_MODEL_INFO
|
||||
), f"{model_name} is not available, choose from {FLUX2_MODEL_INFO.keys()}"
|
||||
|
||||
torch_device = torch.device("cuda")
|
||||
|
||||
mistral = load_mistral_small_embedder()
|
||||
model = load_flow_model(
|
||||
model_name, debug_mode=debug_mode, device="cpu" if cpu_offloading else torch_device
|
||||
)
|
||||
ae = load_ae(model_name)
|
||||
ae.eval()
|
||||
mistral.eval()
|
||||
|
||||
# API client will be initialized lazily when needed
|
||||
openrouter_api_client: Optional[OpenRouterAPIClient] = None
|
||||
|
||||
cfg = DEFAULTS.copy()
|
||||
changes = [f"{key}={value}" for key, value in overwrite.items()]
|
||||
updates = parse_key_values(" ".join(changes))
|
||||
apply_updates(cfg, updates)
|
||||
if prompt is not None:
|
||||
cfg.prompt = prompt
|
||||
print_config(cfg)
|
||||
|
||||
while True:
|
||||
if not single_eval:
|
||||
try:
|
||||
line = input("> ").strip()
|
||||
except (EOFError, KeyboardInterrupt):
|
||||
print("\nbye!")
|
||||
break
|
||||
|
||||
if not line:
|
||||
# Empty -> run with current config
|
||||
cmd = "run"
|
||||
updates = {}
|
||||
else:
|
||||
try:
|
||||
updates = parse_key_values(line)
|
||||
except Exception as e: # noqa: BLE001
|
||||
print(f" ! Failed to parse command: {type(e).__name__}: {e}", file=sys.stderr)
|
||||
print(
|
||||
" ! Please check your syntax (e.g., matching quotes) and try again.\n",
|
||||
file=sys.stderr,
|
||||
)
|
||||
continue
|
||||
|
||||
if "prompt" in updates and mistral.test_txt(updates["prompt"]):
|
||||
print(
|
||||
"Your prompt has been flagged for potential copyright or public personas concerns. Please choose another."
|
||||
)
|
||||
updates.pop("prompt")
|
||||
|
||||
if "input_images" in updates:
|
||||
flagged = False
|
||||
for image in updates["input_images"]:
|
||||
if mistral.test_image(image):
|
||||
print(f"The image {image} has been flagged as unsuitable. Please choose another.")
|
||||
flagged = True
|
||||
if flagged:
|
||||
updates.pop("input_images")
|
||||
|
||||
# If the line was only 'run' / 'show' / ... it will appear as {cmd: True}
|
||||
# If it had key=val pairs, there may be no bare command -> just update config
|
||||
bare_cmds = [k for k, v in updates.items() if v is True and k.isalpha()]
|
||||
cmd = bare_cmds[0] if bare_cmds else None
|
||||
|
||||
# Remove bare commands from updates so they don't get applied as fields
|
||||
for c in bare_cmds:
|
||||
updates.pop(c, None)
|
||||
|
||||
if cmd in ("quit", "q", "exit"):
|
||||
print("bye!")
|
||||
break
|
||||
elif cmd == "reset":
|
||||
cfg = DEFAULTS.copy()
|
||||
print_config(cfg)
|
||||
continue
|
||||
elif cmd == "show":
|
||||
print_config(cfg)
|
||||
continue
|
||||
elif cmd in ("help", "h", "?"):
|
||||
print_help()
|
||||
continue
|
||||
|
||||
# Apply key=value changes
|
||||
if updates:
|
||||
apply_updates(cfg, updates)
|
||||
print_config(cfg)
|
||||
continue
|
||||
|
||||
# Only run if explicitly requested (empty line or 'run' command)
|
||||
if cmd != "run":
|
||||
if cmd is not None:
|
||||
print(f" ! Unknown command: '{cmd}'", file=sys.stderr)
|
||||
print(" ! Type 'help' to see available commands.\n", file=sys.stderr)
|
||||
continue
|
||||
|
||||
try:
|
||||
# Load input images first to potentially match dimensions
|
||||
img_ctx = [Image.open(input_image) for input_image in cfg.input_images]
|
||||
|
||||
# Apply match_image_size if specified
|
||||
width = cfg.width
|
||||
height = cfg.height
|
||||
if cfg.match_image_size is not None:
|
||||
if cfg.match_image_size < 0 or cfg.match_image_size >= len(img_ctx):
|
||||
print(
|
||||
f" ! match_image_size={cfg.match_image_size} is out of range (0-{len(img_ctx)-1})",
|
||||
file=sys.stderr,
|
||||
)
|
||||
print(f" ! Using default dimensions: {width}x{height}", file=sys.stderr)
|
||||
else:
|
||||
ref_img = img_ctx[cfg.match_image_size]
|
||||
width, height = ref_img.size
|
||||
print(f" Matched dimensions from image {cfg.match_image_size}: {width}x{height}")
|
||||
|
||||
seed = cfg.seed if cfg.seed is not None else random.randrange(2**31)
|
||||
dir = Path("output")
|
||||
dir.mkdir(exist_ok=True)
|
||||
output_name = dir / f"sample_{len(list(dir.glob('*')))}.png"
|
||||
|
||||
with torch.no_grad():
|
||||
ref_tokens, ref_ids = encode_image_refs(ae, img_ctx)
|
||||
|
||||
if cfg.upsample_prompt_mode == "openrouter":
|
||||
try:
|
||||
# Ensure API key is available, otherwise prompt the user
|
||||
api_key = os.environ.get("OPENROUTER_API_KEY", "").strip()
|
||||
if not api_key:
|
||||
try:
|
||||
entered = input(
|
||||
"OPENROUTER_API_KEY not set. Enter it now (leave blank to skip OpenRouter upsampling): "
|
||||
).strip()
|
||||
except (EOFError, KeyboardInterrupt):
|
||||
entered = ""
|
||||
if entered:
|
||||
os.environ["OPENROUTER_API_KEY"] = entered
|
||||
else:
|
||||
print(
|
||||
" ! No API key provided; disabling OpenRouter upsampling",
|
||||
file=sys.stderr,
|
||||
)
|
||||
cfg.upsample_prompt_mode = "none"
|
||||
prompt = cfg.prompt
|
||||
# Skip OpenRouter flow
|
||||
|
||||
# Only proceed if still in openrouter mode (not disabled above)
|
||||
if cfg.upsample_prompt_mode == "openrouter":
|
||||
# Let user specify sampling params, or use model defaults if available
|
||||
sampling_params_input = ""
|
||||
try:
|
||||
sampling_params_input = input(
|
||||
"Enter OpenRouter sampling params as JSON or key=value (blank to use defaults): "
|
||||
).strip()
|
||||
except (EOFError, KeyboardInterrupt):
|
||||
sampling_params_input = ""
|
||||
|
||||
sampling_params: Dict[str, Any] = {}
|
||||
if sampling_params_input:
|
||||
# Try JSON first
|
||||
parsed_ok = False
|
||||
try:
|
||||
parsed = json.loads(sampling_params_input)
|
||||
if isinstance(parsed, dict):
|
||||
sampling_params = parsed
|
||||
parsed_ok = True
|
||||
except Exception:
|
||||
parsed_ok = False
|
||||
if not parsed_ok:
|
||||
# Fallback: parse key=value pairs separated by spaces or commas
|
||||
tokens = [
|
||||
tok
|
||||
for tok in sampling_params_input.replace(",", " ").split(" ")
|
||||
if tok
|
||||
]
|
||||
for tok in tokens:
|
||||
if "=" not in tok:
|
||||
continue
|
||||
k, v = tok.split("=", 1)
|
||||
v_str = v.strip()
|
||||
v_low = v_str.lower()
|
||||
if v_low in {"true", "false"}:
|
||||
val: Any = v_low == "true"
|
||||
else:
|
||||
try:
|
||||
if "." in v_str:
|
||||
num = float(v_str)
|
||||
val = int(num) if num.is_integer() else num
|
||||
else:
|
||||
val = int(v_str)
|
||||
except Exception:
|
||||
val = v_str
|
||||
sampling_params[k.strip()] = val
|
||||
print(f" Using custom OpenRouter sampling params: {sampling_params}")
|
||||
else:
|
||||
model_key = cfg.openrouter_model
|
||||
default_params = DEFAULT_SAMPLING_PARAMS.get(model_key)
|
||||
if default_params:
|
||||
sampling_params = default_params
|
||||
print(
|
||||
f" Using default OpenRouter sampling params for {model_key}: {sampling_params}"
|
||||
)
|
||||
else:
|
||||
print(
|
||||
f" Setting no OpenRouter sampling params: not set for this model ({model_key})"
|
||||
)
|
||||
|
||||
# Initialize or reinitialize client if model changed
|
||||
if (
|
||||
openrouter_api_client is None
|
||||
or openrouter_api_client.model != cfg.openrouter_model
|
||||
or getattr(openrouter_api_client, "sampling_params", None) != sampling_params
|
||||
):
|
||||
openrouter_api_client = OpenRouterAPIClient(
|
||||
model=cfg.openrouter_model,
|
||||
sampling_params=sampling_params,
|
||||
)
|
||||
else:
|
||||
# Ensure client uses latest sampling params
|
||||
openrouter_api_client.sampling_params = sampling_params
|
||||
upsampled_prompts = openrouter_api_client.upsample_prompt(
|
||||
[cfg.prompt], img=[img_ctx] if img_ctx else None
|
||||
)
|
||||
prompt = upsampled_prompts[0] if upsampled_prompts else cfg.prompt
|
||||
except Exception as e:
|
||||
print(f" ! Failed to upsample prompt via OpenRouter API: {e}", file=sys.stderr)
|
||||
print(
|
||||
" ! Disabling OpenRouter upsampling and falling back to original prompt",
|
||||
file=sys.stderr,
|
||||
)
|
||||
cfg.upsample_prompt_mode = "none"
|
||||
prompt = cfg.prompt
|
||||
elif cfg.upsample_prompt_mode == "local":
|
||||
# Use local model for upsampling
|
||||
upsampled_prompts = mistral.upsample_prompt(
|
||||
[cfg.prompt], img=[img_ctx] if img_ctx else None
|
||||
)
|
||||
prompt = upsampled_prompts[0] if upsampled_prompts else cfg.prompt
|
||||
else:
|
||||
# upsample_prompt_mode == "none" or invalid value
|
||||
prompt = cfg.prompt
|
||||
|
||||
print("Generating with prompt: ", prompt)
|
||||
|
||||
ctx = mistral([prompt]).to(torch.bfloat16)
|
||||
ctx, ctx_ids = batched_prc_txt(ctx)
|
||||
|
||||
if cpu_offloading:
|
||||
mistral = mistral.cpu()
|
||||
torch.cuda.empty_cache()
|
||||
model = model.to(torch_device)
|
||||
|
||||
# Create noise
|
||||
shape = (1, 128, height // 16, width // 16)
|
||||
generator = torch.Generator(device="cuda").manual_seed(seed)
|
||||
randn = torch.randn(shape, generator=generator, dtype=torch.bfloat16, device="cuda")
|
||||
x, x_ids = batched_prc_img(randn)
|
||||
|
||||
timesteps = get_schedule(cfg.num_steps, x.shape[1])
|
||||
x = denoise(
|
||||
model,
|
||||
x,
|
||||
x_ids,
|
||||
ctx,
|
||||
ctx_ids,
|
||||
timesteps=timesteps,
|
||||
guidance=cfg.guidance,
|
||||
img_cond_seq=ref_tokens,
|
||||
img_cond_seq_ids=ref_ids,
|
||||
)
|
||||
x = torch.cat(scatter_ids(x, x_ids)).squeeze(2)
|
||||
x = ae.decode(x).float()
|
||||
# x = embed_watermark(x)
|
||||
|
||||
if cpu_offloading:
|
||||
model = model.cpu()
|
||||
torch.cuda.empty_cache()
|
||||
mistral = mistral.to(torch_device)
|
||||
|
||||
x = x.clamp(-1, 1)
|
||||
x = rearrange(x[0], "c h w -> h w c")
|
||||
|
||||
img = Image.fromarray((127.5 * (x + 1.0)).cpu().byte().numpy())
|
||||
if mistral.test_image(img):
|
||||
print("Your output has been flagged. Please choose another prompt / input image combination")
|
||||
else:
|
||||
exif_data = Image.Exif()
|
||||
exif_data[ExifTags.Base.Software] = "AI generated;flux2"
|
||||
exif_data[ExifTags.Base.Make] = "Black Forest Labs"
|
||||
img.save(output_name, exif=exif_data, quality=95, subsampling=0)
|
||||
print(f"Saved {output_name}")
|
||||
|
||||
except Exception as e: # noqa: BLE001
|
||||
print(f"\n ERROR: {type(e).__name__}: {e}", file=sys.stderr)
|
||||
print(" The model is still loaded. Please fix the error and try again.\n", file=sys.stderr)
|
||||
|
||||
if single_eval:
|
||||
break
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from fire import Fire
|
||||
|
||||
Fire(main)
|
||||
Reference in New Issue
Block a user