branch: master
flux1.py
21662 bytesRaw
# pip3 install sentencepiece

# This file incorporates code from the following:
# Github Name                    | License | Link
# black-forest-labs/flux         | Apache  | https://github.com/black-forest-labs/flux/tree/main/model_licenses

from tinygrad import Tensor, nn, dtypes, TinyJit
from tinygrad.nn.state import safe_load, load_state_dict
from tinygrad.helpers import fetch, tqdm, colored
from sdxl import FirstStage
from extra.models.clip import FrozenClosedClipEmbedder
from extra.models.t5 import T5Embedder
import numpy as np

import math, time, argparse, tempfile
from typing import List, Dict, Optional, Union, Tuple, Callable
from dataclasses import dataclass
from pathlib import Path
from PIL import Image

urls:dict = {
  "flux-schnell": "https://huggingface.co/black-forest-labs/FLUX.1-schnell/resolve/main/flux1-schnell.safetensors",
  "flux-dev": "https://huggingface.co/camenduru/FLUX.1-dev/resolve/main/flux1-dev.sft",
  "ae": "https://huggingface.co/black-forest-labs/FLUX.1-schnell/resolve/main/ae.safetensors",
  "T5_1_of_2": "https://huggingface.co/black-forest-labs/FLUX.1-schnell/resolve/main/text_encoder_2/model-00001-of-00002.safetensors",
  "T5_2_of_2": "https://huggingface.co/black-forest-labs/FLUX.1-schnell/resolve/main/text_encoder_2/model-00002-of-00002.safetensors",
  "T5_tokenizer": "https://huggingface.co/black-forest-labs/FLUX.1-schnell/resolve/main/tokenizer_2/spiece.model",
  "clip": "https://huggingface.co/black-forest-labs/FLUX.1-schnell/resolve/main/text_encoder/model.safetensors"
}

def tensor_identity(x:Tensor) -> Tensor: return x

class AutoEncoder:
  def __init__(self, scale_factor:float, shift_factor:float):
    self.decoder = FirstStage.Decoder(128, 3, 3, 16, [1, 2, 4, 4], 2, 256)
    self.scale_factor = scale_factor
    self.shift_factor = shift_factor

  def decode(self, z:Tensor) -> Tensor:
    z = z / self.scale_factor + self.shift_factor
    return self.decoder(z)

# Conditioner
class ClipEmbedder(FrozenClosedClipEmbedder):
  def __call__(self, texts:Union[str, List[str], Tensor]) -> Tensor:
    if isinstance(texts, str): texts = [texts]
    assert isinstance(texts, (list,tuple)), f"expected list of strings, got {type(texts).__name__}"
    tokens = Tensor.cat(*[Tensor(self.tokenizer.encode(text)) for text in texts], dim=0)
    return self.transformer.text_model(tokens.reshape(len(texts),-1))[:, tokens.argmax(-1)]

# https://github.com/black-forest-labs/flux/blob/main/src/flux/math.py
def attention(q:Tensor, k:Tensor, v:Tensor, pe:Tensor) -> Tensor:
  q, k = apply_rope(q, k, pe)
  x = Tensor.scaled_dot_product_attention(q, k, v)
  return x.rearrange("B H L D -> B L (H D)")

def rope(pos:Tensor, dim:int, theta:int) -> Tensor:
  assert dim % 2 == 0
  scale = Tensor.arange(0, dim, 2, dtype=dtypes.float32, device=pos.device) / dim # NOTE: this is torch.float64 in reference implementation
  omega = 1.0 / (theta**scale)
  out = Tensor.einsum("...n,d->...nd", pos, omega)
  out = Tensor.stack(Tensor.cos(out), -Tensor.sin(out), Tensor.sin(out), Tensor.cos(out), dim=-1)
  out = out.rearrange("b n d (i j) -> b n d i j", i=2, j=2)
  return out.float()

