1
mirror of https://github.com/comfyanonymous/ComfyUI.git synced 2025-08-02 23:14:49 +08:00

Fix issues with #4302 and support loading diffusers format flux.

This commit is contained in:
comfyanonymous
2024-08-10 21:28:24 -04:00
parent 1765f1c60c
commit 75b9b55b22
2 changed files with 20 additions and 8 deletions

View File

@@ -495,7 +495,12 @@ def model_config_from_diffusers_unet(state_dict):
def convert_diffusers_mmdit(state_dict, output_prefix=""):
out_sd = {}
if 'transformer_blocks.0.attn.add_q_proj.weight' in state_dict: #SD3
if 'transformer_blocks.0.attn.norm_added_k.weight' in state_dict: #Flux
depth = count_blocks(state_dict, 'transformer_blocks.{}.')
depth_single_blocks = count_blocks(state_dict, 'single_transformer_blocks.{}.')
hidden_size = state_dict["x_embedder.bias"].shape[0]
sd_map = comfy.utils.flux_to_diffusers({"depth": depth, "depth_single_blocks": depth_single_blocks, "hidden_size": hidden_size}, output_prefix=output_prefix)
elif 'transformer_blocks.0.attn.add_q_proj.weight' in state_dict: #SD3
num_blocks = count_blocks(state_dict, 'transformer_blocks.{}.')
depth = state_dict["pos_embed.proj.weight"].shape[0] // 64
sd_map = comfy.utils.mmdit_to_diffusers({"depth": depth, "num_blocks": num_blocks}, output_prefix=output_prefix)
@@ -521,7 +526,12 @@ def convert_diffusers_mmdit(state_dict, output_prefix=""):
old_weight = out_sd.get(t[0], None)
if old_weight is None:
old_weight = torch.empty_like(weight)
old_weight = old_weight.repeat([3] + [1] * (len(old_weight.shape) - 1))
if old_weight.shape[offset[0]] < offset[1] + offset[2]:
exp = list(weight.shape)
exp[offset[0]] = offset[1] + offset[2]
new = torch.empty(exp, device=weight.device, dtype=weight.dtype)
new[:old_weight.shape[0]] = old_weight
old_weight = new
w = old_weight.narrow(offset[0], offset[1], offset[2])
else: