mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-08-02 23:14:49 +08:00
Stable Cascade Stage C.
This commit is contained in:
@@ -28,9 +28,26 @@ def calculate_transformer_depth(prefix, state_dict_keys, state_dict):
|
||||
return last_transformer_depth, context_dim, use_linear_in_transformer, time_stack
|
||||
return None
|
||||
|
||||
def detect_unet_config(state_dict, key_prefix, dtype):
|
||||
def detect_unet_config(state_dict, key_prefix):
|
||||
state_dict_keys = list(state_dict.keys())
|
||||
|
||||
if '{}clf.1.weight'.format(key_prefix) in state_dict_keys: #stable cascade
|
||||
unet_config = {}
|
||||
text_mapper_name = '{}clip_txt_mapper.weight'.format(key_prefix)
|
||||
if text_mapper_name in state_dict_keys:
|
||||
unet_config['stable_cascade_stage'] = 'c'
|
||||
w = state_dict[text_mapper_name]
|
||||
if w.shape[0] == 1536: #stage c lite
|
||||
unet_config['c_cond'] = 1536
|
||||
unet_config['c_hidden'] = [1536, 1536]
|
||||
unet_config['nhead'] = [24, 24]
|
||||
unet_config['blocks'] = [[4, 12], [12, 4]]
|
||||
elif w.shape[0] == 2048: #stage c full
|
||||
unet_config['c_cond'] = 2048
|
||||
elif '{}clip_mapper.weight'.format(key_prefix) in state_dict_keys:
|
||||
unet_config['stable_cascade_stage'] = 'b'
|
||||
return unet_config
|
||||
|
||||
unet_config = {
|
||||
"use_checkpoint": False,
|
||||
"image_size": 32,
|
||||
@@ -45,7 +62,6 @@ def detect_unet_config(state_dict, key_prefix, dtype):
|
||||
else:
|
||||
unet_config["adm_in_channels"] = None
|
||||
|
||||
unet_config["dtype"] = dtype
|
||||
model_channels = state_dict['{}input_blocks.0.0.weight'.format(key_prefix)].shape[0]
|
||||
in_channels = state_dict['{}input_blocks.0.0.weight'.format(key_prefix)].shape[1]
|
||||
|
||||
@@ -159,8 +175,8 @@ def model_config_from_unet_config(unet_config):
|
||||
print("no match", unet_config)
|
||||
return None
|
||||
|
||||
def model_config_from_unet(state_dict, unet_key_prefix, dtype, use_base_if_no_match=False):
|
||||
unet_config = detect_unet_config(state_dict, unet_key_prefix, dtype)
|
||||
def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=False):
|
||||
unet_config = detect_unet_config(state_dict, unet_key_prefix)
|
||||
model_config = model_config_from_unet_config(unet_config)
|
||||
if model_config is None and use_base_if_no_match:
|
||||
return comfy.supported_models_base.BASE(unet_config)
|
||||
@@ -206,7 +222,7 @@ def convert_config(unet_config):
|
||||
return new_config
|
||||
|
||||
|
||||
def unet_config_from_diffusers_unet(state_dict, dtype):
|
||||
def unet_config_from_diffusers_unet(state_dict, dtype=None):
|
||||
match = {}
|
||||
transformer_depth = []
|
||||
|
||||
@@ -313,8 +329,8 @@ def unet_config_from_diffusers_unet(state_dict, dtype):
|
||||
return convert_config(unet_config)
|
||||
return None
|
||||
|
||||
def model_config_from_diffusers_unet(state_dict, dtype):
|
||||
unet_config = unet_config_from_diffusers_unet(state_dict, dtype)
|
||||
def model_config_from_diffusers_unet(state_dict):
|
||||
unet_config = unet_config_from_diffusers_unet(state_dict)
|
||||
if unet_config is not None:
|
||||
return model_config_from_unet_config(unet_config)
|
||||
return None
|
||||
|
Reference in New Issue
Block a user