1
mirror of https://github.com/comfyanonymous/ComfyUI.git synced 2025-08-02 15:04:50 +08:00

Unify RMSNorm code.

This commit is contained in:
comfyanonymous
2024-08-28 16:18:39 -04:00
parent b79fd7d92c
commit d31e226650
3 changed files with 17 additions and 24 deletions

View File

@@ -355,29 +355,9 @@ class RMSNorm(torch.nn.Module):
else:
self.register_parameter("weight", None)
def _norm(self, x):
"""
Apply the RMSNorm normalization to the input tensor.
Args:
x (torch.Tensor): The input tensor.
Returns:
torch.Tensor: The normalized tensor.
"""
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
"""
Forward pass through the RMSNorm layer.
Args:
x (torch.Tensor): The input tensor.
Returns:
torch.Tensor: The output tensor after applying RMSNorm.
"""
x = self._norm(x)
if self.learnable_scale:
return x * self.weight.to(device=x.device, dtype=x.dtype)
else:
return x
return comfy.ldm.common_dit.rms_norm(x, self.weight, self.eps)
class SwiGLUFeedForward(nn.Module):