diff --git a/examples/stable-diffusion/training/train_controlnet.py b/examples/stable-diffusion/training/train_controlnet.py index e676ae6ddf..5648952413 100755 --- a/examples/stable-diffusion/training/train_controlnet.py +++ b/examples/stable-diffusion/training/train_controlnet.py @@ -120,6 +120,7 @@ def log_validation( use_habana=True, use_hpu_graphs=args.use_hpu_graphs, gaudi_config=gaudi_config, + sdp_on_bf16=args.sdp_on_bf16, ) pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config) pipeline = pipeline.to(accelerator.device) @@ -438,6 +439,12 @@ def parse_args(input_args=None): ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' ), ) + parser.add_argument( + "--sdp_on_bf16", + action="store_true", + default=False, + help="Allow pyTorch to use reduced precision in the SDPA math backend", + ) parser.add_argument( "--bf16", action="store_true", diff --git a/optimum/habana/diffusers/pipelines/controlnet/pipeline_controlnet.py b/optimum/habana/diffusers/pipelines/controlnet/pipeline_controlnet.py index b4566a0241..2cf8c866ec 100644 --- a/optimum/habana/diffusers/pipelines/controlnet/pipeline_controlnet.py +++ b/optimum/habana/diffusers/pipelines/controlnet/pipeline_controlnet.py @@ -100,6 +100,7 @@ def __init__( use_hpu_graphs: bool = False, gaudi_config: Union[str, GaudiConfig] = None, bf16_full_eval: bool = False, + sdp_on_bf16: bool = True, ): GaudiDiffusionPipeline.__init__( self, @@ -107,6 +108,7 @@ def __init__( use_hpu_graphs, gaudi_config, bf16_full_eval, + sdp_on_bf16, ) StableDiffusionControlNetPipeline.__init__(