Skip to content

Commit

Permalink
Add MPS manual cast
Browse files Browse the repository at this point in the history
  • Loading branch information
KohakuBlueleaf committed Oct 28, 2023
1 parent d4d3134 commit ddc2a34
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion modules/devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,8 @@ def manual_autocast():
def manual_cast_forward(self, *args, **kwargs):
org_dtype = next(self.parameters()).dtype
self.to(dtype)
args = [arg.to(dtype) if isinstance(arg, torch.Tensor) else arg for arg in args]
kwargs = {k: v.to(dtype) if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()}
result = self.org_forward(*args, **kwargs)
self.to(org_dtype)
return result
Expand All @@ -136,7 +138,6 @@ def manual_cast_forward(self, *args, **kwargs):


def autocast(disable=False):
print(fp8, dtype, shared.cmd_opts.precision, device)
if disable:
return contextlib.nullcontext()

Expand All @@ -146,6 +147,9 @@ def autocast(disable=False):
if fp8 and (dtype == torch.float32 or shared.cmd_opts.precision == "full" or cuda_no_autocast()):
return manual_autocast()

if has_mps() and shared.cmd_opts.precision != "full":
return manual_autocast()

if dtype == torch.float32 or shared.cmd_opts.precision == "full":
return contextlib.nullcontext()

Expand Down

0 comments on commit ddc2a34

Please sign in to comment.