Skip to content

Commit

Permalink
Make fast fp8 take a bit less peak memory.
Browse files Browse the repository at this point in the history
  • Loading branch information
comfyanonymous committed Dec 24, 2024
1 parent 73e0498 commit 99a1fb6
Showing 1 changed file with 10 additions and 8 deletions.
18 changes: 10 additions & 8 deletions comfy/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,9 +255,10 @@ def fp8_linear(self, input):
tensor_2d = True
input = input.unsqueeze(1)


input_shape = input.shape
input_dtype = input.dtype
if len(input.shape) == 3:
w, bias = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input.dtype)
w, bias = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype)
w = w.t()

scale_weight = self.scale_weight
Expand All @@ -269,23 +270,24 @@ def fp8_linear(self, input):

if scale_input is None:
scale_input = torch.ones((), device=input.device, dtype=torch.float32)
inn = torch.clamp(input, min=-448, max=448).reshape(-1, input.shape[2]).to(dtype)
input = torch.clamp(input, min=-448, max=448, out=input)
input = input.reshape(-1, input_shape[2]).to(dtype)
else:
scale_input = scale_input.to(input.device)
inn = (input * (1.0 / scale_input).to(input.dtype)).reshape(-1, input.shape[2]).to(dtype)
input = (input * (1.0 / scale_input).to(input_dtype)).reshape(-1, input_shape[2]).to(dtype)

if bias is not None:
o = torch._scaled_mm(inn, w, out_dtype=input.dtype, bias=bias, scale_a=scale_input, scale_b=scale_weight)
o = torch._scaled_mm(input, w, out_dtype=input_dtype, bias=bias, scale_a=scale_input, scale_b=scale_weight)
else:
o = torch._scaled_mm(inn, w, out_dtype=input.dtype, scale_a=scale_input, scale_b=scale_weight)
o = torch._scaled_mm(input, w, out_dtype=input_dtype, scale_a=scale_input, scale_b=scale_weight)

if isinstance(o, tuple):
o = o[0]

if tensor_2d:
return o.reshape(input.shape[0], -1)
return o.reshape(input_shape[0], -1)

return o.reshape((-1, input.shape[1], self.weight.shape[0]))
return o.reshape((-1, input_shape[1], self.weight.shape[0]))

return None

Expand Down

0 comments on commit 99a1fb6

Please sign in to comment.