mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-08-02 23:14:49 +08:00
AuraFlow model implementation.
This commit is contained in:
@@ -105,6 +105,12 @@ def detect_unet_config(state_dict, key_prefix):
|
||||
unet_config["audio_model"] = "dit1.0"
|
||||
return unet_config
|
||||
|
||||
if '{}double_layers.0.attn.w1q.weight'.format(key_prefix) in state_dict_keys: #aura flow dit
|
||||
unet_config = {}
|
||||
unet_config["max_seq"] = state_dict['{}positional_encoding'.format(key_prefix)].shape[1]
|
||||
unet_config["cond_seq_dim"] = state_dict['{}cond_seq_linear.weight'.format(key_prefix)].shape[1]
|
||||
return unet_config
|
||||
|
||||
if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys:
|
||||
return None
|
||||
|
||||
@@ -253,6 +259,8 @@ def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=Fal
|
||||
def unet_prefix_from_state_dict(state_dict):
|
||||
if "model.model.postprocess_conv.weight" in state_dict: #audio models
|
||||
unet_key_prefix = "model.model."
|
||||
elif "model.double_layers.0.attn.w1q.weight" in state_dict: #aura flow
|
||||
unet_key_prefix = "model."
|
||||
else:
|
||||
unet_key_prefix = "model.diffusion_model."
|
||||
return unet_key_prefix
|
||||
|
Reference in New Issue
Block a user