From 0ab4d7992c4b3c65de7200a2adca0afa85907cc1 Mon Sep 17 00:00:00 2001 From: Won-Kyu Park Date: Wed, 2 Oct 2024 20:02:30 +0900 Subject: [PATCH] reduce backup_weight size for float8 freeze model --- extensions-builtin/Lora/networks.py | 15 +++-- .../Lora/scripts/lora_script.py | 61 ++++++++++++++++++- 2 files changed, 69 insertions(+), 7 deletions(-) diff --git a/extensions-builtin/Lora/networks.py b/extensions-builtin/Lora/networks.py index 12e1c24e1..76cef0a55 100644 --- a/extensions-builtin/Lora/networks.py +++ b/extensions-builtin/Lora/networks.py @@ -377,13 +377,13 @@ def allowed_layer_without_weight(layer): return False -def store_weights_backup(weight): +def store_weights_backup(weight, dtype): if weight is None: return None if shared.opts.lora_without_backup_weight: return True - return weight.to(devices.cpu, copy=True) + return weight.to(devices.cpu, dtype=dtype, copy=True) def restore_weights_backup(obj, field, weight): @@ -437,18 +437,18 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn raise RuntimeError(f"{network_layer_name} - no backup weights found and current weights are not unchanged") if isinstance(self, torch.nn.MultiheadAttention): - weights_backup = (store_weights_backup(self.in_proj_weight), store_weights_backup(self.out_proj.weight)) + weights_backup = (store_weights_backup(self.in_proj_weight, self.org_dtype), store_weights_backup(self.out_proj.weight, self.org_dtype)) else: - weights_backup = store_weights_backup(self.weight) + weights_backup = store_weights_backup(self.weight, self.org_dtype) self.network_weights_backup = weights_backup bias_backup = getattr(self, "network_bias_backup", None) if bias_backup is None and wanted_names != (): if isinstance(self, torch.nn.MultiheadAttention) and self.out_proj.bias is not None: - bias_backup = store_weights_backup(self.out_proj.bias) + bias_backup = store_weights_backup(self.out_proj.bias, self.org_dtype) elif getattr(self, 'bias', None) is not None: - bias_backup = store_weights_backup(self.bias) + bias_backup = store_weights_backup(self.bias, self.org_dtype) else: bias_backup = None @@ -487,6 +487,7 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn self.bias = torch.nn.Parameter(ex_bias).to(self.weight.dtype) else: self.bias.copy_((bias + ex_bias).to(dtype=self.bias.dtype)) + del weight, bias, updown, ex_bias except RuntimeError as e: logging.debug(f"Network {net.name} layer {network_layer_name}: {e}") extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1 @@ -538,6 +539,7 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn updown_qkv = torch.vstack([updown_q, updown_k, updown_v]) self.weight += updown_qkv del updown_qkv + del updown_q, updown_k, updown_v except RuntimeError as e: logging.debug(f"Network {net.name} layer {network_layer_name}: {e}") @@ -560,6 +562,7 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn updown_qkv_mlp = torch.vstack([updown_q, updown_k, updown_v, updown_mlp]) self.weight += updown_qkv_mlp del updown_qkv_mlp + del updown_q, updown_k, updown_v, updown_mlp except RuntimeError as e: logging.debug(f"Network {net.name} layer {network_layer_name}: {e}") diff --git a/extensions-builtin/Lora/scripts/lora_script.py b/extensions-builtin/Lora/scripts/lora_script.py index d3ea369ae..8ee93efef 100644 --- a/extensions-builtin/Lora/scripts/lora_script.py +++ b/extensions-builtin/Lora/scripts/lora_script.py @@ -1,4 +1,5 @@ import re +import torch import gradio as gr from fastapi import FastAPI @@ -9,7 +10,7 @@ import lora # noqa:F401 import lora_patches import extra_networks_lora import ui_extra_networks_lora -from modules import script_callbacks, ui_extra_networks, extra_networks, shared +from modules import script_callbacks, ui_extra_networks, extra_networks, shared, scripts, devices def unload(): @@ -97,6 +98,64 @@ def infotext_pasted(infotext, d): d["Prompt"] = re.sub(re_lora, network_replacement, d["Prompt"]) +class ScriptLora(scripts.Script): + name = "Lora" + + def title(self): + return self.name + + def show(self, is_img2img): + return scripts.AlwaysVisible + + def after_extra_networks_activate(self, p, *args, **kwargs): + # check modules and setup org_dtype + modules = [] + if shared.sd_model.is_sdxl: + for _i, embedder in enumerate(shared.sd_model.conditioner.embedders): + if not hasattr(embedder, 'wrapped'): + continue + + for _name, module in embedder.wrapped.named_modules(): + if isinstance(module, (torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention)): + if hasattr(module, 'weight'): + modules.append(module) + elif isinstance(module, torch.nn.MultiheadAttention): + modules.append(module) + + else: + cond_stage_model = getattr(shared.sd_model.cond_stage_model, 'wrapped', shared.sd_model.cond_stage_model) + + for _name, module in cond_stage_model.named_modules(): + if isinstance(module, (torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention)): + if hasattr(module, 'weight'): + modules.append(module) + elif isinstance(module, torch.nn.MultiheadAttention): + modules.append(module) + + for _name, module in shared.sd_model.model.named_modules(): + if isinstance(module, (torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention)): + if hasattr(module, 'weight'): + modules.append(module) + elif isinstance(module, torch.nn.MultiheadAttention): + modules.append(module) + + print("Total lora modules after_extra_networks_activate() =", len(modules)) + + target_dtype = devices.dtype_inference + for module in modules: + if isinstance(module, torch.nn.MultiheadAttention): + org_dtype = torch.float32 + else: + org_dtype = None + for _name, param in module.named_parameters(): + if param.dtype != target_dtype: + org_dtype = param.dtype + break + + # set org_dtype + module.org_dtype = org_dtype + + script_callbacks.on_infotext_pasted(infotext_pasted) shared.opts.onchange("lora_in_memory_limit", networks.purge_networks_from_memory)