mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-08-02 23:14:49 +08:00
Initial support for the stable audio open model.
This commit is contained in:
@@ -96,6 +96,11 @@ def detect_unet_config(state_dict, key_prefix):
|
||||
unet_config['block_repeat'] = [[1, 1, 1, 1], [2, 2, 2, 2]]
|
||||
return unet_config
|
||||
|
||||
if '{}transformer.rotary_pos_emb.inv_freq'.format(key_prefix) in state_dict_keys: #stable audio dit
|
||||
unet_config = {}
|
||||
unet_config["audio_model"] = "dit1.0"
|
||||
return unet_config
|
||||
|
||||
unet_config = {
|
||||
"use_checkpoint": False,
|
||||
"image_size": 32,
|
||||
@@ -236,6 +241,13 @@ def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=Fal
|
||||
else:
|
||||
return model_config
|
||||
|
||||
def unet_prefix_from_state_dict(state_dict):
|
||||
if "model.model.postprocess_conv.weight" in state_dict: #audio models
|
||||
unet_key_prefix = "model.model."
|
||||
else:
|
||||
unet_key_prefix = "model.diffusion_model."
|
||||
return unet_key_prefix
|
||||
|
||||
def convert_config(unet_config):
|
||||
new_config = unet_config.copy()
|
||||
num_res_blocks = new_config.get("num_res_blocks", None)
|
||||
|
Reference in New Issue
Block a user