Skip to content

Commit

Permalink
Merge pull request #15283 from AUTOMATIC1111/dora-weight-decompose
Browse files Browse the repository at this point in the history
Use correct DoRA implementation
  • Loading branch information
AUTOMATIC1111 authored Mar 16, 2024
2 parents bf35c66 + 8dcb8fa commit df8c09b
Showing 1 changed file with 15 additions and 3 deletions.
18 changes: 15 additions & 3 deletions extensions-builtin/Lora/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def __init__(self, net: Network, weights: NetworkWeights):
self.scale = weights.w["scale"].item() if "scale" in weights.w else None

self.dora_scale = weights.w.get("dora_scale", None)
self.dora_mean_dim = tuple(i for i in range(len(self.shape)) if i != 1)
self.dora_norm_dims = len(self.shape) - 1

def multiplier(self):
if 'transformer' in self.sd_key[:20]:
Expand All @@ -170,10 +170,22 @@ def calc_scale(self):
return 1.0

def apply_weight_decompose(self, updown, orig_weight):
orig_weight = orig_weight.to(updown)
# Match the device/dtype
orig_weight = orig_weight.to(updown.dtype)
dora_scale = self.dora_scale.to(device=orig_weight.device, dtype=updown.dtype)
updown = updown.to(orig_weight.device)

merged_scale1 = updown + orig_weight
merged_scale1_norm = (
merged_scale1.transpose(0, 1)
.reshape(merged_scale1.shape[1], -1)
.norm(dim=1, keepdim=True)
.reshape(merged_scale1.shape[1], *[1] * self.dora_norm_dims)
.transpose(0, 1)
)

dora_merged = (
merged_scale1 / merged_scale1.mean(dim=self.dora_mean_dim, keepdim=True) * self.dora_scale
merged_scale1 * (dora_scale / merged_scale1_norm)
)
final_updown = dora_merged - orig_weight
return final_updown
Expand Down

0 comments on commit df8c09b

Please sign in to comment.