1
mirror of https://github.com/comfyanonymous/ComfyUI.git synced 2025-08-02 23:14:49 +08:00

SD1 and SD2 clip and tokenizer code is now more similar to the SDXL one.

This commit is contained in:
comfyanonymous
2023-10-27 15:54:04 -04:00
parent 6ec3f12c6e
commit e60ca6929a
5 changed files with 69 additions and 30 deletions

View File

@@ -35,7 +35,7 @@ class ClipTokenWeightEncoder:
return z_empty.cpu(), first_pooled.cpu()
return torch.cat(output, dim=-2).cpu(), first_pooled.cpu()
class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder):
class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
"""Uses the CLIP transformer encoder for text (from huggingface)"""
LAYERS = [
"last",
@@ -342,7 +342,7 @@ def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=No
embed_out = next(iter(values))
return embed_out
class SD1Tokenizer:
class SDTokenizer:
def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l'):
if tokenizer_path is None:
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_tokenizer")
@@ -454,3 +454,40 @@ class SD1Tokenizer:
def untokenize(self, token_weight_pair):
return list(map(lambda a: (a, self.inv_vocab[a[0]]), token_weight_pair))
class SD1Tokenizer:
def __init__(self, embedding_directory=None, clip_name="l", tokenizer=SDTokenizer):
self.clip_name = clip_name
self.clip = "clip_{}".format(self.clip_name)
setattr(self, self.clip, tokenizer(embedding_directory=embedding_directory))
def tokenize_with_weights(self, text:str, return_word_ids=False):
out = {}
out[self.clip_name] = getattr(self, self.clip).tokenize_with_weights(text, return_word_ids)
return out
def untokenize(self, token_weight_pair):
return getattr(self, self.clip).untokenize(token_weight_pair)
class SD1ClipModel(torch.nn.Module):
def __init__(self, device="cpu", dtype=None, clip_name="l", clip_model=SDClipModel):
super().__init__()
self.clip_name = clip_name
self.clip = "clip_{}".format(self.clip_name)
setattr(self, self.clip, clip_model(device=device, dtype=dtype))
def clip_layer(self, layer_idx):
getattr(self, self.clip).clip_layer(layer_idx)
def reset_clip_layer(self):
getattr(self, self.clip).reset_clip_layer()
def encode_token_weights(self, token_weight_pairs):
token_weight_pairs = token_weight_pairs[self.clip_name]
out, pooled = getattr(self, self.clip).encode_token_weights(token_weight_pairs)
return out, pooled
def load_sd(self, sd):
return getattr(self, self.clip).load_sd(sd)