-
Notifications
You must be signed in to change notification settings - Fork 117
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
RMS Norm doesn't seem to be supported #355
Comments
Using this implementation of RMSNorm instead of the built in one also fails: class RMSNorm(nn.Module):
def __init__(self, normalized_shape: int, eps=1e-8 ):
"""
Root Mean Square Layer Normalization
:param normalized_shape: input size
:param eps: epsilon value, default 1e-8
"""
super().__init__()
self.eps = eps
self.normalized_shape = normalized_shape
self.scale = nn.Parameter(torch.ones(normalized_shape))
self.register_parameter("scale", self.scale)
def forward(self, x: torch.Tensor):
norm_x = x.norm(2, dim=-1, keepdim=True)
d_x = self.normalized_shape
rms_x = norm_x * d_x ** (-1.0 / 2)
x_normed = x / (rms_x + self.eps)
return self.scale * x_normed error: File "rmsnormmodel_q.py", line 31, in forward
norm_0_f = transpose_0_f.norm(2, dim=-1, keepdim=True)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File ".../lib/python3.12/site-packages/torch/_tensor.py", line 761, in norm
return torch.norm(self, p, dim, keepdim, dtype=dtype)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File ".../lib/python3.12/site-packages/torch/functional.py", line 1632, in norm
return torch.linalg.vector_norm(input, _p, _dim, keepdim, dtype=dtype)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: linalg.vector_norm: Expected a floating point or complex tensor as input. Got QUInt8 |
@spacycoder Yes, both quantization for either |
FYI class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-6) -> None:
super().__init__()
self.eps = eps
self.scale = nn.Parameter(torch.ones(dim))
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x (Tensor): input tensor to normalize
Returns:
Tensor: The output tensor after applying RMSNorm.
"""
# computation is in fp32
x_fp32 = x.float()
x_normed = (
x_fp32 * torch.rsqrt(x_fp32.pow(2).mean(-1, keepdim=True) + self.eps)
).type_as(x)
return x_normed * self.scale A working implementation of RMSNorm can be made like this: class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-6) -> None:
super().__init__()
self.eps = eps
self.scale = nn.Parameter(torch.ones(dim))
def forward(self, x: torch.Tensor) -> torch.Tensor:
x_fp32 = x.float()
var = x_fp32.pow(2).mean(dim=-1, keepdim=True) + self.eps
x_norm = x_fp32 * (1. / torch.sqrt(var))
return self.scale * x_norm |
@spacycoder OP-wise speaking, yes, we may go through |
with #356, at least it won't throw an error for the models you provided. Quantization for those ops are still skipped. |
Hi, converting a model that uses
nn.RMSNorm
does not work:error:
The text was updated successfully, but these errors were encountered: