1
mirror of https://github.com/comfyanonymous/ComfyUI.git synced 2025-08-02 23:14:49 +08:00

add RMSNorm to comfy.ops

This commit is contained in:
comfyanonymous
2025-04-14 18:00:33 -04:00
parent a14c2fc356
commit 8a438115fb
3 changed files with 88 additions and 17 deletions

View File

@@ -21,6 +21,7 @@ import logging
import comfy.model_management
from comfy.cli_args import args, PerformanceFeature
import comfy.float
import comfy.rmsnorm
cast_to = comfy.model_management.cast_to #TODO: remove once no more references
@@ -146,6 +147,25 @@ class disable_weight_init:
else:
return super().forward(*args, **kwargs)
class RMSNorm(comfy.rmsnorm.RMSNorm, CastWeightBiasOp):
def reset_parameters(self):
self.bias = None
return None
def forward_comfy_cast_weights(self, input):
if self.weight is not None:
weight, bias = cast_bias_weight(self, input)
else:
weight = None
return comfy.rmsnorm.rms_norm(input, weight, self.eps) # TODO: switch to commented out line when old torch is deprecated
# return torch.nn.functional.rms_norm(input, self.normalized_shape, weight, self.eps)
def forward(self, *args, **kwargs):
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
return super().forward(*args, **kwargs)
class ConvTranspose2d(torch.nn.ConvTranspose2d, CastWeightBiasOp):
def reset_parameters(self):
return None