Skip to content

Commit

Permalink
patch 'state_dict' to retrieve the original state_dict
Browse files Browse the repository at this point in the history
  • Loading branch information
wkpark committed Oct 31, 2023
1 parent 464fbcd commit 6f6e673
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 0 deletions.
10 changes: 10 additions & 0 deletions extensions-builtin/Lora/lora_patches.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,24 +8,34 @@ class LoraPatches:
def __init__(self):
self.Linear_forward = patches.patch(__name__, torch.nn.Linear, 'forward', networks.network_Linear_forward)
self.Linear_load_state_dict = patches.patch(__name__, torch.nn.Linear, '_load_from_state_dict', networks.network_Linear_load_state_dict)
self.Linear_state_dict = patches.patch(__name__, torch.nn.Linear, 'state_dict', networks.network_Linear_state_dict)
self.Conv2d_forward = patches.patch(__name__, torch.nn.Conv2d, 'forward', networks.network_Conv2d_forward)
self.Conv2d_load_state_dict = patches.patch(__name__, torch.nn.Conv2d, '_load_from_state_dict', networks.network_Conv2d_load_state_dict)
self.Conv2d_state_dict = patches.patch(__name__, torch.nn.Conv2d, 'state_dict', networks.network_Conv2d_state_dict)
self.GroupNorm_forward = patches.patch(__name__, torch.nn.GroupNorm, 'forward', networks.network_GroupNorm_forward)
self.GroupNorm_load_state_dict = patches.patch(__name__, torch.nn.GroupNorm, '_load_from_state_dict', networks.network_GroupNorm_load_state_dict)
self.GroupNorm_state_dict = patches.patch(__name__, torch.nn.GroupNorm, 'state_dict', networks.network_GroupNorm_state_dict)
self.LayerNorm_forward = patches.patch(__name__, torch.nn.LayerNorm, 'forward', networks.network_LayerNorm_forward)
self.LayerNorm_load_state_dict = patches.patch(__name__, torch.nn.LayerNorm, '_load_from_state_dict', networks.network_LayerNorm_load_state_dict)
self.LayerNorm_state_dict = patches.patch(__name__, torch.nn.LayerNorm, 'state_dict', networks.network_LayerNorm_state_dict)
self.MultiheadAttention_forward = patches.patch(__name__, torch.nn.MultiheadAttention, 'forward', networks.network_MultiheadAttention_forward)
self.MultiheadAttention_load_state_dict = patches.patch(__name__, torch.nn.MultiheadAttention, '_load_from_state_dict', networks.network_MultiheadAttention_load_state_dict)
self.MultiheadAttention_state_dict = patches.patch(__name__, torch.nn.MultiheadAttention, 'state_dict', networks.network_MultiheadAttention_state_dict)

def undo(self):
self.Linear_forward = patches.undo(__name__, torch.nn.Linear, 'forward')
self.Linear_load_state_dict = patches.undo(__name__, torch.nn.Linear, '_load_from_state_dict')
self.Linear_state_dict = patches.undo(__name__, torch.nn.Linear, 'state_dict')
self.Conv2d_forward = patches.undo(__name__, torch.nn.Conv2d, 'forward')
self.Conv2d_load_state_dict = patches.undo(__name__, torch.nn.Conv2d, '_load_from_state_dict')
self.Conv2d_state_dict = patches.undo(__name__, torch.nn.Conv2d, 'state_dict')
self.GroupNorm_forward = patches.undo(__name__, torch.nn.GroupNorm, 'forward')
self.GroupNorm_load_state_dict = patches.undo(__name__, torch.nn.GroupNorm, '_load_from_state_dict')
self.GroupNorm_state_dict = patches.undo(__name__, torch.nn.GroupNorm, 'state_dict')
self.LayerNorm_forward = patches.undo(__name__, torch.nn.LayerNorm, 'forward')
self.LayerNorm_load_state_dict = patches.undo(__name__, torch.nn.LayerNorm, '_load_from_state_dict')
self.LayerNorm_state_dict = patches.undo(__name__, torch.nn.LayerNorm, 'state_dict')
self.MultiheadAttention_forward = patches.undo(__name__, torch.nn.MultiheadAttention, 'forward')
self.MultiheadAttention_load_state_dict = patches.undo(__name__, torch.nn.MultiheadAttention, '_load_from_state_dict')
self.MultiheadAttention_state_dict = patches.undo(__name__, torch.nn.MultiheadAttention, 'state_dict')

35 changes: 35 additions & 0 deletions extensions-builtin/Lora/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,6 +535,41 @@ def network_MultiheadAttention_load_state_dict(self, *args, **kwargs):
return originals.MultiheadAttention_load_state_dict(self, *args, **kwargs)


def network_Linear_state_dict(self, *args, **kwargs):
with torch.no_grad():
network_restore_weights_from_backup(self)

return originals.Linear_state_dict(self, *args, **kwargs)


def network_Conv2d_state_dict(self, *args, **kwargs):
with torch.no_grad():
network_restore_weights_from_backup(self)

return originals.Conv2d_state_dict(self, *args, **kwargs)


def network_GroupNorm_state_dict(self, *args, **kwargs):
with torch.no_grad():
network_restore_weights_from_backup(self)

return originals.GroupNorm_state_dict(self, *args, **kwargs)


def network_LayerNorm_state_dict(self, *args, **kwargs):
with torch.no_grad():
network_restore_weights_from_backup(self)

return originals.LayerNorm_state_dict(self, *args, **kwargs)


def network_MultiheadAttention_state_dict(self, *args, **kwargs):
with torch.no_grad():
network_restore_weights_from_backup(self)

return originals.MultiheadAttention_state_dict(self, *args, **kwargs)


def list_available_networks():
available_networks.clear()
available_network_aliases.clear()
Expand Down

0 comments on commit 6f6e673

Please sign in to comment.