1
mirror of https://github.com/comfyanonymous/ComfyUI.git synced 2025-08-02 23:14:49 +08:00

Make VAE code closer to sgm.

This commit is contained in:
comfyanonymous
2023-10-17 14:51:51 -04:00
parent f8caa24bcc
commit d44a2de49f
6 changed files with 224 additions and 211 deletions

View File

@@ -541,7 +541,10 @@ class Decoder(nn.Module):
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False,
attn_type="vanilla", **ignorekwargs):
conv_out_op=comfy.ops.Conv2d,
resnet_op=ResnetBlock,
attn_op=AttnBlock,
**ignorekwargs):
super().__init__()
if use_linear_attn: attn_type = "linear"
self.ch = ch
@@ -570,12 +573,12 @@ class Decoder(nn.Module):
# middle
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(in_channels=block_in,
self.mid.block_1 = resnet_op(in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout)
self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
self.mid.block_2 = ResnetBlock(in_channels=block_in,
self.mid.attn_1 = attn_op(block_in)
self.mid.block_2 = resnet_op(in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout)
@@ -587,13 +590,13 @@ class Decoder(nn.Module):
attn = nn.ModuleList()
block_out = ch*ch_mult[i_level]
for i_block in range(self.num_res_blocks+1):
block.append(ResnetBlock(in_channels=block_in,
block.append(resnet_op(in_channels=block_in,
out_channels=block_out,
temb_channels=self.temb_ch,
dropout=dropout))
block_in = block_out
if curr_res in attn_resolutions:
attn.append(make_attn(block_in, attn_type=attn_type))
attn.append(attn_op(block_in))
up = nn.Module()
up.block = block
up.attn = attn
@@ -604,13 +607,13 @@ class Decoder(nn.Module):
# end
self.norm_out = Normalize(block_in)
self.conv_out = comfy.ops.Conv2d(block_in,
self.conv_out = conv_out_op(block_in,
out_ch,
kernel_size=3,
stride=1,
padding=1)
def forward(self, z):
def forward(self, z, **kwargs):
#assert z.shape[1:] == self.z_shape[1:]
self.last_z_shape = z.shape
@@ -621,16 +624,16 @@ class Decoder(nn.Module):
h = self.conv_in(z)
# middle
h = self.mid.block_1(h, temb)
h = self.mid.attn_1(h)
h = self.mid.block_2(h, temb)
h = self.mid.block_1(h, temb, **kwargs)
h = self.mid.attn_1(h, **kwargs)
h = self.mid.block_2(h, temb, **kwargs)
# upsampling
for i_level in reversed(range(self.num_resolutions)):
for i_block in range(self.num_res_blocks+1):
h = self.up[i_level].block[i_block](h, temb)
h = self.up[i_level].block[i_block](h, temb, **kwargs)
if len(self.up[i_level].attn) > 0:
h = self.up[i_level].attn[i_block](h)
h = self.up[i_level].attn[i_block](h, **kwargs)
if i_level != 0:
h = self.up[i_level].upsample(h)
@@ -640,7 +643,7 @@ class Decoder(nn.Module):
h = self.norm_out(h)
h = nonlinearity(h)
h = self.conv_out(h)
h = self.conv_out(h, **kwargs)
if self.tanh_out:
h = torch.tanh(h)
return h