mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-08-02 23:14:49 +08:00
Basic SD3 controlnet implementation.
Still missing the node to properly use it.
This commit is contained in:
@@ -745,6 +745,8 @@ class MMDiT(nn.Module):
|
||||
qkv_bias: bool = True,
|
||||
context_processor_layers = None,
|
||||
context_size = 4096,
|
||||
num_blocks = None,
|
||||
final_layer = True,
|
||||
dtype = None, #TODO
|
||||
device = None,
|
||||
operations = None,
|
||||
@@ -766,7 +768,10 @@ class MMDiT(nn.Module):
|
||||
# apply magic --> this defines a head_size of 64
|
||||
self.hidden_size = 64 * depth
|
||||
num_heads = depth
|
||||
if num_blocks is None:
|
||||
num_blocks = depth
|
||||
|
||||
self.depth = depth
|
||||
self.num_heads = num_heads
|
||||
|
||||
self.x_embedder = PatchEmbed(
|
||||
@@ -821,7 +826,7 @@ class MMDiT(nn.Module):
|
||||
mlp_ratio=mlp_ratio,
|
||||
qkv_bias=qkv_bias,
|
||||
attn_mode=attn_mode,
|
||||
pre_only=i == depth - 1,
|
||||
pre_only=(i == num_blocks - 1) and final_layer,
|
||||
rmsnorm=rmsnorm,
|
||||
scale_mod_only=scale_mod_only,
|
||||
swiglu=swiglu,
|
||||
@@ -830,11 +835,12 @@ class MMDiT(nn.Module):
|
||||
device=device,
|
||||
operations=operations
|
||||
)
|
||||
for i in range(depth)
|
||||
for i in range(num_blocks)
|
||||
]
|
||||
)
|
||||
|
||||
self.final_layer = FinalLayer(self.hidden_size, patch_size, self.out_channels, dtype=dtype, device=device, operations=operations)
|
||||
if final_layer:
|
||||
self.final_layer = FinalLayer(self.hidden_size, patch_size, self.out_channels, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
if compile_core:
|
||||
assert False
|
||||
@@ -893,6 +899,7 @@ class MMDiT(nn.Module):
|
||||
x: torch.Tensor,
|
||||
c_mod: torch.Tensor,
|
||||
context: Optional[torch.Tensor] = None,
|
||||
control = None,
|
||||
) -> torch.Tensor:
|
||||
if self.register_length > 0:
|
||||
context = torch.cat(
|
||||
@@ -905,13 +912,20 @@ class MMDiT(nn.Module):
|
||||
|
||||
# context is B, L', D
|
||||
# x is B, L, D
|
||||
for block in self.joint_blocks:
|
||||
context, x = block(
|
||||
blocks = len(self.joint_blocks)
|
||||
for i in range(blocks):
|
||||
context, x = self.joint_blocks[i](
|
||||
context,
|
||||
x,
|
||||
c=c_mod,
|
||||
use_checkpoint=self.use_checkpoint,
|
||||
)
|
||||
if control is not None:
|
||||
control_o = control.get("output")
|
||||
if i < len(control_o):
|
||||
add = control_o[i]
|
||||
if add is not None:
|
||||
x += add
|
||||
|
||||
x = self.final_layer(x, c_mod) # (N, T, patch_size ** 2 * out_channels)
|
||||
return x
|
||||
@@ -922,6 +936,7 @@ class MMDiT(nn.Module):
|
||||
t: torch.Tensor,
|
||||
y: Optional[torch.Tensor] = None,
|
||||
context: Optional[torch.Tensor] = None,
|
||||
control = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Forward pass of DiT.
|
||||
@@ -943,7 +958,7 @@ class MMDiT(nn.Module):
|
||||
if context is not None:
|
||||
context = self.context_embedder(context)
|
||||
|
||||
x = self.forward_core_with_concat(x, c, context)
|
||||
x = self.forward_core_with_concat(x, c, context, control)
|
||||
|
||||
x = self.unpatchify(x, hw=hw) # (N, out_channels, H, W)
|
||||
return x[:,:,:hw[-2],:hw[-1]]
|
||||
@@ -956,7 +971,8 @@ class OpenAISignatureMMDITWrapper(MMDiT):
|
||||
timesteps: torch.Tensor,
|
||||
context: Optional[torch.Tensor] = None,
|
||||
y: Optional[torch.Tensor] = None,
|
||||
control = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
return super().forward(x, timesteps, context=context, y=y)
|
||||
return super().forward(x, timesteps, context=context, y=y, control=control)
|
||||
|
||||
|
Reference in New Issue
Block a user