mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-08-02 23:14:49 +08:00
Support controlnet in diffusers format.
This commit is contained in:
@@ -118,3 +118,57 @@ def model_config_from_unet_config(unet_config):
|
||||
def model_config_from_unet(state_dict, unet_key_prefix, use_fp16):
|
||||
unet_config = detect_unet_config(state_dict, unet_key_prefix, use_fp16)
|
||||
return model_config_from_unet_config(unet_config)
|
||||
|
||||
|
||||
def model_config_from_diffusers_unet(state_dict, use_fp16):
|
||||
match = {}
|
||||
match["context_dim"] = state_dict["down_blocks.1.attentions.1.transformer_blocks.0.attn2.to_k.weight"].shape[1]
|
||||
match["model_channels"] = state_dict["conv_in.weight"].shape[0]
|
||||
match["in_channels"] = state_dict["conv_in.weight"].shape[1]
|
||||
match["adm_in_channels"] = None
|
||||
if "class_embedding.linear_1.weight" in state_dict:
|
||||
match["adm_in_channels"] = state_dict["class_embedding.linear_1.weight"].shape[1]
|
||||
elif "add_embedding.linear_1.weight" in state_dict:
|
||||
match["adm_in_channels"] = state_dict["add_embedding.linear_1.weight"].shape[1]
|
||||
|
||||
SDXL = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
||||
'num_classes': 'sequential', 'adm_in_channels': 2816, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 320,
|
||||
'num_res_blocks': 2, 'attention_resolutions': [2, 4], 'transformer_depth': [0, 2, 10], 'channel_mult': [1, 2, 4],
|
||||
'transformer_depth_middle': 10, 'use_linear_in_transformer': True, 'context_dim': 2048}
|
||||
|
||||
SDXL_refiner = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
||||
'num_classes': 'sequential', 'adm_in_channels': 2560, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 384,
|
||||
'num_res_blocks': 2, 'attention_resolutions': [2, 4], 'transformer_depth': [0, 4, 4, 0], 'channel_mult': [1, 2, 4, 4],
|
||||
'transformer_depth_middle': 4, 'use_linear_in_transformer': True, 'context_dim': 1280}
|
||||
|
||||
SD21 = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
||||
'adm_in_channels': None, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': 2,
|
||||
'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4],
|
||||
'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024}
|
||||
|
||||
SD21_uncliph = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
||||
'num_classes': 'sequential', 'adm_in_channels': 2048, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 320,
|
||||
'num_res_blocks': 2, 'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4],
|
||||
'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024}
|
||||
|
||||
SD21_unclipl = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
||||
'num_classes': 'sequential', 'adm_in_channels': 1536, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 320,
|
||||
'num_res_blocks': 2, 'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4],
|
||||
'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024}
|
||||
|
||||
SD15 = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
||||
'adm_in_channels': None, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': 2,
|
||||
'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4],
|
||||
'transformer_depth_middle': 1, 'use_linear_in_transformer': False, 'context_dim': 768}
|
||||
|
||||
supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl]
|
||||
|
||||
for unet_config in supported_models:
|
||||
matches = True
|
||||
for k in match:
|
||||
if match[k] != unet_config[k]:
|
||||
matches = False
|
||||
break
|
||||
if matches:
|
||||
return model_config_from_unet_config(unet_config)
|
||||
return None
|
||||
|
Reference in New Issue
Block a user