Skip to content

Commit

Permalink
Merge pull request #15260 from v0xie/fix-OFT-MhA-AttributeError
Browse files Browse the repository at this point in the history
Fix AttributeError in OFT when trying to get MultiheadAttention weight
  • Loading branch information
AUTOMATIC1111 authored Mar 16, 2024
2 parents 0cc3647 + 07805cb commit 3cb698a
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions extensions-builtin/Lora/network_oft.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,6 @@ def __init__(self, net: network.Network, weights: network.NetworkWeights):
# self.alpha is unused
self.dim = self.oft_blocks.shape[1] # (num_blocks, block_size, block_size)

# LyCORIS BOFT
if self.oft_blocks.dim() == 4:
self.is_boft = True
self.rescale = weights.w.get('rescale', None)
if self.rescale is not None:
self.rescale = self.rescale.reshape(-1, *[1]*(self.org_module[0].weight.dim() - 1))

is_linear = type(self.sd_module) in [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear]
is_conv = type(self.sd_module) in [torch.nn.Conv2d]
is_other_linear = type(self.sd_module) in [torch.nn.MultiheadAttention] # unsupported
Expand All @@ -54,6 +47,13 @@ def __init__(self, net: network.Network, weights: network.NetworkWeights):
elif is_other_linear:
self.out_dim = self.sd_module.embed_dim

# LyCORIS BOFT
if self.oft_blocks.dim() == 4:
self.is_boft = True
self.rescale = weights.w.get('rescale', None)
if self.rescale is not None and not is_other_linear:
self.rescale = self.rescale.reshape(-1, *[1]*(self.org_module[0].weight.dim() - 1))

self.num_blocks = self.dim
self.block_size = self.out_dim // self.dim
self.constraint = (0 if self.alpha is None else self.alpha) * self.out_dim
Expand Down

0 comments on commit 3cb698a

Please sign in to comment.