diff --git a/.github/workflows/run_tests.yaml b/.github/workflows/run_tests.yaml index 0610f4f54..eccf15e46 100644 --- a/.github/workflows/run_tests.yaml +++ b/.github/workflows/run_tests.yaml @@ -51,6 +51,7 @@ jobs: --test-server --do-not-download-clip --no-half + --precision full --disable-opt-split-attention --use-cpu all --api-server-stop diff --git a/configs/flux1-inference.yaml b/configs/flux1-inference.yaml new file mode 100644 index 000000000..f9bbe9073 --- /dev/null +++ b/configs/flux1-inference.yaml @@ -0,0 +1,4 @@ +model: + target: modules.models.flux.FLUX1Inferencer + params: + state_dict: null diff --git a/extensions-builtin/Lora/network_lora.py b/extensions-builtin/Lora/network_lora.py index a7a088949..2bc6af5d2 100644 --- a/extensions-builtin/Lora/network_lora.py +++ b/extensions-builtin/Lora/network_lora.py @@ -2,6 +2,7 @@ import torch import lyco_helpers import modules.models.sd3.mmdit +import modules.models.flux.modules.layers import network from modules import devices @@ -37,7 +38,7 @@ class NetworkModuleLora(network.NetworkModule): if weight is None and none_ok: return None - is_linear = type(self.sd_module) in [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear, torch.nn.MultiheadAttention, modules.models.sd3.mmdit.QkvLinear] + is_linear = type(self.sd_module) in [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear, torch.nn.MultiheadAttention, modules.models.sd3.mmdit.QkvLinear, modules.models.flux.modules.layers.QkvLinear ] is_conv = type(self.sd_module) in [torch.nn.Conv2d] if is_linear: diff --git a/extensions-builtin/Lora/networks.py b/extensions-builtin/Lora/networks.py index 67f9abe2a..948fa6740 100644 --- a/extensions-builtin/Lora/networks.py +++ b/extensions-builtin/Lora/networks.py @@ -37,7 +37,7 @@ module_types = [ re_digits = re.compile(r"\d+") -re_x_proj = re.compile(r"(.*)_([qkv]_proj)$") +re_x_proj = re.compile(r"(.*)_((?:[qkv]|mlp)_proj)$") re_compiled = {} suffix_conversion = { @@ -183,8 +183,12 @@ def load_network(name, network_on_disk): for key_network, weight in sd.items(): if diffusers_weight_map: - key_network_without_network_parts, network_name, network_weight = key_network.rsplit(".", 2) - network_part = network_name + '.' + network_weight + if key_network.startswith("lora_unet"): + key_network_without_network_parts, _, network_part = key_network.partition(".") + key_network_without_network_parts = key_network_without_network_parts.replace("lora_unet", "diffusion_model") + else: + key_network_without_network_parts, network_name, network_weight = key_network.rsplit(".", 2) + network_part = network_name + '.' + network_weight else: key_network_without_network_parts, _, network_part = key_network.partition(".") @@ -373,11 +377,13 @@ def allowed_layer_without_weight(layer): return False -def store_weights_backup(weight): +def store_weights_backup(weight, dtype): if weight is None: return None - return weight.to(devices.cpu, copy=True) + if shared.opts.lora_without_backup_weight: + return True + return weight.to(devices.cpu, dtype=dtype, copy=True) def restore_weights_backup(obj, field, weight): @@ -385,16 +391,20 @@ def restore_weights_backup(obj, field, weight): setattr(obj, field, None) return - getattr(obj, field).copy_(weight) + old_weight = getattr(obj, field) + old_weight.copy_(weight.to(dtype=old_weight.dtype)) -def network_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention]): +def network_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention], cleanup=False): weights_backup = getattr(self, "network_weights_backup", None) bias_backup = getattr(self, "network_bias_backup", None) if weights_backup is None and bias_backup is None: return + if shared.opts.lora_without_backup_weight: + return + if weights_backup is not None: if isinstance(self, torch.nn.MultiheadAttention): restore_weights_backup(self, 'in_proj_weight', weights_backup[0]) @@ -407,6 +417,51 @@ def network_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Li else: restore_weights_backup(self, 'bias', bias_backup) + if cleanup: + if weights_backup is not None: + del self.network_weights_backup + if bias_backup is not None: + del self.network_bias_backup + + +def network_backup_weights(self): + network_layer_name = getattr(self, 'network_layer_name', None) + + _current_names = getattr(self, "network_current_names", ()) + wanted_names = tuple((x.name, x.te_multiplier, x.unet_multiplier, x.dyn_dim) for x in loaded_networks) + + need_backup = False + for net in loaded_networks: + if network_layer_name in net.modules: + need_backup = True + break + elif network_layer_name + "_q_proj" in net.modules: + need_backup = True + break + + if not need_backup: + return + + weights_backup = getattr(self, "network_weights_backup", None) + if weights_backup is None and wanted_names != (): + if isinstance(self, torch.nn.MultiheadAttention): + weights_backup = (store_weights_backup(self.in_proj_weight, self.org_dtype), store_weights_backup(self.out_proj.weight, self.org_dtype)) + else: + weights_backup = store_weights_backup(self.weight, self.org_dtype) + + self.network_weights_backup = weights_backup + + bias_backup = getattr(self, "network_bias_backup", None) + if bias_backup is None and wanted_names != (): + if isinstance(self, torch.nn.MultiheadAttention) and self.out_proj.bias is not None: + bias_backup = store_weights_backup(self.out_proj.bias, self.org_dtype) + elif getattr(self, 'bias', None) is not None: + bias_backup = store_weights_backup(self.bias, self.org_dtype) + else: + bias_backup = None + + self.network_bias_backup = bias_backup + def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention]): """ @@ -424,38 +479,17 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn weights_backup = getattr(self, "network_weights_backup", None) if weights_backup is None and wanted_names != (): - if current_names != () and not allowed_layer_without_weight(self): - raise RuntimeError(f"{network_layer_name} - no backup weights found and current weights are not unchanged") - - if isinstance(self, torch.nn.MultiheadAttention): - weights_backup = (store_weights_backup(self.in_proj_weight), store_weights_backup(self.out_proj.weight)) - else: - weights_backup = store_weights_backup(self.weight) - - self.network_weights_backup = weights_backup - - bias_backup = getattr(self, "network_bias_backup", None) - if bias_backup is None and wanted_names != (): - if isinstance(self, torch.nn.MultiheadAttention) and self.out_proj.bias is not None: - bias_backup = store_weights_backup(self.out_proj.bias) - elif getattr(self, 'bias', None) is not None: - bias_backup = store_weights_backup(self.bias) - else: - bias_backup = None - - # Unlike weight which always has value, some modules don't have bias. - # Only report if bias is not None and current bias are not unchanged. - if bias_backup is not None and current_names != (): - raise RuntimeError("no backup bias found and current bias are not unchanged") - - self.network_bias_backup = bias_backup + network_backup_weights(self) + elif current_names != () and current_names != wanted_names and not getattr(self, "weights_restored", False): + network_restore_weights_from_backup(self) if current_names != wanted_names: - network_restore_weights_from_backup(self) + if hasattr(self, "weights_restored"): + self.weights_restored = False for net in loaded_networks: module = net.modules.get(network_layer_name, None) - if module is not None and hasattr(self, 'weight') and not isinstance(module, modules.models.sd3.mmdit.QkvLinear): + if module is not None and hasattr(self, 'weight') and not all(isinstance(module, linear) for linear in (modules.models.sd3.mmdit.QkvLinear, modules.models.flux.modules.layers.QkvLinear)): try: with torch.no_grad(): if getattr(self, 'fp16_weight', None) is None: @@ -478,6 +512,7 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn self.bias = torch.nn.Parameter(ex_bias).to(self.weight.dtype) else: self.bias.copy_((bias + ex_bias).to(dtype=self.bias.dtype)) + del weight, bias, updown, ex_bias except RuntimeError as e: logging.debug(f"Network {net.name} layer {network_layer_name}: {e}") extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1 @@ -515,7 +550,9 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn continue - if isinstance(self, modules.models.sd3.mmdit.QkvLinear) and module_q and module_k and module_v: + module_mlp = net.modules.get(network_layer_name + "_mlp_proj", None) + + if any(isinstance(self, linear) for linear in (modules.models.sd3.mmdit.QkvLinear, modules.models.flux.modules.layers.QkvLinear)) and module_q and module_k and module_v and module_mlp is None and self.weight.shape[0] // 3 == module_q.up_model.weight.shape[0]: try: with torch.no_grad(): # Send "real" orig_weight into MHA's lora module @@ -526,6 +563,31 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn del qw, kw, vw updown_qkv = torch.vstack([updown_q, updown_k, updown_v]) self.weight += updown_qkv + del updown_qkv + del updown_q, updown_k, updown_v + + except RuntimeError as e: + logging.debug(f"Network {net.name} layer {network_layer_name}: {e}") + extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1 + + continue + + if any(isinstance(self, linear) for linear in (modules.models.flux.modules.layers.QkvLinear,)) and module_q and module_k and module_v: + try: + with torch.no_grad(): + qw, kw, vw, mlp = torch.tensor_split(self.weight, (3072, 6144, 9216,), 0) + updown_q, _ = module_q.calc_updown(qw) + updown_k, _ = module_k.calc_updown(kw) + updown_v, _ = module_v.calc_updown(vw) + if module_mlp is not None: + updown_mlp, _ = module_mlp.calc_updown(mlp) + else: + updown_mlp = torch.zeros(3072 * 4, 3072, dtype=updown_q.dtype, device=updown_q.device) + del qw, kw, vw, mlp + updown_qkv_mlp = torch.vstack([updown_q, updown_k, updown_v, updown_mlp]) + self.weight += updown_qkv_mlp + del updown_qkv_mlp + del updown_q, updown_k, updown_v, updown_mlp except RuntimeError as e: logging.debug(f"Network {net.name} layer {network_layer_name}: {e}") @@ -539,7 +601,12 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn logging.debug(f"Network {net.name} layer {network_layer_name}: couldn't find supported operation") extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1 - self.network_current_names = wanted_names + + if shared.opts.lora_without_backup_weight: + self.network_weights_backup = None + self.network_bias_backup = None + else: + self.network_current_names = wanted_names def network_forward(org_module, input, original_forward): diff --git a/extensions-builtin/Lora/scripts/lora_script.py b/extensions-builtin/Lora/scripts/lora_script.py index d3ea369ae..7a23b8d57 100644 --- a/extensions-builtin/Lora/scripts/lora_script.py +++ b/extensions-builtin/Lora/scripts/lora_script.py @@ -1,15 +1,17 @@ import re +import torch import gradio as gr from fastapi import FastAPI +import gc import network import networks import lora # noqa:F401 import lora_patches import extra_networks_lora import ui_extra_networks_lora -from modules import script_callbacks, ui_extra_networks, extra_networks, shared +from modules import script_callbacks, ui_extra_networks, extra_networks, shared, scripts, devices def unload(): @@ -97,6 +99,82 @@ def infotext_pasted(infotext, d): d["Prompt"] = re.sub(re_lora, network_replacement, d["Prompt"]) +class ScriptLora(scripts.Script): + name = "Lora" + + def title(self): + return self.name + + def show(self, is_img2img): + return scripts.AlwaysVisible + + def after_extra_networks_activate(self, p, *args, **kwargs): + # check modules and setup org_dtype + modules = [] + if shared.sd_model.is_sdxl: + for _i, embedder in enumerate(shared.sd_model.conditioner.embedders): + if not hasattr(embedder, 'wrapped'): + continue + + for _name, module in embedder.wrapped.named_modules(): + if isinstance(module, (torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention)): + if hasattr(module, 'weight'): + modules.append(module) + elif isinstance(module, torch.nn.MultiheadAttention): + modules.append(module) + + else: + cond_stage_model = getattr(shared.sd_model.cond_stage_model, 'wrapped', shared.sd_model.cond_stage_model) + + for _name, module in cond_stage_model.named_modules(): + if isinstance(module, (torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention)): + if hasattr(module, 'weight'): + modules.append(module) + elif isinstance(module, torch.nn.MultiheadAttention): + modules.append(module) + + for _name, module in shared.sd_model.model.named_modules(): + if isinstance(module, (torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention)): + if hasattr(module, 'weight'): + modules.append(module) + elif isinstance(module, torch.nn.MultiheadAttention): + modules.append(module) + + print("Total lora modules after_extra_networks_activate() =", len(modules)) + + target_dtype = devices.dtype_inference + for module in modules: + network_layer_name = getattr(module, 'network_layer_name', None) + if network_layer_name is None: + continue + + if isinstance(module, torch.nn.MultiheadAttention): + org_dtype = torch.float32 + else: + org_dtype = None + for _name, param in module.named_parameters(): + if param.dtype != target_dtype: + org_dtype = param.dtype + break + + # set org_dtype + module.org_dtype = org_dtype + + # backup/restore weights + current_names = getattr(module, "network_current_names", ()) + wanted_names = tuple((x.name, x.te_multiplier, x.unet_multiplier, x.dyn_dim) for x in networks.loaded_networks) + + weights_backup = getattr(module, "network_weights_backup", None) + + if current_names == () and current_names != wanted_names and weights_backup is None: + networks.network_backup_weights(module) + elif current_names != () and current_names != wanted_names: + networks.network_restore_weights_from_backup(module, wanted_names == ()) + module.weights_restored = True + if current_names != wanted_names and wanted_names == (): + gc.collect() + + script_callbacks.on_infotext_pasted(infotext_pasted) shared.opts.onchange("lora_in_memory_limit", networks.purge_networks_from_memory) diff --git a/modules/call_queue.py b/modules/call_queue.py index 555c35312..b20badcaf 100644 --- a/modules/call_queue.py +++ b/modules/call_queue.py @@ -3,7 +3,7 @@ from functools import wraps import html import time -from modules import shared, progress, errors, devices, fifo_lock, profiling +from modules import shared, progress, errors, devices, fifo_lock, profiling, manager queue_lock = fifo_lock.FIFOLock() @@ -34,7 +34,7 @@ def wrap_gradio_gpu_call(func, extra_outputs=None): progress.start_task(id_task) try: - res = func(*args, **kwargs) + res = manager.task.run_and_wait_result(func, *args, **kwargs) progress.record_results(id_task, res) finally: progress.finish_task(id_task) diff --git a/modules/devices.py b/modules/devices.py index ee679141a..5b763ec85 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -1,5 +1,6 @@ import sys import contextlib +from copy import deepcopy from functools import lru_cache import torch @@ -128,6 +129,26 @@ dtype_unet: torch.dtype = torch.float16 dtype_inference: torch.dtype = torch.float16 unet_needs_upcast = False +supported_vae_dtypes = [torch.float16, torch.float32] + + +# prepare available dtypes +if torch.version.cuda: + if torch.cuda.is_bf16_supported() and torch.cuda.get_device_properties(torch.cuda.current_device()).major >= 8: + supported_vae_dtypes = [torch.bfloat16] + supported_vae_dtypes + if has_xpu(): + supported_vae_dtypes = [torch.bfloat16] + supported_vae_dtypes + + +def supports_non_blocking(): + if has_mps() or has_xpu(): + return False + + if npu_specific.has_npu: + return False + + return True + def cond_cast_unet(input): if force_fp16: @@ -146,17 +167,33 @@ patch_module_list = [ torch.nn.MultiheadAttention, torch.nn.GroupNorm, torch.nn.LayerNorm, + torch.nn.Embedding, ] -def manual_cast_forward(target_dtype): +def manual_cast_forward(target_dtype, target_device=None, copy=False): + params = dict() + if supports_non_blocking(): + params['non_blocking'] = True + + supported_cast_dtypes = [torch.float16, torch.float32] + if torch.cuda.is_bf16_supported(): + supported_cast_dtypes += [torch.bfloat16] + def forward_wrapper(self, *args, **kwargs): - if any( - isinstance(arg, torch.Tensor) and arg.dtype != target_dtype - for arg in args - ): - args = [arg.to(target_dtype) if isinstance(arg, torch.Tensor) else arg for arg in args] - kwargs = {k: v.to(target_dtype) if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()} + if target_device is not None: + params['device'] = target_device + params['dtype'] = target_dtype + + args = list(args) + for j in (i for i, arg in enumerate(args) if isinstance(arg, torch.Tensor) and arg.dtype != target_dtype): + if args[j].dtype in supported_cast_dtypes: + args[j] = args[j].to(**params) + args = tuple(args) + + for key in (k for k, v in kwargs.items() if isinstance(v, torch.Tensor) and v.dtype != target_dtype): + if kwargs[key].dtype in supported_cast_dtypes: + kwargs[key] = kwargs[key].to(**params) org_dtype = target_dtype for param in self.parameters(): @@ -164,38 +201,52 @@ def manual_cast_forward(target_dtype): org_dtype = param.dtype break - if org_dtype != target_dtype: - self.to(target_dtype) - result = self.org_forward(*args, **kwargs) - if org_dtype != target_dtype: - self.to(org_dtype) + if copy and not isinstance(self, torch.nn.Embedding): + copied = deepcopy(self) + if org_dtype != target_dtype: + copied.to(**params) + + result = copied.org_forward(*args, **kwargs) + del copied + else: + if org_dtype != target_dtype: + self.to(**params) + + result = self.org_forward(*args, **kwargs) + + if org_dtype != target_dtype: + params['dtype'] = org_dtype + self.to(**params) if target_dtype != dtype_inference: + params['dtype'] = dtype_inference if isinstance(result, tuple): result = tuple( - i.to(dtype_inference) + i.to(**params) if isinstance(i, torch.Tensor) else i for i in result ) elif isinstance(result, torch.Tensor): - result = result.to(dtype_inference) + result = result.to(**params) return result return forward_wrapper @contextlib.contextmanager -def manual_cast(target_dtype): +def manual_cast(target_dtype, target_device=None, copy=None): applied = False + + for module_type in patch_module_list: if hasattr(module_type, "org_forward"): continue applied = True org_forward = module_type.forward if module_type == torch.nn.MultiheadAttention: - module_type.forward = manual_cast_forward(torch.float32) + module_type.forward = manual_cast_forward(torch.float32, target_device, copy) else: - module_type.forward = manual_cast_forward(target_dtype) + module_type.forward = manual_cast_forward(target_dtype, target_device, copy) module_type.org_forward = org_forward try: yield None @@ -207,26 +258,37 @@ def manual_cast(target_dtype): delattr(module_type, "org_forward") -def autocast(disable=False): +def autocast(disable=False, current_dtype=None, target_dtype=None, target_device=None, copy=None): if disable: return contextlib.nullcontext() + copy = copy if copy is not None else shared.opts.lora_without_backup_weight + + if target_dtype is None: + target_dtype = dtype + if force_fp16: # No casting during inference if force_fp16 is enabled. # All tensor dtype conversion happens before inference. return contextlib.nullcontext() - if fp8 and device==cpu: + if fp8 and target_device==cpu: return torch.autocast("cpu", dtype=torch.bfloat16, enabled=True) if fp8 and dtype_inference == torch.float32: - return manual_cast(dtype) + return manual_cast(target_dtype, target_device, copy=copy) - if dtype == torch.float32 or dtype_inference == torch.float32: + if target_dtype != dtype_inference or copy: + return manual_cast(target_dtype, target_device, copy=copy) + + if current_dtype is not None and current_dtype != target_dtype: + return manual_cast(target_dtype, target_device, copy=copy) + + if target_dtype == torch.float32 or dtype_inference == torch.float32: return contextlib.nullcontext() if has_xpu() or has_mps() or cuda_no_autocast(): - return manual_cast(dtype) + return manual_cast(target_dtype, target_device) return torch.autocast("cuda") diff --git a/modules/launch_utils.py b/modules/launch_utils.py index cc97a67b0..efaefdc3e 100644 --- a/modules/launch_utils.py +++ b/modules/launch_utils.py @@ -495,11 +495,17 @@ def configure_for_tests(): def start(): print(f"Launching {'API server' if '--nowebui' in sys.argv else 'Web UI'} with arguments: {shlex.join(sys.argv[1:])}") import webui + + from modules import manager + if '--nowebui' in sys.argv: webui.api_only() else: webui.webui() + manager.task.main_loop() + return + def dump_sysinfo(): from modules import sysinfo diff --git a/modules/lowvram.py b/modules/lowvram.py index 6728c337b..9914a06c6 100644 --- a/modules/lowvram.py +++ b/modules/lowvram.py @@ -53,6 +53,7 @@ def setup_for_low_vram(sd_model, use_medvram): if module_in_gpu is not None: module_in_gpu.to(cpu) + devices.torch_gc() module.to(devices.device) module_in_gpu = module diff --git a/modules/manager.py b/modules/manager.py new file mode 100644 index 000000000..34c67c6b3 --- /dev/null +++ b/modules/manager.py @@ -0,0 +1,83 @@ +# +# based on forge's work from https://github.com/lllyasviel/stable-diffusion-webui-forge/blob/main/modules_forge/main_thread.py +# +# Original author comment: +# This file is the main thread that handles all gradio calls for major t2i or i2i processing. +# Other gradio calls (like those from extensions) are not influenced. +# By using one single thread to process all major calls, model moving is significantly faster. +# +# 2024/09/28 classified, + +import random +import string +import threading +import time + +from collections import OrderedDict + + +class Task: + def __init__(self, **kwargs): + self.__dict__.update(kwargs) + + +class TaskManager: + last_exception = None + pending_tasks = [] + finished_tasks = OrderedDict() + lock = None + running = False + + def __init__(self): + self.lock = threading.Lock() + + def work(self, task): + try: + task.result = task.func(*task.args, **task.kwargs) + except Exception as e: + task.exception = e + self.last_exception = e + + + def stop(self): + self.running = False + + + def main_loop(self): + self.running = True + while self.running: + time.sleep(0.01) + if len(self.pending_tasks) > 0: + with self.lock: + task = self.pending_tasks.pop(0) + + self.work(task) + + self.finished_tasks[task.task_id] = task + + + def push_task(self, func, *args, **kwargs): + if args and type(args[0]) == str and args[0].startswith("task(") and args[0].endswith(")"): + task_id = args[0] + else: + task_id = ''.join(random.choices(string.ascii_uppercase + string.digits, k=7)) + task = Task(task_id=task_id, func=func, args=args, kwargs=kwargs, result=None, exception=None) + self.pending_tasks.append(task) + + return task.task_id + + + def run_and_wait_result(self, func, *args, **kwargs): + current_id = self.push_task(func, *args, **kwargs) + + while True: + time.sleep(0.01) + if current_id in self.finished_tasks: + finished = self.finished_tasks.pop(current_id) + if finished.exception is not None: + raise finished.exception + + return finished.result + + +task = TaskManager() diff --git a/modules/models/flux/__init__.py b/modules/models/flux/__init__.py new file mode 100644 index 000000000..1cc52a00b --- /dev/null +++ b/modules/models/flux/__init__.py @@ -0,0 +1,5 @@ +from .flux import FLUX1Inferencer + +__all__ = [ + "FLUX1Inferencer", +] diff --git a/modules/models/flux/flux.py b/modules/models/flux/flux.py new file mode 100644 index 000000000..a7370af25 --- /dev/null +++ b/modules/models/flux/flux.py @@ -0,0 +1,360 @@ +import contextlib + +import os +import safetensors +import torch +import math + +import k_diffusion +from transformers import CLIPTokenizer + +from modules import shared, devices, modelloader, sd_hijack_clip + +from modules.models.sd3.sd3_impls import SDVAE +from modules.models.sd3.sd3_cond import CLIPL_CONFIG, T5_CONFIG, CLIPL_URL, T5_URL, SafetensorsMapping, Sd3T5 +from modules.models.sd3.other_impls import SDClipModel, T5XXLModel, SDTokenizer, T5XXLTokenizer +from PIL import Image + +from .model import Flux + + +class FluxTokenizer: + def __init__(self): + clip_tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") + self.clip_l = SDTokenizer(tokenizer=clip_tokenizer) + self.t5xxl = T5XXLTokenizer() + + def tokenize_with_weights(self, text:str): + out = {} + out["l"] = self.clip_l.tokenize_with_weights(text) + out["t5xxl"] = self.t5xxl.tokenize_with_weights(text) + return out + + +class Flux1ClipL(sd_hijack_clip.TextConditionalModel): + def __init__(self, clip_l): + super().__init__() + + self.clip_l = clip_l + + self.tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") + + empty = self.tokenizer('')["input_ids"] + self.id_start = empty[0] + self.id_end = empty[1] + self.id_pad = empty[1] + + self.return_pooled = True + + def tokenize(self, texts): + return self.tokenizer(texts, truncation=False, add_special_tokens=False)["input_ids"] + + def encode_with_transformers(self, tokens): + l_out, l_pooled = self.clip_l(tokens) + l_out = torch.cat([l_out], dim=-1) + l_out = torch.nn.functional.pad(l_out, (0, 4096 - l_out.shape[-1])) + + vector_out = torch.cat([l_pooled], dim=-1) + + l_out.pooled = vector_out + + return l_out + + def encode_embedding_init_text(self, init_text, nvpt): + return torch.zeros((nvpt, 768+1280), device=devices.device) # XXX + + + +class FluxCond(torch.nn.Module): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self.tokenizer = FluxTokenizer() + + with torch.no_grad(): + self.clip_l = SDClipModel(layer="hidden", layer_idx=-2, device="cpu", dtype=devices.dtype_inference, layer_norm_hidden_state=False, return_projected_pooled=False, textmodel_json_config=CLIPL_CONFIG) + + if shared.opts.flux_enable_t5: + self.t5xxl = T5XXLModel(T5_CONFIG, device="cpu", dtype=devices.dtype_inference) + else: + self.t5xxl = None + + self.model_l = Flux1ClipL(self.clip_l) + self.model_t5 = Sd3T5(self.t5xxl) + + def forward(self, prompts: list[str]): + with devices.without_autocast(): + l_out, vector_out = self.model_l(prompts) + t5_out = self.model_t5(prompts, token_count=l_out.shape[1]) + lt_out = torch.cat([l_out, t5_out], dim=-2) + + return { + 'crossattn': lt_out, + 'vector': vector_out, + } + + def before_load_weights(self, state_dict): + clip_path = os.path.join(shared.models_path, "CLIP") + + if 'text_encoders.clip_l.transformer.text_model.embeddings.position_embedding.weight' not in state_dict: + clip_l_file = modelloader.load_file_from_url(CLIPL_URL, model_dir=clip_path, file_name="clip_l.safetensors") + with safetensors.safe_open(clip_l_file, framework="pt") as file: + self.clip_l.transformer.load_state_dict(SafetensorsMapping(file), strict=False) + + if self.t5xxl and 'text_encoders.t5xxl.transformer.encoder.block.0.layer.0.SelfAttention.k.weight' not in state_dict: + t5_file = modelloader.load_file_from_url(T5_URL, model_dir=clip_path, file_name="t5xxl_fp8_e4m3fn.safetensors") + with safetensors.safe_open(t5_file, framework="pt") as file: + self.t5xxl.transformer.load_state_dict(SafetensorsMapping(file), strict=False) + + def encode_embedding_init_text(self, init_text, nvpt): + return self.model_l.encode_embedding_init_text(init_text, nvpt) + + def tokenize(self, texts): + return self.model_l.tokenize(texts) + + def medvram_modules(self): + return [self.clip_l, self.t5xxl] + + def get_token_count(self, text): + _, token_count = self.model_l.process_texts([text]) + + return token_count + + def get_target_prompt_token_count(self, token_count): + return self.model_l.get_target_prompt_token_count(token_count) + +def flux_time_shift(mu: float, sigma: float, t): + return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) + +class ModelSamplingFlux(torch.nn.Module): + def __init__(self, shift=1.15): + super().__init__() + + self.set_parameters(shift=shift) + + def set_parameters(self, shift=1.15, timesteps=10000): + self.shift = shift + ts = self.sigma((torch.arange(1, timesteps + 1, 1) / timesteps)) + self.register_buffer('sigmas', ts) + + @property + def sigma_min(self): + return self.sigmas[0] + + @property + def sigma_max(self): + return self.sigmas[-1] + + def timestep(self, sigma): + return sigma + + def sigma(self, timestep): + return flux_time_shift(self.shift, 1.0, timestep) + + def percent_to_sigma(self, percent): + if percent <= 0.0: + return 1.0 + if percent >= 1.0: + return 0.0 + return 1.0 - percent + + def calculate_denoised(self, sigma, model_output, model_input): + sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1)) + return model_input - model_output * sigma + + +class BaseModel(torch.nn.Module): + """Wrapper around the core FLUX model""" + def __init__(self, shift=1.15, device=None, dtype=torch.float16, state_dict=None, prefix="", **kwargs): + super().__init__() + + self.diffusion_model = Flux(device=device, dtype=dtype, **kwargs) + self.model_sampling = ModelSamplingFlux(shift=shift) + self.depth = kwargs['depth'] + self.depth_single_block = kwargs['depth_single_blocks'] + + def apply_model(self, x, sigma, c_crossattn=None, y=None): + dtype = self.get_dtype() + timestep = self.model_sampling.timestep(sigma).float() + guidance = torch.FloatTensor([3.5]).to(device=devices.device, dtype=torch.float32) + model_output = self.diffusion_model(x.to(dtype), timestep, context=c_crossattn.to(dtype), y=y.to(dtype), guidance=guidance).to(x.dtype) + return self.model_sampling.calculate_denoised(sigma, model_output, x) + + def forward(self, *args, **kwargs): + return self.apply_model(*args, **kwargs) + + def get_dtype(self): + return self.diffusion_model.dtype + + +class FLUX1LatentFormat: + """Latents are slightly shifted from center - this class must be called after VAE Decode to correct for the shift""" + def __init__(self, scale_factor=0.3611, shift_factor=0.1159): + self.scale_factor = scale_factor + self.shift_factor = shift_factor + + def process_in(self, latent): + return (latent - self.shift_factor) * self.scale_factor + + def process_out(self, latent): + return (latent / self.scale_factor) + self.shift_factor + + def decode_latent_to_preview(self, x0): + """Quick RGB approximate preview of sd3 latents""" + factors = torch.tensor([ + [-0.0404, 0.0159, 0.0609], [ 0.0043, 0.0298, 0.0850], + [ 0.0328, -0.0749, -0.0503], [-0.0245, 0.0085, 0.0549], + [ 0.0966, 0.0894, 0.0530], [ 0.0035, 0.0399, 0.0123], + [ 0.0583, 0.1184, 0.1262], [-0.0191, -0.0206, -0.0306], + [-0.0324, 0.0055, 0.1001], [ 0.0955, 0.0659, -0.0545], + [-0.0504, 0.0231, -0.0013], [ 0.0500, -0.0008, -0.0088], + [ 0.0982, 0.0941, 0.0976], [-0.1233, -0.0280, -0.0897], + [-0.0005, -0.0530, -0.0020], [-0.1273, -0.0932, -0.0680], + ], device="cpu") + latent_image = x0[0].permute(1, 2, 0).cpu() @ factors + + latents_ubyte = (((latent_image + 1) / 2) + .clamp(0, 1) # change scale from -1..1 to 0..1 + .mul(0xFF) # to 0..255 + .byte()).cpu() + + return Image.fromarray(latents_ubyte.numpy()) + + +class FLUX1Denoiser(k_diffusion.external.DiscreteSchedule): + def __init__(self, inner_model, sigmas): + super().__init__(sigmas, quantize=shared.opts.enable_quantization) + self.inner_model = inner_model + + def forward(self, input, sigma, **kwargs): + return self.inner_model.apply_model(input, sigma, **kwargs) + + +class FLUX1Inferencer(torch.nn.Module): + def __init__(self, state_dict, use_ema=False): + super().__init__() + + params = dict( + image_model="flux", + in_channels=16, + vec_in_dim=768, + context_in_dim=4096, + hidden_size=3072, + mlp_ratio=4.0, + num_heads=24, + depth=19, + depth_single_blocks=38, + axes_dim=[16, 56, 56], + theta=10000, + qkv_bias=True, + guidance_embed=True, + ) + + # detect model_prefix + diffusion_model_prefix = "model.diffusion_model." + if "model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale" in state_dict: + diffusion_model_prefix = "model.diffusion_model." + elif "double_blocks.0.img_attn.norm.key_norm.scale" in state_dict: + diffusion_model_prefix = "" + + shift=1.15 + # check guidance_in to detect Flux schnell + if f"{diffusion_model_prefix}guidance_in.in_layer.weight" not in state_dict: + print("Flux schnell detected") + params.update(dict(guidance_embed=False,)) + shift=1.0 + + with torch.no_grad(): + self.model = BaseModel(shift=shift, state_dict=state_dict, prefix=diffusion_model_prefix, device="cpu", dtype=devices.dtype_inference, **params) + self.first_stage_model = SDVAE(device="cpu", dtype=devices.dtype_vae) + self.first_stage_model.dtype = devices.dtype_vae + self.vae = self.first_stage_model # real vae + + self.alphas_cumprod = 1 / (self.model.model_sampling.sigmas ** 2 + 1) + + self.text_encoders = FluxCond() + self.cond_stage_key = 'txt' + + self.parameterization = "eps" + self.model.conditioning_key = "crossattn" + + self.latent_format = FLUX1LatentFormat() + self.latent_channels = 16 + + @property + def cond_stage_model(self): + return self.text_encoders + + def before_load_weights(self, state_dict): + self.cond_stage_model.before_load_weights(state_dict) + + def ema_scope(self): + return contextlib.nullcontext() + + def get_learned_conditioning(self, batch: list[str]): + return self.cond_stage_model(batch) + + def apply_model(self, x, t, cond): + return self.model(x, t, c_crossattn=cond['crossattn'], y=cond['vector']) + + def decode_first_stage(self, latent): + latent = self.latent_format.process_out(latent) + x = self.first_stage_model.decode(latent) + if x.dtype == torch.float16: + x = torch.nan_to_num(x, nan=0.0, posinf=65504, neginf=-65504) + return x + + def encode_first_stage(self, image): + latent = self.first_stage_model.encode(image) + return self.latent_format.process_in(latent) + + def get_first_stage_encoding(self, x): + return x + + def create_denoiser(self): + return FLUX1Denoiser(self, self.model.model_sampling.sigmas) + + def medvram_fields(self): + return [ + (self, 'first_stage_model'), + (self, 'text_encoders'), + (self, 'model'), + ] + + def add_noise_to_latent(self, x, noise, amount): + return x * (1 - amount) + noise * amount + + def fix_dimensions(self, width, height): + return width // 16 * 16, height // 16 * 16 + + def diffusers_weight_mapping(self): + # https://github.com/kohya-ss/sd-scripts/blob/a61cf73a5cb5209c3f4d1a3688dd276a4dfd1ecb/networks/convert_flux_lora.py + # please see also https://github.com/huggingface/diffusers/blob/main/src/diffusers/loaders/lora_conversion_utils.py + for i in range(self.model.depth): + yield f"transformer.transformer_blocks.{i}.attn.add_k_proj", f"diffusion_model_double_blocks_{i}_txt_attn_qkv_k_proj" + yield f"transformer.transformer_blocks.{i}.attn.add_q_proj", f"diffusion_model_double_blocks_{i}_txt_attn_qkv_q_proj" + yield f"transformer.transformer_blocks.{i}.attn.add_v_proj", f"diffusion_model_double_blocks_{i}_txt_attn_qkv_v_proj" + + yield f"transformer.transformer_blocks.{i}.attn.to_add_out", f"diffusion_model_double_blocks_{i}_txt_attn_proj" + + yield f"transformer.transformer_blocks.{i}.attn.to_k", f"diffusion_model_double_blocks_{i}_img_attn_qkv_k_proj" + yield f"transformer.transformer_blocks.{i}.attn.to_q", f"diffusion_model_double_blocks_{i}_img_attn_qkv_q_proj" + yield f"transformer.transformer_blocks.{i}.attn.to_v", f"diffusion_model_double_blocks_{i}_img_attn_qkv_v_proj" + + yield f"transformer.transformer_blocks.{i}.attn.to_out.0", f"diffusion_model_double_blocks_{i}_img_attn_proj" + + yield f"transformer.transformer_blocks.{i}.ff.net.0.proj", f"diffusion_model_double_blocks_{i}_img_mlp_0" + yield f"transformer.transformer_blocks.{i}.ff.net.2", f"diffusion_model_double_blocks_{i}_img_mlp_2" + yield f"transformer.transformer_blocks.{i}.ff_context.net.0.proj", f"diffusion_model_double_blocks_{i}_txt_mlp_0" + yield f"transformer.transformer_blocks.{i}.ff_context.net.2", f"diffusion_model_double_blocks_{i}_txt_mlp_2" + yield f"transformer.transformer_blocks.{i}.norm1.linear", f"diffusion_model_double_blocks_{i}_img_mod_lin" + yield f"transformer.transformer_blocks.{i}.norm1_context.linear", f"diffusion_model_double_blocks_{i}_txt_mod_lin" + + for i in range(self.model.depth_single_block): + yield f"transformer.single_transformer_blocks.{i}.attn.to_q", f"diffusion_model_single_blocks_{i}_linear1_q_proj" + yield f"transformer.single_transformer_blocks.{i}.attn.to_k", f"diffusion_model_single_blocks_{i}_linear1_k_proj" + yield f"transformer.single_transformer_blocks.{i}.attn.to_v", f"diffusion_model_single_blocks_{i}_linear1_v_proj" + yield f"transformer.single_transformer_blocks.{i}.proj_mlp", f"diffusion_model_single_blocks_{i}_linear1_mlp_proj" + + yield f"transformer.single_transformer_blocks.{i}.proj_out", f"diffusion_model_single_blocks_{i}_linear2" + yield f"transformer.single_transformer_blocks.{i}.norm.linear", f"diffusion_model_single_blocks_{i}_modulation_lin" diff --git a/modules/models/flux/math.py b/modules/models/flux/math.py new file mode 100644 index 000000000..4ad99818e --- /dev/null +++ b/modules/models/flux/math.py @@ -0,0 +1,30 @@ +import torch +from einops import rearrange +from torch import Tensor + + +def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor: + q, k = apply_rope(q, k, pe) + + x = torch.nn.functional.scaled_dot_product_attention(q, k, v) + x = rearrange(x, "B H L D -> B L (H D)") + + return x + + +def rope(pos: Tensor, dim: int, theta: int) -> Tensor: + assert dim % 2 == 0 + scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim + omega = 1.0 / (theta**scale) + out = torch.einsum("...n,d->...nd", pos, omega) + out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1) + out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2) + return out.float() + + +def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor): + xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2) + xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2) + xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1] + xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1] + return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk) diff --git a/modules/models/flux/model.py b/modules/models/flux/model.py new file mode 100644 index 000000000..c87397827 --- /dev/null +++ b/modules/models/flux/model.py @@ -0,0 +1,147 @@ +# original code from https://github.com/black-forest-labs/flux +# +from dataclasses import dataclass + +import torch + +from einops import rearrange, repeat +from torch import Tensor, nn + +from .modules.layers import (DoubleStreamBlock, EmbedND, LastLayer, + MLPEmbedder, SingleStreamBlock, + timestep_embedding) + + +@dataclass +class FluxParams: + in_channels: int + vec_in_dim: int + context_in_dim: int + hidden_size: int + mlp_ratio: float + num_heads: int + depth: int + depth_single_blocks: int + axes_dim: list + theta: int + qkv_bias: bool + guidance_embed: bool + + +class Flux(nn.Module): + """ + Transformer model for flow matching on sequences. + """ + + def __init__(self, image_model=None, final_layer=True, dtype=None, device=None, **kwargs): + super().__init__() + + self.dtype = dtype + params = FluxParams(**kwargs) + self.params = params + self.in_channels = params.in_channels * 2 * 2 + self.out_channels = self.in_channels + if params.hidden_size % params.num_heads != 0: + raise ValueError( + f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}" + ) + pe_dim = params.hidden_size // params.num_heads + if sum(params.axes_dim) != pe_dim: + raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}") + self.hidden_size = params.hidden_size + self.num_heads = params.num_heads + self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim) + self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True, dtype=dtype, device=device) + self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, dtype=dtype, device=device) + self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size, dtype=dtype, device=device) + self.guidance_in = ( + MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, dtype=dtype, device=device) if params.guidance_embed else nn.Identity() + ) + self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size, dtype=dtype, device=device) + + self.double_blocks = nn.ModuleList( + [ + DoubleStreamBlock( + self.hidden_size, + self.num_heads, + mlp_ratio=params.mlp_ratio, + qkv_bias=params.qkv_bias, + dtype=dtype, device=device, + ) + for _ in range(params.depth) + ] + ) + + self.single_blocks = nn.ModuleList( + [ + SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, dtype=dtype, device=device) + for _ in range(params.depth_single_blocks) + ] + ) + + if final_layer: + self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels, dtype=dtype, device=device) + + def forward_orig( + self, + img: Tensor, + img_ids: Tensor, + txt: Tensor, + txt_ids: Tensor, + timesteps: Tensor, + y: Tensor, + guidance: Tensor = None, + ) -> Tensor: + if img.ndim != 3 or txt.ndim != 3: + raise ValueError("Input img and txt tensors must have 3 dimensions.") + + # running on sequences img + img = self.img_in(img) + vec = self.time_in(timestep_embedding(timesteps, 256).to(img.dtype)) + if self.params.guidance_embed: + if guidance is None: + raise ValueError("Didn't get guidance strength for guidance distilled model.") + vec = vec + self.guidance_in(timestep_embedding(guidance, 256).to(img.dtype)) + vec = vec + self.vector_in(y) + txt = self.txt_in(txt) + + ids = torch.cat((txt_ids, img_ids), dim=1) + pe = self.pe_embedder(ids) + + for block in self.double_blocks: + img, txt = block(img=img, txt=txt, vec=vec, pe=pe) + + img = torch.cat((txt, img), 1) + for block in self.single_blocks: + img = block(img, vec=vec, pe=pe) + img = img[:, txt.shape[1] :, ...] + + if self.final_layer: + img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels) + return img + + def forward(self, x, timestep, context, y, guidance, **kwargs): + # from comfy/ldm/common_dit.py + def pad_to_patch_size(img, patch_size=(2, 2), padding_mode="circular"): + if padding_mode == "circular" and torch.jit.is_tracing() or torch.jit.is_scripting(): + padding_mode = "reflect" + pad_h = (patch_size[0] - img.shape[-2] % patch_size[0]) % patch_size[0] + pad_w = (patch_size[1] - img.shape[-1] % patch_size[1]) % patch_size[1] + return torch.nn.functional.pad(img, (0, pad_w, 0, pad_h), mode=padding_mode) + + bs, c, h, w = x.shape + patch_size = 2 + x = pad_to_patch_size(x, (patch_size, patch_size)) + + img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size) + + h_len = ((h + (patch_size // 2)) // patch_size) + w_len = ((w + (patch_size // 2)) // patch_size) + img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype) + img_ids[..., 1] = img_ids[..., 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype)[:, None] + img_ids[..., 2] = img_ids[..., 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype)[None, :] + img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs) + + txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype) + out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance) + return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)[:,:,:h,:w] diff --git a/modules/models/flux/modules/layers.py b/modules/models/flux/modules/layers.py new file mode 100644 index 000000000..2202f5dcb --- /dev/null +++ b/modules/models/flux/modules/layers.py @@ -0,0 +1,265 @@ +import math +from dataclasses import dataclass + +import torch +from einops import rearrange +from torch import Tensor, nn + +from ..math import attention, rope + + +class EmbedND(nn.Module): + def __init__(self, dim: int, theta: int, axes_dim: list): + super().__init__() + self.dim = dim + self.theta = theta + self.axes_dim = axes_dim + + def forward(self, ids: Tensor) -> Tensor: + n_axes = ids.shape[-1] + emb = torch.cat( + [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)], + dim=-3, + ) + + return emb.unsqueeze(1) + + +def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0): + """ + Create sinusoidal timestep embeddings. + :param t: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an (N, D) Tensor of positional embeddings. + """ + t = time_factor * t + half = dim // 2 + freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half) + + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + if torch.is_floating_point(t): + embedding = embedding.to(t) + return embedding + + +class MLPEmbedder(nn.Module): + def __init__(self, in_dim: int, hidden_dim: int, dtype=None, device=None): + super().__init__() + self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True, dtype=dtype, device=device) + self.silu = nn.SiLU() + self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True, dtype=dtype, device=device) + + def forward(self, x: Tensor) -> Tensor: + return self.out_layer(self.silu(self.in_layer(x))) + + +class RMSNorm(torch.nn.Module): + def __init__(self, dim: int, dtype=None, device=None): + super().__init__() + self.scale = nn.Parameter(torch.empty((dim), dtype=dtype, device=device)) + + def forward(self, x: Tensor): + x_dtype = x.dtype + x = x.float() + rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6) + return (x * rrms).to(dtype=x_dtype) * self.scale + + +class QKNorm(torch.nn.Module): + def __init__(self, dim: int, dtype=None, device=None): + super().__init__() + self.query_norm = RMSNorm(dim, dtype=dtype, device=device) + self.key_norm = RMSNorm(dim, dtype=dtype, device=device) + + def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple: + q = self.query_norm(q) + k = self.key_norm(k) + return q.to(v), k.to(v) + + +class QkvLinear(torch.nn.Linear): + pass + + +class SelfAttention(nn.Module): + def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False, dtype=None, device=None): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + + self.qkv = QkvLinear(dim, dim * 3, bias=qkv_bias, dtype=dtype, device=device) + self.norm = QKNorm(head_dim, dtype=dtype, device=device) + self.proj = nn.Linear(dim, dim, dtype=dtype, device=device) + + def forward(self, x: Tensor, pe: Tensor) -> Tensor: + qkv = self.qkv(x) + q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) + q, k = self.norm(q, k, v) + x = attention(q, k, v, pe=pe) + x = self.proj(x) + return x + + +@dataclass +class ModulationOut: + shift: Tensor + scale: Tensor + gate: Tensor + + +class Modulation(nn.Module): + def __init__(self, dim: int, double: bool, dtype=None, device=None): + super().__init__() + self.is_double = double + self.multiplier = 6 if double else 3 + self.lin = nn.Linear(dim, self.multiplier * dim, bias=True, dtype=dtype, device=device) + + def forward(self, vec: Tensor) -> tuple: + out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1) + + return ( + ModulationOut(*out[:3]), + ModulationOut(*out[3:]) if self.is_double else None, + ) + + +class DoubleStreamBlock(nn.Module): + def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, dtype=None, device=None): + super().__init__() + + mlp_hidden_dim = int(hidden_size * mlp_ratio) + self.num_heads = num_heads + self.hidden_size = hidden_size + self.img_mod = Modulation(hidden_size, double=True, dtype=dtype, device=device) + self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) + self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, dtype=dtype, device=device) + + self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) + self.img_mlp = nn.Sequential( + nn.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device), + nn.GELU(approximate="tanh"), + nn.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device), + ) + + self.txt_mod = Modulation(hidden_size, double=True, dtype=dtype, device=device) + self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) + self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, dtype=dtype, device=device) + + self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) + self.txt_mlp = nn.Sequential( + nn.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device), + nn.GELU(approximate="tanh"), + nn.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device), + ) + + def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor): + img_mod1, img_mod2 = self.img_mod(vec) + txt_mod1, txt_mod2 = self.txt_mod(vec) + + # prepare image for attention + img_modulated = self.img_norm1(img) + img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift + img_qkv = self.img_attn.qkv(img_modulated) + img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) + img_q, img_k = self.img_attn.norm(img_q, img_k, img_v) + + # prepare txt for attention + txt_modulated = self.txt_norm1(txt) + txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift + txt_qkv = self.txt_attn.qkv(txt_modulated) + txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) + txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v) + + # run actual attention + q = torch.cat((txt_q, img_q), dim=2) + k = torch.cat((txt_k, img_k), dim=2) + v = torch.cat((txt_v, img_v), dim=2) + + attn = attention(q, k, v, pe=pe) + txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :] + + # calculate the img bloks + img = img + img_mod1.gate * self.img_attn.proj(img_attn) + img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift) + + # calculate the txt bloks + txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn) + txt = txt + txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift) + + if txt.dtype == torch.float16: + txt = torch.nan_to_num(txt, nan=0.0, posinf=65504, neginf=-65504) + + + return img, txt + + +class SingleStreamBlock(nn.Module): + """ + A DiT block with parallel linear layers as described in + https://arxiv.org/abs/2302.05442 and adapted modulation interface. + """ + + def __init__( + self, + hidden_size: int, + num_heads: int, + mlp_ratio: float = 4.0, + qk_scale: float = None, + dtype=None, + device=None, + ): + super().__init__() + self.hidden_dim = hidden_size + self.num_heads = num_heads + head_dim = hidden_size // num_heads + self.scale = qk_scale or head_dim**-0.5 + + self.mlp_hidden_dim = int(hidden_size * mlp_ratio) + # qkv and mlp_in + self.linear1 = QkvLinear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim, dtype=dtype, device=device) + # proj and mlp_out + self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size, dtype=dtype, device=device) + + self.norm = QKNorm(head_dim, dtype=dtype, device=device) + + self.hidden_size = hidden_size + self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) + + self.mlp_act = nn.GELU(approximate="tanh") + self.modulation = Modulation(hidden_size, double=False, dtype=dtype, device=device) + + def forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor: + mod, _ = self.modulation(vec) + x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift + qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1) + + q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) + q, k = self.norm(q, k, v) + + # compute attention + attn = attention(q, k, v, pe=pe) + # compute activation in mlp stream, cat again and run second linear layer + output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2)) + x += mod.gate * output + if x.dtype == torch.float16: + x = torch.nan_to_num(x, nan=0.0, posinf=65504, neginf=-65504) + return x + + +class LastLayer(nn.Module): + def __init__(self, hidden_size: int, patch_size: int, out_channels: int, dtype=None, device=None): + super().__init__() + self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) + self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True, dtype=dtype, device=device) + self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True, dtype=dtype, device=device)) + + def forward(self, x: Tensor, vec: Tensor) -> Tensor: + shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1) + x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :] + x = self.linear(x) + return x diff --git a/modules/models/flux/util.py b/modules/models/flux/util.py new file mode 100644 index 000000000..9303eb7cf --- /dev/null +++ b/modules/models/flux/util.py @@ -0,0 +1,201 @@ +import os +from dataclasses import dataclass + +import torch +from einops import rearrange +from huggingface_hub import hf_hub_download +from imwatermark import WatermarkEncoder +from safetensors.torch import load_file as load_sft + +from .model import Flux, FluxParams +from .modules.autoencoder import AutoEncoder, AutoEncoderParams +from .modules.conditioner import HFEmbedder + + +@dataclass +class ModelSpec: + params: FluxParams + ae_params: AutoEncoderParams + ckpt_path: str | None + ae_path: str | None + repo_id: str | None + repo_flow: str | None + repo_ae: str | None + + +configs = { + "flux-dev": ModelSpec( + repo_id="black-forest-labs/FLUX.1-dev", + repo_flow="flux1-dev.safetensors", + repo_ae="ae.safetensors", + ckpt_path=os.getenv("FLUX_DEV"), + params=FluxParams( + in_channels=64, + vec_in_dim=768, + context_in_dim=4096, + hidden_size=3072, + mlp_ratio=4.0, + num_heads=24, + depth=19, + depth_single_blocks=38, + axes_dim=[16, 56, 56], + theta=10_000, + qkv_bias=True, + guidance_embed=True, + ), + ae_path=os.getenv("AE"), + ae_params=AutoEncoderParams( + resolution=256, + in_channels=3, + ch=128, + out_ch=3, + ch_mult=[1, 2, 4, 4], + num_res_blocks=2, + z_channels=16, + scale_factor=0.3611, + shift_factor=0.1159, + ), + ), + "flux-schnell": ModelSpec( + repo_id="black-forest-labs/FLUX.1-schnell", + repo_flow="flux1-schnell.safetensors", + repo_ae="ae.safetensors", + ckpt_path=os.getenv("FLUX_SCHNELL"), + params=FluxParams( + in_channels=64, + vec_in_dim=768, + context_in_dim=4096, + hidden_size=3072, + mlp_ratio=4.0, + num_heads=24, + depth=19, + depth_single_blocks=38, + axes_dim=[16, 56, 56], + theta=10_000, + qkv_bias=True, + guidance_embed=False, + ), + ae_path=os.getenv("AE"), + ae_params=AutoEncoderParams( + resolution=256, + in_channels=3, + ch=128, + out_ch=3, + ch_mult=[1, 2, 4, 4], + num_res_blocks=2, + z_channels=16, + scale_factor=0.3611, + shift_factor=0.1159, + ), + ), +} + + +def print_load_warning(missing: list[str], unexpected: list[str]) -> None: + if len(missing) > 0 and len(unexpected) > 0: + print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing)) + print("\n" + "-" * 79 + "\n") + print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected)) + elif len(missing) > 0: + print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing)) + elif len(unexpected) > 0: + print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected)) + + +def load_flow_model(name: str, device: str | torch.device = "cuda", hf_download: bool = True): + # Loading Flux + print("Init model") + ckpt_path = configs[name].ckpt_path + if ( + ckpt_path is None + and configs[name].repo_id is not None + and configs[name].repo_flow is not None + and hf_download + ): + ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_flow) + + with torch.device("meta" if ckpt_path is not None else device): + model = Flux(configs[name].params).to(torch.bfloat16) + + if ckpt_path is not None: + print("Loading checkpoint") + # load_sft doesn't support torch.device + sd = load_sft(ckpt_path, device=str(device)) + missing, unexpected = model.load_state_dict(sd, strict=False, assign=True) + print_load_warning(missing, unexpected) + return model + + +def load_t5(device: str | torch.device = "cuda", max_length: int = 512) -> HFEmbedder: + # max length 64, 128, 256 and 512 should work (if your sequence is short enough) + return HFEmbedder("google/t5-v1_1-xxl", max_length=max_length, torch_dtype=torch.bfloat16).to(device) + + +def load_clip(device: str | torch.device = "cuda") -> HFEmbedder: + return HFEmbedder("openai/clip-vit-large-patch14", max_length=77, torch_dtype=torch.bfloat16).to(device) + + +def load_ae(name: str, device: str | torch.device = "cuda", hf_download: bool = True) -> AutoEncoder: + ckpt_path = configs[name].ae_path + if ( + ckpt_path is None + and configs[name].repo_id is not None + and configs[name].repo_ae is not None + and hf_download + ): + ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_ae) + + # Loading the autoencoder + print("Init AE") + with torch.device("meta" if ckpt_path is not None else device): + ae = AutoEncoder(configs[name].ae_params) + + if ckpt_path is not None: + sd = load_sft(ckpt_path, device=str(device)) + missing, unexpected = ae.load_state_dict(sd, strict=False, assign=True) + print_load_warning(missing, unexpected) + return ae + + +class WatermarkEmbedder: + def __init__(self, watermark): + self.watermark = watermark + self.num_bits = len(WATERMARK_BITS) + self.encoder = WatermarkEncoder() + self.encoder.set_watermark("bits", self.watermark) + + def __call__(self, image: torch.Tensor) -> torch.Tensor: + """ + Adds a predefined watermark to the input image + + Args: + image: ([N,] B, RGB, H, W) in range [-1, 1] + + Returns: + same as input but watermarked + """ + image = 0.5 * image + 0.5 + squeeze = len(image.shape) == 4 + if squeeze: + image = image[None, ...] + n = image.shape[0] + image_np = rearrange((255 * image).detach().cpu(), "n b c h w -> (n b) h w c").numpy()[:, :, :, ::-1] + # torch (b, c, h, w) in [0, 1] -> numpy (b, h, w, c) [0, 255] + # watermarking libary expects input as cv2 BGR format + for k in range(image_np.shape[0]): + image_np[k] = self.encoder.encode(image_np[k], "dwtDct") + image = torch.from_numpy(rearrange(image_np[:, :, :, ::-1], "(n b) h w c -> n b c h w", n=n)).to( + image.device + ) + image = torch.clamp(image / 255, min=0.0, max=1.0) + if squeeze: + image = image[0] + image = 2 * image - 1 + return image + + +# A fixed 48-bit message that was chosen at random +WATERMARK_MESSAGE = 0b001010101111111010000111100111001111010100101110 +# bin(x)[2:] gives bits of x as str, use int to convert them to 0/1 +WATERMARK_BITS = [int(bit) for bit in bin(WATERMARK_MESSAGE)[2:]] +embed_watermark = WatermarkEmbedder(WATERMARK_BITS) diff --git a/modules/models/sd3/other_impls.py b/modules/models/sd3/other_impls.py index 78c1dc687..4524fa019 100644 --- a/modules/models/sd3/other_impls.py +++ b/modules/models/sd3/other_impls.py @@ -24,6 +24,11 @@ class AutocastLinear(nn.Linear): def forward(self, x): return torch.nn.functional.linear(x, self.weight.to(x.dtype), self.bias.to(x.dtype) if self.bias is not None else None) +class AutocastLayerNorm(nn.LayerNorm): + def forward(self, x): + return torch.nn.functional.layer_norm( + x, self.normalized_shape, self.weight.to(x.dtype), self.bias.to(x.dtype) if self.bias is not None else None, self.eps) + def attention(q, k, v, heads, mask=None): """Convenience wrapper around a basic attention operation""" @@ -41,9 +46,9 @@ class Mlp(nn.Module): out_features = out_features or in_features hidden_features = hidden_features or in_features - self.fc1 = nn.Linear(in_features, hidden_features, bias=bias, dtype=dtype, device=device) + self.fc1 = AutocastLinear(in_features, hidden_features, bias=bias, dtype=dtype, device=device) self.act = act_layer - self.fc2 = nn.Linear(hidden_features, out_features, bias=bias, dtype=dtype, device=device) + self.fc2 = AutocastLinear(hidden_features, out_features, bias=bias, dtype=dtype, device=device) def forward(self, x): x = self.fc1(x) @@ -61,10 +66,10 @@ class CLIPAttention(torch.nn.Module): def __init__(self, embed_dim, heads, dtype, device): super().__init__() self.heads = heads - self.q_proj = nn.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device) - self.k_proj = nn.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device) - self.v_proj = nn.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device) - self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device) + self.q_proj = AutocastLinear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device) + self.k_proj = AutocastLinear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device) + self.v_proj = AutocastLinear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device) + self.out_proj = AutocastLinear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device) def forward(self, x, mask=None): q = self.q_proj(x) @@ -82,9 +87,11 @@ ACTIVATIONS = { class CLIPLayer(torch.nn.Module): def __init__(self, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device): super().__init__() - self.layer_norm1 = nn.LayerNorm(embed_dim, dtype=dtype, device=device) + #self.layer_norm1 = nn.LayerNorm(embed_dim, dtype=dtype, device=device) + self.layer_norm1 = AutocastLayerNorm(embed_dim, dtype=dtype, device=device) self.self_attn = CLIPAttention(embed_dim, heads, dtype, device) - self.layer_norm2 = nn.LayerNorm(embed_dim, dtype=dtype, device=device) + self.layer_norm2 = AutocastLayerNorm(embed_dim, dtype=dtype, device=device) + #self.layer_norm2 = nn.LayerNorm(embed_dim, dtype=dtype, device=device) #self.mlp = CLIPMLP(embed_dim, intermediate_size, intermediate_activation, dtype, device) self.mlp = Mlp(embed_dim, intermediate_size, embed_dim, act_layer=ACTIVATIONS[intermediate_activation], dtype=dtype, device=device) @@ -131,7 +138,7 @@ class CLIPTextModel_(torch.nn.Module): super().__init__() self.embeddings = CLIPEmbeddings(embed_dim, dtype=torch.float32, device=device, textual_inversion_key=config_dict.get('textual_inversion_key', 'clip_l')) self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device) - self.final_layer_norm = nn.LayerNorm(embed_dim, dtype=dtype, device=device) + self.final_layer_norm = AutocastLayerNorm(embed_dim, dtype=dtype, device=device) def forward(self, input_tokens, intermediate_output=None, final_layer_norm_intermediate=True): x = self.embeddings(input_tokens) @@ -150,7 +157,7 @@ class CLIPTextModel(torch.nn.Module): self.num_layers = config_dict["num_hidden_layers"] self.text_model = CLIPTextModel_(config_dict, dtype, device) embed_dim = config_dict["hidden_size"] - self.text_projection = nn.Linear(embed_dim, embed_dim, bias=False, dtype=dtype, device=device) + self.text_projection = AutocastLinear(embed_dim, embed_dim, bias=False, dtype=dtype, device=device) self.text_projection.weight.copy_(torch.eye(embed_dim)) self.dtype = dtype @@ -370,7 +377,7 @@ class T5Attention(torch.nn.Module): if relative_attention_bias: self.relative_attention_num_buckets = 32 self.relative_attention_max_distance = 128 - self.relative_attention_bias = torch.nn.Embedding(self.relative_attention_num_buckets, self.num_heads, device=device) + self.relative_attention_bias = torch.nn.Embedding(self.relative_attention_num_buckets, self.num_heads, device=device, dtype=torch.float32) @staticmethod def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128): @@ -442,7 +449,7 @@ class T5Attention(torch.nn.Module): else: mask = None - out = attention(q, k * ((k.shape[-1] / self.num_heads) ** 0.5), v, self.num_heads, mask.to(x.dtype) if mask is not None else None) + out = attention(q, k * ((k.shape[-1] / self.num_heads) ** 0.5), v, self.num_heads, mask.to(q.dtype) if mask is not None else None) return self.o(out), past_bias @@ -475,19 +482,21 @@ class T5Block(torch.nn.Module): class T5Stack(torch.nn.Module): def __init__(self, num_layers, model_dim, inner_dim, ff_dim, num_heads, vocab_size, dtype, device): super().__init__() - self.embed_tokens = torch.nn.Embedding(vocab_size, model_dim, device=device) + #self.embed_tokens = torch.nn.Embedding(vocab_size, model_dim, device=device, dtype=torch.float32) self.block = torch.nn.ModuleList([T5Block(model_dim, inner_dim, ff_dim, num_heads, relative_attention_bias=(i == 0), dtype=dtype, device=device) for i in range(num_layers)]) self.final_layer_norm = T5LayerNorm(model_dim, dtype=dtype, device=device) - def forward(self, input_ids, intermediate_output=None, final_layer_norm_intermediate=True): + def forward(self, x, intermediate_output=None, final_layer_norm_intermediate=True): intermediate = None - x = self.embed_tokens(input_ids).to(torch.float32) # needs float32 or else T5 returns all zeroes + #x = self.embed_tokens(input_ids).to(torch.float32) # needs float32 or else T5 returns all zeroes + # some T5XXL do not embed_token. use shared token instead like comfy past_bias = None for i, layer in enumerate(self.block): x, past_bias = layer(x, past_bias) if i == intermediate_output: intermediate = x.clone() x = self.final_layer_norm(x) + x = torch.nan_to_num(x) if intermediate is not None and final_layer_norm_intermediate: intermediate = self.final_layer_norm(intermediate) return x, intermediate @@ -498,13 +507,18 @@ class T5(torch.nn.Module): super().__init__() self.num_layers = config_dict["num_layers"] self.encoder = T5Stack(self.num_layers, config_dict["d_model"], config_dict["d_model"], config_dict["d_ff"], config_dict["num_heads"], config_dict["vocab_size"], dtype, device) + self.shared = torch.nn.Embedding(config_dict["vocab_size"], config_dict["d_model"], device=device, dtype=torch.float32) self.dtype = dtype def get_input_embeddings(self): - return self.encoder.embed_tokens + #return self.encoder.embed_tokens + return self.shared def set_input_embeddings(self, embeddings): - self.encoder.embed_tokens = embeddings + #self.encoder.embed_tokens = embeddings + self.shared = embeddings - def forward(self, *args, **kwargs): - return self.encoder(*args, **kwargs) + def forward(self, input_ids, *args, **kwargs): + x = self.shared(input_ids).float() + x = torch.nan_to_num(x) + return self.encoder(x, *args, **kwargs) diff --git a/modules/models/sd3/sd3_cond.py b/modules/models/sd3/sd3_cond.py index 6a43f569b..fc0232325 100644 --- a/modules/models/sd3/sd3_cond.py +++ b/modules/models/sd3/sd3_cond.py @@ -43,7 +43,7 @@ CLIPG_CONFIG = { "textual_inversion_key": "clip_g", } -T5_URL = f"{shared.hf_endpoint}/AUTOMATIC/stable-diffusion-3-medium-text-encoders/resolve/main/t5xxl_fp16.safetensors" +T5_URL = f"{shared.hf_endpoint}/AUTOMATIC/stable-diffusion-3-medium-text-encoders/resolve/main/t5xxl_fp16_e4m3fn.safetensors" T5_CONFIG = { "d_ff": 10240, "d_model": 4096, @@ -140,7 +140,7 @@ class Sd3T5(torch.nn.Module): return tokens, multipliers def forward(self, texts, *, token_count): - if not self.t5xxl or not shared.opts.sd3_enable_t5: + if not self.t5xxl: return torch.zeros((len(texts), token_count, 4096), device=devices.device, dtype=devices.dtype) tokens_batch = [] @@ -164,11 +164,11 @@ class SD3Cond(torch.nn.Module): self.tokenizer = SD3Tokenizer() with torch.no_grad(): - self.clip_g = SDXLClipG(CLIPG_CONFIG, device="cpu", dtype=devices.dtype) - self.clip_l = SDClipModel(layer="hidden", layer_idx=-2, device="cpu", dtype=devices.dtype, layer_norm_hidden_state=False, return_projected_pooled=False, textmodel_json_config=CLIPL_CONFIG) + self.clip_g = SDXLClipG(CLIPG_CONFIG, device="cpu", dtype=devices.dtype_inference) + self.clip_l = SDClipModel(layer="hidden", layer_idx=-2, device="cpu", dtype=devices.dtype_inference, layer_norm_hidden_state=False, return_projected_pooled=False, textmodel_json_config=CLIPL_CONFIG) if shared.opts.sd3_enable_t5: - self.t5xxl = T5XXLModel(T5_CONFIG, device="cpu", dtype=devices.dtype) + self.t5xxl = T5XXLModel(T5_CONFIG, device="cpu", dtype=devices.dtype_inference) else: self.t5xxl = None @@ -199,8 +199,8 @@ class SD3Cond(torch.nn.Module): with safetensors.safe_open(clip_l_file, framework="pt") as file: self.clip_l.transformer.load_state_dict(SafetensorsMapping(file), strict=False) - if self.t5xxl and 'text_encoders.t5xxl.transformer.encoder.embed_tokens.weight' not in state_dict: - t5_file = modelloader.load_file_from_url(T5_URL, model_dir=clip_path, file_name="t5xxl_fp16.safetensors") + if self.t5xxl and 'text_encoders.t5xxl.transformer.encoder.block.0.layer.0.SelfAttention.k.weight' not in state_dict: + t5_file = modelloader.load_file_from_url(T5_URL, model_dir=clip_path, file_name="t5xxl_fp8_e4m3fn.safetensors") with safetensors.safe_open(t5_file, framework="pt") as file: self.t5xxl.transformer.load_state_dict(SafetensorsMapping(file), strict=False) diff --git a/modules/processing.py b/modules/processing.py index 92c3582cc..ed140983d 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -484,7 +484,7 @@ class StableDiffusionProcessing: cache = caches[0] - with devices.autocast(): + with devices.autocast(target_dtype=devices.dtype_inference): cache[1] = function(shared.sd_model, required_prompts, steps, hires_steps, shared.opts.use_old_scheduling) cache[0] = cached_params @@ -984,7 +984,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: sd_models.apply_alpha_schedule_override(p.sd_model, p) - with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast(): + with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast(target_dtype=devices.dtype_inference, current_dtype=devices.dtype_unet): samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts) if p.scripts is not None: @@ -1147,6 +1147,11 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: if p.scripts is not None: p.scripts.postprocess(p, res) + + if lowvram.is_enabled(shared.sd_model): + # for interrupted case + lowvram.send_everything_to_cpu() + return res @@ -1439,7 +1444,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): with devices.autocast(): extra_networks.activate(self, self.hr_extra_network_data) - with devices.autocast(): + with devices.autocast(target_dtype=devices.dtype_inference): self.calculate_hr_conds() sd_models.apply_token_merging(self.sd_model, self.get_token_merging_ratio(for_hr=True)) diff --git a/modules/sd_disable_initialization.py b/modules/sd_disable_initialization.py index 3750e85e9..0261db08f 100644 --- a/modules/sd_disable_initialization.py +++ b/modules/sd_disable_initialization.py @@ -160,7 +160,7 @@ class LoadStateDictOnMeta(ReplaceHelper): self.state_dict = state_dict self.device = device self.weight_dtype_conversion = weight_dtype_conversion or {} - self.default_dtype = self.weight_dtype_conversion.get('') + self.default_dtype = self.weight_dtype_conversion.get('', None) def get_weight_dtype(self, key): key_first_term, _ = key.split('.', 1) @@ -176,6 +176,11 @@ class LoadStateDictOnMeta(ReplaceHelper): def load_from_state_dict(original, module, state_dict, prefix, *args, **kwargs): used_param_keys = [] + if isinstance(module, (torch.nn.Linear, torch.nn.Conv2d, torch.nn.GroupNorm,)): + # HACK add assign=True to local_metadata for some cases + args[0]['assign_to_params_buffers'] = True + + for name, param in module._parameters.items(): if param is None: continue @@ -183,12 +188,16 @@ class LoadStateDictOnMeta(ReplaceHelper): key = prefix + name sd_param = sd.pop(key, None) if sd_param is not None: - state_dict[key] = sd_param.to(dtype=self.get_weight_dtype(key)) + dtype = self.get_weight_dtype(key) + if dtype is None: + state_dict[key] = sd_param + else: + state_dict[key] = sd_param.to(dtype=dtype) used_param_keys.append(key) if param.is_meta: dtype = sd_param.dtype if sd_param is not None else param.dtype - module._parameters[name] = torch.nn.parameter.Parameter(torch.zeros_like(param, device=device, dtype=dtype), requires_grad=param.requires_grad) + module._parameters[name] = torch.nn.parameter.Parameter(torch.empty_like(param, device=device, dtype=dtype), requires_grad=param.requires_grad) for name in module._buffers: key = prefix + name diff --git a/modules/sd_hijack_unet.py b/modules/sd_hijack_unet.py index b4f03b138..842030be8 100644 --- a/modules/sd_hijack_unet.py +++ b/modules/sd_hijack_unet.py @@ -42,12 +42,12 @@ def apply_model(orig_func, self, x_noisy, t, cond, **kwargs): if isinstance(cond, dict): for y in cond.keys(): if isinstance(cond[y], list): - cond[y] = [x.to(devices.dtype_unet) if isinstance(x, torch.Tensor) else x for x in cond[y]] + cond[y] = [x.to(devices.dtype_inference) if isinstance(x, torch.Tensor) else x for x in cond[y]] else: - cond[y] = cond[y].to(devices.dtype_unet) if isinstance(cond[y], torch.Tensor) else cond[y] + cond[y] = cond[y].to(devices.dtype_inference) if isinstance(cond[y], torch.Tensor) else cond[y] with devices.autocast(): - result = orig_func(self, x_noisy.to(devices.dtype_unet), t.to(devices.dtype_unet), cond, **kwargs) + result = orig_func(self, x_noisy.to(devices.dtype_inference), t.to(devices.dtype_inference), cond, **kwargs) if devices.unet_needs_upcast: return result.float() else: @@ -107,7 +107,7 @@ class GELUHijack(torch.nn.GELU, torch.nn.Module): torch.nn.GELU.__init__(self, *args, **kwargs) def forward(self, x): if devices.unet_needs_upcast: - return torch.nn.GELU.forward(self.float(), x.float()).to(devices.dtype_unet) + return torch.nn.GELU.forward(self.float(), x.float()).to(devices.dtype_inference) else: return torch.nn.GELU.forward(self, x) @@ -125,11 +125,11 @@ unet_needs_upcast = lambda *args, **kwargs: devices.unet_needs_upcast CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.apply_model', apply_model, unet_needs_upcast) CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', timestep_embedding) CondFunc('ldm.modules.attention.SpatialTransformer.forward', spatial_transformer_forward) -CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', lambda orig_func, timesteps, *args, **kwargs: orig_func(timesteps, *args, **kwargs).to(torch.float32 if timesteps.dtype == torch.int64 else devices.dtype_unet), unet_needs_upcast) +CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', lambda orig_func, timesteps, *args, **kwargs: orig_func(timesteps, *args, **kwargs).to(torch.float32 if timesteps.dtype == torch.int64 else devices.dtype_inference), unet_needs_upcast) if version.parse(torch.__version__) <= version.parse("1.13.2") or torch.cuda.is_available(): CondFunc('ldm.modules.diffusionmodules.util.GroupNorm32.forward', lambda orig_func, self, *args, **kwargs: orig_func(self.float(), *args, **kwargs), unet_needs_upcast) - CondFunc('ldm.modules.attention.GEGLU.forward', lambda orig_func, self, x: orig_func(self.float(), x.float()).to(devices.dtype_unet), unet_needs_upcast) + CondFunc('ldm.modules.attention.GEGLU.forward', lambda orig_func, self, x: orig_func(self.float(), x.float()).to(devices.dtype_inference), unet_needs_upcast) CondFunc('open_clip.transformer.ResidualAttentionBlock.__init__', lambda orig_func, *args, **kwargs: kwargs.update({'act_layer': GELUHijack}) and False or orig_func(*args, **kwargs), lambda _, *args, **kwargs: kwargs.get('act_layer') is None or kwargs['act_layer'] == torch.nn.GELU) first_stage_cond = lambda _, self, *args, **kwargs: devices.unet_needs_upcast and self.model.diffusion_model.dtype == torch.float16 @@ -146,7 +146,7 @@ def timestep_embedding_cast_result(orig_func, timesteps, *args, **kwargs): if devices.unet_needs_upcast and timesteps.dtype == torch.int64: dtype = torch.float32 else: - dtype = devices.dtype_unet + dtype = devices.dtype_inference return orig_func(timesteps, *args, **kwargs).to(dtype=dtype) diff --git a/modules/sd_models.py b/modules/sd_models.py index f4274ae42..d273951e2 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -34,6 +34,7 @@ class ModelType(enum.Enum): SDXL = 3 SSD = 4 SD3 = 5 + FLUX1 = 6 def replace_key(d, key, new_key, value): @@ -267,6 +268,30 @@ def get_state_dict_from_checkpoint(pl_sd): return pl_sd +def fix_unet_prefix(state_dict): + known_prefixes = ("model.diffusion_model.", "first_stage_model.", "cond_stage_model.", "conditioner", "vae.", "text_encoders.") + + for k in state_dict.keys(): + found = [prefix for prefix in known_prefixes if k.startswith(prefix)] + if len(found) > 0: + return state_dict + + # no known prefix found. + # in this case, this is a unet only state_dict + known_keys = ( + "input_blocks.0.0.weight", # SD1.5, SD2, SDXL + "joint_blocks.0.context_block.adaLN_modulation.1.weight", # SD3 + "double_blocks.0.img_attn.proj.weight", # FLUX + ) + + if any(key in state_dict for key in known_keys): + state_dict = {f"model.diffusion_model.{k}": v for k, v in state_dict.items()} + print("Fixed state_dict keys...") + return state_dict + + return state_dict + + def read_metadata_from_safetensors(filename): import json @@ -328,6 +353,7 @@ def get_checkpoint_state_dict(checkpoint_info: CheckpointInfo, timer): print(f"Loading weights [{sd_model_hash}] from {checkpoint_info.filename}") res = read_state_dict(checkpoint_info.filename) + res = fix_unet_prefix(res) timer.record("load weights from disk") return res @@ -355,7 +381,7 @@ def check_fp8(model): enable_fp8 = False elif shared.opts.fp8_storage == "Enable": enable_fp8 = True - elif getattr(model, "is_sdxl", False) and shared.opts.fp8_storage == "Enable for SDXL": + elif any(getattr(model, attr, False) for attr in ("is_sdxl", "is_flux1")) and shared.opts.fp8_storage == "Enable for SDXL": enable_fp8 = True else: enable_fp8 = False @@ -368,10 +394,14 @@ def set_model_type(model, state_dict): model.is_sdxl = False model.is_ssd = False model.is_sd3 = False + model.is_flux1 = False if "model.diffusion_model.x_embedder.proj.weight" in state_dict: model.is_sd3 = True model.model_type = ModelType.SD3 + elif "model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale" in state_dict: + model.is_flux1 = True + model.model_type = ModelType.FLUX1 elif hasattr(model, 'conditioner'): model.is_sdxl = True @@ -393,6 +423,82 @@ def set_model_fields(model): model.latent_channels = 4 +def get_state_dict_dtype(state_dict): + # detect dtypes of state_dict + state_dict_dtype = {} + + known_prefixes = ("model.diffusion_model.", "first_stage_model.", "cond_stage_model.", "conditioner", "vae.", "text_encoders.") + + for k in state_dict.keys(): + found = [prefix for prefix in known_prefixes if k.startswith(prefix)] + if len(found) > 0: + prefix = found[0] + dtype = state_dict[k].dtype + dtypes = state_dict_dtype.get(prefix, {}) + if dtype in dtypes: + dtypes[dtype] += 1 + else: + dtypes[dtype] = 1 + state_dict_dtype[prefix] = dtypes + + for prefix in state_dict_dtype: + dtypes = state_dict_dtype[prefix] + # sort by count + state_dict_dtype[prefix] = dict(sorted(dtypes.items(), key=lambda item: item[1], reverse=True)) + + print("Detected dtypes:", state_dict_dtype) + return state_dict_dtype + + +def get_loadable_dtype(prefix="model.diffusion_model.", state_dict=None, state_dict_dtype=None): + if state_dict is not None: + state_dict_dtype = get_state_dict_dtype(state_dict) + + # get the first dtype + if prefix in state_dict_dtype: + return list(state_dict_dtype[prefix])[0] + return None + + +def get_vae_dtype(state_dict=None, state_dict_dtype=None): + if state_dict is not None: + state_dict_dtype = get_state_dict_dtype(state_dict) + + if state_dict_dtype is None: + raise ValueError("fail to get vae dtype") + + + vae_prefixes = [prefix for prefix in ("vae.", "first_stage_model.") if prefix in state_dict_dtype] + + if len(vae_prefixes) > 0: + vae_prefix = vae_prefixes[0] + for dtype in state_dict_dtype[vae_prefix]: + if state_dict_dtype[vae_prefix][dtype] > 240 and dtype in (torch.float16, torch.float32, torch.bfloat16): + # vae items: 248 for SD1, SDXL 245 for flux + return dtype + + return None + + +def fix_position_ids(state_dict, force=False): + # for SD1.5 or some SDXL with position_ids + for prefix in ("cond_stage_models.", "conditioner.embedders.0."): + position_id_key = f"{prefix}transformer.text_model.embeddings.position_ids" + if position_id_key in state_dict: + original = state_dict[position_id_key] + if original.dtype == torch.int64: + return + + if force: + # regenerate + fixed = torch.tensor([list(range(77))], dtype=torch.int64, device=original.device) + else: + fixed = state_dict[position_id_key].to(torch.int64) + print(f"Warning: Fixed position_ids dtype from {original.dtype} to {fixed.dtype}") + + state_dict[position_id_key] = fixed + + def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer): sd_model_hash = checkpoint_info.calculate_shorthash() timer.record("calculate hash") @@ -414,6 +520,9 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer else: model.ztsnr = False + fix_position_ids(state_dict) + + if model.is_sdxl: sd_models_xl.extend_sdxl(model) @@ -427,6 +536,9 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer if hasattr(model, "before_load_weights"): model.before_load_weights(state_dict) + # get all dtypes of state_dict + state_dict_dtype = get_state_dict_dtype(state_dict) + model.load_state_dict(state_dict, strict=False) timer.record("apply weights to model") @@ -452,7 +564,13 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer model.to(memory_format=torch.channels_last) timer.record("apply channels_last") - if shared.cmd_opts.no_half: + # check dtype of vae + dtype_vae = get_vae_dtype(state_dict_dtype=state_dict_dtype) + found_unet_dtype = get_loadable_dtype("model.diffusion_model.", state_dict_dtype=state_dict_dtype) + unet_has_float = found_unet_dtype in (torch.float16, torch.float32, torch.bfloat16) + + if (found_unet_dtype is None or unet_has_float) and shared.cmd_opts.no_half: + # unet type is not detected or unet has float dtypes model.float() model.alphas_cumprod_original = model.alphas_cumprod devices.dtype_unet = torch.float32 @@ -462,8 +580,11 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer vae = model.first_stage_model depth_model = getattr(model, 'depth_model', None) + if dtype_vae == torch.bfloat16 and dtype_vae in devices.supported_vae_dtypes: + # preserve bfloat16 if it supported + model.first_stage_model = None # with --no-half-vae, remove VAE from model when doing half() to prevent its weights from being converted to float16 - if shared.cmd_opts.no_half_vae: + elif shared.cmd_opts.no_half_vae: model.first_stage_model = None # with --upcast-sampling, don't convert the depth model weights to float16 if shared.cmd_opts.upcast_sampling and depth_model: @@ -471,15 +592,28 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer alphas_cumprod = model.alphas_cumprod model.alphas_cumprod = None - model.half() + + + if found_unet_dtype in (torch.float16, torch.float32, torch.bfloat16): + model.half() + elif found_unet_dtype in (torch.float8_e4m3fn, torch.float8_e5m2): + pass + else: + print("Fail to get a vaild UNet dtype. ignore...") + model.alphas_cumprod = alphas_cumprod model.alphas_cumprod_original = alphas_cumprod model.first_stage_model = vae if depth_model: model.depth_model = depth_model - devices.dtype_unet = torch.float16 - timer.record("apply half()") + if found_unet_dtype in (torch.float16, torch.float32): + devices.dtype_unet = torch.float16 + timer.record("apply half()") + else: + print(f"load Unet {found_unet_dtype} as is ...") + devices.dtype_unet = found_unet_dtype if found_unet_dtype else torch.float16 + timer.record("load UNet") apply_alpha_schedule_override(model) @@ -489,10 +623,18 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer if hasattr(module, 'fp16_bias'): del module.fp16_bias - if check_fp8(model): + if found_unet_dtype not in (torch.float8_e4m3fn,torch.float8_e5m2) and check_fp8(model): devices.fp8 = True + + # do not convert vae, text_encoders.clip_l, clip_g, t5xxl first_stage = model.first_stage_model model.first_stage_model = None + vae = getattr(model, 'vae', None) + if vae is not None: + model.vae = None + text_encoders = getattr(model, 'text_encoders', None) + if text_encoders is not None: + model.text_encoders = None for module in model.modules(): if isinstance(module, (torch.nn.Conv2d, torch.nn.Linear)): if shared.opts.cache_fp16_weight: @@ -500,6 +642,10 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer if module.bias is not None: module.fp16_bias = module.bias.data.clone().cpu().half() module.to(torch.float8_e4m3fn) + if text_encoders is not None: + model.text_encoders = text_encoders + if vae is not None: + model.vae = vae model.first_stage_model = first_stage timer.record("apply fp8") else: @@ -507,8 +653,16 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer devices.unet_needs_upcast = shared.cmd_opts.upcast_sampling and devices.dtype == torch.float16 and devices.dtype_unet == torch.float16 - model.first_stage_model.to(devices.dtype_vae) - timer.record("apply dtype to VAE") + # check supported vae dtype + dtype_vae = get_vae_dtype(state_dict_dtype=state_dict_dtype) + if dtype_vae == torch.bfloat16 and dtype_vae in devices.supported_vae_dtypes: + devices.dtype_vae = torch.bfloat16 + print(f"VAE dtype {dtype_vae} detected. load as is.") + else: + # use default devices.dtype_vae + model.first_stage_model.to(devices.dtype_vae) + print(f"Use VAE dtype {devices.dtype_vae}") + timer.record("apply dtype to VAE") # clean up cache if limit is reached while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache: @@ -661,6 +815,9 @@ sd1_clip_weight = 'cond_stage_model.transformer.text_model.embeddings.token_embe sd2_clip_weight = 'cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight' sdxl_clip_weight = 'conditioner.embedders.1.model.ln_final.weight' sdxl_refiner_clip_weight = 'conditioner.embedders.0.model.ln_final.weight' +clip_l_clip_weight = 'text_encoders.clip_l.transformer.text_model.final_layer_norm.weight' +clip_g_clip_weight = 'text_encoders.clip_g.transformer.text_model.final_layer_norm.weight' +t5xxl_clip_weight = 'text_encoders.t5xxl.transformer.encoder.final_layer_norm.weight' class SdModelData: @@ -793,7 +950,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None, checkpoint_ if not checkpoint_config: checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info) - clip_is_included_into_sd = any(x for x in [sd1_clip_weight, sd2_clip_weight, sdxl_clip_weight, sdxl_refiner_clip_weight] if x in state_dict) + clip_is_included_into_sd = any(x for x in [sd1_clip_weight, sd2_clip_weight, sdxl_clip_weight, sdxl_refiner_clip_weight, clip_l_clip_weight, clip_g_clip_weight ] if x in state_dict) timer.record("find config") @@ -804,6 +961,18 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None, checkpoint_ print(f"Creating model from config: {checkpoint_config}") + # get all dtypes of state_dict + state_dict_dtype = get_state_dict_dtype(state_dict) + + # check loadable unet dtype before loading + loadable_unet_dtype = get_loadable_dtype("model.diffusion_model.", state_dict_dtype=state_dict_dtype) + + # check dtype of vae + dtype_vae = get_vae_dtype(state_dict_dtype=state_dict_dtype) + if dtype_vae == torch.bfloat16 and dtype_vae in devices.supported_vae_dtypes: + devices.dtype_vae = torch.bfloat16 + print(f"VAE dtype {dtype_vae} detected.") + sd_model = None try: with sd_disable_initialization.DisableInitialization(disable_clip=clip_is_included_into_sd or shared.cmd_opts.do_not_download_clip): @@ -828,8 +997,10 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None, checkpoint_ else: weight_dtype_conversion = { 'first_stage_model': None, + 'text_encoders': None, + 'vae': None, 'alphas_cumprod': None, - '': torch.float16, + '': torch.float16 if loadable_unet_dtype in (torch.float16, torch.float32, torch.bfloat16) else None, } with sd_disable_initialization.LoadStateDictOnMeta(state_dict, device=model_target_device(sd_model), weight_dtype_conversion=weight_dtype_conversion): @@ -856,7 +1027,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None, checkpoint_ timer.record("scripts callbacks") - with devices.autocast(), torch.no_grad(): + with devices.autocast(target_dtype=devices.dtype_inference), torch.no_grad(): sd_model.cond_stage_model_empty_prompt = get_empty_cond(sd_model) timer.record("calculate empty prompt") diff --git a/modules/sd_models_config.py b/modules/sd_models_config.py index 3c1e4a151..4251062c8 100644 --- a/modules/sd_models_config.py +++ b/modules/sd_models_config.py @@ -25,6 +25,7 @@ config_instruct_pix2pix = os.path.join(sd_configs_path, "instruct-pix2pix.yaml") config_alt_diffusion = os.path.join(sd_configs_path, "alt-diffusion-inference.yaml") config_alt_diffusion_m18 = os.path.join(sd_configs_path, "alt-diffusion-m18-inference.yaml") config_sd3 = os.path.join(sd_configs_path, "sd3-inference.yaml") +config_flux1 = os.path.join(sd_configs_path, "flux1-inference.yaml") def is_using_v_parameterization_for_sd2(state_dict): @@ -78,6 +79,9 @@ def guess_model_config_from_state_dict(sd, filename): if "model.diffusion_model.x_embedder.proj.weight" in sd: return config_sd3 + if "model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale" in sd: + return config_flux1 + if sd.get('conditioner.embedders.1.model.ln_final.weight', None) is not None: if diffusion_model_input.shape[1] == 9: return config_sdxl_inpainting diff --git a/modules/sd_models_types.py b/modules/sd_models_types.py index 2fce2777b..867f8b6e2 100644 --- a/modules/sd_models_types.py +++ b/modules/sd_models_types.py @@ -36,5 +36,8 @@ class WebuiSdModel(LatentDiffusion): is_sd3: bool """True if the model's architecture is SD 3""" + is_flux1: bool + """True if the model's architecture is FLUX 1""" + latent_channels: int """number of layer in latent image representation; will be 16 in SD3 and 4 in other version""" diff --git a/modules/sd_models_xl.py b/modules/sd_models_xl.py index 1242a5936..5abb75d1f 100644 --- a/modules/sd_models_xl.py +++ b/modules/sd_models_xl.py @@ -18,7 +18,7 @@ def get_learned_conditioning(self: sgm.models.diffusion.DiffusionEngine, batch: is_negative_prompt = getattr(batch, 'is_negative_prompt', False) aesthetic_score = shared.opts.sdxl_refiner_low_aesthetic_score if is_negative_prompt else shared.opts.sdxl_refiner_high_aesthetic_score - devices_args = dict(device=devices.device, dtype=devices.dtype) + devices_args = dict(device=devices.device, dtype=devices.dtype_inference) sdxl_conds = { "txt": batch, diff --git a/modules/sd_samplers_common.py b/modules/sd_samplers_common.py index c060cccb2..b312c41d8 100644 --- a/modules/sd_samplers_common.py +++ b/modules/sd_samplers_common.py @@ -64,7 +64,7 @@ def single_sample_to_image(sample, approximation=None): x_sample = samples_to_images_tensor(sample.unsqueeze(0), approximation)[0] * 0.5 + 0.5 x_sample = torch.clamp(x_sample, min=0.0, max=1.0) - x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2) + x_sample = 255. * np.moveaxis(x_sample.to(dtype=devices.dtype).cpu().numpy(), 0, 2) x_sample = x_sample.astype(np.uint8) return Image.fromarray(x_sample) diff --git a/modules/sd_vae.py b/modules/sd_vae.py index 43687e48d..6ae038333 100644 --- a/modules/sd_vae.py +++ b/modules/sd_vae.py @@ -197,47 +197,58 @@ def load_vae(model, vae_file=None, vae_source="from unknown source"): cache_enabled = shared.opts.sd_vae_checkpoint_cache > 0 + loaded = False if vae_file: if cache_enabled and vae_file in checkpoints_loaded: # use vae checkpoint cache print(f"Loading VAE weights {vae_source}: cached {get_filename(vae_file)}") store_base_vae(model) - _load_vae_dict(model, checkpoints_loaded[vae_file]) + loaded = _load_vae_dict(model, checkpoints_loaded[vae_file]) else: assert os.path.isfile(vae_file), f"VAE {vae_source} doesn't exist: {vae_file}" print(f"Loading VAE weights {vae_source}: {vae_file}") store_base_vae(model) vae_dict_1 = load_vae_dict(vae_file, map_location=shared.weight_load_location) - _load_vae_dict(model, vae_dict_1) + loaded = _load_vae_dict(model, vae_dict_1) - if cache_enabled: + if loaded and cache_enabled: # cache newly loaded vae checkpoints_loaded[vae_file] = vae_dict_1.copy() # clean up cache if limit is reached - if cache_enabled: + if loaded and cache_enabled: while len(checkpoints_loaded) > shared.opts.sd_vae_checkpoint_cache + 1: # we need to count the current model checkpoints_loaded.popitem(last=False) # LRU # If vae used is not in dict, update it # It will be removed on refresh though vae_opt = get_filename(vae_file) - if vae_opt not in vae_dict: + if loaded and vae_opt not in vae_dict: vae_dict[vae_opt] = vae_file elif loaded_vae_file: restore_base_vae(model) + loaded = True - loaded_vae_file = vae_file + if loaded: + loaded_vae_file = vae_file model.base_vae = base_vae model.loaded_vae_file = loaded_vae_file + return loaded # don't call this from outside def _load_vae_dict(model, vae_dict_1): + conv_out = model.first_stage_model.state_dict().get("encoder.conv_out.weight") + # check shape of "encoder.conv_out.weight". SD1.5/SDXL: [8, 512, 3, 3], FLUX/SD3: [32, 512, 3, 3] + if conv_out.shape != vae_dict_1["encoder.conv_out.weight"].shape: + print("Failed to load VAE. Size mismatched!") + return False + model.first_stage_model.load_state_dict(vae_dict_1) model.first_stage_model.to(devices.dtype_vae) + return True def clear_loaded_vae(): @@ -270,7 +281,7 @@ def reload_vae_weights(sd_model=None, vae_file=unspecified): sd_hijack.model_hijack.undo_hijack(sd_model) - load_vae(sd_model, vae_file, vae_source) + loaded = load_vae(sd_model, vae_file, vae_source) sd_hijack.model_hijack.hijack(sd_model) @@ -279,5 +290,6 @@ def reload_vae_weights(sd_model=None, vae_file=unspecified): script_callbacks.model_loaded_callback(sd_model) - print("VAE weights loaded.") + if loaded: + print("VAE weights loaded.") return sd_model diff --git a/modules/sd_vae_approx.py b/modules/sd_vae_approx.py index c5dda7431..48ffd672d 100644 --- a/modules/sd_vae_approx.py +++ b/modules/sd_vae_approx.py @@ -44,6 +44,8 @@ def model(): model_name = "vaeapprox-sd3.pt" elif shared.sd_model.is_sdxl: model_name = "vaeapprox-sdxl.pt" + elif shared.sd_model.is_flux1: + model_name = "vaeapprox-sd3.pt" else: model_name = "model.pt" @@ -81,6 +83,18 @@ def cheap_approximation(sample): [ 0.0815, 0.0846, 0.1207], [-0.0120, -0.0055, -0.0867], [-0.0749, -0.0634, -0.0456], [-0.1418, -0.1457, -0.1259], ] + elif shared.sd_model.is_flux1: + coeffs = [ + # from comfy + [-0.0404, 0.0159, 0.0609], [ 0.0043, 0.0298, 0.0850], + [ 0.0328, -0.0749, -0.0503], [-0.0245, 0.0085, 0.0549], + [ 0.0966, 0.0894, 0.0530], [ 0.0035, 0.0399, 0.0123], + [ 0.0583, 0.1184, 0.1262], [-0.0191, -0.0206, -0.0306], + [-0.0324, 0.0055, 0.1001], [ 0.0955, 0.0659, -0.0545], + [-0.0504, 0.0231, -0.0013], [ 0.0500, -0.0008, -0.0088], + [ 0.0982, 0.0941, 0.0976], [-0.1233, -0.0280, -0.0897], + [-0.0005, -0.0530, -0.0020], [-0.1273, -0.0932, -0.0680], + ] elif shared.sd_model.is_sdxl: coeffs = [ [ 0.3448, 0.4168, 0.4395], diff --git a/modules/sd_vae_taesd.py b/modules/sd_vae_taesd.py index d06253d2a..76771e95e 100644 --- a/modules/sd_vae_taesd.py +++ b/modules/sd_vae_taesd.py @@ -63,7 +63,7 @@ class TAESDDecoder(nn.Module): super().__init__() if latent_channels is None: - latent_channels = 16 if "taesd3" in str(decoder_path) else 4 + latent_channels = 16 if any(typ in str(decoder_path) for typ in ("taesd3", "taef1")) else 4 self.decoder = decoder(latent_channels) self.decoder.load_state_dict( @@ -79,7 +79,7 @@ class TAESDEncoder(nn.Module): super().__init__() if latent_channels is None: - latent_channels = 16 if "taesd3" in str(encoder_path) else 4 + latent_channels = 16 if any(typ in str(encoder_path) for typ in ("taesd3", "taef1")) else 4 self.encoder = encoder(latent_channels) self.encoder.load_state_dict( @@ -97,6 +97,8 @@ def download_model(model_path, model_url): def decoder_model(): if shared.sd_model.is_sd3: model_name = "taesd3_decoder.pth" + elif shared.sd_model.is_flux1: + model_name = "taef1_decoder.pth" elif shared.sd_model.is_sdxl: model_name = "taesdxl_decoder.pth" else: @@ -122,6 +124,8 @@ def decoder_model(): def encoder_model(): if shared.sd_model.is_sd3: model_name = "taesd3_encoder.pth" + elif shared.sd_model.is_flux1: + model_name = "taef1_encoder.pth" elif shared.sd_model.is_sdxl: model_name = "taesdxl_encoder.pth" else: diff --git a/modules/shared_init.py b/modules/shared_init.py index a6ad0433d..2d58c6374 100644 --- a/modules/shared_init.py +++ b/modules/shared_init.py @@ -29,7 +29,7 @@ def initialize(): devices.dtype = torch.float32 if cmd_opts.no_half else torch.float16 devices.dtype_vae = torch.float32 if cmd_opts.no_half or cmd_opts.no_half_vae else torch.float16 - devices.dtype_inference = torch.float32 if cmd_opts.precision == 'full' else devices.dtype + devices.dtype_inference = torch.float32 if cmd_opts.precision == 'full' else devices.dtype_inference if cmd_opts.precision == "half": msg = "--no-half and --no-half-vae conflict with --precision half" diff --git a/modules/shared_options.py b/modules/shared_options.py index 03632ecc0..3ba02d00f 100644 --- a/modules/shared_options.py +++ b/modules/shared_options.py @@ -196,6 +196,9 @@ options_templates.update(options_section(('sdxl', "Stable Diffusion XL", "sd"), options_templates.update(options_section(('sd3', "Stable Diffusion 3", "sd"), { "sd3_enable_t5": OptionInfo(False, "Enable T5").info("load T5 text encoder; increases VRAM use by a lot, potentially improving quality of generation; requires model reload to apply"), })) +options_templates.update(options_section(('flux', "Stable Diffusion FLUX", "sd"), { + "flux_enable_t5": OptionInfo(False, "Enable T5").info("load T5 text encoder; increases VRAM use by a lot, potentially improving quality of generation; requires model reload to apply"), +})) options_templates.update(options_section(('vae', "VAE", "sd"), { "sd_vae_explanation": OptionHTML(""" @@ -243,6 +246,7 @@ options_templates.update(options_section(('optimizations', "Optimizations", "sd" "batch_cond_uncond": OptionInfo(True, "Batch cond/uncond").info("do both conditional and unconditional denoising in one batch; uses a bit more VRAM during sampling, but improves speed; previously this was controlled by --always-batch-cond-uncond commandline argument"), "fp8_storage": OptionInfo("Disable", "FP8 weight", gr.Radio, {"choices": ["Disable", "Enable for SDXL", "Enable"]}).info("Use FP8 to store Linear/Conv layers' weight. Require pytorch>=2.1.0."), "cache_fp16_weight": OptionInfo(False, "Cache FP16 weight for LoRA").info("Cache fp16 weight when enabling FP8, will increase the quality of LoRA. Use more system ram."), + "lora_without_backup_weight": OptionInfo(False, "LoRA without backup weights").info("LoRA without backup weights to save RAM."), })) options_templates.update(options_section(('compatibility', "Compatibility", "sd"), { diff --git a/webui.py b/webui.py index 421e3b833..398d83550 100644 --- a/webui.py +++ b/webui.py @@ -6,6 +6,8 @@ import time from modules import timer from modules import initialize_util from modules import initialize +from modules import manager +from threading import Thread startup_timer = timer.startup_timer startup_timer.record("launcher") @@ -14,6 +16,8 @@ initialize.imports() initialize.check_versions() +initialize.initialize() + def create_api(app): from modules.api.api import Api @@ -23,12 +27,10 @@ def create_api(app): return api -def api_only(): +def _api_only(): from fastapi import FastAPI from modules.shared_cmd_options import cmd_opts - initialize.initialize() - app = FastAPI() initialize_util.setup_middleware(app) api = create_api(app) @@ -83,11 +85,10 @@ For more information see: https://github.com/AUTOMATIC1111/stable-diffusion-webu {"!"*25} Warning {"!"*25}''') -def webui(): +def _webui(): from modules.shared_cmd_options import cmd_opts launch_api = cmd_opts.api - initialize.initialize() from modules import shared, ui_tempdir, script_callbacks, ui, progress, ui_extra_networks @@ -177,6 +178,7 @@ def webui(): print("Stopping server...") # If we catch a keyboard interrupt, we want to stop the server and exit. shared.demo.close() + manager.task.stop() break # disable auto launch webui in browser for subsequent UI Reload @@ -193,6 +195,13 @@ def webui(): initialize.initialize_rest(reload_script_modules=True) +def api_only(): + Thread(target=_api_only, daemon=True).start() + + +def webui(): + Thread(target=_webui, daemon=True).start() + if __name__ == "__main__": from modules.shared_cmd_options import cmd_opts @@ -200,3 +209,5 @@ if __name__ == "__main__": api_only() else: webui() + + manager.task.main_loop()