Skip to content

Commit

Permalink
restore org_dtype != compute dtype case
Browse files Browse the repository at this point in the history
  • Loading branch information
wkpark committed Nov 1, 2024
1 parent b783a96 commit 310d0e6
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion extensions-builtin/Lora/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,8 @@ def restore_weights_backup(obj, field, weight):
setattr(obj, field, None)
return

getattr(obj, field).copy_(weight)
old_weight = getattr(obj, field)
old_weight.copy_(weight.to(dtype=old_weight.dtype))


def network_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention], cleanup=False):
Expand Down

0 comments on commit 310d0e6

Please sign in to comment.