mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-08-02 23:14:49 +08:00
Make applying embeddings more efficient.
Adding new tokens no longer makes a whole copy of the embeddings weight which can be massive on certain models.
This commit is contained in:
@@ -97,8 +97,12 @@ class CLIPTextModel_(torch.nn.Module):
|
||||
self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations)
|
||||
self.final_layer_norm = operations.LayerNorm(embed_dim, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, input_tokens, attention_mask=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=torch.float32):
|
||||
x = self.embeddings(input_tokens, dtype=dtype)
|
||||
def forward(self, input_tokens=None, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=torch.float32):
|
||||
if embeds is not None:
|
||||
x = embeds + comfy.ops.cast_to(self.embeddings.position_embedding.weight, dtype=dtype, device=embeds.device)
|
||||
else:
|
||||
x = self.embeddings(input_tokens, dtype=dtype)
|
||||
|
||||
mask = None
|
||||
if attention_mask is not None:
|
||||
mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1])
|
||||
@@ -116,7 +120,10 @@ class CLIPTextModel_(torch.nn.Module):
|
||||
if i is not None and final_layer_norm_intermediate:
|
||||
i = self.final_layer_norm(i)
|
||||
|
||||
pooled_output = x[torch.arange(x.shape[0], device=x.device), (torch.round(input_tokens).to(dtype=torch.int, device=x.device) == self.eos_token_id).int().argmax(dim=-1),]
|
||||
if num_tokens is not None:
|
||||
pooled_output = x[list(range(x.shape[0])), list(map(lambda a: a - 1, num_tokens))]
|
||||
else:
|
||||
pooled_output = x[torch.arange(x.shape[0], device=x.device), (torch.round(input_tokens).to(dtype=torch.int, device=x.device) == self.eos_token_id).int().argmax(dim=-1),]
|
||||
return x, i, pooled_output
|
||||
|
||||
class CLIPTextModel(torch.nn.Module):
|
||||
|
Reference in New Issue
Block a user