This commit is contained in:
Won-Kyu Park 2025-12-20 22:18:57 +03:00 committed by GitHub
commit 4c937cbfab
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
32 changed files with 1693 additions and 131 deletions

View file

@ -51,6 +51,7 @@ jobs:
--test-server
--do-not-download-clip
--no-half
--precision full
--disable-opt-split-attention
--use-cpu all
--api-server-stop

View file

@ -0,0 +1,4 @@
model:
target: modules.models.flux.FLUX1Inferencer
params:
state_dict: null

View file

@ -2,6 +2,7 @@ import torch
import lyco_helpers
import modules.models.sd3.mmdit
import modules.models.flux.modules.layers
import network
from modules import devices
@ -37,7 +38,7 @@ class NetworkModuleLora(network.NetworkModule):
if weight is None and none_ok:
return None
is_linear = type(self.sd_module) in [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear, torch.nn.MultiheadAttention, modules.models.sd3.mmdit.QkvLinear]
is_linear = type(self.sd_module) in [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear, torch.nn.MultiheadAttention, modules.models.sd3.mmdit.QkvLinear, modules.models.flux.modules.layers.QkvLinear ]
is_conv = type(self.sd_module) in [torch.nn.Conv2d]
if is_linear:

View file

@ -37,7 +37,7 @@ module_types = [
re_digits = re.compile(r"\d+")
re_x_proj = re.compile(r"(.*)_([qkv]_proj)$")
re_x_proj = re.compile(r"(.*)_((?:[qkv]|mlp)_proj)$")
re_compiled = {}
suffix_conversion = {
@ -183,8 +183,12 @@ def load_network(name, network_on_disk):
for key_network, weight in sd.items():
if diffusers_weight_map:
key_network_without_network_parts, network_name, network_weight = key_network.rsplit(".", 2)
network_part = network_name + '.' + network_weight
if key_network.startswith("lora_unet"):
key_network_without_network_parts, _, network_part = key_network.partition(".")
key_network_without_network_parts = key_network_without_network_parts.replace("lora_unet", "diffusion_model")
else:
key_network_without_network_parts, network_name, network_weight = key_network.rsplit(".", 2)
network_part = network_name + '.' + network_weight
else:
key_network_without_network_parts, _, network_part = key_network.partition(".")
@ -373,11 +377,13 @@ def allowed_layer_without_weight(layer):
return False
def store_weights_backup(weight):
def store_weights_backup(weight, dtype):
if weight is None:
return None
return weight.to(devices.cpu, copy=True)
if shared.opts.lora_without_backup_weight:
return True
return weight.to(devices.cpu, dtype=dtype, copy=True)
def restore_weights_backup(obj, field, weight):
@ -385,16 +391,20 @@ def restore_weights_backup(obj, field, weight):
setattr(obj, field, None)
return
getattr(obj, field).copy_(weight)
old_weight = getattr(obj, field)
old_weight.copy_(weight.to(dtype=old_weight.dtype))
def network_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention]):
def network_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention], cleanup=False):
weights_backup = getattr(self, "network_weights_backup", None)
bias_backup = getattr(self, "network_bias_backup", None)
if weights_backup is None and bias_backup is None:
return
if shared.opts.lora_without_backup_weight:
return
if weights_backup is not None:
if isinstance(self, torch.nn.MultiheadAttention):
restore_weights_backup(self, 'in_proj_weight', weights_backup[0])
@ -407,6 +417,51 @@ def network_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Li
else:
restore_weights_backup(self, 'bias', bias_backup)
if cleanup:
if weights_backup is not None:
del self.network_weights_backup
if bias_backup is not None:
del self.network_bias_backup
def network_backup_weights(self):
network_layer_name = getattr(self, 'network_layer_name', None)
_current_names = getattr(self, "network_current_names", ())
wanted_names = tuple((x.name, x.te_multiplier, x.unet_multiplier, x.dyn_dim) for x in loaded_networks)
need_backup = False
for net in loaded_networks:
if network_layer_name in net.modules:
need_backup = True
break
elif network_layer_name + "_q_proj" in net.modules:
need_backup = True
break
if not need_backup:
return
weights_backup = getattr(self, "network_weights_backup", None)
if weights_backup is None and wanted_names != ():
if isinstance(self, torch.nn.MultiheadAttention):
weights_backup = (store_weights_backup(self.in_proj_weight, self.org_dtype), store_weights_backup(self.out_proj.weight, self.org_dtype))
else:
weights_backup = store_weights_backup(self.weight, self.org_dtype)
self.network_weights_backup = weights_backup
bias_backup = getattr(self, "network_bias_backup", None)
if bias_backup is None and wanted_names != ():
if isinstance(self, torch.nn.MultiheadAttention) and self.out_proj.bias is not None:
bias_backup = store_weights_backup(self.out_proj.bias, self.org_dtype)
elif getattr(self, 'bias', None) is not None:
bias_backup = store_weights_backup(self.bias, self.org_dtype)
else:
bias_backup = None
self.network_bias_backup = bias_backup
def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention]):
"""
@ -424,38 +479,17 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
weights_backup = getattr(self, "network_weights_backup", None)
if weights_backup is None and wanted_names != ():
if current_names != () and not allowed_layer_without_weight(self):
raise RuntimeError(f"{network_layer_name} - no backup weights found and current weights are not unchanged")
if isinstance(self, torch.nn.MultiheadAttention):
weights_backup = (store_weights_backup(self.in_proj_weight), store_weights_backup(self.out_proj.weight))
else:
weights_backup = store_weights_backup(self.weight)
self.network_weights_backup = weights_backup
bias_backup = getattr(self, "network_bias_backup", None)
if bias_backup is None and wanted_names != ():
if isinstance(self, torch.nn.MultiheadAttention) and self.out_proj.bias is not None:
bias_backup = store_weights_backup(self.out_proj.bias)
elif getattr(self, 'bias', None) is not None:
bias_backup = store_weights_backup(self.bias)
else:
bias_backup = None
# Unlike weight which always has value, some modules don't have bias.
# Only report if bias is not None and current bias are not unchanged.
if bias_backup is not None and current_names != ():
raise RuntimeError("no backup bias found and current bias are not unchanged")
self.network_bias_backup = bias_backup
network_backup_weights(self)
elif current_names != () and current_names != wanted_names and not getattr(self, "weights_restored", False):
network_restore_weights_from_backup(self)
if current_names != wanted_names:
network_restore_weights_from_backup(self)
if hasattr(self, "weights_restored"):
self.weights_restored = False
for net in loaded_networks:
module = net.modules.get(network_layer_name, None)
if module is not None and hasattr(self, 'weight') and not isinstance(module, modules.models.sd3.mmdit.QkvLinear):
if module is not None and hasattr(self, 'weight') and not all(isinstance(module, linear) for linear in (modules.models.sd3.mmdit.QkvLinear, modules.models.flux.modules.layers.QkvLinear)):
try:
with torch.no_grad():
if getattr(self, 'fp16_weight', None) is None:
@ -478,6 +512,7 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
self.bias = torch.nn.Parameter(ex_bias).to(self.weight.dtype)
else:
self.bias.copy_((bias + ex_bias).to(dtype=self.bias.dtype))
del weight, bias, updown, ex_bias
except RuntimeError as e:
logging.debug(f"Network {net.name} layer {network_layer_name}: {e}")
extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1
@ -515,7 +550,9 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
continue
if isinstance(self, modules.models.sd3.mmdit.QkvLinear) and module_q and module_k and module_v:
module_mlp = net.modules.get(network_layer_name + "_mlp_proj", None)
if any(isinstance(self, linear) for linear in (modules.models.sd3.mmdit.QkvLinear, modules.models.flux.modules.layers.QkvLinear)) and module_q and module_k and module_v and module_mlp is None and self.weight.shape[0] // 3 == module_q.up_model.weight.shape[0]:
try:
with torch.no_grad():
# Send "real" orig_weight into MHA's lora module
@ -526,6 +563,31 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
del qw, kw, vw
updown_qkv = torch.vstack([updown_q, updown_k, updown_v])
self.weight += updown_qkv
del updown_qkv
del updown_q, updown_k, updown_v
except RuntimeError as e:
logging.debug(f"Network {net.name} layer {network_layer_name}: {e}")
extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1
continue
if any(isinstance(self, linear) for linear in (modules.models.flux.modules.layers.QkvLinear,)) and module_q and module_k and module_v:
try:
with torch.no_grad():
qw, kw, vw, mlp = torch.tensor_split(self.weight, (3072, 6144, 9216,), 0)
updown_q, _ = module_q.calc_updown(qw)
updown_k, _ = module_k.calc_updown(kw)
updown_v, _ = module_v.calc_updown(vw)
if module_mlp is not None:
updown_mlp, _ = module_mlp.calc_updown(mlp)
else:
updown_mlp = torch.zeros(3072 * 4, 3072, dtype=updown_q.dtype, device=updown_q.device)
del qw, kw, vw, mlp
updown_qkv_mlp = torch.vstack([updown_q, updown_k, updown_v, updown_mlp])
self.weight += updown_qkv_mlp
del updown_qkv_mlp
del updown_q, updown_k, updown_v, updown_mlp
except RuntimeError as e:
logging.debug(f"Network {net.name} layer {network_layer_name}: {e}")
@ -539,7 +601,12 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
logging.debug(f"Network {net.name} layer {network_layer_name}: couldn't find supported operation")
extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1
self.network_current_names = wanted_names
if shared.opts.lora_without_backup_weight:
self.network_weights_backup = None
self.network_bias_backup = None
else:
self.network_current_names = wanted_names
def network_forward(org_module, input, original_forward):

View file

@ -1,15 +1,17 @@
import re
import torch
import gradio as gr
from fastapi import FastAPI
import gc
import network
import networks
import lora # noqa:F401
import lora_patches
import extra_networks_lora
import ui_extra_networks_lora
from modules import script_callbacks, ui_extra_networks, extra_networks, shared
from modules import script_callbacks, ui_extra_networks, extra_networks, shared, scripts, devices
def unload():
@ -97,6 +99,82 @@ def infotext_pasted(infotext, d):
d["Prompt"] = re.sub(re_lora, network_replacement, d["Prompt"])
class ScriptLora(scripts.Script):
name = "Lora"
def title(self):
return self.name
def show(self, is_img2img):
return scripts.AlwaysVisible
def after_extra_networks_activate(self, p, *args, **kwargs):
# check modules and setup org_dtype
modules = []
if shared.sd_model.is_sdxl:
for _i, embedder in enumerate(shared.sd_model.conditioner.embedders):
if not hasattr(embedder, 'wrapped'):
continue
for _name, module in embedder.wrapped.named_modules():
if isinstance(module, (torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention)):
if hasattr(module, 'weight'):
modules.append(module)
elif isinstance(module, torch.nn.MultiheadAttention):
modules.append(module)
else:
cond_stage_model = getattr(shared.sd_model.cond_stage_model, 'wrapped', shared.sd_model.cond_stage_model)
for _name, module in cond_stage_model.named_modules():
if isinstance(module, (torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention)):
if hasattr(module, 'weight'):
modules.append(module)
elif isinstance(module, torch.nn.MultiheadAttention):
modules.append(module)
for _name, module in shared.sd_model.model.named_modules():
if isinstance(module, (torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention)):
if hasattr(module, 'weight'):
modules.append(module)
elif isinstance(module, torch.nn.MultiheadAttention):
modules.append(module)
print("Total lora modules after_extra_networks_activate() =", len(modules))
target_dtype = devices.dtype_inference
for module in modules:
network_layer_name = getattr(module, 'network_layer_name', None)
if network_layer_name is None:
continue
if isinstance(module, torch.nn.MultiheadAttention):
org_dtype = torch.float32
else:
org_dtype = None
for _name, param in module.named_parameters():
if param.dtype != target_dtype:
org_dtype = param.dtype
break
# set org_dtype
module.org_dtype = org_dtype
# backup/restore weights
current_names = getattr(module, "network_current_names", ())
wanted_names = tuple((x.name, x.te_multiplier, x.unet_multiplier, x.dyn_dim) for x in networks.loaded_networks)
weights_backup = getattr(module, "network_weights_backup", None)
if current_names == () and current_names != wanted_names and weights_backup is None:
networks.network_backup_weights(module)
elif current_names != () and current_names != wanted_names:
networks.network_restore_weights_from_backup(module, wanted_names == ())
module.weights_restored = True
if current_names != wanted_names and wanted_names == ():
gc.collect()
script_callbacks.on_infotext_pasted(infotext_pasted)
shared.opts.onchange("lora_in_memory_limit", networks.purge_networks_from_memory)

View file

@ -3,7 +3,7 @@ from functools import wraps
import html
import time
from modules import shared, progress, errors, devices, fifo_lock, profiling
from modules import shared, progress, errors, devices, fifo_lock, profiling, manager
queue_lock = fifo_lock.FIFOLock()
@ -34,7 +34,7 @@ def wrap_gradio_gpu_call(func, extra_outputs=None):
progress.start_task(id_task)
try:
res = func(*args, **kwargs)
res = manager.task.run_and_wait_result(func, *args, **kwargs)
progress.record_results(id_task, res)
finally:
progress.finish_task(id_task)

View file

@ -1,5 +1,6 @@
import sys
import contextlib
from copy import deepcopy
from functools import lru_cache
import torch
@ -128,6 +129,26 @@ dtype_unet: torch.dtype = torch.float16
dtype_inference: torch.dtype = torch.float16
unet_needs_upcast = False
supported_vae_dtypes = [torch.float16, torch.float32]
# prepare available dtypes
if torch.version.cuda:
if torch.cuda.is_bf16_supported() and torch.cuda.get_device_properties(torch.cuda.current_device()).major >= 8:
supported_vae_dtypes = [torch.bfloat16] + supported_vae_dtypes
if has_xpu():
supported_vae_dtypes = [torch.bfloat16] + supported_vae_dtypes
def supports_non_blocking():
if has_mps() or has_xpu():
return False
if npu_specific.has_npu:
return False
return True
def cond_cast_unet(input):
if force_fp16:
@ -146,17 +167,33 @@ patch_module_list = [
torch.nn.MultiheadAttention,
torch.nn.GroupNorm,
torch.nn.LayerNorm,
torch.nn.Embedding,
]
def manual_cast_forward(target_dtype):
def manual_cast_forward(target_dtype, target_device=None, copy=False):
params = dict()
if supports_non_blocking():
params['non_blocking'] = True
supported_cast_dtypes = [torch.float16, torch.float32]
if torch.cuda.is_bf16_supported():
supported_cast_dtypes += [torch.bfloat16]
def forward_wrapper(self, *args, **kwargs):
if any(
isinstance(arg, torch.Tensor) and arg.dtype != target_dtype
for arg in args
):
args = [arg.to(target_dtype) if isinstance(arg, torch.Tensor) else arg for arg in args]
kwargs = {k: v.to(target_dtype) if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()}
if target_device is not None:
params['device'] = target_device
params['dtype'] = target_dtype
args = list(args)
for j in (i for i, arg in enumerate(args) if isinstance(arg, torch.Tensor) and arg.dtype != target_dtype):
if args[j].dtype in supported_cast_dtypes:
args[j] = args[j].to(**params)
args = tuple(args)
for key in (k for k, v in kwargs.items() if isinstance(v, torch.Tensor) and v.dtype != target_dtype):
if kwargs[key].dtype in supported_cast_dtypes:
kwargs[key] = kwargs[key].to(**params)
org_dtype = target_dtype
for param in self.parameters():
@ -164,38 +201,52 @@ def manual_cast_forward(target_dtype):
org_dtype = param.dtype
break
if org_dtype != target_dtype:
self.to(target_dtype)
result = self.org_forward(*args, **kwargs)
if org_dtype != target_dtype:
self.to(org_dtype)
if copy and not isinstance(self, torch.nn.Embedding):
copied = deepcopy(self)
if org_dtype != target_dtype:
copied.to(**params)
result = copied.org_forward(*args, **kwargs)
del copied
else:
if org_dtype != target_dtype:
self.to(**params)
result = self.org_forward(*args, **kwargs)
if org_dtype != target_dtype:
params['dtype'] = org_dtype
self.to(**params)
if target_dtype != dtype_inference:
params['dtype'] = dtype_inference
if isinstance(result, tuple):
result = tuple(
i.to(dtype_inference)
i.to(**params)
if isinstance(i, torch.Tensor)
else i
for i in result
)
elif isinstance(result, torch.Tensor):
result = result.to(dtype_inference)
result = result.to(**params)
return result
return forward_wrapper
@contextlib.contextmanager
def manual_cast(target_dtype):
def manual_cast(target_dtype, target_device=None, copy=None):
applied = False
for module_type in patch_module_list:
if hasattr(module_type, "org_forward"):
continue
applied = True
org_forward = module_type.forward
if module_type == torch.nn.MultiheadAttention:
module_type.forward = manual_cast_forward(torch.float32)
module_type.forward = manual_cast_forward(torch.float32, target_device, copy)
else:
module_type.forward = manual_cast_forward(target_dtype)
module_type.forward = manual_cast_forward(target_dtype, target_device, copy)
module_type.org_forward = org_forward
try:
yield None
@ -207,26 +258,37 @@ def manual_cast(target_dtype):
delattr(module_type, "org_forward")
def autocast(disable=False):
def autocast(disable=False, current_dtype=None, target_dtype=None, target_device=None, copy=None):
if disable:
return contextlib.nullcontext()
copy = copy if copy is not None else shared.opts.lora_without_backup_weight
if target_dtype is None:
target_dtype = dtype
if force_fp16:
# No casting during inference if force_fp16 is enabled.
# All tensor dtype conversion happens before inference.
return contextlib.nullcontext()
if fp8 and device==cpu:
if fp8 and target_device==cpu:
return torch.autocast("cpu", dtype=torch.bfloat16, enabled=True)
if fp8 and dtype_inference == torch.float32:
return manual_cast(dtype)
return manual_cast(target_dtype, target_device, copy=copy)
if dtype == torch.float32 or dtype_inference == torch.float32:
if target_dtype != dtype_inference or copy:
return manual_cast(target_dtype, target_device, copy=copy)
if current_dtype is not None and current_dtype != target_dtype:
return manual_cast(target_dtype, target_device, copy=copy)
if target_dtype == torch.float32 or dtype_inference == torch.float32:
return contextlib.nullcontext()
if has_xpu() or has_mps() or cuda_no_autocast():
return manual_cast(dtype)
return manual_cast(target_dtype, target_device)
return torch.autocast("cuda")

View file

@ -495,11 +495,17 @@ def configure_for_tests():
def start():
print(f"Launching {'API server' if '--nowebui' in sys.argv else 'Web UI'} with arguments: {shlex.join(sys.argv[1:])}")
import webui
from modules import manager
if '--nowebui' in sys.argv:
webui.api_only()
else:
webui.webui()
manager.task.main_loop()
return
def dump_sysinfo():
from modules import sysinfo

View file

@ -53,6 +53,7 @@ def setup_for_low_vram(sd_model, use_medvram):
if module_in_gpu is not None:
module_in_gpu.to(cpu)
devices.torch_gc()
module.to(devices.device)
module_in_gpu = module

83
modules/manager.py Normal file
View file

@ -0,0 +1,83 @@
#
# based on forge's work from https://github.com/lllyasviel/stable-diffusion-webui-forge/blob/main/modules_forge/main_thread.py
#
# Original author comment:
# This file is the main thread that handles all gradio calls for major t2i or i2i processing.
# Other gradio calls (like those from extensions) are not influenced.
# By using one single thread to process all major calls, model moving is significantly faster.
#
# 2024/09/28 classified,
import random
import string
import threading
import time
from collections import OrderedDict
class Task:
def __init__(self, **kwargs):
self.__dict__.update(kwargs)
class TaskManager:
last_exception = None
pending_tasks = []
finished_tasks = OrderedDict()
lock = None
running = False
def __init__(self):
self.lock = threading.Lock()
def work(self, task):
try:
task.result = task.func(*task.args, **task.kwargs)
except Exception as e:
task.exception = e
self.last_exception = e
def stop(self):
self.running = False
def main_loop(self):
self.running = True
while self.running:
time.sleep(0.01)
if len(self.pending_tasks) > 0:
with self.lock:
task = self.pending_tasks.pop(0)
self.work(task)
self.finished_tasks[task.task_id] = task
def push_task(self, func, *args, **kwargs):
if args and type(args[0]) == str and args[0].startswith("task(") and args[0].endswith(")"):
task_id = args[0]
else:
task_id = ''.join(random.choices(string.ascii_uppercase + string.digits, k=7))
task = Task(task_id=task_id, func=func, args=args, kwargs=kwargs, result=None, exception=None)
self.pending_tasks.append(task)
return task.task_id
def run_and_wait_result(self, func, *args, **kwargs):
current_id = self.push_task(func, *args, **kwargs)
while True:
time.sleep(0.01)
if current_id in self.finished_tasks:
finished = self.finished_tasks.pop(current_id)
if finished.exception is not None:
raise finished.exception
return finished.result
task = TaskManager()

View file

@ -0,0 +1,5 @@
from .flux import FLUX1Inferencer
__all__ = [
"FLUX1Inferencer",
]

360
modules/models/flux/flux.py Normal file
View file

@ -0,0 +1,360 @@
import contextlib
import os
import safetensors
import torch
import math
import k_diffusion
from transformers import CLIPTokenizer
from modules import shared, devices, modelloader, sd_hijack_clip
from modules.models.sd3.sd3_impls import SDVAE
from modules.models.sd3.sd3_cond import CLIPL_CONFIG, T5_CONFIG, CLIPL_URL, T5_URL, SafetensorsMapping, Sd3T5
from modules.models.sd3.other_impls import SDClipModel, T5XXLModel, SDTokenizer, T5XXLTokenizer
from PIL import Image
from .model import Flux
class FluxTokenizer:
def __init__(self):
clip_tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
self.clip_l = SDTokenizer(tokenizer=clip_tokenizer)
self.t5xxl = T5XXLTokenizer()
def tokenize_with_weights(self, text:str):
out = {}
out["l"] = self.clip_l.tokenize_with_weights(text)
out["t5xxl"] = self.t5xxl.tokenize_with_weights(text)
return out
class Flux1ClipL(sd_hijack_clip.TextConditionalModel):
def __init__(self, clip_l):
super().__init__()
self.clip_l = clip_l
self.tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
empty = self.tokenizer('')["input_ids"]
self.id_start = empty[0]
self.id_end = empty[1]
self.id_pad = empty[1]
self.return_pooled = True
def tokenize(self, texts):
return self.tokenizer(texts, truncation=False, add_special_tokens=False)["input_ids"]
def encode_with_transformers(self, tokens):
l_out, l_pooled = self.clip_l(tokens)
l_out = torch.cat([l_out], dim=-1)
l_out = torch.nn.functional.pad(l_out, (0, 4096 - l_out.shape[-1]))
vector_out = torch.cat([l_pooled], dim=-1)
l_out.pooled = vector_out
return l_out
def encode_embedding_init_text(self, init_text, nvpt):
return torch.zeros((nvpt, 768+1280), device=devices.device) # XXX
class FluxCond(torch.nn.Module):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.tokenizer = FluxTokenizer()
with torch.no_grad():
self.clip_l = SDClipModel(layer="hidden", layer_idx=-2, device="cpu", dtype=devices.dtype_inference, layer_norm_hidden_state=False, return_projected_pooled=False, textmodel_json_config=CLIPL_CONFIG)
if shared.opts.flux_enable_t5:
self.t5xxl = T5XXLModel(T5_CONFIG, device="cpu", dtype=devices.dtype_inference)
else:
self.t5xxl = None
self.model_l = Flux1ClipL(self.clip_l)
self.model_t5 = Sd3T5(self.t5xxl)
def forward(self, prompts: list[str]):
with devices.without_autocast():
l_out, vector_out = self.model_l(prompts)
t5_out = self.model_t5(prompts, token_count=l_out.shape[1])
lt_out = torch.cat([l_out, t5_out], dim=-2)
return {
'crossattn': lt_out,
'vector': vector_out,
}
def before_load_weights(self, state_dict):
clip_path = os.path.join(shared.models_path, "CLIP")
if 'text_encoders.clip_l.transformer.text_model.embeddings.position_embedding.weight' not in state_dict:
clip_l_file = modelloader.load_file_from_url(CLIPL_URL, model_dir=clip_path, file_name="clip_l.safetensors")
with safetensors.safe_open(clip_l_file, framework="pt") as file:
self.clip_l.transformer.load_state_dict(SafetensorsMapping(file), strict=False)
if self.t5xxl and 'text_encoders.t5xxl.transformer.encoder.block.0.layer.0.SelfAttention.k.weight' not in state_dict:
t5_file = modelloader.load_file_from_url(T5_URL, model_dir=clip_path, file_name="t5xxl_fp8_e4m3fn.safetensors")
with safetensors.safe_open(t5_file, framework="pt") as file:
self.t5xxl.transformer.load_state_dict(SafetensorsMapping(file), strict=False)
def encode_embedding_init_text(self, init_text, nvpt):
return self.model_l.encode_embedding_init_text(init_text, nvpt)
def tokenize(self, texts):
return self.model_l.tokenize(texts)
def medvram_modules(self):
return [self.clip_l, self.t5xxl]
def get_token_count(self, text):
_, token_count = self.model_l.process_texts([text])
return token_count
def get_target_prompt_token_count(self, token_count):
return self.model_l.get_target_prompt_token_count(token_count)
def flux_time_shift(mu: float, sigma: float, t):
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
class ModelSamplingFlux(torch.nn.Module):
def __init__(self, shift=1.15):
super().__init__()
self.set_parameters(shift=shift)
def set_parameters(self, shift=1.15, timesteps=10000):
self.shift = shift
ts = self.sigma((torch.arange(1, timesteps + 1, 1) / timesteps))
self.register_buffer('sigmas', ts)
@property
def sigma_min(self):
return self.sigmas[0]
@property
def sigma_max(self):
return self.sigmas[-1]
def timestep(self, sigma):
return sigma
def sigma(self, timestep):
return flux_time_shift(self.shift, 1.0, timestep)
def percent_to_sigma(self, percent):
if percent <= 0.0:
return 1.0
if percent >= 1.0:
return 0.0
return 1.0 - percent
def calculate_denoised(self, sigma, model_output, model_input):
sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
return model_input - model_output * sigma
class BaseModel(torch.nn.Module):
"""Wrapper around the core FLUX model"""
def __init__(self, shift=1.15, device=None, dtype=torch.float16, state_dict=None, prefix="", **kwargs):
super().__init__()
self.diffusion_model = Flux(device=device, dtype=dtype, **kwargs)
self.model_sampling = ModelSamplingFlux(shift=shift)
self.depth = kwargs['depth']
self.depth_single_block = kwargs['depth_single_blocks']
def apply_model(self, x, sigma, c_crossattn=None, y=None):
dtype = self.get_dtype()
timestep = self.model_sampling.timestep(sigma).float()
guidance = torch.FloatTensor([3.5]).to(device=devices.device, dtype=torch.float32)
model_output = self.diffusion_model(x.to(dtype), timestep, context=c_crossattn.to(dtype), y=y.to(dtype), guidance=guidance).to(x.dtype)
return self.model_sampling.calculate_denoised(sigma, model_output, x)
def forward(self, *args, **kwargs):
return self.apply_model(*args, **kwargs)
def get_dtype(self):
return self.diffusion_model.dtype
class FLUX1LatentFormat:
"""Latents are slightly shifted from center - this class must be called after VAE Decode to correct for the shift"""
def __init__(self, scale_factor=0.3611, shift_factor=0.1159):
self.scale_factor = scale_factor
self.shift_factor = shift_factor
def process_in(self, latent):
return (latent - self.shift_factor) * self.scale_factor
def process_out(self, latent):
return (latent / self.scale_factor) + self.shift_factor
def decode_latent_to_preview(self, x0):
"""Quick RGB approximate preview of sd3 latents"""
factors = torch.tensor([
[-0.0404, 0.0159, 0.0609], [ 0.0043, 0.0298, 0.0850],
[ 0.0328, -0.0749, -0.0503], [-0.0245, 0.0085, 0.0549],
[ 0.0966, 0.0894, 0.0530], [ 0.0035, 0.0399, 0.0123],
[ 0.0583, 0.1184, 0.1262], [-0.0191, -0.0206, -0.0306],
[-0.0324, 0.0055, 0.1001], [ 0.0955, 0.0659, -0.0545],
[-0.0504, 0.0231, -0.0013], [ 0.0500, -0.0008, -0.0088],
[ 0.0982, 0.0941, 0.0976], [-0.1233, -0.0280, -0.0897],
[-0.0005, -0.0530, -0.0020], [-0.1273, -0.0932, -0.0680],
], device="cpu")
latent_image = x0[0].permute(1, 2, 0).cpu() @ factors
latents_ubyte = (((latent_image + 1) / 2)
.clamp(0, 1) # change scale from -1..1 to 0..1
.mul(0xFF) # to 0..255
.byte()).cpu()
return Image.fromarray(latents_ubyte.numpy())
class FLUX1Denoiser(k_diffusion.external.DiscreteSchedule):
def __init__(self, inner_model, sigmas):
super().__init__(sigmas, quantize=shared.opts.enable_quantization)
self.inner_model = inner_model
def forward(self, input, sigma, **kwargs):
return self.inner_model.apply_model(input, sigma, **kwargs)
class FLUX1Inferencer(torch.nn.Module):
def __init__(self, state_dict, use_ema=False):
super().__init__()
params = dict(
image_model="flux",
in_channels=16,
vec_in_dim=768,
context_in_dim=4096,
hidden_size=3072,
mlp_ratio=4.0,
num_heads=24,
depth=19,
depth_single_blocks=38,
axes_dim=[16, 56, 56],
theta=10000,
qkv_bias=True,
guidance_embed=True,
)
# detect model_prefix
diffusion_model_prefix = "model.diffusion_model."
if "model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale" in state_dict:
diffusion_model_prefix = "model.diffusion_model."
elif "double_blocks.0.img_attn.norm.key_norm.scale" in state_dict:
diffusion_model_prefix = ""
shift=1.15
# check guidance_in to detect Flux schnell
if f"{diffusion_model_prefix}guidance_in.in_layer.weight" not in state_dict:
print("Flux schnell detected")
params.update(dict(guidance_embed=False,))
shift=1.0
with torch.no_grad():
self.model = BaseModel(shift=shift, state_dict=state_dict, prefix=diffusion_model_prefix, device="cpu", dtype=devices.dtype_inference, **params)
self.first_stage_model = SDVAE(device="cpu", dtype=devices.dtype_vae)
self.first_stage_model.dtype = devices.dtype_vae
self.vae = self.first_stage_model # real vae
self.alphas_cumprod = 1 / (self.model.model_sampling.sigmas ** 2 + 1)
self.text_encoders = FluxCond()
self.cond_stage_key = 'txt'
self.parameterization = "eps"
self.model.conditioning_key = "crossattn"
self.latent_format = FLUX1LatentFormat()
self.latent_channels = 16
@property
def cond_stage_model(self):
return self.text_encoders
def before_load_weights(self, state_dict):
self.cond_stage_model.before_load_weights(state_dict)
def ema_scope(self):
return contextlib.nullcontext()
def get_learned_conditioning(self, batch: list[str]):
return self.cond_stage_model(batch)
def apply_model(self, x, t, cond):
return self.model(x, t, c_crossattn=cond['crossattn'], y=cond['vector'])
def decode_first_stage(self, latent):
latent = self.latent_format.process_out(latent)
x = self.first_stage_model.decode(latent)
if x.dtype == torch.float16:
x = torch.nan_to_num(x, nan=0.0, posinf=65504, neginf=-65504)
return x
def encode_first_stage(self, image):
latent = self.first_stage_model.encode(image)
return self.latent_format.process_in(latent)
def get_first_stage_encoding(self, x):
return x
def create_denoiser(self):
return FLUX1Denoiser(self, self.model.model_sampling.sigmas)
def medvram_fields(self):
return [
(self, 'first_stage_model'),
(self, 'text_encoders'),
(self, 'model'),
]
def add_noise_to_latent(self, x, noise, amount):
return x * (1 - amount) + noise * amount
def fix_dimensions(self, width, height):
return width // 16 * 16, height // 16 * 16
def diffusers_weight_mapping(self):
# https://github.com/kohya-ss/sd-scripts/blob/a61cf73a5cb5209c3f4d1a3688dd276a4dfd1ecb/networks/convert_flux_lora.py
# please see also https://github.com/huggingface/diffusers/blob/main/src/diffusers/loaders/lora_conversion_utils.py
for i in range(self.model.depth):
yield f"transformer.transformer_blocks.{i}.attn.add_k_proj", f"diffusion_model_double_blocks_{i}_txt_attn_qkv_k_proj"
yield f"transformer.transformer_blocks.{i}.attn.add_q_proj", f"diffusion_model_double_blocks_{i}_txt_attn_qkv_q_proj"
yield f"transformer.transformer_blocks.{i}.attn.add_v_proj", f"diffusion_model_double_blocks_{i}_txt_attn_qkv_v_proj"
yield f"transformer.transformer_blocks.{i}.attn.to_add_out", f"diffusion_model_double_blocks_{i}_txt_attn_proj"
yield f"transformer.transformer_blocks.{i}.attn.to_k", f"diffusion_model_double_blocks_{i}_img_attn_qkv_k_proj"
yield f"transformer.transformer_blocks.{i}.attn.to_q", f"diffusion_model_double_blocks_{i}_img_attn_qkv_q_proj"
yield f"transformer.transformer_blocks.{i}.attn.to_v", f"diffusion_model_double_blocks_{i}_img_attn_qkv_v_proj"
yield f"transformer.transformer_blocks.{i}.attn.to_out.0", f"diffusion_model_double_blocks_{i}_img_attn_proj"
yield f"transformer.transformer_blocks.{i}.ff.net.0.proj", f"diffusion_model_double_blocks_{i}_img_mlp_0"
yield f"transformer.transformer_blocks.{i}.ff.net.2", f"diffusion_model_double_blocks_{i}_img_mlp_2"
yield f"transformer.transformer_blocks.{i}.ff_context.net.0.proj", f"diffusion_model_double_blocks_{i}_txt_mlp_0"
yield f"transformer.transformer_blocks.{i}.ff_context.net.2", f"diffusion_model_double_blocks_{i}_txt_mlp_2"
yield f"transformer.transformer_blocks.{i}.norm1.linear", f"diffusion_model_double_blocks_{i}_img_mod_lin"
yield f"transformer.transformer_blocks.{i}.norm1_context.linear", f"diffusion_model_double_blocks_{i}_txt_mod_lin"
for i in range(self.model.depth_single_block):
yield f"transformer.single_transformer_blocks.{i}.attn.to_q", f"diffusion_model_single_blocks_{i}_linear1_q_proj"
yield f"transformer.single_transformer_blocks.{i}.attn.to_k", f"diffusion_model_single_blocks_{i}_linear1_k_proj"
yield f"transformer.single_transformer_blocks.{i}.attn.to_v", f"diffusion_model_single_blocks_{i}_linear1_v_proj"
yield f"transformer.single_transformer_blocks.{i}.proj_mlp", f"diffusion_model_single_blocks_{i}_linear1_mlp_proj"
yield f"transformer.single_transformer_blocks.{i}.proj_out", f"diffusion_model_single_blocks_{i}_linear2"
yield f"transformer.single_transformer_blocks.{i}.norm.linear", f"diffusion_model_single_blocks_{i}_modulation_lin"

View file

@ -0,0 +1,30 @@
import torch
from einops import rearrange
from torch import Tensor
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor:
q, k = apply_rope(q, k, pe)
x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
x = rearrange(x, "B H L D -> B L (H D)")
return x
def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
assert dim % 2 == 0
scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
omega = 1.0 / (theta**scale)
out = torch.einsum("...n,d->...nd", pos, omega)
out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
return out.float()
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor):
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)

View file

@ -0,0 +1,147 @@
# original code from https://github.com/black-forest-labs/flux
#
from dataclasses import dataclass
import torch
from einops import rearrange, repeat
from torch import Tensor, nn
from .modules.layers import (DoubleStreamBlock, EmbedND, LastLayer,
MLPEmbedder, SingleStreamBlock,
timestep_embedding)
@dataclass
class FluxParams:
in_channels: int
vec_in_dim: int
context_in_dim: int
hidden_size: int
mlp_ratio: float
num_heads: int
depth: int
depth_single_blocks: int
axes_dim: list
theta: int
qkv_bias: bool
guidance_embed: bool
class Flux(nn.Module):
"""
Transformer model for flow matching on sequences.
"""
def __init__(self, image_model=None, final_layer=True, dtype=None, device=None, **kwargs):
super().__init__()
self.dtype = dtype
params = FluxParams(**kwargs)
self.params = params
self.in_channels = params.in_channels * 2 * 2
self.out_channels = self.in_channels
if params.hidden_size % params.num_heads != 0:
raise ValueError(
f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
)
pe_dim = params.hidden_size // params.num_heads
if sum(params.axes_dim) != pe_dim:
raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
self.hidden_size = params.hidden_size
self.num_heads = params.num_heads
self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True, dtype=dtype, device=device)
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, dtype=dtype, device=device)
self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size, dtype=dtype, device=device)
self.guidance_in = (
MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, dtype=dtype, device=device) if params.guidance_embed else nn.Identity()
)
self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size, dtype=dtype, device=device)
self.double_blocks = nn.ModuleList(
[
DoubleStreamBlock(
self.hidden_size,
self.num_heads,
mlp_ratio=params.mlp_ratio,
qkv_bias=params.qkv_bias,
dtype=dtype, device=device,
)
for _ in range(params.depth)
]
)
self.single_blocks = nn.ModuleList(
[
SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, dtype=dtype, device=device)
for _ in range(params.depth_single_blocks)
]
)
if final_layer:
self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels, dtype=dtype, device=device)
def forward_orig(
self,
img: Tensor,
img_ids: Tensor,
txt: Tensor,
txt_ids: Tensor,
timesteps: Tensor,
y: Tensor,
guidance: Tensor = None,
) -> Tensor:
if img.ndim != 3 or txt.ndim != 3:
raise ValueError("Input img and txt tensors must have 3 dimensions.")
# running on sequences img
img = self.img_in(img)
vec = self.time_in(timestep_embedding(timesteps, 256).to(img.dtype))
if self.params.guidance_embed:
if guidance is None:
raise ValueError("Didn't get guidance strength for guidance distilled model.")
vec = vec + self.guidance_in(timestep_embedding(guidance, 256).to(img.dtype))
vec = vec + self.vector_in(y)
txt = self.txt_in(txt)
ids = torch.cat((txt_ids, img_ids), dim=1)
pe = self.pe_embedder(ids)
for block in self.double_blocks:
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
img = torch.cat((txt, img), 1)
for block in self.single_blocks:
img = block(img, vec=vec, pe=pe)
img = img[:, txt.shape[1] :, ...]
if self.final_layer:
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
return img
def forward(self, x, timestep, context, y, guidance, **kwargs):
# from comfy/ldm/common_dit.py
def pad_to_patch_size(img, patch_size=(2, 2), padding_mode="circular"):
if padding_mode == "circular" and torch.jit.is_tracing() or torch.jit.is_scripting():
padding_mode = "reflect"
pad_h = (patch_size[0] - img.shape[-2] % patch_size[0]) % patch_size[0]
pad_w = (patch_size[1] - img.shape[-1] % patch_size[1]) % patch_size[1]
return torch.nn.functional.pad(img, (0, pad_w, 0, pad_h), mode=padding_mode)
bs, c, h, w = x.shape
patch_size = 2
x = pad_to_patch_size(x, (patch_size, patch_size))
img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
h_len = ((h + (patch_size // 2)) // patch_size)
w_len = ((w + (patch_size // 2)) // patch_size)
img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
img_ids[..., 1] = img_ids[..., 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype)[:, None]
img_ids[..., 2] = img_ids[..., 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype)[None, :]
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance)
return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)[:,:,:h,:w]

View file

@ -0,0 +1,265 @@
import math
from dataclasses import dataclass
import torch
from einops import rearrange
from torch import Tensor, nn
from ..math import attention, rope
class EmbedND(nn.Module):
def __init__(self, dim: int, theta: int, axes_dim: list):
super().__init__()
self.dim = dim
self.theta = theta
self.axes_dim = axes_dim
def forward(self, ids: Tensor) -> Tensor:
n_axes = ids.shape[-1]
emb = torch.cat(
[rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
dim=-3,
)
return emb.unsqueeze(1)
def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0):
"""
Create sinusoidal timestep embeddings.
:param t: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an (N, D) Tensor of positional embeddings.
"""
t = time_factor * t
half = dim // 2
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half)
args = t[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
if torch.is_floating_point(t):
embedding = embedding.to(t)
return embedding
class MLPEmbedder(nn.Module):
def __init__(self, in_dim: int, hidden_dim: int, dtype=None, device=None):
super().__init__()
self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True, dtype=dtype, device=device)
self.silu = nn.SiLU()
self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True, dtype=dtype, device=device)
def forward(self, x: Tensor) -> Tensor:
return self.out_layer(self.silu(self.in_layer(x)))
class RMSNorm(torch.nn.Module):
def __init__(self, dim: int, dtype=None, device=None):
super().__init__()
self.scale = nn.Parameter(torch.empty((dim), dtype=dtype, device=device))
def forward(self, x: Tensor):
x_dtype = x.dtype
x = x.float()
rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6)
return (x * rrms).to(dtype=x_dtype) * self.scale
class QKNorm(torch.nn.Module):
def __init__(self, dim: int, dtype=None, device=None):
super().__init__()
self.query_norm = RMSNorm(dim, dtype=dtype, device=device)
self.key_norm = RMSNorm(dim, dtype=dtype, device=device)
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple:
q = self.query_norm(q)
k = self.key_norm(k)
return q.to(v), k.to(v)
class QkvLinear(torch.nn.Linear):
pass
class SelfAttention(nn.Module):
def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False, dtype=None, device=None):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.qkv = QkvLinear(dim, dim * 3, bias=qkv_bias, dtype=dtype, device=device)
self.norm = QKNorm(head_dim, dtype=dtype, device=device)
self.proj = nn.Linear(dim, dim, dtype=dtype, device=device)
def forward(self, x: Tensor, pe: Tensor) -> Tensor:
qkv = self.qkv(x)
q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
q, k = self.norm(q, k, v)
x = attention(q, k, v, pe=pe)
x = self.proj(x)
return x
@dataclass
class ModulationOut:
shift: Tensor
scale: Tensor
gate: Tensor
class Modulation(nn.Module):
def __init__(self, dim: int, double: bool, dtype=None, device=None):
super().__init__()
self.is_double = double
self.multiplier = 6 if double else 3
self.lin = nn.Linear(dim, self.multiplier * dim, bias=True, dtype=dtype, device=device)
def forward(self, vec: Tensor) -> tuple:
out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1)
return (
ModulationOut(*out[:3]),
ModulationOut(*out[3:]) if self.is_double else None,
)
class DoubleStreamBlock(nn.Module):
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, dtype=None, device=None):
super().__init__()
mlp_hidden_dim = int(hidden_size * mlp_ratio)
self.num_heads = num_heads
self.hidden_size = hidden_size
self.img_mod = Modulation(hidden_size, double=True, dtype=dtype, device=device)
self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, dtype=dtype, device=device)
self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.img_mlp = nn.Sequential(
nn.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device),
nn.GELU(approximate="tanh"),
nn.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
)
self.txt_mod = Modulation(hidden_size, double=True, dtype=dtype, device=device)
self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, dtype=dtype, device=device)
self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.txt_mlp = nn.Sequential(
nn.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device),
nn.GELU(approximate="tanh"),
nn.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
)
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor):
img_mod1, img_mod2 = self.img_mod(vec)
txt_mod1, txt_mod2 = self.txt_mod(vec)
# prepare image for attention
img_modulated = self.img_norm1(img)
img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
img_qkv = self.img_attn.qkv(img_modulated)
img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
# prepare txt for attention
txt_modulated = self.txt_norm1(txt)
txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
txt_qkv = self.txt_attn.qkv(txt_modulated)
txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
# run actual attention
q = torch.cat((txt_q, img_q), dim=2)
k = torch.cat((txt_k, img_k), dim=2)
v = torch.cat((txt_v, img_v), dim=2)
attn = attention(q, k, v, pe=pe)
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
# calculate the img bloks
img = img + img_mod1.gate * self.img_attn.proj(img_attn)
img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift)
# calculate the txt bloks
txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn)
txt = txt + txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
if txt.dtype == torch.float16:
txt = torch.nan_to_num(txt, nan=0.0, posinf=65504, neginf=-65504)
return img, txt
class SingleStreamBlock(nn.Module):
"""
A DiT block with parallel linear layers as described in
https://arxiv.org/abs/2302.05442 and adapted modulation interface.
"""
def __init__(
self,
hidden_size: int,
num_heads: int,
mlp_ratio: float = 4.0,
qk_scale: float = None,
dtype=None,
device=None,
):
super().__init__()
self.hidden_dim = hidden_size
self.num_heads = num_heads
head_dim = hidden_size // num_heads
self.scale = qk_scale or head_dim**-0.5
self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
# qkv and mlp_in
self.linear1 = QkvLinear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim, dtype=dtype, device=device)
# proj and mlp_out
self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size, dtype=dtype, device=device)
self.norm = QKNorm(head_dim, dtype=dtype, device=device)
self.hidden_size = hidden_size
self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.mlp_act = nn.GELU(approximate="tanh")
self.modulation = Modulation(hidden_size, double=False, dtype=dtype, device=device)
def forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor:
mod, _ = self.modulation(vec)
x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
q, k = self.norm(q, k, v)
# compute attention
attn = attention(q, k, v, pe=pe)
# compute activation in mlp stream, cat again and run second linear layer
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
x += mod.gate * output
if x.dtype == torch.float16:
x = torch.nan_to_num(x, nan=0.0, posinf=65504, neginf=-65504)
return x
class LastLayer(nn.Module):
def __init__(self, hidden_size: int, patch_size: int, out_channels: int, dtype=None, device=None):
super().__init__()
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True, dtype=dtype, device=device)
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True, dtype=dtype, device=device))
def forward(self, x: Tensor, vec: Tensor) -> Tensor:
shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1)
x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
x = self.linear(x)
return x

201
modules/models/flux/util.py Normal file
View file

@ -0,0 +1,201 @@
import os
from dataclasses import dataclass
import torch
from einops import rearrange
from huggingface_hub import hf_hub_download
from imwatermark import WatermarkEncoder
from safetensors.torch import load_file as load_sft
from .model import Flux, FluxParams
from .modules.autoencoder import AutoEncoder, AutoEncoderParams
from .modules.conditioner import HFEmbedder
@dataclass
class ModelSpec:
params: FluxParams
ae_params: AutoEncoderParams
ckpt_path: str | None
ae_path: str | None
repo_id: str | None
repo_flow: str | None
repo_ae: str | None
configs = {
"flux-dev": ModelSpec(
repo_id="black-forest-labs/FLUX.1-dev",
repo_flow="flux1-dev.safetensors",
repo_ae="ae.safetensors",
ckpt_path=os.getenv("FLUX_DEV"),
params=FluxParams(
in_channels=64,
vec_in_dim=768,
context_in_dim=4096,
hidden_size=3072,
mlp_ratio=4.0,
num_heads=24,
depth=19,
depth_single_blocks=38,
axes_dim=[16, 56, 56],
theta=10_000,
qkv_bias=True,
guidance_embed=True,
),
ae_path=os.getenv("AE"),
ae_params=AutoEncoderParams(
resolution=256,
in_channels=3,
ch=128,
out_ch=3,
ch_mult=[1, 2, 4, 4],
num_res_blocks=2,
z_channels=16,
scale_factor=0.3611,
shift_factor=0.1159,
),
),
"flux-schnell": ModelSpec(
repo_id="black-forest-labs/FLUX.1-schnell",
repo_flow="flux1-schnell.safetensors",
repo_ae="ae.safetensors",
ckpt_path=os.getenv("FLUX_SCHNELL"),
params=FluxParams(
in_channels=64,
vec_in_dim=768,
context_in_dim=4096,
hidden_size=3072,
mlp_ratio=4.0,
num_heads=24,
depth=19,
depth_single_blocks=38,
axes_dim=[16, 56, 56],
theta=10_000,
qkv_bias=True,
guidance_embed=False,
),
ae_path=os.getenv("AE"),
ae_params=AutoEncoderParams(
resolution=256,
in_channels=3,
ch=128,
out_ch=3,
ch_mult=[1, 2, 4, 4],
num_res_blocks=2,
z_channels=16,
scale_factor=0.3611,
shift_factor=0.1159,
),
),
}
def print_load_warning(missing: list[str], unexpected: list[str]) -> None:
if len(missing) > 0 and len(unexpected) > 0:
print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing))
print("\n" + "-" * 79 + "\n")
print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected))
elif len(missing) > 0:
print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing))
elif len(unexpected) > 0:
print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected))
def load_flow_model(name: str, device: str | torch.device = "cuda", hf_download: bool = True):
# Loading Flux
print("Init model")
ckpt_path = configs[name].ckpt_path
if (
ckpt_path is None
and configs[name].repo_id is not None
and configs[name].repo_flow is not None
and hf_download
):
ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_flow)
with torch.device("meta" if ckpt_path is not None else device):
model = Flux(configs[name].params).to(torch.bfloat16)
if ckpt_path is not None:
print("Loading checkpoint")
# load_sft doesn't support torch.device
sd = load_sft(ckpt_path, device=str(device))
missing, unexpected = model.load_state_dict(sd, strict=False, assign=True)
print_load_warning(missing, unexpected)
return model
def load_t5(device: str | torch.device = "cuda", max_length: int = 512) -> HFEmbedder:
# max length 64, 128, 256 and 512 should work (if your sequence is short enough)
return HFEmbedder("google/t5-v1_1-xxl", max_length=max_length, torch_dtype=torch.bfloat16).to(device)
def load_clip(device: str | torch.device = "cuda") -> HFEmbedder:
return HFEmbedder("openai/clip-vit-large-patch14", max_length=77, torch_dtype=torch.bfloat16).to(device)
def load_ae(name: str, device: str | torch.device = "cuda", hf_download: bool = True) -> AutoEncoder:
ckpt_path = configs[name].ae_path
if (
ckpt_path is None
and configs[name].repo_id is not None
and configs[name].repo_ae is not None
and hf_download
):
ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_ae)
# Loading the autoencoder
print("Init AE")
with torch.device("meta" if ckpt_path is not None else device):
ae = AutoEncoder(configs[name].ae_params)
if ckpt_path is not None:
sd = load_sft(ckpt_path, device=str(device))
missing, unexpected = ae.load_state_dict(sd, strict=False, assign=True)
print_load_warning(missing, unexpected)
return ae
class WatermarkEmbedder:
def __init__(self, watermark):
self.watermark = watermark
self.num_bits = len(WATERMARK_BITS)
self.encoder = WatermarkEncoder()
self.encoder.set_watermark("bits", self.watermark)
def __call__(self, image: torch.Tensor) -> torch.Tensor:
"""
Adds a predefined watermark to the input image
Args:
image: ([N,] B, RGB, H, W) in range [-1, 1]
Returns:
same as input but watermarked
"""
image = 0.5 * image + 0.5
squeeze = len(image.shape) == 4
if squeeze:
image = image[None, ...]
n = image.shape[0]
image_np = rearrange((255 * image).detach().cpu(), "n b c h w -> (n b) h w c").numpy()[:, :, :, ::-1]
# torch (b, c, h, w) in [0, 1] -> numpy (b, h, w, c) [0, 255]
# watermarking libary expects input as cv2 BGR format
for k in range(image_np.shape[0]):
image_np[k] = self.encoder.encode(image_np[k], "dwtDct")
image = torch.from_numpy(rearrange(image_np[:, :, :, ::-1], "(n b) h w c -> n b c h w", n=n)).to(
image.device
)
image = torch.clamp(image / 255, min=0.0, max=1.0)
if squeeze:
image = image[0]
image = 2 * image - 1
return image
# A fixed 48-bit message that was chosen at random
WATERMARK_MESSAGE = 0b001010101111111010000111100111001111010100101110
# bin(x)[2:] gives bits of x as str, use int to convert them to 0/1
WATERMARK_BITS = [int(bit) for bit in bin(WATERMARK_MESSAGE)[2:]]
embed_watermark = WatermarkEmbedder(WATERMARK_BITS)

View file

@ -24,6 +24,11 @@ class AutocastLinear(nn.Linear):
def forward(self, x):
return torch.nn.functional.linear(x, self.weight.to(x.dtype), self.bias.to(x.dtype) if self.bias is not None else None)
class AutocastLayerNorm(nn.LayerNorm):
def forward(self, x):
return torch.nn.functional.layer_norm(
x, self.normalized_shape, self.weight.to(x.dtype), self.bias.to(x.dtype) if self.bias is not None else None, self.eps)
def attention(q, k, v, heads, mask=None):
"""Convenience wrapper around a basic attention operation"""
@ -41,9 +46,9 @@ class Mlp(nn.Module):
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias, dtype=dtype, device=device)
self.fc1 = AutocastLinear(in_features, hidden_features, bias=bias, dtype=dtype, device=device)
self.act = act_layer
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias, dtype=dtype, device=device)
self.fc2 = AutocastLinear(hidden_features, out_features, bias=bias, dtype=dtype, device=device)
def forward(self, x):
x = self.fc1(x)
@ -61,10 +66,10 @@ class CLIPAttention(torch.nn.Module):
def __init__(self, embed_dim, heads, dtype, device):
super().__init__()
self.heads = heads
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
self.q_proj = AutocastLinear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
self.k_proj = AutocastLinear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
self.v_proj = AutocastLinear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
self.out_proj = AutocastLinear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
def forward(self, x, mask=None):
q = self.q_proj(x)
@ -82,9 +87,11 @@ ACTIVATIONS = {
class CLIPLayer(torch.nn.Module):
def __init__(self, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device):
super().__init__()
self.layer_norm1 = nn.LayerNorm(embed_dim, dtype=dtype, device=device)
#self.layer_norm1 = nn.LayerNorm(embed_dim, dtype=dtype, device=device)
self.layer_norm1 = AutocastLayerNorm(embed_dim, dtype=dtype, device=device)
self.self_attn = CLIPAttention(embed_dim, heads, dtype, device)
self.layer_norm2 = nn.LayerNorm(embed_dim, dtype=dtype, device=device)
self.layer_norm2 = AutocastLayerNorm(embed_dim, dtype=dtype, device=device)
#self.layer_norm2 = nn.LayerNorm(embed_dim, dtype=dtype, device=device)
#self.mlp = CLIPMLP(embed_dim, intermediate_size, intermediate_activation, dtype, device)
self.mlp = Mlp(embed_dim, intermediate_size, embed_dim, act_layer=ACTIVATIONS[intermediate_activation], dtype=dtype, device=device)
@ -131,7 +138,7 @@ class CLIPTextModel_(torch.nn.Module):
super().__init__()
self.embeddings = CLIPEmbeddings(embed_dim, dtype=torch.float32, device=device, textual_inversion_key=config_dict.get('textual_inversion_key', 'clip_l'))
self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device)
self.final_layer_norm = nn.LayerNorm(embed_dim, dtype=dtype, device=device)
self.final_layer_norm = AutocastLayerNorm(embed_dim, dtype=dtype, device=device)
def forward(self, input_tokens, intermediate_output=None, final_layer_norm_intermediate=True):
x = self.embeddings(input_tokens)
@ -150,7 +157,7 @@ class CLIPTextModel(torch.nn.Module):
self.num_layers = config_dict["num_hidden_layers"]
self.text_model = CLIPTextModel_(config_dict, dtype, device)
embed_dim = config_dict["hidden_size"]
self.text_projection = nn.Linear(embed_dim, embed_dim, bias=False, dtype=dtype, device=device)
self.text_projection = AutocastLinear(embed_dim, embed_dim, bias=False, dtype=dtype, device=device)
self.text_projection.weight.copy_(torch.eye(embed_dim))
self.dtype = dtype
@ -370,7 +377,7 @@ class T5Attention(torch.nn.Module):
if relative_attention_bias:
self.relative_attention_num_buckets = 32
self.relative_attention_max_distance = 128
self.relative_attention_bias = torch.nn.Embedding(self.relative_attention_num_buckets, self.num_heads, device=device)
self.relative_attention_bias = torch.nn.Embedding(self.relative_attention_num_buckets, self.num_heads, device=device, dtype=torch.float32)
@staticmethod
def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128):
@ -442,7 +449,7 @@ class T5Attention(torch.nn.Module):
else:
mask = None
out = attention(q, k * ((k.shape[-1] / self.num_heads) ** 0.5), v, self.num_heads, mask.to(x.dtype) if mask is not None else None)
out = attention(q, k * ((k.shape[-1] / self.num_heads) ** 0.5), v, self.num_heads, mask.to(q.dtype) if mask is not None else None)
return self.o(out), past_bias
@ -475,19 +482,21 @@ class T5Block(torch.nn.Module):
class T5Stack(torch.nn.Module):
def __init__(self, num_layers, model_dim, inner_dim, ff_dim, num_heads, vocab_size, dtype, device):
super().__init__()
self.embed_tokens = torch.nn.Embedding(vocab_size, model_dim, device=device)
#self.embed_tokens = torch.nn.Embedding(vocab_size, model_dim, device=device, dtype=torch.float32)
self.block = torch.nn.ModuleList([T5Block(model_dim, inner_dim, ff_dim, num_heads, relative_attention_bias=(i == 0), dtype=dtype, device=device) for i in range(num_layers)])
self.final_layer_norm = T5LayerNorm(model_dim, dtype=dtype, device=device)
def forward(self, input_ids, intermediate_output=None, final_layer_norm_intermediate=True):
def forward(self, x, intermediate_output=None, final_layer_norm_intermediate=True):
intermediate = None
x = self.embed_tokens(input_ids).to(torch.float32) # needs float32 or else T5 returns all zeroes
#x = self.embed_tokens(input_ids).to(torch.float32) # needs float32 or else T5 returns all zeroes
# some T5XXL do not embed_token. use shared token instead like comfy
past_bias = None
for i, layer in enumerate(self.block):
x, past_bias = layer(x, past_bias)
if i == intermediate_output:
intermediate = x.clone()
x = self.final_layer_norm(x)
x = torch.nan_to_num(x)
if intermediate is not None and final_layer_norm_intermediate:
intermediate = self.final_layer_norm(intermediate)
return x, intermediate
@ -498,13 +507,18 @@ class T5(torch.nn.Module):
super().__init__()
self.num_layers = config_dict["num_layers"]
self.encoder = T5Stack(self.num_layers, config_dict["d_model"], config_dict["d_model"], config_dict["d_ff"], config_dict["num_heads"], config_dict["vocab_size"], dtype, device)
self.shared = torch.nn.Embedding(config_dict["vocab_size"], config_dict["d_model"], device=device, dtype=torch.float32)
self.dtype = dtype
def get_input_embeddings(self):
return self.encoder.embed_tokens
#return self.encoder.embed_tokens
return self.shared
def set_input_embeddings(self, embeddings):
self.encoder.embed_tokens = embeddings
#self.encoder.embed_tokens = embeddings
self.shared = embeddings
def forward(self, *args, **kwargs):
return self.encoder(*args, **kwargs)
def forward(self, input_ids, *args, **kwargs):
x = self.shared(input_ids).float()
x = torch.nan_to_num(x)
return self.encoder(x, *args, **kwargs)

View file

@ -43,7 +43,7 @@ CLIPG_CONFIG = {
"textual_inversion_key": "clip_g",
}
T5_URL = f"{shared.hf_endpoint}/AUTOMATIC/stable-diffusion-3-medium-text-encoders/resolve/main/t5xxl_fp16.safetensors"
T5_URL = f"{shared.hf_endpoint}/AUTOMATIC/stable-diffusion-3-medium-text-encoders/resolve/main/t5xxl_fp16_e4m3fn.safetensors"
T5_CONFIG = {
"d_ff": 10240,
"d_model": 4096,
@ -140,7 +140,7 @@ class Sd3T5(torch.nn.Module):
return tokens, multipliers
def forward(self, texts, *, token_count):
if not self.t5xxl or not shared.opts.sd3_enable_t5:
if not self.t5xxl:
return torch.zeros((len(texts), token_count, 4096), device=devices.device, dtype=devices.dtype)
tokens_batch = []
@ -164,11 +164,11 @@ class SD3Cond(torch.nn.Module):
self.tokenizer = SD3Tokenizer()
with torch.no_grad():
self.clip_g = SDXLClipG(CLIPG_CONFIG, device="cpu", dtype=devices.dtype)
self.clip_l = SDClipModel(layer="hidden", layer_idx=-2, device="cpu", dtype=devices.dtype, layer_norm_hidden_state=False, return_projected_pooled=False, textmodel_json_config=CLIPL_CONFIG)
self.clip_g = SDXLClipG(CLIPG_CONFIG, device="cpu", dtype=devices.dtype_inference)
self.clip_l = SDClipModel(layer="hidden", layer_idx=-2, device="cpu", dtype=devices.dtype_inference, layer_norm_hidden_state=False, return_projected_pooled=False, textmodel_json_config=CLIPL_CONFIG)
if shared.opts.sd3_enable_t5:
self.t5xxl = T5XXLModel(T5_CONFIG, device="cpu", dtype=devices.dtype)
self.t5xxl = T5XXLModel(T5_CONFIG, device="cpu", dtype=devices.dtype_inference)
else:
self.t5xxl = None
@ -199,8 +199,8 @@ class SD3Cond(torch.nn.Module):
with safetensors.safe_open(clip_l_file, framework="pt") as file:
self.clip_l.transformer.load_state_dict(SafetensorsMapping(file), strict=False)
if self.t5xxl and 'text_encoders.t5xxl.transformer.encoder.embed_tokens.weight' not in state_dict:
t5_file = modelloader.load_file_from_url(T5_URL, model_dir=clip_path, file_name="t5xxl_fp16.safetensors")
if self.t5xxl and 'text_encoders.t5xxl.transformer.encoder.block.0.layer.0.SelfAttention.k.weight' not in state_dict:
t5_file = modelloader.load_file_from_url(T5_URL, model_dir=clip_path, file_name="t5xxl_fp8_e4m3fn.safetensors")
with safetensors.safe_open(t5_file, framework="pt") as file:
self.t5xxl.transformer.load_state_dict(SafetensorsMapping(file), strict=False)

View file

@ -484,7 +484,7 @@ class StableDiffusionProcessing:
cache = caches[0]
with devices.autocast():
with devices.autocast(target_dtype=devices.dtype_inference):
cache[1] = function(shared.sd_model, required_prompts, steps, hires_steps, shared.opts.use_old_scheduling)
cache[0] = cached_params
@ -984,7 +984,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
sd_models.apply_alpha_schedule_override(p.sd_model, p)
with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast():
with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast(target_dtype=devices.dtype_inference, current_dtype=devices.dtype_unet):
samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts)
if p.scripts is not None:
@ -1147,6 +1147,11 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
if p.scripts is not None:
p.scripts.postprocess(p, res)
if lowvram.is_enabled(shared.sd_model):
# for interrupted case
lowvram.send_everything_to_cpu()
return res
@ -1439,7 +1444,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
with devices.autocast():
extra_networks.activate(self, self.hr_extra_network_data)
with devices.autocast():
with devices.autocast(target_dtype=devices.dtype_inference):
self.calculate_hr_conds()
sd_models.apply_token_merging(self.sd_model, self.get_token_merging_ratio(for_hr=True))

View file

@ -160,7 +160,7 @@ class LoadStateDictOnMeta(ReplaceHelper):
self.state_dict = state_dict
self.device = device
self.weight_dtype_conversion = weight_dtype_conversion or {}
self.default_dtype = self.weight_dtype_conversion.get('')
self.default_dtype = self.weight_dtype_conversion.get('', None)
def get_weight_dtype(self, key):
key_first_term, _ = key.split('.', 1)
@ -176,6 +176,11 @@ class LoadStateDictOnMeta(ReplaceHelper):
def load_from_state_dict(original, module, state_dict, prefix, *args, **kwargs):
used_param_keys = []
if isinstance(module, (torch.nn.Linear, torch.nn.Conv2d, torch.nn.GroupNorm,)):
# HACK add assign=True to local_metadata for some cases
args[0]['assign_to_params_buffers'] = True
for name, param in module._parameters.items():
if param is None:
continue
@ -183,12 +188,16 @@ class LoadStateDictOnMeta(ReplaceHelper):
key = prefix + name
sd_param = sd.pop(key, None)
if sd_param is not None:
state_dict[key] = sd_param.to(dtype=self.get_weight_dtype(key))
dtype = self.get_weight_dtype(key)
if dtype is None:
state_dict[key] = sd_param
else:
state_dict[key] = sd_param.to(dtype=dtype)
used_param_keys.append(key)
if param.is_meta:
dtype = sd_param.dtype if sd_param is not None else param.dtype
module._parameters[name] = torch.nn.parameter.Parameter(torch.zeros_like(param, device=device, dtype=dtype), requires_grad=param.requires_grad)
module._parameters[name] = torch.nn.parameter.Parameter(torch.empty_like(param, device=device, dtype=dtype), requires_grad=param.requires_grad)
for name in module._buffers:
key = prefix + name

View file

@ -42,12 +42,12 @@ def apply_model(orig_func, self, x_noisy, t, cond, **kwargs):
if isinstance(cond, dict):
for y in cond.keys():
if isinstance(cond[y], list):
cond[y] = [x.to(devices.dtype_unet) if isinstance(x, torch.Tensor) else x for x in cond[y]]
cond[y] = [x.to(devices.dtype_inference) if isinstance(x, torch.Tensor) else x for x in cond[y]]
else:
cond[y] = cond[y].to(devices.dtype_unet) if isinstance(cond[y], torch.Tensor) else cond[y]
cond[y] = cond[y].to(devices.dtype_inference) if isinstance(cond[y], torch.Tensor) else cond[y]
with devices.autocast():
result = orig_func(self, x_noisy.to(devices.dtype_unet), t.to(devices.dtype_unet), cond, **kwargs)
result = orig_func(self, x_noisy.to(devices.dtype_inference), t.to(devices.dtype_inference), cond, **kwargs)
if devices.unet_needs_upcast:
return result.float()
else:
@ -107,7 +107,7 @@ class GELUHijack(torch.nn.GELU, torch.nn.Module):
torch.nn.GELU.__init__(self, *args, **kwargs)
def forward(self, x):
if devices.unet_needs_upcast:
return torch.nn.GELU.forward(self.float(), x.float()).to(devices.dtype_unet)
return torch.nn.GELU.forward(self.float(), x.float()).to(devices.dtype_inference)
else:
return torch.nn.GELU.forward(self, x)
@ -125,11 +125,11 @@ unet_needs_upcast = lambda *args, **kwargs: devices.unet_needs_upcast
CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.apply_model', apply_model, unet_needs_upcast)
CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', timestep_embedding)
CondFunc('ldm.modules.attention.SpatialTransformer.forward', spatial_transformer_forward)
CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', lambda orig_func, timesteps, *args, **kwargs: orig_func(timesteps, *args, **kwargs).to(torch.float32 if timesteps.dtype == torch.int64 else devices.dtype_unet), unet_needs_upcast)
CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', lambda orig_func, timesteps, *args, **kwargs: orig_func(timesteps, *args, **kwargs).to(torch.float32 if timesteps.dtype == torch.int64 else devices.dtype_inference), unet_needs_upcast)
if version.parse(torch.__version__) <= version.parse("1.13.2") or torch.cuda.is_available():
CondFunc('ldm.modules.diffusionmodules.util.GroupNorm32.forward', lambda orig_func, self, *args, **kwargs: orig_func(self.float(), *args, **kwargs), unet_needs_upcast)
CondFunc('ldm.modules.attention.GEGLU.forward', lambda orig_func, self, x: orig_func(self.float(), x.float()).to(devices.dtype_unet), unet_needs_upcast)
CondFunc('ldm.modules.attention.GEGLU.forward', lambda orig_func, self, x: orig_func(self.float(), x.float()).to(devices.dtype_inference), unet_needs_upcast)
CondFunc('open_clip.transformer.ResidualAttentionBlock.__init__', lambda orig_func, *args, **kwargs: kwargs.update({'act_layer': GELUHijack}) and False or orig_func(*args, **kwargs), lambda _, *args, **kwargs: kwargs.get('act_layer') is None or kwargs['act_layer'] == torch.nn.GELU)
first_stage_cond = lambda _, self, *args, **kwargs: devices.unet_needs_upcast and self.model.diffusion_model.dtype == torch.float16
@ -146,7 +146,7 @@ def timestep_embedding_cast_result(orig_func, timesteps, *args, **kwargs):
if devices.unet_needs_upcast and timesteps.dtype == torch.int64:
dtype = torch.float32
else:
dtype = devices.dtype_unet
dtype = devices.dtype_inference
return orig_func(timesteps, *args, **kwargs).to(dtype=dtype)

View file

@ -34,6 +34,7 @@ class ModelType(enum.Enum):
SDXL = 3
SSD = 4
SD3 = 5
FLUX1 = 6
def replace_key(d, key, new_key, value):
@ -267,6 +268,30 @@ def get_state_dict_from_checkpoint(pl_sd):
return pl_sd
def fix_unet_prefix(state_dict):
known_prefixes = ("model.diffusion_model.", "first_stage_model.", "cond_stage_model.", "conditioner", "vae.", "text_encoders.")
for k in state_dict.keys():
found = [prefix for prefix in known_prefixes if k.startswith(prefix)]
if len(found) > 0:
return state_dict
# no known prefix found.
# in this case, this is a unet only state_dict
known_keys = (
"input_blocks.0.0.weight", # SD1.5, SD2, SDXL
"joint_blocks.0.context_block.adaLN_modulation.1.weight", # SD3
"double_blocks.0.img_attn.proj.weight", # FLUX
)
if any(key in state_dict for key in known_keys):
state_dict = {f"model.diffusion_model.{k}": v for k, v in state_dict.items()}
print("Fixed state_dict keys...")
return state_dict
return state_dict
def read_metadata_from_safetensors(filename):
import json
@ -328,6 +353,7 @@ def get_checkpoint_state_dict(checkpoint_info: CheckpointInfo, timer):
print(f"Loading weights [{sd_model_hash}] from {checkpoint_info.filename}")
res = read_state_dict(checkpoint_info.filename)
res = fix_unet_prefix(res)
timer.record("load weights from disk")
return res
@ -355,7 +381,7 @@ def check_fp8(model):
enable_fp8 = False
elif shared.opts.fp8_storage == "Enable":
enable_fp8 = True
elif getattr(model, "is_sdxl", False) and shared.opts.fp8_storage == "Enable for SDXL":
elif any(getattr(model, attr, False) for attr in ("is_sdxl", "is_flux1")) and shared.opts.fp8_storage == "Enable for SDXL":
enable_fp8 = True
else:
enable_fp8 = False
@ -368,10 +394,14 @@ def set_model_type(model, state_dict):
model.is_sdxl = False
model.is_ssd = False
model.is_sd3 = False
model.is_flux1 = False
if "model.diffusion_model.x_embedder.proj.weight" in state_dict:
model.is_sd3 = True
model.model_type = ModelType.SD3
elif "model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale" in state_dict:
model.is_flux1 = True
model.model_type = ModelType.FLUX1
elif hasattr(model, 'conditioner'):
model.is_sdxl = True
@ -393,6 +423,82 @@ def set_model_fields(model):
model.latent_channels = 4
def get_state_dict_dtype(state_dict):
# detect dtypes of state_dict
state_dict_dtype = {}
known_prefixes = ("model.diffusion_model.", "first_stage_model.", "cond_stage_model.", "conditioner", "vae.", "text_encoders.")
for k in state_dict.keys():
found = [prefix for prefix in known_prefixes if k.startswith(prefix)]
if len(found) > 0:
prefix = found[0]
dtype = state_dict[k].dtype
dtypes = state_dict_dtype.get(prefix, {})
if dtype in dtypes:
dtypes[dtype] += 1
else:
dtypes[dtype] = 1
state_dict_dtype[prefix] = dtypes
for prefix in state_dict_dtype:
dtypes = state_dict_dtype[prefix]
# sort by count
state_dict_dtype[prefix] = dict(sorted(dtypes.items(), key=lambda item: item[1], reverse=True))
print("Detected dtypes:", state_dict_dtype)
return state_dict_dtype
def get_loadable_dtype(prefix="model.diffusion_model.", state_dict=None, state_dict_dtype=None):
if state_dict is not None:
state_dict_dtype = get_state_dict_dtype(state_dict)
# get the first dtype
if prefix in state_dict_dtype:
return list(state_dict_dtype[prefix])[0]
return None
def get_vae_dtype(state_dict=None, state_dict_dtype=None):
if state_dict is not None:
state_dict_dtype = get_state_dict_dtype(state_dict)
if state_dict_dtype is None:
raise ValueError("fail to get vae dtype")
vae_prefixes = [prefix for prefix in ("vae.", "first_stage_model.") if prefix in state_dict_dtype]
if len(vae_prefixes) > 0:
vae_prefix = vae_prefixes[0]
for dtype in state_dict_dtype[vae_prefix]:
if state_dict_dtype[vae_prefix][dtype] > 240 and dtype in (torch.float16, torch.float32, torch.bfloat16):
# vae items: 248 for SD1, SDXL 245 for flux
return dtype
return None
def fix_position_ids(state_dict, force=False):
# for SD1.5 or some SDXL with position_ids
for prefix in ("cond_stage_models.", "conditioner.embedders.0."):
position_id_key = f"{prefix}transformer.text_model.embeddings.position_ids"
if position_id_key in state_dict:
original = state_dict[position_id_key]
if original.dtype == torch.int64:
return
if force:
# regenerate
fixed = torch.tensor([list(range(77))], dtype=torch.int64, device=original.device)
else:
fixed = state_dict[position_id_key].to(torch.int64)
print(f"Warning: Fixed position_ids dtype from {original.dtype} to {fixed.dtype}")
state_dict[position_id_key] = fixed
def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer):
sd_model_hash = checkpoint_info.calculate_shorthash()
timer.record("calculate hash")
@ -414,6 +520,9 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
else:
model.ztsnr = False
fix_position_ids(state_dict)
if model.is_sdxl:
sd_models_xl.extend_sdxl(model)
@ -427,6 +536,9 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
if hasattr(model, "before_load_weights"):
model.before_load_weights(state_dict)
# get all dtypes of state_dict
state_dict_dtype = get_state_dict_dtype(state_dict)
model.load_state_dict(state_dict, strict=False)
timer.record("apply weights to model")
@ -452,7 +564,13 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
model.to(memory_format=torch.channels_last)
timer.record("apply channels_last")
if shared.cmd_opts.no_half:
# check dtype of vae
dtype_vae = get_vae_dtype(state_dict_dtype=state_dict_dtype)
found_unet_dtype = get_loadable_dtype("model.diffusion_model.", state_dict_dtype=state_dict_dtype)
unet_has_float = found_unet_dtype in (torch.float16, torch.float32, torch.bfloat16)
if (found_unet_dtype is None or unet_has_float) and shared.cmd_opts.no_half:
# unet type is not detected or unet has float dtypes
model.float()
model.alphas_cumprod_original = model.alphas_cumprod
devices.dtype_unet = torch.float32
@ -462,8 +580,11 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
vae = model.first_stage_model
depth_model = getattr(model, 'depth_model', None)
if dtype_vae == torch.bfloat16 and dtype_vae in devices.supported_vae_dtypes:
# preserve bfloat16 if it supported
model.first_stage_model = None
# with --no-half-vae, remove VAE from model when doing half() to prevent its weights from being converted to float16
if shared.cmd_opts.no_half_vae:
elif shared.cmd_opts.no_half_vae:
model.first_stage_model = None
# with --upcast-sampling, don't convert the depth model weights to float16
if shared.cmd_opts.upcast_sampling and depth_model:
@ -471,15 +592,28 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
alphas_cumprod = model.alphas_cumprod
model.alphas_cumprod = None
model.half()
if found_unet_dtype in (torch.float16, torch.float32, torch.bfloat16):
model.half()
elif found_unet_dtype in (torch.float8_e4m3fn, torch.float8_e5m2):
pass
else:
print("Fail to get a vaild UNet dtype. ignore...")
model.alphas_cumprod = alphas_cumprod
model.alphas_cumprod_original = alphas_cumprod
model.first_stage_model = vae
if depth_model:
model.depth_model = depth_model
devices.dtype_unet = torch.float16
timer.record("apply half()")
if found_unet_dtype in (torch.float16, torch.float32):
devices.dtype_unet = torch.float16
timer.record("apply half()")
else:
print(f"load Unet {found_unet_dtype} as is ...")
devices.dtype_unet = found_unet_dtype if found_unet_dtype else torch.float16
timer.record("load UNet")
apply_alpha_schedule_override(model)
@ -489,10 +623,18 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
if hasattr(module, 'fp16_bias'):
del module.fp16_bias
if check_fp8(model):
if found_unet_dtype not in (torch.float8_e4m3fn,torch.float8_e5m2) and check_fp8(model):
devices.fp8 = True
# do not convert vae, text_encoders.clip_l, clip_g, t5xxl
first_stage = model.first_stage_model
model.first_stage_model = None
vae = getattr(model, 'vae', None)
if vae is not None:
model.vae = None
text_encoders = getattr(model, 'text_encoders', None)
if text_encoders is not None:
model.text_encoders = None
for module in model.modules():
if isinstance(module, (torch.nn.Conv2d, torch.nn.Linear)):
if shared.opts.cache_fp16_weight:
@ -500,6 +642,10 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
if module.bias is not None:
module.fp16_bias = module.bias.data.clone().cpu().half()
module.to(torch.float8_e4m3fn)
if text_encoders is not None:
model.text_encoders = text_encoders
if vae is not None:
model.vae = vae
model.first_stage_model = first_stage
timer.record("apply fp8")
else:
@ -507,8 +653,16 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
devices.unet_needs_upcast = shared.cmd_opts.upcast_sampling and devices.dtype == torch.float16 and devices.dtype_unet == torch.float16
model.first_stage_model.to(devices.dtype_vae)
timer.record("apply dtype to VAE")
# check supported vae dtype
dtype_vae = get_vae_dtype(state_dict_dtype=state_dict_dtype)
if dtype_vae == torch.bfloat16 and dtype_vae in devices.supported_vae_dtypes:
devices.dtype_vae = torch.bfloat16
print(f"VAE dtype {dtype_vae} detected. load as is.")
else:
# use default devices.dtype_vae
model.first_stage_model.to(devices.dtype_vae)
print(f"Use VAE dtype {devices.dtype_vae}")
timer.record("apply dtype to VAE")
# clean up cache if limit is reached
while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache:
@ -661,6 +815,9 @@ sd1_clip_weight = 'cond_stage_model.transformer.text_model.embeddings.token_embe
sd2_clip_weight = 'cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight'
sdxl_clip_weight = 'conditioner.embedders.1.model.ln_final.weight'
sdxl_refiner_clip_weight = 'conditioner.embedders.0.model.ln_final.weight'
clip_l_clip_weight = 'text_encoders.clip_l.transformer.text_model.final_layer_norm.weight'
clip_g_clip_weight = 'text_encoders.clip_g.transformer.text_model.final_layer_norm.weight'
t5xxl_clip_weight = 'text_encoders.t5xxl.transformer.encoder.final_layer_norm.weight'
class SdModelData:
@ -793,7 +950,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None, checkpoint_
if not checkpoint_config:
checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info)
clip_is_included_into_sd = any(x for x in [sd1_clip_weight, sd2_clip_weight, sdxl_clip_weight, sdxl_refiner_clip_weight] if x in state_dict)
clip_is_included_into_sd = any(x for x in [sd1_clip_weight, sd2_clip_weight, sdxl_clip_weight, sdxl_refiner_clip_weight, clip_l_clip_weight, clip_g_clip_weight ] if x in state_dict)
timer.record("find config")
@ -804,6 +961,18 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None, checkpoint_
print(f"Creating model from config: {checkpoint_config}")
# get all dtypes of state_dict
state_dict_dtype = get_state_dict_dtype(state_dict)
# check loadable unet dtype before loading
loadable_unet_dtype = get_loadable_dtype("model.diffusion_model.", state_dict_dtype=state_dict_dtype)
# check dtype of vae
dtype_vae = get_vae_dtype(state_dict_dtype=state_dict_dtype)
if dtype_vae == torch.bfloat16 and dtype_vae in devices.supported_vae_dtypes:
devices.dtype_vae = torch.bfloat16
print(f"VAE dtype {dtype_vae} detected.")
sd_model = None
try:
with sd_disable_initialization.DisableInitialization(disable_clip=clip_is_included_into_sd or shared.cmd_opts.do_not_download_clip):
@ -828,8 +997,10 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None, checkpoint_
else:
weight_dtype_conversion = {
'first_stage_model': None,
'text_encoders': None,
'vae': None,
'alphas_cumprod': None,
'': torch.float16,
'': torch.float16 if loadable_unet_dtype in (torch.float16, torch.float32, torch.bfloat16) else None,
}
with sd_disable_initialization.LoadStateDictOnMeta(state_dict, device=model_target_device(sd_model), weight_dtype_conversion=weight_dtype_conversion):
@ -856,7 +1027,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None, checkpoint_
timer.record("scripts callbacks")
with devices.autocast(), torch.no_grad():
with devices.autocast(target_dtype=devices.dtype_inference), torch.no_grad():
sd_model.cond_stage_model_empty_prompt = get_empty_cond(sd_model)
timer.record("calculate empty prompt")

View file

@ -25,6 +25,7 @@ config_instruct_pix2pix = os.path.join(sd_configs_path, "instruct-pix2pix.yaml")
config_alt_diffusion = os.path.join(sd_configs_path, "alt-diffusion-inference.yaml")
config_alt_diffusion_m18 = os.path.join(sd_configs_path, "alt-diffusion-m18-inference.yaml")
config_sd3 = os.path.join(sd_configs_path, "sd3-inference.yaml")
config_flux1 = os.path.join(sd_configs_path, "flux1-inference.yaml")
def is_using_v_parameterization_for_sd2(state_dict):
@ -78,6 +79,9 @@ def guess_model_config_from_state_dict(sd, filename):
if "model.diffusion_model.x_embedder.proj.weight" in sd:
return config_sd3
if "model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale" in sd:
return config_flux1
if sd.get('conditioner.embedders.1.model.ln_final.weight', None) is not None:
if diffusion_model_input.shape[1] == 9:
return config_sdxl_inpainting

View file

@ -36,5 +36,8 @@ class WebuiSdModel(LatentDiffusion):
is_sd3: bool
"""True if the model's architecture is SD 3"""
is_flux1: bool
"""True if the model's architecture is FLUX 1"""
latent_channels: int
"""number of layer in latent image representation; will be 16 in SD3 and 4 in other version"""

