mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-08-02 23:14:49 +08:00
Initial support for the stable audio open model.
This commit is contained in:
41
comfy/ops.py
41
comfy/ops.py
@@ -51,6 +51,20 @@ class disable_weight_init:
|
||||
else:
|
||||
return super().forward(*args, **kwargs)
|
||||
|
||||
class Conv1d(torch.nn.Conv1d, CastWeightBiasOp):
|
||||
def reset_parameters(self):
|
||||
return None
|
||||
|
||||
def forward_comfy_cast_weights(self, input):
|
||||
weight, bias = cast_bias_weight(self, input)
|
||||
return self._conv_forward(input, weight, bias)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
if self.comfy_cast_weights:
|
||||
return self.forward_comfy_cast_weights(*args, **kwargs)
|
||||
else:
|
||||
return super().forward(*args, **kwargs)
|
||||
|
||||
class Conv2d(torch.nn.Conv2d, CastWeightBiasOp):
|
||||
def reset_parameters(self):
|
||||
return None
|
||||
@@ -133,6 +147,27 @@ class disable_weight_init:
|
||||
else:
|
||||
return super().forward(*args, **kwargs)
|
||||
|
||||
class ConvTranspose1d(torch.nn.ConvTranspose1d, CastWeightBiasOp):
|
||||
def reset_parameters(self):
|
||||
return None
|
||||
|
||||
def forward_comfy_cast_weights(self, input, output_size=None):
|
||||
num_spatial_dims = 1
|
||||
output_padding = self._output_padding(
|
||||
input, output_size, self.stride, self.padding, self.kernel_size,
|
||||
num_spatial_dims, self.dilation)
|
||||
|
||||
weight, bias = cast_bias_weight(self, input)
|
||||
return torch.nn.functional.conv_transpose1d(
|
||||
input, weight, bias, self.stride, self.padding,
|
||||
output_padding, self.groups, self.dilation)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
if self.comfy_cast_weights:
|
||||
return self.forward_comfy_cast_weights(*args, **kwargs)
|
||||
else:
|
||||
return super().forward(*args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def conv_nd(s, dims, *args, **kwargs):
|
||||
if dims == 2:
|
||||
@@ -147,6 +182,9 @@ class manual_cast(disable_weight_init):
|
||||
class Linear(disable_weight_init.Linear):
|
||||
comfy_cast_weights = True
|
||||
|
||||
class Conv1d(disable_weight_init.Conv1d):
|
||||
comfy_cast_weights = True
|
||||
|
||||
class Conv2d(disable_weight_init.Conv2d):
|
||||
comfy_cast_weights = True
|
||||
|
||||
@@ -161,3 +199,6 @@ class manual_cast(disable_weight_init):
|
||||
|
||||
class ConvTranspose2d(disable_weight_init.ConvTranspose2d):
|
||||
comfy_cast_weights = True
|
||||
|
||||
class ConvTranspose1d(disable_weight_init.ConvTranspose1d):
|
||||
comfy_cast_weights = True
|
||||
|
Reference in New Issue
Block a user