mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-08-02 23:14:49 +08:00
Initial support for the stable audio open model.
This commit is contained in:
@@ -6,12 +6,15 @@ from comfy.ldm.cascade.stage_b import StageB
|
||||
from comfy.ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugmentation
|
||||
from comfy.ldm.modules.diffusionmodules.upscaling import ImageConcatWithNoiseAugmentation
|
||||
from comfy.ldm.modules.diffusionmodules.mmdit import OpenAISignatureMMDITWrapper
|
||||
import comfy.ldm.audio.dit
|
||||
import comfy.ldm.audio.embedders
|
||||
import comfy.model_management
|
||||
import comfy.conds
|
||||
import comfy.ops
|
||||
from enum import Enum
|
||||
from . import utils
|
||||
import comfy.latent_formats
|
||||
import math
|
||||
|
||||
class ModelType(Enum):
|
||||
EPS = 1
|
||||
@@ -20,9 +23,10 @@ class ModelType(Enum):
|
||||
STABLE_CASCADE = 4
|
||||
EDM = 5
|
||||
FLOW = 6
|
||||
V_PREDICTION_CONTINUOUS = 7
|
||||
|
||||
|
||||
from comfy.model_sampling import EPS, V_PREDICTION, EDM, ModelSamplingDiscrete, ModelSamplingContinuousEDM, StableCascadeSampling
|
||||
from comfy.model_sampling import EPS, V_PREDICTION, EDM, ModelSamplingDiscrete, ModelSamplingContinuousEDM, StableCascadeSampling, ModelSamplingContinuousV
|
||||
|
||||
|
||||
def model_sampling(model_config, model_type):
|
||||
@@ -44,6 +48,9 @@ def model_sampling(model_config, model_type):
|
||||
elif model_type == ModelType.EDM:
|
||||
c = EDM
|
||||
s = ModelSamplingContinuousEDM
|
||||
elif model_type == ModelType.V_PREDICTION_CONTINUOUS:
|
||||
c = V_PREDICTION
|
||||
s = ModelSamplingContinuousV
|
||||
|
||||
class ModelSampling(s, c):
|
||||
pass
|
||||
@@ -236,11 +243,11 @@ class BaseModel(torch.nn.Module):
|
||||
if self.manual_cast_dtype is not None:
|
||||
dtype = self.manual_cast_dtype
|
||||
#TODO: this needs to be tweaked
|
||||
area = input_shape[0] * input_shape[2] * input_shape[3]
|
||||
area = input_shape[0] * math.prod(input_shape[2:])
|
||||
return (area * comfy.model_management.dtype_size(dtype) / 50) * (1024 * 1024)
|
||||
else:
|
||||
#TODO: this formula might be too aggressive since I tweaked the sub-quad and split algorithms to use less memory.
|
||||
area = input_shape[0] * input_shape[2] * input_shape[3]
|
||||
area = input_shape[0] * math.prod(input_shape[2:])
|
||||
return (((area * 0.6) / 0.9) + 1024) * (1024 * 1024)
|
||||
|
||||
|
||||
@@ -590,3 +597,33 @@ class SD3(BaseModel):
|
||||
else:
|
||||
area = input_shape[0] * input_shape[2] * input_shape[3]
|
||||
return (area * 0.3) * (1024 * 1024)
|
||||
|
||||
|
||||
class StableAudio1(BaseModel):
|
||||
def __init__(self, model_config, seconds_start_embedder_weights, seconds_total_embedder_weights, model_type=ModelType.V_PREDICTION_CONTINUOUS, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.audio.dit.AudioDiffusionTransformer)
|
||||
self.seconds_start_embedder = comfy.ldm.audio.embedders.NumberConditioner(768, min_val=0, max_val=512)
|
||||
self.seconds_total_embedder = comfy.ldm.audio.embedders.NumberConditioner(768, min_val=0, max_val=512)
|
||||
self.seconds_start_embedder.load_state_dict(seconds_start_embedder_weights)
|
||||
self.seconds_total_embedder.load_state_dict(seconds_total_embedder_weights)
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = {}
|
||||
|
||||
noise = kwargs.get("noise", None)
|
||||
device = kwargs["device"]
|
||||
|
||||
seconds_start = kwargs.get("seconds_start", 0)
|
||||
seconds_total = kwargs.get("seconds_total", int(noise.shape[-1] / 21.53))
|
||||
|
||||
seconds_start_embed = self.seconds_start_embedder([seconds_start])[0].to(device)
|
||||
seconds_total_embed = self.seconds_total_embedder([seconds_total])[0].to(device)
|
||||
|
||||
global_embed = torch.cat([seconds_start_embed, seconds_total_embed], dim=-1).reshape((1, -1))
|
||||
out['global_embed'] = comfy.conds.CONDRegular(global_embed)
|
||||
|
||||
cross_attn = kwargs.get("cross_attn", None)
|
||||
if cross_attn is not None:
|
||||
cross_attn = torch.cat([cross_attn.to(device), seconds_start_embed.repeat((cross_attn.shape[0], 1, 1)), seconds_total_embed.repeat((cross_attn.shape[0], 1, 1))], dim=1)
|
||||
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||
return out
|
||||
|
Reference in New Issue
Block a user