Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Performance 1/6] use_checkpoint = False #15803

Merged
merged 3 commits into from
Jun 8, 2024

Conversation

huchenlei
Copy link
Contributor

@huchenlei huchenlei commented May 15, 2024

Description

According to lllyasviel/stable-diffusion-webui-forge#716 (comment) ,
calls to parameters in checkpoint function is a significant overhead in A1111. However, checkpoint function is mainly used for training, disabling it does not affect inference at all.

This PR disables checkpoint in A1111 in exchange for performance improvement. This reduces about 100ms/it on my local setup (4090). The duration/it before patch is ~580ms/it.

Screenshots/videos:

image

Checklist:

@huchenlei huchenlei requested a review from AUTOMATIC1111 as a code owner May 15, 2024 19:29
@huchenlei huchenlei changed the title use_checkpoint = False [Performance 1/6] use_checkpoint = False May 15, 2024
@huchenlei huchenlei changed the base branch from master to dev May 15, 2024 19:30
def BasicTransformerBlock_forward(self, x, context=None):
return checkpoint(self._forward, x, context)
return checkpoint(self._forward, x, context, flag=False)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The checkpoint here is torch.utils.checkpoint.checkpoint, and it does not have flag=False. I think you confused this with ldm.modules.diffusionmodules.util.checkpoint. The sd_hijack_checkpoint.py already removed the checkpointing in ldm, but we might need to do it on sgm as well.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As far as I understand, what you are looking for is actually

ldm.modules.attention.BasicTransformerBlock.forward = ldm.modules.attention.BasicTransformerBlock._forward

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A closer look indicates that the checkpoint here is only called when training occurs (textual_inversion & hypernetwork) so disabling checkpoint here may be undesirable.

@AUTOMATIC1111 AUTOMATIC1111 merged commit ad229fa into AUTOMATIC1111:dev Jun 8, 2024
3 checks passed
@lawchingman lawchingman mentioned this pull request Oct 5, 2024
catboxanon added a commit that referenced this pull request Oct 29, 2024
@catboxanon catboxanon mentioned this pull request Oct 29, 2024
4 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants