1
mirror of https://github.com/comfyanonymous/ComfyUI.git synced 2025-08-02 23:14:49 +08:00

Support diffusion models with scaled fp8 weights.

This commit is contained in:
comfyanonymous
2024-10-19 23:47:42 -04:00
parent 73e3a9e676
commit a68bbafddb
7 changed files with 95 additions and 23 deletions

View File

@@ -19,6 +19,7 @@
import torch
import comfy.model_management
from comfy.cli_args import args
import comfy.float
cast_to = comfy.model_management.cast_to #TODO: remove once no more references
@@ -250,18 +251,18 @@ def fp8_linear(self, input):
return None
if len(input.shape) == 3:
inn = input.reshape(-1, input.shape[2]).to(dtype)
w, bias = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input.dtype)
w = w.t()
scale_weight = self.scale_weight
scale_input = self.scale_input
if scale_weight is None:
scale_weight = torch.ones((1), device=input.device, dtype=torch.float32)
if scale_input is None:
scale_input = scale_weight
scale_weight = torch.ones((), device=input.device, dtype=torch.float32)
if scale_input is None:
scale_input = torch.ones((1), device=input.device, dtype=torch.float32)
scale_input = torch.ones((), device=input.device, dtype=torch.float32)
inn = input.reshape(-1, input.shape[2]).to(dtype)
else:
inn = (input * (1.0 / scale_input).to(input.dtype)).reshape(-1, input.shape[2]).to(dtype)
if bias is not None:
o = torch._scaled_mm(inn, w, out_dtype=input.dtype, bias=bias, scale_a=scale_input, scale_b=scale_weight)
@@ -289,15 +290,46 @@ class fp8_ops(manual_cast):
weight, bias = cast_bias_weight(self, input)
return torch.nn.functional.linear(input, weight, bias)
def scaled_fp8_ops(fp8_matrix_mult=False):
class scaled_fp8_op(manual_cast):
class Linear(manual_cast.Linear):
def reset_parameters(self):
if not hasattr(self, 'scale_weight'):
self.scale_weight = torch.nn.parameter.Parameter(data=torch.ones((), device=self.weight.device, dtype=torch.float32), requires_grad=False)
if not hasattr(self, 'scale_input'):
self.scale_input = torch.nn.parameter.Parameter(data=torch.ones((), device=self.weight.device, dtype=torch.float32), requires_grad=False)
return None
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False):
if comfy.model_management.supports_fp8_compute(load_device):
if (fp8_optimizations or args.fast) and not disable_fast_fp8:
return fp8_ops
def forward_comfy_cast_weights(self, input):
if fp8_matrix_mult:
out = fp8_linear(self, input)
if out is not None:
return out
weight, bias = cast_bias_weight(self, input)
return torch.nn.functional.linear(input, weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype), bias)
def convert_weight(self, weight):
return weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype)
def set_weight(self, weight, inplace_update=False, seed=None, **kwargs):
weight = comfy.float.stochastic_rounding(weight / self.scale_weight.to(device=weight.device, dtype=weight.dtype), self.weight.dtype, seed=seed)
if inplace_update:
self.weight.data.copy_(weight)
else:
self.weight = torch.nn.Parameter(weight, requires_grad=False)
return scaled_fp8_op
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=False):
fp8_compute = comfy.model_management.supports_fp8_compute(load_device)
if scaled_fp8:
return scaled_fp8_ops(fp8_matrix_mult=fp8_compute)
if fp8_compute and (fp8_optimizations or args.fast) and not disable_fast_fp8:
return fp8_ops
if compute_dtype is None or weight_dtype == compute_dtype:
return disable_weight_init
if args.fast and not disable_fast_fp8:
if comfy.model_management.supports_fp8_compute(load_device):
return fp8_ops
return manual_cast