mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-08-02 15:04:50 +08:00
Cleaner CLIP text encoder implementation.
Use a simple CLIP model implementation instead of the one from transformers. This will allow some interesting things that would too hackish to implement using the transformers implementation.
This commit is contained in:
@@ -1,12 +1,14 @@
|
||||
import os
|
||||
|
||||
from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextConfig, modeling_utils
|
||||
from transformers import CLIPTokenizer
|
||||
import comfy.ops
|
||||
import torch
|
||||
import traceback
|
||||
import zipfile
|
||||
from . import model_management
|
||||
import contextlib
|
||||
import comfy.clip_model
|
||||
import json
|
||||
|
||||
def gen_empty_tokens(special_tokens, length):
|
||||
start_token = special_tokens.get("start", None)
|
||||
@@ -65,35 +67,19 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
"hidden"
|
||||
]
|
||||
def __init__(self, version="openai/clip-vit-large-patch14", device="cpu", max_length=77,
|
||||
freeze=True, layer="last", layer_idx=None, textmodel_json_config=None, textmodel_path=None, dtype=None,
|
||||
special_tokens={"start": 49406, "end": 49407, "pad": 49407},layer_norm_hidden_state=True, config_class=CLIPTextConfig,
|
||||
model_class=CLIPTextModel, inner_name="text_model"): # clip-vit-base-patch32
|
||||
freeze=True, layer="last", layer_idx=None, textmodel_json_config=None, dtype=None, model_class=comfy.clip_model.CLIPTextModel,
|
||||
special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=True): # clip-vit-base-patch32
|
||||
super().__init__()
|
||||
assert layer in self.LAYERS
|
||||
self.num_layers = 12
|
||||
if textmodel_path is not None:
|
||||
self.transformer = model_class.from_pretrained(textmodel_path)
|
||||
else:
|
||||
if textmodel_json_config is None:
|
||||
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_clip_config.json")
|
||||
config = config_class.from_json_file(textmodel_json_config)
|
||||
self.num_layers = config.num_hidden_layers
|
||||
with comfy.ops.use_comfy_ops(device, dtype):
|
||||
with modeling_utils.no_init_weights():
|
||||
self.transformer = model_class(config)
|
||||
|
||||
self.inner_name = inner_name
|
||||
if dtype is not None:
|
||||
inner_model = getattr(self.transformer, self.inner_name)
|
||||
if hasattr(inner_model, "embeddings"):
|
||||
embeddings_bak = inner_model.embeddings.to(torch.float32)
|
||||
inner_model.embeddings = None
|
||||
self.transformer.to(dtype)
|
||||
inner_model.embeddings = embeddings_bak
|
||||
else:
|
||||
previous_inputs = self.transformer.get_input_embeddings().to(torch.float32, copy=True)
|
||||
self.transformer.to(dtype)
|
||||
self.transformer.set_input_embeddings(previous_inputs)
|
||||
if textmodel_json_config is None:
|
||||
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_clip_config.json")
|
||||
|
||||
with open(textmodel_json_config) as f:
|
||||
config = json.load(f)
|
||||
|
||||
self.transformer = model_class(config, dtype, device, comfy.ops)
|
||||
self.num_layers = self.transformer.num_layers
|
||||
|
||||
self.max_length = max_length
|
||||
if freeze:
|
||||
@@ -108,7 +94,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
self.layer_norm_hidden_state = layer_norm_hidden_state
|
||||
if layer == "hidden":
|
||||
assert layer_idx is not None
|
||||
assert abs(layer_idx) <= self.num_layers
|
||||
assert abs(layer_idx) < self.num_layers
|
||||
self.clip_layer(layer_idx)
|
||||
self.layer_default = (self.layer, self.layer_idx)
|
||||
|
||||
@@ -119,7 +105,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
param.requires_grad = False
|
||||
|
||||
def clip_layer(self, layer_idx):
|
||||
if abs(layer_idx) >= self.num_layers:
|
||||
if abs(layer_idx) > self.num_layers:
|
||||
self.layer = "last"
|
||||
else:
|
||||
self.layer = "hidden"
|
||||
@@ -174,7 +160,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
tokens = self.set_up_textual_embeddings(tokens, backup_embeds)
|
||||
tokens = torch.LongTensor(tokens).to(device)
|
||||
|
||||
if getattr(self.transformer, self.inner_name).final_layer_norm.weight.dtype != torch.float32:
|
||||
if self.transformer.dtype != torch.float32:
|
||||
precision_scope = torch.autocast
|
||||
else:
|
||||
precision_scope = lambda a, dtype: contextlib.nullcontext(a)
|
||||
@@ -190,20 +176,16 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
if tokens[x, y] == max_token:
|
||||
break
|
||||
|
||||
outputs = self.transformer(input_ids=tokens, attention_mask=attention_mask, output_hidden_states=self.layer=="hidden")
|
||||
outputs = self.transformer(tokens, attention_mask, intermediate_output=self.layer_idx, final_layer_norm_intermediate=self.layer_norm_hidden_state)
|
||||
self.transformer.set_input_embeddings(backup_embeds)
|
||||
|
||||
if self.layer == "last":
|
||||
z = outputs.last_hidden_state
|
||||
elif self.layer == "pooled":
|
||||
z = outputs.pooler_output[:, None, :]
|
||||
z = outputs[0]
|
||||
else:
|
||||
z = outputs.hidden_states[self.layer_idx]
|
||||
if self.layer_norm_hidden_state:
|
||||
z = getattr(self.transformer, self.inner_name).final_layer_norm(z)
|
||||
z = outputs[1]
|
||||
|
||||
if hasattr(outputs, "pooler_output"):
|
||||
pooled_output = outputs.pooler_output.float()
|
||||
if outputs[2] is not None:
|
||||
pooled_output = outputs[2].float()
|
||||
else:
|
||||
pooled_output = None
|
||||
|
||||
|
Reference in New Issue
Block a user