mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-08-02 23:14:49 +08:00
Tiny wan vae optimizations. (#9136)
This commit is contained in:
@@ -24,12 +24,17 @@ class CausalConv3d(ops.Conv3d):
|
|||||||
self.padding[1], 2 * self.padding[0], 0)
|
self.padding[1], 2 * self.padding[0], 0)
|
||||||
self.padding = (0, 0, 0)
|
self.padding = (0, 0, 0)
|
||||||
|
|
||||||
def forward(self, x, cache_x=None):
|
def forward(self, x, cache_x=None, cache_list=None, cache_idx=None):
|
||||||
|
if cache_list is not None:
|
||||||
|
cache_x = cache_list[cache_idx]
|
||||||
|
cache_list[cache_idx] = None
|
||||||
|
|
||||||
padding = list(self._padding)
|
padding = list(self._padding)
|
||||||
if cache_x is not None and self._padding[4] > 0:
|
if cache_x is not None and self._padding[4] > 0:
|
||||||
cache_x = cache_x.to(x.device)
|
cache_x = cache_x.to(x.device)
|
||||||
x = torch.cat([cache_x, x], dim=2)
|
x = torch.cat([cache_x, x], dim=2)
|
||||||
padding[4] -= cache_x.shape[2]
|
padding[4] -= cache_x.shape[2]
|
||||||
|
del cache_x
|
||||||
x = F.pad(x, padding)
|
x = F.pad(x, padding)
|
||||||
|
|
||||||
return super().forward(x)
|
return super().forward(x)
|
||||||
@@ -166,7 +171,7 @@ class ResidualBlock(nn.Module):
|
|||||||
if in_dim != out_dim else nn.Identity()
|
if in_dim != out_dim else nn.Identity()
|
||||||
|
|
||||||
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
||||||
h = self.shortcut(x)
|
old_x = x
|
||||||
for layer in self.residual:
|
for layer in self.residual:
|
||||||
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
||||||
idx = feat_idx[0]
|
idx = feat_idx[0]
|
||||||
@@ -178,12 +183,12 @@ class ResidualBlock(nn.Module):
|
|||||||
cache_x.device), cache_x
|
cache_x.device), cache_x
|
||||||
],
|
],
|
||||||
dim=2)
|
dim=2)
|
||||||
x = layer(x, feat_cache[idx])
|
x = layer(x, cache_list=feat_cache, cache_idx=idx)
|
||||||
feat_cache[idx] = cache_x
|
feat_cache[idx] = cache_x
|
||||||
feat_idx[0] += 1
|
feat_idx[0] += 1
|
||||||
else:
|
else:
|
||||||
x = layer(x)
|
x = layer(x)
|
||||||
return x + h
|
return x + self.shortcut(old_x)
|
||||||
|
|
||||||
|
|
||||||
class AttentionBlock(nn.Module):
|
class AttentionBlock(nn.Module):
|
||||||
|
@@ -151,7 +151,7 @@ class ResidualBlock(nn.Module):
|
|||||||
],
|
],
|
||||||
dim=2,
|
dim=2,
|
||||||
)
|
)
|
||||||
x = layer(x, feat_cache[idx])
|
x = layer(x, cache_list=feat_cache, cache_idx=idx)
|
||||||
feat_cache[idx] = cache_x
|
feat_cache[idx] = cache_x
|
||||||
feat_idx[0] += 1
|
feat_idx[0] += 1
|
||||||
else:
|
else:
|
||||||
|
Reference in New Issue
Block a user