diff --git a/modules/sd_models.py b/modules/sd_models.py index 4b8a9ae65..dcf816b3a 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -435,9 +435,9 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer for module in model.modules(): if isinstance(module, (torch.nn.Conv2d, torch.nn.Linear)): if shared.opts.cache_fp16_weight: - module.fp16_weight = module.weight.clone().half() + module.fp16_weight = module.weight.data.clone().cpu().half() if module.bias is not None: - module.fp16_bias = module.bias.clone().half() + module.fp16_bias = module.bias.data.clone().cpu().half() module.to(torch.float8_e4m3fn) model.first_stage_model = first_stage timer.record("apply fp8")