mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-08-02 23:14:49 +08:00
Lower T5 memory usage by a few hundred MB.
This commit is contained in:
33
comfy/ops.py
33
comfy/ops.py
@@ -19,17 +19,27 @@
|
||||
import torch
|
||||
import comfy.model_management
|
||||
|
||||
def cast_to_input(weight, input, non_blocking=False):
|
||||
return weight.to(device=input.device, dtype=input.dtype, non_blocking=non_blocking)
|
||||
|
||||
def cast_bias_weight(s, input):
|
||||
def cast_to(weight, dtype=None, device=None, non_blocking=False):
|
||||
return weight.to(device=device, dtype=dtype, non_blocking=non_blocking)
|
||||
|
||||
def cast_to_input(weight, input, non_blocking=False):
|
||||
return cast_to(weight, input.dtype, input.device, non_blocking=non_blocking)
|
||||
|
||||
def cast_bias_weight(s, input=None, dtype=None, device=None):
|
||||
if input is not None:
|
||||
if dtype is None:
|
||||
dtype = input.dtype
|
||||
if device is None:
|
||||
device = input.device
|
||||
|
||||
bias = None
|
||||
non_blocking = comfy.model_management.device_should_use_non_blocking(input.device)
|
||||
non_blocking = comfy.model_management.device_should_use_non_blocking(device)
|
||||
if s.bias is not None:
|
||||
bias = cast_to_input(s.bias, input, non_blocking=non_blocking)
|
||||
bias = cast_to(s.bias, dtype, device, non_blocking=non_blocking)
|
||||
if s.bias_function is not None:
|
||||
bias = s.bias_function(bias)
|
||||
weight = cast_to_input(s.weight, input, non_blocking=non_blocking)
|
||||
weight = cast_to(s.weight, dtype, device, non_blocking=non_blocking)
|
||||
if s.weight_function is not None:
|
||||
weight = s.weight_function(weight)
|
||||
return weight, bias
|
||||
@@ -176,14 +186,19 @@ class disable_weight_init:
|
||||
self.bias = None
|
||||
return None
|
||||
|
||||
def forward_comfy_cast_weights(self, input):
|
||||
weight, bias = cast_bias_weight(self, input)
|
||||
return torch.nn.functional.embedding(input, weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse)
|
||||
def forward_comfy_cast_weights(self, input, out_dtype=None):
|
||||
output_dtype = out_dtype
|
||||
if self.weight.dtype == torch.float16 or self.weight.dtype == torch.bfloat16:
|
||||
out_dtype = None
|
||||
weight, bias = cast_bias_weight(self, device=input.device, dtype=out_dtype)
|
||||
return torch.nn.functional.embedding(input, weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse).to(dtype=output_dtype)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
if self.comfy_cast_weights:
|
||||
return self.forward_comfy_cast_weights(*args, **kwargs)
|
||||
else:
|
||||
if "out_dtype" in kwargs:
|
||||
kwargs.pop("out_dtype")
|
||||
return super().forward(*args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
|
Reference in New Issue
Block a user