def apply_rope(xq:Tensor, xk:Tensor, freqs_cis:Tensor) -> Tuple[Tensor, Tensor]:
  xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
  xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
  xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
  xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
  return xq_out.reshape(*xq.shape).cast(xq.dtype), xk_out.reshape(*xk.shape).cast(xk.dtype)


# https://github.com/black-forest-labs/flux/blob/main/src/flux/modules/layers.py
class EmbedND:
  def __init__(self, dim:int, theta:int, axes_dim:List[int]):
    self.dim = dim
    self.theta = theta
    self.axes_dim = axes_dim

  def __call__(self, ids:Tensor) -> Tensor:
    n_axes = ids.shape[-1]
    emb = Tensor.cat(*[rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)], dim=-3)
    return emb.unsqueeze(1)

class MLPEmbedder:
  def __init__(self, in_dim:int, hidden_dim:int):
    self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True)
    self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True)

  def __call__(self, x:Tensor) -> Tensor:
    return self.out_layer(self.in_layer(x).silu())

class QKNorm:
  def __init__(self, dim:int):
    self.query_norm = nn.RMSNorm(dim)
    self.key_norm = nn.RMSNorm(dim)

  def __call__(self, q:Tensor, k:Tensor) -> Tuple[Tensor, Tensor]:
    return self.query_norm(q), self.key_norm(k)

class SelfAttention:
  def __init__(self, dim:int, num_heads:int = 8, qkv_bias:bool = False):
    self.num_heads = num_heads
    head_dim = dim // num_heads

    self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
    self.norm = QKNorm(head_dim)
    self.proj = nn.Linear(dim, dim)

  def __call__(self, x:Tensor, pe:Tensor) -> Tensor:
    qkv = self.qkv(x)
    q, k, v = qkv.rearrange("B L (K H D) -> K B H L D", K=3, H=self.num_heads)
    q, k = self.norm(q, k)
    x = attention(q, k, v, pe=pe)
    return self.proj(x)

@dataclass
class ModulationOut:
  shift:Tensor
  scale:Tensor
  gate:Tensor

class Modulation:
  def __init__(self, dim:int, double:bool):
    self.is_double = double
    self.multiplier = 6 if double else 3
    self.lin = nn.Linear(dim, self.multiplier * dim, bias=True)

  def __call__(self, vec:Tensor) -> Tuple[ModulationOut, Optional[ModulationOut]]:
    out = self.lin(vec.silu())[:, None, :].chunk(self.multiplier, dim=-1)
    return ModulationOut(*out[:3]), ModulationOut(*out[3:]) if self.is_double else None

