1
mirror of https://github.com/comfyanonymous/ComfyUI.git synced 2025-08-02 23:14:49 +08:00

Stable Cascade Stage C.

This commit is contained in:
comfyanonymous
2024-02-16 10:55:08 -05:00
parent 5e06baf112
commit f83109f09b
11 changed files with 619 additions and 31 deletions

View File

@@ -28,9 +28,26 @@ def calculate_transformer_depth(prefix, state_dict_keys, state_dict):
return last_transformer_depth, context_dim, use_linear_in_transformer, time_stack
return None
def detect_unet_config(state_dict, key_prefix, dtype):
def detect_unet_config(state_dict, key_prefix):
state_dict_keys = list(state_dict.keys())
if '{}clf.1.weight'.format(key_prefix) in state_dict_keys: #stable cascade
unet_config = {}
text_mapper_name = '{}clip_txt_mapper.weight'.format(key_prefix)
if text_mapper_name in state_dict_keys:
unet_config['stable_cascade_stage'] = 'c'
w = state_dict[text_mapper_name]
if w.shape[0] == 1536: #stage c lite
unet_config['c_cond'] = 1536
unet_config['c_hidden'] = [1536, 1536]
unet_config['nhead'] = [24, 24]
unet_config['blocks'] = [[4, 12], [12, 4]]
elif w.shape[0] == 2048: #stage c full
unet_config['c_cond'] = 2048
elif '{}clip_mapper.weight'.format(key_prefix) in state_dict_keys:
unet_config['stable_cascade_stage'] = 'b'
return unet_config
unet_config = {
"use_checkpoint": False,
"image_size": 32,
@@ -45,7 +62,6 @@ def detect_unet_config(state_dict, key_prefix, dtype):
else:
unet_config["adm_in_channels"] = None
unet_config["dtype"] = dtype
model_channels = state_dict['{}input_blocks.0.0.weight'.format(key_prefix)].shape[0]
in_channels = state_dict['{}input_blocks.0.0.weight'.format(key_prefix)].shape[1]
@@ -159,8 +175,8 @@ def model_config_from_unet_config(unet_config):
print("no match", unet_config)
return None
def model_config_from_unet(state_dict, unet_key_prefix, dtype, use_base_if_no_match=False):
unet_config = detect_unet_config(state_dict, unet_key_prefix, dtype)
def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=False):
unet_config = detect_unet_config(state_dict, unet_key_prefix)
model_config = model_config_from_unet_config(unet_config)
if model_config is None and use_base_if_no_match:
return comfy.supported_models_base.BASE(unet_config)
@@ -206,7 +222,7 @@ def convert_config(unet_config):
return new_config
def unet_config_from_diffusers_unet(state_dict, dtype):
def unet_config_from_diffusers_unet(state_dict, dtype=None):
match = {}
transformer_depth = []
@@ -313,8 +329,8 @@ def unet_config_from_diffusers_unet(state_dict, dtype):
return convert_config(unet_config)
return None
def model_config_from_diffusers_unet(state_dict, dtype):
unet_config = unet_config_from_diffusers_unet(state_dict, dtype)
def model_config_from_diffusers_unet(state_dict):
unet_config = unet_config_from_diffusers_unet(state_dict)
if unet_config is not None:
return model_config_from_unet_config(unet_config)
return None