1
mirror of https://github.com/comfyanonymous/ComfyUI.git synced 2025-08-02 23:14:49 +08:00

Compare commits

...

13 Commits

Author SHA1 Message Date
comfyanonymous
619b8cde74 Bump ComfyUI version to 0.3.11 2025-01-16 14:54:48 -05:00
comfyanonymous
31831e6ef1 Code refactor. 2025-01-16 07:23:54 -05:00
comfyanonymous
88ceb28e20 Tweak hunyuan memory usage factor. 2025-01-16 06:31:03 -05:00
comfyanonymous
23289a6a5c Clean up some debug lines. 2025-01-16 04:24:39 -05:00
comfyanonymous
9d8b6c1f46 More accurate memory estimation for cosmos and hunyuan video. 2025-01-16 03:48:40 -05:00
comfyanonymous
6320d05696 Slightly lower hunyuan video memory usage. 2025-01-16 00:23:01 -05:00
comfyanonymous
25683b5b02 Lower cosmos diffusion model memory usage. 2025-01-15 23:46:42 -05:00
comfyanonymous
4758fb64b9 Lower cosmos VAE memory usage by a bit. 2025-01-15 22:57:52 -05:00
comfyanonymous
008761166f Optimize first attention block in cosmos VAE. 2025-01-15 21:48:46 -05:00
comfyanonymous
bfd5dfd611 3.13 doesn't work yet. 2025-01-15 20:32:44 -05:00
comfyanonymous
55ade36d01 Remove python 3.8 from test-build workflow. 2025-01-15 20:24:55 -05:00
comfyanonymous
2e20e399ea Add minimum numpy version to requirements.txt 2025-01-15 20:19:56 -05:00
comfyanonymous
3baf92d120 CosmosImageToVideoLatent batch_size now does something. 2025-01-15 17:19:59 -05:00
16 changed files with 90 additions and 58 deletions

View File

@@ -18,7 +18,7 @@ jobs:
strategy:
fail-fast: false
matrix:
python-version: ["3.8", "3.9", "3.10", "3.11"]
python-version: ["3.9", "3.10", "3.11", "3.12"]
steps:
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
@@ -28,4 +28,4 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -r requirements.txt
pip install -r requirements.txt

View File

