mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-08-02 23:14:49 +08:00
Support properly saving CosXL checkpoints.
This commit is contained in:
@@ -2,7 +2,9 @@ import comfy.sd
|
||||
import comfy.utils
|
||||
import comfy.model_base
|
||||
import comfy.model_management
|
||||
import comfy.model_sampling
|
||||
|
||||
import torch
|
||||
import folder_paths
|
||||
import json
|
||||
import os
|
||||
@@ -189,6 +191,13 @@ def save_checkpoint(model, clip=None, vae=None, clip_vision=None, filename_prefi
|
||||
# "stable-diffusion-v2-768-v", "stable-diffusion-v2-unclip-l", "stable-diffusion-v2-unclip-h",
|
||||
# "v2-inpainting"
|
||||
|
||||
extra_keys = {}
|
||||
model_sampling = model.get_model_object("model_sampling")
|
||||
if isinstance(model_sampling, comfy.model_sampling.ModelSamplingContinuousEDM):
|
||||
if isinstance(model_sampling, comfy.model_sampling.V_PREDICTION):
|
||||
extra_keys["edm_vpred.sigma_max"] = torch.tensor(model_sampling.sigma_max).float()
|
||||
extra_keys["edm_vpred.sigma_min"] = torch.tensor(model_sampling.sigma_min).float()
|
||||
|
||||
if model.model.model_type == comfy.model_base.ModelType.EPS:
|
||||
metadata["modelspec.predict_key"] = "epsilon"
|
||||
elif model.model.model_type == comfy.model_base.ModelType.V_PREDICTION:
|
||||
@@ -203,7 +212,7 @@ def save_checkpoint(model, clip=None, vae=None, clip_vision=None, filename_prefi
|
||||
output_checkpoint = f"{filename}_{counter:05}_.safetensors"
|
||||
output_checkpoint = os.path.join(full_output_folder, output_checkpoint)
|
||||
|
||||
comfy.sd.save_checkpoint(output_checkpoint, model, clip, vae, clip_vision, metadata=metadata)
|
||||
comfy.sd.save_checkpoint(output_checkpoint, model, clip, vae, clip_vision, metadata=metadata, extra_keys=extra_keys)
|
||||
|
||||
class CheckpointSave:
|
||||
def __init__(self):
|
||||
|
Reference in New Issue
Block a user