1
mirror of https://github.com/comfyanonymous/ComfyUI.git synced 2025-08-02 23:14:49 +08:00

Make applying embeddings more efficient.

Adding new tokens no longer makes a whole copy of the embeddings weight
which can be massive on certain models.
This commit is contained in:
comfyanonymous
2025-03-05 17:34:38 -05:00
parent 5d84607bf3
commit 85ef295069
5 changed files with 81 additions and 61 deletions

View File

@@ -97,8 +97,12 @@ class CLIPTextModel_(torch.nn.Module):
self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations)
self.final_layer_norm = operations.LayerNorm(embed_dim, dtype=dtype, device=device)
def forward(self, input_tokens, attention_mask=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=torch.float32):
x = self.embeddings(input_tokens, dtype=dtype)
def forward(self, input_tokens=None, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=torch.float32):
if embeds is not None:
x = embeds + comfy.ops.cast_to(self.embeddings.position_embedding.weight, dtype=dtype, device=embeds.device)
else:
x = self.embeddings(input_tokens, dtype=dtype)
mask = None
if attention_mask is not None:
mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1])
@@ -116,7 +120,10 @@ class CLIPTextModel_(torch.nn.Module):
if i is not None and final_layer_norm_intermediate:
i = self.final_layer_norm(i)
pooled_output = x[torch.arange(x.shape[0], device=x.device), (torch.round(input_tokens).to(dtype=torch.int, device=x.device) == self.eos_token_id).int().argmax(dim=-1),]
if num_tokens is not None:
pooled_output = x[list(range(x.shape[0])), list(map(lambda a: a - 1, num_tokens))]
else:
pooled_output = x[torch.arange(x.shape[0], device=x.device), (torch.round(input_tokens).to(dtype=torch.int, device=x.device) == self.eos_token_id).int().argmax(dim=-1),]
return x, i, pooled_output
class CLIPTextModel(torch.nn.Module):