1
mirror of https://github.com/comfyanonymous/ComfyUI.git synced 2025-08-02 15:04:50 +08:00

Add --fast argument to enable experimental optimizations.

Optimizations that might break things/lower quality will be put behind
this flag first and might be enabled by default in the future.

Currently the only optimization is float8_e4m3fn matrix multiplication on
4000/ADA series Nvidia cards or later. If you have one of these cards you
will see a speed boost when using fp8_e4m3fn flux for example.
This commit is contained in:
comfyanonymous
2024-08-20 11:49:33 -04:00
parent d1a6bd6845
commit 9953f22fce
4 changed files with 52 additions and 5 deletions

View File

@@ -18,7 +18,7 @@
import torch
import comfy.model_management
from comfy.cli_args import args
def cast_to(weight, dtype=None, device=None, non_blocking=False):
if (dtype is None or weight.dtype == dtype) and (device is None or weight.device == device):
@@ -242,3 +242,42 @@ class manual_cast(disable_weight_init):
class Embedding(disable_weight_init.Embedding):
comfy_cast_weights = True
def fp8_linear(self, input):
dtype = self.weight.dtype
if dtype not in [torch.float8_e4m3fn]:
return None
if len(input.shape) == 3:
out = torch.empty((input.shape[0], input.shape[1], self.weight.shape[0]), device=input.device, dtype=input.dtype)
inn = input.to(dtype)
non_blocking = comfy.model_management.device_supports_non_blocking(input.device)
w = cast_to(self.weight, device=input.device, non_blocking=non_blocking).t()
for i in range(input.shape[0]):
if self.bias is not None:
o, _ = torch._scaled_mm(inn[i], w, out_dtype=input.dtype, bias=cast_to_input(self.bias, input, non_blocking=non_blocking))
else:
o, _ = torch._scaled_mm(inn[i], w, out_dtype=input.dtype)
out[i] = o
return out
return None
class fp8_ops(manual_cast):
class Linear(manual_cast.Linear):
def forward_comfy_cast_weights(self, input):
out = fp8_linear(self, input)
if out is not None:
return out
weight, bias = cast_bias_weight(self, input)
return torch.nn.functional.linear(input, weight, bias)
def pick_operations(weight_dtype, compute_dtype, load_device=None):
if compute_dtype is None or weight_dtype == compute_dtype:
return disable_weight_init
if args.fast:
if comfy.model_management.supports_fp8_compute(load_device):
return fp8_ops
return manual_cast