1
mirror of https://github.com/comfyanonymous/ComfyUI.git synced 2025-08-02 23:14:49 +08:00

Support SSD1B model and make it easier to support asymmetric unets.

This commit is contained in:
comfyanonymous
2023-10-27 14:15:45 -04:00
parent 434ce25ec0
commit 6ec3f12c6e
6 changed files with 153 additions and 96 deletions

View File

@@ -259,10 +259,6 @@ class UNetModel(nn.Module):
:param model_channels: base channel count for the model.
:param out_channels: channels in the output Tensor.
:param num_res_blocks: number of residual blocks per downsample.
:param attention_resolutions: a collection of downsample rates at which
attention will take place. May be a set, list, or tuple.
For example, if this contains 4, then at 4x downsampling, attention
will be used.
:param dropout: the dropout probability.
:param channel_mult: channel multiplier for each level of the UNet.
:param conv_resample: if True, use learned convolutions for upsampling and
@@ -289,7 +285,6 @@ class UNetModel(nn.Module):
model_channels,
out_channels,
num_res_blocks,
attention_resolutions,
dropout=0,
channel_mult=(1, 2, 4, 8),
conv_resample=True,
@@ -314,6 +309,7 @@ class UNetModel(nn.Module):
use_linear_in_transformer=False,
adm_in_channels=None,
transformer_depth_middle=None,
transformer_depth_output=None,
device=None,
operations=comfy.ops,
):
@@ -341,10 +337,7 @@ class UNetModel(nn.Module):
self.in_channels = in_channels
self.model_channels = model_channels
self.out_channels = out_channels
if isinstance(transformer_depth, int):
transformer_depth = len(channel_mult) * [transformer_depth]
if transformer_depth_middle is None:
transformer_depth_middle = transformer_depth[-1]
if isinstance(num_res_blocks, int):
self.num_res_blocks = len(channel_mult) * [num_res_blocks]
else:
@@ -352,18 +345,16 @@ class UNetModel(nn.Module):
raise ValueError("provide num_res_blocks either as an int (globally constant) or "
"as a list/tuple (per-level) with the same length as channel_mult")
self.num_res_blocks = num_res_blocks
if disable_self_attentions is not None:
# should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
assert len(disable_self_attentions) == len(channel_mult)
if num_attention_blocks is not None:
assert len(num_attention_blocks) == len(self.num_res_blocks)
assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks))))
print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
f"attention will still not be set.")
self.attention_resolutions = attention_resolutions
transformer_depth = transformer_depth[:]
transformer_depth_output = transformer_depth_output[:]
self.dropout = dropout
self.channel_mult = channel_mult
self.conv_resample = conv_resample
@@ -428,7 +419,8 @@ class UNetModel(nn.Module):
)
]
ch = mult * model_channels
if ds in attention_resolutions:
num_transformers = transformer_depth.pop(0)
if num_transformers > 0:
if num_head_channels == -1:
dim_head = ch // num_heads
else:
@@ -444,7 +436,7 @@ class UNetModel(nn.Module):
if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
layers.append(SpatialTransformer(
ch, num_heads, dim_head, depth=transformer_depth[level], context_dim=context_dim,
ch, num_heads, dim_head, depth=num_transformers, context_dim=context_dim,
disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations
)
@@ -488,7 +480,7 @@ class UNetModel(nn.Module):
if legacy:
#num_heads = 1
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
self.middle_block = TimestepEmbedSequential(
mid_block = [
ResBlock(
ch,
time_embed_dim,
@@ -499,8 +491,9 @@ class UNetModel(nn.Module):
dtype=self.dtype,
device=device,
operations=operations
),
SpatialTransformer( # always uses a self-attn
)]
if transformer_depth_middle >= 0:
mid_block += [SpatialTransformer( # always uses a self-attn
ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim,
disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations
@@ -515,8 +508,8 @@ class UNetModel(nn.Module):
dtype=self.dtype,
device=device,
operations=operations
),
)
)]
self.middle_block = TimestepEmbedSequential(*mid_block)
self._feature_size += ch
self.output_blocks = nn.ModuleList([])
@@ -538,7 +531,8 @@ class UNetModel(nn.Module):
)
]
ch = model_channels * mult
if ds in attention_resolutions:
num_transformers = transformer_depth_output.pop()
if num_transformers > 0:
if num_head_channels == -1:
dim_head = ch // num_heads
else:
@@ -555,7 +549,7 @@ class UNetModel(nn.Module):
if not exists(num_attention_blocks) or i < num_attention_blocks[level]:
layers.append(
SpatialTransformer(
ch, num_heads, dim_head, depth=transformer_depth[level], context_dim=context_dim,
ch, num_heads, dim_head, depth=num_transformers, context_dim=context_dim,
disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations
)