1
mirror of https://github.com/comfyanonymous/ComfyUI.git synced 2025-08-02 23:14:49 +08:00

Support SSD1B model and make it easier to support asymmetric unets.

This commit is contained in:
comfyanonymous
2023-10-27 14:15:45 -04:00
parent 434ce25ec0
commit 6ec3f12c6e
6 changed files with 153 additions and 96 deletions

View File

@@ -14,6 +14,19 @@ def count_blocks(state_dict_keys, prefix_string):
count += 1
return count
def calculate_transformer_depth(prefix, state_dict_keys, state_dict):
context_dim = None
use_linear_in_transformer = False
transformer_prefix = prefix + "1.transformer_blocks."
transformer_keys = sorted(list(filter(lambda a: a.startswith(transformer_prefix), state_dict_keys)))
if len(transformer_keys) > 0:
last_transformer_depth = count_blocks(state_dict_keys, transformer_prefix + '{}')
context_dim = state_dict['{}0.attn2.to_k.weight'.format(transformer_prefix)].shape[1]
use_linear_in_transformer = len(state_dict['{}1.proj_in.weight'.format(prefix)].shape) == 2
return last_transformer_depth, context_dim, use_linear_in_transformer
return None
def detect_unet_config(state_dict, key_prefix, dtype):
state_dict_keys = list(state_dict.keys())
@@ -40,6 +53,7 @@ def detect_unet_config(state_dict, key_prefix, dtype):
channel_mult = []
attention_resolutions = []
transformer_depth = []
transformer_depth_output = []
context_dim = None
use_linear_in_transformer = False
@@ -48,60 +62,67 @@ def detect_unet_config(state_dict, key_prefix, dtype):
count = 0
last_res_blocks = 0
last_transformer_depth = 0
last_channel_mult = 0
while True:
input_block_count = count_blocks(state_dict_keys, '{}input_blocks'.format(key_prefix) + '.{}.')
for count in range(input_block_count):
prefix = '{}input_blocks.{}.'.format(key_prefix, count)
prefix_output = '{}output_blocks.{}.'.format(key_prefix, input_block_count - count - 1)
block_keys = sorted(list(filter(lambda a: a.startswith(prefix), state_dict_keys)))
if len(block_keys) == 0:
break
block_keys_output = sorted(list(filter(lambda a: a.startswith(prefix_output), state_dict_keys)))
if "{}0.op.weight".format(prefix) in block_keys: #new layer
if last_transformer_depth > 0:
attention_resolutions.append(current_res)
transformer_depth.append(last_transformer_depth)
num_res_blocks.append(last_res_blocks)
channel_mult.append(last_channel_mult)
current_res *= 2
last_res_blocks = 0
last_transformer_depth = 0
last_channel_mult = 0
out = calculate_transformer_depth(prefix_output, state_dict_keys, state_dict)
if out is not None:
transformer_depth_output.append(out[0])
else:
transformer_depth_output.append(0)
else:
res_block_prefix = "{}0.in_layers.0.weight".format(prefix)
if res_block_prefix in block_keys:
last_res_blocks += 1
last_channel_mult = state_dict["{}0.out_layers.3.weight".format(prefix)].shape[0] // model_channels
transformer_prefix = prefix + "1.transformer_blocks."
transformer_keys = sorted(list(filter(lambda a: a.startswith(transformer_prefix), state_dict_keys)))
if len(transformer_keys) > 0:
last_transformer_depth = count_blocks(state_dict_keys, transformer_prefix + '{}')
if context_dim is None:
context_dim = state_dict['{}0.attn2.to_k.weight'.format(transformer_prefix)].shape[1]
use_linear_in_transformer = len(state_dict['{}1.proj_in.weight'.format(prefix)].shape) == 2
out = calculate_transformer_depth(prefix, state_dict_keys, state_dict)
if out is not None:
transformer_depth.append(out[0])
if context_dim is None:
context_dim = out[1]
use_linear_in_transformer = out[2]
else:
transformer_depth.append(0)
res_block_prefix = "{}0.in_layers.0.weight".format(prefix_output)
if res_block_prefix in block_keys_output:
out = calculate_transformer_depth(prefix_output, state_dict_keys, state_dict)
if out is not None:
transformer_depth_output.append(out[0])
else:
transformer_depth_output.append(0)
count += 1
if last_transformer_depth > 0:
attention_resolutions.append(current_res)
transformer_depth.append(last_transformer_depth)
num_res_blocks.append(last_res_blocks)
channel_mult.append(last_channel_mult)
transformer_depth_middle = count_blocks(state_dict_keys, '{}middle_block.1.transformer_blocks.'.format(key_prefix) + '{}')
if len(set(num_res_blocks)) == 1:
num_res_blocks = num_res_blocks[0]
if len(set(transformer_depth)) == 1:
transformer_depth = transformer_depth[0]
if "{}middle_block.1.proj_in.weight".format(key_prefix) in state_dict_keys:
transformer_depth_middle = count_blocks(state_dict_keys, '{}middle_block.1.transformer_blocks.'.format(key_prefix) + '{}')
else:
transformer_depth_middle = -1
unet_config["in_channels"] = in_channels
unet_config["model_channels"] = model_channels
unet_config["num_res_blocks"] = num_res_blocks
unet_config["attention_resolutions"] = attention_resolutions
unet_config["transformer_depth"] = transformer_depth
unet_config["transformer_depth_output"] = transformer_depth_output
unet_config["channel_mult"] = channel_mult
unet_config["transformer_depth_middle"] = transformer_depth_middle
unet_config['use_linear_in_transformer'] = use_linear_in_transformer
@@ -124,6 +145,45 @@ def model_config_from_unet(state_dict, unet_key_prefix, dtype, use_base_if_no_ma
else:
return model_config
def convert_config(unet_config):
new_config = unet_config.copy()
num_res_blocks = new_config.get("num_res_blocks", None)
channel_mult = new_config.get("channel_mult", None)
if isinstance(num_res_blocks, int):
num_res_blocks = len(channel_mult) * [num_res_blocks]
if "attention_resolutions" in new_config:
attention_resolutions = new_config.pop("attention_resolutions")
transformer_depth = new_config.get("transformer_depth", None)
transformer_depth_middle = new_config.get("transformer_depth_middle", None)
if isinstance(transformer_depth, int):
transformer_depth = len(channel_mult) * [transformer_depth]
if transformer_depth_middle is None:
transformer_depth_middle = transformer_depth[-1]
t_in = []
t_out = []
s = 1
for i in range(len(num_res_blocks)):
res = num_res_blocks[i]
d = 0
if s in attention_resolutions:
d = transformer_depth[i]
t_in += [d] * res
t_out += [d] * (res + 1)
s *= 2
transformer_depth = t_in
transformer_depth_output = t_out
new_config["transformer_depth"] = t_in
new_config["transformer_depth_output"] = t_out
new_config["transformer_depth_middle"] = transformer_depth_middle
new_config["num_res_blocks"] = num_res_blocks
return new_config
def unet_config_from_diffusers_unet(state_dict, dtype):
match = {}
attention_resolutions = []
@@ -200,7 +260,7 @@ def unet_config_from_diffusers_unet(state_dict, dtype):
matches = False
break
if matches:
return unet_config
return convert_config(unet_config)
return None
def model_config_from_diffusers_unet(state_dict, dtype):