1
mirror of https://github.com/comfyanonymous/ComfyUI.git synced 2025-08-03 07:26:31 +08:00

Make the stochastic fp8 rounding reproducible.

This commit is contained in:
comfyanonymous
2024-08-26 15:12:06 -04:00
parent 7985ff88b9
commit 2ca8f6e23d
2 changed files with 21 additions and 7 deletions

View File

@@ -1,18 +1,18 @@
import torch
import math
def calc_mantissa(abs_x, exponent, normal_mask, MANTISSA_BITS, EXPONENT_BIAS):
def calc_mantissa(abs_x, exponent, normal_mask, MANTISSA_BITS, EXPONENT_BIAS, generator=None):
mantissa_scaled = torch.where(
normal_mask,
(abs_x / (2.0 ** (exponent - EXPONENT_BIAS)) - 1.0) * (2**MANTISSA_BITS),
(abs_x / (2.0 ** (-EXPONENT_BIAS + 1 - MANTISSA_BITS)))
)
mantissa_scaled += torch.rand_like(mantissa_scaled)
mantissa_scaled += torch.rand(mantissa_scaled.size(), dtype=mantissa_scaled.dtype, layout=mantissa_scaled.layout, device=mantissa_scaled.device, generator=generator)
return mantissa_scaled.floor() / (2**MANTISSA_BITS)
#Not 100% sure about this
def manual_stochastic_round_to_float8(x, dtype):
def manual_stochastic_round_to_float8(x, dtype, generator=None):
if dtype == torch.float8_e4m3fn:
EXPONENT_BITS, MANTISSA_BITS, EXPONENT_BIAS = 4, 3, 7
elif dtype == torch.float8_e5m2:
@@ -34,7 +34,7 @@ def manual_stochastic_round_to_float8(x, dtype):
# Combine mantissa calculation and rounding
normal_mask = ~(exponent == 0)
abs_x[:] = calc_mantissa(abs_x, exponent, normal_mask, MANTISSA_BITS, EXPONENT_BIAS)
abs_x[:] = calc_mantissa(abs_x, exponent, normal_mask, MANTISSA_BITS, EXPONENT_BIAS, generator=generator)
sign *= torch.where(
normal_mask,
@@ -47,7 +47,7 @@ def manual_stochastic_round_to_float8(x, dtype):
def stochastic_rounding(value, dtype):
def stochastic_rounding(value, dtype, seed=0):
if dtype == torch.float32:
return value.to(dtype=torch.float32)
if dtype == torch.float16:
@@ -55,6 +55,8 @@ def stochastic_rounding(value, dtype):
if dtype == torch.bfloat16:
return value.to(dtype=torch.bfloat16)
if dtype == torch.float8_e4m3fn or dtype == torch.float8_e5m2:
return manual_stochastic_round_to_float8(value, dtype)
generator = torch.Generator(device=value.device)
generator.manual_seed(seed)
return manual_stochastic_round_to_float8(value, dtype, generator=generator)
return value.to(dtype=dtype)