1
mirror of https://github.com/comfyanonymous/ComfyUI.git synced 2025-08-02 23:14:49 +08:00

Switch text encoder to manual cast.

Use fp16 text encoder weights for CPU inference to lower memory usage.
This commit is contained in:
comfyanonymous
2023-12-10 23:00:54 -05:00
parent 69033081c5
commit 57926635e8
3 changed files with 59 additions and 29 deletions

View File

@@ -78,7 +78,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
with open(textmodel_json_config) as f:
config = json.load(f)
self.transformer = model_class(config, dtype, device, comfy.ops)
self.transformer = model_class(config, dtype, device, comfy.ops.manual_cast)
self.num_layers = self.transformer.num_layers
self.max_length = max_length
@@ -160,37 +160,31 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
tokens = self.set_up_textual_embeddings(tokens, backup_embeds)
tokens = torch.LongTensor(tokens).to(device)
if self.transformer.dtype != torch.float32:
precision_scope = torch.autocast
attention_mask = None
if self.enable_attention_masks:
attention_mask = torch.zeros_like(tokens)
max_token = self.transformer.get_input_embeddings().weight.shape[0] - 1
for x in range(attention_mask.shape[0]):
for y in range(attention_mask.shape[1]):
attention_mask[x, y] = 1
if tokens[x, y] == max_token:
break
outputs = self.transformer(tokens, attention_mask, intermediate_output=self.layer_idx, final_layer_norm_intermediate=self.layer_norm_hidden_state)
self.transformer.set_input_embeddings(backup_embeds)
if self.layer == "last":
z = outputs[0]
else:
precision_scope = lambda a, dtype: contextlib.nullcontext(a)
z = outputs[1]
with precision_scope(model_management.get_autocast_device(device), dtype=torch.float32):
attention_mask = None
if self.enable_attention_masks:
attention_mask = torch.zeros_like(tokens)
max_token = self.transformer.get_input_embeddings().weight.shape[0] - 1
for x in range(attention_mask.shape[0]):
for y in range(attention_mask.shape[1]):
attention_mask[x, y] = 1
if tokens[x, y] == max_token:
break
if outputs[2] is not None:
pooled_output = outputs[2].float()
else:
pooled_output = None
outputs = self.transformer(tokens, attention_mask, intermediate_output=self.layer_idx, final_layer_norm_intermediate=self.layer_norm_hidden_state)
self.transformer.set_input_embeddings(backup_embeds)
if self.layer == "last":
z = outputs[0]
else:
z = outputs[1]
if outputs[2] is not None:
pooled_output = outputs[2].float()
else:
pooled_output = None
if self.text_projection is not None and pooled_output is not None:
pooled_output = pooled_output.float().to(self.text_projection.device) @ self.text_projection.float()
if self.text_projection is not None and pooled_output is not None:
pooled_output = pooled_output.float().to(self.text_projection.device) @ self.text_projection.float()
return z.float(), pooled_output
def encode(self, tokens):