mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-08-03 07:26:31 +08:00
Support SSD1B model and make it easier to support asymmetric unets.
This commit is contained in:
@@ -170,25 +170,12 @@ UNET_MAP_BASIC = {
|
||||
|
||||
def unet_to_diffusers(unet_config):
|
||||
num_res_blocks = unet_config["num_res_blocks"]
|
||||
attention_resolutions = unet_config["attention_resolutions"]
|
||||
channel_mult = unet_config["channel_mult"]
|
||||
transformer_depth = unet_config["transformer_depth"]
|
||||
transformer_depth = unet_config["transformer_depth"][:]
|
||||
transformer_depth_output = unet_config["transformer_depth_output"][:]
|
||||
num_blocks = len(channel_mult)
|
||||
if isinstance(num_res_blocks, int):
|
||||
num_res_blocks = [num_res_blocks] * num_blocks
|
||||
if isinstance(transformer_depth, int):
|
||||
transformer_depth = [transformer_depth] * num_blocks
|
||||
|
||||
transformers_per_layer = []
|
||||
res = 1
|
||||
for i in range(num_blocks):
|
||||
transformers = 0
|
||||
if res in attention_resolutions:
|
||||
transformers = transformer_depth[i]
|
||||
transformers_per_layer.append(transformers)
|
||||
res *= 2
|
||||
|
||||
transformers_mid = unet_config.get("transformer_depth_middle", transformer_depth[-1])
|
||||
transformers_mid = unet_config.get("transformer_depth_middle", None)
|
||||
|
||||
diffusers_unet_map = {}
|
||||
for x in range(num_blocks):
|
||||
@@ -196,10 +183,11 @@ def unet_to_diffusers(unet_config):
|
||||
for i in range(num_res_blocks[x]):
|
||||
for b in UNET_MAP_RESNET:
|
||||
diffusers_unet_map["down_blocks.{}.resnets.{}.{}".format(x, i, UNET_MAP_RESNET[b])] = "input_blocks.{}.0.{}".format(n, b)
|
||||
if transformers_per_layer[x] > 0:
|
||||
num_transformers = transformer_depth.pop(0)
|
||||
if num_transformers > 0:
|
||||
for b in UNET_MAP_ATTENTIONS:
|
||||
diffusers_unet_map["down_blocks.{}.attentions.{}.{}".format(x, i, b)] = "input_blocks.{}.1.{}".format(n, b)
|
||||
for t in range(transformers_per_layer[x]):
|
||||
for t in range(num_transformers):
|
||||
for b in TRANSFORMER_BLOCKS:
|
||||
diffusers_unet_map["down_blocks.{}.attentions.{}.transformer_blocks.{}.{}".format(x, i, t, b)] = "input_blocks.{}.1.transformer_blocks.{}.{}".format(n, t, b)
|
||||
n += 1
|
||||
@@ -218,7 +206,6 @@ def unet_to_diffusers(unet_config):
|
||||
diffusers_unet_map["mid_block.resnets.{}.{}".format(i, UNET_MAP_RESNET[b])] = "middle_block.{}.{}".format(n, b)
|
||||
|
||||
num_res_blocks = list(reversed(num_res_blocks))
|
||||
transformers_per_layer = list(reversed(transformers_per_layer))
|
||||
for x in range(num_blocks):
|
||||
n = (num_res_blocks[x] + 1) * x
|
||||
l = num_res_blocks[x] + 1
|
||||
@@ -227,11 +214,12 @@ def unet_to_diffusers(unet_config):
|
||||
for b in UNET_MAP_RESNET:
|
||||
diffusers_unet_map["up_blocks.{}.resnets.{}.{}".format(x, i, UNET_MAP_RESNET[b])] = "output_blocks.{}.0.{}".format(n, b)
|
||||
c += 1
|
||||
if transformers_per_layer[x] > 0:
|
||||
num_transformers = transformer_depth_output.pop()
|
||||
if num_transformers > 0:
|
||||
c += 1
|
||||
for b in UNET_MAP_ATTENTIONS:
|
||||
diffusers_unet_map["up_blocks.{}.attentions.{}.{}".format(x, i, b)] = "output_blocks.{}.1.{}".format(n, b)
|
||||
for t in range(transformers_per_layer[x]):
|
||||
for t in range(num_transformers):
|
||||
for b in TRANSFORMER_BLOCKS:
|
||||
diffusers_unet_map["up_blocks.{}.attentions.{}.transformer_blocks.{}.{}".format(x, i, t, b)] = "output_blocks.{}.1.transformer_blocks.{}.{}".format(n, t, b)
|
||||
if i == l - 1:
|
||||
|
Reference in New Issue
Block a user