1
mirror of https://github.com/comfyanonymous/ComfyUI.git synced 2025-08-02 23:14:49 +08:00

Support loading diffusers SD3 model format with UNETLoader node.

This commit is contained in:
comfyanonymous
2024-06-19 21:46:37 -04:00
parent b08a9dd04b
commit 0d6a57938e
4 changed files with 84 additions and 5 deletions

View File

@@ -249,6 +249,11 @@ def unet_to_diffusers(unet_config):
return diffusers_unet_map
def swap_scale_shift(weight):
shift, scale = weight.chunk(2, dim=0)
new_weight = torch.cat([scale, shift], dim=0)
return new_weight
MMDIT_MAP_BASIC = {
("context_embedder.bias", "context_embedder.bias"),
("context_embedder.weight", "context_embedder.weight"),
@@ -263,8 +268,8 @@ MMDIT_MAP_BASIC = {
("y_embedder.mlp.2.bias", "time_text_embed.text_embedder.linear_2.bias"),
("y_embedder.mlp.2.weight", "time_text_embed.text_embedder.linear_2.weight"),
("pos_embed", "pos_embed.pos_embed"),
("final_layer.adaLN_modulation.1.bias", "norm_out.linear.bias"),
("final_layer.adaLN_modulation.1.weight", "norm_out.linear.weight"),
("final_layer.adaLN_modulation.1.bias", "norm_out.linear.bias", swap_scale_shift),
("final_layer.adaLN_modulation.1.weight", "norm_out.linear.weight", swap_scale_shift),
("final_layer.linear.bias", "proj_out.bias"),
("final_layer.linear.weight", "proj_out.weight"),
}
@@ -313,8 +318,15 @@ def mmdit_to_diffusers(mmdit_config, output_prefix=""):
for k in MMDIT_MAP_BLOCK:
key_map["{}.{}".format(block_from, k[1])] = "{}.{}".format(block_to, k[0])
for k in MMDIT_MAP_BASIC:
key_map[k[1]] = "{}{}".format(output_prefix, k[0])
map_basic = MMDIT_MAP_BASIC.copy()
map_basic.add(("joint_blocks.{}.context_block.adaLN_modulation.1.bias".format(depth - 1), "transformer_blocks.{}.norm1_context.linear.bias".format(depth - 1), swap_scale_shift))
map_basic.add(("joint_blocks.{}.context_block.adaLN_modulation.1.weight".format(depth - 1), "transformer_blocks.{}.norm1_context.linear.weight".format(depth - 1), swap_scale_shift))
for k in map_basic:
if len(k) > 2:
key_map[k[1]] = ("{}{}".format(output_prefix, k[0]), None, k[2])
else:
key_map[k[1]] = "{}{}".format(output_prefix, k[0])
return key_map