1
mirror of https://github.com/comfyanonymous/ComfyUI.git synced 2025-08-02 23:14:49 +08:00

AuraFlow model implementation.

This commit is contained in:
comfyanonymous
2024-07-11 16:51:06 -04:00
parent f45157e3ac
commit 9f291d75b3
12 changed files with 1744 additions and 2 deletions

View File

@@ -105,6 +105,12 @@ def detect_unet_config(state_dict, key_prefix):
unet_config["audio_model"] = "dit1.0"
return unet_config
if '{}double_layers.0.attn.w1q.weight'.format(key_prefix) in state_dict_keys: #aura flow dit
unet_config = {}
unet_config["max_seq"] = state_dict['{}positional_encoding'.format(key_prefix)].shape[1]
unet_config["cond_seq_dim"] = state_dict['{}cond_seq_linear.weight'.format(key_prefix)].shape[1]
return unet_config
if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys:
return None
@@ -253,6 +259,8 @@ def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=Fal
def unet_prefix_from_state_dict(state_dict):
if "model.model.postprocess_conv.weight" in state_dict: #audio models
unet_key_prefix = "model.model."
elif "model.double_layers.0.attn.w1q.weight" in state_dict: #aura flow
unet_key_prefix = "model."
else:
unet_key_prefix = "model.diffusion_model."
return unet_key_prefix