Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

patch 'state_dict' to retrieve the original state_dict #13808

Closed
wants to merge 1 commit into from

Conversation

wkpark
Copy link
Contributor

@wkpark wkpark commented Oct 31, 2023

Description

  • after partial update/modify block level UNet, shared.sd_model.state_dict() works nicely, but after using any lora in the prompt, sd_mode.state_dict() does not work as expected.

this fix adds more patch to LoraPatches to hook state_dict() to retrieve original weights/bias.

without this fix we can get the original state_dict() by the following patch/hook

import torch

from modules import patches, script_callbacks

class StateDictPatches:
    def __init__(self):
        import networks

        def network_Linear_state_dict(module, *args, **kwargs):
            with torch.no_grad():
                networks.network_restore_weights_from_backup(module)

            return self.Linear_state_dict(module, *args, **kwargs)


        def network_Conv2d_state_dict(module, *args, **kwargs):
            with torch.no_grad():
                networks.network_restore_weights_from_backup(module)

            return self.Conv2d_state_dict(module, *args, **kwargs)


        def network_GroupNorm_state_dict(module, *args, **kwargs):
            with torch.no_grad():
                networks.network_restore_weights_from_backup(module)

            return self.GroupNorm_state_dict(module, *args, **kwargs)


        def network_LayerNorm_state_dict(module, *args, **kwargs):
            with torch.no_grad():
                networks.network_restore_weights_from_backup(module)

            return self.LayerNorm_state_dict(module, *args, **kwargs)


        def network_MultiheadAttention_state_dict(module, *args, **kwargs):
            with torch.no_grad():
                networks.network_restore_weights_from_backup(module)

            return self.MultiheadAttention_state_dict(module, *args, **kwargs)


        self.Linear_state_dict = patches.patch(__name__, torch.nn.Linear, 'state_dict', network_Linear_state_dict)
        self.Conv2d_state_dict = patches.patch(__name__, torch.nn.Conv2d, 'state_dict', network_Conv2d_state_dict)
        self.GroupNorm_state_dict = patches.patch(__name__, torch.nn.GroupNorm, 'state_dict', network_GroupNorm_state_dict)
        self.LayerNorm_state_dict = patches.patch(__name__, torch.nn.LayerNorm, 'state_dict', network_LayerNorm_state_dict)
        self.MultiheadAttention_state_dict = patches.patch(__name__, torch.nn.MultiheadAttention, 'state_dict', network_MultiheadAttention_state_dict)


    def undo(self):
        self.Linear_state_dict = patches.undo(__name__, torch.nn.Linear, 'state_dict')
        self.Conv2d_state_dict = patches.undo(__name__, torch.nn.Conv2d, 'state_dict')
        self.GroupNorm_state_dict = patches.undo(__name__, torch.nn.GroupNorm, 'state_dict')
        self.LayerNorm_state_dict = patches.undo(__name__, torch.nn.LayerNorm, 'state_dict')
        self.MultiheadAttention_state_dict = patches.undo(__name__, torch.nn.MultiheadAttention, 'state_dict')

Checklist:

@wkpark wkpark requested a review from AUTOMATIC1111 as a code owner October 31, 2023 08:10
@AUTOMATIC1111
Copy link
Owner

The problem is that you're expecting state_dict of the original checkpoint and are getting state_dict of checkpoint with loras applied, is that correct? I kind of thought that getting weights with loras was the desired outcome.

@wkpark
Copy link
Contributor Author

wkpark commented Nov 4, 2023

The problem is that you're expecting state_dict of the original checkpoint and are getting state_dict of checkpoint with loras applied, is that correct? I kind of thought that getting weights with loras was the desired outcome.

I don't think this is the intended result, but the current behavior can be useful for merging LoRAs into a checkpoint

anyway, without this PR, we can get the original state_dict by partial hook/patches onto state_dict like this PR.

@wkpark wkpark closed this Nov 8, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants