diff --git a/modules/devices.py b/modules/devices.py index 556e72d2e..866b6ab16 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -1,5 +1,6 @@ import sys import contextlib +from copy import deepcopy from functools import lru_cache import torch @@ -169,7 +170,7 @@ patch_module_list = [ ] -def manual_cast_forward(target_dtype, target_device=None): +def manual_cast_forward(target_dtype, target_device=None, copy=False): params = dict() if supports_non_blocking(): params['non_blocking'] = True @@ -193,13 +194,22 @@ def manual_cast_forward(target_dtype, target_device=None): org_dtype = param.dtype break - if org_dtype != target_dtype: - self.to(**params) - result = self.org_forward(*args, **kwargs) + if copy: + copied = deepcopy(self) + if org_dtype != target_dtype: + copied.to(**params) - if org_dtype != target_dtype: - params['dtype'] = org_dtype - self.to(**params) + result = copied.org_forward(*args, **kwargs) + del copied + else: + if org_dtype != target_dtype: + self.to(**params) + + result = self.org_forward(*args, **kwargs) + + if org_dtype != target_dtype: + params['dtype'] = org_dtype + self.to(**params) if target_dtype != dtype_inference: params['dtype'] = dtype_inference @@ -220,15 +230,17 @@ def manual_cast_forward(target_dtype, target_device=None): def manual_cast(target_dtype, target_device=None): applied = False + copy = shared.opts.lora_without_backup_weight + for module_type in patch_module_list: if hasattr(module_type, "org_forward"): continue applied = True org_forward = module_type.forward if module_type == torch.nn.MultiheadAttention: - module_type.forward = manual_cast_forward(torch.float32, target_device) + module_type.forward = manual_cast_forward(torch.float32, target_device, copy) else: - module_type.forward = manual_cast_forward(target_dtype, target_device) + module_type.forward = manual_cast_forward(target_dtype, target_device, copy) module_type.org_forward = org_forward try: yield None