1
mirror of https://github.com/comfyanonymous/ComfyUI.git synced 2025-08-02 23:14:49 +08:00

Add CublasOps support (#7574)

* CublasOps support

* Guard CublasOps behind --fast arg
This commit is contained in:
catboxanon
2025-04-12 18:29:15 -04:00
committed by GitHub
parent 73ecb75a3d
commit 1714a4c158
2 changed files with 30 additions and 1 deletions

View File

@@ -357,6 +357,25 @@ def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None
return scaled_fp8_op
CUBLAS_IS_AVAILABLE = False
try:
from cublas_ops import CublasLinear
CUBLAS_IS_AVAILABLE = True
except ImportError:
pass
if CUBLAS_IS_AVAILABLE:
class cublas_ops(disable_weight_init):
class Linear(CublasLinear, disable_weight_init.Linear):
def reset_parameters(self):
return None
def forward_comfy_cast_weights(self, input):
return super().forward(input)
def forward(self, *args, **kwargs):
return super().forward(*args, **kwargs)
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None):
fp8_compute = comfy.model_management.supports_fp8_compute(load_device)
if scaled_fp8 is not None:
@@ -369,6 +388,15 @@ def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_
):
return fp8_ops
if (
PerformanceFeature.CublasOps in args.fast and
CUBLAS_IS_AVAILABLE and
weight_dtype == torch.float16 and
(compute_dtype == torch.float16 or compute_dtype is None)
):
logging.info("Using cublas ops")
return cublas_ops
if compute_dtype is None or weight_dtype == compute_dtype:
return disable_weight_init