1
mirror of https://github.com/comfyanonymous/ComfyUI.git synced 2025-08-02 23:14:49 +08:00

Automatically use fp8 for diffusion model weights if:

Checkpoint contains weights in fp8.

There isn't enough memory to load the diffusion model in GPU vram.
This commit is contained in:
comfyanonymous
2024-08-03 13:45:19 -04:00
parent f123328b82
commit ba9095e5bd
4 changed files with 34 additions and 4 deletions

View File

@@ -510,13 +510,14 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
diffusion_model_prefix = model_detection.unet_prefix_from_state_dict(sd)
parameters = comfy.utils.calculate_parameters(sd, diffusion_model_prefix)
weight_dtype = comfy.utils.weight_dtype(sd, diffusion_model_prefix)
load_device = model_management.get_torch_device()
model_config = model_detection.model_config_from_unet(sd, diffusion_model_prefix)
if model_config is None:
raise RuntimeError("ERROR: Could not detect model type of: {}".format(ckpt_path))
unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=model_config.supported_inference_dtypes)
unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=[weight_dtype] + model_config.supported_inference_dtypes)
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)