diff --git a/comfy/ldm/wan/vae.py b/comfy/ldm/wan/vae.py index 6b07840fc..791596938 100644 --- a/comfy/ldm/wan/vae.py +++ b/comfy/ldm/wan/vae.py @@ -24,12 +24,17 @@ class CausalConv3d(ops.Conv3d): self.padding[1], 2 * self.padding[0], 0) self.padding = (0, 0, 0) - def forward(self, x, cache_x=None): + def forward(self, x, cache_x=None, cache_list=None, cache_idx=None): + if cache_list is not None: + cache_x = cache_list[cache_idx] + cache_list[cache_idx] = None + padding = list(self._padding) if cache_x is not None and self._padding[4] > 0: cache_x = cache_x.to(x.device) x = torch.cat([cache_x, x], dim=2) padding[4] -= cache_x.shape[2] + del cache_x x = F.pad(x, padding) return super().forward(x) @@ -166,7 +171,7 @@ class ResidualBlock(nn.Module): if in_dim != out_dim else nn.Identity() def forward(self, x, feat_cache=None, feat_idx=[0]): - h = self.shortcut(x) + old_x = x for layer in self.residual: if isinstance(layer, CausalConv3d) and feat_cache is not None: idx = feat_idx[0] @@ -178,12 +183,12 @@ class ResidualBlock(nn.Module): cache_x.device), cache_x ], dim=2) - x = layer(x, feat_cache[idx]) + x = layer(x, cache_list=feat_cache, cache_idx=idx) feat_cache[idx] = cache_x feat_idx[0] += 1 else: x = layer(x) - return x + h + return x + self.shortcut(old_x) class AttentionBlock(nn.Module): diff --git a/comfy/ldm/wan/vae2_2.py b/comfy/ldm/wan/vae2_2.py index b9c2d1a26..1f6d584a2 100644 --- a/comfy/ldm/wan/vae2_2.py +++ b/comfy/ldm/wan/vae2_2.py @@ -151,7 +151,7 @@ class ResidualBlock(nn.Module): ], dim=2, ) - x = layer(x, feat_cache[idx]) + x = layer(x, cache_list=feat_cache, cache_idx=idx) feat_cache[idx] = cache_x feat_idx[0] += 1 else: