mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-08-02 15:04:50 +08:00
LoRA Trainer: LoRA training node in weight adapter scheme (#8446)
This commit is contained in:
705
comfy_extras/nodes_train.py
Normal file
705
comfy_extras/nodes_train.py
Normal file
@@ -0,0 +1,705 @@
|
||||
import datetime
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import safetensors
|
||||
import torch
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
from PIL.PngImagePlugin import PngInfo
|
||||
import torch.utils.checkpoint
|
||||
import tqdm
|
||||
|
||||
import comfy.samplers
|
||||
import comfy.sd
|
||||
import comfy.utils
|
||||
import comfy.model_management
|
||||
import comfy_extras.nodes_custom_sampler
|
||||
import folder_paths
|
||||
import node_helpers
|
||||
from comfy.cli_args import args
|
||||
from comfy.comfy_types.node_typing import IO
|
||||
from comfy.weight_adapter import adapters
|
||||
|
||||
|
||||
class TrainSampler(comfy.samplers.Sampler):
|
||||
|
||||
def __init__(self, loss_fn, optimizer, loss_callback=None):
|
||||
self.loss_fn = loss_fn
|
||||
self.optimizer = optimizer
|
||||
self.loss_callback = loss_callback
|
||||
|
||||
def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False):
|
||||
self.optimizer.zero_grad()
|
||||
noise = model_wrap.inner_model.model_sampling.noise_scaling(sigmas, noise, latent_image, False)
|
||||
latent = model_wrap.inner_model.model_sampling.noise_scaling(
|
||||
torch.zeros_like(sigmas),
|
||||
torch.zeros_like(noise, requires_grad=True),
|
||||
latent_image,
|
||||
False
|
||||
)
|
||||
|
||||
# Ensure model is in training mode and computing gradients
|
||||
# x0 pred
|
||||
denoised = model_wrap(noise, sigmas, **extra_args)
|
||||
try:
|
||||
loss = self.loss_fn(denoised, latent.clone())
|
||||
except RuntimeError as e:
|
||||
if "does not require grad and does not have a grad_fn" in str(e):
|
||||
logging.info("WARNING: This is likely due to the model is loaded in inference mode.")
|
||||
loss.backward()
|
||||
if self.loss_callback:
|
||||
self.loss_callback(loss.item())
|
||||
|
||||
self.optimizer.step()
|
||||
# torch.cuda.memory._dump_snapshot("trainn.pickle")
|
||||
# torch.cuda.memory._record_memory_history(enabled=None)
|
||||
return torch.zeros_like(latent_image)
|
||||
|
||||
|
||||
class BiasDiff(torch.nn.Module):
|
||||
def __init__(self, bias):
|
||||
super().__init__()
|
||||
self.bias = bias
|
||||
|
||||
def __call__(self, b):
|
||||
org_dtype = b.dtype
|
||||
return (b.to(self.bias) + self.bias).to(org_dtype)
|
||||
|
||||
def passive_memory_usage(self):
|
||||
return self.bias.nelement() * self.bias.element_size()
|
||||
|
||||
def move_to(self, device):
|
||||
self.to(device=device)
|
||||
return self.passive_memory_usage()
|
||||
|
||||
|
||||
def load_and_process_images(image_files, input_dir, resize_method="None"):
|
||||
"""Utility function to load and process a list of images.
|
||||
|
||||
Args:
|
||||
image_files: List of image filenames
|
||||
input_dir: Base directory containing the images
|
||||
resize_method: How to handle images of different sizes ("None", "Stretch", "Crop", "Pad")
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Batch of processed images
|
||||
"""
|
||||
if not image_files:
|
||||
raise ValueError("No valid images found in input")
|
||||
|
||||
output_images = []
|
||||
w, h = None, None
|
||||
|
||||
for file in image_files:
|
||||
image_path = os.path.join(input_dir, file)
|
||||
img = node_helpers.pillow(Image.open, image_path)
|
||||
|
||||
if img.mode == "I":
|
||||
img = img.point(lambda i: i * (1 / 255))
|
||||
img = img.convert("RGB")
|
||||
|
||||
if w is None and h is None:
|
||||
w, h = img.size[0], img.size[1]
|
||||
|
||||
# Resize image to first image
|
||||
if img.size[0] != w or img.size[1] != h:
|
||||
if resize_method == "Stretch":
|
||||
img = img.resize((w, h), Image.Resampling.LANCZOS)
|
||||
elif resize_method == "Crop":
|
||||
img = img.crop((0, 0, w, h))
|
||||
elif resize_method == "Pad":
|
||||
img = img.resize((w, h), Image.Resampling.LANCZOS)
|
||||
elif resize_method == "None":
|
||||
raise ValueError(
|
||||
"Your input image size does not match the first image in the dataset. Either select a valid resize method or use the same size for all images."
|
||||
)
|
||||
|
||||
img_array = np.array(img).astype(np.float32) / 255.0
|
||||
img_tensor = torch.from_numpy(img_array)[None,]
|
||||
output_images.append(img_tensor)
|
||||
|
||||
return torch.cat(output_images, dim=0)
|
||||
|
||||
|
||||
class LoadImageSetNode:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"images": (
|
||||
[
|
||||
f
|
||||
for f in os.listdir(folder_paths.get_input_directory())
|
||||
if f.endswith((".png", ".jpg", ".jpeg", ".webp", ".bmp", ".gif", ".jpe", ".apng", ".tif", ".tiff"))
|
||||
],
|
||||
{"image_upload": True, "allow_batch": True},
|
||||
)
|
||||
},
|
||||
"optional": {
|
||||
"resize_method": (
|
||||
["None", "Stretch", "Crop", "Pad"],
|
||||
{"default": "None"},
|
||||
),
|
||||
},
|
||||
}
|
||||
|
||||
INPUT_IS_LIST = True
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "load_images"
|
||||
CATEGORY = "loaders"
|
||||
EXPERIMENTAL = True
|
||||
DESCRIPTION = "Loads a batch of images from a directory for training."
|
||||
|
||||
@classmethod
|
||||
def VALIDATE_INPUTS(s, images, resize_method):
|
||||
filenames = images[0] if isinstance(images[0], list) else images
|
||||
|
||||
for image in filenames:
|
||||
if not folder_paths.exists_annotated_filepath(image):
|
||||
return "Invalid image file: {}".format(image)
|
||||
return True
|
||||
|
||||
def load_images(self, input_files, resize_method):
|
||||
input_dir = folder_paths.get_input_directory()
|
||||
valid_extensions = [".png", ".jpg", ".jpeg", ".webp", ".bmp", ".gif", ".jpe", ".apng", ".tif", ".tiff"]
|
||||
image_files = [
|
||||
f
|
||||
for f in input_files
|
||||
if any(f.lower().endswith(ext) for ext in valid_extensions)
|
||||
]
|
||||
output_tensor = load_and_process_images(image_files, input_dir, resize_method)
|
||||
return (output_tensor,)
|
||||
|
||||
|
||||
class LoadImageSetFromFolderNode:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"folder": (folder_paths.get_input_subfolders(), {"tooltip": "The folder to load images from."})
|
||||
},
|
||||
"optional": {
|
||||
"resize_method": (
|
||||
["None", "Stretch", "Crop", "Pad"],
|
||||
{"default": "None"},
|
||||
),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "load_images"
|
||||
CATEGORY = "loaders"
|
||||
EXPERIMENTAL = True
|
||||
DESCRIPTION = "Loads a batch of images from a directory for training."
|
||||
|
||||
def load_images(self, folder, resize_method):
|
||||
sub_input_dir = os.path.join(folder_paths.get_input_directory(), folder)
|
||||
valid_extensions = [".png", ".jpg", ".jpeg", ".webp"]
|
||||
image_files = [
|
||||
f
|
||||
for f in os.listdir(sub_input_dir)
|
||||
if any(f.lower().endswith(ext) for ext in valid_extensions)
|
||||
]
|
||||
output_tensor = load_and_process_images(image_files, sub_input_dir, resize_method)
|
||||
return (output_tensor,)
|
||||
|
||||
|
||||
def draw_loss_graph(loss_map, steps):
|
||||
width, height = 500, 300
|
||||
img = Image.new("RGB", (width, height), "white")
|
||||
draw = ImageDraw.Draw(img)
|
||||
|
||||
min_loss, max_loss = min(loss_map.values()), max(loss_map.values())
|
||||
scaled_loss = [(l - min_loss) / (max_loss - min_loss) for l in loss_map.values()]
|
||||
|
||||
prev_point = (0, height - int(scaled_loss[0] * height))
|
||||
for i, l in enumerate(scaled_loss[1:], start=1):
|
||||
x = int(i / (steps - 1) * width)
|
||||
y = height - int(l * height)
|
||||
draw.line([prev_point, (x, y)], fill="blue", width=2)
|
||||
prev_point = (x, y)
|
||||
|
||||
return img
|
||||
|
||||
|
||||
def find_all_highest_child_module_with_forward(model: torch.nn.Module, result = None, name = None):
|
||||
if result is None:
|
||||
result = []
|
||||
elif hasattr(model, "forward") and not isinstance(model, (torch.nn.ModuleList, torch.nn.Sequential, torch.nn.ModuleDict)):
|
||||
result.append(model)
|
||||
logging.debug(f"Found module with forward: {name} ({model.__class__.__name__})")
|
||||
return result
|
||||
name = name or "root"
|
||||
for next_name, child in model.named_children():
|
||||
find_all_highest_child_module_with_forward(child, result, f"{name}.{next_name}")
|
||||
return result
|
||||
|
||||
|
||||
def patch(m):
|
||||
if not hasattr(m, "forward"):
|
||||
return
|
||||
org_forward = m.forward
|
||||
def fwd(args, kwargs):
|
||||
return org_forward(*args, **kwargs)
|
||||
def checkpointing_fwd(*args, **kwargs):
|
||||
return torch.utils.checkpoint.checkpoint(
|
||||
fwd, args, kwargs, use_reentrant=False
|
||||
)
|
||||
m.org_forward = org_forward
|
||||
m.forward = checkpointing_fwd
|
||||
|
||||
|
||||
def unpatch(m):
|
||||
if hasattr(m, "org_forward"):
|
||||
m.forward = m.org_forward
|
||||
del m.org_forward
|
||||
|
||||
|
||||
class TrainLoraNode:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"model": (IO.MODEL, {"tooltip": "The model to train the LoRA on."}),
|
||||
"latents": (
|
||||
"LATENT",
|
||||
{
|
||||
"tooltip": "The Latents to use for training, serve as dataset/input of the model."
|
||||
},
|
||||
),
|
||||
"positive": (
|
||||
IO.CONDITIONING,
|
||||
{"tooltip": "The positive conditioning to use for training."},
|
||||
),
|
||||
"batch_size": (
|
||||
IO.INT,
|
||||
{
|
||||
"default": 1,
|
||||
"min": 1,
|
||||
"max": 10000,
|
||||
"step": 1,
|
||||
"tooltip": "The batch size to use for training.",
|
||||
},
|
||||
),
|
||||
"steps": (
|
||||
IO.INT,
|
||||
{
|
||||
"default": 16,
|
||||
"min": 1,
|
||||
"max": 100000,
|
||||
"tooltip": "The number of steps to train the LoRA for.",
|
||||
},
|
||||
),
|
||||
"learning_rate": (
|
||||
IO.FLOAT,
|
||||
{
|
||||
"default": 0.0005,
|
||||
"min": 0.0000001,
|
||||
"max": 1.0,
|
||||
"step": 0.000001,
|
||||
"tooltip": "The learning rate to use for training.",
|
||||
},
|
||||
),
|
||||
"rank": (
|
||||
IO.INT,
|
||||
{
|
||||
"default": 8,
|
||||
"min": 1,
|
||||
"max": 128,
|
||||
"tooltip": "The rank of the LoRA layers.",
|
||||
},
|
||||
),
|
||||
"optimizer": (
|
||||
["AdamW", "Adam", "SGD", "RMSprop"],
|
||||
{
|
||||
"default": "AdamW",
|
||||
"tooltip": "The optimizer to use for training.",
|
||||
},
|
||||
),
|
||||
"loss_function": (
|
||||
["MSE", "L1", "Huber", "SmoothL1"],
|
||||
{
|
||||
"default": "MSE",
|
||||
"tooltip": "The loss function to use for training.",
|
||||
},
|
||||
),
|
||||
"seed": (
|
||||
IO.INT,
|
||||
{
|
||||
"default": 0,
|
||||
"min": 0,
|
||||
"max": 0xFFFFFFFFFFFFFFFF,
|
||||
"tooltip": "The seed to use for training (used in generator for LoRA weight initialization and noise sampling)",
|
||||
},
|
||||
),
|
||||
"training_dtype": (
|
||||
["bf16", "fp32"],
|
||||
{"default": "bf16", "tooltip": "The dtype to use for training."},
|
||||
),
|
||||
"lora_dtype": (
|
||||
["bf16", "fp32"],
|
||||
{"default": "bf16", "tooltip": "The dtype to use for lora."},
|
||||
),
|
||||
"existing_lora": (
|
||||
folder_paths.get_filename_list("loras") + ["[None]"],
|
||||
{
|
||||
"default": "[None]",
|
||||
"tooltip": "The existing LoRA to append to. Set to None for new LoRA.",
|
||||
},
|
||||
),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = (IO.MODEL, IO.LORA_MODEL, IO.LOSS_MAP, IO.INT)
|
||||
RETURN_NAMES = ("model_with_lora", "lora", "loss", "steps")
|
||||
FUNCTION = "train"
|
||||
CATEGORY = "training"
|
||||
EXPERIMENTAL = True
|
||||
|
||||
def train(
|
||||
self,
|
||||
model,
|
||||
latents,
|
||||
positive,
|
||||
batch_size,
|
||||
steps,
|
||||
learning_rate,
|
||||
rank,
|
||||
optimizer,
|
||||
loss_function,
|
||||
seed,
|
||||
training_dtype,
|
||||
lora_dtype,
|
||||
existing_lora,
|
||||
):
|
||||
mp = model.clone()
|
||||
dtype = node_helpers.string_to_torch_dtype(training_dtype)
|
||||
lora_dtype = node_helpers.string_to_torch_dtype(lora_dtype)
|
||||
mp.set_model_compute_dtype(dtype)
|
||||
|
||||
latents = latents["samples"].to(dtype)
|
||||
num_images = latents.shape[0]
|
||||
|
||||
with torch.inference_mode(False):
|
||||
lora_sd = {}
|
||||
generator = torch.Generator()
|
||||
generator.manual_seed(seed)
|
||||
|
||||
# Load existing LoRA weights if provided
|
||||
existing_weights = {}
|
||||
existing_steps = 0
|
||||
if existing_lora != "[None]":
|
||||
lora_path = folder_paths.get_full_path_or_raise("loras", existing_lora)
|
||||
# Extract steps from filename like "trained_lora_10_steps_20250225_203716"
|
||||
existing_steps = int(existing_lora.split("_steps_")[0].split("_")[-1])
|
||||
if lora_path:
|
||||
existing_weights = comfy.utils.load_torch_file(lora_path)
|
||||
|
||||
all_weight_adapters = []
|
||||
for n, m in mp.model.named_modules():
|
||||
if hasattr(m, "weight_function"):
|
||||
if m.weight is not None:
|
||||
key = "{}.weight".format(n)
|
||||
shape = m.weight.shape
|
||||
if len(shape) >= 2:
|
||||
alpha = float(existing_weights.get(f"{key}.alpha", 1.0))
|
||||
dora_scale = existing_weights.get(
|
||||
f"{key}.dora_scale", None
|
||||
)
|
||||
for adapter_cls in adapters:
|
||||
existing_adapter = adapter_cls.load(
|
||||
n, existing_weights, alpha, dora_scale
|
||||
)
|
||||
if existing_adapter is not None:
|
||||
break
|
||||
else:
|
||||
# If no existing adapter found, use LoRA
|
||||
# We will add algo option in the future
|
||||
existing_adapter = None
|
||||
adapter_cls = adapters[0]
|
||||
|
||||
if existing_adapter is not None:
|
||||
train_adapter = existing_adapter.to_train().to(lora_dtype)
|
||||
else:
|
||||
# Use LoRA with alpha=1.0 by default
|
||||
train_adapter = adapter_cls.create_train(
|
||||
m.weight, rank=rank, alpha=1.0
|
||||
).to(lora_dtype)
|
||||
for name, parameter in train_adapter.named_parameters():
|
||||
lora_sd[f"{n}.{name}"] = parameter
|
||||
|
||||
mp.add_weight_wrapper(key, train_adapter)
|
||||
all_weight_adapters.append(train_adapter)
|
||||
else:
|
||||
diff = torch.nn.Parameter(
|
||||
torch.zeros(
|
||||
m.weight.shape, dtype=lora_dtype, requires_grad=True
|
||||
)
|
||||
)
|
||||
diff_module = BiasDiff(diff)
|
||||
mp.add_weight_wrapper(key, BiasDiff(diff))
|
||||
all_weight_adapters.append(diff_module)
|
||||
lora_sd["{}.diff".format(n)] = diff
|
||||
if hasattr(m, "bias") and m.bias is not None:
|
||||
key = "{}.bias".format(n)
|
||||
bias = torch.nn.Parameter(
|
||||
torch.zeros(m.bias.shape, dtype=lora_dtype, requires_grad=True)
|
||||
)
|
||||
bias_module = BiasDiff(bias)
|
||||
lora_sd["{}.diff_b".format(n)] = bias
|
||||
mp.add_weight_wrapper(key, BiasDiff(bias))
|
||||
all_weight_adapters.append(bias_module)
|
||||
|
||||
if optimizer == "Adam":
|
||||
optimizer = torch.optim.Adam(lora_sd.values(), lr=learning_rate)
|
||||
elif optimizer == "AdamW":
|
||||
optimizer = torch.optim.AdamW(lora_sd.values(), lr=learning_rate)
|
||||
elif optimizer == "SGD":
|
||||
optimizer = torch.optim.SGD(lora_sd.values(), lr=learning_rate)
|
||||
elif optimizer == "RMSprop":
|
||||
optimizer = torch.optim.RMSprop(lora_sd.values(), lr=learning_rate)
|
||||
|
||||
# Setup loss function based on selection
|
||||
if loss_function == "MSE":
|
||||
criterion = torch.nn.MSELoss()
|
||||
elif loss_function == "L1":
|
||||
criterion = torch.nn.L1Loss()
|
||||
elif loss_function == "Huber":
|
||||
criterion = torch.nn.HuberLoss()
|
||||
elif loss_function == "SmoothL1":
|
||||
criterion = torch.nn.SmoothL1Loss()
|
||||
|
||||
# setup models
|
||||
for m in find_all_highest_child_module_with_forward(mp.model.diffusion_model):
|
||||
patch(m)
|
||||
comfy.model_management.load_models_gpu([mp], memory_required=1e20, force_full_load=True)
|
||||
|
||||
# Setup sampler and guider like in test script
|
||||
loss_map = {"loss": []}
|
||||
def loss_callback(loss):
|
||||
loss_map["loss"].append(loss)
|
||||
pbar.set_postfix({"loss": f"{loss:.4f}"})
|
||||
train_sampler = TrainSampler(
|
||||
criterion, optimizer, loss_callback=loss_callback
|
||||
)
|
||||
guider = comfy_extras.nodes_custom_sampler.Guider_Basic(mp)
|
||||
guider.set_conds(positive) # Set conditioning from input
|
||||
ss = comfy_extras.nodes_custom_sampler.SamplerCustomAdvanced()
|
||||
|
||||
# yoland: this currently resize to the first image in the dataset
|
||||
|
||||
# Training loop
|
||||
torch.cuda.empty_cache()
|
||||
try:
|
||||
for step in (pbar:=tqdm.trange(steps, desc="Training LoRA", smoothing=0.01, disable=not comfy.utils.PROGRESS_BAR_ENABLED)):
|
||||
# Generate random sigma
|
||||
sigma = mp.model.model_sampling.percent_to_sigma(
|
||||
torch.rand((1,)).item()
|
||||
)
|
||||
sigma = torch.tensor([sigma])
|
||||
|
||||
noise = comfy_extras.nodes_custom_sampler.Noise_RandomNoise(step * 1000 + seed)
|
||||
|
||||
indices = torch.randperm(num_images)[:batch_size]
|
||||
ss.sample(
|
||||
noise, guider, train_sampler, sigma, {"samples": latents[indices].clone()}
|
||||
)
|
||||
finally:
|
||||
for m in mp.model.modules():
|
||||
unpatch(m)
|
||||
del ss, train_sampler, optimizer
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
for adapter in all_weight_adapters:
|
||||
adapter.requires_grad_(False)
|
||||
|
||||
for param in lora_sd:
|
||||
lora_sd[param] = lora_sd[param].to(lora_dtype)
|
||||
|
||||
return (mp, lora_sd, loss_map, steps + existing_steps)
|
||||
|
||||
|
||||
class LoraModelLoader:
|
||||
def __init__(self):
|
||||
self.loaded_lora = None
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"model": ("MODEL", {"tooltip": "The diffusion model the LoRA will be applied to."}),
|
||||
"lora": (IO.LORA_MODEL, {"tooltip": "The LoRA model to apply to the diffusion model."}),
|
||||
"strength_model": ("FLOAT", {"default": 1.0, "min": -100.0, "max": 100.0, "step": 0.01, "tooltip": "How strongly to modify the diffusion model. This value can be negative."}),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
OUTPUT_TOOLTIPS = ("The modified diffusion model.",)
|
||||
FUNCTION = "load_lora_model"
|
||||
|
||||
CATEGORY = "loaders"
|
||||
DESCRIPTION = "Load Trained LoRA weights from Train LoRA node."
|
||||
EXPERIMENTAL = True
|
||||
|
||||
def load_lora_model(self, model, lora, strength_model):
|
||||
if strength_model == 0:
|
||||
return (model, )
|
||||
|
||||
model_lora, _ = comfy.sd.load_lora_for_models(model, None, lora, strength_model, 0)
|
||||
return (model_lora, )
|
||||
|
||||
|
||||
class SaveLoRA:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"lora": (
|
||||
IO.LORA_MODEL,
|
||||
{
|
||||
"tooltip": "The LoRA model to save. Do not use the model with LoRA layers."
|
||||
},
|
||||
),
|
||||
"prefix": (
|
||||
"STRING",
|
||||
{
|
||||
"default": "trained_lora",
|
||||
"tooltip": "The prefix to use for the saved LoRA file.",
|
||||
},
|
||||
),
|
||||
},
|
||||
"optional": {
|
||||
"steps": (
|
||||
IO.INT,
|
||||
{
|
||||
"forceInput": True,
|
||||
"tooltip": "Optional: The number of steps to LoRA has been trained for, used to name the saved file.",
|
||||
},
|
||||
),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ()
|
||||
FUNCTION = "save"
|
||||
CATEGORY = "loaders"
|
||||
EXPERIMENTAL = True
|
||||
OUTPUT_NODE = True
|
||||
|
||||
def save(self, lora, prefix, steps=None):
|
||||
date = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
if steps is None:
|
||||
output_file = f"models/loras/{prefix}_{date}_lora.safetensors"
|
||||
else:
|
||||
output_file = f"models/loras/{prefix}_{steps}_steps_{date}_lora.safetensors"
|
||||
safetensors.torch.save_file(lora, output_file)
|
||||
return {}
|
||||
|
||||
|
||||
class LossGraphNode:
|
||||
def __init__(self):
|
||||
self.output_dir = folder_paths.get_temp_directory()
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"loss": (IO.LOSS_MAP, {"default": {}}),
|
||||
"filename_prefix": (IO.STRING, {"default": "loss_graph"}),
|
||||
},
|
||||
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ()
|
||||
FUNCTION = "plot_loss"
|
||||
OUTPUT_NODE = True
|
||||
CATEGORY = "training"
|
||||
EXPERIMENTAL = True
|
||||
DESCRIPTION = "Plots the loss graph and saves it to the output directory."
|
||||
|
||||
def plot_loss(self, loss, filename_prefix, prompt=None, extra_pnginfo=None):
|
||||
loss_values = loss["loss"]
|
||||
width, height = 800, 480
|
||||
margin = 40
|
||||
|
||||
img = Image.new(
|
||||
"RGB", (width + margin, height + margin), "white"
|
||||
) # Extend canvas
|
||||
draw = ImageDraw.Draw(img)
|
||||
|
||||
min_loss, max_loss = min(loss_values), max(loss_values)
|
||||
scaled_loss = [(l - min_loss) / (max_loss - min_loss) for l in loss_values]
|
||||
|
||||
steps = len(loss_values)
|
||||
|
||||
prev_point = (margin, height - int(scaled_loss[0] * height))
|
||||
for i, l in enumerate(scaled_loss[1:], start=1):
|
||||
x = margin + int(i / steps * width) # Scale X properly
|
||||
y = height - int(l * height)
|
||||
draw.line([prev_point, (x, y)], fill="blue", width=2)
|
||||
prev_point = (x, y)
|
||||
|
||||
draw.line([(margin, 0), (margin, height)], fill="black", width=2) # Y-axis
|
||||
draw.line(
|
||||
[(margin, height), (width + margin, height)], fill="black", width=2
|
||||
) # X-axis
|
||||
|
||||
font = None
|
||||
try:
|
||||
font = ImageFont.truetype("arial.ttf", 12)
|
||||
except IOError:
|
||||
font = ImageFont.load_default()
|
||||
|
||||
# Add axis labels
|
||||
draw.text((5, height // 2), "Loss", font=font, fill="black")
|
||||
draw.text((width // 2, height + 10), "Steps", font=font, fill="black")
|
||||
|
||||
# Add min/max loss values
|
||||
draw.text((margin - 30, 0), f"{max_loss:.2f}", font=font, fill="black")
|
||||
draw.text(
|
||||
(margin - 30, height - 10), f"{min_loss:.2f}", font=font, fill="black"
|
||||
)
|
||||
|
||||
metadata = None
|
||||
if not args.disable_metadata:
|
||||
metadata = PngInfo()
|
||||
if prompt is not None:
|
||||
metadata.add_text("prompt", json.dumps(prompt))
|
||||
if extra_pnginfo is not None:
|
||||
for x in extra_pnginfo:
|
||||
metadata.add_text(x, json.dumps(extra_pnginfo[x]))
|
||||
|
||||
date = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
img.save(
|
||||
os.path.join(self.output_dir, f"{filename_prefix}_{date}.png"),
|
||||
pnginfo=metadata,
|
||||
)
|
||||
return {
|
||||
"ui": {
|
||||
"images": [
|
||||
{
|
||||
"filename": f"{filename_prefix}_{date}.png",
|
||||
"subfolder": "",
|
||||
"type": "temp",
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"TrainLoraNode": TrainLoraNode,
|
||||
"SaveLoRANode": SaveLoRA,
|
||||
"LoraModelLoader": LoraModelLoader,
|
||||
"LoadImageSetFromFolderNode": LoadImageSetFromFolderNode,
|
||||
"LossGraphNode": LossGraphNode,
|
||||
}
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"TrainLoraNode": "Train LoRA",
|
||||
"SaveLoRANode": "Save LoRA Weights",
|
||||
"LoraModelLoader": "Load LoRA Model",
|
||||
"LoadImageSetFromFolderNode": "Load Image Dataset from Folder",
|
||||
"LossGraphNode": "Plot Loss Graph",
|
||||
}
|
Reference in New Issue
Block a user