1
mirror of https://github.com/comfyanonymous/ComfyUI.git synced 2025-08-02 23:14:49 +08:00

SD3 Support.

This commit is contained in:
comfyanonymous
2024-06-10 13:26:25 -04:00
parent a5e6a632f9
commit 8c4a9befa7
17 changed files with 132182 additions and 5 deletions

View File

@@ -1,5 +1,6 @@
import comfy.supported_models
import comfy.supported_models_base
import math
import logging
def count_blocks(state_dict_keys, prefix_string):
@@ -26,12 +27,47 @@ def calculate_transformer_depth(prefix, state_dict_keys, state_dict):
context_dim = state_dict['{}0.attn2.to_k.weight'.format(transformer_prefix)].shape[1]
use_linear_in_transformer = len(state_dict['{}1.proj_in.weight'.format(prefix)].shape) == 2
time_stack = '{}1.time_stack.0.attn1.to_q.weight'.format(prefix) in state_dict or '{}1.time_mix_blocks.0.attn1.to_q.weight'.format(prefix) in state_dict
return last_transformer_depth, context_dim, use_linear_in_transformer, time_stack
time_stack_cross = '{}1.time_stack.0.attn2.to_q.weight'.format(prefix) in state_dict or '{}1.time_mix_blocks.0.attn2.to_q.weight'.format(prefix) in state_dict
return last_transformer_depth, context_dim, use_linear_in_transformer, time_stack, time_stack_cross
return None
def detect_unet_config(state_dict, key_prefix):
state_dict_keys = list(state_dict.keys())
if '{}joint_blocks.0.context_block.attn.qkv.weight'.format(key_prefix) in state_dict_keys: #mmdit model
unet_config = {}
unet_config["in_channels"] = state_dict['{}x_embedder.proj.weight'.format(key_prefix)].shape[1]
patch_size = state_dict['{}x_embedder.proj.weight'.format(key_prefix)].shape[2]
unet_config["patch_size"] = patch_size
unet_config["out_channels"] = state_dict['{}final_layer.linear.weight'.format(key_prefix)].shape[0] // (patch_size * patch_size)
unet_config["depth"] = state_dict['{}x_embedder.proj.weight'.format(key_prefix)].shape[0] // 64
unet_config["input_size"] = None
y_key = '{}y_embedder.mlp.0.weight'.format(key_prefix)
if y_key in state_dict_keys:
unet_config["adm_in_channels"] = state_dict[y_key].shape[1]
context_key = '{}context_embedder.weight'.format(key_prefix)
if context_key in state_dict_keys:
in_features = state_dict[context_key].shape[1]
out_features = state_dict[context_key].shape[0]
unet_config["context_embedder_config"] = {"target": "torch.nn.Linear", "params": {"in_features": in_features, "out_features": out_features}}
num_patches_key = '{}pos_embed'.format(key_prefix)
if num_patches_key in state_dict_keys:
num_patches = state_dict[num_patches_key].shape[1]
unet_config["num_patches"] = num_patches
unet_config["pos_embed_max_size"] = round(math.sqrt(num_patches))
rms_qk = '{}joint_blocks.0.context_block.attn.ln_q.weight'.format(key_prefix)
if rms_qk in state_dict_keys:
unet_config["qk_norm"] = "rms"
unet_config["pos_embed_scaling_factor"] = None #unused for inference
context_processor = '{}context_processor.layers.0.attn.qkv.weight'.format(key_prefix)
if context_processor in state_dict_keys:
unet_config["context_processor_layers"] = count_blocks(state_dict_keys, '{}context_processor.layers.'.format(key_prefix) + '{}.')
return unet_config
if '{}clf.1.weight'.format(key_prefix) in state_dict_keys: #stable cascade
unet_config = {}
text_mapper_name = '{}clip_txt_mapper.weight'.format(key_prefix)
@@ -58,7 +94,6 @@ def detect_unet_config(state_dict, key_prefix):
unet_config['nhead'] = [-1, 9, 18, 18]
unet_config['blocks'] = [[2, 4, 14, 4], [4, 14, 4, 2]]
unet_config['block_repeat'] = [[1, 1, 1, 1], [2, 2, 2, 2]]
return unet_config
unet_config = {
@@ -93,6 +128,7 @@ def detect_unet_config(state_dict, key_prefix):
use_linear_in_transformer = False
video_model = False
video_model_cross = False
current_res = 1
count = 0
@@ -136,6 +172,7 @@ def detect_unet_config(state_dict, key_prefix):
context_dim = out[1]
use_linear_in_transformer = out[2]
video_model = out[3]
video_model_cross = out[4]
else:
transformer_depth.append(0)
@@ -176,6 +213,7 @@ def detect_unet_config(state_dict, key_prefix):
unet_config["video_kernel_size"] = [3, 1, 1]
unet_config["use_temporal_resblock"] = True
unet_config["use_temporal_attention"] = True
unet_config["disable_temporal_crossattention"] = not video_model_cross
else:
unet_config["use_temporal_resblock"] = False
unet_config["use_temporal_attention"] = False