1
mirror of https://github.com/comfyanonymous/ComfyUI.git synced 2025-08-02 23:14:49 +08:00

Lower T5 memory usage by a few hundred MB.

This commit is contained in:
comfyanonymous
2024-07-31 00:52:34 -04:00
parent 82cae45d44
commit b85216a3c0
3 changed files with 33 additions and 17 deletions

View File

@@ -19,17 +19,27 @@
import torch
import comfy.model_management
def cast_to_input(weight, input, non_blocking=False):
return weight.to(device=input.device, dtype=input.dtype, non_blocking=non_blocking)
def cast_bias_weight(s, input):
def cast_to(weight, dtype=None, device=None, non_blocking=False):
return weight.to(device=device, dtype=dtype, non_blocking=non_blocking)
def cast_to_input(weight, input, non_blocking=False):
return cast_to(weight, input.dtype, input.device, non_blocking=non_blocking)
def cast_bias_weight(s, input=None, dtype=None, device=None):
if input is not None:
if dtype is None:
dtype = input.dtype
if device is None:
device = input.device
bias = None
non_blocking = comfy.model_management.device_should_use_non_blocking(input.device)
non_blocking = comfy.model_management.device_should_use_non_blocking(device)
if s.bias is not None:
bias = cast_to_input(s.bias, input, non_blocking=non_blocking)
bias = cast_to(s.bias, dtype, device, non_blocking=non_blocking)
if s.bias_function is not None:
bias = s.bias_function(bias)
weight = cast_to_input(s.weight, input, non_blocking=non_blocking)
weight = cast_to(s.weight, dtype, device, non_blocking=non_blocking)
if s.weight_function is not None:
weight = s.weight_function(weight)
return weight, bias
@@ -176,14 +186,19 @@ class disable_weight_init:
self.bias = None
return None
def forward_comfy_cast_weights(self, input):
weight, bias = cast_bias_weight(self, input)
return torch.nn.functional.embedding(input, weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse)
def forward_comfy_cast_weights(self, input, out_dtype=None):
output_dtype = out_dtype
if self.weight.dtype == torch.float16 or self.weight.dtype == torch.bfloat16:
out_dtype = None
weight, bias = cast_bias_weight(self, device=input.device, dtype=out_dtype)
return torch.nn.functional.embedding(input, weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse).to(dtype=output_dtype)
def forward(self, *args, **kwargs):
if self.comfy_cast_weights:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
if "out_dtype" in kwargs:
kwargs.pop("out_dtype")
return super().forward(*args, **kwargs)
@classmethod