Skip to content

Commit

Permalink
Fix bugs for torch.nn.MultiheadAttention
Browse files Browse the repository at this point in the history
  • Loading branch information
KohakuBlueleaf committed Mar 9, 2024
1 parent 12bcacf commit 851c3d5
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 4 deletions.
8 changes: 7 additions & 1 deletion extensions-builtin/Lora/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,12 @@ def __init__(self, net: Network, weights: NetworkWeights):

if hasattr(self.sd_module, 'weight'):
self.shape = self.sd_module.weight.shape
elif isinstance(self.sd_module, nn.MultiheadAttention):
# For now, only self-attn use Pytorch's MHA
# So assume all qkvo proj have same shape
self.shape = self.sd_module.out_proj.weight.shape
else:
self.shape = None

self.ops = None
self.extra_kwargs = {}
Expand Down Expand Up @@ -146,7 +152,7 @@ def __init__(self, net: Network, weights: NetworkWeights):
self.alpha = weights.w["alpha"].item() if "alpha" in weights.w else None
self.scale = weights.w["scale"].item() if "scale" in weights.w else None

self.dora_scale = weights.w["dora_scale"] if "dora_scale" in weights.w else None
self.dora_scale = weights.w.get("dora_scale", None)
self.dora_mean_dim = tuple(i for i in range(len(self.shape)) if i != 1)

def multiplier(self):
Expand Down
9 changes: 6 additions & 3 deletions extensions-builtin/Lora/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,9 +429,12 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
if isinstance(self, torch.nn.MultiheadAttention) and module_q and module_k and module_v and module_out:
try:
with torch.no_grad():
updown_q, _ = module_q.calc_updown(self.in_proj_weight)
updown_k, _ = module_k.calc_updown(self.in_proj_weight)
updown_v, _ = module_v.calc_updown(self.in_proj_weight)
# Send "real" orig_weight into MHA's lora module
qw, kw, vw = self.in_proj_weight.chunk(3, 0)
updown_q, _ = module_q.calc_updown(qw)
updown_k, _ = module_k.calc_updown(kw)
updown_v, _ = module_v.calc_updown(vw)
del qw, kw, vw
updown_qkv = torch.vstack([updown_q, updown_k, updown_v])
updown_out, ex_bias = module_out.calc_updown(self.out_proj.weight)

Expand Down

0 comments on commit 851c3d5

Please sign in to comment.