mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-08-02 23:14:49 +08:00
Stable Cascade Stage B.
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
import torch
|
||||
from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep
|
||||
from comfy.ldm.cascade.stage_c import StageC
|
||||
from comfy.ldm.cascade.stage_b import StageB
|
||||
from comfy.ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugmentation
|
||||
from comfy.ldm.modules.diffusionmodules.upscaling import ImageConcatWithNoiseAugmentation
|
||||
import comfy.model_management
|
||||
@@ -461,3 +462,27 @@ class StableCascade_C(BaseModel):
|
||||
out['clip_text'] = comfy.conds.CONDCrossAttn(cross_attn)
|
||||
return out
|
||||
|
||||
|
||||
class StableCascade_B(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.STABLE_CASCADE, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=StageB)
|
||||
self.diffusion_model.eval().requires_grad_(False)
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = {}
|
||||
noise = kwargs.get("noise", None)
|
||||
|
||||
clip_text_pooled = kwargs["pooled_output"]
|
||||
if clip_text_pooled is not None:
|
||||
out['clip_text_pooled'] = comfy.conds.CONDRegular(clip_text_pooled)
|
||||
|
||||
#size of prior doesn't really matter if zeros because it gets resized but I still want it to get batched
|
||||
prior = kwargs.get("stable_cascade_prior", torch.zeros((1, 16, (noise.shape[2] * 4) // 42, (noise.shape[3] * 4) // 42), dtype=noise.dtype, layout=noise.layout, device=noise.device))
|
||||
|
||||
out["effnet"] = comfy.conds.CONDRegular(prior)
|
||||
out["sca"] = comfy.conds.CONDRegular(torch.zeros((1,)))
|
||||
|
||||
cross_attn = kwargs.get("cross_attn", None)
|
||||
if cross_attn is not None:
|
||||
out['clip'] = comfy.conds.CONDCrossAttn(cross_attn)
|
||||
return out
|
||||
|
Reference in New Issue
Block a user