diff --git a/modules/models/flux/flux.py b/modules/models/flux/flux.py index eb17a31c7..14fa4e255 100644 --- a/modules/models/flux/flux.py +++ b/modules/models/flux/flux.py @@ -270,7 +270,8 @@ class FLUX1Inferencer(torch.nn.Module): with torch.no_grad(): self.model = BaseModel(state_dict=state_dict, prefix=diffusion_model_prefix, device="cpu", dtype=devices.dtype) self.first_stage_model = SDVAE(device="cpu", dtype=devices.dtype_vae) - self.first_stage_model.dtype = self.model.diffusion_model.dtype + self.first_stage_model.dtype = devices.dtype_vae + self.vae = self.first_stage_model # real vae self.alphas_cumprod = 1 / (self.model.model_sampling.sigmas ** 2 + 1) diff --git a/modules/sd_models.py b/modules/sd_models.py index 41649970b..b4702151a 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -956,6 +956,8 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None, checkpoint_ else: weight_dtype_conversion = { 'first_stage_model': None, + 'text_encoders': None, + 'vae': None, 'alphas_cumprod': None, '': torch.float16 if loadable_unet_dtype in (torch.float16, torch.float32, torch.bfloat16) else None, }