1
mirror of https://github.com/comfyanonymous/ComfyUI.git synced 2025-08-02 15:04:50 +08:00

Initialize transformer unet block weights in right dtype at the start.

This commit is contained in:
comfyanonymous
2023-06-15 14:29:26 -04:00
parent 6253ec4aef
commit e21d9ad445
2 changed files with 44 additions and 44 deletions

View File

@@ -631,7 +631,7 @@ class UNetModel(nn.Module):
) if not use_spatial_transformer else SpatialTransformer(
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
use_checkpoint=use_checkpoint
use_checkpoint=use_checkpoint, dtype=self.dtype
)
)
self.input_blocks.append(TimestepEmbedSequential(*layers))
@@ -688,7 +688,7 @@ class UNetModel(nn.Module):
) if not use_spatial_transformer else SpatialTransformer( # always uses a self-attn
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
use_checkpoint=use_checkpoint
use_checkpoint=use_checkpoint, dtype=self.dtype
),
ResBlock(
ch,
@@ -742,7 +742,7 @@ class UNetModel(nn.Module):
) if not use_spatial_transformer else SpatialTransformer(
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
use_checkpoint=use_checkpoint
use_checkpoint=use_checkpoint, dtype=self.dtype
)
)
if level and i == self.num_res_blocks[level]: