mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2026-01-30 20:33:06 -08:00
Merge 310d0e6938 into fd68e0c384
This commit is contained in:
commit
4c937cbfab
32 changed files with 1693 additions and 131 deletions
1
.github/workflows/run_tests.yaml
vendored
1
.github/workflows/run_tests.yaml
vendored
|
|
@ -51,6 +51,7 @@ jobs:
|
|||
--test-server
|
||||
--do-not-download-clip
|
||||
--no-half
|
||||
--precision full
|
||||
--disable-opt-split-attention
|
||||
--use-cpu all
|
||||
--api-server-stop
|
||||
|
|
|
|||
4
configs/flux1-inference.yaml
Normal file
4
configs/flux1-inference.yaml
Normal file
|
|
@ -0,0 +1,4 @@
|
|||
model:
|
||||
target: modules.models.flux.FLUX1Inferencer
|
||||
params:
|
||||
state_dict: null
|
||||
|
|
@ -2,6 +2,7 @@ import torch
|
|||
|
||||
import lyco_helpers
|
||||
import modules.models.sd3.mmdit
|
||||
import modules.models.flux.modules.layers
|
||||
import network
|
||||
from modules import devices
|
||||
|
||||
|
|
@ -37,7 +38,7 @@ class NetworkModuleLora(network.NetworkModule):
|
|||
if weight is None and none_ok:
|
||||
return None
|
||||
|
||||
is_linear = type(self.sd_module) in [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear, torch.nn.MultiheadAttention, modules.models.sd3.mmdit.QkvLinear]
|
||||
is_linear = type(self.sd_module) in [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear, torch.nn.MultiheadAttention, modules.models.sd3.mmdit.QkvLinear, modules.models.flux.modules.layers.QkvLinear ]
|
||||
is_conv = type(self.sd_module) in [torch.nn.Conv2d]
|
||||
|
||||
if is_linear:
|
||||
|
|
|
|||
|
|
@ -37,7 +37,7 @@ module_types = [
|
|||
|
||||
|
||||
re_digits = re.compile(r"\d+")
|
||||
re_x_proj = re.compile(r"(.*)_([qkv]_proj)$")
|
||||
re_x_proj = re.compile(r"(.*)_((?:[qkv]|mlp)_proj)$")
|
||||
re_compiled = {}
|
||||
|
||||
suffix_conversion = {
|
||||
|
|
@ -183,8 +183,12 @@ def load_network(name, network_on_disk):
|
|||
for key_network, weight in sd.items():
|
||||
|
||||
if diffusers_weight_map:
|
||||
key_network_without_network_parts, network_name, network_weight = key_network.rsplit(".", 2)
|
||||
network_part = network_name + '.' + network_weight
|
||||
if key_network.startswith("lora_unet"):
|
||||
key_network_without_network_parts, _, network_part = key_network.partition(".")
|
||||
key_network_without_network_parts = key_network_without_network_parts.replace("lora_unet", "diffusion_model")
|
||||
else:
|
||||
key_network_without_network_parts, network_name, network_weight = key_network.rsplit(".", 2)
|
||||
network_part = network_name + '.' + network_weight
|
||||
else:
|
||||
key_network_without_network_parts, _, network_part = key_network.partition(".")
|
||||
|
||||
|
|
@ -373,11 +377,13 @@ def allowed_layer_without_weight(layer):
|
|||
return False
|
||||
|
||||
|
||||
def store_weights_backup(weight):
|
||||
def store_weights_backup(weight, dtype):
|
||||
if weight is None:
|
||||
return None
|
||||
|
||||
return weight.to(devices.cpu, copy=True)
|
||||
if shared.opts.lora_without_backup_weight:
|
||||
return True
|
||||
return weight.to(devices.cpu, dtype=dtype, copy=True)
|
||||
|
||||
|
||||
def restore_weights_backup(obj, field, weight):
|
||||
|
|
@ -385,16 +391,20 @@ def restore_weights_backup(obj, field, weight):
|
|||
setattr(obj, field, None)
|
||||
return
|
||||
|
||||
getattr(obj, field).copy_(weight)
|
||||
old_weight = getattr(obj, field)
|
||||
old_weight.copy_(weight.to(dtype=old_weight.dtype))
|
||||
|
||||
|
||||
def network_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention]):
|
||||
def network_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention], cleanup=False):
|
||||
weights_backup = getattr(self, "network_weights_backup", None)
|
||||
bias_backup = getattr(self, "network_bias_backup", None)
|
||||
|
||||
if weights_backup is None and bias_backup is None:
|
||||
return
|
||||
|
||||
if shared.opts.lora_without_backup_weight:
|
||||
return
|
||||
|
||||
if weights_backup is not None:
|
||||
if isinstance(self, torch.nn.MultiheadAttention):
|
||||
restore_weights_backup(self, 'in_proj_weight', weights_backup[0])
|
||||
|
|
@ -407,6 +417,51 @@ def network_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Li
|
|||
else:
|
||||
restore_weights_backup(self, 'bias', bias_backup)
|
||||
|
||||
if cleanup:
|
||||
if weights_backup is not None:
|
||||
del self.network_weights_backup
|
||||
if bias_backup is not None:
|
||||
del self.network_bias_backup
|
||||
|
||||
|
||||
def network_backup_weights(self):
|
||||
network_layer_name = getattr(self, 'network_layer_name', None)
|
||||
|
||||
_current_names = getattr(self, "network_current_names", ())
|
||||
wanted_names = tuple((x.name, x.te_multiplier, x.unet_multiplier, x.dyn_dim) for x in loaded_networks)
|
||||
|
||||
need_backup = False
|
||||
for net in loaded_networks:
|
||||
if network_layer_name in net.modules:
|
||||
need_backup = True
|
||||
break
|
||||
elif network_layer_name + "_q_proj" in net.modules:
|
||||
need_backup = True
|
||||
break
|
||||
|
||||
if not need_backup:
|
||||
return
|
||||
|
||||
weights_backup = getattr(self, "network_weights_backup", None)
|
||||
if weights_backup is None and wanted_names != ():
|
||||
if isinstance(self, torch.nn.MultiheadAttention):
|
||||
weights_backup = (store_weights_backup(self.in_proj_weight, self.org_dtype), store_weights_backup(self.out_proj.weight, self.org_dtype))
|
||||
else:
|
||||
weights_backup = store_weights_backup(self.weight, self.org_dtype)
|
||||
|
||||
self.network_weights_backup = weights_backup
|
||||
|
||||
bias_backup = getattr(self, "network_bias_backup", None)
|
||||
if bias_backup is None and wanted_names != ():
|
||||
if isinstance(self, torch.nn.MultiheadAttention) and self.out_proj.bias is not None:
|
||||
bias_backup = store_weights_backup(self.out_proj.bias, self.org_dtype)
|
||||
elif getattr(self, 'bias', None) is not None:
|
||||
bias_backup = store_weights_backup(self.bias, self.org_dtype)
|
||||
else:
|
||||
bias_backup = None
|
||||
|
||||
self.network_bias_backup = bias_backup
|
||||
|
||||
|
||||
def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention]):
|
||||
"""
|
||||
|
|
@ -424,38 +479,17 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
|
|||
|
||||
weights_backup = getattr(self, "network_weights_backup", None)
|
||||
if weights_backup is None and wanted_names != ():
|
||||
if current_names != () and not allowed_layer_without_weight(self):
|
||||
raise RuntimeError(f"{network_layer_name} - no backup weights found and current weights are not unchanged")
|
||||
|
||||
if isinstance(self, torch.nn.MultiheadAttention):
|
||||
weights_backup = (store_weights_backup(self.in_proj_weight), store_weights_backup(self.out_proj.weight))
|
||||
else:
|
||||
weights_backup = store_weights_backup(self.weight)
|
||||
|
||||
self.network_weights_backup = weights_backup
|
||||
|
||||
bias_backup = getattr(self, "network_bias_backup", None)
|
||||
if bias_backup is None and wanted_names != ():
|
||||
if isinstance(self, torch.nn.MultiheadAttention) and self.out_proj.bias is not None:
|
||||
bias_backup = store_weights_backup(self.out_proj.bias)
|
||||
elif getattr(self, 'bias', None) is not None:
|
||||
bias_backup = store_weights_backup(self.bias)
|
||||
else:
|
||||
bias_backup = None
|
||||
|
||||
# Unlike weight which always has value, some modules don't have bias.
|
||||
# Only report if bias is not None and current bias are not unchanged.
|
||||
if bias_backup is not None and current_names != ():
|
||||
raise RuntimeError("no backup bias found and current bias are not unchanged")
|
||||
|
||||
self.network_bias_backup = bias_backup
|
||||
network_backup_weights(self)
|
||||
elif current_names != () and current_names != wanted_names and not getattr(self, "weights_restored", False):
|
||||
network_restore_weights_from_backup(self)
|
||||
|
||||
if current_names != wanted_names:
|
||||
network_restore_weights_from_backup(self)
|
||||
if hasattr(self, "weights_restored"):
|
||||
self.weights_restored = False
|
||||
|
||||
for net in loaded_networks:
|
||||
module = net.modules.get(network_layer_name, None)
|
||||
if module is not None and hasattr(self, 'weight') and not isinstance(module, modules.models.sd3.mmdit.QkvLinear):
|
||||
if module is not None and hasattr(self, 'weight') and not all(isinstance(module, linear) for linear in (modules.models.sd3.mmdit.QkvLinear, modules.models.flux.modules.layers.QkvLinear)):
|
||||
try:
|
||||
with torch.no_grad():
|
||||
if getattr(self, 'fp16_weight', None) is None:
|
||||
|
|
@ -478,6 +512,7 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
|
|||
self.bias = torch.nn.Parameter(ex_bias).to(self.weight.dtype)
|
||||
else:
|
||||
self.bias.copy_((bias + ex_bias).to(dtype=self.bias.dtype))
|
||||
del weight, bias, updown, ex_bias
|
||||
except RuntimeError as e:
|
||||
logging.debug(f"Network {net.name} layer {network_layer_name}: {e}")
|
||||
extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1
|
||||
|
|
@ -515,7 +550,9 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
|
|||
|
||||
continue
|
||||
|
||||
if isinstance(self, modules.models.sd3.mmdit.QkvLinear) and module_q and module_k and module_v:
|
||||
module_mlp = net.modules.get(network_layer_name + "_mlp_proj", None)
|
||||
|
||||
if any(isinstance(self, linear) for linear in (modules.models.sd3.mmdit.QkvLinear, modules.models.flux.modules.layers.QkvLinear)) and module_q and module_k and module_v and module_mlp is None and self.weight.shape[0] // 3 == module_q.up_model.weight.shape[0]:
|
||||
try:
|
||||
with torch.no_grad():
|
||||
# Send "real" orig_weight into MHA's lora module
|
||||
|
|
@ -526,6 +563,31 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
|
|||
del qw, kw, vw
|
||||
updown_qkv = torch.vstack([updown_q, updown_k, updown_v])
|
||||
self.weight += updown_qkv
|
||||
del updown_qkv
|
||||
del updown_q, updown_k, updown_v
|
||||
|
||||
except RuntimeError as e:
|
||||
logging.debug(f"Network {net.name} layer {network_layer_name}: {e}")
|
||||
extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1
|
||||
|
||||
continue
|
||||
|
||||
if any(isinstance(self, linear) for linear in (modules.models.flux.modules.layers.QkvLinear,)) and module_q and module_k and module_v:
|
||||
try:
|
||||
with torch.no_grad():
|
||||
qw, kw, vw, mlp = torch.tensor_split(self.weight, (3072, 6144, 9216,), 0)
|
||||
updown_q, _ = module_q.calc_updown(qw)
|
||||
updown_k, _ = module_k.calc_updown(kw)
|
||||
updown_v, _ = module_v.calc_updown(vw)
|
||||
if module_mlp is not None:
|
||||
updown_mlp, _ = module_mlp.calc_updown(mlp)
|
||||
else:
|
||||
updown_mlp = torch.zeros(3072 * 4, 3072, dtype=updown_q.dtype, device=updown_q.device)
|
||||
del qw, kw, vw, mlp
|
||||
updown_qkv_mlp = torch.vstack([updown_q, updown_k, updown_v, updown_mlp])
|
||||
self.weight += updown_qkv_mlp
|
||||
del updown_qkv_mlp
|
||||
del updown_q, updown_k, updown_v, updown_mlp
|
||||
|
||||
except RuntimeError as e:
|
||||
logging.debug(f"Network {net.name} layer {network_layer_name}: {e}")
|
||||
|
|
@ -539,7 +601,12 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
|
|||
logging.debug(f"Network {net.name} layer {network_layer_name}: couldn't find supported operation")
|
||||
extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1
|
||||
|
||||
self.network_current_names = wanted_names
|
||||
|
||||
if shared.opts.lora_without_backup_weight:
|
||||
self.network_weights_backup = None
|
||||
self.network_bias_backup = None
|
||||
else:
|
||||
self.network_current_names = wanted_names
|
||||
|
||||
|
||||
def network_forward(org_module, input, original_forward):
|
||||
|
|
|
|||
|
|
@ -1,15 +1,17 @@
|
|||
import re
|
||||
import torch
|
||||
|
||||
import gradio as gr
|
||||
from fastapi import FastAPI
|
||||
|
||||
import gc
|
||||
import network
|
||||
import networks
|
||||
import lora # noqa:F401
|
||||
import lora_patches
|
||||
import extra_networks_lora
|
||||
import ui_extra_networks_lora
|
||||
from modules import script_callbacks, ui_extra_networks, extra_networks, shared
|
||||
from modules import script_callbacks, ui_extra_networks, extra_networks, shared, scripts, devices
|
||||
|
||||
|
||||
def unload():
|
||||
|
|
@ -97,6 +99,82 @@ def infotext_pasted(infotext, d):
|
|||
d["Prompt"] = re.sub(re_lora, network_replacement, d["Prompt"])
|
||||
|
||||
|
||||
class ScriptLora(scripts.Script):
|
||||
name = "Lora"
|
||||
|
||||
def title(self):
|
||||
return self.name
|
||||
|
||||
def show(self, is_img2img):
|
||||
return scripts.AlwaysVisible
|
||||
|
||||
def after_extra_networks_activate(self, p, *args, **kwargs):
|
||||
# check modules and setup org_dtype
|
||||
modules = []
|
||||
if shared.sd_model.is_sdxl:
|
||||
for _i, embedder in enumerate(shared.sd_model.conditioner.embedders):
|
||||
if not hasattr(embedder, 'wrapped'):
|
||||
continue
|
||||
|
||||
for _name, module in embedder.wrapped.named_modules():
|
||||
if isinstance(module, (torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention)):
|
||||
if hasattr(module, 'weight'):
|
||||
modules.append(module)
|
||||
elif isinstance(module, torch.nn.MultiheadAttention):
|
||||
modules.append(module)
|
||||
|
||||
else:
|
||||
cond_stage_model = getattr(shared.sd_model.cond_stage_model, 'wrapped', shared.sd_model.cond_stage_model)
|
||||
|
||||
for _name, module in cond_stage_model.named_modules():
|
||||
if isinstance(module, (torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention)):
|
||||
if hasattr(module, 'weight'):
|
||||
modules.append(module)
|
||||
elif isinstance(module, torch.nn.MultiheadAttention):
|
||||
modules.append(module)
|
||||
|
||||
for _name, module in shared.sd_model.model.named_modules():
|
||||
if isinstance(module, (torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention)):
|
||||
if hasattr(module, 'weight'):
|
||||
modules.append(module)
|
||||
elif isinstance(module, torch.nn.MultiheadAttention):
|
||||
modules.append(module)
|
||||
|
||||
print("Total lora modules after_extra_networks_activate() =", len(modules))
|
||||
|
||||
target_dtype = devices.dtype_inference
|
||||
for module in modules:
|
||||
network_layer_name = getattr(module, 'network_layer_name', None)
|
||||
if network_layer_name is None:
|
||||
continue
|
||||
|
||||
if isinstance(module, torch.nn.MultiheadAttention):
|
||||
org_dtype = torch.float32
|
||||
else:
|
||||
org_dtype = None
|
||||
for _name, param in module.named_parameters():
|
||||
if param.dtype != target_dtype:
|
||||
org_dtype = param.dtype
|
||||
break
|
||||
|
||||
# set org_dtype
|
||||
module.org_dtype = org_dtype
|
||||
|
||||
# backup/restore weights
|
||||
current_names = getattr(module, "network_current_names", ())
|
||||
wanted_names = tuple((x.name, x.te_multiplier, x.unet_multiplier, x.dyn_dim) for x in networks.loaded_networks)
|
||||
|
||||
weights_backup = getattr(module, "network_weights_backup", None)
|
||||
|
||||
if current_names == () and current_names != wanted_names and weights_backup is None:
|
||||
networks.network_backup_weights(module)
|
||||
elif current_names != () and current_names != wanted_names:
|
||||
networks.network_restore_weights_from_backup(module, wanted_names == ())
|
||||
module.weights_restored = True
|
||||
if current_names != wanted_names and wanted_names == ():
|
||||
gc.collect()
|
||||
|
||||
|
||||
script_callbacks.on_infotext_pasted(infotext_pasted)
|
||||
|
||||
shared.opts.onchange("lora_in_memory_limit", networks.purge_networks_from_memory)
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ from functools import wraps
|
|||
import html
|
||||
import time
|
||||
|
||||
from modules import shared, progress, errors, devices, fifo_lock, profiling
|
||||
from modules import shared, progress, errors, devices, fifo_lock, profiling, manager
|
||||
|
||||
queue_lock = fifo_lock.FIFOLock()
|
||||
|
||||
|
|
@ -34,7 +34,7 @@ def wrap_gradio_gpu_call(func, extra_outputs=None):
|
|||
progress.start_task(id_task)
|
||||
|
||||
try:
|
||||
res = func(*args, **kwargs)
|
||||
res = manager.task.run_and_wait_result(func, *args, **kwargs)
|
||||
progress.record_results(id_task, res)
|
||||
finally:
|
||||
progress.finish_task(id_task)
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
import sys
|
||||
import contextlib
|
||||
from copy import deepcopy
|
||||
from functools import lru_cache
|
||||
|
||||
import torch
|
||||
|
|
@ -128,6 +129,26 @@ dtype_unet: torch.dtype = torch.float16
|
|||
dtype_inference: torch.dtype = torch.float16
|
||||
unet_needs_upcast = False
|
||||
|
||||
supported_vae_dtypes = [torch.float16, torch.float32]
|
||||
|
||||
|
||||
# prepare available dtypes
|
||||
if torch.version.cuda:
|
||||
if torch.cuda.is_bf16_supported() and torch.cuda.get_device_properties(torch.cuda.current_device()).major >= 8:
|
||||
supported_vae_dtypes = [torch.bfloat16] + supported_vae_dtypes
|
||||
if has_xpu():
|
||||
supported_vae_dtypes = [torch.bfloat16] + supported_vae_dtypes
|
||||
|
||||
|
||||
def supports_non_blocking():
|
||||
if has_mps() or has_xpu():
|
||||
return False
|
||||
|
||||
if npu_specific.has_npu:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def cond_cast_unet(input):
|
||||
if force_fp16:
|
||||
|
|
@ -146,17 +167,33 @@ patch_module_list = [
|
|||
torch.nn.MultiheadAttention,
|
||||
torch.nn.GroupNorm,
|
||||
torch.nn.LayerNorm,
|
||||
torch.nn.Embedding,
|
||||
]
|
||||
|
||||
|
||||
def manual_cast_forward(target_dtype):
|
||||
def manual_cast_forward(target_dtype, target_device=None, copy=False):
|
||||
params = dict()
|
||||
if supports_non_blocking():
|
||||
params['non_blocking'] = True
|
||||
|
||||
supported_cast_dtypes = [torch.float16, torch.float32]
|
||||
if torch.cuda.is_bf16_supported():
|
||||
supported_cast_dtypes += [torch.bfloat16]
|
||||
|
||||
def forward_wrapper(self, *args, **kwargs):
|
||||
if any(
|
||||
isinstance(arg, torch.Tensor) and arg.dtype != target_dtype
|
||||
for arg in args
|
||||
):
|
||||
args = [arg.to(target_dtype) if isinstance(arg, torch.Tensor) else arg for arg in args]
|
||||
kwargs = {k: v.to(target_dtype) if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()}
|
||||
if target_device is not None:
|
||||
params['device'] = target_device
|
||||
params['dtype'] = target_dtype
|
||||
|
||||
args = list(args)
|
||||
for j in (i for i, arg in enumerate(args) if isinstance(arg, torch.Tensor) and arg.dtype != target_dtype):
|
||||
if args[j].dtype in supported_cast_dtypes:
|
||||
args[j] = args[j].to(**params)
|
||||
args = tuple(args)
|
||||
|
||||
for key in (k for k, v in kwargs.items() if isinstance(v, torch.Tensor) and v.dtype != target_dtype):
|
||||
if kwargs[key].dtype in supported_cast_dtypes:
|
||||
kwargs[key] = kwargs[key].to(**params)
|
||||
|
||||
org_dtype = target_dtype
|
||||
for param in self.parameters():
|
||||
|
|
@ -164,38 +201,52 @@ def manual_cast_forward(target_dtype):
|
|||
org_dtype = param.dtype
|
||||
break
|
||||
|
||||
if org_dtype != target_dtype:
|
||||
self.to(target_dtype)
|
||||
result = self.org_forward(*args, **kwargs)
|
||||
if org_dtype != target_dtype:
|
||||
self.to(org_dtype)
|
||||
if copy and not isinstance(self, torch.nn.Embedding):
|
||||
copied = deepcopy(self)
|
||||
if org_dtype != target_dtype:
|
||||
copied.to(**params)
|
||||
|
||||
result = copied.org_forward(*args, **kwargs)
|
||||
del copied
|
||||
else:
|
||||
if org_dtype != target_dtype:
|
||||
self.to(**params)
|
||||
|
||||
result = self.org_forward(*args, **kwargs)
|
||||
|
||||
if org_dtype != target_dtype:
|
||||
params['dtype'] = org_dtype
|
||||
self.to(**params)
|
||||
|
||||
if target_dtype != dtype_inference:
|
||||
params['dtype'] = dtype_inference
|
||||
if isinstance(result, tuple):
|
||||
result = tuple(
|
||||
i.to(dtype_inference)
|
||||
i.to(**params)
|
||||
if isinstance(i, torch.Tensor)
|
||||
else i
|
||||
for i in result
|
||||
)
|
||||
elif isinstance(result, torch.Tensor):
|
||||
result = result.to(dtype_inference)
|
||||
result = result.to(**params)
|
||||
return result
|
||||
return forward_wrapper
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def manual_cast(target_dtype):
|
||||
def manual_cast(target_dtype, target_device=None, copy=None):
|
||||
applied = False
|
||||
|
||||
|
||||
for module_type in patch_module_list:
|
||||
if hasattr(module_type, "org_forward"):
|
||||
continue
|
||||
applied = True
|
||||
org_forward = module_type.forward
|
||||
if module_type == torch.nn.MultiheadAttention:
|
||||
module_type.forward = manual_cast_forward(torch.float32)
|
||||
module_type.forward = manual_cast_forward(torch.float32, target_device, copy)
|
||||
else:
|
||||
module_type.forward = manual_cast_forward(target_dtype)
|
||||
module_type.forward = manual_cast_forward(target_dtype, target_device, copy)
|
||||
module_type.org_forward = org_forward
|
||||
try:
|
||||
yield None
|
||||
|
|
@ -207,26 +258,37 @@ def manual_cast(target_dtype):
|
|||
delattr(module_type, "org_forward")
|
||||
|
||||
|
||||
def autocast(disable=False):
|
||||
def autocast(disable=False, current_dtype=None, target_dtype=None, target_device=None, copy=None):
|
||||
if disable:
|
||||
return contextlib.nullcontext()
|
||||
|
||||
copy = copy if copy is not None else shared.opts.lora_without_backup_weight
|
||||
|
||||
if target_dtype is None:
|
||||
target_dtype = dtype
|
||||
|
||||
if force_fp16:
|
||||
# No casting during inference if force_fp16 is enabled.
|
||||
# All tensor dtype conversion happens before inference.
|
||||
return contextlib.nullcontext()
|
||||
|
||||
if fp8 and device==cpu:
|
||||
if fp8 and target_device==cpu:
|
||||
return torch.autocast("cpu", dtype=torch.bfloat16, enabled=True)
|
||||
|
||||
if fp8 and dtype_inference == torch.float32:
|
||||
return manual_cast(dtype)
|
||||
return manual_cast(target_dtype, target_device, copy=copy)
|
||||
|
||||
if dtype == torch.float32 or dtype_inference == torch.float32:
|
||||
if target_dtype != dtype_inference or copy:
|
||||
return manual_cast(target_dtype, target_device, copy=copy)
|
||||
|
||||
if current_dtype is not None and current_dtype != target_dtype:
|
||||
return manual_cast(target_dtype, target_device, copy=copy)
|
||||
|
||||
if target_dtype == torch.float32 or dtype_inference == torch.float32:
|
||||
return contextlib.nullcontext()
|
||||
|
||||
if has_xpu() or has_mps() or cuda_no_autocast():
|
||||
return manual_cast(dtype)
|
||||
return manual_cast(target_dtype, target_device)
|
||||
|
||||
return torch.autocast("cuda")
|
||||
|
||||
|
|
|
|||
|
|
@ -495,11 +495,17 @@ def configure_for_tests():
|
|||
def start():
|
||||
print(f"Launching {'API server' if '--nowebui' in sys.argv else 'Web UI'} with arguments: {shlex.join(sys.argv[1:])}")
|
||||
import webui
|
||||
|
||||
from modules import manager
|
||||
|
||||
if '--nowebui' in sys.argv:
|
||||
webui.api_only()
|
||||
else:
|
||||
webui.webui()
|
||||
|
||||
manager.task.main_loop()
|
||||
return
|
||||
|
||||
|
||||
def dump_sysinfo():
|
||||
from modules import sysinfo
|
||||
|
|
|
|||
|
|
@ -53,6 +53,7 @@ def setup_for_low_vram(sd_model, use_medvram):
|
|||
|
||||
if module_in_gpu is not None:
|
||||
module_in_gpu.to(cpu)
|
||||
devices.torch_gc()
|
||||
|
||||
module.to(devices.device)
|
||||
module_in_gpu = module
|
||||
|
|
|
|||
83
modules/manager.py
Normal file
83
modules/manager.py
Normal file
|
|
@ -0,0 +1,83 @@
|
|||
#
|
||||
# based on forge's work from https://github.com/lllyasviel/stable-diffusion-webui-forge/blob/main/modules_forge/main_thread.py
|
||||
#
|
||||
# Original author comment:
|
||||
# This file is the main thread that handles all gradio calls for major t2i or i2i processing.
|
||||
# Other gradio calls (like those from extensions) are not influenced.
|
||||
# By using one single thread to process all major calls, model moving is significantly faster.
|
||||
#
|
||||
# 2024/09/28 classified,
|
||||
|
||||
import random
|
||||
import string
|
||||
import threading
|
||||
import time
|
||||
|
||||
from collections import OrderedDict
|
||||
|
||||
|
||||
class Task:
|
||||
def __init__(self, **kwargs):
|
||||
self.__dict__.update(kwargs)
|
||||
|
||||
|
||||
class TaskManager:
|
||||
last_exception = None
|
||||
pending_tasks = []
|
||||
finished_tasks = OrderedDict()
|
||||
lock = None
|
||||
running = False
|
||||
|
||||
def __init__(self):
|
||||
self.lock = threading.Lock()
|
||||
|
||||
def work(self, task):
|
||||
try:
|
||||
task.result = task.func(*task.args, **task.kwargs)
|
||||
except Exception as e:
|
||||
task.exception = e
|
||||
self.last_exception = e
|
||||
|
||||
|
||||
def stop(self):
|
||||
self.running = False
|
||||
|
||||
|
||||
def main_loop(self):
|
||||
self.running = True
|
||||
while self.running:
|
||||
time.sleep(0.01)
|
||||
if len(self.pending_tasks) > 0:
|
||||
with self.lock:
|
||||
task = self.pending_tasks.pop(0)
|
||||
|
||||
self.work(task)
|
||||
|
||||
self.finished_tasks[task.task_id] = task
|
||||
|
||||
|
||||
def push_task(self, func, *args, **kwargs):
|
||||
if args and type(args[0]) == str and args[0].startswith("task(") and args[0].endswith(")"):
|
||||
task_id = args[0]
|
||||
else:
|
||||
task_id = ''.join(random.choices(string.ascii_uppercase + string.digits, k=7))
|
||||
task = Task(task_id=task_id, func=func, args=args, kwargs=kwargs, result=None, exception=None)
|
||||
self.pending_tasks.append(task)
|
||||
|
||||
return task.task_id
|
||||
|
||||
|
||||
def run_and_wait_result(self, func, *args, **kwargs):
|
||||
current_id = self.push_task(func, *args, **kwargs)
|
||||
|
||||
while True:
|
||||
time.sleep(0.01)
|
||||
if current_id in self.finished_tasks:
|
||||
finished = self.finished_tasks.pop(current_id)
|
||||
if finished.exception is not None:
|
||||
raise finished.exception
|
||||
|
||||
return finished.result
|
||||
|
||||
|
||||
task = TaskManager()
|
||||
5
modules/models/flux/__init__.py
Normal file
5
modules/models/flux/__init__.py
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
from .flux import FLUX1Inferencer
|
||||
|
||||
__all__ = [
|
||||
"FLUX1Inferencer",
|
||||
]
|
||||
360
modules/models/flux/flux.py
Normal file
360
modules/models/flux/flux.py
Normal file
|
|
@ -0,0 +1,360 @@
|
|||
import contextlib
|
||||
|
||||
import os
|
||||
import safetensors
|
||||
import torch
|
||||
import math
|
||||
|
||||
import k_diffusion
|
||||
from transformers import CLIPTokenizer
|
||||
|
||||
from modules import shared, devices, modelloader, sd_hijack_clip
|
||||
|
||||
from modules.models.sd3.sd3_impls import SDVAE
|
||||
from modules.models.sd3.sd3_cond import CLIPL_CONFIG, T5_CONFIG, CLIPL_URL, T5_URL, SafetensorsMapping, Sd3T5
|
||||
from modules.models.sd3.other_impls import SDClipModel, T5XXLModel, SDTokenizer, T5XXLTokenizer
|
||||
from PIL import Image
|
||||
|
||||
from .model import Flux
|
||||
|
||||
|
||||
class FluxTokenizer:
|
||||
def __init__(self):
|
||||
clip_tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
|
||||
self.clip_l = SDTokenizer(tokenizer=clip_tokenizer)
|
||||
self.t5xxl = T5XXLTokenizer()
|
||||
|
||||
def tokenize_with_weights(self, text:str):
|
||||
out = {}
|
||||
out["l"] = self.clip_l.tokenize_with_weights(text)
|
||||
out["t5xxl"] = self.t5xxl.tokenize_with_weights(text)
|
||||
return out
|
||||
|
||||
|
||||
class Flux1ClipL(sd_hijack_clip.TextConditionalModel):
|
||||
def __init__(self, clip_l):
|
||||
super().__init__()
|
||||
|
||||
self.clip_l = clip_l
|
||||
|
||||
self.tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
|
||||
|
||||
empty = self.tokenizer('')["input_ids"]
|
||||
self.id_start = empty[0]
|
||||
self.id_end = empty[1]
|
||||
self.id_pad = empty[1]
|
||||
|
||||
self.return_pooled = True
|
||||
|
||||
def tokenize(self, texts):
|
||||
return self.tokenizer(texts, truncation=False, add_special_tokens=False)["input_ids"]
|
||||
|
||||
def encode_with_transformers(self, tokens):
|
||||
l_out, l_pooled = self.clip_l(tokens)
|
||||
l_out = torch.cat([l_out], dim=-1)
|
||||
l_out = torch.nn.functional.pad(l_out, (0, 4096 - l_out.shape[-1]))
|
||||
|
||||
vector_out = torch.cat([l_pooled], dim=-1)
|
||||
|
||||
l_out.pooled = vector_out
|
||||
|
||||
return l_out
|
||||
|
||||
def encode_embedding_init_text(self, init_text, nvpt):
|
||||
return torch.zeros((nvpt, 768+1280), device=devices.device) # XXX
|
||||
|
||||
|
||||
|
||||
class FluxCond(torch.nn.Module):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self.tokenizer = FluxTokenizer()
|
||||
|
||||
with torch.no_grad():
|
||||
self.clip_l = SDClipModel(layer="hidden", layer_idx=-2, device="cpu", dtype=devices.dtype_inference, layer_norm_hidden_state=False, return_projected_pooled=False, textmodel_json_config=CLIPL_CONFIG)
|
||||
|
||||
if shared.opts.flux_enable_t5:
|
||||
self.t5xxl = T5XXLModel(T5_CONFIG, device="cpu", dtype=devices.dtype_inference)
|
||||
else:
|
||||
self.t5xxl = None
|
||||
|
||||
self.model_l = Flux1ClipL(self.clip_l)
|
||||
self.model_t5 = Sd3T5(self.t5xxl)
|
||||
|
||||
def forward(self, prompts: list[str]):
|
||||
with devices.without_autocast():
|
||||
l_out, vector_out = self.model_l(prompts)
|
||||
t5_out = self.model_t5(prompts, token_count=l_out.shape[1])
|
||||
lt_out = torch.cat([l_out, t5_out], dim=-2)
|
||||
|
||||
return {
|
||||
'crossattn': lt_out,
|
||||
'vector': vector_out,
|
||||
}
|
||||
|
||||
def before_load_weights(self, state_dict):
|
||||
clip_path = os.path.join(shared.models_path, "CLIP")
|
||||
|
||||
if 'text_encoders.clip_l.transformer.text_model.embeddings.position_embedding.weight' not in state_dict:
|
||||
clip_l_file = modelloader.load_file_from_url(CLIPL_URL, model_dir=clip_path, file_name="clip_l.safetensors")
|
||||
with safetensors.safe_open(clip_l_file, framework="pt") as file:
|
||||
self.clip_l.transformer.load_state_dict(SafetensorsMapping(file), strict=False)
|
||||
|
||||
if self.t5xxl and 'text_encoders.t5xxl.transformer.encoder.block.0.layer.0.SelfAttention.k.weight' not in state_dict:
|
||||
t5_file = modelloader.load_file_from_url(T5_URL, model_dir=clip_path, file_name="t5xxl_fp8_e4m3fn.safetensors")
|
||||
with safetensors.safe_open(t5_file, framework="pt") as file:
|
||||
self.t5xxl.transformer.load_state_dict(SafetensorsMapping(file), strict=False)
|
||||
|
||||
def encode_embedding_init_text(self, init_text, nvpt):
|
||||
return self.model_l.encode_embedding_init_text(init_text, nvpt)
|
||||
|
||||
def tokenize(self, texts):
|
||||
return self.model_l.tokenize(texts)
|
||||
|
||||
def medvram_modules(self):
|
||||
return [self.clip_l, self.t5xxl]
|
||||
|
||||
def get_token_count(self, text):
|
||||
_, token_count = self.model_l.process_texts([text])
|
||||
|
||||
return token_count
|
||||
|
||||
def get_target_prompt_token_count(self, token_count):
|
||||
return self.model_l.get_target_prompt_token_count(token_count)
|
||||
|
||||
def flux_time_shift(mu: float, sigma: float, t):
|
||||
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
|
||||
|
||||
class ModelSamplingFlux(torch.nn.Module):
|
||||
def __init__(self, shift=1.15):
|
||||
super().__init__()
|
||||
|
||||
self.set_parameters(shift=shift)
|
||||
|
||||
def set_parameters(self, shift=1.15, timesteps=10000):
|
||||
self.shift = shift
|
||||
ts = self.sigma((torch.arange(1, timesteps + 1, 1) / timesteps))
|
||||
self.register_buffer('sigmas', ts)
|
||||
|
||||
@property
|
||||
def sigma_min(self):
|
||||
return self.sigmas[0]
|
||||
|
||||
@property
|
||||
def sigma_max(self):
|
||||
return self.sigmas[-1]
|
||||
|
||||
def timestep(self, sigma):
|
||||
return sigma
|
||||
|
||||
def sigma(self, timestep):
|
||||
return flux_time_shift(self.shift, 1.0, timestep)
|
||||
|
||||
def percent_to_sigma(self, percent):
|
||||
if percent <= 0.0:
|
||||
return 1.0
|
||||
if percent >= 1.0:
|
||||
return 0.0
|
||||
return 1.0 - percent
|
||||
|
||||
def calculate_denoised(self, sigma, model_output, model_input):
|
||||
sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
|
||||
return model_input - model_output * sigma
|
||||
|
||||
|
||||
class BaseModel(torch.nn.Module):
|
||||
"""Wrapper around the core FLUX model"""
|
||||
def __init__(self, shift=1.15, device=None, dtype=torch.float16, state_dict=None, prefix="", **kwargs):
|
||||
super().__init__()
|
||||
|
||||
self.diffusion_model = Flux(device=device, dtype=dtype, **kwargs)
|
||||
self.model_sampling = ModelSamplingFlux(shift=shift)
|
||||
self.depth = kwargs['depth']
|
||||
self.depth_single_block = kwargs['depth_single_blocks']
|
||||
|
||||
def apply_model(self, x, sigma, c_crossattn=None, y=None):
|
||||
dtype = self.get_dtype()
|
||||
timestep = self.model_sampling.timestep(sigma).float()
|
||||
guidance = torch.FloatTensor([3.5]).to(device=devices.device, dtype=torch.float32)
|
||||
model_output = self.diffusion_model(x.to(dtype), timestep, context=c_crossattn.to(dtype), y=y.to(dtype), guidance=guidance).to(x.dtype)
|
||||
return self.model_sampling.calculate_denoised(sigma, model_output, x)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
return self.apply_model(*args, **kwargs)
|
||||
|
||||
def get_dtype(self):
|
||||
return self.diffusion_model.dtype
|
||||
|
||||
|
||||
class FLUX1LatentFormat:
|
||||
"""Latents are slightly shifted from center - this class must be called after VAE Decode to correct for the shift"""
|
||||
def __init__(self, scale_factor=0.3611, shift_factor=0.1159):
|
||||
self.scale_factor = scale_factor
|
||||
self.shift_factor = shift_factor
|
||||
|
||||
def process_in(self, latent):
|
||||
return (latent - self.shift_factor) * self.scale_factor
|
||||
|
||||
def process_out(self, latent):
|
||||
return (latent / self.scale_factor) + self.shift_factor
|
||||
|
||||
def decode_latent_to_preview(self, x0):
|
||||
"""Quick RGB approximate preview of sd3 latents"""
|
||||
factors = torch.tensor([
|
||||
[-0.0404, 0.0159, 0.0609], [ 0.0043, 0.0298, 0.0850],
|
||||
[ 0.0328, -0.0749, -0.0503], [-0.0245, 0.0085, 0.0549],
|
||||
[ 0.0966, 0.0894, 0.0530], [ 0.0035, 0.0399, 0.0123],
|
||||
[ 0.0583, 0.1184, 0.1262], [-0.0191, -0.0206, -0.0306],
|
||||
[-0.0324, 0.0055, 0.1001], [ 0.0955, 0.0659, -0.0545],
|
||||
[-0.0504, 0.0231, -0.0013], [ 0.0500, -0.0008, -0.0088],
|
||||
[ 0.0982, 0.0941, 0.0976], [-0.1233, -0.0280, -0.0897],
|
||||
[-0.0005, -0.0530, -0.0020], [-0.1273, -0.0932, -0.0680],
|
||||
], device="cpu")
|
||||
latent_image = x0[0].permute(1, 2, 0).cpu() @ factors
|
||||
|
||||
latents_ubyte = (((latent_image + 1) / 2)
|
||||
.clamp(0, 1) # change scale from -1..1 to 0..1
|
||||
.mul(0xFF) # to 0..255
|
||||
.byte()).cpu()
|
||||
|
||||
return Image.fromarray(latents_ubyte.numpy())
|
||||
|
||||
|
||||
class FLUX1Denoiser(k_diffusion.external.DiscreteSchedule):
|
||||
def __init__(self, inner_model, sigmas):
|
||||
super().__init__(sigmas, quantize=shared.opts.enable_quantization)
|
||||
self.inner_model = inner_model
|
||||
|
||||
def forward(self, input, sigma, **kwargs):
|
||||
return self.inner_model.apply_model(input, sigma, **kwargs)
|
||||
|
||||
|
||||
class FLUX1Inferencer(torch.nn.Module):
|
||||
def __init__(self, state_dict, use_ema=False):
|
||||
super().__init__()
|
||||
|
||||
params = dict(
|
||||
image_model="flux",
|
||||
in_channels=16,
|
||||
vec_in_dim=768,
|
||||
context_in_dim=4096,
|
||||
hidden_size=3072,
|
||||
mlp_ratio=4.0,
|
||||
num_heads=24,
|
||||
depth=19,
|
||||
depth_single_blocks=38,
|
||||
axes_dim=[16, 56, 56],
|
||||
theta=10000,
|
||||
qkv_bias=True,
|
||||
guidance_embed=True,
|
||||
)
|
||||
|
||||
# detect model_prefix
|
||||
diffusion_model_prefix = "model.diffusion_model."
|
||||
if "model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale" in state_dict:
|
||||
diffusion_model_prefix = "model.diffusion_model."
|
||||
elif "double_blocks.0.img_attn.norm.key_norm.scale" in state_dict:
|
||||
diffusion_model_prefix = ""
|
||||
|
||||
shift=1.15
|
||||
# check guidance_in to detect Flux schnell
|
||||
if f"{diffusion_model_prefix}guidance_in.in_layer.weight" not in state_dict:
|
||||
print("Flux schnell detected")
|
||||
params.update(dict(guidance_embed=False,))
|
||||
shift=1.0
|
||||
|
||||
with torch.no_grad():
|
||||
self.model = BaseModel(shift=shift, state_dict=state_dict, prefix=diffusion_model_prefix, device="cpu", dtype=devices.dtype_inference, **params)
|
||||
self.first_stage_model = SDVAE(device="cpu", dtype=devices.dtype_vae)
|
||||
self.first_stage_model.dtype = devices.dtype_vae
|
||||
self.vae = self.first_stage_model # real vae
|
||||
|
||||
self.alphas_cumprod = 1 / (self.model.model_sampling.sigmas ** 2 + 1)
|
||||
|
||||
self.text_encoders = FluxCond()
|
||||
self.cond_stage_key = 'txt'
|
||||
|
||||
self.parameterization = "eps"
|
||||
self.model.conditioning_key = "crossattn"
|
||||
|
||||
self.latent_format = FLUX1LatentFormat()
|
||||
self.latent_channels = 16
|
||||
|
||||
@property
|
||||
def cond_stage_model(self):
|
||||
return self.text_encoders
|
||||
|
||||
def before_load_weights(self, state_dict):
|
||||
self.cond_stage_model.before_load_weights(state_dict)
|
||||
|
||||
def ema_scope(self):
|
||||
return contextlib.nullcontext()
|
||||
|
||||
def get_learned_conditioning(self, batch: list[str]):
|
||||
return self.cond_stage_model(batch)
|
||||
|
||||
def apply_model(self, x, t, cond):
|
||||
return self.model(x, t, c_crossattn=cond['crossattn'], y=cond['vector'])
|
||||
|
||||
def decode_first_stage(self, latent):
|
||||
latent = self.latent_format.process_out(latent)
|
||||
x = self.first_stage_model.decode(latent)
|
||||
if x.dtype == torch.float16:
|
||||
x = torch.nan_to_num(x, nan=0.0, posinf=65504, neginf=-65504)
|
||||
return x
|
||||
|
||||
def encode_first_stage(self, image):
|
||||
latent = self.first_stage_model.encode(image)
|
||||
return self.latent_format.process_in(latent)
|
||||
|
||||
def get_first_stage_encoding(self, x):
|
||||
return x
|
||||
|
||||
def create_denoiser(self):
|
||||
return FLUX1Denoiser(self, self.model.model_sampling.sigmas)
|
||||
|
||||
def medvram_fields(self):
|
||||
return [
|
||||
(self, 'first_stage_model'),
|
||||
(self, 'text_encoders'),
|
||||
(self, 'model'),
|
||||
]
|
||||
|
||||
def add_noise_to_latent(self, x, noise, amount):
|
||||
return x * (1 - amount) + noise * amount
|
||||
|
||||
def fix_dimensions(self, width, height):
|
||||
return width // 16 * 16, height // 16 * 16
|
||||
|
||||
def diffusers_weight_mapping(self):
|
||||
# https://github.com/kohya-ss/sd-scripts/blob/a61cf73a5cb5209c3f4d1a3688dd276a4dfd1ecb/networks/convert_flux_lora.py
|
||||
# please see also https://github.com/huggingface/diffusers/blob/main/src/diffusers/loaders/lora_conversion_utils.py
|
||||
for i in range(self.model.depth):
|
||||
yield f"transformer.transformer_blocks.{i}.attn.add_k_proj", f"diffusion_model_double_blocks_{i}_txt_attn_qkv_k_proj"
|
||||
yield f"transformer.transformer_blocks.{i}.attn.add_q_proj", f"diffusion_model_double_blocks_{i}_txt_attn_qkv_q_proj"
|
||||
yield f"transformer.transformer_blocks.{i}.attn.add_v_proj", f"diffusion_model_double_blocks_{i}_txt_attn_qkv_v_proj"
|
||||
|
||||
yield f"transformer.transformer_blocks.{i}.attn.to_add_out", f"diffusion_model_double_blocks_{i}_txt_attn_proj"
|
||||
|
||||
yield f"transformer.transformer_blocks.{i}.attn.to_k", f"diffusion_model_double_blocks_{i}_img_attn_qkv_k_proj"
|
||||
yield f"transformer.transformer_blocks.{i}.attn.to_q", f"diffusion_model_double_blocks_{i}_img_attn_qkv_q_proj"
|
||||
yield f"transformer.transformer_blocks.{i}.attn.to_v", f"diffusion_model_double_blocks_{i}_img_attn_qkv_v_proj"
|
||||
|
||||
yield f"transformer.transformer_blocks.{i}.attn.to_out.0", f"diffusion_model_double_blocks_{i}_img_attn_proj"
|
||||
|
||||
yield f"transformer.transformer_blocks.{i}.ff.net.0.proj", f"diffusion_model_double_blocks_{i}_img_mlp_0"
|
||||
yield f"transformer.transformer_blocks.{i}.ff.net.2", f"diffusion_model_double_blocks_{i}_img_mlp_2"
|
||||
yield f"transformer.transformer_blocks.{i}.ff_context.net.0.proj", f"diffusion_model_double_blocks_{i}_txt_mlp_0"
|
||||
yield f"transformer.transformer_blocks.{i}.ff_context.net.2", f"diffusion_model_double_blocks_{i}_txt_mlp_2"
|
||||
yield f"transformer.transformer_blocks.{i}.norm1.linear", f"diffusion_model_double_blocks_{i}_img_mod_lin"
|
||||
yield f"transformer.transformer_blocks.{i}.norm1_context.linear", f"diffusion_model_double_blocks_{i}_txt_mod_lin"
|
||||
|
||||
for i in range(self.model.depth_single_block):
|
||||
yield f"transformer.single_transformer_blocks.{i}.attn.to_q", f"diffusion_model_single_blocks_{i}_linear1_q_proj"
|
||||
yield f"transformer.single_transformer_blocks.{i}.attn.to_k", f"diffusion_model_single_blocks_{i}_linear1_k_proj"
|
||||
yield f"transformer.single_transformer_blocks.{i}.attn.to_v", f"diffusion_model_single_blocks_{i}_linear1_v_proj"
|
||||
yield f"transformer.single_transformer_blocks.{i}.proj_mlp", f"diffusion_model_single_blocks_{i}_linear1_mlp_proj"
|
||||
|
||||
yield f"transformer.single_transformer_blocks.{i}.proj_out", f"diffusion_model_single_blocks_{i}_linear2"
|
||||
yield f"transformer.single_transformer_blocks.{i}.norm.linear", f"diffusion_model_single_blocks_{i}_modulation_lin"
|
||||
30
modules/models/flux/math.py
Normal file
30
modules/models/flux/math.py
Normal file
|
|
@ -0,0 +1,30 @@
|
|||
import torch
|
||||
from einops import rearrange
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor:
|
||||
q, k = apply_rope(q, k, pe)
|
||||
|
||||
x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
|
||||
x = rearrange(x, "B H L D -> B L (H D)")
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
|
||||
assert dim % 2 == 0
|
||||
scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
|
||||
omega = 1.0 / (theta**scale)
|
||||
out = torch.einsum("...n,d->...nd", pos, omega)
|
||||
out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
|
||||
out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
|
||||
return out.float()
|
||||
|
||||
|
||||
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor):
|
||||
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
|
||||
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
|
||||
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
|
||||
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
|
||||
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
|
||||
147
modules/models/flux/model.py
Normal file
147
modules/models/flux/model.py
Normal file
|
|
@ -0,0 +1,147 @@
|
|||
# original code from https://github.com/black-forest-labs/flux
|
||||
#
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
|
||||
from einops import rearrange, repeat
|
||||
from torch import Tensor, nn
|
||||
|
||||
from .modules.layers import (DoubleStreamBlock, EmbedND, LastLayer,
|
||||
MLPEmbedder, SingleStreamBlock,
|
||||
timestep_embedding)
|
||||
|
||||
|
||||
@dataclass
|
||||
class FluxParams:
|
||||
in_channels: int
|
||||
vec_in_dim: int
|
||||
context_in_dim: int
|
||||
hidden_size: int
|
||||
mlp_ratio: float
|
||||
num_heads: int
|
||||
depth: int
|
||||
depth_single_blocks: int
|
||||
axes_dim: list
|
||||
theta: int
|
||||
qkv_bias: bool
|
||||
guidance_embed: bool
|
||||
|
||||
|
||||
class Flux(nn.Module):
|
||||
"""
|
||||
Transformer model for flow matching on sequences.
|
||||
"""
|
||||
|
||||
def __init__(self, image_model=None, final_layer=True, dtype=None, device=None, **kwargs):
|
||||
super().__init__()
|
||||
|
||||
self.dtype = dtype
|
||||
params = FluxParams(**kwargs)
|
||||
self.params = params
|
||||
self.in_channels = params.in_channels * 2 * 2
|
||||
self.out_channels = self.in_channels
|
||||
if params.hidden_size % params.num_heads != 0:
|
||||
raise ValueError(
|
||||
f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
|
||||
)
|
||||
pe_dim = params.hidden_size // params.num_heads
|
||||
if sum(params.axes_dim) != pe_dim:
|
||||
raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
|
||||
self.hidden_size = params.hidden_size
|
||||
self.num_heads = params.num_heads
|
||||
self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
|
||||
self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True, dtype=dtype, device=device)
|
||||
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, dtype=dtype, device=device)
|
||||
self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size, dtype=dtype, device=device)
|
||||
self.guidance_in = (
|
||||
MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, dtype=dtype, device=device) if params.guidance_embed else nn.Identity()
|
||||
)
|
||||
self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size, dtype=dtype, device=device)
|
||||
|
||||
self.double_blocks = nn.ModuleList(
|
||||
[
|
||||
DoubleStreamBlock(
|
||||
self.hidden_size,
|
||||
self.num_heads,
|
||||
mlp_ratio=params.mlp_ratio,
|
||||
qkv_bias=params.qkv_bias,
|
||||
dtype=dtype, device=device,
|
||||
)
|
||||
for _ in range(params.depth)
|
||||
]
|
||||
)
|
||||
|
||||
self.single_blocks = nn.ModuleList(
|
||||
[
|
||||
SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, dtype=dtype, device=device)
|
||||
for _ in range(params.depth_single_blocks)
|
||||
]
|
||||
)
|
||||
|
||||
if final_layer:
|
||||
self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels, dtype=dtype, device=device)
|
||||
|
||||
def forward_orig(
|
||||
self,
|
||||
img: Tensor,
|
||||
img_ids: Tensor,
|
||||
txt: Tensor,
|
||||
txt_ids: Tensor,
|
||||
timesteps: Tensor,
|
||||
y: Tensor,
|
||||
guidance: Tensor = None,
|
||||
) -> Tensor:
|
||||
if img.ndim != 3 or txt.ndim != 3:
|
||||
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
||||
|
||||
# running on sequences img
|
||||
img = self.img_in(img)
|
||||
vec = self.time_in(timestep_embedding(timesteps, 256).to(img.dtype))
|
||||
if self.params.guidance_embed:
|
||||
if guidance is None:
|
||||
raise ValueError("Didn't get guidance strength for guidance distilled model.")
|
||||
vec = vec + self.guidance_in(timestep_embedding(guidance, 256).to(img.dtype))
|
||||
vec = vec + self.vector_in(y)
|
||||
txt = self.txt_in(txt)
|
||||
|
||||
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||
pe = self.pe_embedder(ids)
|
||||
|
||||
for block in self.double_blocks:
|
||||
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
|
||||
|
||||
img = torch.cat((txt, img), 1)
|
||||
for block in self.single_blocks:
|
||||
img = block(img, vec=vec, pe=pe)
|
||||
img = img[:, txt.shape[1] :, ...]
|
||||
|
||||
if self.final_layer:
|
||||
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
|
||||
return img
|
||||
|
||||
def forward(self, x, timestep, context, y, guidance, **kwargs):
|
||||
# from comfy/ldm/common_dit.py
|
||||
def pad_to_patch_size(img, patch_size=(2, 2), padding_mode="circular"):
|
||||
if padding_mode == "circular" and torch.jit.is_tracing() or torch.jit.is_scripting():
|
||||
padding_mode = "reflect"
|
||||
pad_h = (patch_size[0] - img.shape[-2] % patch_size[0]) % patch_size[0]
|
||||
pad_w = (patch_size[1] - img.shape[-1] % patch_size[1]) % patch_size[1]
|
||||
return torch.nn.functional.pad(img, (0, pad_w, 0, pad_h), mode=padding_mode)
|
||||
|
||||
bs, c, h, w = x.shape
|
||||
patch_size = 2
|
||||
x = pad_to_patch_size(x, (patch_size, patch_size))
|
||||
|
||||
img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
|
||||
|
||||
h_len = ((h + (patch_size // 2)) // patch_size)
|
||||
w_len = ((w + (patch_size // 2)) // patch_size)
|
||||
img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
|
||||
img_ids[..., 1] = img_ids[..., 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype)[:, None]
|
||||
img_ids[..., 2] = img_ids[..., 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype)[None, :]
|
||||
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
||||
|
||||
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
|
||||
out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance)
|
||||
return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)[:,:,:h,:w]
|
||||
265
modules/models/flux/modules/layers.py
Normal file
265
modules/models/flux/modules/layers.py
Normal file
|
|
@ -0,0 +1,265 @@
|
|||
import math
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
from einops import rearrange
|
||||
from torch import Tensor, nn
|
||||
|
||||
from ..math import attention, rope
|
||||
|
||||
|
||||
class EmbedND(nn.Module):
|
||||
def __init__(self, dim: int, theta: int, axes_dim: list):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.theta = theta
|
||||
self.axes_dim = axes_dim
|
||||
|
||||
def forward(self, ids: Tensor) -> Tensor:
|
||||
n_axes = ids.shape[-1]
|
||||
emb = torch.cat(
|
||||
[rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
|
||||
dim=-3,
|
||||
)
|
||||
|
||||
return emb.unsqueeze(1)
|
||||
|
||||
|
||||
def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0):
|
||||
"""
|
||||
Create sinusoidal timestep embeddings.
|
||||
:param t: a 1-D Tensor of N indices, one per batch element.
|
||||
These may be fractional.
|
||||
:param dim: the dimension of the output.
|
||||
:param max_period: controls the minimum frequency of the embeddings.
|
||||
:return: an (N, D) Tensor of positional embeddings.
|
||||
"""
|
||||
t = time_factor * t
|
||||
half = dim // 2
|
||||
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half)
|
||||
|
||||
args = t[:, None].float() * freqs[None]
|
||||
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||
if dim % 2:
|
||||
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
||||
if torch.is_floating_point(t):
|
||||
embedding = embedding.to(t)
|
||||
return embedding
|
||||
|
||||
|
||||
class MLPEmbedder(nn.Module):
|
||||
def __init__(self, in_dim: int, hidden_dim: int, dtype=None, device=None):
|
||||
super().__init__()
|
||||
self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True, dtype=dtype, device=device)
|
||||
self.silu = nn.SiLU()
|
||||
self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
return self.out_layer(self.silu(self.in_layer(x)))
|
||||
|
||||
|
||||
class RMSNorm(torch.nn.Module):
|
||||
def __init__(self, dim: int, dtype=None, device=None):
|
||||
super().__init__()
|
||||
self.scale = nn.Parameter(torch.empty((dim), dtype=dtype, device=device))
|
||||
|
||||
def forward(self, x: Tensor):
|
||||
x_dtype = x.dtype
|
||||
x = x.float()
|
||||
rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6)
|
||||
return (x * rrms).to(dtype=x_dtype) * self.scale
|
||||
|
||||
|
||||
class QKNorm(torch.nn.Module):
|
||||
def __init__(self, dim: int, dtype=None, device=None):
|
||||
super().__init__()
|
||||
self.query_norm = RMSNorm(dim, dtype=dtype, device=device)
|
||||
self.key_norm = RMSNorm(dim, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple:
|
||||
q = self.query_norm(q)
|
||||
k = self.key_norm(k)
|
||||
return q.to(v), k.to(v)
|
||||
|
||||
|
||||
class QkvLinear(torch.nn.Linear):
|
||||
pass
|
||||
|
||||
|
||||
class SelfAttention(nn.Module):
|
||||
def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False, dtype=None, device=None):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
head_dim = dim // num_heads
|
||||
|
||||
self.qkv = QkvLinear(dim, dim * 3, bias=qkv_bias, dtype=dtype, device=device)
|
||||
self.norm = QKNorm(head_dim, dtype=dtype, device=device)
|
||||
self.proj = nn.Linear(dim, dim, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x: Tensor, pe: Tensor) -> Tensor:
|
||||
qkv = self.qkv(x)
|
||||
q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
||||
q, k = self.norm(q, k, v)
|
||||
x = attention(q, k, v, pe=pe)
|
||||
x = self.proj(x)
|
||||
return x
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModulationOut:
|
||||
shift: Tensor
|
||||
scale: Tensor
|
||||
gate: Tensor
|
||||
|
||||
|
||||
class Modulation(nn.Module):
|
||||
def __init__(self, dim: int, double: bool, dtype=None, device=None):
|
||||
super().__init__()
|
||||
self.is_double = double
|
||||
self.multiplier = 6 if double else 3
|
||||
self.lin = nn.Linear(dim, self.multiplier * dim, bias=True, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, vec: Tensor) -> tuple:
|
||||
out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1)
|
||||
|
||||
return (
|
||||
ModulationOut(*out[:3]),
|
||||
ModulationOut(*out[3:]) if self.is_double else None,
|
||||
)
|
||||
|
||||
|
||||
class DoubleStreamBlock(nn.Module):
|
||||
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, dtype=None, device=None):
|
||||
super().__init__()
|
||||
|
||||
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||
self.num_heads = num_heads
|
||||
self.hidden_size = hidden_size
|
||||
self.img_mod = Modulation(hidden_size, double=True, dtype=dtype, device=device)
|
||||
self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, dtype=dtype, device=device)
|
||||
|
||||
self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
self.img_mlp = nn.Sequential(
|
||||
nn.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device),
|
||||
nn.GELU(approximate="tanh"),
|
||||
nn.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
|
||||
)
|
||||
|
||||
self.txt_mod = Modulation(hidden_size, double=True, dtype=dtype, device=device)
|
||||
self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, dtype=dtype, device=device)
|
||||
|
||||
self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
self.txt_mlp = nn.Sequential(
|
||||
nn.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device),
|
||||
nn.GELU(approximate="tanh"),
|
||||
nn.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
|
||||
)
|
||||
|
||||
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor):
|
||||
img_mod1, img_mod2 = self.img_mod(vec)
|
||||
txt_mod1, txt_mod2 = self.txt_mod(vec)
|
||||
|
||||
# prepare image for attention
|
||||
img_modulated = self.img_norm1(img)
|
||||
img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
|
||||
img_qkv = self.img_attn.qkv(img_modulated)
|
||||
img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
||||
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
|
||||
|
||||
# prepare txt for attention
|
||||
txt_modulated = self.txt_norm1(txt)
|
||||
txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
|
||||
txt_qkv = self.txt_attn.qkv(txt_modulated)
|
||||
txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
||||
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
|
||||
|
||||
# run actual attention
|
||||
q = torch.cat((txt_q, img_q), dim=2)
|
||||
k = torch.cat((txt_k, img_k), dim=2)
|
||||
v = torch.cat((txt_v, img_v), dim=2)
|
||||
|
||||
attn = attention(q, k, v, pe=pe)
|
||||
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
|
||||
|
||||
# calculate the img bloks
|
||||
img = img + img_mod1.gate * self.img_attn.proj(img_attn)
|
||||
img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift)
|
||||
|
||||
# calculate the txt bloks
|
||||
txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn)
|
||||
txt = txt + txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
|
||||
|
||||
if txt.dtype == torch.float16:
|
||||
txt = torch.nan_to_num(txt, nan=0.0, posinf=65504, neginf=-65504)
|
||||
|
||||
|
||||
return img, txt
|
||||
|
||||
|
||||
class SingleStreamBlock(nn.Module):
|
||||
"""
|
||||
A DiT block with parallel linear layers as described in
|
||||
https://arxiv.org/abs/2302.05442 and adapted modulation interface.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
mlp_ratio: float = 4.0,
|
||||
qk_scale: float = None,
|
||||
dtype=None,
|
||||
device=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_dim = hidden_size
|
||||
self.num_heads = num_heads
|
||||
head_dim = hidden_size // num_heads
|
||||
self.scale = qk_scale or head_dim**-0.5
|
||||
|
||||
self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||
# qkv and mlp_in
|
||||
self.linear1 = QkvLinear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim, dtype=dtype, device=device)
|
||||
# proj and mlp_out
|
||||
self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size, dtype=dtype, device=device)
|
||||
|
||||
self.norm = QKNorm(head_dim, dtype=dtype, device=device)
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
|
||||
self.mlp_act = nn.GELU(approximate="tanh")
|
||||
self.modulation = Modulation(hidden_size, double=False, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor:
|
||||
mod, _ = self.modulation(vec)
|
||||
x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
|
||||
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
|
||||
|
||||
q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
||||
q, k = self.norm(q, k, v)
|
||||
|
||||
# compute attention
|
||||
attn = attention(q, k, v, pe=pe)
|
||||
# compute activation in mlp stream, cat again and run second linear layer
|
||||
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
|
||||
x += mod.gate * output
|
||||
if x.dtype == torch.float16:
|
||||
x = torch.nan_to_num(x, nan=0.0, posinf=65504, neginf=-65504)
|
||||
return x
|
||||
|
||||
|
||||
class LastLayer(nn.Module):
|
||||
def __init__(self, hidden_size: int, patch_size: int, out_channels: int, dtype=None, device=None):
|
||||
super().__init__()
|
||||
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True, dtype=dtype, device=device)
|
||||
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True, dtype=dtype, device=device))
|
||||
|
||||
def forward(self, x: Tensor, vec: Tensor) -> Tensor:
|
||||
shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1)
|
||||
x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
|
||||
x = self.linear(x)
|
||||
return x
|
||||
201
modules/models/flux/util.py
Normal file
201
modules/models/flux/util.py
Normal file
|
|
@ -0,0 +1,201 @@
|
|||
import os
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
from einops import rearrange
|
||||
from huggingface_hub import hf_hub_download
|
||||
from imwatermark import WatermarkEncoder
|
||||
from safetensors.torch import load_file as load_sft
|
||||
|
||||
from .model import Flux, FluxParams
|
||||
from .modules.autoencoder import AutoEncoder, AutoEncoderParams
|
||||
from .modules.conditioner import HFEmbedder
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelSpec:
|
||||
params: FluxParams
|
||||
ae_params: AutoEncoderParams
|
||||
ckpt_path: str | None
|
||||
ae_path: str | None
|
||||
repo_id: str | None
|
||||
repo_flow: str | None
|
||||
repo_ae: str | None
|
||||
|
||||
|
||||
configs = {
|
||||
"flux-dev": ModelSpec(
|
||||
repo_id="black-forest-labs/FLUX.1-dev",
|
||||
repo_flow="flux1-dev.safetensors",
|
||||
repo_ae="ae.safetensors",
|
||||
ckpt_path=os.getenv("FLUX_DEV"),
|
||||
params=FluxParams(
|
||||
in_channels=64,
|
||||
vec_in_dim=768,
|
||||
context_in_dim=4096,
|
||||
hidden_size=3072,
|
||||
mlp_ratio=4.0,
|
||||
num_heads=24,
|
||||
depth=19,
|
||||
depth_single_blocks=38,
|
||||
axes_dim=[16, 56, 56],
|
||||
theta=10_000,
|
||||
qkv_bias=True,
|
||||
guidance_embed=True,
|
||||
),
|
||||
ae_path=os.getenv("AE"),
|
||||
ae_params=AutoEncoderParams(
|
||||
resolution=256,
|
||||
in_channels=3,
|
||||
ch=128,
|
||||
out_ch=3,
|
||||
ch_mult=[1, 2, 4, 4],
|
||||
num_res_blocks=2,
|
||||
z_channels=16,
|
||||
scale_factor=0.3611,
|
||||
shift_factor=0.1159,
|
||||
),
|
||||
),
|
||||
"flux-schnell": ModelSpec(
|
||||
repo_id="black-forest-labs/FLUX.1-schnell",
|
||||
repo_flow="flux1-schnell.safetensors",
|
||||
repo_ae="ae.safetensors",
|
||||
ckpt_path=os.getenv("FLUX_SCHNELL"),
|
||||
params=FluxParams(
|
||||
in_channels=64,
|
||||
vec_in_dim=768,
|
||||
context_in_dim=4096,
|
||||
hidden_size=3072,
|
||||
mlp_ratio=4.0,
|
||||
num_heads=24,
|
||||
depth=19,
|
||||
depth_single_blocks=38,
|
||||
axes_dim=[16, 56, 56],
|
||||
theta=10_000,
|
||||
qkv_bias=True,
|
||||
guidance_embed=False,
|
||||
),
|
||||
ae_path=os.getenv("AE"),
|
||||
ae_params=AutoEncoderParams(
|
||||
resolution=256,
|
||||
in_channels=3,
|
||||
ch=128,
|
||||
out_ch=3,
|
||||
ch_mult=[1, 2, 4, 4],
|
||||
num_res_blocks=2,
|
||||
z_channels=16,
|
||||
scale_factor=0.3611,
|
||||
shift_factor=0.1159,
|
||||
),
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def print_load_warning(missing: list[str], unexpected: list[str]) -> None:
|
||||
if len(missing) > 0 and len(unexpected) > 0:
|
||||
print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing))
|
||||
print("\n" + "-" * 79 + "\n")
|
||||
print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected))
|
||||
elif len(missing) > 0:
|
||||
print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing))
|
||||
elif len(unexpected) > 0:
|
||||
print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected))
|
||||
|
||||
|
||||
def load_flow_model(name: str, device: str | torch.device = "cuda", hf_download: bool = True):
|
||||
# Loading Flux
|
||||
print("Init model")
|
||||
ckpt_path = configs[name].ckpt_path
|
||||
if (
|
||||
ckpt_path is None
|
||||
and configs[name].repo_id is not None
|
||||
and configs[name].repo_flow is not None
|
||||
and hf_download
|
||||
):
|
||||
ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_flow)
|
||||
|
||||
with torch.device("meta" if ckpt_path is not None else device):
|
||||
model = Flux(configs[name].params).to(torch.bfloat16)
|
||||
|
||||
if ckpt_path is not None:
|
||||
print("Loading checkpoint")
|
||||
# load_sft doesn't support torch.device
|
||||
sd = load_sft(ckpt_path, device=str(device))
|
||||
missing, unexpected = model.load_state_dict(sd, strict=False, assign=True)
|
||||
print_load_warning(missing, unexpected)
|
||||
return model
|
||||
|
||||
|
||||
def load_t5(device: str | torch.device = "cuda", max_length: int = 512) -> HFEmbedder:
|
||||
# max length 64, 128, 256 and 512 should work (if your sequence is short enough)
|
||||
return HFEmbedder("google/t5-v1_1-xxl", max_length=max_length, torch_dtype=torch.bfloat16).to(device)
|
||||
|
||||
|
||||
def load_clip(device: str | torch.device = "cuda") -> HFEmbedder:
|
||||
return HFEmbedder("openai/clip-vit-large-patch14", max_length=77, torch_dtype=torch.bfloat16).to(device)
|
||||
|
||||
|
||||
def load_ae(name: str, device: str | torch.device = "cuda", hf_download: bool = True) -> AutoEncoder:
|
||||
ckpt_path = configs[name].ae_path
|
||||
if (
|
||||
ckpt_path is None
|
||||
and configs[name].repo_id is not None
|
||||
and configs[name].repo_ae is not None
|
||||
and hf_download
|
||||
):
|
||||
ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_ae)
|
||||
|
||||
# Loading the autoencoder
|
||||
print("Init AE")
|
||||
with torch.device("meta" if ckpt_path is not None else device):
|
||||
ae = AutoEncoder(configs[name].ae_params)
|
||||
|
||||
if ckpt_path is not None:
|
||||
sd = load_sft(ckpt_path, device=str(device))
|
||||
missing, unexpected = ae.load_state_dict(sd, strict=False, assign=True)
|
||||
print_load_warning(missing, unexpected)
|
||||
return ae
|
||||
|
||||
|
||||
class WatermarkEmbedder:
|
||||
def __init__(self, watermark):
|
||||
self.watermark = watermark
|
||||
self.num_bits = len(WATERMARK_BITS)
|
||||
self.encoder = WatermarkEncoder()
|
||||
self.encoder.set_watermark("bits", self.watermark)
|
||||
|
||||
def __call__(self, image: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Adds a predefined watermark to the input image
|
||||
|
||||
Args:
|
||||
image: ([N,] B, RGB, H, W) in range [-1, 1]
|
||||
|
||||
Returns:
|
||||
same as input but watermarked
|
||||
"""
|
||||
image = 0.5 * image + 0.5
|
||||
squeeze = len(image.shape) == 4
|
||||
if squeeze:
|
||||
image = image[None, ...]
|
||||
n = image.shape[0]
|
||||
image_np = rearrange((255 * image).detach().cpu(), "n b c h w -> (n b) h w c").numpy()[:, :, :, ::-1]
|
||||
# torch (b, c, h, w) in [0, 1] -> numpy (b, h, w, c) [0, 255]
|
||||
# watermarking libary expects input as cv2 BGR format
|
||||
for k in range(image_np.shape[0]):
|
||||
image_np[k] = self.encoder.encode(image_np[k], "dwtDct")
|
||||
image = torch.from_numpy(rearrange(image_np[:, :, :, ::-1], "(n b) h w c -> n b c h w", n=n)).to(
|
||||
image.device
|
||||
)
|
||||
image = torch.clamp(image / 255, min=0.0, max=1.0)
|
||||
if squeeze:
|
||||
image = image[0]
|
||||
image = 2 * image - 1
|
||||
return image
|
||||
|
||||
|
||||
# A fixed 48-bit message that was chosen at random
|
||||
WATERMARK_MESSAGE = 0b001010101111111010000111100111001111010100101110
|
||||
# bin(x)[2:] gives bits of x as str, use int to convert them to 0/1
|
||||
WATERMARK_BITS = [int(bit) for bit in bin(WATERMARK_MESSAGE)[2:]]
|
||||
embed_watermark = WatermarkEmbedder(WATERMARK_BITS)
|
||||
|
|
@ -24,6 +24,11 @@ class AutocastLinear(nn.Linear):
|
|||
def forward(self, x):
|
||||
return torch.nn.functional.linear(x, self.weight.to(x.dtype), self.bias.to(x.dtype) if self.bias is not None else None)
|
||||
|
||||
class AutocastLayerNorm(nn.LayerNorm):
|
||||
def forward(self, x):
|
||||
return torch.nn.functional.layer_norm(
|
||||
x, self.normalized_shape, self.weight.to(x.dtype), self.bias.to(x.dtype) if self.bias is not None else None, self.eps)
|
||||
|
||||
|
||||
def attention(q, k, v, heads, mask=None):
|
||||
"""Convenience wrapper around a basic attention operation"""
|
||||
|
|
@ -41,9 +46,9 @@ class Mlp(nn.Module):
|
|||
out_features = out_features or in_features
|
||||
hidden_features = hidden_features or in_features
|
||||
|
||||
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias, dtype=dtype, device=device)
|
||||
self.fc1 = AutocastLinear(in_features, hidden_features, bias=bias, dtype=dtype, device=device)
|
||||
self.act = act_layer
|
||||
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias, dtype=dtype, device=device)
|
||||
self.fc2 = AutocastLinear(hidden_features, out_features, bias=bias, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.fc1(x)
|
||||
|
|
@ -61,10 +66,10 @@ class CLIPAttention(torch.nn.Module):
|
|||
def __init__(self, embed_dim, heads, dtype, device):
|
||||
super().__init__()
|
||||
self.heads = heads
|
||||
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
|
||||
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
|
||||
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
|
||||
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
|
||||
self.q_proj = AutocastLinear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
|
||||
self.k_proj = AutocastLinear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
|
||||
self.v_proj = AutocastLinear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
|
||||
self.out_proj = AutocastLinear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x, mask=None):
|
||||
q = self.q_proj(x)
|
||||
|
|
@ -82,9 +87,11 @@ ACTIVATIONS = {
|
|||
class CLIPLayer(torch.nn.Module):
|
||||
def __init__(self, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device):
|
||||
super().__init__()
|
||||
self.layer_norm1 = nn.LayerNorm(embed_dim, dtype=dtype, device=device)
|
||||
#self.layer_norm1 = nn.LayerNorm(embed_dim, dtype=dtype, device=device)
|
||||
self.layer_norm1 = AutocastLayerNorm(embed_dim, dtype=dtype, device=device)
|
||||
self.self_attn = CLIPAttention(embed_dim, heads, dtype, device)
|
||||
self.layer_norm2 = nn.LayerNorm(embed_dim, dtype=dtype, device=device)
|
||||
self.layer_norm2 = AutocastLayerNorm(embed_dim, dtype=dtype, device=device)
|
||||
#self.layer_norm2 = nn.LayerNorm(embed_dim, dtype=dtype, device=device)
|
||||
#self.mlp = CLIPMLP(embed_dim, intermediate_size, intermediate_activation, dtype, device)
|
||||
self.mlp = Mlp(embed_dim, intermediate_size, embed_dim, act_layer=ACTIVATIONS[intermediate_activation], dtype=dtype, device=device)
|
||||
|
||||
|
|
@ -131,7 +138,7 @@ class CLIPTextModel_(torch.nn.Module):
|
|||
super().__init__()
|
||||
self.embeddings = CLIPEmbeddings(embed_dim, dtype=torch.float32, device=device, textual_inversion_key=config_dict.get('textual_inversion_key', 'clip_l'))
|
||||
self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device)
|
||||
self.final_layer_norm = nn.LayerNorm(embed_dim, dtype=dtype, device=device)
|
||||
self.final_layer_norm = AutocastLayerNorm(embed_dim, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, input_tokens, intermediate_output=None, final_layer_norm_intermediate=True):
|
||||
x = self.embeddings(input_tokens)
|
||||
|
|
@ -150,7 +157,7 @@ class CLIPTextModel(torch.nn.Module):
|
|||
self.num_layers = config_dict["num_hidden_layers"]
|
||||
self.text_model = CLIPTextModel_(config_dict, dtype, device)
|
||||
embed_dim = config_dict["hidden_size"]
|
||||
self.text_projection = nn.Linear(embed_dim, embed_dim, bias=False, dtype=dtype, device=device)
|
||||
self.text_projection = AutocastLinear(embed_dim, embed_dim, bias=False, dtype=dtype, device=device)
|
||||
self.text_projection.weight.copy_(torch.eye(embed_dim))
|
||||
self.dtype = dtype
|
||||
|
||||
|
|
@ -370,7 +377,7 @@ class T5Attention(torch.nn.Module):
|
|||
if relative_attention_bias:
|
||||
self.relative_attention_num_buckets = 32
|
||||
self.relative_attention_max_distance = 128
|
||||
self.relative_attention_bias = torch.nn.Embedding(self.relative_attention_num_buckets, self.num_heads, device=device)
|
||||
self.relative_attention_bias = torch.nn.Embedding(self.relative_attention_num_buckets, self.num_heads, device=device, dtype=torch.float32)
|
||||
|
||||
@staticmethod
|
||||
def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128):
|
||||
|
|
@ -442,7 +449,7 @@ class T5Attention(torch.nn.Module):
|
|||
else:
|
||||
mask = None
|
||||
|
||||
out = attention(q, k * ((k.shape[-1] / self.num_heads) ** 0.5), v, self.num_heads, mask.to(x.dtype) if mask is not None else None)
|
||||
out = attention(q, k * ((k.shape[-1] / self.num_heads) ** 0.5), v, self.num_heads, mask.to(q.dtype) if mask is not None else None)
|
||||
|
||||
return self.o(out), past_bias
|
||||
|
||||
|
|
@ -475,19 +482,21 @@ class T5Block(torch.nn.Module):
|
|||
class T5Stack(torch.nn.Module):
|
||||
def __init__(self, num_layers, model_dim, inner_dim, ff_dim, num_heads, vocab_size, dtype, device):
|
||||
super().__init__()
|
||||
self.embed_tokens = torch.nn.Embedding(vocab_size, model_dim, device=device)
|
||||
#self.embed_tokens = torch.nn.Embedding(vocab_size, model_dim, device=device, dtype=torch.float32)
|
||||
self.block = torch.nn.ModuleList([T5Block(model_dim, inner_dim, ff_dim, num_heads, relative_attention_bias=(i == 0), dtype=dtype, device=device) for i in range(num_layers)])
|
||||
self.final_layer_norm = T5LayerNorm(model_dim, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, input_ids, intermediate_output=None, final_layer_norm_intermediate=True):
|
||||
def forward(self, x, intermediate_output=None, final_layer_norm_intermediate=True):
|
||||
intermediate = None
|
||||
x = self.embed_tokens(input_ids).to(torch.float32) # needs float32 or else T5 returns all zeroes
|
||||
#x = self.embed_tokens(input_ids).to(torch.float32) # needs float32 or else T5 returns all zeroes
|
||||
# some T5XXL do not embed_token. use shared token instead like comfy
|
||||
past_bias = None
|
||||
for i, layer in enumerate(self.block):
|
||||
x, past_bias = layer(x, past_bias)
|
||||
if i == intermediate_output:
|
||||
intermediate = x.clone()
|
||||
x = self.final_layer_norm(x)
|
||||
x = torch.nan_to_num(x)
|
||||
if intermediate is not None and final_layer_norm_intermediate:
|
||||
intermediate = self.final_layer_norm(intermediate)
|
||||
return x, intermediate
|
||||
|
|
@ -498,13 +507,18 @@ class T5(torch.nn.Module):
|
|||
super().__init__()
|
||||
self.num_layers = config_dict["num_layers"]
|
||||
self.encoder = T5Stack(self.num_layers, config_dict["d_model"], config_dict["d_model"], config_dict["d_ff"], config_dict["num_heads"], config_dict["vocab_size"], dtype, device)
|
||||
self.shared = torch.nn.Embedding(config_dict["vocab_size"], config_dict["d_model"], device=device, dtype=torch.float32)
|
||||
self.dtype = dtype
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.encoder.embed_tokens
|
||||
#return self.encoder.embed_tokens
|
||||
return self.shared
|
||||
|
||||
def set_input_embeddings(self, embeddings):
|
||||
self.encoder.embed_tokens = embeddings
|
||||
#self.encoder.embed_tokens = embeddings
|
||||
self.shared = embeddings
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
return self.encoder(*args, **kwargs)
|
||||
def forward(self, input_ids, *args, **kwargs):
|
||||
x = self.shared(input_ids).float()
|
||||
x = torch.nan_to_num(x)
|
||||
return self.encoder(x, *args, **kwargs)
|
||||
|
|
|
|||
|
|
@ -43,7 +43,7 @@ CLIPG_CONFIG = {
|
|||
"textual_inversion_key": "clip_g",
|
||||
}
|
||||
|
||||
T5_URL = f"{shared.hf_endpoint}/AUTOMATIC/stable-diffusion-3-medium-text-encoders/resolve/main/t5xxl_fp16.safetensors"
|
||||
T5_URL = f"{shared.hf_endpoint}/AUTOMATIC/stable-diffusion-3-medium-text-encoders/resolve/main/t5xxl_fp16_e4m3fn.safetensors"
|
||||
T5_CONFIG = {
|
||||
"d_ff": 10240,
|
||||
"d_model": 4096,
|
||||
|
|
@ -140,7 +140,7 @@ class Sd3T5(torch.nn.Module):
|
|||
return tokens, multipliers
|
||||
|
||||
def forward(self, texts, *, token_count):
|
||||
if not self.t5xxl or not shared.opts.sd3_enable_t5:
|
||||
if not self.t5xxl:
|
||||
return torch.zeros((len(texts), token_count, 4096), device=devices.device, dtype=devices.dtype)
|
||||
|
||||
tokens_batch = []
|
||||
|
|
@ -164,11 +164,11 @@ class SD3Cond(torch.nn.Module):
|
|||
self.tokenizer = SD3Tokenizer()
|
||||
|
||||
with torch.no_grad():
|
||||
self.clip_g = SDXLClipG(CLIPG_CONFIG, device="cpu", dtype=devices.dtype)
|
||||
self.clip_l = SDClipModel(layer="hidden", layer_idx=-2, device="cpu", dtype=devices.dtype, layer_norm_hidden_state=False, return_projected_pooled=False, textmodel_json_config=CLIPL_CONFIG)
|
||||
self.clip_g = SDXLClipG(CLIPG_CONFIG, device="cpu", dtype=devices.dtype_inference)
|
||||
self.clip_l = SDClipModel(layer="hidden", layer_idx=-2, device="cpu", dtype=devices.dtype_inference, layer_norm_hidden_state=False, return_projected_pooled=False, textmodel_json_config=CLIPL_CONFIG)
|
||||
|
||||
if shared.opts.sd3_enable_t5:
|
||||
self.t5xxl = T5XXLModel(T5_CONFIG, device="cpu", dtype=devices.dtype)
|
||||
self.t5xxl = T5XXLModel(T5_CONFIG, device="cpu", dtype=devices.dtype_inference)
|
||||
else:
|
||||
self.t5xxl = None
|
||||
|
||||
|
|
@ -199,8 +199,8 @@ class SD3Cond(torch.nn.Module):
|
|||
with safetensors.safe_open(clip_l_file, framework="pt") as file:
|
||||
self.clip_l.transformer.load_state_dict(SafetensorsMapping(file), strict=False)
|
||||
|
||||
if self.t5xxl and 'text_encoders.t5xxl.transformer.encoder.embed_tokens.weight' not in state_dict:
|
||||
t5_file = modelloader.load_file_from_url(T5_URL, model_dir=clip_path, file_name="t5xxl_fp16.safetensors")
|
||||
if self.t5xxl and 'text_encoders.t5xxl.transformer.encoder.block.0.layer.0.SelfAttention.k.weight' not in state_dict:
|
||||
t5_file = modelloader.load_file_from_url(T5_URL, model_dir=clip_path, file_name="t5xxl_fp8_e4m3fn.safetensors")
|
||||
with safetensors.safe_open(t5_file, framework="pt") as file:
|
||||
self.t5xxl.transformer.load_state_dict(SafetensorsMapping(file), strict=False)
|
||||
|
||||
|
|
|
|||
|
|
@ -484,7 +484,7 @@ class StableDiffusionProcessing:
|
|||
|
||||
cache = caches[0]
|
||||
|
||||
with devices.autocast():
|
||||
with devices.autocast(target_dtype=devices.dtype_inference):
|
||||
cache[1] = function(shared.sd_model, required_prompts, steps, hires_steps, shared.opts.use_old_scheduling)
|
||||
|
||||
cache[0] = cached_params
|
||||
|
|
@ -984,7 +984,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
|||
|
||||
sd_models.apply_alpha_schedule_override(p.sd_model, p)
|
||||
|
||||
with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast():
|
||||
with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast(target_dtype=devices.dtype_inference, current_dtype=devices.dtype_unet):
|
||||
samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts)
|
||||
|
||||
if p.scripts is not None:
|
||||
|
|
@ -1147,6 +1147,11 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
|||
if p.scripts is not None:
|
||||
p.scripts.postprocess(p, res)
|
||||
|
||||
|
||||
if lowvram.is_enabled(shared.sd_model):
|
||||
# for interrupted case
|
||||
lowvram.send_everything_to_cpu()
|
||||
|
||||
return res
|
||||
|
||||
|
||||
|
|
@ -1439,7 +1444,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
|||
with devices.autocast():
|
||||
extra_networks.activate(self, self.hr_extra_network_data)
|
||||
|
||||
with devices.autocast():
|
||||
with devices.autocast(target_dtype=devices.dtype_inference):
|
||||
self.calculate_hr_conds()
|
||||
|
||||
sd_models.apply_token_merging(self.sd_model, self.get_token_merging_ratio(for_hr=True))
|
||||
|
|
|
|||
|
|
@ -160,7 +160,7 @@ class LoadStateDictOnMeta(ReplaceHelper):
|
|||
self.state_dict = state_dict
|
||||
self.device = device
|
||||
self.weight_dtype_conversion = weight_dtype_conversion or {}
|
||||
self.default_dtype = self.weight_dtype_conversion.get('')
|
||||
self.default_dtype = self.weight_dtype_conversion.get('', None)
|
||||
|
||||
def get_weight_dtype(self, key):
|
||||
key_first_term, _ = key.split('.', 1)
|
||||
|
|
@ -176,6 +176,11 @@ class LoadStateDictOnMeta(ReplaceHelper):
|
|||
def load_from_state_dict(original, module, state_dict, prefix, *args, **kwargs):
|
||||
used_param_keys = []
|
||||
|
||||
if isinstance(module, (torch.nn.Linear, torch.nn.Conv2d, torch.nn.GroupNorm,)):
|
||||
# HACK add assign=True to local_metadata for some cases
|
||||
args[0]['assign_to_params_buffers'] = True
|
||||
|
||||
|
||||
for name, param in module._parameters.items():
|
||||
if param is None:
|
||||
continue
|
||||
|
|
@ -183,12 +188,16 @@ class LoadStateDictOnMeta(ReplaceHelper):
|
|||
key = prefix + name
|
||||
sd_param = sd.pop(key, None)
|
||||
if sd_param is not None:
|
||||
state_dict[key] = sd_param.to(dtype=self.get_weight_dtype(key))
|
||||
dtype = self.get_weight_dtype(key)
|
||||
if dtype is None:
|
||||
state_dict[key] = sd_param
|
||||
else:
|
||||
state_dict[key] = sd_param.to(dtype=dtype)
|
||||
used_param_keys.append(key)
|
||||
|
||||
if param.is_meta:
|
||||
dtype = sd_param.dtype if sd_param is not None else param.dtype
|
||||
module._parameters[name] = torch.nn.parameter.Parameter(torch.zeros_like(param, device=device, dtype=dtype), requires_grad=param.requires_grad)
|
||||
module._parameters[name] = torch.nn.parameter.Parameter(torch.empty_like(param, device=device, dtype=dtype), requires_grad=param.requires_grad)
|
||||
|
||||
for name in module._buffers:
|
||||
key = prefix + name
|
||||
|
|
|
|||
|
|
@ -42,12 +42,12 @@ def apply_model(orig_func, self, x_noisy, t, cond, **kwargs):
|
|||
if isinstance(cond, dict):
|
||||
for y in cond.keys():
|
||||
if isinstance(cond[y], list):
|
||||
cond[y] = [x.to(devices.dtype_unet) if isinstance(x, torch.Tensor) else x for x in cond[y]]
|
||||
cond[y] = [x.to(devices.dtype_inference) if isinstance(x, torch.Tensor) else x for x in cond[y]]
|
||||
else:
|
||||
cond[y] = cond[y].to(devices.dtype_unet) if isinstance(cond[y], torch.Tensor) else cond[y]
|
||||
cond[y] = cond[y].to(devices.dtype_inference) if isinstance(cond[y], torch.Tensor) else cond[y]
|
||||
|
||||
with devices.autocast():
|
||||
result = orig_func(self, x_noisy.to(devices.dtype_unet), t.to(devices.dtype_unet), cond, **kwargs)
|
||||
result = orig_func(self, x_noisy.to(devices.dtype_inference), t.to(devices.dtype_inference), cond, **kwargs)
|
||||
if devices.unet_needs_upcast:
|
||||
return result.float()
|
||||
else:
|
||||
|
|
@ -107,7 +107,7 @@ class GELUHijack(torch.nn.GELU, torch.nn.Module):
|
|||
torch.nn.GELU.__init__(self, *args, **kwargs)
|
||||
def forward(self, x):
|
||||
if devices.unet_needs_upcast:
|
||||
return torch.nn.GELU.forward(self.float(), x.float()).to(devices.dtype_unet)
|
||||
return torch.nn.GELU.forward(self.float(), x.float()).to(devices.dtype_inference)
|
||||
else:
|
||||
return torch.nn.GELU.forward(self, x)
|
||||
|
||||
|
|
@ -125,11 +125,11 @@ unet_needs_upcast = lambda *args, **kwargs: devices.unet_needs_upcast
|
|||
CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.apply_model', apply_model, unet_needs_upcast)
|
||||
CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', timestep_embedding)
|
||||
CondFunc('ldm.modules.attention.SpatialTransformer.forward', spatial_transformer_forward)
|
||||
CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', lambda orig_func, timesteps, *args, **kwargs: orig_func(timesteps, *args, **kwargs).to(torch.float32 if timesteps.dtype == torch.int64 else devices.dtype_unet), unet_needs_upcast)
|
||||
CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', lambda orig_func, timesteps, *args, **kwargs: orig_func(timesteps, *args, **kwargs).to(torch.float32 if timesteps.dtype == torch.int64 else devices.dtype_inference), unet_needs_upcast)
|
||||
|
||||
if version.parse(torch.__version__) <= version.parse("1.13.2") or torch.cuda.is_available():
|
||||
CondFunc('ldm.modules.diffusionmodules.util.GroupNorm32.forward', lambda orig_func, self, *args, **kwargs: orig_func(self.float(), *args, **kwargs), unet_needs_upcast)
|
||||
CondFunc('ldm.modules.attention.GEGLU.forward', lambda orig_func, self, x: orig_func(self.float(), x.float()).to(devices.dtype_unet), unet_needs_upcast)
|
||||
CondFunc('ldm.modules.attention.GEGLU.forward', lambda orig_func, self, x: orig_func(self.float(), x.float()).to(devices.dtype_inference), unet_needs_upcast)
|
||||
CondFunc('open_clip.transformer.ResidualAttentionBlock.__init__', lambda orig_func, *args, **kwargs: kwargs.update({'act_layer': GELUHijack}) and False or orig_func(*args, **kwargs), lambda _, *args, **kwargs: kwargs.get('act_layer') is None or kwargs['act_layer'] == torch.nn.GELU)
|
||||
|
||||
first_stage_cond = lambda _, self, *args, **kwargs: devices.unet_needs_upcast and self.model.diffusion_model.dtype == torch.float16
|
||||
|
|
@ -146,7 +146,7 @@ def timestep_embedding_cast_result(orig_func, timesteps, *args, **kwargs):
|
|||
if devices.unet_needs_upcast and timesteps.dtype == torch.int64:
|
||||
dtype = torch.float32
|
||||
else:
|
||||
dtype = devices.dtype_unet
|
||||
dtype = devices.dtype_inference
|
||||
return orig_func(timesteps, *args, **kwargs).to(dtype=dtype)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -34,6 +34,7 @@ class ModelType(enum.Enum):
|
|||
SDXL = 3
|
||||
SSD = 4
|
||||
SD3 = 5
|
||||
FLUX1 = 6
|
||||
|
||||
|
||||
def replace_key(d, key, new_key, value):
|
||||
|
|
@ -267,6 +268,30 @@ def get_state_dict_from_checkpoint(pl_sd):
|
|||
return pl_sd
|
||||
|
||||
|
||||
def fix_unet_prefix(state_dict):
|
||||
known_prefixes = ("model.diffusion_model.", "first_stage_model.", "cond_stage_model.", "conditioner", "vae.", "text_encoders.")
|
||||
|
||||
for k in state_dict.keys():
|
||||
found = [prefix for prefix in known_prefixes if k.startswith(prefix)]
|
||||
if len(found) > 0:
|
||||
return state_dict
|
||||
|
||||
# no known prefix found.
|
||||
# in this case, this is a unet only state_dict
|
||||
known_keys = (
|
||||
"input_blocks.0.0.weight", # SD1.5, SD2, SDXL
|
||||
"joint_blocks.0.context_block.adaLN_modulation.1.weight", # SD3
|
||||
"double_blocks.0.img_attn.proj.weight", # FLUX
|
||||
)
|
||||
|
||||
if any(key in state_dict for key in known_keys):
|
||||
state_dict = {f"model.diffusion_model.{k}": v for k, v in state_dict.items()}
|
||||
print("Fixed state_dict keys...")
|
||||
return state_dict
|
||||
|
||||
return state_dict
|
||||
|
||||
|
||||
def read_metadata_from_safetensors(filename):
|
||||
import json
|
||||
|
||||
|
|
@ -328,6 +353,7 @@ def get_checkpoint_state_dict(checkpoint_info: CheckpointInfo, timer):
|
|||
|
||||
print(f"Loading weights [{sd_model_hash}] from {checkpoint_info.filename}")
|
||||
res = read_state_dict(checkpoint_info.filename)
|
||||
res = fix_unet_prefix(res)
|
||||
timer.record("load weights from disk")
|
||||
|
||||
return res
|
||||
|
|
@ -355,7 +381,7 @@ def check_fp8(model):
|
|||
enable_fp8 = False
|
||||
elif shared.opts.fp8_storage == "Enable":
|
||||
enable_fp8 = True
|
||||
elif getattr(model, "is_sdxl", False) and shared.opts.fp8_storage == "Enable for SDXL":
|
||||
elif any(getattr(model, attr, False) for attr in ("is_sdxl", "is_flux1")) and shared.opts.fp8_storage == "Enable for SDXL":
|
||||
enable_fp8 = True
|
||||
else:
|
||||
enable_fp8 = False
|
||||
|
|
@ -368,10 +394,14 @@ def set_model_type(model, state_dict):
|
|||
model.is_sdxl = False
|
||||
model.is_ssd = False
|
||||
model.is_sd3 = False
|
||||
model.is_flux1 = False
|
||||
|
||||
if "model.diffusion_model.x_embedder.proj.weight" in state_dict:
|
||||
model.is_sd3 = True
|
||||
model.model_type = ModelType.SD3
|
||||
elif "model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale" in state_dict:
|
||||
model.is_flux1 = True
|
||||
model.model_type = ModelType.FLUX1
|
||||
elif hasattr(model, 'conditioner'):
|
||||
model.is_sdxl = True
|
||||
|
||||
|
|
@ -393,6 +423,82 @@ def set_model_fields(model):
|
|||
model.latent_channels = 4
|
||||
|
||||
|
||||
def get_state_dict_dtype(state_dict):
|
||||
# detect dtypes of state_dict
|
||||
state_dict_dtype = {}
|
||||
|
||||
known_prefixes = ("model.diffusion_model.", "first_stage_model.", "cond_stage_model.", "conditioner", "vae.", "text_encoders.")
|
||||
|
||||
for k in state_dict.keys():
|
||||
found = [prefix for prefix in known_prefixes if k.startswith(prefix)]
|
||||
if len(found) > 0:
|
||||
prefix = found[0]
|
||||
dtype = state_dict[k].dtype
|
||||
dtypes = state_dict_dtype.get(prefix, {})
|
||||
if dtype in dtypes:
|
||||
dtypes[dtype] += 1
|
||||
else:
|
||||
dtypes[dtype] = 1
|
||||
state_dict_dtype[prefix] = dtypes
|
||||
|
||||
for prefix in state_dict_dtype:
|
||||
dtypes = state_dict_dtype[prefix]
|
||||
# sort by count
|
||||
state_dict_dtype[prefix] = dict(sorted(dtypes.items(), key=lambda item: item[1], reverse=True))
|
||||
|
||||
print("Detected dtypes:", state_dict_dtype)
|
||||
return state_dict_dtype
|
||||
|
||||
|
||||
def get_loadable_dtype(prefix="model.diffusion_model.", state_dict=None, state_dict_dtype=None):
|
||||
if state_dict is not None:
|
||||
state_dict_dtype = get_state_dict_dtype(state_dict)
|
||||
|
||||
# get the first dtype
|
||||
if prefix in state_dict_dtype:
|
||||
return list(state_dict_dtype[prefix])[0]
|
||||
return None
|
||||
|
||||
|
||||
def get_vae_dtype(state_dict=None, state_dict_dtype=None):
|
||||
if state_dict is not None:
|
||||
state_dict_dtype = get_state_dict_dtype(state_dict)
|
||||
|
||||
if state_dict_dtype is None:
|
||||
raise ValueError("fail to get vae dtype")
|
||||
|
||||
|
||||
vae_prefixes = [prefix for prefix in ("vae.", "first_stage_model.") if prefix in state_dict_dtype]
|
||||
|
||||
if len(vae_prefixes) > 0:
|
||||
vae_prefix = vae_prefixes[0]
|
||||
for dtype in state_dict_dtype[vae_prefix]:
|
||||
if state_dict_dtype[vae_prefix][dtype] > 240 and dtype in (torch.float16, torch.float32, torch.bfloat16):
|
||||
# vae items: 248 for SD1, SDXL 245 for flux
|
||||
return dtype
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def fix_position_ids(state_dict, force=False):
|
||||
# for SD1.5 or some SDXL with position_ids
|
||||
for prefix in ("cond_stage_models.", "conditioner.embedders.0."):
|
||||
position_id_key = f"{prefix}transformer.text_model.embeddings.position_ids"
|
||||
if position_id_key in state_dict:
|
||||
original = state_dict[position_id_key]
|
||||
if original.dtype == torch.int64:
|
||||
return
|
||||
|
||||
if force:
|
||||
# regenerate
|
||||
fixed = torch.tensor([list(range(77))], dtype=torch.int64, device=original.device)
|
||||
else:
|
||||
fixed = state_dict[position_id_key].to(torch.int64)
|
||||
print(f"Warning: Fixed position_ids dtype from {original.dtype} to {fixed.dtype}")
|
||||
|
||||
state_dict[position_id_key] = fixed
|
||||
|
||||
|
||||
def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer):
|
||||
sd_model_hash = checkpoint_info.calculate_shorthash()
|
||||
timer.record("calculate hash")
|
||||
|
|
@ -414,6 +520,9 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
|
|||
else:
|
||||
model.ztsnr = False
|
||||
|
||||
fix_position_ids(state_dict)
|
||||
|
||||
|
||||
if model.is_sdxl:
|
||||
sd_models_xl.extend_sdxl(model)
|
||||
|
||||
|
|
@ -427,6 +536,9 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
|
|||
if hasattr(model, "before_load_weights"):
|
||||
model.before_load_weights(state_dict)
|
||||
|
||||
# get all dtypes of state_dict
|
||||
state_dict_dtype = get_state_dict_dtype(state_dict)
|
||||
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
timer.record("apply weights to model")
|
||||
|
||||
|
|
@ -452,7 +564,13 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
|
|||
model.to(memory_format=torch.channels_last)
|
||||
timer.record("apply channels_last")
|
||||
|
||||
if shared.cmd_opts.no_half:
|
||||
# check dtype of vae
|
||||
dtype_vae = get_vae_dtype(state_dict_dtype=state_dict_dtype)
|
||||
found_unet_dtype = get_loadable_dtype("model.diffusion_model.", state_dict_dtype=state_dict_dtype)
|
||||
unet_has_float = found_unet_dtype in (torch.float16, torch.float32, torch.bfloat16)
|
||||
|
||||
if (found_unet_dtype is None or unet_has_float) and shared.cmd_opts.no_half:
|
||||
# unet type is not detected or unet has float dtypes
|
||||
model.float()
|
||||
model.alphas_cumprod_original = model.alphas_cumprod
|
||||
devices.dtype_unet = torch.float32
|
||||
|
|
@ -462,8 +580,11 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
|
|||
vae = model.first_stage_model
|
||||
depth_model = getattr(model, 'depth_model', None)
|
||||
|
||||
if dtype_vae == torch.bfloat16 and dtype_vae in devices.supported_vae_dtypes:
|
||||
# preserve bfloat16 if it supported
|
||||
model.first_stage_model = None
|
||||
# with --no-half-vae, remove VAE from model when doing half() to prevent its weights from being converted to float16
|
||||
if shared.cmd_opts.no_half_vae:
|
||||
elif shared.cmd_opts.no_half_vae:
|
||||
model.first_stage_model = None
|
||||
# with --upcast-sampling, don't convert the depth model weights to float16
|
||||
if shared.cmd_opts.upcast_sampling and depth_model:
|
||||
|
|
@ -471,15 +592,28 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
|
|||
|
||||
alphas_cumprod = model.alphas_cumprod
|
||||
model.alphas_cumprod = None
|
||||
model.half()
|
||||
|
||||
|
||||
if found_unet_dtype in (torch.float16, torch.float32, torch.bfloat16):
|
||||
model.half()
|
||||
elif found_unet_dtype in (torch.float8_e4m3fn, torch.float8_e5m2):
|
||||
pass
|
||||
else:
|
||||
print("Fail to get a vaild UNet dtype. ignore...")
|
||||
|
||||
model.alphas_cumprod = alphas_cumprod
|
||||
model.alphas_cumprod_original = alphas_cumprod
|
||||
model.first_stage_model = vae
|
||||
if depth_model:
|
||||
model.depth_model = depth_model
|
||||
|
||||
devices.dtype_unet = torch.float16
|
||||
timer.record("apply half()")
|
||||
if found_unet_dtype in (torch.float16, torch.float32):
|
||||
devices.dtype_unet = torch.float16
|
||||
timer.record("apply half()")
|
||||
else:
|
||||
print(f"load Unet {found_unet_dtype} as is ...")
|
||||
devices.dtype_unet = found_unet_dtype if found_unet_dtype else torch.float16
|
||||
timer.record("load UNet")
|
||||
|
||||
apply_alpha_schedule_override(model)
|
||||
|
||||
|
|
@ -489,10 +623,18 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
|
|||
if hasattr(module, 'fp16_bias'):
|
||||
del module.fp16_bias
|
||||
|
||||
if check_fp8(model):
|
||||
if found_unet_dtype not in (torch.float8_e4m3fn,torch.float8_e5m2) and check_fp8(model):
|
||||
devices.fp8 = True
|
||||
|
||||
# do not convert vae, text_encoders.clip_l, clip_g, t5xxl
|
||||
first_stage = model.first_stage_model
|
||||
model.first_stage_model = None
|
||||
vae = getattr(model, 'vae', None)
|
||||
if vae is not None:
|
||||
model.vae = None
|
||||
text_encoders = getattr(model, 'text_encoders', None)
|
||||
if text_encoders is not None:
|
||||
model.text_encoders = None
|
||||
for module in model.modules():
|
||||
if isinstance(module, (torch.nn.Conv2d, torch.nn.Linear)):
|
||||
if shared.opts.cache_fp16_weight:
|
||||
|
|
@ -500,6 +642,10 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
|
|||
if module.bias is not None:
|
||||
module.fp16_bias = module.bias.data.clone().cpu().half()
|
||||
module.to(torch.float8_e4m3fn)
|
||||
if text_encoders is not None:
|
||||
model.text_encoders = text_encoders
|
||||
if vae is not None:
|
||||
model.vae = vae
|
||||
model.first_stage_model = first_stage
|
||||
timer.record("apply fp8")
|
||||
else:
|
||||
|
|
@ -507,8 +653,16 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
|
|||
|
||||
devices.unet_needs_upcast = shared.cmd_opts.upcast_sampling and devices.dtype == torch.float16 and devices.dtype_unet == torch.float16
|
||||
|
||||
model.first_stage_model.to(devices.dtype_vae)
|
||||
timer.record("apply dtype to VAE")
|
||||
# check supported vae dtype
|
||||
dtype_vae = get_vae_dtype(state_dict_dtype=state_dict_dtype)
|
||||
if dtype_vae == torch.bfloat16 and dtype_vae in devices.supported_vae_dtypes:
|
||||
devices.dtype_vae = torch.bfloat16
|
||||
print(f"VAE dtype {dtype_vae} detected. load as is.")
|
||||
else:
|
||||
# use default devices.dtype_vae
|
||||
model.first_stage_model.to(devices.dtype_vae)
|
||||
print(f"Use VAE dtype {devices.dtype_vae}")
|
||||
timer.record("apply dtype to VAE")
|
||||
|
||||
# clean up cache if limit is reached
|
||||
while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache:
|
||||
|
|
@ -661,6 +815,9 @@ sd1_clip_weight = 'cond_stage_model.transformer.text_model.embeddings.token_embe
|
|||
sd2_clip_weight = 'cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight'
|
||||
sdxl_clip_weight = 'conditioner.embedders.1.model.ln_final.weight'
|
||||
sdxl_refiner_clip_weight = 'conditioner.embedders.0.model.ln_final.weight'
|
||||
clip_l_clip_weight = 'text_encoders.clip_l.transformer.text_model.final_layer_norm.weight'
|
||||
clip_g_clip_weight = 'text_encoders.clip_g.transformer.text_model.final_layer_norm.weight'
|
||||
t5xxl_clip_weight = 'text_encoders.t5xxl.transformer.encoder.final_layer_norm.weight'
|
||||
|
||||
|
||||
class SdModelData:
|
||||
|
|
@ -793,7 +950,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None, checkpoint_
|
|||
|
||||
if not checkpoint_config:
|
||||
checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info)
|
||||
clip_is_included_into_sd = any(x for x in [sd1_clip_weight, sd2_clip_weight, sdxl_clip_weight, sdxl_refiner_clip_weight] if x in state_dict)
|
||||
clip_is_included_into_sd = any(x for x in [sd1_clip_weight, sd2_clip_weight, sdxl_clip_weight, sdxl_refiner_clip_weight, clip_l_clip_weight, clip_g_clip_weight ] if x in state_dict)
|
||||
|
||||
timer.record("find config")
|
||||
|
||||
|
|
@ -804,6 +961,18 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None, checkpoint_
|
|||
|
||||
print(f"Creating model from config: {checkpoint_config}")
|
||||
|
||||
# get all dtypes of state_dict
|
||||
state_dict_dtype = get_state_dict_dtype(state_dict)
|
||||
|
||||
# check loadable unet dtype before loading
|
||||
loadable_unet_dtype = get_loadable_dtype("model.diffusion_model.", state_dict_dtype=state_dict_dtype)
|
||||
|
||||
# check dtype of vae
|
||||
dtype_vae = get_vae_dtype(state_dict_dtype=state_dict_dtype)
|
||||
if dtype_vae == torch.bfloat16 and dtype_vae in devices.supported_vae_dtypes:
|
||||
devices.dtype_vae = torch.bfloat16
|
||||
print(f"VAE dtype {dtype_vae} detected.")
|
||||
|
||||
sd_model = None
|
||||
try:
|
||||
with sd_disable_initialization.DisableInitialization(disable_clip=clip_is_included_into_sd or shared.cmd_opts.do_not_download_clip):
|
||||
|
|
@ -828,8 +997,10 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None, checkpoint_
|
|||
else:
|
||||
weight_dtype_conversion = {
|
||||
'first_stage_model': None,
|
||||
'text_encoders': None,
|
||||
'vae': None,
|
||||
'alphas_cumprod': None,
|
||||
'': torch.float16,
|
||||
'': torch.float16 if loadable_unet_dtype in (torch.float16, torch.float32, torch.bfloat16) else None,
|
||||
}
|
||||
|
||||
with sd_disable_initialization.LoadStateDictOnMeta(state_dict, device=model_target_device(sd_model), weight_dtype_conversion=weight_dtype_conversion):
|
||||
|
|
@ -856,7 +1027,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None, checkpoint_
|
|||
|
||||
timer.record("scripts callbacks")
|
||||
|
||||
with devices.autocast(), torch.no_grad():
|
||||
with devices.autocast(target_dtype=devices.dtype_inference), torch.no_grad():
|
||||
sd_model.cond_stage_model_empty_prompt = get_empty_cond(sd_model)
|
||||
|
||||
timer.record("calculate empty prompt")
|
||||
|
|
|
|||
|
|
@ -25,6 +25,7 @@ config_instruct_pix2pix = os.path.join(sd_configs_path, "instruct-pix2pix.yaml")
|
|||
config_alt_diffusion = os.path.join(sd_configs_path, "alt-diffusion-inference.yaml")
|
||||
config_alt_diffusion_m18 = os.path.join(sd_configs_path, "alt-diffusion-m18-inference.yaml")
|
||||
config_sd3 = os.path.join(sd_configs_path, "sd3-inference.yaml")
|
||||
config_flux1 = os.path.join(sd_configs_path, "flux1-inference.yaml")
|
||||
|
||||
|
||||
def is_using_v_parameterization_for_sd2(state_dict):
|
||||
|
|
@ -78,6 +79,9 @@ def guess_model_config_from_state_dict(sd, filename):
|
|||
if "model.diffusion_model.x_embedder.proj.weight" in sd:
|
||||
return config_sd3
|
||||
|
||||
if "model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale" in sd:
|
||||
return config_flux1
|
||||
|
||||
if sd.get('conditioner.embedders.1.model.ln_final.weight', None) is not None:
|
||||
if diffusion_model_input.shape[1] == 9:
|
||||
return config_sdxl_inpainting
|
||||
|
|
|
|||
|
|
@ -36,5 +36,8 @@ class WebuiSdModel(LatentDiffusion):
|
|||
is_sd3: bool
|
||||
"""True if the model's architecture is SD 3"""
|
||||
|
||||
is_flux1: bool
|
||||
"""True if the model's architecture is FLUX 1"""
|
||||
|
||||
latent_channels: int
|
||||
"""number of layer in latent image representation; will be 16 in SD3 and 4 in other version"""
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ def get_learned_conditioning(self: sgm.models.diffusion.DiffusionEngine, batch:
|
|||
is_negative_prompt = getattr(batch, 'is_negative_prompt', False)
|
||||
aesthetic_score = shared.opts.sdxl_refiner_low_aesthetic_score if is_negative_prompt else shared.opts.sdxl_refiner_high_aesthetic_score
|
||||
|
||||
devices_args = dict(device=devices.device, dtype=devices.dtype)
|
||||
devices_args = dict(device=devices.device, dtype=devices.dtype_inference)
|
||||
|
||||
sdxl_conds = {
|
||||
"txt": batch,
|
||||
|
|
|
|||
|
|
@ -64,7 +64,7 @@ def single_sample_to_image(sample, approximation=None):
|
|||
x_sample = samples_to_images_tensor(sample.unsqueeze(0), approximation)[0] * 0.5 + 0.5
|
||||
|
||||
x_sample = torch.clamp(x_sample, min=0.0, max=1.0)
|
||||
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
|
||||
x_sample = 255. * np.moveaxis(x_sample.to(dtype=devices.dtype).cpu().numpy(), 0, 2)
|
||||
x_sample = x_sample.astype(np.uint8)
|
||||
|
||||
return Image.fromarray(x_sample)
|
||||
|
|
|
|||
|
|
@ -197,47 +197,58 @@ def load_vae(model, vae_file=None, vae_source="from unknown source"):
|
|||
|
||||
cache_enabled = shared.opts.sd_vae_checkpoint_cache > 0
|
||||
|
||||
loaded = False
|
||||
if vae_file:
|
||||
if cache_enabled and vae_file in checkpoints_loaded:
|
||||
# use vae checkpoint cache
|
||||
print(f"Loading VAE weights {vae_source}: cached {get_filename(vae_file)}")
|
||||
store_base_vae(model)
|
||||
_load_vae_dict(model, checkpoints_loaded[vae_file])
|
||||
loaded = _load_vae_dict(model, checkpoints_loaded[vae_file])
|
||||
else:
|
||||
assert os.path.isfile(vae_file), f"VAE {vae_source} doesn't exist: {vae_file}"
|
||||
print(f"Loading VAE weights {vae_source}: {vae_file}")
|
||||
store_base_vae(model)
|
||||
|
||||
vae_dict_1 = load_vae_dict(vae_file, map_location=shared.weight_load_location)
|
||||
_load_vae_dict(model, vae_dict_1)
|
||||
loaded = _load_vae_dict(model, vae_dict_1)
|
||||
|
||||
if cache_enabled:
|
||||
if loaded and cache_enabled:
|
||||
# cache newly loaded vae
|
||||
checkpoints_loaded[vae_file] = vae_dict_1.copy()
|
||||
|
||||
# clean up cache if limit is reached
|
||||
if cache_enabled:
|
||||
if loaded and cache_enabled:
|
||||
while len(checkpoints_loaded) > shared.opts.sd_vae_checkpoint_cache + 1: # we need to count the current model
|
||||
checkpoints_loaded.popitem(last=False) # LRU
|
||||
|
||||
# If vae used is not in dict, update it
|
||||
# It will be removed on refresh though
|
||||
vae_opt = get_filename(vae_file)
|
||||
if vae_opt not in vae_dict:
|
||||
if loaded and vae_opt not in vae_dict:
|
||||
vae_dict[vae_opt] = vae_file
|
||||
|
||||
elif loaded_vae_file:
|
||||
restore_base_vae(model)
|
||||
loaded = True
|
||||
|
||||
loaded_vae_file = vae_file
|
||||
if loaded:
|
||||
loaded_vae_file = vae_file
|
||||
model.base_vae = base_vae
|
||||
model.loaded_vae_file = loaded_vae_file
|
||||
return loaded
|
||||
|
||||
|
||||
# don't call this from outside
|
||||
def _load_vae_dict(model, vae_dict_1):
|
||||
conv_out = model.first_stage_model.state_dict().get("encoder.conv_out.weight")
|
||||
# check shape of "encoder.conv_out.weight". SD1.5/SDXL: [8, 512, 3, 3], FLUX/SD3: [32, 512, 3, 3]
|
||||
if conv_out.shape != vae_dict_1["encoder.conv_out.weight"].shape:
|
||||
print("Failed to load VAE. Size mismatched!")
|
||||
return False
|
||||
|
||||
model.first_stage_model.load_state_dict(vae_dict_1)
|
||||
model.first_stage_model.to(devices.dtype_vae)
|
||||
return True
|
||||
|
||||
|
||||
def clear_loaded_vae():
|
||||
|
|
@ -270,7 +281,7 @@ def reload_vae_weights(sd_model=None, vae_file=unspecified):
|
|||
|
||||
sd_hijack.model_hijack.undo_hijack(sd_model)
|
||||
|
||||
load_vae(sd_model, vae_file, vae_source)
|
||||
loaded = load_vae(sd_model, vae_file, vae_source)
|
||||
|
||||
sd_hijack.model_hijack.hijack(sd_model)
|
||||
|
||||
|
|
@ -279,5 +290,6 @@ def reload_vae_weights(sd_model=None, vae_file=unspecified):
|
|||
|
||||
script_callbacks.model_loaded_callback(sd_model)
|
||||
|
||||
print("VAE weights loaded.")
|
||||
if loaded:
|
||||
print("VAE weights loaded.")
|
||||
return sd_model
|
||||
|
|
|
|||
|
|
@ -44,6 +44,8 @@ def model():
|
|||
model_name = "vaeapprox-sd3.pt"
|
||||
elif shared.sd_model.is_sdxl:
|
||||
model_name = "vaeapprox-sdxl.pt"
|
||||
elif shared.sd_model.is_flux1:
|
||||
model_name = "vaeapprox-sd3.pt"
|
||||
else:
|
||||
model_name = "model.pt"
|
||||
|
||||
|
|
@ -81,6 +83,18 @@ def cheap_approximation(sample):
|
|||
[ 0.0815, 0.0846, 0.1207], [-0.0120, -0.0055, -0.0867],
|
||||
[-0.0749, -0.0634, -0.0456], [-0.1418, -0.1457, -0.1259],
|
||||
]
|
||||
elif shared.sd_model.is_flux1:
|
||||
coeffs = [
|
||||
# from comfy
|
||||
[-0.0404, 0.0159, 0.0609], [ 0.0043, 0.0298, 0.0850],
|
||||
[ 0.0328, -0.0749, -0.0503], [-0.0245, 0.0085, 0.0549],
|
||||
[ 0.0966, 0.0894, 0.0530], [ 0.0035, 0.0399, 0.0123],
|
||||
[ 0.0583, 0.1184, 0.1262], [-0.0191, -0.0206, -0.0306],
|
||||
[-0.0324, 0.0055, 0.1001], [ 0.0955, 0.0659, -0.0545],
|
||||
[-0.0504, 0.0231, -0.0013], [ 0.0500, -0.0008, -0.0088],
|
||||
[ 0.0982, 0.0941, 0.0976], [-0.1233, -0.0280, -0.0897],
|
||||
[-0.0005, -0.0530, -0.0020], [-0.1273, -0.0932, -0.0680],
|
||||
]
|
||||
elif shared.sd_model.is_sdxl:
|
||||
coeffs = [
|
||||
[ 0.3448, 0.4168, 0.4395],
|
||||
|
|
|
|||
|
|
@ -63,7 +63,7 @@ class TAESDDecoder(nn.Module):
|
|||
super().__init__()
|
||||
|
||||
if latent_channels is None:
|
||||
latent_channels = 16 if "taesd3" in str(decoder_path) else 4
|
||||
latent_channels = 16 if any(typ in str(decoder_path) for typ in ("taesd3", "taef1")) else 4
|
||||
|
||||
self.decoder = decoder(latent_channels)
|
||||
self.decoder.load_state_dict(
|
||||
|
|
@ -79,7 +79,7 @@ class TAESDEncoder(nn.Module):
|
|||
super().__init__()
|
||||
|
||||
if latent_channels is None:
|
||||
latent_channels = 16 if "taesd3" in str(encoder_path) else 4
|
||||
latent_channels = 16 if any(typ in str(encoder_path) for typ in ("taesd3", "taef1")) else 4
|
||||
|
||||
self.encoder = encoder(latent_channels)
|
||||
self.encoder.load_state_dict(
|
||||
|
|
@ -97,6 +97,8 @@ def download_model(model_path, model_url):
|
|||
def decoder_model():
|
||||
if shared.sd_model.is_sd3:
|
||||
model_name = "taesd3_decoder.pth"
|
||||
elif shared.sd_model.is_flux1:
|
||||
model_name = "taef1_decoder.pth"
|
||||
elif shared.sd_model.is_sdxl:
|
||||
model_name = "taesdxl_decoder.pth"
|
||||
else:
|
||||
|
|
@ -122,6 +124,8 @@ def decoder_model():
|
|||
def encoder_model():
|
||||
if shared.sd_model.is_sd3:
|
||||
model_name = "taesd3_encoder.pth"
|
||||
elif shared.sd_model.is_flux1:
|
||||
model_name = "taef1_encoder.pth"
|
||||
elif shared.sd_model.is_sdxl:
|
||||
model_name = "taesdxl_encoder.pth"
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -29,7 +29,7 @@ def initialize():
|
|||
|
||||
devices.dtype = torch.float32 if cmd_opts.no_half else torch.float16
|
||||
devices.dtype_vae = torch.float32 if cmd_opts.no_half or cmd_opts.no_half_vae else torch.float16
|
||||
devices.dtype_inference = torch.float32 if cmd_opts.precision == 'full' else devices.dtype
|
||||
devices.dtype_inference = torch.float32 if cmd_opts.precision == 'full' else devices.dtype_inference
|
||||
|
||||
if cmd_opts.precision == "half":
|
||||
msg = "--no-half and --no-half-vae conflict with --precision half"
|
||||
|
|
|
|||
|
|
@ -196,6 +196,9 @@ options_templates.update(options_section(('sdxl', "Stable Diffusion XL", "sd"),
|
|||
options_templates.update(options_section(('sd3', "Stable Diffusion 3", "sd"), {
|
||||
"sd3_enable_t5": OptionInfo(False, "Enable T5").info("load T5 text encoder; increases VRAM use by a lot, potentially improving quality of generation; requires model reload to apply"),
|
||||
}))
|
||||
options_templates.update(options_section(('flux', "Stable Diffusion FLUX", "sd"), {
|
||||
"flux_enable_t5": OptionInfo(False, "Enable T5").info("load T5 text encoder; increases VRAM use by a lot, potentially improving quality of generation; requires model reload to apply"),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('vae', "VAE", "sd"), {
|
||||
"sd_vae_explanation": OptionHTML("""
|
||||
|
|
@ -243,6 +246,7 @@ options_templates.update(options_section(('optimizations', "Optimizations", "sd"
|
|||
"batch_cond_uncond": OptionInfo(True, "Batch cond/uncond").info("do both conditional and unconditional denoising in one batch; uses a bit more VRAM during sampling, but improves speed; previously this was controlled by --always-batch-cond-uncond commandline argument"),
|
||||
"fp8_storage": OptionInfo("Disable", "FP8 weight", gr.Radio, {"choices": ["Disable", "Enable for SDXL", "Enable"]}).info("Use FP8 to store Linear/Conv layers' weight. Require pytorch>=2.1.0."),
|
||||
"cache_fp16_weight": OptionInfo(False, "Cache FP16 weight for LoRA").info("Cache fp16 weight when enabling FP8, will increase the quality of LoRA. Use more system ram."),
|
||||
"lora_without_backup_weight": OptionInfo(False, "LoRA without backup weights").info("LoRA without backup weights to save RAM."),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('compatibility', "Compatibility", "sd"), {
|
||||
|
|
|
|||
21
webui.py
21
webui.py
|
|
@ -6,6 +6,8 @@ import time
|
|||
from modules import timer
|
||||
from modules import initialize_util
|
||||
from modules import initialize
|
||||
from modules import manager
|
||||
from threading import Thread
|
||||
|
||||
startup_timer = timer.startup_timer
|
||||
startup_timer.record("launcher")
|
||||
|
|
@ -14,6 +16,8 @@ initialize.imports()
|
|||
|
||||
initialize.check_versions()
|
||||
|
||||
initialize.initialize()
|
||||
|
||||
|
||||
def create_api(app):
|
||||
from modules.api.api import Api
|
||||
|
|
@ -23,12 +27,10 @@ def create_api(app):
|
|||
return api
|
||||
|
||||
|
||||
def api_only():
|
||||
def _api_only():
|
||||
from fastapi import FastAPI
|
||||
from modules.shared_cmd_options import cmd_opts
|
||||
|
||||
initialize.initialize()
|
||||
|
||||
app = FastAPI()
|
||||
initialize_util.setup_middleware(app)
|
||||
api = create_api(app)
|
||||
|
|
@ -83,11 +85,10 @@ For more information see: https://github.com/AUTOMATIC1111/stable-diffusion-webu
|
|||
{"!"*25} Warning {"!"*25}''')
|
||||
|
||||
|
||||
def webui():
|
||||
def _webui():
|
||||
from modules.shared_cmd_options import cmd_opts
|
||||
|
||||
launch_api = cmd_opts.api
|
||||
initialize.initialize()
|
||||
|
||||
from modules import shared, ui_tempdir, script_callbacks, ui, progress, ui_extra_networks
|
||||
|
||||
|
|
@ -177,6 +178,7 @@ def webui():
|
|||
print("Stopping server...")
|
||||
# If we catch a keyboard interrupt, we want to stop the server and exit.
|
||||
shared.demo.close()
|
||||
manager.task.stop()
|
||||
break
|
||||
|
||||
# disable auto launch webui in browser for subsequent UI Reload
|
||||
|
|
@ -193,6 +195,13 @@ def webui():
|
|||
initialize.initialize_rest(reload_script_modules=True)
|
||||
|
||||
|
||||
def api_only():
|
||||
Thread(target=_api_only, daemon=True).start()
|
||||
|
||||
|
||||
def webui():
|
||||
Thread(target=_webui, daemon=True).start()
|
||||
|
||||
if __name__ == "__main__":
|
||||
from modules.shared_cmd_options import cmd_opts
|
||||
|
||||
|
|
@ -200,3 +209,5 @@ if __name__ == "__main__":
|
|||
api_only()
|
||||
else:
|
||||
webui()
|
||||
|
||||
manager.task.main_loop()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue