mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-08-02 23:14:49 +08:00
Switch text encoder to manual cast.
Use fp16 text encoder weights for CPU inference to lower memory usage.
This commit is contained in:
@@ -78,7 +78,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
with open(textmodel_json_config) as f:
|
||||
config = json.load(f)
|
||||
|
||||
self.transformer = model_class(config, dtype, device, comfy.ops)
|
||||
self.transformer = model_class(config, dtype, device, comfy.ops.manual_cast)
|
||||
self.num_layers = self.transformer.num_layers
|
||||
|
||||
self.max_length = max_length
|
||||
@@ -160,37 +160,31 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
tokens = self.set_up_textual_embeddings(tokens, backup_embeds)
|
||||
tokens = torch.LongTensor(tokens).to(device)
|
||||
|
||||
if self.transformer.dtype != torch.float32:
|
||||
precision_scope = torch.autocast
|
||||
attention_mask = None
|
||||
if self.enable_attention_masks:
|
||||
attention_mask = torch.zeros_like(tokens)
|
||||
max_token = self.transformer.get_input_embeddings().weight.shape[0] - 1
|
||||
for x in range(attention_mask.shape[0]):
|
||||
for y in range(attention_mask.shape[1]):
|
||||
attention_mask[x, y] = 1
|
||||
if tokens[x, y] == max_token:
|
||||
break
|
||||
|
||||
outputs = self.transformer(tokens, attention_mask, intermediate_output=self.layer_idx, final_layer_norm_intermediate=self.layer_norm_hidden_state)
|
||||
self.transformer.set_input_embeddings(backup_embeds)
|
||||
|
||||
if self.layer == "last":
|
||||
z = outputs[0]
|
||||
else:
|
||||
precision_scope = lambda a, dtype: contextlib.nullcontext(a)
|
||||
z = outputs[1]
|
||||
|
||||
with precision_scope(model_management.get_autocast_device(device), dtype=torch.float32):
|
||||
attention_mask = None
|
||||
if self.enable_attention_masks:
|
||||
attention_mask = torch.zeros_like(tokens)
|
||||
max_token = self.transformer.get_input_embeddings().weight.shape[0] - 1
|
||||
for x in range(attention_mask.shape[0]):
|
||||
for y in range(attention_mask.shape[1]):
|
||||
attention_mask[x, y] = 1
|
||||
if tokens[x, y] == max_token:
|
||||
break
|
||||
if outputs[2] is not None:
|
||||
pooled_output = outputs[2].float()
|
||||
else:
|
||||
pooled_output = None
|
||||
|
||||
outputs = self.transformer(tokens, attention_mask, intermediate_output=self.layer_idx, final_layer_norm_intermediate=self.layer_norm_hidden_state)
|
||||
self.transformer.set_input_embeddings(backup_embeds)
|
||||
|
||||
if self.layer == "last":
|
||||
z = outputs[0]
|
||||
else:
|
||||
z = outputs[1]
|
||||
|
||||
if outputs[2] is not None:
|
||||
pooled_output = outputs[2].float()
|
||||
else:
|
||||
pooled_output = None
|
||||
|
||||
if self.text_projection is not None and pooled_output is not None:
|
||||
pooled_output = pooled_output.float().to(self.text_projection.device) @ self.text_projection.float()
|
||||
if self.text_projection is not None and pooled_output is not None:
|
||||
pooled_output = pooled_output.float().to(self.text_projection.device) @ self.text_projection.float()
|
||||
return z.float(), pooled_output
|
||||
|
||||
def encode(self, tokens):
|
||||
|
Reference in New Issue
Block a user