1
mirror of https://github.com/comfyanonymous/ComfyUI.git synced 2025-08-02 15:04:50 +08:00

Support AuraFlow Lora and loading model weights in diffusers format.

You can load model weights in diffusers format using the UNETLoader node.
This commit is contained in:
comfyanonymous
2024-07-13 13:51:40 -04:00
parent ce2473bb01
commit a3dffc447a
4 changed files with 133 additions and 44 deletions

View File

@@ -109,6 +109,10 @@ def detect_unet_config(state_dict, key_prefix):
unet_config = {}
unet_config["max_seq"] = state_dict['{}positional_encoding'.format(key_prefix)].shape[1]
unet_config["cond_seq_dim"] = state_dict['{}cond_seq_linear.weight'.format(key_prefix)].shape[1]
double_layers = count_blocks(state_dict_keys, '{}double_layers.'.format(key_prefix) + '{}.')
single_layers = count_blocks(state_dict_keys, '{}single_layers.'.format(key_prefix) + '{}.')
unet_config["n_double_layers"] = double_layers
unet_config["n_layers"] = double_layers + single_layers
return unet_config
if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys:
@@ -450,37 +454,45 @@ def model_config_from_diffusers_unet(state_dict):
return None
def convert_diffusers_mmdit(state_dict, output_prefix=""):
num_blocks = count_blocks(state_dict, 'transformer_blocks.{}.')
if num_blocks > 0:
out_sd = {}
if 'transformer_blocks.0.attn.add_q_proj.weight' in state_dict: #SD3
num_blocks = count_blocks(state_dict, 'transformer_blocks.{}.')
depth = state_dict["pos_embed.proj.weight"].shape[0] // 64
out_sd = {}
sd_map = comfy.utils.mmdit_to_diffusers({"depth": depth, "num_blocks": num_blocks}, output_prefix=output_prefix)
for k in sd_map:
weight = state_dict.get(k, None)
if weight is not None:
t = sd_map[k]
elif 'joint_transformer_blocks.0.attn.add_k_proj.weight' in state_dict: #AuraFlow
num_joint = count_blocks(state_dict, 'joint_transformer_blocks.{}.')
num_single = count_blocks(state_dict, 'single_transformer_blocks.{}.')
sd_map = comfy.utils.auraflow_to_diffusers({"n_double_layers": num_joint, "n_layers": num_joint + num_single}, output_prefix=output_prefix)
else:
return None
if not isinstance(t, str):
if len(t) > 2:
fun = t[2]
else:
fun = lambda a: a
offset = t[1]
if offset is not None:
old_weight = out_sd.get(t[0], None)
if old_weight is None:
old_weight = torch.empty_like(weight)
old_weight = old_weight.repeat([3] + [1] * (len(old_weight.shape) - 1))
for k in sd_map:
weight = state_dict.get(k, None)
if weight is not None:
t = sd_map[k]
w = old_weight.narrow(offset[0], offset[1], offset[2])
else:
old_weight = weight
w = weight
w[:] = fun(weight)
t = t[0]
out_sd[t] = old_weight
if not isinstance(t, str):
if len(t) > 2:
fun = t[2]
else:
out_sd[t] = weight
state_dict.pop(k)
fun = lambda a: a
offset = t[1]
if offset is not None:
old_weight = out_sd.get(t[0], None)
if old_weight is None:
old_weight = torch.empty_like(weight)
old_weight = old_weight.repeat([3] + [1] * (len(old_weight.shape) - 1))
w = old_weight.narrow(offset[0], offset[1], offset[2])
else:
old_weight = weight
w = weight
w[:] = fun(weight)
t = t[0]
out_sd[t] = old_weight
else:
out_sd[t] = weight
state_dict.pop(k)
return out_sd