class DoubleStreamBlock:
  def __init__(self, hidden_size:int, num_heads:int, mlp_ratio:float, qkv_bias:bool = False):
    mlp_hidden_dim = int(hidden_size * mlp_ratio)
    self.num_heads = num_heads
    self.hidden_size = hidden_size
    self.img_mod = Modulation(hidden_size, double=True)
    self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
    self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)

    self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
    self.img_mlp = [nn.Linear(hidden_size, mlp_hidden_dim, bias=True), Tensor.gelu, nn.Linear(mlp_hidden_dim, hidden_size, bias=True)]

    self.txt_mod = Modulation(hidden_size, double=True)
    self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
    self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)

    self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
    self.txt_mlp = [nn.Linear(hidden_size, mlp_hidden_dim, bias=True), Tensor.gelu, nn.Linear(mlp_hidden_dim, hidden_size, bias=True)]

  def __call__(self, img:Tensor, txt:Tensor, vec:Tensor, pe:Tensor) -> tuple[Tensor, Tensor]:
    img_mod1, img_mod2 = self.img_mod(vec)
    txt_mod1, txt_mod2 = self.txt_mod(vec)
    assert img_mod2 is not None and txt_mod2 is not None
    # prepare image for attention
    img_modulated = self.img_norm1(img)
    img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
    img_qkv = self.img_attn.qkv(img_modulated)
    img_q, img_k, img_v = img_qkv.rearrange("B L (K H D) -> K B H L D", K=3, H=self.num_heads)
    img_q, img_k = self.img_attn.norm(img_q, img_k)

    # prepare txt for attention
    txt_modulated = self.txt_norm1(txt)
    txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
    txt_qkv = self.txt_attn.qkv(txt_modulated)
    txt_q, txt_k, txt_v = txt_qkv.rearrange("B L (K H D) -> K B H L D", K=3, H=self.num_heads)
    txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k)

    # run actual attention
    q = Tensor.cat(txt_q, img_q, dim=2)
    k = Tensor.cat(txt_k, img_k, dim=2)
    v = Tensor.cat(txt_v, img_v, dim=2)

    attn = attention(q, k, v, pe=pe)
    txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]

    # calculate the img bloks
    img = img + img_mod1.gate * self.img_attn.proj(img_attn)
    img = img + img_mod2.gate * ((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift).sequential(self.img_mlp)

    # calculate the txt bloks
    txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn)
    txt = txt + txt_mod2.gate * ((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift).sequential(self.txt_mlp)
    return img, txt


class SingleStreamBlock:
  """
  A DiT block with parallel linear layers as described in
  https://arxiv.org/abs/2302.05442 and adapted modulation interface.
  """

  def __init__(self,hidden_size:int, num_heads:int, mlp_ratio:float=4.0, qk_scale:Optional[float]=None):
    self.hidden_dim = hidden_size
    self.num_heads = num_heads
    head_dim = hidden_size // num_heads
    self.scale = qk_scale or head_dim**-0.5

    self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
    # qkv and mlp_in
    self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim)
    # proj and mlp_out
    self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size)

    self.norm = QKNorm(head_dim)

    self.hidden_size = hidden_size
    self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)

    self.mlp_act = Tensor.gelu
    self.modulation = Modulation(hidden_size, double=False)

  def __call__(self, x:Tensor, vec:Tensor, pe:Tensor) -> Tensor:
    mod, _ = self.modulation(vec)
    x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
    qkv, mlp = Tensor.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
    q, k, v = qkv.rearrange("B L (K H D) -> K B H L D", K=3, H=self.num_heads)
    q, k = self.norm(q, k)

    # compute attention
    attn = attention(q, k, v, pe=pe)
    # compute activation in mlp stream, cat again and run second linear layer
    output = self.linear2(Tensor.cat(attn, self.mlp_act(mlp), dim=2))
    return x + mod.gate * output


class LastLayer:
  def __init__(self, hidden_size:int, patch_size:int, out_channels:int):
    self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
    self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
    self.adaLN_modulation:List[Callable[[Tensor], Tensor]] = [Tensor.silu, nn.Linear(hidden_size, 2 * hidden_size, bias=True)]

  def __call__(self, x:Tensor, vec:Tensor) -> Tensor:
    shift, scale = vec.sequential(self.adaLN_modulation).chunk(2, dim=1)
    x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
    return self.linear(x)

def timestep_embedding(t:Tensor, dim:int, max_period:int=10000, time_factor:float=1000.0) -> Tensor:
  """
  Create sinusoidal timestep embeddings.
  :param t: a 1-D Tensor of N indices, one per batch element.
                    These may be fractional.
  :param dim: the dimension of the output.
  :param max_period: controls the minimum frequency of the embeddings.
  :return: an (N, D) Tensor of positional embeddings.
  """
  t = time_factor * t
  half = dim // 2
  freqs = Tensor.exp(-math.log(max_period) * Tensor.arange(0, stop=half, dtype=dtypes.float32) / half).to(t.device)

  args = t[:, None].float() * freqs[None]
  embedding = Tensor.cat(Tensor.cos(args), Tensor.sin(args), dim=-1)
  if dim % 2:  embedding = Tensor.cat(*[embedding, Tensor.zeros_like(embedding[:, :1])], dim=-1)
  if Tensor.is_floating_point(t):  embedding = embedding.cast(t.dtype)
  return embedding

# https://github.com/black-forest-labs/flux/blob/main/src/flux/model.py
class Flux:
  """
  Transformer model for flow matching on sequences.
  """

  def __init__(
      self,
      guidance_embed:bool,
      in_channels:int = 64,
      vec_in_dim:int = 768,
      context_in_dim:int = 4096,
      hidden_size:int = 3072,
      mlp_ratio:float = 4.0,
      num_heads:int = 24,
      depth:int = 19,
      depth_single_blocks:int = 38,
      axes_dim:Optional[List[int]] = None,
      theta:int = 10_000,
      qkv_bias:bool = True,
      ):

    axes_dim = axes_dim or [16, 56, 56]
    self.guidance_embed = guidance_embed
    self.in_channels = in_channels
    self.out_channels = self.in_channels
    if hidden_size % num_heads != 0:
      raise ValueError(f"Hidden size {hidden_size} must be divisible by num_heads {num_heads}")
    pe_dim = hidden_size // num_heads
    if sum(axes_dim) != pe_dim:
      raise ValueError(f"Got {axes_dim} but expected positional dim {pe_dim}")
    self.hidden_size = hidden_size
    self.num_heads = num_heads
    self.pe_embedder = EmbedND(dim=pe_dim, theta=theta, axes_dim=axes_dim)
    self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
    self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
    self.vector_in = MLPEmbedder(vec_in_dim, self.hidden_size)
    self.guidance_in:Callable[[Tensor], Tensor] = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if guidance_embed else tensor_identity
    self.txt_in = nn.Linear(context_in_dim, self.hidden_size)

    self.double_blocks = [DoubleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias) for _ in range(depth)]
    self.single_blocks = [SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=mlp_ratio) for _ in range(depth_single_blocks)]
    self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)

  def __call__(self, img:Tensor, img_ids:Tensor, txt:Tensor, txt_ids:Tensor, timesteps:Tensor, y:Tensor, guidance:Optional[Tensor] = None) -> Tensor:
    if img.ndim != 3 or txt.ndim != 3:
      raise ValueError("Input img and txt tensors must have 3 dimensions.")
    # running on sequences img
    img = self.img_in(img)
    vec = self.time_in(timestep_embedding(timesteps, 256))
    if self.guidance_embed:
      if guidance is None:
        raise ValueError("Didn't get guidance strength for guidance distilled model.")
      vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
    vec = vec + self.vector_in(y)
    txt = self.txt_in(txt)
    ids = Tensor.cat(txt_ids, img_ids, dim=1)
    pe = self.pe_embedder(ids)
    for double_block in self.double_blocks:
      img, txt = double_block(img=img, txt=txt, vec=vec, pe=pe)

    img = Tensor.cat(txt, img, dim=1)
    for single_block in self.single_blocks:
      img = single_block(img, vec=vec, pe=pe)

    img = img[:, txt.shape[1] :, ...]

    return self.final_layer(img, vec)  # (N, T, patch_size ** 2 * out_channels)

# https://github.com/black-forest-labs/flux/blob/main/src/flux/util.py
def load_flow_model(name:str, model_path:str):
  # Loading Flux
  print("Init model")
  model = Flux(guidance_embed=(name != "flux-schnell"))
  if not model_path: model_path = fetch(urls[name])
  state_dict = {k.replace("scale", "weight"): v for k, v in safe_load(model_path).items()}
  load_state_dict(model, state_dict)
  return model

def load_T5(max_length:int=512):
  # max length 64, 128, 256 and 512 should work (if your sequence is short enough)
  print("Init T5")
  T5 = T5Embedder(max_length, fetch(urls["T5_tokenizer"]))
  pt_1 = fetch(urls["T5_1_of_2"])
  pt_2 = fetch(urls["T5_2_of_2"])
  load_state_dict(T5.encoder, safe_load(pt_1) | safe_load(pt_2), strict=False)
  return T5

def load_clip():
  print("Init Clip")
  clip = ClipEmbedder()
  load_state_dict(clip.transformer, safe_load(fetch(urls["clip"])))
  return clip

def load_ae() -> AutoEncoder:
  # Loading the autoencoder
  print("Init AE")
  ae = AutoEncoder(0.3611, 0.1159)
  load_state_dict(ae, safe_load(fetch(urls["ae"])))
  return ae

# https://github.com/black-forest-labs/flux/blob/main/src/flux/sampling.py
def prepare(T5:T5Embedder, clip:ClipEmbedder, img:Tensor, prompt:Union[str, List[str]]) -> Dict[str, Tensor]:
  bs, _, h, w = img.shape
  if bs == 1 and not isinstance(prompt, str):
    bs = len(prompt)

  img = img.rearrange("b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
  if img.shape[0] == 1 and bs > 1:
    img = img.expand((bs, *img.shape[1:]))

  img_ids = Tensor.zeros(h // 2, w // 2, 3).contiguous()
  img_ids[..., 1] = img_ids[..., 1] + Tensor.arange(h // 2)[:, None]
  img_ids[..., 2] = img_ids[..., 2] + Tensor.arange(w // 2)[None, :]
  img_ids = img_ids.rearrange("h w c -> 1 (h w) c")
  img_ids = img_ids.expand((bs, *img_ids.shape[1:]))

  if isinstance(prompt, str):
    prompt = [prompt]
  txt = T5(prompt).realize()
  if txt.shape[0] == 1 and bs > 1:
    txt = txt.expand((bs, *txt.shape[1:]))
  txt_ids = Tensor.zeros(bs, txt.shape[1], 3)

  vec = clip(prompt).realize()
  if vec.shape[0] == 1 and bs > 1:
    vec = vec.expand((bs, *vec.shape[1:]))

  return {"img": img, "img_ids": img_ids.to(img.device), "txt": txt.to(img.device), "txt_ids": txt_ids.to(img.device), "vec": vec.to(img.device)}


def get_schedule(num_steps:int, image_seq_len:int, base_shift:float=0.5, max_shift:float=1.15, shift:bool=True) -> List[float]:
  # extra step for zero
  step_size = -1.0 / num_steps
  timesteps = Tensor.arange(1, 0 + step_size, step_size)

  # shifting the schedule to favor high timesteps for higher signal images
  if shift:
    # estimate mu based on linear estimation between two points
    mu = 0.5 + (max_shift - base_shift) * (image_seq_len - 256) / (4096 - 256)
    timesteps = math.exp(mu) / (math.exp(mu) + (1 / timesteps - 1))
  return timesteps.tolist()

@TinyJit
def run(model, *args): return model(*args).realize()

def denoise(model, img:Tensor, img_ids:Tensor, txt:Tensor, txt_ids:Tensor, vec:Tensor, timesteps:List[float], guidance:float=4.0) -> Tensor:
  # this is ignored for schnell
  guidance_vec = Tensor((guidance,), device=img.device, dtype=img.dtype).expand((img.shape[0],))
  for t_curr, t_prev in tqdm(list(zip(timesteps[:-1], timesteps[1:])), "Denoising"):
    t_vec = Tensor((t_curr,), device=img.device, dtype=img.dtype).expand((img.shape[0],))
    pred = run(model, img, img_ids, txt, txt_ids, t_vec, vec, guidance_vec)
    img = img + (t_prev - t_curr) * pred

  return img

def unpack(x:Tensor, height:int, width:int) -> Tensor:
  return x.rearrange("b (h w) (c ph pw) -> b c (h ph) (w pw)", h=math.ceil(height / 16), w=math.ceil(width / 16), ph=2, pw=2)

# https://github.com/black-forest-labs/flux/blob/main/src/flux/cli.py
if __name__ == "__main__":
  default_prompt = "bananas and a can of coke"
  parser = argparse.ArgumentParser(description="Run Flux.1", formatter_class=argparse.ArgumentDefaultsHelpFormatter)

  parser.add_argument("--name",       type=str,   default="flux-schnell", help="Name of the model to load")
  parser.add_argument("--model_path", type=str,   default="",             help="path of the model file")
  parser.add_argument("--width",      type=int,   default=512,            help="width of the sample in pixels (should be a multiple of 16)")
  parser.add_argument("--height",     type=int,   default=512,            help="height of the sample in pixels (should be a multiple of 16)")
  parser.add_argument("--seed",       type=int,   default=None,           help="Set a seed for sampling")
  parser.add_argument("--prompt",     type=str,   default=default_prompt, help="Prompt used for sampling")
  parser.add_argument('--out',        type=str,   default=Path(tempfile.gettempdir()) / "rendered.png", help="Output filename")
  parser.add_argument("--num_steps",  type=int,   default=None,           help="number of sampling steps (default 4 for schnell, 50 for guidance distilled)") #noqa:E501
  parser.add_argument("--guidance",   type=float, default=3.5,            help="guidance value used for guidance distillation")
  parser.add_argument("--output_dir", type=str,   default="output",       help="output directory")
  args = parser.parse_args()

  if args.name not in ["flux-schnell", "flux-dev"]:
    raise ValueError(f"Got unknown model name: {args.name}, chose from flux-schnell and flux-dev")

  if args.num_steps is None:
    args.num_steps = 4 if args.name == "flux-schnell" else 50

  # allow for packing and conversion to latent space
  height = 16 * (args.height // 16)
  width = 16 * (args.width // 16)

  if args.seed is None: args.seed = Tensor._seed
  else: Tensor.manual_seed(args.seed)

  print(f"Generating with seed {args.seed}:\n{args.prompt}")
  t0 = time.perf_counter()

  # prepare input noise
  x = Tensor.randn(1, 16, 2 * math.ceil(height / 16), 2 * math.ceil(width / 16), dtype="bfloat16")

  # load text embedders
  T5 = load_T5(max_length=256 if args.name == "flux-schnell" else 512)
  clip = load_clip()

  # embed text to get inputs for model
  inp = prepare(T5, clip, x, prompt=args.prompt)
  timesteps = get_schedule(args.num_steps, inp["img"].shape[1], shift=(args.name != "flux-schnell"))

  # done with text embedders
  del T5, clip

  # load model
  model = load_flow_model(args.name, args.model_path)

  # denoise initial noise
  x = denoise(model, **inp, timesteps=timesteps, guidance=args.guidance)

  # done with model
  del model, run

  # load autoencoder
  ae = load_ae()

  # decode latents to pixel space
  x = unpack(x.float(), height, width)
  x = ae.decode(x).realize()

  t1 = time.perf_counter()
  print(f"Done in {t1 - t0:.1f}s. Saving {args.out}")

  # bring into PIL format and save
  x = x.clamp(-1, 1)
  x = x[0].rearrange("c h w -> h w c")
  x = (127.5 * (x + 1.0)).cast("uint8")

  img = Image.fromarray(x.numpy())

  img.save(args.out)

  # validation!
  if args.prompt == default_prompt and args.name=="flux-schnell" and args.seed == 0 and args.width == args.height == 512:
    ref_image = Tensor(np.array(Image.open("examples/flux1_seed0.png")))
    distance = (((x.cast(dtypes.float) - ref_image.cast(dtypes.float)) / ref_image.max())**2).mean().item()
    assert distance < 4e-3, colored(f"validation failed with {distance=}", "red")
    print(colored(f"output validated with {distance=}", "green"))