1
mirror of https://github.com/comfyanonymous/ComfyUI.git synced 2025-08-02 23:14:49 +08:00

Properly support SDXL diffusers loras for unet.

This commit is contained in:
comfyanonymous
2023-07-04 21:10:12 -04:00
parent 8d694cc450
commit acf95191ff
2 changed files with 128 additions and 116 deletions

View File

@@ -59,35 +59,6 @@ LORA_CLIP_MAP = {
"self_attn.out_proj": "self_attn_out_proj",
}
LORA_UNET_MAP_ATTENTIONS = {
"proj_in": "proj_in",
"proj_out": "proj_out",
}
transformer_lora_blocks = {
"transformer_blocks.{}.attn1.to_q": "transformer_blocks_{}_attn1_to_q",
"transformer_blocks.{}.attn1.to_k": "transformer_blocks_{}_attn1_to_k",
"transformer_blocks.{}.attn1.to_v": "transformer_blocks_{}_attn1_to_v",
"transformer_blocks.{}.attn1.to_out.0": "transformer_blocks_{}_attn1_to_out_0",
"transformer_blocks.{}.attn2.to_q": "transformer_blocks_{}_attn2_to_q",
"transformer_blocks.{}.attn2.to_k": "transformer_blocks_{}_attn2_to_k",
"transformer_blocks.{}.attn2.to_v": "transformer_blocks_{}_attn2_to_v",
"transformer_blocks.{}.attn2.to_out.0": "transformer_blocks_{}_attn2_to_out_0",
"transformer_blocks.{}.ff.net.0.proj": "transformer_blocks_{}_ff_net_0_proj",
"transformer_blocks.{}.ff.net.2": "transformer_blocks_{}_ff_net_2",
}
for i in range(10):
for k in transformer_lora_blocks:
LORA_UNET_MAP_ATTENTIONS[k.format(i)] = transformer_lora_blocks[k].format(i)
LORA_UNET_MAP_RESNET = {
"in_layers.2": "resnets_{}_conv1",
"emb_layers.1": "resnets_{}_time_emb_proj",
"out_layers.3": "resnets_{}_conv2",
"skip_connection": "resnets_{}_conv_shortcut"
}
def load_lora(lora, to_load):
patch_dict = {}
@@ -188,39 +159,9 @@ def load_lora(lora, to_load):
print("lora key not loaded", x)
return patch_dict
def model_lora_keys(model, key_map={}):
def model_lora_keys_clip(model, key_map={}):
sdk = model.state_dict().keys()
counter = 0
for b in range(12):
tk = "diffusion_model.input_blocks.{}.1".format(b)
up_counter = 0
for c in LORA_UNET_MAP_ATTENTIONS:
k = "{}.{}.weight".format(tk, c)
if k in sdk:
lora_key = "lora_unet_down_blocks_{}_attentions_{}_{}".format(counter // 2, counter % 2, LORA_UNET_MAP_ATTENTIONS[c])
key_map[lora_key] = k
up_counter += 1
if up_counter >= 4:
counter += 1
for c in LORA_UNET_MAP_ATTENTIONS:
k = "diffusion_model.middle_block.1.{}.weight".format(c)
if k in sdk:
lora_key = "lora_unet_mid_block_attentions_0_{}".format(LORA_UNET_MAP_ATTENTIONS[c])
key_map[lora_key] = k
counter = 3
for b in range(12):
tk = "diffusion_model.output_blocks.{}.1".format(b)
up_counter = 0
for c in LORA_UNET_MAP_ATTENTIONS:
k = "{}.{}.weight".format(tk, c)
if k in sdk:
lora_key = "lora_unet_up_blocks_{}_attentions_{}_{}".format(counter // 3, counter % 3, LORA_UNET_MAP_ATTENTIONS[c])
key_map[lora_key] = k
up_counter += 1
if up_counter >= 4:
counter += 1
counter = 0
text_model_lora_key = "lora_te_text_model_encoder_layers_{}_{}"
clip_l_present = False
for b in range(32):
@@ -244,69 +185,23 @@ def model_lora_keys(model, key_map={}):
lora_key = "lora_te_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) #TODO: test if this is correct for SDXL-Refiner
key_map[lora_key] = k
return key_map
#Locon stuff
ds_counter = 0
counter = 0
for b in range(12):
tk = "diffusion_model.input_blocks.{}.0".format(b)
key_in = False
for c in LORA_UNET_MAP_RESNET:
k = "{}.{}.weight".format(tk, c)
if k in sdk:
lora_key = "lora_unet_down_blocks_{}_{}".format(counter // 2, LORA_UNET_MAP_RESNET[c].format(counter % 2))
key_map[lora_key] = k
key_in = True
for bb in range(3):
k = "{}.{}.op.weight".format(tk[:-2], bb)
if k in sdk:
lora_key = "lora_unet_down_blocks_{}_downsamplers_0_conv".format(ds_counter)
key_map[lora_key] = k
ds_counter += 1
if key_in:
counter += 1
counter = 0
for b in range(3):
tk = "diffusion_model.middle_block.{}".format(b)
key_in = False
for c in LORA_UNET_MAP_RESNET:
k = "{}.{}.weight".format(tk, c)
if k in sdk:
lora_key = "lora_unet_mid_block_{}".format(LORA_UNET_MAP_RESNET[c].format(counter))
key_map[lora_key] = k
key_in = True
if key_in:
counter += 1
counter = 0
us_counter = 0
for b in range(12):
tk = "diffusion_model.output_blocks.{}.0".format(b)
key_in = False
for c in LORA_UNET_MAP_RESNET:
k = "{}.{}.weight".format(tk, c)
if k in sdk:
lora_key = "lora_unet_up_blocks_{}_{}".format(counter // 3, LORA_UNET_MAP_RESNET[c].format(counter % 3))
key_map[lora_key] = k
key_in = True
for bb in range(3):
k = "{}.{}.conv.weight".format(tk[:-2], bb)
if k in sdk:
lora_key = "lora_unet_up_blocks_{}_upsamplers_0_conv".format(us_counter)
key_map[lora_key] = k
us_counter += 1
if key_in:
counter += 1
def model_lora_keys_unet(model, key_map={}):
sdk = model.state_dict().keys()
for k in sdk:
if k.startswith("diffusion_model.") and k.endswith(".weight"):
key_lora = k[len("diffusion_model."):-len(".weight")].replace(".", "_")
key_map["lora_unet_{}".format(key_lora)] = k
diffusers_keys = utils.unet_to_diffusers(model.model_config.unet_config)
for k in diffusers_keys:
if k.endswith(".weight"):
key_lora = k[:-len(".weight")].replace(".", "_")
key_map["lora_unet_{}".format(key_lora)] = "diffusion_model.{}".format(diffusers_keys[k])
return key_map
class ModelPatcher:
def __init__(self, model, load_device, offload_device, size=0):
self.size = size
@@ -506,8 +401,8 @@ class ModelPatcher:
self.backup = {}
def load_lora_for_models(model, clip, lora, strength_model, strength_clip):
key_map = model_lora_keys(model.model)
key_map = model_lora_keys(clip.cond_stage_model, key_map)
key_map = model_lora_keys_unet(model.model)
key_map = model_lora_keys_clip(clip.cond_stage_model, key_map)
loaded = load_lora(lora, key_map)
new_modelpatcher = model.clone()
k = new_modelpatcher.add_patches(loaded, strength_model)