mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-08-03 07:26:31 +08:00
Support SDXL embedding format with 2 CLIP.
This commit is contained in:
@@ -233,7 +233,7 @@ def expand_directory_list(directories):
|
||||
dirs.add(root)
|
||||
return list(dirs)
|
||||
|
||||
def load_embed(embedding_name, embedding_directory, embedding_size):
|
||||
def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=None):
|
||||
if isinstance(embedding_directory, str):
|
||||
embedding_directory = [embedding_directory]
|
||||
|
||||
@@ -292,13 +292,15 @@ def load_embed(embedding_name, embedding_directory, embedding_size):
|
||||
continue
|
||||
out_list.append(t.reshape(-1, t.shape[-1]))
|
||||
embed_out = torch.cat(out_list, dim=0)
|
||||
elif embed_key is not None and embed_key in embed:
|
||||
embed_out = embed[embed_key]
|
||||
else:
|
||||
values = embed.values()
|
||||
embed_out = next(iter(values))
|
||||
return embed_out
|
||||
|
||||
class SD1Tokenizer:
|
||||
def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768):
|
||||
def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l'):
|
||||
if tokenizer_path is None:
|
||||
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_tokenizer")
|
||||
self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path)
|
||||
@@ -315,17 +317,18 @@ class SD1Tokenizer:
|
||||
self.max_word_length = 8
|
||||
self.embedding_identifier = "embedding:"
|
||||
self.embedding_size = embedding_size
|
||||
self.embedding_key = embedding_key
|
||||
|
||||
def _try_get_embedding(self, embedding_name:str):
|
||||
'''
|
||||
Takes a potential embedding name and tries to retrieve it.
|
||||
Returns a Tuple consisting of the embedding and any leftover string, embedding can be None.
|
||||
'''
|
||||
embed = load_embed(embedding_name, self.embedding_directory, self.embedding_size)
|
||||
embed = load_embed(embedding_name, self.embedding_directory, self.embedding_size, self.embedding_key)
|
||||
if embed is None:
|
||||
stripped = embedding_name.strip(',')
|
||||
if len(stripped) < len(embedding_name):
|
||||
embed = load_embed(stripped, self.embedding_directory, self.embedding_size)
|
||||
embed = load_embed(stripped, self.embedding_directory, self.embedding_size, self.embedding_key)
|
||||
return (embed, embedding_name[len(stripped):])
|
||||
return (embed, "")
|
||||
|
||||
|
Reference in New Issue
Block a user