From 1e73a287075e6d3ef9059c8748b58972b6f5a367 Mon Sep 17 00:00:00 2001 From: Won-Kyu Park Date: Tue, 17 Sep 2024 10:07:58 +0900 Subject: [PATCH] fix for float8_e5m2 freeze model --- modules/sd_models.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/modules/sd_models.py b/modules/sd_models.py index 3e0b577bb..6dffdc036 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -585,7 +585,7 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer if found_unet_dtype in (torch.float16, torch.float32, torch.bfloat16): model.half() - elif found_unet_dtype in (torch.float8_e4m3fn,): + elif found_unet_dtype in (torch.float8_e4m3fn, torch.float8_e5m2): pass else: print("Fail to get a vaild UNet dtype. ignore...") @@ -612,7 +612,7 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer if hasattr(module, 'fp16_bias'): del module.fp16_bias - if found_unet_dtype not in (torch.float8_e4m3fn,) and check_fp8(model): + if found_unet_dtype not in (torch.float8_e4m3fn,torch.float8_e5m2) and check_fp8(model): devices.fp8 = True # do not convert vae, text_encoders.clip_l, clip_g, t5xxl