1
mirror of https://github.com/comfyanonymous/ComfyUI.git synced 2025-08-02 23:14:49 +08:00

Add DualClipLoader to load clip models for SDXL.

Update LoadClip to load clip models for SDXL refiner.
This commit is contained in:
comfyanonymous
2023-06-25 01:40:38 -04:00
parent b7933960bb
commit 20f579d91d
4 changed files with 67 additions and 11 deletions

View File

@@ -31,6 +31,11 @@ class SDXLClipG(sd1_clip.SD1ClipModel):
self.layer = "hidden"
self.layer_idx = layer_idx
def load_sd(self, sd):
if "text_projection" in sd:
self.text_projection[:] = sd.pop("text_projection")
return super().load_sd(sd)
class SDXLClipGTokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, tokenizer_path=None, embedding_directory=None):
super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=1280)
@@ -68,6 +73,12 @@ class SDXLClipModel(torch.nn.Module):
l_out, l_pooled = self.clip_l.encode_token_weights(token_weight_pairs_l)
return torch.cat([l_out, g_out], dim=-1), g_pooled
def load_sd(self, sd):
if "text_model.encoder.layers.30.mlp.fc1.weight" in sd:
return self.clip_g.load_sd(sd)
else:
return self.clip_l.load_sd(sd)
class SDXLRefinerClipModel(torch.nn.Module):
def __init__(self, device="cpu"):
super().__init__()
@@ -81,3 +92,5 @@ class SDXLRefinerClipModel(torch.nn.Module):
g_out, g_pooled = self.clip_g.encode_token_weights(token_weight_pairs_g)
return g_out, g_pooled
def load_sd(self, sd):
return self.clip_g.load_sd(sd)