1
mirror of https://github.com/comfyanonymous/ComfyUI.git synced 2025-08-03 07:26:31 +08:00

Simpler base model code.

This commit is contained in:
comfyanonymous
2023-06-09 12:24:24 -04:00
parent 4b0b516544
commit de142eaad5
4 changed files with 163 additions and 74 deletions

View File

@@ -15,8 +15,15 @@ from . import utils
from . import clip_vision
from . import gligen
from . import diffusers_convert
from . import model_base
def load_model_weights(model, sd, verbose=False, load_state_dict_to=[]):
replace_prefix = {"model.diffusion_model.": "diffusion_model."}
for rp in replace_prefix:
replace = list(map(lambda a: (a, "{}{}".format(replace_prefix[rp], a[len(rp):])), filter(lambda a: a.startswith(rp), sd.keys())))
for x in replace:
sd[x[1]] = sd.pop(x[0])
m, u = model.load_state_dict(sd, strict=False)
k = list(sd.keys())
@@ -182,7 +189,7 @@ def model_lora_keys(model, key_map={}):
counter = 0
for b in range(12):
tk = "model.diffusion_model.input_blocks.{}.1".format(b)
tk = "diffusion_model.input_blocks.{}.1".format(b)
up_counter = 0
for c in LORA_UNET_MAP_ATTENTIONS:
k = "{}.{}.weight".format(tk, c)
@@ -193,13 +200,13 @@ def model_lora_keys(model, key_map={}):
if up_counter >= 4:
counter += 1
for c in LORA_UNET_MAP_ATTENTIONS:
k = "model.diffusion_model.middle_block.1.{}.weight".format(c)
k = "diffusion_model.middle_block.1.{}.weight".format(c)
if k in sdk:
lora_key = "lora_unet_mid_block_attentions_0_{}".format(LORA_UNET_MAP_ATTENTIONS[c])
key_map[lora_key] = k
counter = 3
for b in range(12):
tk = "model.diffusion_model.output_blocks.{}.1".format(b)
tk = "diffusion_model.output_blocks.{}.1".format(b)
up_counter = 0
for c in LORA_UNET_MAP_ATTENTIONS:
k = "{}.{}.weight".format(tk, c)
@@ -223,7 +230,7 @@ def model_lora_keys(model, key_map={}):
ds_counter = 0
counter = 0
for b in range(12):
tk = "model.diffusion_model.input_blocks.{}.0".format(b)
tk = "diffusion_model.input_blocks.{}.0".format(b)
key_in = False
for c in LORA_UNET_MAP_RESNET:
k = "{}.{}.weight".format(tk, c)
@@ -242,7 +249,7 @@ def model_lora_keys(model, key_map={}):
counter = 0
for b in range(3):
tk = "model.diffusion_model.middle_block.{}".format(b)
tk = "diffusion_model.middle_block.{}".format(b)
key_in = False
for c in LORA_UNET_MAP_RESNET:
k = "{}.{}.weight".format(tk, c)
@@ -256,7 +263,7 @@ def model_lora_keys(model, key_map={}):
counter = 0
us_counter = 0
for b in range(12):
tk = "model.diffusion_model.output_blocks.{}.0".format(b)
tk = "diffusion_model.output_blocks.{}.0".format(b)
key_in = False
for c in LORA_UNET_MAP_RESNET:
k = "{}.{}.weight".format(tk, c)
@@ -332,7 +339,7 @@ class ModelPatcher:
patch_list[i] = patch_list[i].to(device)
def model_dtype(self):
return self.model.diffusion_model.dtype
return self.model.get_dtype()
def add_patches(self, patches, strength=1.0):
p = {}
@@ -764,7 +771,7 @@ def load_controlnet(ckpt_path, model=None):
for x in controlnet_data:
c_m = "control_model."
if x.startswith(c_m):
sd_key = "model.diffusion_model.{}".format(x[len(c_m):])
sd_key = "diffusion_model.{}".format(x[len(c_m):])
if sd_key in model_sd:
cd = controlnet_data[x]
cd += model_sd[sd_key].type(cd.dtype).to(cd.device)
@@ -931,9 +938,10 @@ def load_gligen(ckpt_path):
model = model.half()
return model
def load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, embedding_directory=None):
with open(config_path, 'r') as stream:
config = yaml.safe_load(stream)
def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_clip=True, embedding_directory=None, state_dict=None, config=None):
if config is None:
with open(config_path, 'r') as stream:
config = yaml.safe_load(stream)
model_config_params = config['model']['params']
clip_config = model_config_params['cond_stage_config']
scale_factor = model_config_params['scale_factor']
@@ -942,8 +950,19 @@ def load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, e
fp16 = False
if "unet_config" in model_config_params:
if "params" in model_config_params["unet_config"]:
if "use_fp16" in model_config_params["unet_config"]["params"]:
fp16 = model_config_params["unet_config"]["params"]["use_fp16"]
unet_config = model_config_params["unet_config"]["params"]
if "use_fp16" in unet_config:
fp16 = unet_config["use_fp16"]
noise_aug_config = None
if "noise_aug_config" in model_config_params:
noise_aug_config = model_config_params["noise_aug_config"]
v_prediction = False
if "parameterization" in model_config_params:
if model_config_params["parameterization"] == "v":
v_prediction = True
clip = None
vae = None
@@ -963,9 +982,16 @@ def load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, e
w.cond_stage_model = clip.cond_stage_model
load_state_dict_to = [w]
model = instantiate_from_config(config["model"])
sd = utils.load_torch_file(ckpt_path)
model = load_model_weights(model, sd, verbose=False, load_state_dict_to=load_state_dict_to)
if config['model']["target"].endswith("LatentInpaintDiffusion"):
model = model_base.SDInpaint(unet_config, v_prediction=v_prediction)
elif config['model']["target"].endswith("ImageEmbeddingConditionedLatentDiffusion"):
model = model_base.SD21UNCLIP(unet_config, noise_aug_config["params"], v_prediction=v_prediction)
else:
model = model_base.BaseModel(unet_config, v_prediction=v_prediction)
if state_dict is None:
state_dict = utils.load_torch_file(ckpt_path)
model = load_model_weights(model, state_dict, verbose=False, load_state_dict_to=load_state_dict_to)
if fp16:
model = model.half()
@@ -1073,16 +1099,20 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
sd_config["unet_config"] = {"target": "comfy.ldm.modules.diffusionmodules.openaimodel.UNetModel", "params": unet_config}
model_config = {"target": "comfy.ldm.models.diffusion.ddpm.LatentDiffusion", "params": sd_config}
unclip_model = False
inpaint_model = False
if noise_aug_config is not None: #SD2.x unclip model
sd_config["noise_aug_config"] = noise_aug_config
sd_config["image_size"] = 96
sd_config["embedding_dropout"] = 0.25
sd_config["conditioning_key"] = 'crossattn-adm'
unclip_model = True
model_config["target"] = "comfy.ldm.models.diffusion.ddpm.ImageEmbeddingConditionedLatentDiffusion"
elif unet_config["in_channels"] > 4: #inpainting model
sd_config["conditioning_key"] = "hybrid"
sd_config["finetune_keys"] = None
model_config["target"] = "comfy.ldm.models.diffusion.ddpm.LatentInpaintDiffusion"
inpaint_model = True
else:
sd_config["conditioning_key"] = "crossattn"
@@ -1096,13 +1126,21 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
unet_config["num_classes"] = "sequential"
unet_config["adm_in_channels"] = sd[unclip].shape[1]
v_prediction = False
if unet_config["context_dim"] == 1024 and unet_config["in_channels"] == 4: #only SD2.x non inpainting models are v prediction
k = "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.norm1.bias"
out = sd[k]
if torch.std(out, unbiased=False) > 0.09: # not sure how well this will actually work. I guess we will find out.
v_prediction = True
sd_config["parameterization"] = 'v'
model = instantiate_from_config(model_config)
if inpaint_model:
model = model_base.SDInpaint(unet_config, v_prediction=v_prediction)
elif unclip_model:
model = model_base.SD21UNCLIP(unet_config, noise_aug_config["params"], v_prediction=v_prediction)
else:
model = model_base.BaseModel(unet_config, v_prediction=v_prediction)
model = load_model_weights(model, sd, verbose=False, load_state_dict_to=load_state_dict_to)
if fp16: