diff --git a/modules/sd_models.py b/modules/sd_models.py index 7fc4093d894..b515a633a62 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -581,7 +581,7 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer if found_unet_dtype in (torch.float16, torch.float32, torch.bfloat16): model.half() - elif found_unet_dtype in (torch.float8_e4m3fn,): + elif found_unet_dtype in (torch.float8_e4m3fn, torch.float8_e5m2): pass else: print("Fail to get a vaild UNet dtype. ignore...") @@ -608,7 +608,7 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer if hasattr(module, 'fp16_bias'): del module.fp16_bias - if found_unet_dtype not in (torch.float8_e4m3fn,) and check_fp8(model): + if found_unet_dtype not in (torch.float8_e4m3fn,torch.float8_e5m2) and check_fp8(model): devices.fp8 = True # do not convert vae, text_encoders.clip_l, clip_g, t5xxl