mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-08-02 23:14:49 +08:00
Remove unecessary clones in the wan2.2 VAE. (#9083)
This commit is contained in:
@@ -136,7 +136,7 @@ class ResidualBlock(nn.Module):
|
|||||||
if in_dim != out_dim else nn.Identity())
|
if in_dim != out_dim else nn.Identity())
|
||||||
|
|
||||||
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
||||||
h = self.shortcut(x)
|
old_x = x
|
||||||
for layer in self.residual:
|
for layer in self.residual:
|
||||||
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
||||||
idx = feat_idx[0]
|
idx = feat_idx[0]
|
||||||
@@ -156,7 +156,7 @@ class ResidualBlock(nn.Module):
|
|||||||
feat_idx[0] += 1
|
feat_idx[0] += 1
|
||||||
else:
|
else:
|
||||||
x = layer(x)
|
x = layer(x)
|
||||||
return x + h
|
return x + self.shortcut(old_x)
|
||||||
|
|
||||||
|
|
||||||
def patchify(x, patch_size):
|
def patchify(x, patch_size):
|
||||||
@@ -327,7 +327,7 @@ class Down_ResidualBlock(nn.Module):
|
|||||||
self.downsamples = nn.Sequential(*downsamples)
|
self.downsamples = nn.Sequential(*downsamples)
|
||||||
|
|
||||||
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
||||||
x_copy = x.clone()
|
x_copy = x
|
||||||
for module in self.downsamples:
|
for module in self.downsamples:
|
||||||
x = module(x, feat_cache, feat_idx)
|
x = module(x, feat_cache, feat_idx)
|
||||||
|
|
||||||
@@ -369,7 +369,7 @@ class Up_ResidualBlock(nn.Module):
|
|||||||
self.upsamples = nn.Sequential(*upsamples)
|
self.upsamples = nn.Sequential(*upsamples)
|
||||||
|
|
||||||
def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False):
|
def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False):
|
||||||
x_main = x.clone()
|
x_main = x
|
||||||
for module in self.upsamples:
|
for module in self.upsamples:
|
||||||
x_main = module(x_main, feat_cache, feat_idx)
|
x_main = module(x_main, feat_cache, feat_idx)
|
||||||
if self.avg_shortcut is not None:
|
if self.avg_shortcut is not None:
|
||||||
|
Reference in New Issue
Block a user