From 853551bd6ee1314576c4d9a8e539b6e02a7bb205 Mon Sep 17 00:00:00 2001 From: Won-Kyu Park Date: Sat, 31 Aug 2024 16:34:33 +0900 Subject: [PATCH 01/50] import Flux from https://github.com/black-forest-labs/flux/ License: Apache 2.0 original author: Tim Dockhorn @timudk --- modules/models/flux/math.py | 30 +++ modules/models/flux/model.py | 112 ++++++++++++ modules/models/flux/modules/layers.py | 253 ++++++++++++++++++++++++++ modules/models/flux/util.py | 201 ++++++++++++++++++++ 4 files changed, 596 insertions(+) create mode 100644 modules/models/flux/math.py create mode 100644 modules/models/flux/model.py create mode 100644 modules/models/flux/modules/layers.py create mode 100644 modules/models/flux/util.py diff --git a/modules/models/flux/math.py b/modules/models/flux/math.py new file mode 100644 index 000000000..0156bb6a2 --- /dev/null +++ b/modules/models/flux/math.py @@ -0,0 +1,30 @@ +import torch +from einops import rearrange +from torch import Tensor + + +def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor: + q, k = apply_rope(q, k, pe) + + x = torch.nn.functional.scaled_dot_product_attention(q, k, v) + x = rearrange(x, "B H L D -> B L (H D)") + + return x + + +def rope(pos: Tensor, dim: int, theta: int) -> Tensor: + assert dim % 2 == 0 + scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim + omega = 1.0 / (theta**scale) + out = torch.einsum("...n,d->...nd", pos, omega) + out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1) + out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2) + return out.float() + + +def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]: + xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2) + xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2) + xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1] + xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1] + return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk) diff --git a/modules/models/flux/model.py b/modules/models/flux/model.py new file mode 100644 index 000000000..f33ab8323 --- /dev/null +++ b/modules/models/flux/model.py @@ -0,0 +1,112 @@ +from dataclasses import dataclass + +import torch +from torch import Tensor, nn + +from flux.modules.layers import (DoubleStreamBlock, EmbedND, LastLayer, + MLPEmbedder, SingleStreamBlock, + timestep_embedding) + + +@dataclass +class FluxParams: + in_channels: int + vec_in_dim: int + context_in_dim: int + hidden_size: int + mlp_ratio: float + num_heads: int + depth: int + depth_single_blocks: int + axes_dim: list[int] + theta: int + qkv_bias: bool + guidance_embed: bool + + +class Flux(nn.Module): + """ + Transformer model for flow matching on sequences. + """ + + def __init__(self, params: FluxParams): + super().__init__() + + self.params = params + self.in_channels = params.in_channels + self.out_channels = self.in_channels + if params.hidden_size % params.num_heads != 0: + raise ValueError( + f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}" + ) + pe_dim = params.hidden_size // params.num_heads + if sum(params.axes_dim) != pe_dim: + raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}") + self.hidden_size = params.hidden_size + self.num_heads = params.num_heads + self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim) + self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True) + self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) + self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size) + self.guidance_in = ( + MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else nn.Identity() + ) + self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size) + + self.double_blocks = nn.ModuleList( + [ + DoubleStreamBlock( + self.hidden_size, + self.num_heads, + mlp_ratio=params.mlp_ratio, + qkv_bias=params.qkv_bias, + ) + for _ in range(params.depth) + ] + ) + + self.single_blocks = nn.ModuleList( + [ + SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio) + for _ in range(params.depth_single_blocks) + ] + ) + + self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels) + + def forward( + self, + img: Tensor, + img_ids: Tensor, + txt: Tensor, + txt_ids: Tensor, + timesteps: Tensor, + y: Tensor, + guidance: Tensor | None = None, + ) -> Tensor: + if img.ndim != 3 or txt.ndim != 3: + raise ValueError("Input img and txt tensors must have 3 dimensions.") + + # running on sequences img + img = self.img_in(img) + vec = self.time_in(timestep_embedding(timesteps, 256)) + if self.params.guidance_embed: + if guidance is None: + raise ValueError("Didn't get guidance strength for guidance distilled model.") + vec = vec + self.guidance_in(timestep_embedding(guidance, 256)) + vec = vec + self.vector_in(y) + txt = self.txt_in(txt) + + ids = torch.cat((txt_ids, img_ids), dim=1) + pe = self.pe_embedder(ids) + + for block in self.double_blocks: + img, txt = block(img=img, txt=txt, vec=vec, pe=pe) + + img = torch.cat((txt, img), 1) + for block in self.single_blocks: + img = block(img, vec=vec, pe=pe) + img = img[:, txt.shape[1] :, ...] + + img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels) + return img diff --git a/modules/models/flux/modules/layers.py b/modules/models/flux/modules/layers.py new file mode 100644 index 000000000..091ddf624 --- /dev/null +++ b/modules/models/flux/modules/layers.py @@ -0,0 +1,253 @@ +import math +from dataclasses import dataclass + +import torch +from einops import rearrange +from torch import Tensor, nn + +from flux.math import attention, rope + + +class EmbedND(nn.Module): + def __init__(self, dim: int, theta: int, axes_dim: list[int]): + super().__init__() + self.dim = dim + self.theta = theta + self.axes_dim = axes_dim + + def forward(self, ids: Tensor) -> Tensor: + n_axes = ids.shape[-1] + emb = torch.cat( + [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)], + dim=-3, + ) + + return emb.unsqueeze(1) + + +def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0): + """ + Create sinusoidal timestep embeddings. + :param t: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an (N, D) Tensor of positional embeddings. + """ + t = time_factor * t + half = dim // 2 + freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to( + t.device + ) + + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + if torch.is_floating_point(t): + embedding = embedding.to(t) + return embedding + + +class MLPEmbedder(nn.Module): + def __init__(self, in_dim: int, hidden_dim: int): + super().__init__() + self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True) + self.silu = nn.SiLU() + self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True) + + def forward(self, x: Tensor) -> Tensor: + return self.out_layer(self.silu(self.in_layer(x))) + + +class RMSNorm(torch.nn.Module): + def __init__(self, dim: int): + super().__init__() + self.scale = nn.Parameter(torch.ones(dim)) + + def forward(self, x: Tensor): + x_dtype = x.dtype + x = x.float() + rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6) + return (x * rrms).to(dtype=x_dtype) * self.scale + + +class QKNorm(torch.nn.Module): + def __init__(self, dim: int): + super().__init__() + self.query_norm = RMSNorm(dim) + self.key_norm = RMSNorm(dim) + + def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]: + q = self.query_norm(q) + k = self.key_norm(k) + return q.to(v), k.to(v) + + +class SelfAttention(nn.Module): + def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.norm = QKNorm(head_dim) + self.proj = nn.Linear(dim, dim) + + def forward(self, x: Tensor, pe: Tensor) -> Tensor: + qkv = self.qkv(x) + q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) + q, k = self.norm(q, k, v) + x = attention(q, k, v, pe=pe) + x = self.proj(x) + return x + + +@dataclass +class ModulationOut: + shift: Tensor + scale: Tensor + gate: Tensor + + +class Modulation(nn.Module): + def __init__(self, dim: int, double: bool): + super().__init__() + self.is_double = double + self.multiplier = 6 if double else 3 + self.lin = nn.Linear(dim, self.multiplier * dim, bias=True) + + def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut | None]: + out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1) + + return ( + ModulationOut(*out[:3]), + ModulationOut(*out[3:]) if self.is_double else None, + ) + + +class DoubleStreamBlock(nn.Module): + def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False): + super().__init__() + + mlp_hidden_dim = int(hidden_size * mlp_ratio) + self.num_heads = num_heads + self.hidden_size = hidden_size + self.img_mod = Modulation(hidden_size, double=True) + self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias) + + self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.img_mlp = nn.Sequential( + nn.Linear(hidden_size, mlp_hidden_dim, bias=True), + nn.GELU(approximate="tanh"), + nn.Linear(mlp_hidden_dim, hidden_size, bias=True), + ) + + self.txt_mod = Modulation(hidden_size, double=True) + self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias) + + self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.txt_mlp = nn.Sequential( + nn.Linear(hidden_size, mlp_hidden_dim, bias=True), + nn.GELU(approximate="tanh"), + nn.Linear(mlp_hidden_dim, hidden_size, bias=True), + ) + + def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor) -> tuple[Tensor, Tensor]: + img_mod1, img_mod2 = self.img_mod(vec) + txt_mod1, txt_mod2 = self.txt_mod(vec) + + # prepare image for attention + img_modulated = self.img_norm1(img) + img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift + img_qkv = self.img_attn.qkv(img_modulated) + img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) + img_q, img_k = self.img_attn.norm(img_q, img_k, img_v) + + # prepare txt for attention + txt_modulated = self.txt_norm1(txt) + txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift + txt_qkv = self.txt_attn.qkv(txt_modulated) + txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) + txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v) + + # run actual attention + q = torch.cat((txt_q, img_q), dim=2) + k = torch.cat((txt_k, img_k), dim=2) + v = torch.cat((txt_v, img_v), dim=2) + + attn = attention(q, k, v, pe=pe) + txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :] + + # calculate the img bloks + img = img + img_mod1.gate * self.img_attn.proj(img_attn) + img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift) + + # calculate the txt bloks + txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn) + txt = txt + txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift) + return img, txt + + +class SingleStreamBlock(nn.Module): + """ + A DiT block with parallel linear layers as described in + https://arxiv.org/abs/2302.05442 and adapted modulation interface. + """ + + def __init__( + self, + hidden_size: int, + num_heads: int, + mlp_ratio: float = 4.0, + qk_scale: float | None = None, + ): + super().__init__() + self.hidden_dim = hidden_size + self.num_heads = num_heads + head_dim = hidden_size // num_heads + self.scale = qk_scale or head_dim**-0.5 + + self.mlp_hidden_dim = int(hidden_size * mlp_ratio) + # qkv and mlp_in + self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim) + # proj and mlp_out + self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size) + + self.norm = QKNorm(head_dim) + + self.hidden_size = hidden_size + self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + + self.mlp_act = nn.GELU(approximate="tanh") + self.modulation = Modulation(hidden_size, double=False) + + def forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor: + mod, _ = self.modulation(vec) + x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift + qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1) + + q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) + q, k = self.norm(q, k, v) + + # compute attention + attn = attention(q, k, v, pe=pe) + # compute activation in mlp stream, cat again and run second linear layer + output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2)) + return x + mod.gate * output + + +class LastLayer(nn.Module): + def __init__(self, hidden_size: int, patch_size: int, out_channels: int): + super().__init__() + self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True) + self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True)) + + def forward(self, x: Tensor, vec: Tensor) -> Tensor: + shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1) + x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :] + x = self.linear(x) + return x diff --git a/modules/models/flux/util.py b/modules/models/flux/util.py new file mode 100644 index 000000000..77fc76c09 --- /dev/null +++ b/modules/models/flux/util.py @@ -0,0 +1,201 @@ +import os +from dataclasses import dataclass + +import torch +from einops import rearrange +from huggingface_hub import hf_hub_download +from imwatermark import WatermarkEncoder +from safetensors.torch import load_file as load_sft + +from flux.model import Flux, FluxParams +from flux.modules.autoencoder import AutoEncoder, AutoEncoderParams +from flux.modules.conditioner import HFEmbedder + + +@dataclass +class ModelSpec: + params: FluxParams + ae_params: AutoEncoderParams + ckpt_path: str | None + ae_path: str | None + repo_id: str | None + repo_flow: str | None + repo_ae: str | None + + +configs = { + "flux-dev": ModelSpec( + repo_id="black-forest-labs/FLUX.1-dev", + repo_flow="flux1-dev.safetensors", + repo_ae="ae.safetensors", + ckpt_path=os.getenv("FLUX_DEV"), + params=FluxParams( + in_channels=64, + vec_in_dim=768, + context_in_dim=4096, + hidden_size=3072, + mlp_ratio=4.0, + num_heads=24, + depth=19, + depth_single_blocks=38, + axes_dim=[16, 56, 56], + theta=10_000, + qkv_bias=True, + guidance_embed=True, + ), + ae_path=os.getenv("AE"), + ae_params=AutoEncoderParams( + resolution=256, + in_channels=3, + ch=128, + out_ch=3, + ch_mult=[1, 2, 4, 4], + num_res_blocks=2, + z_channels=16, + scale_factor=0.3611, + shift_factor=0.1159, + ), + ), + "flux-schnell": ModelSpec( + repo_id="black-forest-labs/FLUX.1-schnell", + repo_flow="flux1-schnell.safetensors", + repo_ae="ae.safetensors", + ckpt_path=os.getenv("FLUX_SCHNELL"), + params=FluxParams( + in_channels=64, + vec_in_dim=768, + context_in_dim=4096, + hidden_size=3072, + mlp_ratio=4.0, + num_heads=24, + depth=19, + depth_single_blocks=38, + axes_dim=[16, 56, 56], + theta=10_000, + qkv_bias=True, + guidance_embed=False, + ), + ae_path=os.getenv("AE"), + ae_params=AutoEncoderParams( + resolution=256, + in_channels=3, + ch=128, + out_ch=3, + ch_mult=[1, 2, 4, 4], + num_res_blocks=2, + z_channels=16, + scale_factor=0.3611, + shift_factor=0.1159, + ), + ), +} + + +def print_load_warning(missing: list[str], unexpected: list[str]) -> None: + if len(missing) > 0 and len(unexpected) > 0: + print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing)) + print("\n" + "-" * 79 + "\n") + print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected)) + elif len(missing) > 0: + print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing)) + elif len(unexpected) > 0: + print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected)) + + +def load_flow_model(name: str, device: str | torch.device = "cuda", hf_download: bool = True): + # Loading Flux + print("Init model") + ckpt_path = configs[name].ckpt_path + if ( + ckpt_path is None + and configs[name].repo_id is not None + and configs[name].repo_flow is not None + and hf_download + ): + ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_flow) + + with torch.device("meta" if ckpt_path is not None else device): + model = Flux(configs[name].params).to(torch.bfloat16) + + if ckpt_path is not None: + print("Loading checkpoint") + # load_sft doesn't support torch.device + sd = load_sft(ckpt_path, device=str(device)) + missing, unexpected = model.load_state_dict(sd, strict=False, assign=True) + print_load_warning(missing, unexpected) + return model + + +def load_t5(device: str | torch.device = "cuda", max_length: int = 512) -> HFEmbedder: + # max length 64, 128, 256 and 512 should work (if your sequence is short enough) + return HFEmbedder("google/t5-v1_1-xxl", max_length=max_length, torch_dtype=torch.bfloat16).to(device) + + +def load_clip(device: str | torch.device = "cuda") -> HFEmbedder: + return HFEmbedder("openai/clip-vit-large-patch14", max_length=77, torch_dtype=torch.bfloat16).to(device) + + +def load_ae(name: str, device: str | torch.device = "cuda", hf_download: bool = True) -> AutoEncoder: + ckpt_path = configs[name].ae_path + if ( + ckpt_path is None + and configs[name].repo_id is not None + and configs[name].repo_ae is not None + and hf_download + ): + ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_ae) + + # Loading the autoencoder + print("Init AE") + with torch.device("meta" if ckpt_path is not None else device): + ae = AutoEncoder(configs[name].ae_params) + + if ckpt_path is not None: + sd = load_sft(ckpt_path, device=str(device)) + missing, unexpected = ae.load_state_dict(sd, strict=False, assign=True) + print_load_warning(missing, unexpected) + return ae + + +class WatermarkEmbedder: + def __init__(self, watermark): + self.watermark = watermark + self.num_bits = len(WATERMARK_BITS) + self.encoder = WatermarkEncoder() + self.encoder.set_watermark("bits", self.watermark) + + def __call__(self, image: torch.Tensor) -> torch.Tensor: + """ + Adds a predefined watermark to the input image + + Args: + image: ([N,] B, RGB, H, W) in range [-1, 1] + + Returns: + same as input but watermarked + """ + image = 0.5 * image + 0.5 + squeeze = len(image.shape) == 4 + if squeeze: + image = image[None, ...] + n = image.shape[0] + image_np = rearrange((255 * image).detach().cpu(), "n b c h w -> (n b) h w c").numpy()[:, :, :, ::-1] + # torch (b, c, h, w) in [0, 1] -> numpy (b, h, w, c) [0, 255] + # watermarking libary expects input as cv2 BGR format + for k in range(image_np.shape[0]): + image_np[k] = self.encoder.encode(image_np[k], "dwtDct") + image = torch.from_numpy(rearrange(image_np[:, :, :, ::-1], "(n b) h w c -> n b c h w", n=n)).to( + image.device + ) + image = torch.clamp(image / 255, min=0.0, max=1.0) + if squeeze: + image = image[0] + image = 2 * image - 1 + return image + + +# A fixed 48-bit message that was chosen at random +WATERMARK_MESSAGE = 0b001010101111111010000111100111001111010100101110 +# bin(x)[2:] gives bits of x as str, use int to convert them to 0/1 +WATERMARK_BITS = [int(bit) for bit in bin(WATERMARK_MESSAGE)[2:]] +embed_watermark = WatermarkEmbedder(WATERMARK_BITS) From d38732efae195d93388828fbf4b77034b3150ec5 Mon Sep 17 00:00:00 2001 From: Won-Kyu Park Date: Sat, 31 Aug 2024 16:40:28 +0900 Subject: [PATCH 02/50] add flux model wrapper --- modules/models/flux/__init__.py | 5 + modules/models/flux/flux.py | 338 ++++++++++++++++++++++++++++++++ 2 files changed, 343 insertions(+) create mode 100644 modules/models/flux/__init__.py create mode 100644 modules/models/flux/flux.py diff --git a/modules/models/flux/__init__.py b/modules/models/flux/__init__.py new file mode 100644 index 000000000..1cc52a00b --- /dev/null +++ b/modules/models/flux/__init__.py @@ -0,0 +1,5 @@ +from .flux import FLUX1Inferencer + +__all__ = [ + "FLUX1Inferencer", +] diff --git a/modules/models/flux/flux.py b/modules/models/flux/flux.py new file mode 100644 index 000000000..d17febc68 --- /dev/null +++ b/modules/models/flux/flux.py @@ -0,0 +1,338 @@ +import contextlib + +import os +import safetensors +import torch +import math + +import k_diffusion +from transformers import CLIPTokenizer + +from modules import shared, devices, modelloader, sd_hijack_clip + +from modules.models.sd3.sd3_impls import SDVAE +from modules.models.sd3.sd3_cond import CLIPL_CONFIG, T5_CONFIG, CLIPL_URL, T5_URL, SafetensorsMapping, Sd3T5 +from modules.models.sd3.other_impls import SDClipModel, T5XXLModel, SDTokenizer, T5XXLTokenizer +from PIL import Image + +from .model import Flux + + +class FluxTokenizer: + def __init__(self): + clip_tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") + self.clip_l = SDTokenizer(tokenizer=clip_tokenizer) + self.t5xxl = T5XXLTokenizer() + + def tokenize_with_weights(self, text:str): + out = {} + out["l"] = self.clip_l.tokenize_with_weights(text) + out["t5xxl"] = self.t5xxl.tokenize_with_weights(text) + return out + + +class Flux1ClipL(sd_hijack_clip.TextConditionalModel): + def __init__(self, clip_l): + super().__init__() + + self.clip_l = clip_l + + self.tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") + + empty = self.tokenizer('')["input_ids"] + self.id_start = empty[0] + self.id_end = empty[1] + self.id_pad = empty[1] + + self.return_pooled = True + + def tokenize(self, texts): + return self.tokenizer(texts, truncation=False, add_special_tokens=False)["input_ids"] + + def encode_with_transformers(self, tokens): + tokens_g = tokens.clone() + + for batch_pos in range(tokens_g.shape[0]): + index = tokens_g[batch_pos].cpu().tolist().index(self.id_end) + tokens_g[batch_pos, index+1:tokens_g.shape[1]] = 0 + + l_out, l_pooled = self.clip_l(tokens) + l_out = torch.cat([l_out], dim=-1) + l_out = torch.nn.functional.pad(l_out, (0, 4096 - l_out.shape[-1])) + + vector_out = torch.cat([l_pooled], dim=-1) + + l_out.pooled = vector_out + + return l_out + + def encode_embedding_init_text(self, init_text, nvpt): + return torch.zeros((nvpt, 768+1280), device=devices.device) # XXX + + + +class FluxCond(torch.nn.Module): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self.tokenizer = FluxTokenizer() + + with torch.no_grad(): + self.clip_l = SDClipModel(layer="hidden", layer_idx=-2, device="cpu", dtype=devices.dtype, layer_norm_hidden_state=False, return_projected_pooled=False, textmodel_json_config=CLIPL_CONFIG) + + if shared.opts.sd3_enable_t5: + self.t5xxl = T5XXLModel(T5_CONFIG, device="cpu", dtype=devices.dtype) + else: + self.t5xxl = None + + self.model_l = Flux1ClipL(self.clip_l) + self.model_t5 = Sd3T5(self.t5xxl) + + def forward(self, prompts: list[str]): + with devices.without_autocast(): + l_out, vector_out = self.model_l(prompts) + t5_out = self.model_t5(prompts, token_count=l_out.shape[1]) + lt_out = torch.cat([l_out, t5_out], dim=-2) + + return { + 'crossattn': lt_out, + 'vector': vector_out, + } + + def before_load_weights(self, state_dict): + clip_path = os.path.join(shared.models_path, "CLIP") + + if 'text_encoders.clip_l.transformer.text_model.embeddings.position_embedding.weight' not in state_dict: + clip_l_file = modelloader.load_file_from_url(CLIPL_URL, model_dir=clip_path, file_name="clip_l.safetensors") + with safetensors.safe_open(clip_l_file, framework="pt") as file: + self.clip_l.transformer.load_state_dict(SafetensorsMapping(file), strict=False) + + if self.t5xxl and 'text_encoders.t5xxl.transformer.encoder.embed_tokens.weight' not in state_dict: + t5_file = modelloader.load_file_from_url(T5_URL, model_dir=clip_path, file_name="t5xxl_fp16.safetensors") + with safetensors.safe_open(t5_file, framework="pt") as file: + self.t5xxl.transformer.load_state_dict(SafetensorsMapping(file), strict=False) + + def encode_embedding_init_text(self, init_text, nvpt): + return self.model_l.encode_embedding_init_text(init_text, nvpt) + + def tokenize(self, texts): + return self.model_l.tokenize(texts) + + def medvram_modules(self): + return [self.clip_l, self.t5xxl] + + def get_token_count(self, text): + _, token_count = self.model_l.process_texts([text]) + + return token_count + + def get_target_prompt_token_count(self, token_count): + return self.model_l.get_target_prompt_token_count(token_count) + +def flux_time_shift(mu: float, sigma: float, t): + return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) + +class ModelSamplingFlux(torch.nn.Module): + def __init__(self, model_config=None): + super().__init__() + if model_config is not None: + sampling_settings = model_config.sampling_settings + else: + sampling_settings = {} + + self.set_parameters(shift=sampling_settings.get("shift", 1.15)) + + def set_parameters(self, shift=1.15, timesteps=10000): + self.shift = shift + ts = self.sigma((torch.arange(1, timesteps + 1, 1) / timesteps)) + self.register_buffer('sigmas', ts) + + @property + def sigma_min(self): + return self.sigmas[0] + + @property + def sigma_max(self): + return self.sigmas[-1] + + def timestep(self, sigma): + return sigma + + def sigma(self, timestep): + return flux_time_shift(self.shift, 1.0, timestep) + + def percent_to_sigma(self, percent): + if percent <= 0.0: + return 1.0 + if percent >= 1.0: + return 0.0 + return 1.0 - percent + + def calculate_denoised(self, sigma, model_output, model_input): + sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1)) + return model_input - model_output * sigma + + +class BaseModel(torch.nn.Module): + """Wrapper around the core FLUX model""" + def __init__(self, shift=1.0, device=None, dtype=torch.float16, state_dict=None, prefix=""): + super().__init__() + + params = dict( + image_model="flux", + in_channels=16, + vec_in_dim=768, + context_in_dim=4096, + hidden_size=3072, + mlp_ratio=4.0, + num_heads=24, + depth=19, + depth_single_blocks=38, + axes_dim=[16, 56, 56], + theta=10000, + qkv_bias=True, + guidance_embed=True, + ) + + self.diffusion_model = Flux(device=device, dtype=devices.dtype, **params) + self.model_sampling = ModelSamplingFlux() + self.depth = 19 + + def apply_model(self, x, sigma, c_crossattn=None, y=None): + dtype = self.get_dtype() + timestep = self.model_sampling.timestep(sigma).float() + guidance = torch.FloatTensor([3.5]).to(device=devices.device, dtype=dtype) + model_output = self.diffusion_model(x.to(dtype), timestep, context=c_crossattn.to(dtype), y=y.to(dtype), guidance=guidance).float() + return self.model_sampling.calculate_denoised(sigma, model_output, x) + + def forward(self, *args, **kwargs): + return self.apply_model(*args, **kwargs) + + def get_dtype(self): + return self.diffusion_model.dtype + + +class FLUX1LatentFormat: + """Latents are slightly shifted from center - this class must be called after VAE Decode to correct for the shift""" + def __init__(self): + self.scale_factor = 0.3611 + self.shift_factor = 0.1159 + + def process_in(self, latent): + return (latent - self.shift_factor) * self.scale_factor + + def process_out(self, latent): + return (latent / self.scale_factor) + self.shift_factor + + def decode_latent_to_preview(self, x0): + """Quick RGB approximate preview of sd3 latents""" + factors = torch.tensor([ + [-0.0404, 0.0159, 0.0609], [ 0.0043, 0.0298, 0.0850], + [ 0.0328, -0.0749, -0.0503], [-0.0245, 0.0085, 0.0549], + [ 0.0966, 0.0894, 0.0530], [ 0.0035, 0.0399, 0.0123], + [ 0.0583, 0.1184, 0.1262], [-0.0191, -0.0206, -0.0306], + [-0.0324, 0.0055, 0.1001], [ 0.0955, 0.0659, -0.0545], + [-0.0504, 0.0231, -0.0013], [ 0.0500, -0.0008, -0.0088], + [ 0.0982, 0.0941, 0.0976], [-0.1233, -0.0280, -0.0897], + [-0.0005, -0.0530, -0.0020], [-0.1273, -0.0932, -0.0680], + ], device="cpu") + latent_image = x0[0].permute(1, 2, 0).cpu() @ factors + + latents_ubyte = (((latent_image + 1) / 2) + .clamp(0, 1) # change scale from -1..1 to 0..1 + .mul(0xFF) # to 0..255 + .byte()).cpu() + + return Image.fromarray(latents_ubyte.numpy()) + + +class FLUX1Denoiser(k_diffusion.external.DiscreteSchedule): + def __init__(self, inner_model, sigmas): + super().__init__(sigmas, quantize=shared.opts.enable_quantization) + self.inner_model = inner_model + + def forward(self, input, sigma, **kwargs): + return self.inner_model.apply_model(input, sigma, **kwargs) + + +class FLUX1Inferencer(torch.nn.Module): + def __init__(self, state_dict, use_ema=False): + super().__init__() + + # detect model_prefix + diffusion_model_prefix = "model.diffusion_model." + if "model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale" in state_dict: + diffusion_model_prefix = "model.diffusion_model." + elif "double_blocks.0.img_attn.norm.key_norm.scale" in state_dict: + diffusion_model_prefix = "" + + with torch.no_grad(): + self.model = BaseModel(state_dict=state_dict, prefix=diffusion_model_prefix, device="cpu", dtype=devices.dtype) + self.first_stage_model = SDVAE(device="cpu", dtype=devices.dtype_vae) + self.first_stage_model.dtype = self.model.diffusion_model.dtype + + self.alphas_cumprod = 1 / (self.model.model_sampling.sigmas ** 2 + 1) + + self.text_encoders = FluxCond() + self.cond_stage_key = 'txt' + + self.parameterization = "eps" + self.model.conditioning_key = "crossattn" + + self.latent_format = FLUX1LatentFormat() + self.latent_channels = 16 + + @property + def cond_stage_model(self): + return self.text_encoders + + def before_load_weights(self, state_dict): + self.cond_stage_model.before_load_weights(state_dict) + + def ema_scope(self): + return contextlib.nullcontext() + + def get_learned_conditioning(self, batch: list[str]): + return self.cond_stage_model(batch) + + def apply_model(self, x, t, cond): + return self.model(x, t, c_crossattn=cond['crossattn'], y=cond['vector']) + + def decode_first_stage(self, latent): + latent = self.latent_format.process_out(latent) + return self.first_stage_model.decode(latent) + + def encode_first_stage(self, image): + latent = self.first_stage_model.encode(image) + return self.latent_format.process_in(latent) + + def get_first_stage_encoding(self, x): + return x + + def create_denoiser(self): + return FLUX1Denoiser(self, self.model.model_sampling.sigmas) + + def medvram_fields(self): + return [ + (self, 'first_stage_model'), + (self, 'text_encoders'), + (self, 'model'), + ] + + def add_noise_to_latent(self, x, noise, amount): + return x * (1 - amount) + noise * amount + + def fix_dimensions(self, width, height): + return width // 16 * 16, height // 16 * 16 + + def diffusers_weight_mapping(self): + for i in range(self.model.depth): + yield f"transformer.transformer_blocks.{i}.attn.to_q", f"diffusion_model_joint_blocks_{i}_x_block_attn_qkv_q_proj" + yield f"transformer.transformer_blocks.{i}.attn.to_k", f"diffusion_model_joint_blocks_{i}_x_block_attn_qkv_k_proj" + yield f"transformer.transformer_blocks.{i}.attn.to_v", f"diffusion_model_joint_blocks_{i}_x_block_attn_qkv_v_proj" + yield f"transformer.transformer_blocks.{i}.attn.to_out.0", f"diffusion_model_joint_blocks_{i}_x_block_attn_proj" + + yield f"transformer.transformer_blocks.{i}.attn.add_q_proj", f"diffusion_model_joint_blocks_{i}_context_block.attn_qkv_q_proj" + yield f"transformer.transformer_blocks.{i}.attn.add_k_proj", f"diffusion_model_joint_blocks_{i}_context_block.attn_qkv_k_proj" + yield f"transformer.transformer_blocks.{i}.attn.add_v_proj", f"diffusion_model_joint_blocks_{i}_context_block.attn_qkv_v_proj" + yield f"transformer.transformer_blocks.{i}.attn.add_out_proj.0", f"diffusion_model_joint_blocks_{i}_context_block_attn_proj" From 2d1db1a2d0878cd56521b71def2de1c5e7ca279c Mon Sep 17 00:00:00 2001 From: Won-Kyu Park Date: Sat, 31 Aug 2024 16:47:17 +0900 Subject: [PATCH 03/50] fix for flux --- configs/flux1-inference.yaml | 4 ++ modules/models/flux/math.py | 2 +- modules/models/flux/model.py | 67 ++++++++++++----- modules/models/flux/modules/layers.py | 100 ++++++++++++++------------ modules/models/flux/util.py | 6 +- 5 files changed, 113 insertions(+), 66 deletions(-) create mode 100644 configs/flux1-inference.yaml diff --git a/configs/flux1-inference.yaml b/configs/flux1-inference.yaml new file mode 100644 index 000000000..f9bbe9073 --- /dev/null +++ b/configs/flux1-inference.yaml @@ -0,0 +1,4 @@ +model: + target: modules.models.flux.FLUX1Inferencer + params: + state_dict: null diff --git a/modules/models/flux/math.py b/modules/models/flux/math.py index 0156bb6a2..4ad99818e 100644 --- a/modules/models/flux/math.py +++ b/modules/models/flux/math.py @@ -22,7 +22,7 @@ def rope(pos: Tensor, dim: int, theta: int) -> Tensor: return out.float() -def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]: +def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor): xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2) xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2) xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1] diff --git a/modules/models/flux/model.py b/modules/models/flux/model.py index f33ab8323..c87397827 100644 --- a/modules/models/flux/model.py +++ b/modules/models/flux/model.py @@ -1,9 +1,13 @@ +# original code from https://github.com/black-forest-labs/flux +# from dataclasses import dataclass import torch + +from einops import rearrange, repeat from torch import Tensor, nn -from flux.modules.layers import (DoubleStreamBlock, EmbedND, LastLayer, +from .modules.layers import (DoubleStreamBlock, EmbedND, LastLayer, MLPEmbedder, SingleStreamBlock, timestep_embedding) @@ -18,7 +22,7 @@ class FluxParams: num_heads: int depth: int depth_single_blocks: int - axes_dim: list[int] + axes_dim: list theta: int qkv_bias: bool guidance_embed: bool @@ -29,11 +33,13 @@ class Flux(nn.Module): Transformer model for flow matching on sequences. """ - def __init__(self, params: FluxParams): + def __init__(self, image_model=None, final_layer=True, dtype=None, device=None, **kwargs): super().__init__() + self.dtype = dtype + params = FluxParams(**kwargs) self.params = params - self.in_channels = params.in_channels + self.in_channels = params.in_channels * 2 * 2 self.out_channels = self.in_channels if params.hidden_size % params.num_heads != 0: raise ValueError( @@ -45,13 +51,13 @@ class Flux(nn.Module): self.hidden_size = params.hidden_size self.num_heads = params.num_heads self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim) - self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True) - self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) - self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size) + self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True, dtype=dtype, device=device) + self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, dtype=dtype, device=device) + self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size, dtype=dtype, device=device) self.guidance_in = ( - MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else nn.Identity() + MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, dtype=dtype, device=device) if params.guidance_embed else nn.Identity() ) - self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size) + self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size, dtype=dtype, device=device) self.double_blocks = nn.ModuleList( [ @@ -60,6 +66,7 @@ class Flux(nn.Module): self.num_heads, mlp_ratio=params.mlp_ratio, qkv_bias=params.qkv_bias, + dtype=dtype, device=device, ) for _ in range(params.depth) ] @@ -67,14 +74,15 @@ class Flux(nn.Module): self.single_blocks = nn.ModuleList( [ - SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio) + SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, dtype=dtype, device=device) for _ in range(params.depth_single_blocks) ] ) - self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels) + if final_layer: + self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels, dtype=dtype, device=device) - def forward( + def forward_orig( self, img: Tensor, img_ids: Tensor, @@ -82,18 +90,18 @@ class Flux(nn.Module): txt_ids: Tensor, timesteps: Tensor, y: Tensor, - guidance: Tensor | None = None, + guidance: Tensor = None, ) -> Tensor: if img.ndim != 3 or txt.ndim != 3: raise ValueError("Input img and txt tensors must have 3 dimensions.") # running on sequences img img = self.img_in(img) - vec = self.time_in(timestep_embedding(timesteps, 256)) + vec = self.time_in(timestep_embedding(timesteps, 256).to(img.dtype)) if self.params.guidance_embed: if guidance is None: raise ValueError("Didn't get guidance strength for guidance distilled model.") - vec = vec + self.guidance_in(timestep_embedding(guidance, 256)) + vec = vec + self.guidance_in(timestep_embedding(guidance, 256).to(img.dtype)) vec = vec + self.vector_in(y) txt = self.txt_in(txt) @@ -108,5 +116,32 @@ class Flux(nn.Module): img = block(img, vec=vec, pe=pe) img = img[:, txt.shape[1] :, ...] - img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels) + if self.final_layer: + img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels) return img + + def forward(self, x, timestep, context, y, guidance, **kwargs): + # from comfy/ldm/common_dit.py + def pad_to_patch_size(img, patch_size=(2, 2), padding_mode="circular"): + if padding_mode == "circular" and torch.jit.is_tracing() or torch.jit.is_scripting(): + padding_mode = "reflect" + pad_h = (patch_size[0] - img.shape[-2] % patch_size[0]) % patch_size[0] + pad_w = (patch_size[1] - img.shape[-1] % patch_size[1]) % patch_size[1] + return torch.nn.functional.pad(img, (0, pad_w, 0, pad_h), mode=padding_mode) + + bs, c, h, w = x.shape + patch_size = 2 + x = pad_to_patch_size(x, (patch_size, patch_size)) + + img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size) + + h_len = ((h + (patch_size // 2)) // patch_size) + w_len = ((w + (patch_size // 2)) // patch_size) + img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype) + img_ids[..., 1] = img_ids[..., 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype)[:, None] + img_ids[..., 2] = img_ids[..., 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype)[None, :] + img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs) + + txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype) + out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance) + return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)[:,:,:h,:w] diff --git a/modules/models/flux/modules/layers.py b/modules/models/flux/modules/layers.py index 091ddf624..aa830849e 100644 --- a/modules/models/flux/modules/layers.py +++ b/modules/models/flux/modules/layers.py @@ -5,11 +5,11 @@ import torch from einops import rearrange from torch import Tensor, nn -from flux.math import attention, rope +from ..math import attention, rope class EmbedND(nn.Module): - def __init__(self, dim: int, theta: int, axes_dim: list[int]): + def __init__(self, dim: int, theta: int, axes_dim: list): super().__init__() self.dim = dim self.theta = theta @@ -36,9 +36,7 @@ def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 10 """ t = time_factor * t half = dim // 2 - freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to( - t.device - ) + freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half) args = t[:, None].float() * freqs[None] embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) @@ -50,20 +48,20 @@ def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 10 class MLPEmbedder(nn.Module): - def __init__(self, in_dim: int, hidden_dim: int): + def __init__(self, in_dim: int, hidden_dim: int, dtype=None, device=None): super().__init__() - self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True) + self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True, dtype=dtype, device=device) self.silu = nn.SiLU() - self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True) + self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True, dtype=dtype, device=device) def forward(self, x: Tensor) -> Tensor: return self.out_layer(self.silu(self.in_layer(x))) class RMSNorm(torch.nn.Module): - def __init__(self, dim: int): + def __init__(self, dim: int, dtype=None, device=None): super().__init__() - self.scale = nn.Parameter(torch.ones(dim)) + self.scale = nn.Parameter(torch.empty((dim), dtype=dtype, device=device)) def forward(self, x: Tensor): x_dtype = x.dtype @@ -73,26 +71,26 @@ class RMSNorm(torch.nn.Module): class QKNorm(torch.nn.Module): - def __init__(self, dim: int): + def __init__(self, dim: int, dtype=None, device=None): super().__init__() - self.query_norm = RMSNorm(dim) - self.key_norm = RMSNorm(dim) + self.query_norm = RMSNorm(dim, dtype=dtype, device=device) + self.key_norm = RMSNorm(dim, dtype=dtype, device=device) - def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]: + def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple: q = self.query_norm(q) k = self.key_norm(k) return q.to(v), k.to(v) class SelfAttention(nn.Module): - def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False): + def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False, dtype=None, device=None): super().__init__() self.num_heads = num_heads head_dim = dim // num_heads - self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) - self.norm = QKNorm(head_dim) - self.proj = nn.Linear(dim, dim) + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias, dtype=dtype, device=device) + self.norm = QKNorm(head_dim, dtype=dtype, device=device) + self.proj = nn.Linear(dim, dim, dtype=dtype, device=device) def forward(self, x: Tensor, pe: Tensor) -> Tensor: qkv = self.qkv(x) @@ -111,13 +109,13 @@ class ModulationOut: class Modulation(nn.Module): - def __init__(self, dim: int, double: bool): + def __init__(self, dim: int, double: bool, dtype=None, device=None): super().__init__() self.is_double = double self.multiplier = 6 if double else 3 - self.lin = nn.Linear(dim, self.multiplier * dim, bias=True) + self.lin = nn.Linear(dim, self.multiplier * dim, bias=True, dtype=dtype, device=device) - def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut | None]: + def forward(self, vec: Tensor) -> tuple: out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1) return ( @@ -127,35 +125,35 @@ class Modulation(nn.Module): class DoubleStreamBlock(nn.Module): - def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False): + def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, dtype=None, device=None): super().__init__() mlp_hidden_dim = int(hidden_size * mlp_ratio) self.num_heads = num_heads self.hidden_size = hidden_size - self.img_mod = Modulation(hidden_size, double=True) - self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) - self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias) + self.img_mod = Modulation(hidden_size, double=True, dtype=dtype, device=device) + self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) + self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, dtype=dtype, device=device) - self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) self.img_mlp = nn.Sequential( - nn.Linear(hidden_size, mlp_hidden_dim, bias=True), + nn.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device), nn.GELU(approximate="tanh"), - nn.Linear(mlp_hidden_dim, hidden_size, bias=True), + nn.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device), ) - self.txt_mod = Modulation(hidden_size, double=True) - self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) - self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias) + self.txt_mod = Modulation(hidden_size, double=True, dtype=dtype, device=device) + self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) + self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, dtype=dtype, device=device) - self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) self.txt_mlp = nn.Sequential( - nn.Linear(hidden_size, mlp_hidden_dim, bias=True), + nn.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device), nn.GELU(approximate="tanh"), - nn.Linear(mlp_hidden_dim, hidden_size, bias=True), + nn.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device), ) - def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor) -> tuple[Tensor, Tensor]: + def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor): img_mod1, img_mod2 = self.img_mod(vec) txt_mod1, txt_mod2 = self.txt_mod(vec) @@ -188,6 +186,11 @@ class DoubleStreamBlock(nn.Module): # calculate the txt bloks txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn) txt = txt + txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift) + + if txt.dtype == torch.float16: + txt = torch.nan_to_num(txt, nan=0.0, posinf=65504, neginf=-65504) + + return img, txt @@ -202,7 +205,9 @@ class SingleStreamBlock(nn.Module): hidden_size: int, num_heads: int, mlp_ratio: float = 4.0, - qk_scale: float | None = None, + qk_scale: float = None, + dtype=None, + device=None, ): super().__init__() self.hidden_dim = hidden_size @@ -212,17 +217,17 @@ class SingleStreamBlock(nn.Module): self.mlp_hidden_dim = int(hidden_size * mlp_ratio) # qkv and mlp_in - self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim) + self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim, dtype=dtype, device=device) # proj and mlp_out - self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size) + self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size, dtype=dtype, device=device) - self.norm = QKNorm(head_dim) + self.norm = QKNorm(head_dim, dtype=dtype, device=device) self.hidden_size = hidden_size - self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) self.mlp_act = nn.GELU(approximate="tanh") - self.modulation = Modulation(hidden_size, double=False) + self.modulation = Modulation(hidden_size, double=False, dtype=dtype, device=device) def forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor: mod, _ = self.modulation(vec) @@ -236,15 +241,18 @@ class SingleStreamBlock(nn.Module): attn = attention(q, k, v, pe=pe) # compute activation in mlp stream, cat again and run second linear layer output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2)) - return x + mod.gate * output + x += mod.gate * output + if x.dtype == torch.float16: + x = torch.nan_to_num(x, nan=0.0, posinf=65504, neginf=-65504) + return x class LastLayer(nn.Module): - def __init__(self, hidden_size: int, patch_size: int, out_channels: int): + def __init__(self, hidden_size: int, patch_size: int, out_channels: int, dtype=None, device=None): super().__init__() - self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) - self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True) - self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True)) + self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) + self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True, dtype=dtype, device=device) + self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True, dtype=dtype, device=device)) def forward(self, x: Tensor, vec: Tensor) -> Tensor: shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1) diff --git a/modules/models/flux/util.py b/modules/models/flux/util.py index 77fc76c09..9303eb7cf 100644 --- a/modules/models/flux/util.py +++ b/modules/models/flux/util.py @@ -7,9 +7,9 @@ from huggingface_hub import hf_hub_download from imwatermark import WatermarkEncoder from safetensors.torch import load_file as load_sft -from flux.model import Flux, FluxParams -from flux.modules.autoencoder import AutoEncoder, AutoEncoderParams -from flux.modules.conditioner import HFEmbedder +from .model import Flux, FluxParams +from .modules.autoencoder import AutoEncoder, AutoEncoderParams +from .modules.conditioner import HFEmbedder @dataclass From 821e76a415129dd272d26b5c16658b1c05bd3201 Mon Sep 17 00:00:00 2001 From: Won-Kyu Park Date: Sat, 14 Sep 2024 00:08:44 +0900 Subject: [PATCH 04/50] use empty_like for speed --- modules/sd_disable_initialization.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/sd_disable_initialization.py b/modules/sd_disable_initialization.py index 3750e85e9..0fc1596b7 100644 --- a/modules/sd_disable_initialization.py +++ b/modules/sd_disable_initialization.py @@ -188,7 +188,7 @@ class LoadStateDictOnMeta(ReplaceHelper): if param.is_meta: dtype = sd_param.dtype if sd_param is not None else param.dtype - module._parameters[name] = torch.nn.parameter.Parameter(torch.zeros_like(param, device=device, dtype=dtype), requires_grad=param.requires_grad) + module._parameters[name] = torch.nn.parameter.Parameter(torch.empty_like(param, device=device, dtype=dtype), requires_grad=param.requires_grad) for name in module._buffers: key = prefix + name From 39328bd7db32f852ea12fb371457f2ae50da69e3 Mon Sep 17 00:00:00 2001 From: Won-Kyu Park Date: Thu, 5 Sep 2024 09:17:26 +0900 Subject: [PATCH 05/50] fix misc * check supported dtypes * detect non_blocking * update autocast() to use non_blocking, target_device and current_dtype --- modules/devices.py | 82 +++++++++++++++++++++++++++++++++++----------- 1 file changed, 63 insertions(+), 19 deletions(-) diff --git a/modules/devices.py b/modules/devices.py index ee679141a..556e72d2e 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -128,6 +128,26 @@ dtype_unet: torch.dtype = torch.float16 dtype_inference: torch.dtype = torch.float16 unet_needs_upcast = False +supported_vae_dtypes = [torch.float16, torch.float32] + + +# prepare available dtypes +if torch.version.cuda: + if torch.cuda.is_bf16_supported() and torch.cuda.get_device_properties(torch.cuda.current_device()).major >= 8: + supported_vae_dtypes = [torch.bfloat16] + supported_vae_dtypes + if has_xpu(): + supported_vae_dtypes = [torch.bfloat16] + supported_vae_dtypes + + +def supports_non_blocking(): + if has_mps() or has_xpu(): + return False + + if npu_specific.has_npu: + return False + + return True + def cond_cast_unet(input): if force_fp16: @@ -149,14 +169,23 @@ patch_module_list = [ ] -def manual_cast_forward(target_dtype): +def manual_cast_forward(target_dtype, target_device=None): + params = dict() + if supports_non_blocking(): + params['non_blocking'] = True + def forward_wrapper(self, *args, **kwargs): - if any( - isinstance(arg, torch.Tensor) and arg.dtype != target_dtype - for arg in args - ): - args = [arg.to(target_dtype) if isinstance(arg, torch.Tensor) else arg for arg in args] - kwargs = {k: v.to(target_dtype) if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()} + if target_device is not None: + params['device'] = target_device + params['dtype'] = target_dtype + + args = list(args) + for j in (i for i, arg in enumerate(args) if isinstance(arg, torch.Tensor) and arg.dtype != target_dtype): + args[j] = args[j].to(**params) + args = tuple(args) + + for key in (k for k, v in kwargs.items() if isinstance(v, torch.Tensor) and v.dtype != target_dtype): + kwargs[key] = kwargs[key].to(**params) org_dtype = target_dtype for param in self.parameters(): @@ -165,37 +194,41 @@ def manual_cast_forward(target_dtype): break if org_dtype != target_dtype: - self.to(target_dtype) + self.to(**params) result = self.org_forward(*args, **kwargs) + if org_dtype != target_dtype: - self.to(org_dtype) + params['dtype'] = org_dtype + self.to(**params) if target_dtype != dtype_inference: + params['dtype'] = dtype_inference if isinstance(result, tuple): result = tuple( - i.to(dtype_inference) + i.to(**params) if isinstance(i, torch.Tensor) else i for i in result ) elif isinstance(result, torch.Tensor): - result = result.to(dtype_inference) + result = result.to(**params) return result return forward_wrapper @contextlib.contextmanager -def manual_cast(target_dtype): +def manual_cast(target_dtype, target_device=None): applied = False + for module_type in patch_module_list: if hasattr(module_type, "org_forward"): continue applied = True org_forward = module_type.forward if module_type == torch.nn.MultiheadAttention: - module_type.forward = manual_cast_forward(torch.float32) + module_type.forward = manual_cast_forward(torch.float32, target_device) else: - module_type.forward = manual_cast_forward(target_dtype) + module_type.forward = manual_cast_forward(target_dtype, target_device) module_type.org_forward = org_forward try: yield None @@ -207,26 +240,37 @@ def manual_cast(target_dtype): delattr(module_type, "org_forward") -def autocast(disable=False): +def autocast(disable=False, current_dtype=None, target_dtype=None, target_device=None): if disable: return contextlib.nullcontext() + if target_dtype is None: + target_dtype = dtype + if target_device is None: + target_device = device + if force_fp16: # No casting during inference if force_fp16 is enabled. # All tensor dtype conversion happens before inference. return contextlib.nullcontext() - if fp8 and device==cpu: + if fp8 and target_device==cpu: return torch.autocast("cpu", dtype=torch.bfloat16, enabled=True) if fp8 and dtype_inference == torch.float32: - return manual_cast(dtype) + return manual_cast(target_dtype, target_device) - if dtype == torch.float32 or dtype_inference == torch.float32: + if target_dtype != dtype_inference: + return manual_cast(target_dtype, target_device) + + if current_dtype is not None and current_dtype != target_dtype: + return manual_cast(target_dtype, target_device) + + if target_dtype == torch.float32 or dtype_inference == torch.float32: return contextlib.nullcontext() if has_xpu() or has_mps() or cuda_no_autocast(): - return manual_cast(dtype) + return manual_cast(target_dtype, target_device) return torch.autocast("cuda") From c972951cf69d46bee9c09bd64e0215378f4fa852 Mon Sep 17 00:00:00 2001 From: Won-Kyu Park Date: Thu, 5 Sep 2024 09:34:08 +0900 Subject: [PATCH 06/50] check Unet/VAE and load as is - check float8 unet dtype to save memory - check vae/ text_encoders dtype and use as intended --- modules/sd_models.py | 159 ++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 150 insertions(+), 9 deletions(-) diff --git a/modules/sd_models.py b/modules/sd_models.py index 1c7d370e9..43e0b9208 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -407,6 +407,90 @@ def set_model_fields(model): model.latent_channels = 4 +def get_state_dict_dtype(state_dict): + # detect dtypes of state_dict + state_dict_dtype = {} + + known_prefixes = ("model.diffusion_model.", "first_stage_model.", "cond_stage_model.", "conditioner", "vae.", "text_encoders.") + + for k in state_dict.keys(): + found = [prefix for prefix in known_prefixes if k.startswith(prefix)] + if len(found) > 0: + prefix = found[0] + dtype = state_dict[k].dtype + dtypes = state_dict_dtype.get(prefix, {}) + if dtype in dtypes: + dtypes[dtype] += 1 + else: + dtypes[dtype] = 1 + state_dict_dtype[prefix] = dtypes + + for prefix in state_dict_dtype: + dtypes = state_dict_dtype[prefix] + # sort by count + state_dict_dtype[prefix] = dict(sorted(dtypes.items(), key=lambda item: item[1], reverse=True)) + + print("Detected dtypes:", state_dict_dtype) + return state_dict_dtype + + +def get_loadable_dtype(prefix="model.diffusion_model.", dtype=None, state_dict=None, state_dict_dtype=None, count=490): + if state_dict is not None: + state_dict_dtype = get_state_dict_dtype(state_dict) + + aliases = { + "FP8": "F8", + "FP16": "F16", + "FP32": "F32", + } + + loadables = { + "F8": (torch.float8_e4m3fn,), + "F16": (torch.float16,), + "F32": (torch.float32,), + "BF16": (torch.bfloat16,), + } + + if dtype is None: + # get the first dtype + if prefix in state_dict_dtype: + return list(state_dict_dtype[prefix])[0] + return None + + + if dtype in aliases: + dtype = aliases[dtype] + loadable = loadables[dtype] + + if prefix in state_dict_dtype: + dtypes = [d for d in state_dict_dtype[prefix].keys() if d in loadable] + if len(dtypes) > 0 and state_dict_dtype[prefix][dtypes[0]] >= count: + # mostly dtype weights. + return dtypes[0] + + return None + + +def get_vae_dtype(state_dict=None, state_dict_dtype=None): + if state_dict is not None: + state_dict_dtype = get_state_dict_dtype(state_dict) + + if state_dict_dtype is None: + raise ValueError("fail to get vae dtype") + + + vae_prefixes = [prefix for prefix in ("vae.", "first_stage_model.") if prefix in state_dict_dtype] + + if len(vae_prefixes) > 0: + vae_prefix = vae_prefixes[0] + for dtype in state_dict_dtype[vae_prefix]: + if state_dict_dtype[vae_prefix][dtype] > 240 and dtype in (torch.float16, torch.float32, torch.bfloat16): + # vae items: 248 for SD1, SDXL 245 for flux + return dtype + + return None + + def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer): sd_model_hash = checkpoint_info.calculate_shorthash() timer.record("calculate hash") @@ -441,6 +525,9 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer if hasattr(model, "before_load_weights"): model.before_load_weights(state_dict) + # get all dtypes of state_dict + state_dict_dtype = get_state_dict_dtype(state_dict) + model.load_state_dict(state_dict, strict=False) timer.record("apply weights to model") @@ -466,7 +553,13 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer model.to(memory_format=torch.channels_last) timer.record("apply channels_last") - if shared.cmd_opts.no_half: + # check dtype of vae + dtype_vae = get_vae_dtype(state_dict_dtype=state_dict_dtype) + found_unet_dtype = get_loadable_dtype("model.diffusion_model.", state_dict_dtype=state_dict_dtype) + unet_has_float = found_unet_dtype in (torch.float16, torch.float32, torch.bfloat16) + + if (found_unet_dtype is None or unet_has_float) and shared.cmd_opts.no_half: + # unet type is not detected or unet has float dtypes model.float() model.alphas_cumprod_original = model.alphas_cumprod devices.dtype_unet = torch.float32 @@ -476,8 +569,11 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer vae = model.first_stage_model depth_model = getattr(model, 'depth_model', None) + if dtype_vae == torch.bfloat16 and dtype_vae in devices.supported_vae_dtypes: + # preserve bfloat16 if it supported + model.first_stage_model = None # with --no-half-vae, remove VAE from model when doing half() to prevent its weights from being converted to float16 - if shared.cmd_opts.no_half_vae: + elif shared.cmd_opts.no_half_vae: model.first_stage_model = None # with --upcast-sampling, don't convert the depth model weights to float16 if shared.cmd_opts.upcast_sampling and depth_model: @@ -485,15 +581,28 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer alphas_cumprod = model.alphas_cumprod model.alphas_cumprod = None - model.half() + + + if found_unet_dtype in (torch.float16, torch.float32, torch.bfloat16): + model.half() + elif found_unet_dtype in (torch.float8_e4m3fn,): + pass + else: + print("Fail to get a vaild UNet dtype. ignore...") + model.alphas_cumprod = alphas_cumprod model.alphas_cumprod_original = alphas_cumprod model.first_stage_model = vae if depth_model: model.depth_model = depth_model - devices.dtype_unet = torch.float16 - timer.record("apply half()") + if found_unet_dtype in (torch.float16, torch.float32): + devices.dtype_unet = torch.float16 + timer.record("apply half()") + else: + print(f"load Unet {found_unet_dtype} as is ...") + devices.dtype_unet = found_unet_dtype if found_unet_dtype else torch.float16 + timer.record("load UNet") apply_alpha_schedule_override(model) @@ -503,10 +612,18 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer if hasattr(module, 'fp16_bias'): del module.fp16_bias - if check_fp8(model): + if found_unet_dtype not in (torch.float8_e4m3fn,) and check_fp8(model): devices.fp8 = True + + # do not convert vae, text_encoders.clip_l, clip_g, t5xxl first_stage = model.first_stage_model model.first_stage_model = None + vae = getattr(model, 'vae', None) + if vae is not None: + model.vae = None + text_encoders = getattr(model, 'text_encoders', None) + if text_encoders is not None: + model.text_encoders = None for module in model.modules(): if isinstance(module, (torch.nn.Conv2d, torch.nn.Linear)): if shared.opts.cache_fp16_weight: @@ -514,6 +631,10 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer if module.bias is not None: module.fp16_bias = module.bias.data.clone().cpu().half() module.to(torch.float8_e4m3fn) + if text_encoders is not None: + model.text_encoders = text_encoders + if vae is not None: + model.vae = vae model.first_stage_model = first_stage timer.record("apply fp8") else: @@ -521,8 +642,16 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer devices.unet_needs_upcast = shared.cmd_opts.upcast_sampling and devices.dtype == torch.float16 and devices.dtype_unet == torch.float16 - model.first_stage_model.to(devices.dtype_vae) - timer.record("apply dtype to VAE") + # check supported vae dtype + dtype_vae = get_vae_dtype(state_dict_dtype=state_dict_dtype) + if dtype_vae == torch.bfloat16 and dtype_vae in devices.supported_vae_dtypes: + devices.dtype_vae = torch.bfloat16 + print(f"VAE dtype {dtype_vae} detected. load as is.") + else: + # use default devices.dtype_vae + model.first_stage_model.to(devices.dtype_vae) + print(f"Use VAE dtype {devices.dtype_vae}") + timer.record("apply dtype to VAE") # clean up cache if limit is reached while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache: @@ -818,6 +947,18 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None, checkpoint_ print(f"Creating model from config: {checkpoint_config}") + # get all dtypes of state_dict + state_dict_dtype = get_state_dict_dtype(state_dict) + + # check loadable unet dtype before loading + loadable_unet_dtype = get_loadable_dtype("model.diffusion_model.", state_dict_dtype=state_dict_dtype) + + # check dtype of vae + dtype_vae = get_vae_dtype(state_dict_dtype=state_dict_dtype) + if dtype_vae == torch.bfloat16 and dtype_vae in devices.supported_vae_dtypes: + devices.dtype_vae = torch.bfloat16 + print(f"VAE dtype {dtype_vae} detected.") + sd_model = None try: with sd_disable_initialization.DisableInitialization(disable_clip=clip_is_included_into_sd or shared.cmd_opts.do_not_download_clip): @@ -843,7 +984,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None, checkpoint_ weight_dtype_conversion = { 'first_stage_model': None, 'alphas_cumprod': None, - '': torch.float16, + '': torch.float16 if loadable_unet_dtype in (torch.float16, torch.float32, torch.bfloat16) else None, } with sd_disable_initialization.LoadStateDictOnMeta(state_dict, device=model_target_device(sd_model), weight_dtype_conversion=weight_dtype_conversion): From fcd609f4b4e04263361420a6ecc1d5252ea6fb28 Mon Sep 17 00:00:00 2001 From: Won-Kyu Park Date: Sat, 14 Sep 2024 06:57:41 +0900 Subject: [PATCH 07/50] simplified get_loadable_dtype --- modules/sd_models.py | 33 +++------------------------------ 1 file changed, 3 insertions(+), 30 deletions(-) diff --git a/modules/sd_models.py b/modules/sd_models.py index 43e0b9208..75ce55e2a 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -434,40 +434,13 @@ def get_state_dict_dtype(state_dict): return state_dict_dtype -def get_loadable_dtype(prefix="model.diffusion_model.", dtype=None, state_dict=None, state_dict_dtype=None, count=490): +def get_loadable_dtype(prefix="model.diffusion_model.", state_dict=None, state_dict_dtype=None): if state_dict is not None: state_dict_dtype = get_state_dict_dtype(state_dict) - aliases = { - "FP8": "F8", - "FP16": "F16", - "FP32": "F32", - } - - loadables = { - "F8": (torch.float8_e4m3fn,), - "F16": (torch.float16,), - "F32": (torch.float32,), - "BF16": (torch.bfloat16,), - } - - if dtype is None: - # get the first dtype - if prefix in state_dict_dtype: - return list(state_dict_dtype[prefix])[0] - return None - - - if dtype in aliases: - dtype = aliases[dtype] - loadable = loadables[dtype] - + # get the first dtype if prefix in state_dict_dtype: - dtypes = [d for d in state_dict_dtype[prefix].keys() if d in loadable] - if len(dtypes) > 0 and state_dict_dtype[prefix][dtypes[0]] >= count: - # mostly dtype weights. - return dtypes[0] - + return list(state_dict_dtype[prefix])[0] return None From 537d9dd71c92fd97def33a55150c70e1d7d80e27 Mon Sep 17 00:00:00 2001 From: Won-Kyu Park Date: Sat, 7 Sep 2024 01:06:05 +0900 Subject: [PATCH 08/50] misc fixes to support float8 dtype_unet * devices.dtype_unet, dtype_vae could be considered as storage dtypes (current_dtype) * use devices.dtype_inference as computational dtype (taget_dtype) * misc fixes to support float8 unet storage --- modules/processing.py | 2 +- modules/sd_hijack_unet.py | 14 +++++++------- modules/sd_samplers_common.py | 4 ++-- 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/modules/processing.py b/modules/processing.py index 92c3582cc..b1107e757 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -984,7 +984,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: sd_models.apply_alpha_schedule_override(p.sd_model, p) - with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast(): + with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast(target_dtype=devices.dtype_inference, current_dtype=devices.dtype_unet): samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts) if p.scripts is not None: diff --git a/modules/sd_hijack_unet.py b/modules/sd_hijack_unet.py index b4f03b138..842030be8 100644 --- a/modules/sd_hijack_unet.py +++ b/modules/sd_hijack_unet.py @@ -42,12 +42,12 @@ def apply_model(orig_func, self, x_noisy, t, cond, **kwargs): if isinstance(cond, dict): for y in cond.keys(): if isinstance(cond[y], list): - cond[y] = [x.to(devices.dtype_unet) if isinstance(x, torch.Tensor) else x for x in cond[y]] + cond[y] = [x.to(devices.dtype_inference) if isinstance(x, torch.Tensor) else x for x in cond[y]] else: - cond[y] = cond[y].to(devices.dtype_unet) if isinstance(cond[y], torch.Tensor) else cond[y] + cond[y] = cond[y].to(devices.dtype_inference) if isinstance(cond[y], torch.Tensor) else cond[y] with devices.autocast(): - result = orig_func(self, x_noisy.to(devices.dtype_unet), t.to(devices.dtype_unet), cond, **kwargs) + result = orig_func(self, x_noisy.to(devices.dtype_inference), t.to(devices.dtype_inference), cond, **kwargs) if devices.unet_needs_upcast: return result.float() else: @@ -107,7 +107,7 @@ class GELUHijack(torch.nn.GELU, torch.nn.Module): torch.nn.GELU.__init__(self, *args, **kwargs) def forward(self, x): if devices.unet_needs_upcast: - return torch.nn.GELU.forward(self.float(), x.float()).to(devices.dtype_unet) + return torch.nn.GELU.forward(self.float(), x.float()).to(devices.dtype_inference) else: return torch.nn.GELU.forward(self, x) @@ -125,11 +125,11 @@ unet_needs_upcast = lambda *args, **kwargs: devices.unet_needs_upcast CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.apply_model', apply_model, unet_needs_upcast) CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', timestep_embedding) CondFunc('ldm.modules.attention.SpatialTransformer.forward', spatial_transformer_forward) -CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', lambda orig_func, timesteps, *args, **kwargs: orig_func(timesteps, *args, **kwargs).to(torch.float32 if timesteps.dtype == torch.int64 else devices.dtype_unet), unet_needs_upcast) +CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', lambda orig_func, timesteps, *args, **kwargs: orig_func(timesteps, *args, **kwargs).to(torch.float32 if timesteps.dtype == torch.int64 else devices.dtype_inference), unet_needs_upcast) if version.parse(torch.__version__) <= version.parse("1.13.2") or torch.cuda.is_available(): CondFunc('ldm.modules.diffusionmodules.util.GroupNorm32.forward', lambda orig_func, self, *args, **kwargs: orig_func(self.float(), *args, **kwargs), unet_needs_upcast) - CondFunc('ldm.modules.attention.GEGLU.forward', lambda orig_func, self, x: orig_func(self.float(), x.float()).to(devices.dtype_unet), unet_needs_upcast) + CondFunc('ldm.modules.attention.GEGLU.forward', lambda orig_func, self, x: orig_func(self.float(), x.float()).to(devices.dtype_inference), unet_needs_upcast) CondFunc('open_clip.transformer.ResidualAttentionBlock.__init__', lambda orig_func, *args, **kwargs: kwargs.update({'act_layer': GELUHijack}) and False or orig_func(*args, **kwargs), lambda _, *args, **kwargs: kwargs.get('act_layer') is None or kwargs['act_layer'] == torch.nn.GELU) first_stage_cond = lambda _, self, *args, **kwargs: devices.unet_needs_upcast and self.model.diffusion_model.dtype == torch.float16 @@ -146,7 +146,7 @@ def timestep_embedding_cast_result(orig_func, timesteps, *args, **kwargs): if devices.unet_needs_upcast and timesteps.dtype == torch.int64: dtype = torch.float32 else: - dtype = devices.dtype_unet + dtype = devices.dtype_inference return orig_func(timesteps, *args, **kwargs).to(dtype=dtype) diff --git a/modules/sd_samplers_common.py b/modules/sd_samplers_common.py index c060cccb2..28b8bd820 100644 --- a/modules/sd_samplers_common.py +++ b/modules/sd_samplers_common.py @@ -54,7 +54,7 @@ def samples_to_images_tensor(sample, approximation=None, model=None): else: if model is None: model = shared.sd_model - with torch.no_grad(), devices.without_autocast(): # fixes an issue with unstable VAEs that are flaky even in fp32 + with torch.no_grad(), devices.manual_cast(devices.dtype_vae): # fixes an issue with unstable VAEs that are flaky even in fp32 x_sample = model.decode_first_stage(sample.to(model.first_stage_model.dtype)) return x_sample @@ -64,7 +64,7 @@ def single_sample_to_image(sample, approximation=None): x_sample = samples_to_images_tensor(sample.unsqueeze(0), approximation)[0] * 0.5 + 0.5 x_sample = torch.clamp(x_sample, min=0.0, max=1.0) - x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2) + x_sample = 255. * np.moveaxis(x_sample.to(dtype=devices.dtype).cpu().numpy(), 0, 2) x_sample = x_sample.astype(np.uint8) return Image.fromarray(x_sample) From 2060886450aea0b01035da73b1660515f50cf88e Mon Sep 17 00:00:00 2001 From: Won-Kyu Park Date: Sat, 7 Sep 2024 12:31:00 +0900 Subject: [PATCH 09/50] add shared.opts.lora_without_backup_weight option to reduce ram usage --- extensions-builtin/Lora/networks.py | 12 +++++++++++- modules/shared_options.py | 1 + 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/extensions-builtin/Lora/networks.py b/extensions-builtin/Lora/networks.py index 67f9abe2a..e58e1fb56 100644 --- a/extensions-builtin/Lora/networks.py +++ b/extensions-builtin/Lora/networks.py @@ -377,6 +377,8 @@ def store_weights_backup(weight): if weight is None: return None + if shared.opts.lora_without_backup_weight: + return True return weight.to(devices.cpu, copy=True) @@ -395,6 +397,9 @@ def network_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Li if weights_backup is None and bias_backup is None: return + if weights_backup is True or weights_backup == (True, True): # fake backup + return + if weights_backup is not None: if isinstance(self, torch.nn.MultiheadAttention): restore_weights_backup(self, 'in_proj_weight', weights_backup[0]) @@ -539,7 +544,12 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn logging.debug(f"Network {net.name} layer {network_layer_name}: couldn't find supported operation") extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1 - self.network_current_names = wanted_names + + if weights_backup is True or weights_backup == (True, True): # fake backup + self.network_weights_backup = None + self.network_bias_backup = None + else: + self.network_current_names = wanted_names def network_forward(org_module, input, original_forward): diff --git a/modules/shared_options.py b/modules/shared_options.py index 78089cbec..6b6faf332 100644 --- a/modules/shared_options.py +++ b/modules/shared_options.py @@ -242,6 +242,7 @@ options_templates.update(options_section(('optimizations', "Optimizations", "sd" "batch_cond_uncond": OptionInfo(True, "Batch cond/uncond").info("do both conditional and unconditional denoising in one batch; uses a bit more VRAM during sampling, but improves speed; previously this was controlled by --always-batch-cond-uncond commandline argument"), "fp8_storage": OptionInfo("Disable", "FP8 weight", gr.Radio, {"choices": ["Disable", "Enable for SDXL", "Enable"]}).info("Use FP8 to store Linear/Conv layers' weight. Require pytorch>=2.1.0."), "cache_fp16_weight": OptionInfo(False, "Cache FP16 weight for LoRA").info("Cache fp16 weight when enabling FP8, will increase the quality of LoRA. Use more system ram."), + "lora_without_backup_weight": OptionInfo(False, "LoRA without backup weights").info("LoRA without backup weights to save RAM."), })) options_templates.update(options_section(('compatibility', "Compatibility", "sd"), { From 2f72fd89ff532b78949257a66c37eac81bf33f1b Mon Sep 17 00:00:00 2001 From: Won-Kyu Park Date: Sat, 7 Sep 2024 12:23:03 +0900 Subject: [PATCH 10/50] support copy option to reduce ram usage --- modules/devices.py | 30 +++++++++++++++++++++--------- 1 file changed, 21 insertions(+), 9 deletions(-) diff --git a/modules/devices.py b/modules/devices.py index 556e72d2e..866b6ab16 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -1,5 +1,6 @@ import sys import contextlib +from copy import deepcopy from functools import lru_cache import torch @@ -169,7 +170,7 @@ patch_module_list = [ ] -def manual_cast_forward(target_dtype, target_device=None): +def manual_cast_forward(target_dtype, target_device=None, copy=False): params = dict() if supports_non_blocking(): params['non_blocking'] = True @@ -193,13 +194,22 @@ def manual_cast_forward(target_dtype, target_device=None): org_dtype = param.dtype break - if org_dtype != target_dtype: - self.to(**params) - result = self.org_forward(*args, **kwargs) + if copy: + copied = deepcopy(self) + if org_dtype != target_dtype: + copied.to(**params) - if org_dtype != target_dtype: - params['dtype'] = org_dtype - self.to(**params) + result = copied.org_forward(*args, **kwargs) + del copied + else: + if org_dtype != target_dtype: + self.to(**params) + + result = self.org_forward(*args, **kwargs) + + if org_dtype != target_dtype: + params['dtype'] = org_dtype + self.to(**params) if target_dtype != dtype_inference: params['dtype'] = dtype_inference @@ -220,15 +230,17 @@ def manual_cast_forward(target_dtype, target_device=None): def manual_cast(target_dtype, target_device=None): applied = False + copy = shared.opts.lora_without_backup_weight + for module_type in patch_module_list: if hasattr(module_type, "org_forward"): continue applied = True org_forward = module_type.forward if module_type == torch.nn.MultiheadAttention: - module_type.forward = manual_cast_forward(torch.float32, target_device) + module_type.forward = manual_cast_forward(torch.float32, target_device, copy) else: - module_type.forward = manual_cast_forward(target_dtype, target_device) + module_type.forward = manual_cast_forward(target_dtype, target_device, copy) module_type.org_forward = org_forward try: yield None From 24f2c1b9e447a5a890596bb05c84113cb959aa5a Mon Sep 17 00:00:00 2001 From: Won-Kyu Park Date: Thu, 12 Sep 2024 00:23:00 +0900 Subject: [PATCH 11/50] fix to support dtype_inference != dtype case --- modules/processing.py | 4 ++-- modules/sd_models.py | 2 +- modules/sd_models_xl.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/modules/processing.py b/modules/processing.py index b1107e757..3b23ab7af 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -484,7 +484,7 @@ class StableDiffusionProcessing: cache = caches[0] - with devices.autocast(): + with devices.autocast(target_dtype=devices.dtype_inference): cache[1] = function(shared.sd_model, required_prompts, steps, hires_steps, shared.opts.use_old_scheduling) cache[0] = cached_params @@ -1439,7 +1439,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): with devices.autocast(): extra_networks.activate(self, self.hr_extra_network_data) - with devices.autocast(): + with devices.autocast(target_dtype=devices.dtype_inference): self.calculate_hr_conds() sd_models.apply_token_merging(self.sd_model, self.get_token_merging_ratio(for_hr=True)) diff --git a/modules/sd_models.py b/modules/sd_models.py index 75ce55e2a..41649970b 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -984,7 +984,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None, checkpoint_ timer.record("scripts callbacks") - with devices.autocast(), torch.no_grad(): + with devices.autocast(target_dtype=devices.dtype_inference), torch.no_grad(): sd_model.cond_stage_model_empty_prompt = get_empty_cond(sd_model) timer.record("calculate empty prompt") diff --git a/modules/sd_models_xl.py b/modules/sd_models_xl.py index 1242a5936..5abb75d1f 100644 --- a/modules/sd_models_xl.py +++ b/modules/sd_models_xl.py @@ -18,7 +18,7 @@ def get_learned_conditioning(self: sgm.models.diffusion.DiffusionEngine, batch: is_negative_prompt = getattr(batch, 'is_negative_prompt', False) aesthetic_score = shared.opts.sdxl_refiner_low_aesthetic_score if is_negative_prompt else shared.opts.sdxl_refiner_high_aesthetic_score - devices_args = dict(device=devices.device, dtype=devices.dtype) + devices_args = dict(device=devices.device, dtype=devices.dtype_inference) sdxl_conds = { "txt": batch, From 477ff355177124ec9a2a538aae0b535f9f72fede Mon Sep 17 00:00:00 2001 From: Won-Kyu Park Date: Thu, 5 Sep 2024 09:19:04 +0900 Subject: [PATCH 12/50] preserve detected dtype_inference --- modules/shared_init.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/shared_init.py b/modules/shared_init.py index a6ad0433d..2d58c6374 100644 --- a/modules/shared_init.py +++ b/modules/shared_init.py @@ -29,7 +29,7 @@ def initialize(): devices.dtype = torch.float32 if cmd_opts.no_half else torch.float16 devices.dtype_vae = torch.float32 if cmd_opts.no_half or cmd_opts.no_half_vae else torch.float16 - devices.dtype_inference = torch.float32 if cmd_opts.precision == 'full' else devices.dtype + devices.dtype_inference = torch.float32 if cmd_opts.precision == 'full' else devices.dtype_inference if cmd_opts.precision == "half": msg = "--no-half and --no-half-vae conflict with --precision half" From d6a609a5392b6e3945efda6c71ea6a1b3a2ce2ac Mon Sep 17 00:00:00 2001 From: Won-Kyu Park Date: Thu, 5 Sep 2024 20:15:37 +0900 Subject: [PATCH 13/50] add diffusers weight mapping for flux lora * add QkvLinear class for Flux lora --- modules/models/flux/flux.py | 39 ++++++++++++++++++++------- modules/models/flux/modules/layers.py | 8 ++++-- 2 files changed, 36 insertions(+), 11 deletions(-) diff --git a/modules/models/flux/flux.py b/modules/models/flux/flux.py index d17febc68..fc1e91e9d 100644 --- a/modules/models/flux/flux.py +++ b/modules/models/flux/flux.py @@ -196,7 +196,8 @@ class BaseModel(torch.nn.Module): self.diffusion_model = Flux(device=device, dtype=devices.dtype, **params) self.model_sampling = ModelSamplingFlux() - self.depth = 19 + self.depth = params['depth'] + self.depth_single_block = params['depth_single_blocks'] def apply_model(self, x, sigma, c_crossattn=None, y=None): dtype = self.get_dtype() @@ -326,13 +327,33 @@ class FLUX1Inferencer(torch.nn.Module): return width // 16 * 16, height // 16 * 16 def diffusers_weight_mapping(self): + # https://github.com/kohya-ss/sd-scripts/blob/a61cf73a5cb5209c3f4d1a3688dd276a4dfd1ecb/networks/convert_flux_lora.py + # please see also https://github.com/huggingface/diffusers/blob/main/src/diffusers/loaders/lora_conversion_utils.py for i in range(self.model.depth): - yield f"transformer.transformer_blocks.{i}.attn.to_q", f"diffusion_model_joint_blocks_{i}_x_block_attn_qkv_q_proj" - yield f"transformer.transformer_blocks.{i}.attn.to_k", f"diffusion_model_joint_blocks_{i}_x_block_attn_qkv_k_proj" - yield f"transformer.transformer_blocks.{i}.attn.to_v", f"diffusion_model_joint_blocks_{i}_x_block_attn_qkv_v_proj" - yield f"transformer.transformer_blocks.{i}.attn.to_out.0", f"diffusion_model_joint_blocks_{i}_x_block_attn_proj" + yield f"transformer.transformer_blocks.{i}.attn.add_k_proj", f"diffusion_model_double_blocks_{i}_txt_attn_qkv_k_proj" + yield f"transformer.transformer_blocks.{i}.attn.add_q_proj", f"diffusion_model_double_blocks_{i}_txt_attn_qkv_q_proj" + yield f"transformer.transformer_blocks.{i}.attn.add_v_proj", f"diffusion_model_double_blocks_{i}_txt_attn_qkv_v_proj" - yield f"transformer.transformer_blocks.{i}.attn.add_q_proj", f"diffusion_model_joint_blocks_{i}_context_block.attn_qkv_q_proj" - yield f"transformer.transformer_blocks.{i}.attn.add_k_proj", f"diffusion_model_joint_blocks_{i}_context_block.attn_qkv_k_proj" - yield f"transformer.transformer_blocks.{i}.attn.add_v_proj", f"diffusion_model_joint_blocks_{i}_context_block.attn_qkv_v_proj" - yield f"transformer.transformer_blocks.{i}.attn.add_out_proj.0", f"diffusion_model_joint_blocks_{i}_context_block_attn_proj" + yield f"transformer.transformer_blocks.{i}.attn.to_add_out", f"diffusion_model_double_blocks_{i}_txt_attn_proj" + + yield f"transformer.transformer_blocks.{i}.attn.to_k", f"diffusion_model_double_blocks_{i}_img_attn_qkv_k_proj" + yield f"transformer.transformer_blocks.{i}.attn.to_q", f"diffusion_model_double_blocks_{i}_img_attn_qkv_q_proj" + yield f"transformer.transformer_blocks.{i}.attn.to_v", f"diffusion_model_double_blocks_{i}_img_attn_qkv_v_proj" + + yield f"transformer.transformer_blocks.{i}.attn.to_out.0", f"diffusion_model_double_blocks_{i}_img_attn_proj" + + yield f"transformer.transformer_blocks.{i}.ff.net.0.proj", f"diffusion_model_double_blocks_{i}_img_mlp_0" + yield f"transformer.transformer_blocks.{i}.ff.net.2", f"diffusion_model_double_blocks_{i}_img_mlp_2" + yield f"transformer.transformer_blocks.{i}.ff_context.net.0.proj", f"diffusion_model_double_blocks_{i}_txt_mlp_0" + yield f"transformer.transformer_blocks.{i}.ff_context.net.2", f"diffusion_model_double_blocks_{i}_txt_mlp_2" + yield f"transformer.transformer_blocks.{i}.norm1.linear", f"diffusion_model_double_blocks_{i}_img_mod_lin" + yield f"transformer.transformer_blocks.{i}.norm1_context.linear", f"diffusion_model_double_blocks_{i}_txt_mod_lin" + + for i in range(self.model.depth_single_block): + yield f"transformer.single_transformer_blocks.{i}.attn.to_q", f"diffusion_model_single_blocks_{i}_linear1_q_proj" + yield f"transformer.single_transformer_blocks.{i}.attn.to_k", f"diffusion_model_single_blocks_{i}_linear1_k_proj" + yield f"transformer.single_transformer_blocks.{i}.attn.to_v", f"diffusion_model_single_blocks_{i}_linear1_v_proj" + yield f"transformer.single_transformer_blocks.{i}.proj_mlp", f"diffusion_model_single_blocks_{i}_linear1_mlp_proj" + + yield f"transformer.single_transformer_blocks.{i}.proj_out", f"diffusion_model_single_blocks_{i}_linear2" + yield f"transformer.single_transformer_blocks.{i}.morm.linear", f"diffusion_model_single_blocks_{i}_modulation_lin" diff --git a/modules/models/flux/modules/layers.py b/modules/models/flux/modules/layers.py index aa830849e..2202f5dcb 100644 --- a/modules/models/flux/modules/layers.py +++ b/modules/models/flux/modules/layers.py @@ -82,13 +82,17 @@ class QKNorm(torch.nn.Module): return q.to(v), k.to(v) +class QkvLinear(torch.nn.Linear): + pass + + class SelfAttention(nn.Module): def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False, dtype=None, device=None): super().__init__() self.num_heads = num_heads head_dim = dim // num_heads - self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias, dtype=dtype, device=device) + self.qkv = QkvLinear(dim, dim * 3, bias=qkv_bias, dtype=dtype, device=device) self.norm = QKNorm(head_dim, dtype=dtype, device=device) self.proj = nn.Linear(dim, dim, dtype=dtype, device=device) @@ -217,7 +221,7 @@ class SingleStreamBlock(nn.Module): self.mlp_hidden_dim = int(hidden_size * mlp_ratio) # qkv and mlp_in - self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim, dtype=dtype, device=device) + self.linear1 = QkvLinear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim, dtype=dtype, device=device) # proj and mlp_out self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size, dtype=dtype, device=device) From 51c285265f5b67561ef093977407688a04ba5e34 Mon Sep 17 00:00:00 2001 From: Won-Kyu Park Date: Fri, 6 Sep 2024 14:07:31 +0900 Subject: [PATCH 14/50] fix for Lora flux --- extensions-builtin/Lora/network_lora.py | 3 ++- extensions-builtin/Lora/networks.py | 28 ++++++++++++++++++++++--- 2 files changed, 27 insertions(+), 4 deletions(-) diff --git a/extensions-builtin/Lora/network_lora.py b/extensions-builtin/Lora/network_lora.py index a7a088949..2bc6af5d2 100644 --- a/extensions-builtin/Lora/network_lora.py +++ b/extensions-builtin/Lora/network_lora.py @@ -2,6 +2,7 @@ import torch import lyco_helpers import modules.models.sd3.mmdit +import modules.models.flux.modules.layers import network from modules import devices @@ -37,7 +38,7 @@ class NetworkModuleLora(network.NetworkModule): if weight is None and none_ok: return None - is_linear = type(self.sd_module) in [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear, torch.nn.MultiheadAttention, modules.models.sd3.mmdit.QkvLinear] + is_linear = type(self.sd_module) in [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear, torch.nn.MultiheadAttention, modules.models.sd3.mmdit.QkvLinear, modules.models.flux.modules.layers.QkvLinear ] is_conv = type(self.sd_module) in [torch.nn.Conv2d] if is_linear: diff --git a/extensions-builtin/Lora/networks.py b/extensions-builtin/Lora/networks.py index e58e1fb56..46c9ac1ff 100644 --- a/extensions-builtin/Lora/networks.py +++ b/extensions-builtin/Lora/networks.py @@ -37,7 +37,7 @@ module_types = [ re_digits = re.compile(r"\d+") -re_x_proj = re.compile(r"(.*)_([qkv]_proj)$") +re_x_proj = re.compile(r"(.*)_((?:[qkv]|mlp)_proj)$") re_compiled = {} suffix_conversion = { @@ -460,7 +460,7 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn for net in loaded_networks: module = net.modules.get(network_layer_name, None) - if module is not None and hasattr(self, 'weight') and not isinstance(module, modules.models.sd3.mmdit.QkvLinear): + if module is not None and hasattr(self, 'weight') and not all(isinstance(module, linear) for linear in (modules.models.sd3.mmdit.QkvLinear, modules.models.flux.modules.layers.QkvLinear)): try: with torch.no_grad(): if getattr(self, 'fp16_weight', None) is None: @@ -520,7 +520,9 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn continue - if isinstance(self, modules.models.sd3.mmdit.QkvLinear) and module_q and module_k and module_v: + module_mlp = net.modules.get(network_layer_name + "_mlp_proj", None) + + if any(isinstance(self, linear) for linear in (modules.models.sd3.mmdit.QkvLinear, modules.models.flux.modules.layers.QkvLinear)) and module_q and module_k and module_v and module_mlp is None: try: with torch.no_grad(): # Send "real" orig_weight into MHA's lora module @@ -531,6 +533,26 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn del qw, kw, vw updown_qkv = torch.vstack([updown_q, updown_k, updown_v]) self.weight += updown_qkv + del updown_qkv + + except RuntimeError as e: + logging.debug(f"Network {net.name} layer {network_layer_name}: {e}") + extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1 + + continue + + if any(isinstance(self, linear) for linear in (modules.models.flux.modules.layers.QkvLinear,)) and module_q and module_k and module_v and module_mlp: + try: + with torch.no_grad(): + qw, kw, vw, mlp = torch.tensor_split(self.weight, (3072, 6144, 9216,), 0) + updown_q, _ = module_q.calc_updown(qw) + updown_k, _ = module_k.calc_updown(kw) + updown_v, _ = module_v.calc_updown(vw) + updown_mlp, _ = module_v.calc_updown(mlp) + del qw, kw, vw, mlp + updown_qkv_mlp = torch.vstack([updown_q, updown_k, updown_v, updown_mlp]) + self.weight += updown_qkv_mlp + del updown_qkv_mlp except RuntimeError as e: logging.debug(f"Network {net.name} layer {network_layer_name}: {e}") From 7e2d51965f8b9260b3b3cff21fd397fdba9b3db1 Mon Sep 17 00:00:00 2001 From: Won-Kyu Park Date: Thu, 5 Sep 2024 08:57:42 +0900 Subject: [PATCH 15/50] fix for t5xxl --- modules/models/flux/flux.py | 11 +++++++---- modules/shared_options.py | 3 +++ 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/modules/models/flux/flux.py b/modules/models/flux/flux.py index fc1e91e9d..eb17a31c7 100644 --- a/modules/models/flux/flux.py +++ b/modules/models/flux/flux.py @@ -80,7 +80,7 @@ class FluxCond(torch.nn.Module): with torch.no_grad(): self.clip_l = SDClipModel(layer="hidden", layer_idx=-2, device="cpu", dtype=devices.dtype, layer_norm_hidden_state=False, return_projected_pooled=False, textmodel_json_config=CLIPL_CONFIG) - if shared.opts.sd3_enable_t5: + if shared.opts.flux_enable_t5: self.t5xxl = T5XXLModel(T5_CONFIG, device="cpu", dtype=devices.dtype) else: self.t5xxl = None @@ -107,7 +107,7 @@ class FluxCond(torch.nn.Module): with safetensors.safe_open(clip_l_file, framework="pt") as file: self.clip_l.transformer.load_state_dict(SafetensorsMapping(file), strict=False) - if self.t5xxl and 'text_encoders.t5xxl.transformer.encoder.embed_tokens.weight' not in state_dict: + if self.t5xxl and 'text_encoders.t5xxl.transformer.encoder.block.0.layer.0.SelfAttention.k.weight' not in state_dict: t5_file = modelloader.load_file_from_url(T5_URL, model_dir=clip_path, file_name="t5xxl_fp16.safetensors") with safetensors.safe_open(t5_file, framework="pt") as file: self.t5xxl.transformer.load_state_dict(SafetensorsMapping(file), strict=False) @@ -194,7 +194,7 @@ class BaseModel(torch.nn.Module): guidance_embed=True, ) - self.diffusion_model = Flux(device=device, dtype=devices.dtype, **params) + self.diffusion_model = Flux(device=device, dtype=dtype, **params) self.model_sampling = ModelSamplingFlux() self.depth = params['depth'] self.depth_single_block = params['depth_single_blocks'] @@ -301,7 +301,10 @@ class FLUX1Inferencer(torch.nn.Module): def decode_first_stage(self, latent): latent = self.latent_format.process_out(latent) - return self.first_stage_model.decode(latent) + x = self.first_stage_model.decode(latent) + if x.dtype == torch.float16: + x = torch.nan_to_num(x, nan=0.0, posinf=65504, neginf=-65504) + return x def encode_first_stage(self, image): latent = self.first_stage_model.encode(image) diff --git a/modules/shared_options.py b/modules/shared_options.py index 6b6faf332..0cb10acfd 100644 --- a/modules/shared_options.py +++ b/modules/shared_options.py @@ -195,6 +195,9 @@ options_templates.update(options_section(('sdxl', "Stable Diffusion XL", "sd"), options_templates.update(options_section(('sd3', "Stable Diffusion 3", "sd"), { "sd3_enable_t5": OptionInfo(False, "Enable T5").info("load T5 text encoder; increases VRAM use by a lot, potentially improving quality of generation; requires model reload to apply"), })) +options_templates.update(options_section(('flux', "Stable Diffusion FLUX", "sd"), { + "flux_enable_t5": OptionInfo(False, "Enable T5").info("load T5 text encoder; increases VRAM use by a lot, potentially improving quality of generation; requires model reload to apply"), +})) options_templates.update(options_section(('vae', "VAE", "sd"), { "sd_vae_explanation": OptionHTML(""" From 9c0fd83b5e822bf0c6330e69f959aa2d15ad8f17 Mon Sep 17 00:00:00 2001 From: Won-Kyu Park Date: Sun, 8 Sep 2024 20:21:48 +0900 Subject: [PATCH 16/50] vae fix for flux --- modules/models/flux/flux.py | 3 ++- modules/sd_models.py | 2 ++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/modules/models/flux/flux.py b/modules/models/flux/flux.py index eb17a31c7..14fa4e255 100644 --- a/modules/models/flux/flux.py +++ b/modules/models/flux/flux.py @@ -270,7 +270,8 @@ class FLUX1Inferencer(torch.nn.Module): with torch.no_grad(): self.model = BaseModel(state_dict=state_dict, prefix=diffusion_model_prefix, device="cpu", dtype=devices.dtype) self.first_stage_model = SDVAE(device="cpu", dtype=devices.dtype_vae) - self.first_stage_model.dtype = self.model.diffusion_model.dtype + self.first_stage_model.dtype = devices.dtype_vae + self.vae = self.first_stage_model # real vae self.alphas_cumprod = 1 / (self.model.model_sampling.sigmas ** 2 + 1) diff --git a/modules/sd_models.py b/modules/sd_models.py index 41649970b..b4702151a 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -956,6 +956,8 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None, checkpoint_ else: weight_dtype_conversion = { 'first_stage_model': None, + 'text_encoders': None, + 'vae': None, 'alphas_cumprod': None, '': torch.float16 if loadable_unet_dtype in (torch.float16, torch.float32, torch.bfloat16) else None, } From 219a0e242900388c9f68628294a1881821c9e920 Mon Sep 17 00:00:00 2001 From: Won-Kyu Park Date: Sat, 31 Aug 2024 16:46:34 +0900 Subject: [PATCH 17/50] support Flux1 --- modules/models/sd3/other_impls.py | 27 +++++++++++++++++---------- modules/sd_models.py | 12 ++++++++++-- modules/sd_models_config.py | 4 ++++ modules/sd_models_types.py | 3 +++ modules/sd_vae_taesd.py | 8 ++++++-- 5 files changed, 40 insertions(+), 14 deletions(-) diff --git a/modules/models/sd3/other_impls.py b/modules/models/sd3/other_impls.py index 78c1dc687..695d356cb 100644 --- a/modules/models/sd3/other_impls.py +++ b/modules/models/sd3/other_impls.py @@ -24,6 +24,11 @@ class AutocastLinear(nn.Linear): def forward(self, x): return torch.nn.functional.linear(x, self.weight.to(x.dtype), self.bias.to(x.dtype) if self.bias is not None else None) +class AutocastLayerNorm(nn.LayerNorm): + def forward(self, x): + return torch.nn.functional.layer_norm( + x, self.normalized_shape, self.weight.to(x.dtype), self.bias.to(x.dtype) if self.bias is not None else None, self.eps) + def attention(q, k, v, heads, mask=None): """Convenience wrapper around a basic attention operation""" @@ -41,9 +46,9 @@ class Mlp(nn.Module): out_features = out_features or in_features hidden_features = hidden_features or in_features - self.fc1 = nn.Linear(in_features, hidden_features, bias=bias, dtype=dtype, device=device) + self.fc1 = AutocastLinear(in_features, hidden_features, bias=bias, dtype=dtype, device=device) self.act = act_layer - self.fc2 = nn.Linear(hidden_features, out_features, bias=bias, dtype=dtype, device=device) + self.fc2 = AutocastLinear(hidden_features, out_features, bias=bias, dtype=dtype, device=device) def forward(self, x): x = self.fc1(x) @@ -61,10 +66,10 @@ class CLIPAttention(torch.nn.Module): def __init__(self, embed_dim, heads, dtype, device): super().__init__() self.heads = heads - self.q_proj = nn.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device) - self.k_proj = nn.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device) - self.v_proj = nn.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device) - self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device) + self.q_proj = AutocastLinear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device) + self.k_proj = AutocastLinear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device) + self.v_proj = AutocastLinear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device) + self.out_proj = AutocastLinear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device) def forward(self, x, mask=None): q = self.q_proj(x) @@ -82,9 +87,11 @@ ACTIVATIONS = { class CLIPLayer(torch.nn.Module): def __init__(self, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device): super().__init__() - self.layer_norm1 = nn.LayerNorm(embed_dim, dtype=dtype, device=device) + #self.layer_norm1 = nn.LayerNorm(embed_dim, dtype=dtype, device=device) + self.layer_norm1 = AutocastLayerNorm(embed_dim, dtype=dtype, device=device) self.self_attn = CLIPAttention(embed_dim, heads, dtype, device) - self.layer_norm2 = nn.LayerNorm(embed_dim, dtype=dtype, device=device) + self.layer_norm2 = AutocastLayerNorm(embed_dim, dtype=dtype, device=device) + #self.layer_norm2 = nn.LayerNorm(embed_dim, dtype=dtype, device=device) #self.mlp = CLIPMLP(embed_dim, intermediate_size, intermediate_activation, dtype, device) self.mlp = Mlp(embed_dim, intermediate_size, embed_dim, act_layer=ACTIVATIONS[intermediate_activation], dtype=dtype, device=device) @@ -131,7 +138,7 @@ class CLIPTextModel_(torch.nn.Module): super().__init__() self.embeddings = CLIPEmbeddings(embed_dim, dtype=torch.float32, device=device, textual_inversion_key=config_dict.get('textual_inversion_key', 'clip_l')) self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device) - self.final_layer_norm = nn.LayerNorm(embed_dim, dtype=dtype, device=device) + self.final_layer_norm = AutocastLayerNorm(embed_dim, dtype=dtype, device=device) def forward(self, input_tokens, intermediate_output=None, final_layer_norm_intermediate=True): x = self.embeddings(input_tokens) @@ -150,7 +157,7 @@ class CLIPTextModel(torch.nn.Module): self.num_layers = config_dict["num_hidden_layers"] self.text_model = CLIPTextModel_(config_dict, dtype, device) embed_dim = config_dict["hidden_size"] - self.text_projection = nn.Linear(embed_dim, embed_dim, bias=False, dtype=dtype, device=device) + self.text_projection = AutocastLinear(embed_dim, embed_dim, bias=False, dtype=dtype, device=device) self.text_projection.weight.copy_(torch.eye(embed_dim)) self.dtype = dtype diff --git a/modules/sd_models.py b/modules/sd_models.py index b4702151a..7e48c2328 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -33,6 +33,7 @@ class ModelType(enum.Enum): SDXL = 3 SSD = 4 SD3 = 5 + FLUX1 = 6 def replace_key(d, key, new_key, value): @@ -369,7 +370,7 @@ def check_fp8(model): enable_fp8 = False elif shared.opts.fp8_storage == "Enable": enable_fp8 = True - elif getattr(model, "is_sdxl", False) and shared.opts.fp8_storage == "Enable for SDXL": + elif any(getattr(model, attr, False) for attr in ("is_sdxl", "is_flux1")) and shared.opts.fp8_storage == "Enable for SDXL": enable_fp8 = True else: enable_fp8 = False @@ -382,10 +383,14 @@ def set_model_type(model, state_dict): model.is_sdxl = False model.is_ssd = False model.is_sd3 = False + model.is_flux1 = False if "model.diffusion_model.x_embedder.proj.weight" in state_dict: model.is_sd3 = True model.model_type = ModelType.SD3 + elif "model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale" in state_dict: + model.is_flux1 = True + model.model_type = ModelType.FLUX1 elif hasattr(model, 'conditioner'): model.is_sdxl = True @@ -777,6 +782,9 @@ sd1_clip_weight = 'cond_stage_model.transformer.text_model.embeddings.token_embe sd2_clip_weight = 'cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight' sdxl_clip_weight = 'conditioner.embedders.1.model.ln_final.weight' sdxl_refiner_clip_weight = 'conditioner.embedders.0.model.ln_final.weight' +clip_l_clip_weight = 'text_encoders.clip_l.transformer.text_model.final_layer_norm.weight' +clip_g_clip_weight = 'text_encoders.clip_g.transformer.text_model.final_layer_norm.weight' +t5xxl_clip_weight = 'text_encoders.t5xxl.transformer.encoder.final_layer_norm.weight' class SdModelData: @@ -909,7 +917,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None, checkpoint_ if not checkpoint_config: checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info) - clip_is_included_into_sd = any(x for x in [sd1_clip_weight, sd2_clip_weight, sdxl_clip_weight, sdxl_refiner_clip_weight] if x in state_dict) + clip_is_included_into_sd = any(x for x in [sd1_clip_weight, sd2_clip_weight, sdxl_clip_weight, sdxl_refiner_clip_weight, clip_l_clip_weight, clip_g_clip_weight ] if x in state_dict) timer.record("find config") diff --git a/modules/sd_models_config.py b/modules/sd_models_config.py index 3c1e4a151..4251062c8 100644 --- a/modules/sd_models_config.py +++ b/modules/sd_models_config.py @@ -25,6 +25,7 @@ config_instruct_pix2pix = os.path.join(sd_configs_path, "instruct-pix2pix.yaml") config_alt_diffusion = os.path.join(sd_configs_path, "alt-diffusion-inference.yaml") config_alt_diffusion_m18 = os.path.join(sd_configs_path, "alt-diffusion-m18-inference.yaml") config_sd3 = os.path.join(sd_configs_path, "sd3-inference.yaml") +config_flux1 = os.path.join(sd_configs_path, "flux1-inference.yaml") def is_using_v_parameterization_for_sd2(state_dict): @@ -78,6 +79,9 @@ def guess_model_config_from_state_dict(sd, filename): if "model.diffusion_model.x_embedder.proj.weight" in sd: return config_sd3 + if "model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale" in sd: + return config_flux1 + if sd.get('conditioner.embedders.1.model.ln_final.weight', None) is not None: if diffusion_model_input.shape[1] == 9: return config_sdxl_inpainting diff --git a/modules/sd_models_types.py b/modules/sd_models_types.py index 2fce2777b..867f8b6e2 100644 --- a/modules/sd_models_types.py +++ b/modules/sd_models_types.py @@ -36,5 +36,8 @@ class WebuiSdModel(LatentDiffusion): is_sd3: bool """True if the model's architecture is SD 3""" + is_flux1: bool + """True if the model's architecture is FLUX 1""" + latent_channels: int """number of layer in latent image representation; will be 16 in SD3 and 4 in other version""" diff --git a/modules/sd_vae_taesd.py b/modules/sd_vae_taesd.py index d06253d2a..76771e95e 100644 --- a/modules/sd_vae_taesd.py +++ b/modules/sd_vae_taesd.py @@ -63,7 +63,7 @@ class TAESDDecoder(nn.Module): super().__init__() if latent_channels is None: - latent_channels = 16 if "taesd3" in str(decoder_path) else 4 + latent_channels = 16 if any(typ in str(decoder_path) for typ in ("taesd3", "taef1")) else 4 self.decoder = decoder(latent_channels) self.decoder.load_state_dict( @@ -79,7 +79,7 @@ class TAESDEncoder(nn.Module): super().__init__() if latent_channels is None: - latent_channels = 16 if "taesd3" in str(encoder_path) else 4 + latent_channels = 16 if any(typ in str(encoder_path) for typ in ("taesd3", "taef1")) else 4 self.encoder = encoder(latent_channels) self.encoder.load_state_dict( @@ -97,6 +97,8 @@ def download_model(model_path, model_url): def decoder_model(): if shared.sd_model.is_sd3: model_name = "taesd3_decoder.pth" + elif shared.sd_model.is_flux1: + model_name = "taef1_decoder.pth" elif shared.sd_model.is_sdxl: model_name = "taesdxl_decoder.pth" else: @@ -122,6 +124,8 @@ def decoder_model(): def encoder_model(): if shared.sd_model.is_sd3: model_name = "taesd3_encoder.pth" + elif shared.sd_model.is_flux1: + model_name = "taef1_encoder.pth" elif shared.sd_model.is_sdxl: model_name = "taesdxl_encoder.pth" else: From 9e57c722b284eb7bc8e6fca225a89da6425eb924 Mon Sep 17 00:00:00 2001 From: Won-Kyu Park Date: Sat, 31 Aug 2024 23:22:43 +0900 Subject: [PATCH 18/50] fix to support float8_* --- modules/sd_disable_initialization.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/modules/sd_disable_initialization.py b/modules/sd_disable_initialization.py index 0fc1596b7..aa57c3e06 100644 --- a/modules/sd_disable_initialization.py +++ b/modules/sd_disable_initialization.py @@ -160,7 +160,7 @@ class LoadStateDictOnMeta(ReplaceHelper): self.state_dict = state_dict self.device = device self.weight_dtype_conversion = weight_dtype_conversion or {} - self.default_dtype = self.weight_dtype_conversion.get('') + self.default_dtype = self.weight_dtype_conversion.get('', None) def get_weight_dtype(self, key): key_first_term, _ = key.split('.', 1) @@ -183,7 +183,11 @@ class LoadStateDictOnMeta(ReplaceHelper): key = prefix + name sd_param = sd.pop(key, None) if sd_param is not None: - state_dict[key] = sd_param.to(dtype=self.get_weight_dtype(key)) + dtype = self.get_weight_dtype(key) + if dtype is None: + state_dict[key] = sd_param + else: + state_dict[key] = sd_param.to(dtype=dtype) used_param_keys.append(key) if param.is_meta: From 789bfc7db4ca86fb8717e3871ad96b5f3fcd888a Mon Sep 17 00:00:00 2001 From: Won-Kyu Park Date: Sat, 31 Aug 2024 18:07:57 +0900 Subject: [PATCH 19/50] add cheap approximation for flux --- modules/sd_vae_approx.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/modules/sd_vae_approx.py b/modules/sd_vae_approx.py index c5dda7431..48ffd672d 100644 --- a/modules/sd_vae_approx.py +++ b/modules/sd_vae_approx.py @@ -44,6 +44,8 @@ def model(): model_name = "vaeapprox-sd3.pt" elif shared.sd_model.is_sdxl: model_name = "vaeapprox-sdxl.pt" + elif shared.sd_model.is_flux1: + model_name = "vaeapprox-sd3.pt" else: model_name = "model.pt" @@ -81,6 +83,18 @@ def cheap_approximation(sample): [ 0.0815, 0.0846, 0.1207], [-0.0120, -0.0055, -0.0867], [-0.0749, -0.0634, -0.0456], [-0.1418, -0.1457, -0.1259], ] + elif shared.sd_model.is_flux1: + coeffs = [ + # from comfy + [-0.0404, 0.0159, 0.0609], [ 0.0043, 0.0298, 0.0850], + [ 0.0328, -0.0749, -0.0503], [-0.0245, 0.0085, 0.0549], + [ 0.0966, 0.0894, 0.0530], [ 0.0035, 0.0399, 0.0123], + [ 0.0583, 0.1184, 0.1262], [-0.0191, -0.0206, -0.0306], + [-0.0324, 0.0055, 0.1001], [ 0.0955, 0.0659, -0.0545], + [-0.0504, 0.0231, -0.0013], [ 0.0500, -0.0008, -0.0088], + [ 0.0982, 0.0941, 0.0976], [-0.1233, -0.0280, -0.0897], + [-0.0005, -0.0530, -0.0020], [-0.1273, -0.0932, -0.0680], + ] elif shared.sd_model.is_sdxl: coeffs = [ [ 0.3448, 0.4168, 0.4395], From 44a8480f0c9979dc143d9cb0f5c92aa9960474ac Mon Sep 17 00:00:00 2001 From: Won-Kyu Park Date: Tue, 10 Sep 2024 19:05:41 +0900 Subject: [PATCH 20/50] minor update * use dtype_inference --- modules/models/flux/flux.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/modules/models/flux/flux.py b/modules/models/flux/flux.py index 14fa4e255..46fd568a0 100644 --- a/modules/models/flux/flux.py +++ b/modules/models/flux/flux.py @@ -78,10 +78,10 @@ class FluxCond(torch.nn.Module): self.tokenizer = FluxTokenizer() with torch.no_grad(): - self.clip_l = SDClipModel(layer="hidden", layer_idx=-2, device="cpu", dtype=devices.dtype, layer_norm_hidden_state=False, return_projected_pooled=False, textmodel_json_config=CLIPL_CONFIG) + self.clip_l = SDClipModel(layer="hidden", layer_idx=-2, device="cpu", dtype=devices.dtype_inference, layer_norm_hidden_state=False, return_projected_pooled=False, textmodel_json_config=CLIPL_CONFIG) if shared.opts.flux_enable_t5: - self.t5xxl = T5XXLModel(T5_CONFIG, device="cpu", dtype=devices.dtype) + self.t5xxl = T5XXLModel(T5_CONFIG, device="cpu", dtype=devices.dtype_inference) else: self.t5xxl = None @@ -202,8 +202,8 @@ class BaseModel(torch.nn.Module): def apply_model(self, x, sigma, c_crossattn=None, y=None): dtype = self.get_dtype() timestep = self.model_sampling.timestep(sigma).float() - guidance = torch.FloatTensor([3.5]).to(device=devices.device, dtype=dtype) - model_output = self.diffusion_model(x.to(dtype), timestep, context=c_crossattn.to(dtype), y=y.to(dtype), guidance=guidance).float() + guidance = torch.FloatTensor([3.5]).to(device=devices.device, dtype=torch.float32) + model_output = self.diffusion_model(x.to(dtype), timestep, context=c_crossattn.to(dtype), y=y.to(dtype), guidance=guidance).to(x.dtype) return self.model_sampling.calculate_denoised(sigma, model_output, x) def forward(self, *args, **kwargs): @@ -268,7 +268,7 @@ class FLUX1Inferencer(torch.nn.Module): diffusion_model_prefix = "" with torch.no_grad(): - self.model = BaseModel(state_dict=state_dict, prefix=diffusion_model_prefix, device="cpu", dtype=devices.dtype) + self.model = BaseModel(state_dict=state_dict, prefix=diffusion_model_prefix, device="cpu", dtype=devices.dtype_inference) self.first_stage_model = SDVAE(device="cpu", dtype=devices.dtype_vae) self.first_stage_model.dtype = devices.dtype_vae self.vae = self.first_stage_model # real vae From 9617f15fd9c35eaafc35f343c42bbe10b7e6c6d1 Mon Sep 17 00:00:00 2001 From: Won-Kyu Park Date: Fri, 13 Sep 2024 20:00:49 +0900 Subject: [PATCH 21/50] pytest with --precision full --no-half --- .github/workflows/run_tests.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/run_tests.yaml b/.github/workflows/run_tests.yaml index 0610f4f54..eccf15e46 100644 --- a/.github/workflows/run_tests.yaml +++ b/.github/workflows/run_tests.yaml @@ -51,6 +51,7 @@ jobs: --test-server --do-not-download-clip --no-half + --precision full --disable-opt-split-attention --use-cpu all --api-server-stop From 3cdc26af3083e8fbe9763005f4cb382428bf6230 Mon Sep 17 00:00:00 2001 From: Won-Kyu Park Date: Sun, 15 Sep 2024 23:29:45 +0900 Subject: [PATCH 22/50] fix lora without backup --- extensions-builtin/Lora/networks.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/extensions-builtin/Lora/networks.py b/extensions-builtin/Lora/networks.py index 46c9ac1ff..dfdf3c7e6 100644 --- a/extensions-builtin/Lora/networks.py +++ b/extensions-builtin/Lora/networks.py @@ -397,7 +397,7 @@ def network_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Li if weights_backup is None and bias_backup is None: return - if weights_backup is True or weights_backup == (True, True): # fake backup + if shared.opts.lora_without_backup_weight: return if weights_backup is not None: @@ -567,7 +567,7 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1 - if weights_backup is True or weights_backup == (True, True): # fake backup + if shared.opts.lora_without_backup_weight: self.network_weights_backup = None self.network_bias_backup = None else: From 3b18b6f482b741af523a50a3340cd26b51f8e3d7 Mon Sep 17 00:00:00 2001 From: Won-Kyu Park Date: Mon, 16 Sep 2024 06:33:46 +0900 Subject: [PATCH 23/50] revert to use without_autocast() --- modules/sd_samplers_common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/sd_samplers_common.py b/modules/sd_samplers_common.py index 28b8bd820..b312c41d8 100644 --- a/modules/sd_samplers_common.py +++ b/modules/sd_samplers_common.py @@ -54,7 +54,7 @@ def samples_to_images_tensor(sample, approximation=None, model=None): else: if model is None: model = shared.sd_model - with torch.no_grad(), devices.manual_cast(devices.dtype_vae): # fixes an issue with unstable VAEs that are flaky even in fp32 + with torch.no_grad(), devices.without_autocast(): # fixes an issue with unstable VAEs that are flaky even in fp32 x_sample = model.decode_first_stage(sample.to(model.first_stage_model.dtype)) return x_sample From 2ffdf01e05c8e25926ea29ed1ea3c95ea9c26dc2 Mon Sep 17 00:00:00 2001 From: Won-Kyu Park Date: Tue, 17 Sep 2024 10:05:13 +0900 Subject: [PATCH 24/50] fix position_ids --- modules/sd_models.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/modules/sd_models.py b/modules/sd_models.py index 7e48c2328..3e0b577bb 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -469,6 +469,25 @@ def get_vae_dtype(state_dict=None, state_dict_dtype=None): return None +def fix_position_ids(state_dict, force=False): + # for SD1.5 or some SDXL with position_ids + for prefix in ("cond_stage_models.", "conditioner.embedders.0."): + position_id_key = f"{prefix}transformer.text_model.embeddings.position_ids" + if position_id_key in state_dict: + original = state_dict[position_id_key] + if original.dtype == torch.int64: + return + + if force: + # regenerate + fixed = torch.tensor([list(range(77))], dtype=torch.int64, device=original.device) + else: + fixed = state_dict[position_id_key].to(torch.int64) + print(f"Warning: Fixed position_ids dtype from {original.dtype} to {fixed.dtype}") + + state_dict[position_id_key] = fixed + + def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer): sd_model_hash = checkpoint_info.calculate_shorthash() timer.record("calculate hash") @@ -490,6 +509,9 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer else: model.ztsnr = False + fix_position_ids(state_dict) + + if model.is_sdxl: sd_models_xl.extend_sdxl(model) From 1e73a287075e6d3ef9059c8748b58972b6f5a367 Mon Sep 17 00:00:00 2001 From: Won-Kyu Park Date: Tue, 17 Sep 2024 10:07:58 +0900 Subject: [PATCH 25/50] fix for float8_e5m2 freeze model --- modules/sd_models.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/modules/sd_models.py b/modules/sd_models.py index 3e0b577bb..6dffdc036 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -585,7 +585,7 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer if found_unet_dtype in (torch.float16, torch.float32, torch.bfloat16): model.half() - elif found_unet_dtype in (torch.float8_e4m3fn,): + elif found_unet_dtype in (torch.float8_e4m3fn, torch.float8_e5m2): pass else: print("Fail to get a vaild UNet dtype. ignore...") @@ -612,7 +612,7 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer if hasattr(module, 'fp16_bias'): del module.fp16_bias - if found_unet_dtype not in (torch.float8_e4m3fn,) and check_fp8(model): + if found_unet_dtype not in (torch.float8_e4m3fn,torch.float8_e5m2) and check_fp8(model): devices.fp8 = True # do not convert vae, text_encoders.clip_l, clip_g, t5xxl From 1318f6118e14982e4288ebf5820459da1d00c86c Mon Sep 17 00:00:00 2001 From: Won-Kyu Park Date: Tue, 17 Sep 2024 11:05:48 +0900 Subject: [PATCH 26/50] fix load_vae() to check size mismatch --- modules/sd_vae.py | 28 ++++++++++++++++++++-------- 1 file changed, 20 insertions(+), 8 deletions(-) diff --git a/modules/sd_vae.py b/modules/sd_vae.py index 43687e48d..6ae038333 100644 --- a/modules/sd_vae.py +++ b/modules/sd_vae.py @@ -197,47 +197,58 @@ def load_vae(model, vae_file=None, vae_source="from unknown source"): cache_enabled = shared.opts.sd_vae_checkpoint_cache > 0 + loaded = False if vae_file: if cache_enabled and vae_file in checkpoints_loaded: # use vae checkpoint cache print(f"Loading VAE weights {vae_source}: cached {get_filename(vae_file)}") store_base_vae(model) - _load_vae_dict(model, checkpoints_loaded[vae_file]) + loaded = _load_vae_dict(model, checkpoints_loaded[vae_file]) else: assert os.path.isfile(vae_file), f"VAE {vae_source} doesn't exist: {vae_file}" print(f"Loading VAE weights {vae_source}: {vae_file}") store_base_vae(model) vae_dict_1 = load_vae_dict(vae_file, map_location=shared.weight_load_location) - _load_vae_dict(model, vae_dict_1) + loaded = _load_vae_dict(model, vae_dict_1) - if cache_enabled: + if loaded and cache_enabled: # cache newly loaded vae checkpoints_loaded[vae_file] = vae_dict_1.copy() # clean up cache if limit is reached - if cache_enabled: + if loaded and cache_enabled: while len(checkpoints_loaded) > shared.opts.sd_vae_checkpoint_cache + 1: # we need to count the current model checkpoints_loaded.popitem(last=False) # LRU # If vae used is not in dict, update it # It will be removed on refresh though vae_opt = get_filename(vae_file) - if vae_opt not in vae_dict: + if loaded and vae_opt not in vae_dict: vae_dict[vae_opt] = vae_file elif loaded_vae_file: restore_base_vae(model) + loaded = True - loaded_vae_file = vae_file + if loaded: + loaded_vae_file = vae_file model.base_vae = base_vae model.loaded_vae_file = loaded_vae_file + return loaded # don't call this from outside def _load_vae_dict(model, vae_dict_1): + conv_out = model.first_stage_model.state_dict().get("encoder.conv_out.weight") + # check shape of "encoder.conv_out.weight". SD1.5/SDXL: [8, 512, 3, 3], FLUX/SD3: [32, 512, 3, 3] + if conv_out.shape != vae_dict_1["encoder.conv_out.weight"].shape: + print("Failed to load VAE. Size mismatched!") + return False + model.first_stage_model.load_state_dict(vae_dict_1) model.first_stage_model.to(devices.dtype_vae) + return True def clear_loaded_vae(): @@ -270,7 +281,7 @@ def reload_vae_weights(sd_model=None, vae_file=unspecified): sd_hijack.model_hijack.undo_hijack(sd_model) - load_vae(sd_model, vae_file, vae_source) + loaded = load_vae(sd_model, vae_file, vae_source) sd_hijack.model_hijack.hijack(sd_model) @@ -279,5 +290,6 @@ def reload_vae_weights(sd_model=None, vae_file=unspecified): script_callbacks.model_loaded_callback(sd_model) - print("VAE weights loaded.") + if loaded: + print("VAE weights loaded.") return sd_model From eee7294200dd95fd686bbd7cab01f4955a8f7336 Mon Sep 17 00:00:00 2001 From: Won-Kyu Park Date: Tue, 17 Sep 2024 17:03:11 +0900 Subject: [PATCH 27/50] add fix_unet_prefix() to support unet only checkpoints --- modules/sd_models.py | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/modules/sd_models.py b/modules/sd_models.py index 6dffdc036..da9b7c6f4 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -282,6 +282,30 @@ def get_state_dict_from_checkpoint(pl_sd): return pl_sd +def fix_unet_prefix(state_dict): + known_prefixes = ("model.diffusion_model.", "first_stage_model.", "cond_stage_model.", "conditioner", "vae.", "text_encoders.") + + for k in state_dict.keys(): + found = [prefix for prefix in known_prefixes if k.startswith(prefix)] + if len(found) > 0: + return state_dict + + # no known prefix found. + # in this case, this is a unet only state_dict + known_keys = ( + "input_blocks.0.0.weight", # SD1.5, SD2, SDXL + "joint_blocks.0.context_block.adaLN_modulation.1.weight", # SD3 + "double_blocks.0.img_attn.proj.weight", # FLUX + ) + + if any(key in state_dict for key in known_keys): + state_dict = {f"model.diffusion_model.{k}": v for k, v in state_dict.items()} + print("Fixed state_dict keys...") + return state_dict + + return state_dict + + def read_metadata_from_safetensors(filename): import json @@ -343,6 +367,7 @@ def get_checkpoint_state_dict(checkpoint_info: CheckpointInfo, timer): print(f"Loading weights [{sd_model_hash}] from {checkpoint_info.filename}") res = read_state_dict(checkpoint_info.filename) + res = fix_unet_prefix(res) timer.record("load weights from disk") return res From 6675d1f090318a730dbb1636f8ddfe5626924f11 Mon Sep 17 00:00:00 2001 From: Won-Kyu Park Date: Wed, 18 Sep 2024 02:18:00 +0900 Subject: [PATCH 28/50] use assign=True for some cases --- modules/sd_disable_initialization.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/modules/sd_disable_initialization.py b/modules/sd_disable_initialization.py index aa57c3e06..d60fb5591 100644 --- a/modules/sd_disable_initialization.py +++ b/modules/sd_disable_initialization.py @@ -176,6 +176,11 @@ class LoadStateDictOnMeta(ReplaceHelper): def load_from_state_dict(original, module, state_dict, prefix, *args, **kwargs): used_param_keys = [] + if type(module) in (torch.nn.Linear, torch.nn.Conv2d, torch.nn.GroupNorm, torch.nn.LayerNorm,): + # HACK add assign=True to local_metadata for some cases + args[0]['assign_to_params_buffers'] = True + + for name, param in module._parameters.items(): if param is None: continue From 1f779226f0dc01d97fbc3ea556137fd19bd7e412 Mon Sep 17 00:00:00 2001 From: Won-Kyu Park Date: Thu, 19 Sep 2024 09:29:46 +0900 Subject: [PATCH 29/50] check lora_unet prefix to support Black Forest Labs's lora --- extensions-builtin/Lora/networks.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/extensions-builtin/Lora/networks.py b/extensions-builtin/Lora/networks.py index dfdf3c7e6..5f00d36e7 100644 --- a/extensions-builtin/Lora/networks.py +++ b/extensions-builtin/Lora/networks.py @@ -183,8 +183,12 @@ def load_network(name, network_on_disk): for key_network, weight in sd.items(): if diffusers_weight_map: - key_network_without_network_parts, network_name, network_weight = key_network.rsplit(".", 2) - network_part = network_name + '.' + network_weight + if key_network.startswith("lora_unet"): + key_network_without_network_parts, _, network_part = key_network.partition(".") + key_network_without_network_parts = key_network_without_network_parts.replace("lora_unet", "diffusion_model") + else: + key_network_without_network_parts, network_name, network_weight = key_network.rsplit(".", 2) + network_part = network_name + '.' + network_weight else: key_network_without_network_parts, _, network_part = key_network.partition(".") From 380e9a84c3e8ca37d5ad6f39d4aef75545fd80ca Mon Sep 17 00:00:00 2001 From: Won-Kyu Park Date: Thu, 19 Sep 2024 10:24:59 +0900 Subject: [PATCH 30/50] call lowvram.send_everything_to_cpu() for interrupted case --- modules/processing.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/modules/processing.py b/modules/processing.py index 3b23ab7af..ed140983d 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -1147,6 +1147,11 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: if p.scripts is not None: p.scripts.postprocess(p, res) + + if lowvram.is_enabled(shared.sd_model): + # for interrupted case + lowvram.send_everything_to_cpu() + return res From 71b430f703df99c0809904fa9facc9366d8c3c9e Mon Sep 17 00:00:00 2001 From: Won-Kyu Park Date: Thu, 19 Sep 2024 13:34:16 +0900 Subject: [PATCH 31/50] call torch_gc() to fix VRAM usage spike when call decode_first_stage() --- modules/lowvram.py | 1 + 1 file changed, 1 insertion(+) diff --git a/modules/lowvram.py b/modules/lowvram.py index 6728c337b..9914a06c6 100644 --- a/modules/lowvram.py +++ b/modules/lowvram.py @@ -53,6 +53,7 @@ def setup_for_low_vram(sd_model, use_medvram): if module_in_gpu is not None: module_in_gpu.to(cpu) + devices.torch_gc() module.to(devices.device) module_in_gpu = module From f569f6eb1e40849dbbc93938887a5623c0b6d078 Mon Sep 17 00:00:00 2001 From: Won-Kyu Park Date: Thu, 19 Sep 2024 23:56:07 +0900 Subject: [PATCH 32/50] use text_encoders.t5xxl.transformer.shared.weight tokens weights * some T5XXL do not have encoder.embed_tokens.weight. use shared.weight embed_tokens instead. * use float8 text encoder t5xxl_fp8_e4m3fn.safetensors --- modules/models/sd3/other_impls.py | 21 ++++++++++++++------- modules/models/sd3/sd3_cond.py | 12 ++++++------ 2 files changed, 20 insertions(+), 13 deletions(-) diff --git a/modules/models/sd3/other_impls.py b/modules/models/sd3/other_impls.py index 695d356cb..5492d5e11 100644 --- a/modules/models/sd3/other_impls.py +++ b/modules/models/sd3/other_impls.py @@ -449,7 +449,7 @@ class T5Attention(torch.nn.Module): else: mask = None - out = attention(q, k * ((k.shape[-1] / self.num_heads) ** 0.5), v, self.num_heads, mask.to(x.dtype) if mask is not None else None) + out = attention(q, k * ((k.shape[-1] / self.num_heads) ** 0.5), v, self.num_heads, mask.to(q.dtype) if mask is not None else None) return self.o(out), past_bias @@ -486,15 +486,17 @@ class T5Stack(torch.nn.Module): self.block = torch.nn.ModuleList([T5Block(model_dim, inner_dim, ff_dim, num_heads, relative_attention_bias=(i == 0), dtype=dtype, device=device) for i in range(num_layers)]) self.final_layer_norm = T5LayerNorm(model_dim, dtype=dtype, device=device) - def forward(self, input_ids, intermediate_output=None, final_layer_norm_intermediate=True): + def forward(self, x, intermediate_output=None, final_layer_norm_intermediate=True): intermediate = None - x = self.embed_tokens(input_ids).to(torch.float32) # needs float32 or else T5 returns all zeroes + #x = self.embed_tokens(input_ids).to(torch.float32) # needs float32 or else T5 returns all zeroes + # some T5XXL do not embed_token. use shared token instead like comfy past_bias = None for i, layer in enumerate(self.block): x, past_bias = layer(x, past_bias) if i == intermediate_output: intermediate = x.clone() x = self.final_layer_norm(x) + x = torch.nan_to_num(x) if intermediate is not None and final_layer_norm_intermediate: intermediate = self.final_layer_norm(intermediate) return x, intermediate @@ -505,13 +507,18 @@ class T5(torch.nn.Module): super().__init__() self.num_layers = config_dict["num_layers"] self.encoder = T5Stack(self.num_layers, config_dict["d_model"], config_dict["d_model"], config_dict["d_ff"], config_dict["num_heads"], config_dict["vocab_size"], dtype, device) + self.shared = torch.nn.Embedding(config_dict["vocab_size"], config_dict["d_model"]) self.dtype = dtype def get_input_embeddings(self): - return self.encoder.embed_tokens + #return self.encoder.embed_tokens + return self.shared def set_input_embeddings(self, embeddings): - self.encoder.embed_tokens = embeddings + #self.encoder.embed_tokens = embeddings + self.shared = embeddings - def forward(self, *args, **kwargs): - return self.encoder(*args, **kwargs) + def forward(self, input_ids, *args, **kwargs): + x = self.shared(input_ids).float() + x = torch.nan_to_num(x) + return self.encoder(x, *args, **kwargs) diff --git a/modules/models/sd3/sd3_cond.py b/modules/models/sd3/sd3_cond.py index 6a43f569b..7058950a5 100644 --- a/modules/models/sd3/sd3_cond.py +++ b/modules/models/sd3/sd3_cond.py @@ -43,7 +43,7 @@ CLIPG_CONFIG = { "textual_inversion_key": "clip_g", } -T5_URL = f"{shared.hf_endpoint}/AUTOMATIC/stable-diffusion-3-medium-text-encoders/resolve/main/t5xxl_fp16.safetensors" +T5_URL = f"{shared.hf_endpoint}/AUTOMATIC/stable-diffusion-3-medium-text-encoders/resolve/main/t5xxl_fp16_e4m3fn.safetensors" T5_CONFIG = { "d_ff": 10240, "d_model": 4096, @@ -164,11 +164,11 @@ class SD3Cond(torch.nn.Module): self.tokenizer = SD3Tokenizer() with torch.no_grad(): - self.clip_g = SDXLClipG(CLIPG_CONFIG, device="cpu", dtype=devices.dtype) - self.clip_l = SDClipModel(layer="hidden", layer_idx=-2, device="cpu", dtype=devices.dtype, layer_norm_hidden_state=False, return_projected_pooled=False, textmodel_json_config=CLIPL_CONFIG) + self.clip_g = SDXLClipG(CLIPG_CONFIG, device="cpu", dtype=devices.dtype_inference) + self.clip_l = SDClipModel(layer="hidden", layer_idx=-2, device="cpu", dtype=devices.dtype_inference, layer_norm_hidden_state=False, return_projected_pooled=False, textmodel_json_config=CLIPL_CONFIG) if shared.opts.sd3_enable_t5: - self.t5xxl = T5XXLModel(T5_CONFIG, device="cpu", dtype=devices.dtype) + self.t5xxl = T5XXLModel(T5_CONFIG, device="cpu", dtype=devices.dtype_inference) else: self.t5xxl = None @@ -199,8 +199,8 @@ class SD3Cond(torch.nn.Module): with safetensors.safe_open(clip_l_file, framework="pt") as file: self.clip_l.transformer.load_state_dict(SafetensorsMapping(file), strict=False) - if self.t5xxl and 'text_encoders.t5xxl.transformer.encoder.embed_tokens.weight' not in state_dict: - t5_file = modelloader.load_file_from_url(T5_URL, model_dir=clip_path, file_name="t5xxl_fp16.safetensors") + if self.t5xxl and 'text_encoders.t5xxl.transformer.encoder.block.0.layer.0.SelfAttention.k.weight' not in state_dict: + t5_file = modelloader.load_file_from_url(T5_URL, model_dir=clip_path, file_name="t5xxl_fp8_e4m3fn.safetensors") with safetensors.safe_open(t5_file, framework="pt") as file: self.t5xxl.transformer.load_state_dict(SafetensorsMapping(file), strict=False) From 28eca469594fabe9ddbae99c6557d40a04df1398 Mon Sep 17 00:00:00 2001 From: Won-Kyu Park Date: Fri, 20 Sep 2024 00:00:34 +0900 Subject: [PATCH 33/50] fix flux to use float8 t5xxl --- modules/models/flux/flux.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/models/flux/flux.py b/modules/models/flux/flux.py index 46fd568a0..42e9ea788 100644 --- a/modules/models/flux/flux.py +++ b/modules/models/flux/flux.py @@ -108,7 +108,7 @@ class FluxCond(torch.nn.Module): self.clip_l.transformer.load_state_dict(SafetensorsMapping(file), strict=False) if self.t5xxl and 'text_encoders.t5xxl.transformer.encoder.block.0.layer.0.SelfAttention.k.weight' not in state_dict: - t5_file = modelloader.load_file_from_url(T5_URL, model_dir=clip_path, file_name="t5xxl_fp16.safetensors") + t5_file = modelloader.load_file_from_url(T5_URL, model_dir=clip_path, file_name="t5xxl_fp8_e4m3fn.safetensors") with safetensors.safe_open(t5_file, framework="pt") as file: self.t5xxl.transformer.load_state_dict(SafetensorsMapping(file), strict=False) From 4bea93bc06c8c65f7ab93769ac4c3889ff3b9c72 Mon Sep 17 00:00:00 2001 From: Won-Kyu Park Date: Fri, 20 Sep 2024 08:27:58 +0900 Subject: [PATCH 34/50] fixed typo in the flux lora map --- modules/models/flux/flux.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/models/flux/flux.py b/modules/models/flux/flux.py index 42e9ea788..109c72904 100644 --- a/modules/models/flux/flux.py +++ b/modules/models/flux/flux.py @@ -360,4 +360,4 @@ class FLUX1Inferencer(torch.nn.Module): yield f"transformer.single_transformer_blocks.{i}.proj_mlp", f"diffusion_model_single_blocks_{i}_linear1_mlp_proj" yield f"transformer.single_transformer_blocks.{i}.proj_out", f"diffusion_model_single_blocks_{i}_linear2" - yield f"transformer.single_transformer_blocks.{i}.morm.linear", f"diffusion_model_single_blocks_{i}_modulation_lin" + yield f"transformer.single_transformer_blocks.{i}.norm.linear", f"diffusion_model_single_blocks_{i}_modulation_lin" From 30d0f950b7bb5afc7eb0dbfc10f0529714c4eaa5 Mon Sep 17 00:00:00 2001 From: Won-Kyu Park Date: Fri, 20 Sep 2024 12:30:38 +0900 Subject: [PATCH 35/50] fixed ai-toolkit flux lora support * fixed some mistake * some ai-toolkit's lora do not have proj_mlp --- extensions-builtin/Lora/networks.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/extensions-builtin/Lora/networks.py b/extensions-builtin/Lora/networks.py index 5f00d36e7..12e1c24e1 100644 --- a/extensions-builtin/Lora/networks.py +++ b/extensions-builtin/Lora/networks.py @@ -526,7 +526,7 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn module_mlp = net.modules.get(network_layer_name + "_mlp_proj", None) - if any(isinstance(self, linear) for linear in (modules.models.sd3.mmdit.QkvLinear, modules.models.flux.modules.layers.QkvLinear)) and module_q and module_k and module_v and module_mlp is None: + if any(isinstance(self, linear) for linear in (modules.models.sd3.mmdit.QkvLinear, modules.models.flux.modules.layers.QkvLinear)) and module_q and module_k and module_v and module_mlp is None and self.weight.shape[0] // 3 == module_q.up_model.weight.shape[0]: try: with torch.no_grad(): # Send "real" orig_weight into MHA's lora module @@ -545,14 +545,17 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn continue - if any(isinstance(self, linear) for linear in (modules.models.flux.modules.layers.QkvLinear,)) and module_q and module_k and module_v and module_mlp: + if any(isinstance(self, linear) for linear in (modules.models.flux.modules.layers.QkvLinear,)) and module_q and module_k and module_v: try: with torch.no_grad(): qw, kw, vw, mlp = torch.tensor_split(self.weight, (3072, 6144, 9216,), 0) updown_q, _ = module_q.calc_updown(qw) updown_k, _ = module_k.calc_updown(kw) updown_v, _ = module_v.calc_updown(vw) - updown_mlp, _ = module_v.calc_updown(mlp) + if module_mlp is not None: + updown_mlp, _ = module_mlp.calc_updown(mlp) + else: + updown_mlp = torch.zeros(3072 * 4, 3072, dtype=updown_q.dtype, device=updown_q.device) del qw, kw, vw, mlp updown_qkv_mlp = torch.vstack([updown_q, updown_k, updown_v, updown_mlp]) self.weight += updown_qkv_mlp From 11c9bc719c3deb7b02bcc85232eab3e3d7e2f4a8 Mon Sep 17 00:00:00 2001 From: Won-Kyu Park Date: Fri, 20 Sep 2024 13:07:27 +0900 Subject: [PATCH 36/50] make Sd3T5 shared.opts.sd3_enable_t5 independent --- modules/models/sd3/sd3_cond.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/models/sd3/sd3_cond.py b/modules/models/sd3/sd3_cond.py index 7058950a5..fc0232325 100644 --- a/modules/models/sd3/sd3_cond.py +++ b/modules/models/sd3/sd3_cond.py @@ -140,7 +140,7 @@ class Sd3T5(torch.nn.Module): return tokens, multipliers def forward(self, texts, *, token_count): - if not self.t5xxl or not shared.opts.sd3_enable_t5: + if not self.t5xxl: return torch.zeros((len(texts), token_count, 4096), device=devices.device, dtype=devices.dtype) tokens_batch = [] From 8c9c139c654f6a41f1b5757bcf38930a0186371d Mon Sep 17 00:00:00 2001 From: Won-Kyu Park Date: Sat, 21 Sep 2024 01:02:49 +0900 Subject: [PATCH 37/50] support Flux schnell and cleanup --- modules/models/flux/flux.py | 65 +++++++++++++++++++------------------ 1 file changed, 34 insertions(+), 31 deletions(-) diff --git a/modules/models/flux/flux.py b/modules/models/flux/flux.py index 109c72904..6326e97e6 100644 --- a/modules/models/flux/flux.py +++ b/modules/models/flux/flux.py @@ -133,14 +133,10 @@ def flux_time_shift(mu: float, sigma: float, t): return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) class ModelSamplingFlux(torch.nn.Module): - def __init__(self, model_config=None): + def __init__(self, shift=1.15): super().__init__() - if model_config is not None: - sampling_settings = model_config.sampling_settings - else: - sampling_settings = {} - self.set_parameters(shift=sampling_settings.get("shift", 1.15)) + self.set_parameters(shift=shift) def set_parameters(self, shift=1.15, timesteps=10000): self.shift = shift @@ -175,29 +171,13 @@ class ModelSamplingFlux(torch.nn.Module): class BaseModel(torch.nn.Module): """Wrapper around the core FLUX model""" - def __init__(self, shift=1.0, device=None, dtype=torch.float16, state_dict=None, prefix=""): + def __init__(self, shift=1.15, device=None, dtype=torch.float16, state_dict=None, prefix="", **kwargs): super().__init__() - params = dict( - image_model="flux", - in_channels=16, - vec_in_dim=768, - context_in_dim=4096, - hidden_size=3072, - mlp_ratio=4.0, - num_heads=24, - depth=19, - depth_single_blocks=38, - axes_dim=[16, 56, 56], - theta=10000, - qkv_bias=True, - guidance_embed=True, - ) - - self.diffusion_model = Flux(device=device, dtype=dtype, **params) - self.model_sampling = ModelSamplingFlux() - self.depth = params['depth'] - self.depth_single_block = params['depth_single_blocks'] + self.diffusion_model = Flux(device=device, dtype=dtype, **kwargs) + self.model_sampling = ModelSamplingFlux(shift=shift) + self.depth = kwargs['depth'] + self.depth_single_block = kwargs['depth_single_blocks'] def apply_model(self, x, sigma, c_crossattn=None, y=None): dtype = self.get_dtype() @@ -215,9 +195,9 @@ class BaseModel(torch.nn.Module): class FLUX1LatentFormat: """Latents are slightly shifted from center - this class must be called after VAE Decode to correct for the shift""" - def __init__(self): - self.scale_factor = 0.3611 - self.shift_factor = 0.1159 + def __init__(self, scale_factor=0.3611, shift_factor=0.1159): + self.scale_factor = scale_factor + self.shift_factor = shift_factor def process_in(self, latent): return (latent - self.shift_factor) * self.scale_factor @@ -260,6 +240,22 @@ class FLUX1Inferencer(torch.nn.Module): def __init__(self, state_dict, use_ema=False): super().__init__() + params = dict( + image_model="flux", + in_channels=16, + vec_in_dim=768, + context_in_dim=4096, + hidden_size=3072, + mlp_ratio=4.0, + num_heads=24, + depth=19, + depth_single_blocks=38, + axes_dim=[16, 56, 56], + theta=10000, + qkv_bias=True, + guidance_embed=True, + ) + # detect model_prefix diffusion_model_prefix = "model.diffusion_model." if "model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale" in state_dict: @@ -267,8 +263,15 @@ class FLUX1Inferencer(torch.nn.Module): elif "double_blocks.0.img_attn.norm.key_norm.scale" in state_dict: diffusion_model_prefix = "" + shift=1.15 + # check guidance_in to detect Flux schnell + if f"{diffusion_model_prefix}guidance_in.in_layer.weight" not in state_dict: + print("Flux schnell detected") + params.update(dict(guidance_embed=False,)) + shift=1.0 + with torch.no_grad(): - self.model = BaseModel(state_dict=state_dict, prefix=diffusion_model_prefix, device="cpu", dtype=devices.dtype_inference) + self.model = BaseModel(shift=shift, state_dict=state_dict, prefix=diffusion_model_prefix, device="cpu", dtype=devices.dtype_inference, **params) self.first_stage_model = SDVAE(device="cpu", dtype=devices.dtype_vae) self.first_stage_model.dtype = devices.dtype_vae self.vae = self.first_stage_model # real vae From 2e6533519ba379914f6e32abd361f7274b99686c Mon Sep 17 00:00:00 2001 From: Won-Kyu Park Date: Sun, 22 Sep 2024 00:42:45 +0900 Subject: [PATCH 38/50] fix some nn.Embedding to set dtype=float32 for some float8 freeze model --- modules/models/sd3/other_impls.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/modules/models/sd3/other_impls.py b/modules/models/sd3/other_impls.py index 5492d5e11..4524fa019 100644 --- a/modules/models/sd3/other_impls.py +++ b/modules/models/sd3/other_impls.py @@ -377,7 +377,7 @@ class T5Attention(torch.nn.Module): if relative_attention_bias: self.relative_attention_num_buckets = 32 self.relative_attention_max_distance = 128 - self.relative_attention_bias = torch.nn.Embedding(self.relative_attention_num_buckets, self.num_heads, device=device) + self.relative_attention_bias = torch.nn.Embedding(self.relative_attention_num_buckets, self.num_heads, device=device, dtype=torch.float32) @staticmethod def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128): @@ -482,7 +482,7 @@ class T5Block(torch.nn.Module): class T5Stack(torch.nn.Module): def __init__(self, num_layers, model_dim, inner_dim, ff_dim, num_heads, vocab_size, dtype, device): super().__init__() - self.embed_tokens = torch.nn.Embedding(vocab_size, model_dim, device=device) + #self.embed_tokens = torch.nn.Embedding(vocab_size, model_dim, device=device, dtype=torch.float32) self.block = torch.nn.ModuleList([T5Block(model_dim, inner_dim, ff_dim, num_heads, relative_attention_bias=(i == 0), dtype=dtype, device=device) for i in range(num_layers)]) self.final_layer_norm = T5LayerNorm(model_dim, dtype=dtype, device=device) @@ -507,7 +507,7 @@ class T5(torch.nn.Module): super().__init__() self.num_layers = config_dict["num_layers"] self.encoder = T5Stack(self.num_layers, config_dict["d_model"], config_dict["d_model"], config_dict["d_ff"], config_dict["num_heads"], config_dict["vocab_size"], dtype, device) - self.shared = torch.nn.Embedding(config_dict["vocab_size"], config_dict["d_model"]) + self.shared = torch.nn.Embedding(config_dict["vocab_size"], config_dict["d_model"], device=device, dtype=torch.float32) self.dtype = dtype def get_input_embeddings(self): From 03516f48f0911b33bc2c62b0aca6940e2641eaf9 Mon Sep 17 00:00:00 2001 From: Won-Kyu Park Date: Sun, 22 Sep 2024 21:45:57 +0900 Subject: [PATCH 39/50] use isinstance() --- modules/sd_disable_initialization.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/sd_disable_initialization.py b/modules/sd_disable_initialization.py index d60fb5591..47f98416e 100644 --- a/modules/sd_disable_initialization.py +++ b/modules/sd_disable_initialization.py @@ -176,7 +176,7 @@ class LoadStateDictOnMeta(ReplaceHelper): def load_from_state_dict(original, module, state_dict, prefix, *args, **kwargs): used_param_keys = [] - if type(module) in (torch.nn.Linear, torch.nn.Conv2d, torch.nn.GroupNorm, torch.nn.LayerNorm,): + if isinstance(module, (torch.nn.Linear, torch.nn.Conv2d, torch.nn.GroupNorm, torch.nn.LayerNorm)): # HACK add assign=True to local_metadata for some cases args[0]['assign_to_params_buffers'] = True From ba499f92ac1d6b6afb5242f86d1dca3816e60575 Mon Sep 17 00:00:00 2001 From: Won-Kyu Park Date: Wed, 25 Sep 2024 02:20:28 +0900 Subject: [PATCH 40/50] use shared.opts.lora_without_backup_weight option in the devices.autocast() * add nn.Embedding in the devices.autocast() * do not cast forward args for some cases * add copy option in the devices.autocast() --- modules/devices.py | 26 +++++++++++++++++--------- 1 file changed, 17 insertions(+), 9 deletions(-) diff --git a/modules/devices.py b/modules/devices.py index 866b6ab16..ec6ec5634 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -167,6 +167,7 @@ patch_module_list = [ torch.nn.MultiheadAttention, torch.nn.GroupNorm, torch.nn.LayerNorm, + torch.nn.Embedding, ] @@ -175,6 +176,10 @@ def manual_cast_forward(target_dtype, target_device=None, copy=False): if supports_non_blocking(): params['non_blocking'] = True + supported_cast_dtypes = [torch.float16, torch.float32] + if torch.cuda.is_bf16_supported(): + supported_cast_dtypes += [torch.bfloat16] + def forward_wrapper(self, *args, **kwargs): if target_device is not None: params['device'] = target_device @@ -182,11 +187,13 @@ def manual_cast_forward(target_dtype, target_device=None, copy=False): args = list(args) for j in (i for i, arg in enumerate(args) if isinstance(arg, torch.Tensor) and arg.dtype != target_dtype): - args[j] = args[j].to(**params) + if args[j].dtype in supported_cast_dtypes: + args[j] = args[j].to(**params) args = tuple(args) for key in (k for k, v in kwargs.items() if isinstance(v, torch.Tensor) and v.dtype != target_dtype): - kwargs[key] = kwargs[key].to(**params) + if kwargs[key].dtype in supported_cast_dtypes: + kwargs[key] = kwargs[key].to(**params) org_dtype = target_dtype for param in self.parameters(): @@ -227,10 +234,9 @@ def manual_cast_forward(target_dtype, target_device=None, copy=False): @contextlib.contextmanager -def manual_cast(target_dtype, target_device=None): +def manual_cast(target_dtype, target_device=None, copy=None): applied = False - copy = shared.opts.lora_without_backup_weight for module_type in patch_module_list: if hasattr(module_type, "org_forward"): @@ -252,10 +258,12 @@ def manual_cast(target_dtype, target_device=None): delattr(module_type, "org_forward") -def autocast(disable=False, current_dtype=None, target_dtype=None, target_device=None): +def autocast(disable=False, current_dtype=None, target_dtype=None, target_device=None, copy=None): if disable: return contextlib.nullcontext() + copy = copy if copy is not None else shared.opts.lora_without_backup_weight + if target_dtype is None: target_dtype = dtype if target_device is None: @@ -270,13 +278,13 @@ def autocast(disable=False, current_dtype=None, target_dtype=None, target_device return torch.autocast("cpu", dtype=torch.bfloat16, enabled=True) if fp8 and dtype_inference == torch.float32: - return manual_cast(target_dtype, target_device) + return manual_cast(target_dtype, target_device, copy=copy) - if target_dtype != dtype_inference: - return manual_cast(target_dtype, target_device) + if target_dtype != dtype_inference or copy: + return manual_cast(target_dtype, target_device, copy=copy) if current_dtype is not None and current_dtype != target_dtype: - return manual_cast(target_dtype, target_device) + return manual_cast(target_dtype, target_device, copy=copy) if target_dtype == torch.float32 or dtype_inference == torch.float32: return contextlib.nullcontext() From 5f3314ec43b099a06637a207a1d5e4252d59ad58 Mon Sep 17 00:00:00 2001 From: Won-Kyu Park Date: Wed, 25 Sep 2024 21:58:28 +0900 Subject: [PATCH 41/50] do not use copy option for nn.Embedding --- modules/devices.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/modules/devices.py b/modules/devices.py index ec6ec5634..5b763ec85 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -201,7 +201,7 @@ def manual_cast_forward(target_dtype, target_device=None, copy=False): org_dtype = param.dtype break - if copy: + if copy and not isinstance(self, torch.nn.Embedding): copied = deepcopy(self) if org_dtype != target_dtype: copied.to(**params) @@ -266,8 +266,6 @@ def autocast(disable=False, current_dtype=None, target_dtype=None, target_device if target_dtype is None: target_dtype = dtype - if target_device is None: - target_device = device if force_fp16: # No casting during inference if force_fp16 is enabled. From 4ad5f22c7b1c7f9fd7dd218d48cfd1f80d420fd6 Mon Sep 17 00:00:00 2001 From: Won-Kyu Park Date: Wed, 25 Sep 2024 22:12:46 +0900 Subject: [PATCH 42/50] do not use assing=True for nn.LayerNorm --- modules/sd_disable_initialization.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/sd_disable_initialization.py b/modules/sd_disable_initialization.py index 47f98416e..0261db08f 100644 --- a/modules/sd_disable_initialization.py +++ b/modules/sd_disable_initialization.py @@ -176,7 +176,7 @@ class LoadStateDictOnMeta(ReplaceHelper): def load_from_state_dict(original, module, state_dict, prefix, *args, **kwargs): used_param_keys = [] - if isinstance(module, (torch.nn.Linear, torch.nn.Conv2d, torch.nn.GroupNorm, torch.nn.LayerNorm)): + if isinstance(module, (torch.nn.Linear, torch.nn.Conv2d, torch.nn.GroupNorm,)): # HACK add assign=True to local_metadata for some cases args[0]['assign_to_params_buffers'] = True From 98cb284eb1942493d21afee56e1d2cfa7ea53a18 Mon Sep 17 00:00:00 2001 From: Won-Kyu Park Date: Sat, 28 Sep 2024 21:03:35 +0900 Subject: [PATCH 43/50] flux: clean up some dead code --- modules/models/flux/flux.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/modules/models/flux/flux.py b/modules/models/flux/flux.py index 6326e97e6..a7370af25 100644 --- a/modules/models/flux/flux.py +++ b/modules/models/flux/flux.py @@ -50,12 +50,6 @@ class Flux1ClipL(sd_hijack_clip.TextConditionalModel): return self.tokenizer(texts, truncation=False, add_special_tokens=False)["input_ids"] def encode_with_transformers(self, tokens): - tokens_g = tokens.clone() - - for batch_pos in range(tokens_g.shape[0]): - index = tokens_g[batch_pos].cpu().tolist().index(self.id_end) - tokens_g[batch_pos, index+1:tokens_g.shape[1]] = 0 - l_out, l_pooled = self.clip_l(tokens) l_out = torch.cat([l_out], dim=-1) l_out = torch.nn.functional.pad(l_out, (0, 4096 - l_out.shape[-1])) From 1d3dae1471c0050fe3995c6fa35b342d619f6496 Mon Sep 17 00:00:00 2001 From: Won-Kyu Park Date: Sat, 28 Sep 2024 23:19:08 +0900 Subject: [PATCH 44/50] task manager added based on https://github.com/lllyasviel/stable-diffusion-webui-forge/blob/main/modules_forge/main_thread.py * classified * this way, gc.collect() will work as intended. --- modules/call_queue.py | 4 +- modules/launch_utils.py | 6 +++ modules/manager.py | 83 +++++++++++++++++++++++++++++++++++++++++ webui.py | 21 ++++++++--- 4 files changed, 107 insertions(+), 7 deletions(-) create mode 100644 modules/manager.py diff --git a/modules/call_queue.py b/modules/call_queue.py index 555c35312..b20badcaf 100644 --- a/modules/call_queue.py +++ b/modules/call_queue.py @@ -3,7 +3,7 @@ from functools import wraps import html import time -from modules import shared, progress, errors, devices, fifo_lock, profiling +from modules import shared, progress, errors, devices, fifo_lock, profiling, manager queue_lock = fifo_lock.FIFOLock() @@ -34,7 +34,7 @@ def wrap_gradio_gpu_call(func, extra_outputs=None): progress.start_task(id_task) try: - res = func(*args, **kwargs) + res = manager.task.run_and_wait_result(func, *args, **kwargs) progress.record_results(id_task, res) finally: progress.finish_task(id_task) diff --git a/modules/launch_utils.py b/modules/launch_utils.py index 20c7dc127..5c868747e 100644 --- a/modules/launch_utils.py +++ b/modules/launch_utils.py @@ -463,11 +463,17 @@ def configure_for_tests(): def start(): print(f"Launching {'API server' if '--nowebui' in sys.argv else 'Web UI'} with arguments: {shlex.join(sys.argv[1:])}") import webui + + from modules import manager + if '--nowebui' in sys.argv: webui.api_only() else: webui.webui() + manager.task.main_loop() + return + def dump_sysinfo(): from modules import sysinfo diff --git a/modules/manager.py b/modules/manager.py new file mode 100644 index 000000000..34c67c6b3 --- /dev/null +++ b/modules/manager.py @@ -0,0 +1,83 @@ +# +# based on forge's work from https://github.com/lllyasviel/stable-diffusion-webui-forge/blob/main/modules_forge/main_thread.py +# +# Original author comment: +# This file is the main thread that handles all gradio calls for major t2i or i2i processing. +# Other gradio calls (like those from extensions) are not influenced. +# By using one single thread to process all major calls, model moving is significantly faster. +# +# 2024/09/28 classified, + +import random +import string +import threading +import time + +from collections import OrderedDict + + +class Task: + def __init__(self, **kwargs): + self.__dict__.update(kwargs) + + +class TaskManager: + last_exception = None + pending_tasks = [] + finished_tasks = OrderedDict() + lock = None + running = False + + def __init__(self): + self.lock = threading.Lock() + + def work(self, task): + try: + task.result = task.func(*task.args, **task.kwargs) + except Exception as e: + task.exception = e + self.last_exception = e + + + def stop(self): + self.running = False + + + def main_loop(self): + self.running = True + while self.running: + time.sleep(0.01) + if len(self.pending_tasks) > 0: + with self.lock: + task = self.pending_tasks.pop(0) + + self.work(task) + + self.finished_tasks[task.task_id] = task + + + def push_task(self, func, *args, **kwargs): + if args and type(args[0]) == str and args[0].startswith("task(") and args[0].endswith(")"): + task_id = args[0] + else: + task_id = ''.join(random.choices(string.ascii_uppercase + string.digits, k=7)) + task = Task(task_id=task_id, func=func, args=args, kwargs=kwargs, result=None, exception=None) + self.pending_tasks.append(task) + + return task.task_id + + + def run_and_wait_result(self, func, *args, **kwargs): + current_id = self.push_task(func, *args, **kwargs) + + while True: + time.sleep(0.01) + if current_id in self.finished_tasks: + finished = self.finished_tasks.pop(current_id) + if finished.exception is not None: + raise finished.exception + + return finished.result + + +task = TaskManager() diff --git a/webui.py b/webui.py index 421e3b833..398d83550 100644 --- a/webui.py +++ b/webui.py @@ -6,6 +6,8 @@ import time from modules import timer from modules import initialize_util from modules import initialize +from modules import manager +from threading import Thread startup_timer = timer.startup_timer startup_timer.record("launcher") @@ -14,6 +16,8 @@ initialize.imports() initialize.check_versions() +initialize.initialize() + def create_api(app): from modules.api.api import Api @@ -23,12 +27,10 @@ def create_api(app): return api -def api_only(): +def _api_only(): from fastapi import FastAPI from modules.shared_cmd_options import cmd_opts - initialize.initialize() - app = FastAPI() initialize_util.setup_middleware(app) api = create_api(app) @@ -83,11 +85,10 @@ For more information see: https://github.com/AUTOMATIC1111/stable-diffusion-webu {"!"*25} Warning {"!"*25}''') -def webui(): +def _webui(): from modules.shared_cmd_options import cmd_opts launch_api = cmd_opts.api - initialize.initialize() from modules import shared, ui_tempdir, script_callbacks, ui, progress, ui_extra_networks @@ -177,6 +178,7 @@ def webui(): print("Stopping server...") # If we catch a keyboard interrupt, we want to stop the server and exit. shared.demo.close() + manager.task.stop() break # disable auto launch webui in browser for subsequent UI Reload @@ -193,6 +195,13 @@ def webui(): initialize.initialize_rest(reload_script_modules=True) +def api_only(): + Thread(target=_api_only, daemon=True).start() + + +def webui(): + Thread(target=_webui, daemon=True).start() + if __name__ == "__main__": from modules.shared_cmd_options import cmd_opts @@ -200,3 +209,5 @@ if __name__ == "__main__": api_only() else: webui() + + manager.task.main_loop() From 0ab4d7992c4b3c65de7200a2adca0afa85907cc1 Mon Sep 17 00:00:00 2001 From: Won-Kyu Park Date: Wed, 2 Oct 2024 20:02:30 +0900 Subject: [PATCH 45/50] reduce backup_weight size for float8 freeze model --- extensions-builtin/Lora/networks.py | 15 +++-- .../Lora/scripts/lora_script.py | 61 ++++++++++++++++++- 2 files changed, 69 insertions(+), 7 deletions(-) diff --git a/extensions-builtin/Lora/networks.py b/extensions-builtin/Lora/networks.py index 12e1c24e1..76cef0a55 100644 --- a/extensions-builtin/Lora/networks.py +++ b/extensions-builtin/Lora/networks.py @@ -377,13 +377,13 @@ def allowed_layer_without_weight(layer): return False -def store_weights_backup(weight): +def store_weights_backup(weight, dtype): if weight is None: return None if shared.opts.lora_without_backup_weight: return True - return weight.to(devices.cpu, copy=True) + return weight.to(devices.cpu, dtype=dtype, copy=True) def restore_weights_backup(obj, field, weight): @@ -437,18 +437,18 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn raise RuntimeError(f"{network_layer_name} - no backup weights found and current weights are not unchanged") if isinstance(self, torch.nn.MultiheadAttention): - weights_backup = (store_weights_backup(self.in_proj_weight), store_weights_backup(self.out_proj.weight)) + weights_backup = (store_weights_backup(self.in_proj_weight, self.org_dtype), store_weights_backup(self.out_proj.weight, self.org_dtype)) else: - weights_backup = store_weights_backup(self.weight) + weights_backup = store_weights_backup(self.weight, self.org_dtype) self.network_weights_backup = weights_backup bias_backup = getattr(self, "network_bias_backup", None) if bias_backup is None and wanted_names != (): if isinstance(self, torch.nn.MultiheadAttention) and self.out_proj.bias is not None: - bias_backup = store_weights_backup(self.out_proj.bias) + bias_backup = store_weights_backup(self.out_proj.bias, self.org_dtype) elif getattr(self, 'bias', None) is not None: - bias_backup = store_weights_backup(self.bias) + bias_backup = store_weights_backup(self.bias, self.org_dtype) else: bias_backup = None @@ -487,6 +487,7 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn self.bias = torch.nn.Parameter(ex_bias).to(self.weight.dtype) else: self.bias.copy_((bias + ex_bias).to(dtype=self.bias.dtype)) + del weight, bias, updown, ex_bias except RuntimeError as e: logging.debug(f"Network {net.name} layer {network_layer_name}: {e}") extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1 @@ -538,6 +539,7 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn updown_qkv = torch.vstack([updown_q, updown_k, updown_v]) self.weight += updown_qkv del updown_qkv + del updown_q, updown_k, updown_v except RuntimeError as e: logging.debug(f"Network {net.name} layer {network_layer_name}: {e}") @@ -560,6 +562,7 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn updown_qkv_mlp = torch.vstack([updown_q, updown_k, updown_v, updown_mlp]) self.weight += updown_qkv_mlp del updown_qkv_mlp + del updown_q, updown_k, updown_v, updown_mlp except RuntimeError as e: logging.debug(f"Network {net.name} layer {network_layer_name}: {e}") diff --git a/extensions-builtin/Lora/scripts/lora_script.py b/extensions-builtin/Lora/scripts/lora_script.py index d3ea369ae..8ee93efef 100644 --- a/extensions-builtin/Lora/scripts/lora_script.py +++ b/extensions-builtin/Lora/scripts/lora_script.py @@ -1,4 +1,5 @@ import re +import torch import gradio as gr from fastapi import FastAPI @@ -9,7 +10,7 @@ import lora # noqa:F401 import lora_patches import extra_networks_lora import ui_extra_networks_lora -from modules import script_callbacks, ui_extra_networks, extra_networks, shared +from modules import script_callbacks, ui_extra_networks, extra_networks, shared, scripts, devices def unload(): @@ -97,6 +98,64 @@ def infotext_pasted(infotext, d): d["Prompt"] = re.sub(re_lora, network_replacement, d["Prompt"]) +class ScriptLora(scripts.Script): + name = "Lora" + + def title(self): + return self.name + + def show(self, is_img2img): + return scripts.AlwaysVisible + + def after_extra_networks_activate(self, p, *args, **kwargs): + # check modules and setup org_dtype + modules = [] + if shared.sd_model.is_sdxl: + for _i, embedder in enumerate(shared.sd_model.conditioner.embedders): + if not hasattr(embedder, 'wrapped'): + continue + + for _name, module in embedder.wrapped.named_modules(): + if isinstance(module, (torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention)): + if hasattr(module, 'weight'): + modules.append(module) + elif isinstance(module, torch.nn.MultiheadAttention): + modules.append(module) + + else: + cond_stage_model = getattr(shared.sd_model.cond_stage_model, 'wrapped', shared.sd_model.cond_stage_model) + + for _name, module in cond_stage_model.named_modules(): + if isinstance(module, (torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention)): + if hasattr(module, 'weight'): + modules.append(module) + elif isinstance(module, torch.nn.MultiheadAttention): + modules.append(module) + + for _name, module in shared.sd_model.model.named_modules(): + if isinstance(module, (torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention)): + if hasattr(module, 'weight'): + modules.append(module) + elif isinstance(module, torch.nn.MultiheadAttention): + modules.append(module) + + print("Total lora modules after_extra_networks_activate() =", len(modules)) + + target_dtype = devices.dtype_inference + for module in modules: + if isinstance(module, torch.nn.MultiheadAttention): + org_dtype = torch.float32 + else: + org_dtype = None + for _name, param in module.named_parameters(): + if param.dtype != target_dtype: + org_dtype = param.dtype + break + + # set org_dtype + module.org_dtype = org_dtype + + script_callbacks.on_infotext_pasted(infotext_pasted) shared.opts.onchange("lora_in_memory_limit", networks.purge_networks_from_memory) From 04f9084253bbbe62be0e4b4e8d757c4c62b216ae Mon Sep 17 00:00:00 2001 From: Won-Kyu Park Date: Thu, 3 Oct 2024 18:04:51 +0900 Subject: [PATCH 46/50] extract backup/restore io-bound operations out of forward hooks to speed up --- extensions-builtin/Lora/networks.py | 34 +++++++++++++------ .../Lora/scripts/lora_script.py | 16 +++++++++ 2 files changed, 40 insertions(+), 10 deletions(-) diff --git a/extensions-builtin/Lora/networks.py b/extensions-builtin/Lora/networks.py index 76cef0a55..e45b82387 100644 --- a/extensions-builtin/Lora/networks.py +++ b/extensions-builtin/Lora/networks.py @@ -417,16 +417,8 @@ def network_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Li restore_weights_backup(self, 'bias', bias_backup) -def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention]): - """ - Applies the currently selected set of networks to the weights of torch layer self. - If weights already have this particular set of networks applied, does nothing. - If not, restores original weights from backup and alters weights according to networks. - """ - +def network_backup_weights(self): network_layer_name = getattr(self, 'network_layer_name', None) - if network_layer_name is None: - return current_names = getattr(self, "network_current_names", ()) wanted_names = tuple((x.name, x.te_multiplier, x.unet_multiplier, x.dyn_dim) for x in loaded_networks) @@ -459,9 +451,31 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn self.network_bias_backup = bias_backup - if current_names != wanted_names: + +def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention]): + """ + Applies the currently selected set of networks to the weights of torch layer self. + If weights already have this particular set of networks applied, does nothing. + If not, restores original weights from backup and alters weights according to networks. + """ + + network_layer_name = getattr(self, 'network_layer_name', None) + if network_layer_name is None: + return + + current_names = getattr(self, "network_current_names", ()) + wanted_names = tuple((x.name, x.te_multiplier, x.unet_multiplier, x.dyn_dim) for x in loaded_networks) + + weights_backup = getattr(self, "network_weights_backup", None) + if weights_backup is None and wanted_names != (): + network_backup_weights(self) + elif current_names != () and current_names != wanted_names and not getattr(self, "weights_restored", False): network_restore_weights_from_backup(self) + if current_names != wanted_names: + if hasattr(self, "weights_restored"): + self.weights_restored = False + for net in loaded_networks: module = net.modules.get(network_layer_name, None) if module is not None and hasattr(self, 'weight') and not all(isinstance(module, linear) for linear in (modules.models.sd3.mmdit.QkvLinear, modules.models.flux.modules.layers.QkvLinear)): diff --git a/extensions-builtin/Lora/scripts/lora_script.py b/extensions-builtin/Lora/scripts/lora_script.py index 8ee93efef..8163a05f3 100644 --- a/extensions-builtin/Lora/scripts/lora_script.py +++ b/extensions-builtin/Lora/scripts/lora_script.py @@ -143,6 +143,10 @@ class ScriptLora(scripts.Script): target_dtype = devices.dtype_inference for module in modules: + network_layer_name = getattr(module, 'network_layer_name', None) + if network_layer_name is None: + continue + if isinstance(module, torch.nn.MultiheadAttention): org_dtype = torch.float32 else: @@ -155,6 +159,18 @@ class ScriptLora(scripts.Script): # set org_dtype module.org_dtype = org_dtype + # backup/restore weights + current_names = getattr(module, "network_current_names", ()) + wanted_names = tuple((x.name, x.te_multiplier, x.unet_multiplier, x.dyn_dim) for x in networks.loaded_networks) + + weights_backup = getattr(module, "network_weights_backup", None) + + if current_names == () and current_names != wanted_names and weights_backup is None: + networks.network_backup_weights(module) + elif current_names != () and current_names != wanted_names: + networks.network_restore_weights_from_backup(module) + module.weights_restored = True + script_callbacks.on_infotext_pasted(infotext_pasted) From 2a1988fa67e204805b8d63942527c3843fd9a3df Mon Sep 17 00:00:00 2001 From: Won-Kyu Park Date: Thu, 3 Oct 2024 19:25:50 +0900 Subject: [PATCH 47/50] call gc.collect() when wanted_names == () --- extensions-builtin/Lora/networks.py | 8 +++++++- extensions-builtin/Lora/scripts/lora_script.py | 5 ++++- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/extensions-builtin/Lora/networks.py b/extensions-builtin/Lora/networks.py index e45b82387..78d7407a0 100644 --- a/extensions-builtin/Lora/networks.py +++ b/extensions-builtin/Lora/networks.py @@ -394,7 +394,7 @@ def restore_weights_backup(obj, field, weight): getattr(obj, field).copy_(weight) -def network_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention]): +def network_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention], cleanup=False): weights_backup = getattr(self, "network_weights_backup", None) bias_backup = getattr(self, "network_bias_backup", None) @@ -416,6 +416,12 @@ def network_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Li else: restore_weights_backup(self, 'bias', bias_backup) + if cleanup: + if weights_backup is not None: + del self.network_weights_backup + if bias_backup is not None: + del self.network_bias_backup + def network_backup_weights(self): network_layer_name = getattr(self, 'network_layer_name', None) diff --git a/extensions-builtin/Lora/scripts/lora_script.py b/extensions-builtin/Lora/scripts/lora_script.py index 8163a05f3..7a23b8d57 100644 --- a/extensions-builtin/Lora/scripts/lora_script.py +++ b/extensions-builtin/Lora/scripts/lora_script.py @@ -4,6 +4,7 @@ import torch import gradio as gr from fastapi import FastAPI +import gc import network import networks import lora # noqa:F401 @@ -168,8 +169,10 @@ class ScriptLora(scripts.Script): if current_names == () and current_names != wanted_names and weights_backup is None: networks.network_backup_weights(module) elif current_names != () and current_names != wanted_names: - networks.network_restore_weights_from_backup(module) + networks.network_restore_weights_from_backup(module, wanted_names == ()) module.weights_restored = True + if current_names != wanted_names and wanted_names == (): + gc.collect() script_callbacks.on_infotext_pasted(infotext_pasted) From 412401becbf53c6c28d2fa46b9fe7d06cc2eaa79 Mon Sep 17 00:00:00 2001 From: Won-Kyu Park Date: Thu, 3 Oct 2024 19:42:55 +0900 Subject: [PATCH 48/50] backup only for needed weights required by lora --- extensions-builtin/Lora/networks.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/extensions-builtin/Lora/networks.py b/extensions-builtin/Lora/networks.py index 78d7407a0..0ad5f3e71 100644 --- a/extensions-builtin/Lora/networks.py +++ b/extensions-builtin/Lora/networks.py @@ -429,6 +429,18 @@ def network_backup_weights(self): current_names = getattr(self, "network_current_names", ()) wanted_names = tuple((x.name, x.te_multiplier, x.unet_multiplier, x.dyn_dim) for x in loaded_networks) + need_backup = False + for net in loaded_networks: + if network_layer_name in net.modules: + need_backup = True + break + elif network_layer_name + "_q_proj" in net.modules: + need_backup = True + break + + if not need_backup: + return + weights_backup = getattr(self, "network_weights_backup", None) if weights_backup is None and wanted_names != (): if current_names != () and not allowed_layer_without_weight(self): From b783a967c093a0471cb0d40b630bf3798ea2d6a1 Mon Sep 17 00:00:00 2001 From: Won-Kyu Park Date: Fri, 4 Oct 2024 00:07:23 +0900 Subject: [PATCH 49/50] fix for lazy backup --- extensions-builtin/Lora/networks.py | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/extensions-builtin/Lora/networks.py b/extensions-builtin/Lora/networks.py index 0ad5f3e71..66f262f4d 100644 --- a/extensions-builtin/Lora/networks.py +++ b/extensions-builtin/Lora/networks.py @@ -426,7 +426,7 @@ def network_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Li def network_backup_weights(self): network_layer_name = getattr(self, 'network_layer_name', None) - current_names = getattr(self, "network_current_names", ()) + _current_names = getattr(self, "network_current_names", ()) wanted_names = tuple((x.name, x.te_multiplier, x.unet_multiplier, x.dyn_dim) for x in loaded_networks) need_backup = False @@ -443,9 +443,6 @@ def network_backup_weights(self): weights_backup = getattr(self, "network_weights_backup", None) if weights_backup is None and wanted_names != (): - if current_names != () and not allowed_layer_without_weight(self): - raise RuntimeError(f"{network_layer_name} - no backup weights found and current weights are not unchanged") - if isinstance(self, torch.nn.MultiheadAttention): weights_backup = (store_weights_backup(self.in_proj_weight, self.org_dtype), store_weights_backup(self.out_proj.weight, self.org_dtype)) else: @@ -462,11 +459,6 @@ def network_backup_weights(self): else: bias_backup = None - # Unlike weight which always has value, some modules don't have bias. - # Only report if bias is not None and current bias are not unchanged. - if bias_backup is not None and current_names != (): - raise RuntimeError("no backup bias found and current bias are not unchanged") - self.network_bias_backup = bias_backup From 310d0e6938e74fd461a4bf5a33d9a5b4d977b29d Mon Sep 17 00:00:00 2001 From: Won-Kyu Park Date: Fri, 4 Oct 2024 00:08:44 +0900 Subject: [PATCH 50/50] restore org_dtype != compute dtype case --- extensions-builtin/Lora/networks.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/extensions-builtin/Lora/networks.py b/extensions-builtin/Lora/networks.py index 66f262f4d..948fa6740 100644 --- a/extensions-builtin/Lora/networks.py +++ b/extensions-builtin/Lora/networks.py @@ -391,7 +391,8 @@ def restore_weights_backup(obj, field, weight): setattr(obj, field, None) return - getattr(obj, field).copy_(weight) + old_weight = getattr(obj, field) + old_weight.copy_(weight.to(dtype=old_weight.dtype)) def network_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention], cleanup=False):