From 1792e193b1ad22727d9628dda9c5c6457fd9f294 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Sat, 16 Mar 2024 23:52:29 +0800 Subject: [PATCH] Use correct implementation, fix device error --- extensions-builtin/Lora/network.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/extensions-builtin/Lora/network.py b/extensions-builtin/Lora/network.py index 183f8bd7c13..30b979f598d 100644 --- a/extensions-builtin/Lora/network.py +++ b/extensions-builtin/Lora/network.py @@ -153,7 +153,7 @@ def __init__(self, net: Network, weights: NetworkWeights): self.scale = weights.w["scale"].item() if "scale" in weights.w else None self.dora_scale = weights.w.get("dora_scale", None) - self.dora_mean_dim = tuple(i for i in range(len(self.shape)) if i != 1) + self.dora_norm_dims = len(self.shape) - 1 def multiplier(self): if 'transformer' in self.sd_key[:20]: @@ -170,10 +170,22 @@ def calc_scale(self): return 1.0 def apply_weight_decompose(self, updown, orig_weight): - orig_weight = orig_weight.to(updown) + # Match the device/dtype + orig_weight = orig_weight.to(updown.dtype) + dora_scale = self.dora_scale.to(device=orig_weight.device, dtype=updown.dtype) + updown = updown.to(orig_weight.device) + merged_scale1 = updown + orig_weight + merged_scale1_norm = ( + merged_scale1.transpose(0, 1) + .reshape(merged_scale1.shape[1], -1) + .norm(dim=1, keepdim=True) + .reshape(merged_scale1.shape[1], *[1] * self.dora_norm_dims) + .transpose(0, 1) + ) + dora_merged = ( - merged_scale1 / merged_scale1(dim=self.dora_mean_dim, keepdim=True) * self.dora_scale + merged_scale1 * (dora_scale / merged_scale1_norm) ) final_updown = dora_merged - orig_weight return final_updown