1
mirror of https://github.com/comfyanonymous/ComfyUI.git synced 2025-08-02 23:14:49 +08:00

Add support for textual inversion embedding for SD1.x CLIP.

This commit is contained in:
comfyanonymous
2023-01-29 18:46:44 -05:00
parent 702ac43d0c
commit f73e57d881
6 changed files with 108 additions and 15 deletions

View File

@@ -53,19 +53,25 @@ def load_model_from_config(config, ckpt, verbose=False, load_state_dict_to=[]):
class CLIP:
def __init__(self, config):
def __init__(self, config, embedding_directory=None):
self.target_clip = config["target"]
if "params" in config:
params = config["params"]
else:
params = {}
tokenizer_params = {}
if self.target_clip == "ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder":
clip = sd2_clip.SD2ClipModel
tokenizer = sd2_clip.SD2Tokenizer
elif self.target_clip == "ldm.modules.encoders.modules.FrozenCLIPEmbedder":
clip = sd1_clip.SD1ClipModel
tokenizer = sd1_clip.SD1Tokenizer
if "params" in config:
self.cond_stage_model = clip(**(config["params"]))
else:
self.cond_stage_model = clip()
self.tokenizer = tokenizer()
tokenizer_params['embedding_directory'] = embedding_directory
self.cond_stage_model = clip(**(params))
self.tokenizer = tokenizer(**(tokenizer_params))
def encode(self, text):
tokens = self.tokenizer.tokenize_with_weights(text)
@@ -103,7 +109,7 @@ class VAE:
return samples
def load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True):
def load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, embedding_directory=None):
config = OmegaConf.load(config_path)
model_config_params = config['model']['params']
clip_config = model_config_params['cond_stage_config']
@@ -124,7 +130,7 @@ def load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True):
load_state_dict_to = [w]
if output_clip:
clip = CLIP(config=clip_config)
clip = CLIP(config=clip_config, embedding_directory=embedding_directory)
w.cond_stage_model = clip.cond_stage_model
load_state_dict_to = [w]