mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-08-02 15:04:50 +08:00
Support loading diffusers SD3 model format with UNETLoader node.
This commit is contained in:
@@ -1,7 +1,9 @@
|
||||
import comfy.supported_models
|
||||
import comfy.supported_models_base
|
||||
import comfy.utils
|
||||
import math
|
||||
import logging
|
||||
import torch
|
||||
|
||||
def count_blocks(state_dict_keys, prefix_string):
|
||||
count = 0
|
||||
@@ -431,3 +433,38 @@ def model_config_from_diffusers_unet(state_dict):
|
||||
if unet_config is not None:
|
||||
return model_config_from_unet_config(unet_config)
|
||||
return None
|
||||
|
||||
def convert_diffusers_mmdit(state_dict, output_prefix=""):
|
||||
depth = count_blocks(state_dict, 'transformer_blocks.{}.')
|
||||
if depth > 0:
|
||||
out_sd = {}
|
||||
sd_map = comfy.utils.mmdit_to_diffusers({"depth": depth}, output_prefix=output_prefix)
|
||||
for k in sd_map:
|
||||
weight = state_dict.get(k, None)
|
||||
if weight is not None:
|
||||
t = sd_map[k]
|
||||
|
||||
if not isinstance(t, str):
|
||||
if len(t) > 2:
|
||||
fun = t[2]
|
||||
else:
|
||||
fun = lambda a: a
|
||||
offset = t[1]
|
||||
if offset is not None:
|
||||
old_weight = out_sd.get(t[0], None)
|
||||
if old_weight is None:
|
||||
old_weight = torch.empty_like(weight)
|
||||
old_weight = old_weight.repeat([3] + [1] * (len(old_weight.shape) - 1))
|
||||
|
||||
w = old_weight.narrow(offset[0], offset[1], offset[2])
|
||||
else:
|
||||
old_weight = weight
|
||||
w = weight
|
||||
w[:] = fun(weight)
|
||||
t = t[0]
|
||||
out_sd[t] = old_weight
|
||||
else:
|
||||
out_sd[t] = weight
|
||||
state_dict.pop(k)
|
||||
|
||||
return out_sd
|
||||
|
Reference in New Issue
Block a user