mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-08-02 23:14:49 +08:00
Always return unprojected pooled output for gligen.
This commit is contained in:
@@ -91,11 +91,13 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
self.enable_attention_masks = enable_attention_masks
|
||||
|
||||
self.layer_norm_hidden_state = layer_norm_hidden_state
|
||||
self.return_projected_pooled = True
|
||||
|
||||
if layer == "hidden":
|
||||
assert layer_idx is not None
|
||||
assert abs(layer_idx) < self.num_layers
|
||||
self.clip_layer(layer_idx)
|
||||
self.layer_default = (self.layer, self.layer_idx)
|
||||
self.set_clip_options({"layer": layer_idx})
|
||||
self.options_default = (self.layer, self.layer_idx, self.return_projected_pooled)
|
||||
|
||||
def freeze(self):
|
||||
self.transformer = self.transformer.eval()
|
||||
@@ -103,16 +105,19 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
for param in self.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def clip_layer(self, layer_idx):
|
||||
if abs(layer_idx) > self.num_layers:
|
||||
def set_clip_options(self, options):
|
||||
layer_idx = options.get("layer", self.layer_idx)
|
||||
self.return_projected_pooled = options.get("projected_pooled", self.return_projected_pooled)
|
||||
if layer_idx is None or abs(layer_idx) > self.num_layers:
|
||||
self.layer = "last"
|
||||
else:
|
||||
self.layer = "hidden"
|
||||
self.layer_idx = layer_idx
|
||||
|
||||
def reset_clip_layer(self):
|
||||
self.layer = self.layer_default[0]
|
||||
self.layer_idx = self.layer_default[1]
|
||||
def reset_clip_options(self):
|
||||
self.layer = self.options_default[0]
|
||||
self.layer_idx = self.options_default[1]
|
||||
self.return_projected_pooled = self.options_default[2]
|
||||
|
||||
def set_up_textual_embeddings(self, tokens, current_embeds):
|
||||
out_tokens = []
|
||||
@@ -177,10 +182,12 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
else:
|
||||
z = outputs[1]
|
||||
|
||||
if outputs[2] is not None:
|
||||
pooled_output = outputs[2].float()
|
||||
else:
|
||||
pooled_output = None
|
||||
pooled_output = None
|
||||
if len(outputs) >= 3:
|
||||
if not self.return_projected_pooled and len(outputs) >= 4 and outputs[3] is not None:
|
||||
pooled_output = outputs[3].float()
|
||||
elif outputs[2] is not None:
|
||||
pooled_output = outputs[2].float()
|
||||
|
||||
return z.float(), pooled_output
|
||||
|
||||
@@ -497,11 +504,11 @@ class SD1ClipModel(torch.nn.Module):
|
||||
self.clip = "clip_{}".format(self.clip_name)
|
||||
setattr(self, self.clip, clip_model(device=device, dtype=dtype, **kwargs))
|
||||
|
||||
def clip_layer(self, layer_idx):
|
||||
getattr(self, self.clip).clip_layer(layer_idx)
|
||||
def set_clip_options(self, options):
|
||||
getattr(self, self.clip).set_clip_options(options)
|
||||
|
||||
def reset_clip_layer(self):
|
||||
getattr(self, self.clip).reset_clip_layer()
|
||||
def reset_clip_options(self):
|
||||
getattr(self, self.clip).reset_clip_options()
|
||||
|
||||
def encode_token_weights(self, token_weight_pairs):
|
||||
token_weight_pairs = token_weight_pairs[self.clip_name]
|
||||
|
Reference in New Issue
Block a user