1
mirror of https://github.com/comfyanonymous/ComfyUI.git synced 2025-08-02 23:14:49 +08:00

All the unet ops with weights are now handled by comfy.ops

This commit is contained in:
comfyanonymous
2023-12-04 03:12:18 -05:00
parent 6efe561c2a
commit af365e4dd1
4 changed files with 28 additions and 21 deletions

View File

@@ -177,7 +177,7 @@ class ResBlock(TimestepBlock):
padding = kernel_size // 2
self.in_layers = nn.Sequential(
nn.GroupNorm(32, channels, dtype=dtype, device=device),
operations.GroupNorm(32, channels, dtype=dtype, device=device),
nn.SiLU(),
operations.conv_nd(dims, channels, self.out_channels, kernel_size, padding=padding, dtype=dtype, device=device),
)
@@ -206,12 +206,11 @@ class ResBlock(TimestepBlock):
),
)
self.out_layers = nn.Sequential(
nn.GroupNorm(32, self.out_channels, dtype=dtype, device=device),
operations.GroupNorm(32, self.out_channels, dtype=dtype, device=device),
nn.SiLU(),
nn.Dropout(p=dropout),
zero_module(
operations.conv_nd(dims, self.out_channels, self.out_channels, kernel_size, padding=padding, dtype=dtype, device=device)
),
operations.conv_nd(dims, self.out_channels, self.out_channels, kernel_size, padding=padding, dtype=dtype, device=device)
,
)
if self.out_channels == channels:
@@ -810,13 +809,13 @@ class UNetModel(nn.Module):
self._feature_size += ch
self.out = nn.Sequential(
nn.GroupNorm(32, ch, dtype=self.dtype, device=device),
operations.GroupNorm(32, ch, dtype=self.dtype, device=device),
nn.SiLU(),
zero_module(operations.conv_nd(dims, model_channels, out_channels, 3, padding=1, dtype=self.dtype, device=device)),
)
if self.predict_codebook_ids:
self.id_predictor = nn.Sequential(
nn.GroupNorm(32, ch, dtype=self.dtype, device=device),
operations.GroupNorm(32, ch, dtype=self.dtype, device=device),
operations.conv_nd(dims, model_channels, n_embed, 1, dtype=self.dtype, device=device),
#nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
)