diff --git a/comfy/ldm/wan/model.py b/comfy/ldm/wan/model.py index b9e47e9f7..a93a13c86 100644 --- a/comfy/ldm/wan/model.py +++ b/comfy/ldm/wan/model.py @@ -146,6 +146,15 @@ WAN_CROSSATTENTION_CLASSES = { } +def repeat_e(e, x): + repeats = 1 + if e.shape[1] > 1: + repeats = x.shape[1] // e.shape[1] + if repeats == 1: + return e + return torch.repeat_interleave(e, repeats, dim=1) + + class WanAttentionBlock(nn.Module): def __init__(self, @@ -201,6 +210,7 @@ class WanAttentionBlock(nn.Module): freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2] """ # assert e.dtype == torch.float32 + if e.ndim < 4: e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device) + e).chunk(6, dim=1) else: @@ -209,15 +219,15 @@ class WanAttentionBlock(nn.Module): # self-attention y = self.self_attn( - self.norm1(x) * (1 + e[1]) + e[0], + self.norm1(x) * (1 + repeat_e(e[1], x)) + repeat_e(e[0], x), freqs) - x = x + y * e[2] + x = x + y * repeat_e(e[2], x) # cross-attention & ffn x = x + self.cross_attn(self.norm3(x), context, context_img_len=context_img_len) - y = self.ffn(self.norm2(x) * (1 + e[4]) + e[3]) - x = x + y * e[5] + y = self.ffn(self.norm2(x) * (1 + repeat_e(e[4], x)) + repeat_e(e[3], x)) + x = x + y * repeat_e(e[5], x) return x @@ -331,7 +341,8 @@ class Head(nn.Module): e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device) + e.unsqueeze(1)).chunk(2, dim=1) else: e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device).unsqueeze(0) + e.unsqueeze(2)).unbind(2) - x = (self.head(self.norm(x) * (1 + e[1]) + e[0])) + + x = (self.head(self.norm(x) * (1 + repeat_e(e[1], x)) + repeat_e(e[0], x))) return x diff --git a/comfy/model_base.py b/comfy/model_base.py index d019b991a..6b7978949 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1202,7 +1202,7 @@ class WAN22(BaseModel): def process_timestep(self, timestep, x, denoise_mask=None, **kwargs): if denoise_mask is None: return timestep - temp_ts = (torch.mean(denoise_mask[:, :, :, ::2, ::2], dim=1, keepdim=True) * timestep.view([timestep.shape[0]] + [1] * (denoise_mask.ndim - 1))).reshape(timestep.shape[0], -1) + temp_ts = (torch.mean(denoise_mask[:, :, :, :, :], dim=(1, 3, 4), keepdim=True) * timestep.view([timestep.shape[0]] + [1] * (denoise_mask.ndim - 1))).reshape(timestep.shape[0], -1) return temp_ts def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs):