mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-08-02 15:04:50 +08:00
Initialize the unet directly on the target device.
This commit is contained in:
@@ -111,14 +111,14 @@ class Upsample(nn.Module):
|
||||
upsampling occurs in the inner-two dimensions.
|
||||
"""
|
||||
|
||||
def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1, dtype=None):
|
||||
def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1, dtype=None, device=None):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.out_channels = out_channels or channels
|
||||
self.use_conv = use_conv
|
||||
self.dims = dims
|
||||
if use_conv:
|
||||
self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding, dtype=dtype)
|
||||
self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x, output_shape=None):
|
||||
assert x.shape[1] == self.channels
|
||||
@@ -160,7 +160,7 @@ class Downsample(nn.Module):
|
||||
downsampling occurs in the inner-two dimensions.
|
||||
"""
|
||||
|
||||
def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1, dtype=None):
|
||||
def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1, dtype=None, device=None):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.out_channels = out_channels or channels
|
||||
@@ -169,7 +169,7 @@ class Downsample(nn.Module):
|
||||
stride = 2 if dims != 3 else (1, 2, 2)
|
||||
if use_conv:
|
||||
self.op = conv_nd(
|
||||
dims, self.channels, self.out_channels, 3, stride=stride, padding=padding, dtype=dtype
|
||||
dims, self.channels, self.out_channels, 3, stride=stride, padding=padding, dtype=dtype, device=device
|
||||
)
|
||||
else:
|
||||
assert self.channels == self.out_channels
|
||||
@@ -208,7 +208,8 @@ class ResBlock(TimestepBlock):
|
||||
use_checkpoint=False,
|
||||
up=False,
|
||||
down=False,
|
||||
dtype=None
|
||||
dtype=None,
|
||||
device=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
@@ -220,19 +221,19 @@ class ResBlock(TimestepBlock):
|
||||
self.use_scale_shift_norm = use_scale_shift_norm
|
||||
|
||||
self.in_layers = nn.Sequential(
|
||||
nn.GroupNorm(32, channels, dtype=dtype),
|
||||
nn.GroupNorm(32, channels, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
conv_nd(dims, channels, self.out_channels, 3, padding=1, dtype=dtype),
|
||||
conv_nd(dims, channels, self.out_channels, 3, padding=1, dtype=dtype, device=device),
|
||||
)
|
||||
|
||||
self.updown = up or down
|
||||
|
||||
if up:
|
||||
self.h_upd = Upsample(channels, False, dims, dtype=dtype)
|
||||
self.x_upd = Upsample(channels, False, dims, dtype=dtype)
|
||||
self.h_upd = Upsample(channels, False, dims, dtype=dtype, device=device)
|
||||
self.x_upd = Upsample(channels, False, dims, dtype=dtype, device=device)
|
||||
elif down:
|
||||
self.h_upd = Downsample(channels, False, dims, dtype=dtype)
|
||||
self.x_upd = Downsample(channels, False, dims, dtype=dtype)
|
||||
self.h_upd = Downsample(channels, False, dims, dtype=dtype, device=device)
|
||||
self.x_upd = Downsample(channels, False, dims, dtype=dtype, device=device)
|
||||
else:
|
||||
self.h_upd = self.x_upd = nn.Identity()
|
||||
|
||||
@@ -240,15 +241,15 @@ class ResBlock(TimestepBlock):
|
||||
nn.SiLU(),
|
||||
linear(
|
||||
emb_channels,
|
||||
2 * self.out_channels if use_scale_shift_norm else self.out_channels, dtype=dtype
|
||||
2 * self.out_channels if use_scale_shift_norm else self.out_channels, dtype=dtype, device=device
|
||||
),
|
||||
)
|
||||
self.out_layers = nn.Sequential(
|
||||
nn.GroupNorm(32, self.out_channels, dtype=dtype),
|
||||
nn.GroupNorm(32, self.out_channels, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
nn.Dropout(p=dropout),
|
||||
zero_module(
|
||||
conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1, dtype=dtype)
|
||||
conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1, dtype=dtype, device=device)
|
||||
),
|
||||
)
|
||||
|
||||
@@ -256,10 +257,10 @@ class ResBlock(TimestepBlock):
|
||||
self.skip_connection = nn.Identity()
|
||||
elif use_conv:
|
||||
self.skip_connection = conv_nd(
|
||||
dims, channels, self.out_channels, 3, padding=1, dtype=dtype
|
||||
dims, channels, self.out_channels, 3, padding=1, dtype=dtype, device=device
|
||||
)
|
||||
else:
|
||||
self.skip_connection = conv_nd(dims, channels, self.out_channels, 1, dtype=dtype)
|
||||
self.skip_connection = conv_nd(dims, channels, self.out_channels, 1, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x, emb):
|
||||
"""
|
||||
@@ -503,6 +504,7 @@ class UNetModel(nn.Module):
|
||||
use_linear_in_transformer=False,
|
||||
adm_in_channels=None,
|
||||
transformer_depth_middle=None,
|
||||
device=None,
|
||||
):
|
||||
super().__init__()
|
||||
if use_spatial_transformer:
|
||||
@@ -564,9 +566,9 @@ class UNetModel(nn.Module):
|
||||
|
||||
time_embed_dim = model_channels * 4
|
||||
self.time_embed = nn.Sequential(
|
||||
linear(model_channels, time_embed_dim, dtype=self.dtype),
|
||||
linear(model_channels, time_embed_dim, dtype=self.dtype, device=device),
|
||||
nn.SiLU(),
|
||||
linear(time_embed_dim, time_embed_dim, dtype=self.dtype),
|
||||
linear(time_embed_dim, time_embed_dim, dtype=self.dtype, device=device),
|
||||
)
|
||||
|
||||
if self.num_classes is not None:
|
||||
@@ -579,9 +581,9 @@ class UNetModel(nn.Module):
|
||||
assert adm_in_channels is not None
|
||||
self.label_emb = nn.Sequential(
|
||||
nn.Sequential(
|
||||
linear(adm_in_channels, time_embed_dim, dtype=self.dtype),
|
||||
linear(adm_in_channels, time_embed_dim, dtype=self.dtype, device=device),
|
||||
nn.SiLU(),
|
||||
linear(time_embed_dim, time_embed_dim, dtype=self.dtype),
|
||||
linear(time_embed_dim, time_embed_dim, dtype=self.dtype, device=device),
|
||||
)
|
||||
)
|
||||
else:
|
||||
@@ -590,7 +592,7 @@ class UNetModel(nn.Module):
|
||||
self.input_blocks = nn.ModuleList(
|
||||
[
|
||||
TimestepEmbedSequential(
|
||||
conv_nd(dims, in_channels, model_channels, 3, padding=1, dtype=self.dtype)
|
||||
conv_nd(dims, in_channels, model_channels, 3, padding=1, dtype=self.dtype, device=device)
|
||||
)
|
||||
]
|
||||
)
|
||||
@@ -609,7 +611,8 @@ class UNetModel(nn.Module):
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
dtype=self.dtype
|
||||
dtype=self.dtype,
|
||||
device=device,
|
||||
)
|
||||
]
|
||||
ch = mult * model_channels
|
||||
@@ -638,7 +641,7 @@ class UNetModel(nn.Module):
|
||||
) if not use_spatial_transformer else SpatialTransformer(
|
||||
ch, num_heads, dim_head, depth=transformer_depth[level], context_dim=context_dim,
|
||||
disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
|
||||
use_checkpoint=use_checkpoint, dtype=self.dtype
|
||||
use_checkpoint=use_checkpoint, dtype=self.dtype, device=device
|
||||
)
|
||||
)
|
||||
self.input_blocks.append(TimestepEmbedSequential(*layers))
|
||||
@@ -657,11 +660,12 @@ class UNetModel(nn.Module):
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
down=True,
|
||||
dtype=self.dtype
|
||||
dtype=self.dtype,
|
||||
device=device,
|
||||
)
|
||||
if resblock_updown
|
||||
else Downsample(
|
||||
ch, conv_resample, dims=dims, out_channels=out_ch, dtype=self.dtype
|
||||
ch, conv_resample, dims=dims, out_channels=out_ch, dtype=self.dtype, device=device
|
||||
)
|
||||
)
|
||||
)
|
||||
@@ -686,7 +690,8 @@ class UNetModel(nn.Module):
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
dtype=self.dtype
|
||||
dtype=self.dtype,
|
||||
device=device,
|
||||
),
|
||||
AttentionBlock(
|
||||
ch,
|
||||
@@ -697,7 +702,7 @@ class UNetModel(nn.Module):
|
||||
) if not use_spatial_transformer else SpatialTransformer( # always uses a self-attn
|
||||
ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim,
|
||||
disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
|
||||
use_checkpoint=use_checkpoint, dtype=self.dtype
|
||||
use_checkpoint=use_checkpoint, dtype=self.dtype, device=device
|
||||
),
|
||||
ResBlock(
|
||||
ch,
|
||||
@@ -706,7 +711,8 @@ class UNetModel(nn.Module):
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
dtype=self.dtype
|
||||
dtype=self.dtype,
|
||||
device=device,
|
||||
),
|
||||
)
|
||||
self._feature_size += ch
|
||||
@@ -724,7 +730,8 @@ class UNetModel(nn.Module):
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
dtype=self.dtype
|
||||
dtype=self.dtype,
|
||||
device=device,
|
||||
)
|
||||
]
|
||||
ch = model_channels * mult
|
||||
@@ -753,7 +760,7 @@ class UNetModel(nn.Module):
|
||||
) if not use_spatial_transformer else SpatialTransformer(
|
||||
ch, num_heads, dim_head, depth=transformer_depth[level], context_dim=context_dim,
|
||||
disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
|
||||
use_checkpoint=use_checkpoint, dtype=self.dtype
|
||||
use_checkpoint=use_checkpoint, dtype=self.dtype, device=device
|
||||
)
|
||||
)
|
||||
if level and i == self.num_res_blocks[level]:
|
||||
@@ -768,24 +775,25 @@ class UNetModel(nn.Module):
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
up=True,
|
||||
dtype=self.dtype
|
||||
dtype=self.dtype,
|
||||
device=device,
|
||||
)
|
||||
if resblock_updown
|
||||
else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch, dtype=self.dtype)
|
||||
else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch, dtype=self.dtype, device=device)
|
||||
)
|
||||
ds //= 2
|
||||
self.output_blocks.append(TimestepEmbedSequential(*layers))
|
||||
self._feature_size += ch
|
||||
|
||||
self.out = nn.Sequential(
|
||||
nn.GroupNorm(32, ch, dtype=self.dtype),
|
||||
nn.GroupNorm(32, ch, dtype=self.dtype, device=device),
|
||||
nn.SiLU(),
|
||||
zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1, dtype=self.dtype)),
|
||||
zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1, dtype=self.dtype, device=device)),
|
||||
)
|
||||
if self.predict_codebook_ids:
|
||||
self.id_predictor = nn.Sequential(
|
||||
nn.GroupNorm(32, ch, dtype=self.dtype),
|
||||
conv_nd(dims, model_channels, n_embed, 1),
|
||||
nn.GroupNorm(32, ch, dtype=self.dtype, device=device),
|
||||
conv_nd(dims, model_channels, n_embed, 1, dtype=self.dtype, device=device),
|
||||
#nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
|
||||
)
|
||||
|
||||
|
Reference in New Issue
Block a user