diff --git a/extensions-builtin/Lora/networks.py b/extensions-builtin/Lora/networks.py index 5f00d36e7..12e1c24e1 100644 --- a/extensions-builtin/Lora/networks.py +++ b/extensions-builtin/Lora/networks.py @@ -526,7 +526,7 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn module_mlp = net.modules.get(network_layer_name + "_mlp_proj", None) - if any(isinstance(self, linear) for linear in (modules.models.sd3.mmdit.QkvLinear, modules.models.flux.modules.layers.QkvLinear)) and module_q and module_k and module_v and module_mlp is None: + if any(isinstance(self, linear) for linear in (modules.models.sd3.mmdit.QkvLinear, modules.models.flux.modules.layers.QkvLinear)) and module_q and module_k and module_v and module_mlp is None and self.weight.shape[0] // 3 == module_q.up_model.weight.shape[0]: try: with torch.no_grad(): # Send "real" orig_weight into MHA's lora module @@ -545,14 +545,17 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn continue - if any(isinstance(self, linear) for linear in (modules.models.flux.modules.layers.QkvLinear,)) and module_q and module_k and module_v and module_mlp: + if any(isinstance(self, linear) for linear in (modules.models.flux.modules.layers.QkvLinear,)) and module_q and module_k and module_v: try: with torch.no_grad(): qw, kw, vw, mlp = torch.tensor_split(self.weight, (3072, 6144, 9216,), 0) updown_q, _ = module_q.calc_updown(qw) updown_k, _ = module_k.calc_updown(kw) updown_v, _ = module_v.calc_updown(vw) - updown_mlp, _ = module_v.calc_updown(mlp) + if module_mlp is not None: + updown_mlp, _ = module_mlp.calc_updown(mlp) + else: + updown_mlp = torch.zeros(3072 * 4, 3072, dtype=updown_q.dtype, device=updown_q.device) del qw, kw, vw, mlp updown_qkv_mlp = torch.vstack([updown_q, updown_k, updown_v, updown_mlp]) self.weight += updown_qkv_mlp