mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-08-02 23:14:49 +08:00
Basic initial support for cosmos predict2 text to image 2B and 14B models. (#8517)
This commit is contained in:
@@ -34,6 +34,7 @@ import comfy.ldm.flux.model
|
||||
import comfy.ldm.lightricks.model
|
||||
import comfy.ldm.hunyuan_video.model
|
||||
import comfy.ldm.cosmos.model
|
||||
import comfy.ldm.cosmos.predict2
|
||||
import comfy.ldm.lumina.model
|
||||
import comfy.ldm.wan.model
|
||||
import comfy.ldm.hunyuan3d.model
|
||||
@@ -48,6 +49,7 @@ import comfy.ops
|
||||
from enum import Enum
|
||||
from . import utils
|
||||
import comfy.latent_formats
|
||||
import comfy.model_sampling
|
||||
import math
|
||||
from typing import TYPE_CHECKING
|
||||
if TYPE_CHECKING:
|
||||
@@ -63,38 +65,39 @@ class ModelType(Enum):
|
||||
V_PREDICTION_CONTINUOUS = 7
|
||||
FLUX = 8
|
||||
IMG_TO_IMG = 9
|
||||
|
||||
|
||||
from comfy.model_sampling import EPS, V_PREDICTION, EDM, ModelSamplingDiscrete, ModelSamplingContinuousEDM, StableCascadeSampling, ModelSamplingContinuousV
|
||||
FLOW_COSMOS = 10
|
||||
|
||||
|
||||
def model_sampling(model_config, model_type):
|
||||
s = ModelSamplingDiscrete
|
||||
s = comfy.model_sampling.ModelSamplingDiscrete
|
||||
|
||||
if model_type == ModelType.EPS:
|
||||
c = EPS
|
||||
c = comfy.model_sampling.EPS
|
||||
elif model_type == ModelType.V_PREDICTION:
|
||||
c = V_PREDICTION
|
||||
c = comfy.model_sampling.V_PREDICTION
|
||||
elif model_type == ModelType.V_PREDICTION_EDM:
|
||||
c = V_PREDICTION
|
||||
s = ModelSamplingContinuousEDM
|
||||
c = comfy.model_sampling.V_PREDICTION
|
||||
s = comfy.model_sampling.ModelSamplingContinuousEDM
|
||||
elif model_type == ModelType.FLOW:
|
||||
c = comfy.model_sampling.CONST
|
||||
s = comfy.model_sampling.ModelSamplingDiscreteFlow
|
||||
elif model_type == ModelType.STABLE_CASCADE:
|
||||
c = EPS
|
||||
s = StableCascadeSampling
|
||||
c = comfy.model_sampling.EPS
|
||||
s = comfy.model_sampling.StableCascadeSampling
|
||||
elif model_type == ModelType.EDM:
|
||||
c = EDM
|
||||
s = ModelSamplingContinuousEDM
|
||||
c = comfy.model_sampling.EDM
|
||||
s = comfy.model_sampling.ModelSamplingContinuousEDM
|
||||
elif model_type == ModelType.V_PREDICTION_CONTINUOUS:
|
||||
c = V_PREDICTION
|
||||
s = ModelSamplingContinuousV
|
||||
c = comfy.model_sampling.V_PREDICTION
|
||||
s = comfy.model_sampling.ModelSamplingContinuousV
|
||||
elif model_type == ModelType.FLUX:
|
||||
c = comfy.model_sampling.CONST
|
||||
s = comfy.model_sampling.ModelSamplingFlux
|
||||
elif model_type == ModelType.IMG_TO_IMG:
|
||||
c = comfy.model_sampling.IMG_TO_IMG
|
||||
elif model_type == ModelType.FLOW_COSMOS:
|
||||
c = comfy.model_sampling.COSMOS_RFLOW
|
||||
s = comfy.model_sampling.ModelSamplingCosmosRFlow
|
||||
|
||||
class ModelSampling(s, c):
|
||||
pass
|
||||
@@ -998,6 +1001,22 @@ class CosmosVideo(BaseModel):
|
||||
latent_image = self.model_sampling.calculate_input(torch.tensor([sigma_noise_augmentation], device=latent_image.device, dtype=latent_image.dtype), latent_image)
|
||||
return latent_image * ((sigma ** 2 + self.model_sampling.sigma_data ** 2) ** 0.5)
|
||||
|
||||
class CosmosPredict2(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW_COSMOS, image_to_video=False, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.cosmos.predict2.MiniTrainDIT)
|
||||
self.image_to_video = image_to_video
|
||||
if self.image_to_video:
|
||||
self.concat_keys = ("mask_inverted",)
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = super().extra_conds(**kwargs)
|
||||
cross_attn = kwargs.get("cross_attn", None)
|
||||
if cross_attn is not None:
|
||||
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||
|
||||
out['fps'] = comfy.conds.CONDConstant(kwargs.get("frame_rate", None))
|
||||
return out
|
||||
|
||||
class Lumina2(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.lumina.model.NextDiT)
|
||||
|
Reference in New Issue
Block a user