Skip to content

Commit

Permalink
add sdp_on_bf16 to controlnet (#1631)
Browse files Browse the repository at this point in the history
* add sdp_on_bf16 to controlnet

* Update pipeline_controlnet.py

pass sdp_on_bf16 to controlnet_pipeline

* Update text_to_image_generation.py

* Update text_to_image_generation.py
  • Loading branch information
skaulintel authored Dec 20, 2024
1 parent 948476a commit 1e747dd
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 0 deletions.
7 changes: 7 additions & 0 deletions examples/stable-diffusion/training/train_controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ def log_validation(
use_habana=True,
use_hpu_graphs=args.use_hpu_graphs,
gaudi_config=gaudi_config,
sdp_on_bf16=args.sdp_on_bf16,
)
pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config)
pipeline = pipeline.to(accelerator.device)
Expand Down Expand Up @@ -438,6 +439,12 @@ def parse_args(input_args=None):
' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
),
)
parser.add_argument(
"--sdp_on_bf16",
action="store_true",
default=False,
help="Allow pyTorch to use reduced precision in the SDPA math backend",
)
parser.add_argument(
"--bf16",
action="store_true",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,13 +100,15 @@ def __init__(
use_hpu_graphs: bool = False,
gaudi_config: Union[str, GaudiConfig] = None,
bf16_full_eval: bool = False,
sdp_on_bf16: bool = True,
):
GaudiDiffusionPipeline.__init__(
self,
use_habana,
use_hpu_graphs,
gaudi_config,
bf16_full_eval,
sdp_on_bf16,
)

StableDiffusionControlNetPipeline.__init__(
Expand Down

0 comments on commit 1e747dd

Please sign in to comment.