mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-08-02 15:04:50 +08:00
Initialize transformer unet block weights in right dtype at the start.
This commit is contained in:
@@ -631,7 +631,7 @@ class UNetModel(nn.Module):
|
||||
) if not use_spatial_transformer else SpatialTransformer(
|
||||
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
|
||||
disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
|
||||
use_checkpoint=use_checkpoint
|
||||
use_checkpoint=use_checkpoint, dtype=self.dtype
|
||||
)
|
||||
)
|
||||
self.input_blocks.append(TimestepEmbedSequential(*layers))
|
||||
@@ -688,7 +688,7 @@ class UNetModel(nn.Module):
|
||||
) if not use_spatial_transformer else SpatialTransformer( # always uses a self-attn
|
||||
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
|
||||
disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
|
||||
use_checkpoint=use_checkpoint
|
||||
use_checkpoint=use_checkpoint, dtype=self.dtype
|
||||
),
|
||||
ResBlock(
|
||||
ch,
|
||||
@@ -742,7 +742,7 @@ class UNetModel(nn.Module):
|
||||
) if not use_spatial_transformer else SpatialTransformer(
|
||||
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
|
||||
disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
|
||||
use_checkpoint=use_checkpoint
|
||||
use_checkpoint=use_checkpoint, dtype=self.dtype
|
||||
)
|
||||
)
|
||||
if level and i == self.num_res_blocks[level]:
|
||||
|
Reference in New Issue
Block a user