1
mirror of https://github.com/comfyanonymous/ComfyUI.git synced 2025-08-02 23:14:49 +08:00

Add support for textual inversion embedding for SD1.x CLIP.

This commit is contained in:
comfyanonymous
2023-01-29 18:46:44 -05:00
parent 702ac43d0c
commit f73e57d881
6 changed files with 108 additions and 15 deletions

View File

@@ -63,9 +63,38 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder):
self.layer = "hidden"
self.layer_idx = layer_idx
def set_up_textual_embeddings(self, tokens, current_embeds):
out_tokens = []
next_new_token = token_dict_size = current_embeds.weight.shape[0]
embedding_weights = []
for x in tokens:
tokens_temp = []
for y in x:
if isinstance(y, int):
tokens_temp += [y]
else:
embedding_weights += [y]
tokens_temp += [next_new_token]
next_new_token += 1
out_tokens += [tokens_temp]
if len(embedding_weights) > 0:
new_embedding = torch.nn.Embedding(next_new_token, current_embeds.weight.shape[1])
new_embedding.weight[:token_dict_size] = current_embeds.weight[:]
n = token_dict_size
for x in embedding_weights:
new_embedding.weight[n] = x
n += 1
self.transformer.set_input_embeddings(new_embedding)
return out_tokens
def forward(self, tokens):
backup_embeds = self.transformer.get_input_embeddings()
tokens = self.set_up_textual_embeddings(tokens, backup_embeds)
tokens = torch.LongTensor(tokens).to(self.device)
outputs = self.transformer(input_ids=tokens, output_hidden_states=self.layer=="hidden")
self.transformer.set_input_embeddings(backup_embeds)
if self.layer == "last":
z = outputs.last_hidden_state
@@ -138,18 +167,49 @@ def unescape_important(text):
text = text.replace("\0\2", "(")
return text
def load_embed(embedding_name, embedding_directory):
embed_path = os.path.join(embedding_directory, embedding_name)
if not os.path.isfile(embed_path):
extensions = ['.safetensors', '.pt', '.bin']
valid_file = None
for x in extensions:
t = embed_path + x
if os.path.isfile(t):
valid_file = t
break
if valid_file is None:
print("warning, embedding {} does not exist, ignoring".format(embed_path))
return None
else:
embed_path = valid_file
if embed_path.lower().endswith(".safetensors"):
import safetensors.torch
embed = safetensors.torch.load_file(embed_path, device="cpu")
else:
embed = torch.load(embed_path, weights_only=True, map_location="cpu")
if 'string_to_param' in embed:
values = embed['string_to_param'].values()
else:
values = embed.values()
return next(iter(values))
class SD1Tokenizer:
def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True):
def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None):
if tokenizer_path is None:
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_tokenizer")
self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path)
self.max_length = max_length
self.max_tokens_per_section = self.max_length - 2
empty = self.tokenizer('')["input_ids"]
self.start_token = empty[0]
self.end_token = empty[1]
self.pad_with_end = pad_with_end
vocab = self.tokenizer.get_vocab()
self.inv_vocab = {v: k for k, v in vocab.items()}
self.embedding_directory = embedding_directory
self.max_word_length = 8
def tokenize_with_weights(self, text):
text = escape_important(text)
@@ -157,13 +217,34 @@ class SD1Tokenizer:
tokens = []
for t in parsed_weights:
tt = self.tokenizer(unescape_important(t[0]))["input_ids"][1:-1]
for x in tt:
tokens += [(x, t[1])]
to_tokenize = unescape_important(t[0]).split(' ')
for word in to_tokenize:
temp_tokens = []
embedding_identifier = "embedding:"
if word.startswith(embedding_identifier) and self.embedding_directory is not None:
embedding_name = word[len(embedding_identifier):].strip('\n')
embed = load_embed(embedding_name, self.embedding_directory)
if embed is not None:
if len(embed.shape) == 1:
temp_tokens += [(embed, t[1])]
else:
for x in range(embed.shape[0]):
temp_tokens += [(embed[x], t[1])]
elif len(word) > 0:
tt = self.tokenizer(word)["input_ids"][1:-1]
for x in tt:
temp_tokens += [(x, t[1])]
tokens_left = self.max_tokens_per_section - (len(tokens) % self.max_tokens_per_section)
#try not to split words in different sections
if tokens_left < len(temp_tokens) and len(temp_tokens) < (self.max_word_length):
for x in range(tokens_left):
tokens += [(self.end_token, 1.0)]
tokens += temp_tokens
out_tokens = []
for x in range(0, len(tokens), self.max_length - 2):
o_token = [(self.start_token, 1.0)] + tokens[x:min(self.max_length - 2 + x, len(tokens))]
for x in range(0, len(tokens), self.max_tokens_per_section):
o_token = [(self.start_token, 1.0)] + tokens[x:min(self.max_tokens_per_section + x, len(tokens))]
o_token += [(self.end_token, 1.0)]
if self.pad_with_end:
o_token +=[(self.end_token, 1.0)] * (self.max_length - len(o_token))