mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-08-02 23:14:49 +08:00
Add CheckpointSave node to save checkpoints.
The created checkpoints contain workflow metadata that can be loaded by dragging them on top of the UI or loading them with the "Load" button. Checkpoints will be saved in fp16 or fp32 depending on the format ComfyUI is using for inference on your hardware. To force fp32 use: --force-fp32 Anything that patches the model weights like merging or loras will be saved. The output directory is currently set to: output/checkpoints but that might change in the future.
This commit is contained in:
32
comfy/sd.py
32
comfy/sd.py
@@ -545,11 +545,11 @@ class CLIP:
|
||||
if self.layer_idx is not None:
|
||||
self.cond_stage_model.clip_layer(self.layer_idx)
|
||||
try:
|
||||
self.patcher.patch_model()
|
||||
self.patch_model()
|
||||
cond, pooled = self.cond_stage_model.encode_token_weights(tokens)
|
||||
self.patcher.unpatch_model()
|
||||
self.unpatch_model()
|
||||
except Exception as e:
|
||||
self.patcher.unpatch_model()
|
||||
self.unpatch_model()
|
||||
raise e
|
||||
|
||||
cond_out = cond
|
||||
@@ -564,6 +564,15 @@ class CLIP:
|
||||
def load_sd(self, sd):
|
||||
return self.cond_stage_model.load_sd(sd)
|
||||
|
||||
def get_sd(self):
|
||||
return self.cond_stage_model.state_dict()
|
||||
|
||||
def patch_model(self):
|
||||
self.patcher.patch_model()
|
||||
|
||||
def unpatch_model(self):
|
||||
self.patcher.unpatch_model()
|
||||
|
||||
class VAE:
|
||||
def __init__(self, ckpt_path=None, device=None, config=None):
|
||||
if config is None:
|
||||
@@ -665,6 +674,10 @@ class VAE:
|
||||
self.first_stage_model = self.first_stage_model.cpu()
|
||||
return samples
|
||||
|
||||
def get_sd(self):
|
||||
return self.first_stage_model.state_dict()
|
||||
|
||||
|
||||
def broadcast_image_to(tensor, target_batch_size, batched_number):
|
||||
current_batch_size = tensor.shape[0]
|
||||
#print(current_batch_size, target_batch_size)
|
||||
@@ -1135,3 +1148,16 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
|
||||
print("left over keys:", left_over)
|
||||
|
||||
return (ModelPatcher(model), clip, vae, clipvision)
|
||||
|
||||
def save_checkpoint(output_path, model, clip, vae, metadata=None):
|
||||
try:
|
||||
model.patch_model()
|
||||
clip.patch_model()
|
||||
sd = model.model.state_dict_for_saving(clip.get_sd(), vae.get_sd())
|
||||
utils.save_torch_file(sd, output_path, metadata=metadata)
|
||||
model.unpatch_model()
|
||||
clip.unpatch_model()
|
||||
except Exception as e:
|
||||
model.unpatch_model()
|
||||
clip.unpatch_model()
|
||||
raise e
|
||||
|
Reference in New Issue
Block a user