1
mirror of https://github.com/comfyanonymous/ComfyUI.git synced 2025-08-02 15:04:50 +08:00

Support loading diffusers SD3 model format with UNETLoader node.

This commit is contained in:
comfyanonymous
2024-06-19 21:46:37 -04:00
parent b08a9dd04b
commit 0d6a57938e
4 changed files with 84 additions and 5 deletions

View File

@@ -1,7 +1,9 @@
import comfy.supported_models
import comfy.supported_models_base
import comfy.utils
import math
import logging
import torch
def count_blocks(state_dict_keys, prefix_string):
count = 0
@@ -431,3 +433,38 @@ def model_config_from_diffusers_unet(state_dict):
if unet_config is not None:
return model_config_from_unet_config(unet_config)
return None
def convert_diffusers_mmdit(state_dict, output_prefix=""):
depth = count_blocks(state_dict, 'transformer_blocks.{}.')
if depth > 0:
out_sd = {}
sd_map = comfy.utils.mmdit_to_diffusers({"depth": depth}, output_prefix=output_prefix)
for k in sd_map:
weight = state_dict.get(k, None)
if weight is not None:
t = sd_map[k]
if not isinstance(t, str):
if len(t) > 2:
fun = t[2]
else:
fun = lambda a: a
offset = t[1]
if offset is not None:
old_weight = out_sd.get(t[0], None)
if old_weight is None:
old_weight = torch.empty_like(weight)
old_weight = old_weight.repeat([3] + [1] * (len(old_weight.shape) - 1))
w = old_weight.narrow(offset[0], offset[1], offset[2])
else:
old_weight = weight
w = weight
w[:] = fun(weight)
t = t[0]
out_sd[t] = old_weight
else:
out_sd[t] = weight
state_dict.pop(k)
return out_sd