mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-08-02 23:14:49 +08:00
AuraFlow model implementation.
This commit is contained in:
@@ -6,6 +6,7 @@ from comfy.ldm.cascade.stage_b import StageB
|
||||
from comfy.ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugmentation
|
||||
from comfy.ldm.modules.diffusionmodules.upscaling import ImageConcatWithNoiseAugmentation
|
||||
from comfy.ldm.modules.diffusionmodules.mmdit import OpenAISignatureMMDITWrapper
|
||||
import comfy.ldm.aura.mmdit
|
||||
import comfy.ldm.audio.dit
|
||||
import comfy.ldm.audio.embedders
|
||||
import comfy.model_management
|
||||
@@ -598,6 +599,17 @@ class SD3(BaseModel):
|
||||
area = input_shape[0] * input_shape[2] * input_shape[3]
|
||||
return (area * 0.3) * (1024 * 1024)
|
||||
|
||||
class AuraFlow(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.aura.mmdit.MMDiT)
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = super().extra_conds(**kwargs)
|
||||
cross_attn = kwargs.get("cross_attn", None)
|
||||
if cross_attn is not None:
|
||||
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||
return out
|
||||
|
||||
|
||||
class StableAudio1(BaseModel):
|
||||
def __init__(self, model_config, seconds_start_embedder_weights, seconds_total_embedder_weights, model_type=ModelType.V_PREDICTION_CONTINUOUS, device=None):
|
||||
|
Reference in New Issue
Block a user