1
mirror of https://github.com/comfyanonymous/ComfyUI.git synced 2025-08-02 23:14:49 +08:00

Always return unprojected pooled output for gligen.

This commit is contained in:
comfyanonymous
2024-02-25 07:20:31 -05:00
parent 1cb3f6a83b
commit c2cb8e889b
6 changed files with 38 additions and 28 deletions

View File

@@ -91,11 +91,13 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
self.enable_attention_masks = enable_attention_masks
self.layer_norm_hidden_state = layer_norm_hidden_state
self.return_projected_pooled = True
if layer == "hidden":
assert layer_idx is not None
assert abs(layer_idx) < self.num_layers
self.clip_layer(layer_idx)
self.layer_default = (self.layer, self.layer_idx)
self.set_clip_options({"layer": layer_idx})
self.options_default = (self.layer, self.layer_idx, self.return_projected_pooled)
def freeze(self):
self.transformer = self.transformer.eval()
@@ -103,16 +105,19 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
for param in self.parameters():
param.requires_grad = False
def clip_layer(self, layer_idx):
if abs(layer_idx) > self.num_layers:
def set_clip_options(self, options):
layer_idx = options.get("layer", self.layer_idx)
self.return_projected_pooled = options.get("projected_pooled", self.return_projected_pooled)
if layer_idx is None or abs(layer_idx) > self.num_layers:
self.layer = "last"
else:
self.layer = "hidden"
self.layer_idx = layer_idx
def reset_clip_layer(self):
self.layer = self.layer_default[0]
self.layer_idx = self.layer_default[1]
def reset_clip_options(self):
self.layer = self.options_default[0]
self.layer_idx = self.options_default[1]
self.return_projected_pooled = self.options_default[2]
def set_up_textual_embeddings(self, tokens, current_embeds):
out_tokens = []
@@ -177,10 +182,12 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
else:
z = outputs[1]
if outputs[2] is not None:
pooled_output = outputs[2].float()
else:
pooled_output = None
pooled_output = None
if len(outputs) >= 3:
if not self.return_projected_pooled and len(outputs) >= 4 and outputs[3] is not None:
pooled_output = outputs[3].float()
elif outputs[2] is not None:
pooled_output = outputs[2].float()
return z.float(), pooled_output
@@ -497,11 +504,11 @@ class SD1ClipModel(torch.nn.Module):
self.clip = "clip_{}".format(self.clip_name)
setattr(self, self.clip, clip_model(device=device, dtype=dtype, **kwargs))
def clip_layer(self, layer_idx):
getattr(self, self.clip).clip_layer(layer_idx)
def set_clip_options(self, options):
getattr(self, self.clip).set_clip_options(options)
def reset_clip_layer(self):
getattr(self, self.clip).reset_clip_layer()
def reset_clip_options(self):
getattr(self, self.clip).reset_clip_options()
def encode_token_weights(self, token_weight_pairs):
token_weight_pairs = token_weight_pairs[self.clip_name]