1
mirror of https://github.com/comfyanonymous/ComfyUI.git synced 2025-08-02 23:14:49 +08:00

More flexibility with text encoder return values.

Text encoders can now return other values to the CONDITIONING than the cond
and pooled output.
This commit is contained in:
comfyanonymous
2024-07-10 20:06:50 -04:00
parent e44fa5667f
commit 391c1046cf
3 changed files with 30 additions and 8 deletions

View File

@@ -62,7 +62,16 @@ class ClipTokenWeightEncoder:
r = (out[-1:].to(model_management.intermediate_device()), first_pooled)
else:
r = (torch.cat(output, dim=-2).to(model_management.intermediate_device()), first_pooled)
r = r + tuple(map(lambda a: a[:sections].flatten().unsqueeze(dim=0).to(model_management.intermediate_device()), o[2:]))
if len(o) > 2:
extra = {}
for k in o[2]:
v = o[2][k]
if k == "attention_mask":
v = v[:sections].flatten().unsqueeze(dim=0).to(model_management.intermediate_device())
extra[k] = v
r = r + (extra,)
return r
class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
@@ -206,8 +215,12 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
elif outputs[2] is not None:
pooled_output = outputs[2].float()
extra = {}
if self.return_attention_masks:
return z, pooled_output, attention_mask
extra["attention_mask"] = attention_mask
if len(extra) > 0:
return z, pooled_output, extra
return z, pooled_output
@@ -547,8 +560,8 @@ class SD1ClipModel(torch.nn.Module):
def encode_token_weights(self, token_weight_pairs):
token_weight_pairs = token_weight_pairs[self.clip_name]
out, pooled = getattr(self, self.clip).encode_token_weights(token_weight_pairs)
return out, pooled
out = getattr(self, self.clip).encode_token_weights(token_weight_pairs)
return out
def load_sd(self, sd):
return getattr(self, self.clip).load_sd(sd)