mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-08-02 23:14:49 +08:00
Switch text encoder to manual cast.
Use fp16 text encoder weights for CPU inference to lower memory usage.
This commit is contained in:
33
comfy/ops.py
33
comfy/ops.py
@@ -29,6 +29,39 @@ def conv_nd(dims, *args, **kwargs):
|
||||
else:
|
||||
raise ValueError(f"unsupported dimensions: {dims}")
|
||||
|
||||
def cast_bias_weight(s, input):
|
||||
bias = None
|
||||
if s.bias is not None:
|
||||
bias = s.bias.to(device=input.device, dtype=input.dtype)
|
||||
weight = s.weight.to(device=input.device, dtype=input.dtype)
|
||||
return weight, bias
|
||||
|
||||
class manual_cast:
|
||||
class Linear(Linear):
|
||||
def forward(self, input):
|
||||
weight, bias = cast_bias_weight(self, input)
|
||||
return torch.nn.functional.linear(input, weight, bias)
|
||||
|
||||
class Conv2d(Conv2d):
|
||||
def forward(self, input):
|
||||
weight, bias = cast_bias_weight(self, input)
|
||||
return self._conv_forward(input, weight, bias)
|
||||
|
||||
class Conv3d(Conv3d):
|
||||
def forward(self, input):
|
||||
weight, bias = cast_bias_weight(self, input)
|
||||
return self._conv_forward(input, weight, bias)
|
||||
|
||||
class GroupNorm(GroupNorm):
|
||||
def forward(self, input):
|
||||
weight, bias = cast_bias_weight(self, input)
|
||||
return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps)
|
||||
|
||||
class LayerNorm(LayerNorm):
|
||||
def forward(self, input):
|
||||
weight, bias = cast_bias_weight(self, input)
|
||||
return torch.nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps)
|
||||
|
||||
@contextmanager
|
||||
def use_comfy_ops(device=None, dtype=None): # Kind of an ugly hack but I can't think of a better way
|
||||
old_torch_nn_linear = torch.nn.Linear
|
||||
|
Reference in New Issue
Block a user