View file

@ -18,7 +18,7 @@ def get_learned_conditioning(self: sgm.models.diffusion.DiffusionEngine, batch:
is_negative_prompt = getattr(batch, 'is_negative_prompt', False)
aesthetic_score = shared.opts.sdxl_refiner_low_aesthetic_score if is_negative_prompt else shared.opts.sdxl_refiner_high_aesthetic_score
devices_args = dict(device=devices.device, dtype=devices.dtype)
devices_args = dict(device=devices.device, dtype=devices.dtype_inference)
sdxl_conds = {
"txt": batch,

View file

@ -64,7 +64,7 @@ def single_sample_to_image(sample, approximation=None):
x_sample = samples_to_images_tensor(sample.unsqueeze(0), approximation)[0] * 0.5 + 0.5
x_sample = torch.clamp(x_sample, min=0.0, max=1.0)
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
x_sample = 255. * np.moveaxis(x_sample.to(dtype=devices.dtype).cpu().numpy(), 0, 2)
x_sample = x_sample.astype(np.uint8)
return Image.fromarray(x_sample)

View file

@ -197,47 +197,58 @@ def load_vae(model, vae_file=None, vae_source="from unknown source"):
cache_enabled = shared.opts.sd_vae_checkpoint_cache > 0
loaded = False
if vae_file:
if cache_enabled and vae_file in checkpoints_loaded:
# use vae checkpoint cache
print(f"Loading VAE weights {vae_source}: cached {get_filename(vae_file)}")
store_base_vae(model)
_load_vae_dict(model, checkpoints_loaded[vae_file])
loaded = _load_vae_dict(model, checkpoints_loaded[vae_file])
else:
assert os.path.isfile(vae_file), f"VAE {vae_source} doesn't exist: {vae_file}"
print(f"Loading VAE weights {vae_source}: {vae_file}")
store_base_vae(model)
vae_dict_1 = load_vae_dict(vae_file, map_location=shared.weight_load_location)
_load_vae_dict(model, vae_dict_1)
loaded = _load_vae_dict(model, vae_dict_1)
if cache_enabled:
if loaded and cache_enabled:
# cache newly loaded vae
checkpoints_loaded[vae_file] = vae_dict_1.copy()
# clean up cache if limit is reached
if cache_enabled:
if loaded and cache_enabled:
while len(checkpoints_loaded) > shared.opts.sd_vae_checkpoint_cache + 1: # we need to count the current model
checkpoints_loaded.popitem(last=False) # LRU
# If vae used is not in dict, update it
# It will be removed on refresh though
vae_opt = get_filename(vae_file)
if vae_opt not in vae_dict:
if loaded and vae_opt not in vae_dict:
vae_dict[vae_opt] = vae_file
elif loaded_vae_file:
restore_base_vae(model)
loaded = True
loaded_vae_file = vae_file
if loaded:
loaded_vae_file = vae_file
model.base_vae = base_vae
model.loaded_vae_file = loaded_vae_file
return loaded
# don't call this from outside
def _load_vae_dict(model, vae_dict_1):
conv_out = model.first_stage_model.state_dict().get("encoder.conv_out.weight")
# check shape of "encoder.conv_out.weight". SD1.5/SDXL: [8, 512, 3, 3], FLUX/SD3: [32, 512, 3, 3]
if conv_out.shape != vae_dict_1["encoder.conv_out.weight"].shape:
print("Failed to load VAE. Size mismatched!")
return False
model.first_stage_model.load_state_dict(vae_dict_1)
model.first_stage_model.to(devices.dtype_vae)
return True
def clear_loaded_vae():
@ -270,7 +281,7 @@ def reload_vae_weights(sd_model=None, vae_file=unspecified):
sd_hijack.model_hijack.undo_hijack(sd_model)
load_vae(sd_model, vae_file, vae_source)
loaded = load_vae(sd_model, vae_file, vae_source)
sd_hijack.model_hijack.hijack(sd_model)
@ -279,5 +290,6 @@ def reload_vae_weights(sd_model=None, vae_file=unspecified):
script_callbacks.model_loaded_callback(sd_model)
print("VAE weights loaded.")
if loaded:
print("VAE weights loaded.")
return sd_model

View file

@ -44,6 +44,8 @@ def model():
model_name = "vaeapprox-sd3.pt"
elif shared.sd_model.is_sdxl:
model_name = "vaeapprox-sdxl.pt"
elif shared.sd_model.is_flux1:
model_name = "vaeapprox-sd3.pt"
else:
model_name = "model.pt"
@ -81,6 +83,18 @@ def cheap_approximation(sample):
[ 0.0815, 0.0846, 0.1207], [-0.0120, -0.0055, -0.0867],
[-0.0749, -0.0634, -0.0456], [-0.1418, -0.1457, -0.1259],
]
elif shared.sd_model.is_flux1:
coeffs = [
# from comfy
[-0.0404, 0.0159, 0.0609], [ 0.0043, 0.0298, 0.0850],
[ 0.0328, -0.0749, -0.0503], [-0.0245, 0.0085, 0.0549],
[ 0.0966, 0.0894, 0.0530], [ 0.0035, 0.0399, 0.0123],
[ 0.0583, 0.1184, 0.1262], [-0.0191, -0.0206, -0.0306],
[-0.0324, 0.0055, 0.1001], [ 0.0955, 0.0659, -0.0545],
[-0.0504, 0.0231, -0.0013], [ 0.0500, -0.0008, -0.0088],
[ 0.0982, 0.0941, 0.0976], [-0.1233, -0.0280, -0.0897],
[-0.0005, -0.0530, -0.0020], [-0.1273, -0.0932, -0.0680],
]
elif shared.sd_model.is_sdxl:
coeffs = [
[ 0.3448, 0.4168, 0.4395],

View file

@ -63,7 +63,7 @@ class TAESDDecoder(nn.Module):
super().__init__()
if latent_channels is None:
latent_channels = 16 if "taesd3" in str(decoder_path) else 4
latent_channels = 16 if any(typ in str(decoder_path) for typ in ("taesd3", "taef1")) else 4
self.decoder = decoder(latent_channels)
self.decoder.load_state_dict(
@ -79,7 +79,7 @@ class TAESDEncoder(nn.Module):
super().__init__()
if latent_channels is None:
latent_channels = 16 if "taesd3" in str(encoder_path) else 4
latent_channels = 16 if any(typ in str(encoder_path) for typ in ("taesd3", "taef1")) else 4
self.encoder = encoder(latent_channels)
self.encoder.load_state_dict(
@ -97,6 +97,8 @@ def download_model(model_path, model_url):
def decoder_model():
if shared.sd_model.is_sd3:
model_name = "taesd3_decoder.pth"
elif shared.sd_model.is_flux1:
model_name = "taef1_decoder.pth"
elif shared.sd_model.is_sdxl:
model_name = "taesdxl_decoder.pth"
else:
@ -122,6 +124,8 @@ def decoder_model():
def encoder_model():
if shared.sd_model.is_sd3:
model_name = "taesd3_encoder.pth"
elif shared.sd_model.is_flux1:
model_name = "taef1_encoder.pth"
elif shared.sd_model.is_sdxl:
model_name = "taesdxl_encoder.pth"
else:

View file

@ -29,7 +29,7 @@ def initialize():
devices.dtype = torch.float32 if cmd_opts.no_half else torch.float16
devices.dtype_vae = torch.float32 if cmd_opts.no_half or cmd_opts.no_half_vae else torch.float16
devices.dtype_inference = torch.float32 if cmd_opts.precision == 'full' else devices.dtype
devices.dtype_inference = torch.float32 if cmd_opts.precision == 'full' else devices.dtype_inference
if cmd_opts.precision == "half":
msg = "--no-half and --no-half-vae conflict with --precision half"

View file

@ -196,6 +196,9 @@ options_templates.update(options_section(('sdxl', "Stable Diffusion XL", "sd"),
options_templates.update(options_section(('sd3', "Stable Diffusion 3", "sd"), {
"sd3_enable_t5": OptionInfo(False, "Enable T5").info("load T5 text encoder; increases VRAM use by a lot, potentially improving quality of generation; requires model reload to apply"),
}))
options_templates.update(options_section(('flux', "Stable Diffusion FLUX", "sd"), {
"flux_enable_t5": OptionInfo(False, "Enable T5").info("load T5 text encoder; increases VRAM use by a lot, potentially improving quality of generation; requires model reload to apply"),
}))
options_templates.update(options_section(('vae', "VAE", "sd"), {
"sd_vae_explanation": OptionHTML("""
@ -243,6 +246,7 @@ options_templates.update(options_section(('optimizations', "Optimizations", "sd"
"batch_cond_uncond": OptionInfo(True, "Batch cond/uncond").info("do both conditional and unconditional denoising in one batch; uses a bit more VRAM during sampling, but improves speed; previously this was controlled by --always-batch-cond-uncond commandline argument"),
"fp8_storage": OptionInfo("Disable", "FP8 weight", gr.Radio, {"choices": ["Disable", "Enable for SDXL", "Enable"]}).info("Use FP8 to store Linear/Conv layers' weight. Require pytorch>=2.1.0."),
"cache_fp16_weight": OptionInfo(False, "Cache FP16 weight for LoRA").info("Cache fp16 weight when enabling FP8, will increase the quality of LoRA. Use more system ram."),
"lora_without_backup_weight": OptionInfo(False, "LoRA without backup weights").info("LoRA without backup weights to save RAM."),
}))
options_templates.update(options_section(('compatibility', "Compatibility", "sd"), {

View file

@ -6,6 +6,8 @@ import time
from modules import timer
from modules import initialize_util
from modules import initialize
from modules import manager
from threading import Thread
startup_timer = timer.startup_timer
startup_timer.record("launcher")
@ -14,6 +16,8 @@ initialize.imports()
initialize.check_versions()
initialize.initialize()
def create_api(app):
from modules.api.api import Api
@ -23,12 +27,10 @@ def create_api(app):
return api
def api_only():
def _api_only():
from fastapi import FastAPI
from modules.shared_cmd_options import cmd_opts
initialize.initialize()
app = FastAPI()
initialize_util.setup_middleware(app)
api = create_api(app)
@ -83,11 +85,10 @@ For more information see: https://github.com/AUTOMATIC1111/stable-diffusion-webu
{"!"*25} Warning {"!"*25}''')
def webui():
def _webui():
from modules.shared_cmd_options import cmd_opts
launch_api = cmd_opts.api
initialize.initialize()
from modules import shared, ui_tempdir, script_callbacks, ui, progress, ui_extra_networks
@ -177,6 +178,7 @@ def webui():
print("Stopping server...")
# If we catch a keyboard interrupt, we want to stop the server and exit.
shared.demo.close()
manager.task.stop()
break
# disable auto launch webui in browser for subsequent UI Reload
@ -193,6 +195,13 @@ def webui():
initialize.initialize_rest(reload_script_modules=True)
def api_only():
Thread(target=_api_only, daemon=True).start()
def webui():
Thread(target=_webui, daemon=True).start()
if __name__ == "__main__":
from modules.shared_cmd_options import cmd_opts
@ -200,3 +209,5 @@ if __name__ == "__main__":
api_only()
else:
webui()
manager.task.main_loop()