mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-08-02 15:04:50 +08:00
Add --fast argument to enable experimental optimizations.
Optimizations that might break things/lower quality will be put behind this flag first and might be enabled by default in the future. Currently the only optimization is float8_e4m3fn matrix multiplication on 4000/ADA series Nvidia cards or later. If you have one of these cards you will see a speed boost when using fp8_e4m3fn flux for example.
This commit is contained in:
41
comfy/ops.py
41
comfy/ops.py
@@ -18,7 +18,7 @@
|
||||
|
||||
import torch
|
||||
import comfy.model_management
|
||||
|
||||
from comfy.cli_args import args
|
||||
|
||||
def cast_to(weight, dtype=None, device=None, non_blocking=False):
|
||||
if (dtype is None or weight.dtype == dtype) and (device is None or weight.device == device):
|
||||
@@ -242,3 +242,42 @@ class manual_cast(disable_weight_init):
|
||||
|
||||
class Embedding(disable_weight_init.Embedding):
|
||||
comfy_cast_weights = True
|
||||
|
||||
|
||||
def fp8_linear(self, input):
|
||||
dtype = self.weight.dtype
|
||||
if dtype not in [torch.float8_e4m3fn]:
|
||||
return None
|
||||
|
||||
if len(input.shape) == 3:
|
||||
out = torch.empty((input.shape[0], input.shape[1], self.weight.shape[0]), device=input.device, dtype=input.dtype)
|
||||
inn = input.to(dtype)
|
||||
non_blocking = comfy.model_management.device_supports_non_blocking(input.device)
|
||||
w = cast_to(self.weight, device=input.device, non_blocking=non_blocking).t()
|
||||
for i in range(input.shape[0]):
|
||||
if self.bias is not None:
|
||||
o, _ = torch._scaled_mm(inn[i], w, out_dtype=input.dtype, bias=cast_to_input(self.bias, input, non_blocking=non_blocking))
|
||||
else:
|
||||
o, _ = torch._scaled_mm(inn[i], w, out_dtype=input.dtype)
|
||||
out[i] = o
|
||||
return out
|
||||
return None
|
||||
|
||||
class fp8_ops(manual_cast):
|
||||
class Linear(manual_cast.Linear):
|
||||
def forward_comfy_cast_weights(self, input):
|
||||
out = fp8_linear(self, input)
|
||||
if out is not None:
|
||||
return out
|
||||
|
||||
weight, bias = cast_bias_weight(self, input)
|
||||
return torch.nn.functional.linear(input, weight, bias)
|
||||
|
||||
|
||||
def pick_operations(weight_dtype, compute_dtype, load_device=None):
|
||||
if compute_dtype is None or weight_dtype == compute_dtype:
|
||||
return disable_weight_init
|
||||
if args.fast:
|
||||
if comfy.model_management.supports_fp8_compute(load_device):
|
||||
return fp8_ops
|
||||
return manual_cast
|
||||
|
Reference in New Issue
Block a user