1
mirror of https://github.com/comfyanonymous/ComfyUI.git synced 2025-08-02 23:14:49 +08:00

Basic SD3 controlnet implementation.

Still missing the node to properly use it.
This commit is contained in:
comfyanonymous
2024-06-25 23:40:44 -04:00
parent 66aaa14001
commit f8f7568d03
5 changed files with 165 additions and 15 deletions

View File

@@ -745,6 +745,8 @@ class MMDiT(nn.Module):
qkv_bias: bool = True,
context_processor_layers = None,
context_size = 4096,
num_blocks = None,
final_layer = True,
dtype = None, #TODO
device = None,
operations = None,
@@ -766,7 +768,10 @@ class MMDiT(nn.Module):
# apply magic --> this defines a head_size of 64
self.hidden_size = 64 * depth
num_heads = depth
if num_blocks is None:
num_blocks = depth
self.depth = depth
self.num_heads = num_heads
self.x_embedder = PatchEmbed(
@@ -821,7 +826,7 @@ class MMDiT(nn.Module):
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
attn_mode=attn_mode,
pre_only=i == depth - 1,
pre_only=(i == num_blocks - 1) and final_layer,
rmsnorm=rmsnorm,
scale_mod_only=scale_mod_only,
swiglu=swiglu,
@@ -830,11 +835,12 @@ class MMDiT(nn.Module):
device=device,
operations=operations
)
for i in range(depth)
for i in range(num_blocks)
]
)
self.final_layer = FinalLayer(self.hidden_size, patch_size, self.out_channels, dtype=dtype, device=device, operations=operations)
if final_layer:
self.final_layer = FinalLayer(self.hidden_size, patch_size, self.out_channels, dtype=dtype, device=device, operations=operations)
if compile_core:
assert False
@@ -893,6 +899,7 @@ class MMDiT(nn.Module):
x: torch.Tensor,
c_mod: torch.Tensor,
context: Optional[torch.Tensor] = None,
control = None,
) -> torch.Tensor:
if self.register_length > 0:
context = torch.cat(
@@ -905,13 +912,20 @@ class MMDiT(nn.Module):
# context is B, L', D
# x is B, L, D
for block in self.joint_blocks:
context, x = block(
blocks = len(self.joint_blocks)
for i in range(blocks):
context, x = self.joint_blocks[i](
context,
x,
c=c_mod,
use_checkpoint=self.use_checkpoint,
)
if control is not None:
control_o = control.get("output")
if i < len(control_o):
add = control_o[i]
if add is not None:
x += add
x = self.final_layer(x, c_mod) # (N, T, patch_size ** 2 * out_channels)
return x
@@ -922,6 +936,7 @@ class MMDiT(nn.Module):
t: torch.Tensor,
y: Optional[torch.Tensor] = None,
context: Optional[torch.Tensor] = None,
control = None,
) -> torch.Tensor:
"""
Forward pass of DiT.
@@ -943,7 +958,7 @@ class MMDiT(nn.Module):
if context is not None:
context = self.context_embedder(context)
x = self.forward_core_with_concat(x, c, context)
x = self.forward_core_with_concat(x, c, context, control)
x = self.unpatchify(x, hw=hw) # (N, out_channels, H, W)
return x[:,:,:hw[-2],:hw[-1]]
@@ -956,7 +971,8 @@ class OpenAISignatureMMDITWrapper(MMDiT):
timesteps: torch.Tensor,
context: Optional[torch.Tensor] = None,
y: Optional[torch.Tensor] = None,
control = None,
**kwargs,
) -> torch.Tensor:
return super().forward(x, timesteps, context=context, y=y)
return super().forward(x, timesteps, context=context, y=y, control=control)