@@ -168,14 +168,18 @@ class Attention(nn.Module):
k = self.to_k[1](k)
v = self.to_v[1](v)
if self.is_selfattn and rope_emb is not None: # only apply to self-attention!
q = apply_rotary_pos_emb(q, rope_emb)
k = apply_rotary_pos_emb(k, rope_emb)
return q, k, v
# apply_rotary_pos_emb inlined
q_shape = q.shape
q = q.reshape(*q.shape[:-1], 2, -1).movedim(-2, -1).unsqueeze(-2)
q = rope_emb[..., 0] * q[..., 0] + rope_emb[..., 1] * q[..., 1]
q = q.movedim(-1, -2).reshape(*q_shape).to(x.dtype)
def cal_attn(self, q, k, v, mask=None):
out = optimized_attention(q, k, v, self.heads, skip_reshape=True, mask=mask, skip_output_reshape=True)
out = rearrange(out, " b n s c -> s b (n c)")
return self.to_out(out)
# apply_rotary_pos_emb inlined
k_shape = k.shape
k = k.reshape(*k.shape[:-1], 2, -1).movedim(-2, -1).unsqueeze(-2)
k = rope_emb[..., 0] * k[..., 0] + rope_emb[..., 1] * k[..., 1]
k = k.movedim(-1, -2).reshape(*k_shape).to(x.dtype)
return q, k, v
def forward(
self,
@@ -191,7 +195,10 @@ class Attention(nn.Module):
context (Optional[Tensor]): The key tensor of shape [B, Mk, K] or use x as context [self attention] if None
"""
q, k, v = self.cal_qkv(x, context, mask, rope_emb=rope_emb, **kwargs)
return self.cal_attn(q, k, v, mask)
out = optimized_attention(q, k, v, self.heads, skip_reshape=True, mask=mask, skip_output_reshape=True)
del q, k, v
out = rearrange(out, " b n s c -> s b (n c)")
return self.to_out(out)
class FeedForward(nn.Module):
@@ -788,10 +795,7 @@ class GeneralDITTransformerBlock(nn.Module):
crossattn_mask: Optional[torch.Tensor] = None,
rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
adaln_lora_B_3D: Optional[torch.Tensor] = None,
extra_per_block_pos_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if extra_per_block_pos_emb is not None:
x = x + extra_per_block_pos_emb
for block in self.blocks:
x = block(
x,

View File

@@ -30,6 +30,8 @@ import torch.nn as nn
import torch.nn.functional as F
import logging
from comfy.ldm.modules.diffusionmodules.model import vae_attention
from .patching import (
Patcher,
Patcher3D,
@@ -400,6 +402,8 @@ class CausalAttnBlock(nn.Module):
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
self.optimized_attention = vae_attention()
def forward(self, x: torch.Tensor) -> torch.Tensor:
h_ = x
h_ = self.norm(h_)
@@ -413,18 +417,7 @@ class CausalAttnBlock(nn.Module):
v, batch_size = time2batch(v)
b, c, h, w = q.shape
q = q.reshape(b, c, h * w)
q = q.permute(0, 2, 1)
k = k.reshape(b, c, h * w)
w_ = torch.bmm(q, k)
w_ = w_ * (int(c) ** (-0.5))
w_ = F.softmax(w_, dim=2)
# attend to values
v = v.reshape(b, c, h * w)
w_ = w_.permute(0, 2, 1)
h_ = torch.bmm(v, w_)
h_ = h_.reshape(b, c, h, w)
h_ = self.optimized_attention(q, k, v)
h_ = batch2time(h_, batch_size)
h_ = self.proj_out(h_)
@@ -871,18 +864,16 @@ class EncoderFactorized(nn.Module):
x = self.patcher3d(x)
# downsampling
hs = [self.conv_in(x)]
h = self.conv_in(x)
for i_level in range(self.num_resolutions):
for i_block in range(self.num_res_blocks):
h = self.down[i_level].block[i_block](hs[-1])
h = self.down[i_level].block[i_block](h)
if len(self.down[i_level].attn) > 0:
h = self.down[i_level].attn[i_block](h)
hs.append(h)
if i_level != self.num_resolutions - 1:
hs.append(self.down[i_level].downsample(hs[-1]))
h = self.down[i_level].downsample(h)
# middle
h = hs[-1]
h = self.mid.block_1(h)
h = self.mid.attn_1(h)
h = self.mid.block_2(h)

View File

@@ -281,54 +281,76 @@ class UnPatcher3D(UnPatcher):
hh = hh.to(dtype=dtype)
xlll, xllh, xlhl, xlhh, xhll, xhlh, xhhl, xhhh = torch.chunk(x, 8, dim=1)
del x
# Height height transposed convolutions.
xll = F.conv_transpose3d(
xlll, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)
)
del xlll
xll += F.conv_transpose3d(
xllh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)
)
del xllh
xlh = F.conv_transpose3d(
xlhl, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)
)
del xlhl
xlh += F.conv_transpose3d(
xlhh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)
)
del xlhh
xhl = F.conv_transpose3d(
xhll, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)
)
del xhll
xhl += F.conv_transpose3d(
xhlh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)
)
del xhlh
xhh = F.conv_transpose3d(
xhhl, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)
)
del xhhl
xhh += F.conv_transpose3d(
xhhh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)
)
del xhhh
# Handles width transposed convolutions.
xl = F.conv_transpose3d(
xll, hl.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1)
)
del xll
xl += F.conv_transpose3d(
xlh, hh.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1)
)
del xlh
xh = F.conv_transpose3d(
xhl, hl.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1)
)
del xhl
xh += F.conv_transpose3d(
xhh, hh.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1)
)
del xhh
# Handles time axis transposed convolutions.
x = F.conv_transpose3d(
xl, hl.unsqueeze(3).unsqueeze(4), groups=g, stride=(2, 1, 1)
)
del xl
x += F.conv_transpose3d(
xh, hh.unsqueeze(3).unsqueeze(4), groups=g, stride=(2, 1, 1)
)

View File

@@ -168,7 +168,7 @@ class GeneralDIT(nn.Module):
operations=operations,
)
self.build_pos_embed(device=device)
self.build_pos_embed(device=device, dtype=dtype)
self.block_x_format = block_x_format
self.use_adaln_lora = use_adaln_lora
self.adaln_lora_dim = adaln_lora_dim
@@ -210,7 +210,7 @@ class GeneralDIT(nn.Module):
operations=operations,
)
def build_pos_embed(self, device=None):
def build_pos_embed(self, device=None, dtype=None):
if self.pos_emb_cls == "rope3d":
cls_type = VideoRopePosition3DEmb
else:
@@ -242,6 +242,7 @@ class GeneralDIT(nn.Module):
kwargs["w_extrapolation_ratio"] = self.extra_w_extrapolation_ratio
kwargs["t_extrapolation_ratio"] = self.extra_t_extrapolation_ratio
kwargs["device"] = device
kwargs["dtype"] = dtype
self.extra_pos_embedder = LearnablePosEmbAxis(
**kwargs,
)
@@ -476,6 +477,8 @@ class GeneralDIT(nn.Module):
inputs["original_shape"],
)
extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = inputs["extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D"].to(x.dtype)
del inputs
if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None:
assert (
x.shape == extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.shape
@@ -486,6 +489,8 @@ class GeneralDIT(nn.Module):
self.blocks["block0"].x_format == block.x_format
), f"First block has x_format {self.blocks[0].x_format}, got {block.x_format}"
if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None:
x += extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D
x = block(
x,
affline_emb_B_D,
@@ -493,7 +498,6 @@ class GeneralDIT(nn.Module):
crossattn_mask,
rope_emb_L_1_1_D=rope_emb_L_1_1_D,
adaln_lora_B_3D=adaln_lora_B_3D,
extra_per_block_pos_emb=extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D,
)
x_B_T_H_W_D = rearrange(x, "T H W B D -> B T H W D")

View File

@@ -173,6 +173,7 @@ class LearnablePosEmbAxis(VideoPositionEmb):
len_w: int,
len_t: int,
device=None,
dtype=None,
**kwargs,
):
"""
@@ -184,9 +185,9 @@ class LearnablePosEmbAxis(VideoPositionEmb):
self.interpolation = interpolation
assert self.interpolation in ["crop"], f"Unknown interpolation method {self.interpolation}"
self.pos_emb_h = nn.Parameter(torch.empty(len_h, model_channels, device=device))
self.pos_emb_w = nn.Parameter(torch.empty(len_w, model_channels, device=device))
self.pos_emb_t = nn.Parameter(torch.empty(len_t, model_channels, device=device))
self.pos_emb_h = nn.Parameter(torch.empty(len_h, model_channels, device=device, dtype=dtype))
self.pos_emb_w = nn.Parameter(torch.empty(len_w, model_channels, device=device, dtype=dtype))
self.pos_emb_t = nn.Parameter(torch.empty(len_t, model_channels, device=device, dtype=dtype))
def generate_embeddings(self, B_T_H_W_C: torch.Size, fps=Optional[torch.Tensor], device=None) -> torch.Tensor:

View File

@@ -89,8 +89,8 @@ class CausalContinuousVideoTokenizer(nn.Module):
self.distribution = IdentityDistribution() # ContinuousFormulation[formulation_name].value()
num_parameters = sum(param.numel() for param in self.parameters())
logging.info(f"model={self.name}, num_parameters={num_parameters:,}")
logging.info(
logging.debug(f"model={self.name}, num_parameters={num_parameters:,}")
logging.debug(
f"z_channels={z_channels}, latent_channels={self.latent_channels}."
)

View File

@@ -230,8 +230,7 @@ class SingleStreamBlock(nn.Module):
def forward(self, x: Tensor, vec: Tensor, pe: Tensor, attn_mask=None) -> Tensor:
mod, _ = self.modulation(vec)
x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
qkv, mlp = torch.split(self.linear1((1 + mod.scale) * self.pre_norm(x) + mod.shift), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
q, k = self.norm(q, k, v)

View File

@@ -5,8 +5,15 @@ from torch import Tensor
from comfy.ldm.modules.attention import optimized_attention
import comfy.model_management
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None) -> Tensor:
q, k = apply_rope(q, k, pe)
q_shape = q.shape
k_shape = k.shape
q = q.float().reshape(*q.shape[:-1], -1, 1, 2)
k = k.float().reshape(*k.shape[:-1], -1, 1, 2)
q = (pe[..., 0] * q[..., 0] + pe[..., 1] * q[..., 1]).reshape(*q_shape).type_as(v)
k = (pe[..., 0] * k[..., 0] + pe[..., 1] * k[..., 1]).reshape(*k_shape).type_as(v)
heads = q.shape[1]
x = optimized_attention(q, k, v, heads, skip_reshape=True, mask=mask)

View File

@@ -293,6 +293,17 @@ def pytorch_attention(q, k, v):
return out
def vae_attention():
if model_management.xformers_enabled_vae():
logging.info("Using xformers attention in VAE")
return xformers_attention
elif model_management.pytorch_attention_enabled():
logging.info("Using pytorch attention in VAE")
return pytorch_attention
else:
logging.info("Using split attention in VAE")
return normal_attention
class AttnBlock(nn.Module):
def __init__(self, in_channels, conv_op=ops.Conv2d):
super().__init__()
@@ -320,15 +331,7 @@ class AttnBlock(nn.Module):
stride=1,
padding=0)
if model_management.xformers_enabled_vae():
logging.info("Using xformers attention in VAE")
self.optimized_attention = xformers_attention
elif model_management.pytorch_attention_enabled():
logging.info("Using pytorch attention in VAE")
self.optimized_attention = pytorch_attention
else:
logging.info("Using split attention in VAE")
self.optimized_attention = normal_attention
self.optimized_attention = vae_attention()
def forward(self, x):
h_ = x

View File

@@ -388,8 +388,8 @@ class VAE:
ddconfig = {'z_channels': 16, 'latent_channels': self.latent_channels, 'z_factor': 1, 'resolution': 1024, 'in_channels': 3, 'out_channels': 3, 'channels': 128, 'channels_mult': [2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [32], 'dropout': 0.0, 'patch_size': 4, 'num_groups': 1, 'temporal_compression': 8, 'spacial_compression': 8}
self.first_stage_model = comfy.ldm.cosmos.vae.CausalContinuousVideoTokenizer(**ddconfig)
#TODO: these values are a bit off because this is not a standard VAE
self.memory_used_decode = lambda shape, dtype: (220 * shape[2] * shape[3] * shape[4] * (8 * 8 * 8)) * model_management.dtype_size(dtype)
self.memory_used_encode = lambda shape, dtype: (500 * max(shape[2], 2) * shape[3] * shape[4]) * model_management.dtype_size(dtype)
self.memory_used_decode = lambda shape, dtype: (50 * shape[2] * shape[3] * shape[4] * (8 * 8 * 8)) * model_management.dtype_size(dtype)
self.memory_used_encode = lambda shape, dtype: (50 * (round((shape[2] + 7) / 8) * 8) * shape[3] * shape[4]) * model_management.dtype_size(dtype)
self.working_dtypes = [torch.bfloat16, torch.float32]
else:
logging.warning("WARNING: No VAE weights detected, VAE not initalized.")

View File

@@ -788,7 +788,7 @@ class HunyuanVideo(supported_models_base.BASE):
unet_extra_config = {}
latent_format = latent_formats.HunyuanVideo
memory_usage_factor = 2.0 #TODO
memory_usage_factor = 1.8 #TODO
supported_inference_dtypes = [torch.bfloat16, torch.float32]
@@ -839,7 +839,7 @@ class CosmosT2V(supported_models_base.BASE):
unet_extra_config = {}
latent_format = latent_formats.Cosmos1CV8x8x8
memory_usage_factor = 2.4 #TODO
memory_usage_factor = 1.6 #TODO
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32] #TODO

View File

@@ -71,8 +71,8 @@ class CosmosImageToVideoLatent:
mask[:, :, -latent_temp.shape[-3]:] *= 0.0
out_latent = {}
out_latent["samples"] = latent
out_latent["noise_mask"] = mask
out_latent["samples"] = latent.repeat((batch_size, ) + (1,) * (latent.ndim - 1))
out_latent["noise_mask"] = mask.repeat((batch_size, ) + (1,) * (mask.ndim - 1))
return (out_latent,)

View File

@@ -1,3 +1,3 @@
# This file is automatically generated by the build process when version is
# updated in pyproject.toml.
__version__ = "0.3.10"
__version__ = "0.3.11"

View File

@@ -1,6 +1,6 @@
[project]
name = "ComfyUI"
version = "0.3.10"
version = "0.3.11"
readme = "README.md"
license = { file = "LICENSE" }
requires-python = ">=3.9"

View File

@@ -2,6 +2,7 @@ torch
torchsde
torchvision
torchaudio
numpy>=1.25.0
einops
transformers>=4.28.1
tokenizers>=0.13.3