1
mirror of https://github.com/comfyanonymous/ComfyUI.git synced 2025-08-02 15:04:50 +08:00

Support scaled fp8 t5xxl model.

This commit is contained in:
comfyanonymous
2024-10-20 22:27:00 -04:00
parent f9f9faface
commit 83ca891118
6 changed files with 63 additions and 30 deletions

View File

@@ -290,12 +290,21 @@ class fp8_ops(manual_cast):
weight, bias = cast_bias_weight(self, input)
return torch.nn.functional.linear(input, weight, bias)
def scaled_fp8_ops(fp8_matrix_mult=False):
def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None):
class scaled_fp8_op(manual_cast):
class Linear(manual_cast.Linear):
def __init__(self, *args, **kwargs):
if override_dtype is not None:
kwargs['dtype'] = override_dtype
super().__init__(*args, **kwargs)
def reset_parameters(self):
if not hasattr(self, 'scale_weight'):
self.scale_weight = torch.nn.parameter.Parameter(data=torch.ones((), device=self.weight.device, dtype=torch.float32), requires_grad=False)
if not scale_input:
self.scale_input = None
if not hasattr(self, 'scale_input'):
self.scale_input = torch.nn.parameter.Parameter(data=torch.ones((), device=self.weight.device, dtype=torch.float32), requires_grad=False)
return None
@@ -328,7 +337,7 @@ def scaled_fp8_ops(fp8_matrix_mult=False):
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=False):
fp8_compute = comfy.model_management.supports_fp8_compute(load_device)
if scaled_fp8:
return scaled_fp8_ops(fp8_matrix_mult=fp8_compute)
return scaled_fp8_ops(fp8_matrix_mult=fp8_compute, scale_input=True)
if fp8_compute and (fp8_optimizations or args.fast) and not disable_fast_fp8:
return fp8_ops