1
mirror of https://github.com/comfyanonymous/ComfyUI.git synced 2025-08-02 23:14:49 +08:00

Disable autocast in unet for increased speed.

This commit is contained in:
comfyanonymous
2023-07-05 20:58:44 -04:00
parent 603f02d613
commit ddc6f12ad5
9 changed files with 84 additions and 79 deletions

View File

@@ -220,7 +220,7 @@ class ResBlock(TimestepBlock):
self.use_scale_shift_norm = use_scale_shift_norm
self.in_layers = nn.Sequential(
normalization(channels, dtype=dtype),
nn.GroupNorm(32, channels, dtype=dtype),
nn.SiLU(),
conv_nd(dims, channels, self.out_channels, 3, padding=1, dtype=dtype),
)
@@ -244,7 +244,7 @@ class ResBlock(TimestepBlock):
),
)
self.out_layers = nn.Sequential(
normalization(self.out_channels, dtype=dtype),
nn.GroupNorm(32, self.out_channels, dtype=dtype),
nn.SiLU(),
nn.Dropout(p=dropout),
zero_module(
@@ -778,13 +778,13 @@ class UNetModel(nn.Module):
self._feature_size += ch
self.out = nn.Sequential(
normalization(ch, dtype=self.dtype),
nn.GroupNorm(32, ch, dtype=self.dtype),
nn.SiLU(),
zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1, dtype=self.dtype)),
)
if self.predict_codebook_ids:
self.id_predictor = nn.Sequential(
normalization(ch),
nn.GroupNorm(32, ch, dtype=self.dtype),
conv_nd(dims, model_channels, n_embed, 1),
#nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
)
@@ -821,7 +821,7 @@ class UNetModel(nn.Module):
self.num_classes is not None
), "must specify y if and only if the model is class-conditional"
hs = []
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(self.dtype)
emb = self.time_embed(t_emb)
if self.num_classes is not None: