1
mirror of https://github.com/comfyanonymous/ComfyUI.git synced 2025-08-02 23:14:49 +08:00

Support fp8_scaled diffusion models that don't use fp8 matrix mult.

This commit is contained in:
comfyanonymous
2025-03-07 04:37:58 -05:00
parent e62d72e8ca
commit e1474150de
3 changed files with 8 additions and 2 deletions

View File

@@ -17,6 +17,7 @@
"""
import torch
import logging
import comfy.model_management
from comfy.cli_args import args, PerformanceFeature
import comfy.float
@@ -308,6 +309,7 @@ class fp8_ops(manual_cast):
return torch.nn.functional.linear(input, weight, bias)
def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None):
logging.info("Using scaled fp8: fp8 matrix mult: {}, scale input: {}".format(fp8_matrix_mult, scale_input))
class scaled_fp8_op(manual_cast):
class Linear(manual_cast.Linear):
def __init__(self, *args, **kwargs):
@@ -358,7 +360,7 @@ def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None):
fp8_compute = comfy.model_management.supports_fp8_compute(load_device)
if scaled_fp8 is not None:
return scaled_fp8_ops(fp8_matrix_mult=fp8_compute, scale_input=True, override_dtype=scaled_fp8)
return scaled_fp8_ops(fp8_matrix_mult=fp8_compute and fp8_optimizations, scale_input=True, override_dtype=scaled_fp8)
if (
fp8_compute and