mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-08-02 23:14:49 +08:00
727 lines
22 KiB
Python
727 lines
22 KiB
Python
# original version: https://github.com/Wan-Video/Wan2.2/blob/main/wan/modules/vae2_2.py
|
|
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from einops import rearrange
|
|
from .vae import AttentionBlock, CausalConv3d, RMS_norm
|
|
|
|
import comfy.ops
|
|
ops = comfy.ops.disable_weight_init
|
|
|
|
CACHE_T = 2
|
|
|
|
|
|
class Resample(nn.Module):
|
|
|
|
def __init__(self, dim, mode):
|
|
assert mode in (
|
|
"none",
|
|
"upsample2d",
|
|
"upsample3d",
|
|
"downsample2d",
|
|
"downsample3d",
|
|
)
|
|
super().__init__()
|
|
self.dim = dim
|
|
self.mode = mode
|
|
|
|
# layers
|
|
if mode == "upsample2d":
|
|
self.resample = nn.Sequential(
|
|
nn.Upsample(scale_factor=(2.0, 2.0), mode="nearest-exact"),
|
|
ops.Conv2d(dim, dim, 3, padding=1),
|
|
)
|
|
elif mode == "upsample3d":
|
|
self.resample = nn.Sequential(
|
|
nn.Upsample(scale_factor=(2.0, 2.0), mode="nearest-exact"),
|
|
ops.Conv2d(dim, dim, 3, padding=1),
|
|
# ops.Conv2d(dim, dim//2, 3, padding=1)
|
|
)
|
|
self.time_conv = CausalConv3d(
|
|
dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
|
|
elif mode == "downsample2d":
|
|
self.resample = nn.Sequential(
|
|
nn.ZeroPad2d((0, 1, 0, 1)),
|
|
ops.Conv2d(dim, dim, 3, stride=(2, 2)))
|
|
elif mode == "downsample3d":
|
|
self.resample = nn.Sequential(
|
|
nn.ZeroPad2d((0, 1, 0, 1)),
|
|
ops.Conv2d(dim, dim, 3, stride=(2, 2)))
|
|
self.time_conv = CausalConv3d(
|
|
dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0))
|
|
else:
|
|
self.resample = nn.Identity()
|
|
|
|
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
|
b, c, t, h, w = x.size()
|
|
if self.mode == "upsample3d":
|
|
if feat_cache is not None:
|
|
idx = feat_idx[0]
|
|
if feat_cache[idx] is None:
|
|
feat_cache[idx] = "Rep"
|
|
feat_idx[0] += 1
|
|
else:
|
|
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
|
if (cache_x.shape[2] < 2 and feat_cache[idx] is not None and
|
|
feat_cache[idx] != "Rep"):
|
|
# cache last frame of last two chunk
|
|
cache_x = torch.cat(
|
|
[
|
|
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
|
cache_x.device),
|
|
cache_x,
|
|
],
|
|
dim=2,
|
|
)
|
|
if (cache_x.shape[2] < 2 and feat_cache[idx] is not None and
|
|
feat_cache[idx] == "Rep"):
|
|
cache_x = torch.cat(
|
|
[
|
|
torch.zeros_like(cache_x).to(cache_x.device),
|
|
cache_x
|
|
],
|
|
dim=2,
|
|
)
|
|
if feat_cache[idx] == "Rep":
|
|
x = self.time_conv(x)
|
|
else:
|
|
x = self.time_conv(x, feat_cache[idx])
|
|
feat_cache[idx] = cache_x
|
|
feat_idx[0] += 1
|
|
x = x.reshape(b, 2, c, t, h, w)
|
|
x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]),
|
|
3)
|
|
x = x.reshape(b, c, t * 2, h, w)
|
|
t = x.shape[2]
|
|
x = rearrange(x, "b c t h w -> (b t) c h w")
|
|
x = self.resample(x)
|
|
x = rearrange(x, "(b t) c h w -> b c t h w", t=t)
|
|
|
|
if self.mode == "downsample3d":
|
|
if feat_cache is not None:
|
|
idx = feat_idx[0]
|
|
if feat_cache[idx] is None:
|
|
feat_cache[idx] = x.clone()
|
|
feat_idx[0] += 1
|
|
else:
|
|
cache_x = x[:, :, -1:, :, :].clone()
|
|
x = self.time_conv(
|
|
torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))
|
|
feat_cache[idx] = cache_x
|
|
feat_idx[0] += 1
|
|
return x
|
|
|
|
|
|
class ResidualBlock(nn.Module):
|
|
|
|
def __init__(self, in_dim, out_dim, dropout=0.0):
|
|
super().__init__()
|
|
self.in_dim = in_dim
|
|
self.out_dim = out_dim
|
|
|
|
# layers
|
|
self.residual = nn.Sequential(
|
|
RMS_norm(in_dim, images=False),
|
|
nn.SiLU(),
|
|
CausalConv3d(in_dim, out_dim, 3, padding=1),
|
|
RMS_norm(out_dim, images=False),
|
|
nn.SiLU(),
|
|
nn.Dropout(dropout),
|
|
CausalConv3d(out_dim, out_dim, 3, padding=1),
|
|
)
|
|
self.shortcut = (
|
|
CausalConv3d(in_dim, out_dim, 1)
|
|
if in_dim != out_dim else nn.Identity())
|
|
|
|
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
|
h = self.shortcut(x)
|
|
for layer in self.residual:
|
|
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
|
idx = feat_idx[0]
|
|
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
|
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
|
# cache last frame of last two chunk
|
|
cache_x = torch.cat(
|
|
[
|
|
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
|
cache_x.device),
|
|
cache_x,
|
|
],
|
|
dim=2,
|
|
)
|
|
x = layer(x, feat_cache[idx])
|
|
feat_cache[idx] = cache_x
|
|
feat_idx[0] += 1
|
|
else:
|
|
x = layer(x)
|
|
return x + h
|
|
|
|
|
|
def patchify(x, patch_size):
|
|
if patch_size == 1:
|
|
return x
|
|
if x.dim() == 4:
|
|
x = rearrange(
|
|
x, "b c (h q) (w r) -> b (c r q) h w", q=patch_size, r=patch_size)
|
|
elif x.dim() == 5:
|
|
x = rearrange(
|
|
x,
|
|
"b c f (h q) (w r) -> b (c r q) f h w",
|
|
q=patch_size,
|
|
r=patch_size,
|
|
)
|
|
else:
|
|
raise ValueError(f"Invalid input shape: {x.shape}")
|
|
|
|
return x
|
|
|
|
|
|
def unpatchify(x, patch_size):
|
|
if patch_size == 1:
|
|
return x
|
|
|
|
if x.dim() == 4:
|
|
x = rearrange(
|
|
x, "b (c r q) h w -> b c (h q) (w r)", q=patch_size, r=patch_size)
|
|
elif x.dim() == 5:
|
|
x = rearrange(
|
|
x,
|
|
"b (c r q) f h w -> b c f (h q) (w r)",
|
|
q=patch_size,
|
|
r=patch_size,
|
|
)
|
|
return x
|
|
|
|
|
|
class AvgDown3D(nn.Module):
|
|
|
|
def __init__(
|
|
self,
|
|
in_channels,
|
|
out_channels,
|
|
factor_t,
|
|
factor_s=1,
|
|
):
|
|
super().__init__()
|
|
self.in_channels = in_channels
|
|
self.out_channels = out_channels
|
|
self.factor_t = factor_t
|
|
self.factor_s = factor_s
|
|
self.factor = self.factor_t * self.factor_s * self.factor_s
|
|
|
|
assert in_channels * self.factor % out_channels == 0
|
|
self.group_size = in_channels * self.factor // out_channels
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
pad_t = (self.factor_t - x.shape[2] % self.factor_t) % self.factor_t
|
|
pad = (0, 0, 0, 0, pad_t, 0)
|
|
x = F.pad(x, pad)
|
|
B, C, T, H, W = x.shape
|
|
x = x.view(
|
|
B,
|
|
C,
|
|
T // self.factor_t,
|
|
self.factor_t,
|
|
H // self.factor_s,
|
|
self.factor_s,
|
|
W // self.factor_s,
|
|
self.factor_s,
|
|
)
|
|
x = x.permute(0, 1, 3, 5, 7, 2, 4, 6).contiguous()
|
|
x = x.view(
|
|
B,
|
|
C * self.factor,
|
|
T // self.factor_t,
|
|
H // self.factor_s,
|
|
W // self.factor_s,
|
|
)
|
|
x = x.view(
|
|
B,
|
|
self.out_channels,
|
|
self.group_size,
|
|
T // self.factor_t,
|
|
H // self.factor_s,
|
|
W // self.factor_s,
|
|
)
|
|
x = x.mean(dim=2)
|
|
return x
|
|
|
|
|
|
class DupUp3D(nn.Module):
|
|
|
|
def __init__(
|
|
self,
|
|
in_channels: int,
|
|
out_channels: int,
|
|
factor_t,
|
|
factor_s=1,
|
|
):
|
|
super().__init__()
|
|
self.in_channels = in_channels
|
|
self.out_channels = out_channels
|
|
|
|
self.factor_t = factor_t
|
|
self.factor_s = factor_s
|
|
self.factor = self.factor_t * self.factor_s * self.factor_s
|
|
|
|
assert out_channels * self.factor % in_channels == 0
|
|
self.repeats = out_channels * self.factor // in_channels
|
|
|
|
def forward(self, x: torch.Tensor, first_chunk=False) -> torch.Tensor:
|
|
x = x.repeat_interleave(self.repeats, dim=1)
|
|
x = x.view(
|
|
x.size(0),
|
|
self.out_channels,
|
|
self.factor_t,
|
|
self.factor_s,
|
|
self.factor_s,
|
|
x.size(2),
|
|
x.size(3),
|
|
x.size(4),
|
|
)
|
|
x = x.permute(0, 1, 5, 2, 6, 3, 7, 4).contiguous()
|
|
x = x.view(
|
|
x.size(0),
|
|
self.out_channels,
|
|
x.size(2) * self.factor_t,
|
|
x.size(4) * self.factor_s,
|
|
x.size(6) * self.factor_s,
|
|
)
|
|
if first_chunk:
|
|
x = x[:, :, self.factor_t - 1:, :, :]
|
|
return x
|
|
|
|
|
|
class Down_ResidualBlock(nn.Module):
|
|
|
|
def __init__(self,
|
|
in_dim,
|
|
out_dim,
|
|
dropout,
|
|
mult,
|
|
temperal_downsample=False,
|
|
down_flag=False):
|
|
super().__init__()
|
|
|
|
# Shortcut path with downsample
|
|
self.avg_shortcut = AvgDown3D(
|
|
in_dim,
|
|
out_dim,
|
|
factor_t=2 if temperal_downsample else 1,
|
|
factor_s=2 if down_flag else 1,
|
|
)
|
|
|
|
# Main path with residual blocks and downsample
|
|
downsamples = []
|
|
for _ in range(mult):
|
|
downsamples.append(ResidualBlock(in_dim, out_dim, dropout))
|
|
in_dim = out_dim
|
|
|
|
# Add the final downsample block
|
|
if down_flag:
|
|
mode = "downsample3d" if temperal_downsample else "downsample2d"
|
|
downsamples.append(Resample(out_dim, mode=mode))
|
|
|
|
self.downsamples = nn.Sequential(*downsamples)
|
|
|
|
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
|
x_copy = x.clone()
|
|
for module in self.downsamples:
|
|
x = module(x, feat_cache, feat_idx)
|
|
|
|
return x + self.avg_shortcut(x_copy)
|
|
|
|
|
|
class Up_ResidualBlock(nn.Module):
|
|
|
|
def __init__(self,
|
|
in_dim,
|
|
out_dim,
|
|
dropout,
|
|
mult,
|
|
temperal_upsample=False,
|
|
up_flag=False):
|
|
super().__init__()
|
|
# Shortcut path with upsample
|
|
if up_flag:
|
|
self.avg_shortcut = DupUp3D(
|
|
in_dim,
|
|
out_dim,
|
|
factor_t=2 if temperal_upsample else 1,
|
|
factor_s=2 if up_flag else 1,
|
|
)
|
|
else:
|
|
self.avg_shortcut = None
|
|
|
|
# Main path with residual blocks and upsample
|
|
upsamples = []
|
|
for _ in range(mult):
|
|
upsamples.append(ResidualBlock(in_dim, out_dim, dropout))
|
|
in_dim = out_dim
|
|
|
|
# Add the final upsample block
|
|
if up_flag:
|
|
mode = "upsample3d" if temperal_upsample else "upsample2d"
|
|
upsamples.append(Resample(out_dim, mode=mode))
|
|
|
|
self.upsamples = nn.Sequential(*upsamples)
|
|
|
|
def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False):
|
|
x_main = x.clone()
|
|
for module in self.upsamples:
|
|
x_main = module(x_main, feat_cache, feat_idx)
|
|
if self.avg_shortcut is not None:
|
|
x_shortcut = self.avg_shortcut(x, first_chunk)
|
|
return x_main + x_shortcut
|
|
else:
|
|
return x_main
|
|
|
|
|
|
class Encoder3d(nn.Module):
|
|
|
|
def __init__(
|
|
self,
|
|
dim=128,
|
|
z_dim=4,
|
|
dim_mult=[1, 2, 4, 4],
|
|
num_res_blocks=2,
|
|
attn_scales=[],
|
|
temperal_downsample=[True, True, False],
|
|
dropout=0.0,
|
|
):
|
|
super().__init__()
|
|
self.dim = dim
|
|
self.z_dim = z_dim
|
|
self.dim_mult = dim_mult
|
|
self.num_res_blocks = num_res_blocks
|
|
self.attn_scales = attn_scales
|
|
self.temperal_downsample = temperal_downsample
|
|
|
|
# dimensions
|
|
dims = [dim * u for u in [1] + dim_mult]
|
|
scale = 1.0
|
|
|
|
# init block
|
|
self.conv1 = CausalConv3d(12, dims[0], 3, padding=1)
|
|
|
|
# downsample blocks
|
|
downsamples = []
|
|
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
|
|
t_down_flag = (
|
|
temperal_downsample[i]
|
|
if i < len(temperal_downsample) else False)
|
|
downsamples.append(
|
|
Down_ResidualBlock(
|
|
in_dim=in_dim,
|
|
out_dim=out_dim,
|
|
dropout=dropout,
|
|
mult=num_res_blocks,
|
|
temperal_downsample=t_down_flag,
|
|
down_flag=i != len(dim_mult) - 1,
|
|
))
|
|
scale /= 2.0
|
|
self.downsamples = nn.Sequential(*downsamples)
|
|
|
|
# middle blocks
|
|
self.middle = nn.Sequential(
|
|
ResidualBlock(out_dim, out_dim, dropout),
|
|
AttentionBlock(out_dim),
|
|
ResidualBlock(out_dim, out_dim, dropout),
|
|
)
|
|
|
|
# # output blocks
|
|
self.head = nn.Sequential(
|
|
RMS_norm(out_dim, images=False),
|
|
nn.SiLU(),
|
|
CausalConv3d(out_dim, z_dim, 3, padding=1),
|
|
)
|
|
|
|
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
|
|
|
if feat_cache is not None:
|
|
idx = feat_idx[0]
|
|
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
|
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
|
cache_x = torch.cat(
|
|
[
|
|
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
|
cache_x.device),
|
|
cache_x,
|
|
],
|
|
dim=2,
|
|
)
|
|
x = self.conv1(x, feat_cache[idx])
|
|
feat_cache[idx] = cache_x
|
|
feat_idx[0] += 1
|
|
else:
|
|
x = self.conv1(x)
|
|
|
|
## downsamples
|
|
for layer in self.downsamples:
|
|
if feat_cache is not None:
|
|
x = layer(x, feat_cache, feat_idx)
|
|
else:
|
|
x = layer(x)
|
|
|
|
## middle
|
|
for layer in self.middle:
|
|
if isinstance(layer, ResidualBlock) and feat_cache is not None:
|
|
x = layer(x, feat_cache, feat_idx)
|
|
else:
|
|
x = layer(x)
|
|
|
|
## head
|
|
for layer in self.head:
|
|
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
|
idx = feat_idx[0]
|
|
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
|
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
|
cache_x = torch.cat(
|
|
[
|
|
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
|
cache_x.device),
|
|
cache_x,
|
|
],
|
|
dim=2,
|
|
)
|
|
x = layer(x, feat_cache[idx])
|
|
feat_cache[idx] = cache_x
|
|
feat_idx[0] += 1
|
|
else:
|
|
x = layer(x)
|
|
|
|
return x
|
|
|
|
|
|
class Decoder3d(nn.Module):
|
|
|
|
def __init__(
|
|
self,
|
|
dim=128,
|
|
z_dim=4,
|
|
dim_mult=[1, 2, 4, 4],
|
|
num_res_blocks=2,
|
|
attn_scales=[],
|
|
temperal_upsample=[False, True, True],
|
|
dropout=0.0,
|
|
):
|
|
super().__init__()
|
|
self.dim = dim
|
|
self.z_dim = z_dim
|
|
self.dim_mult = dim_mult
|
|
self.num_res_blocks = num_res_blocks
|
|
self.attn_scales = attn_scales
|
|
self.temperal_upsample = temperal_upsample
|
|
|
|
# dimensions
|
|
dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
|
|
# init block
|
|
self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1)
|
|
|
|
# middle blocks
|
|
self.middle = nn.Sequential(
|
|
ResidualBlock(dims[0], dims[0], dropout),
|
|
AttentionBlock(dims[0]),
|
|
ResidualBlock(dims[0], dims[0], dropout),
|
|
)
|
|
|
|
# upsample blocks
|
|
upsamples = []
|
|
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
|
|
t_up_flag = temperal_upsample[i] if i < len(
|
|
temperal_upsample) else False
|
|
upsamples.append(
|
|
Up_ResidualBlock(
|
|
in_dim=in_dim,
|
|
out_dim=out_dim,
|
|
dropout=dropout,
|
|
mult=num_res_blocks + 1,
|
|
temperal_upsample=t_up_flag,
|
|
up_flag=i != len(dim_mult) - 1,
|
|
))
|
|
self.upsamples = nn.Sequential(*upsamples)
|
|
|
|
# output blocks
|
|
self.head = nn.Sequential(
|
|
RMS_norm(out_dim, images=False),
|
|
nn.SiLU(),
|
|
CausalConv3d(out_dim, 12, 3, padding=1),
|
|
)
|
|
|
|
def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False):
|
|
if feat_cache is not None:
|
|
idx = feat_idx[0]
|
|
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
|
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
|
cache_x = torch.cat(
|
|
[
|
|
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
|
cache_x.device),
|
|
cache_x,
|
|
],
|
|
dim=2,
|
|
)
|
|
x = self.conv1(x, feat_cache[idx])
|
|
feat_cache[idx] = cache_x
|
|
feat_idx[0] += 1
|
|
else:
|
|
x = self.conv1(x)
|
|
|
|
for layer in self.middle:
|
|
if isinstance(layer, ResidualBlock) and feat_cache is not None:
|
|
x = layer(x, feat_cache, feat_idx)
|
|
else:
|
|
x = layer(x)
|
|
|
|
## upsamples
|
|
for layer in self.upsamples:
|
|
if feat_cache is not None:
|
|
x = layer(x, feat_cache, feat_idx, first_chunk)
|
|
else:
|
|
x = layer(x)
|
|
|
|
## head
|
|
for layer in self.head:
|
|
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
|
idx = feat_idx[0]
|
|
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
|
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
|
cache_x = torch.cat(
|
|
[
|
|
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
|
cache_x.device),
|
|
cache_x,
|
|
],
|
|
dim=2,
|
|
)
|
|
x = layer(x, feat_cache[idx])
|
|
feat_cache[idx] = cache_x
|
|
feat_idx[0] += 1
|
|
else:
|
|
x = layer(x)
|
|
return x
|
|
|
|
|
|
def count_conv3d(model):
|
|
count = 0
|
|
for m in model.modules():
|
|
if isinstance(m, CausalConv3d):
|
|
count += 1
|
|
return count
|
|
|
|
|
|
class WanVAE(nn.Module):
|
|
|
|
def __init__(
|
|
self,
|
|
dim=160,
|
|
dec_dim=256,
|
|
z_dim=16,
|
|
dim_mult=[1, 2, 4, 4],
|
|
num_res_blocks=2,
|
|
attn_scales=[],
|
|
temperal_downsample=[True, True, False],
|
|
dropout=0.0,
|
|
):
|
|
super().__init__()
|
|
self.dim = dim
|
|
self.z_dim = z_dim
|
|
self.dim_mult = dim_mult
|
|
self.num_res_blocks = num_res_blocks
|
|
self.attn_scales = attn_scales
|
|
self.temperal_downsample = temperal_downsample
|
|
self.temperal_upsample = temperal_downsample[::-1]
|
|
|
|
# modules
|
|
self.encoder = Encoder3d(
|
|
dim,
|
|
z_dim * 2,
|
|
dim_mult,
|
|
num_res_blocks,
|
|
attn_scales,
|
|
self.temperal_downsample,
|
|
dropout,
|
|
)
|
|
self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1)
|
|
self.conv2 = CausalConv3d(z_dim, z_dim, 1)
|
|
self.decoder = Decoder3d(
|
|
dec_dim,
|
|
z_dim,
|
|
dim_mult,
|
|
num_res_blocks,
|
|
attn_scales,
|
|
self.temperal_upsample,
|
|
dropout,
|
|
)
|
|
|
|
def encode(self, x):
|
|
self.clear_cache()
|
|
x = patchify(x, patch_size=2)
|
|
t = x.shape[2]
|
|
iter_ = 1 + (t - 1) // 4
|
|
for i in range(iter_):
|
|
self._enc_conv_idx = [0]
|
|
if i == 0:
|
|
out = self.encoder(
|
|
x[:, :, :1, :, :],
|
|
feat_cache=self._enc_feat_map,
|
|
feat_idx=self._enc_conv_idx,
|
|
)
|
|
else:
|
|
out_ = self.encoder(
|
|
x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :],
|
|
feat_cache=self._enc_feat_map,
|
|
feat_idx=self._enc_conv_idx,
|
|
)
|
|
out = torch.cat([out, out_], 2)
|
|
mu, log_var = self.conv1(out).chunk(2, dim=1)
|
|
self.clear_cache()
|
|
return mu
|
|
|
|
def decode(self, z):
|
|
self.clear_cache()
|
|
iter_ = z.shape[2]
|
|
x = self.conv2(z)
|
|
for i in range(iter_):
|
|
self._conv_idx = [0]
|
|
if i == 0:
|
|
out = self.decoder(
|
|
x[:, :, i:i + 1, :, :],
|
|
feat_cache=self._feat_map,
|
|
feat_idx=self._conv_idx,
|
|
first_chunk=True,
|
|
)
|
|
else:
|
|
out_ = self.decoder(
|
|
x[:, :, i:i + 1, :, :],
|
|
feat_cache=self._feat_map,
|
|
feat_idx=self._conv_idx,
|
|
)
|
|
out = torch.cat([out, out_], 2)
|
|
out = unpatchify(out, patch_size=2)
|
|
self.clear_cache()
|
|
return out
|
|
|
|
def reparameterize(self, mu, log_var):
|
|
std = torch.exp(0.5 * log_var)
|
|
eps = torch.randn_like(std)
|
|
return eps * std + mu
|
|
|
|
def sample(self, imgs, deterministic=False):
|
|
mu, log_var = self.encode(imgs)
|
|
if deterministic:
|
|
return mu
|
|
std = torch.exp(0.5 * log_var.clamp(-30.0, 20.0))
|
|
return mu + std * torch.randn_like(std)
|
|
|
|
def clear_cache(self):
|
|
self._conv_num = count_conv3d(self.decoder)
|
|
self._conv_idx = [0]
|
|
self._feat_map = [None] * self._conv_num
|
|
# cache encode
|
|
self._enc_conv_num = count_conv3d(self.encoder)
|
|
self._enc_conv_idx = [0]
|
|
self._enc_feat_map = [None] * self._enc_conv_num
|