1
mirror of https://github.com/comfyanonymous/ComfyUI.git synced 2025-08-03 07:26:31 +08:00

Implement beta sampling scheduler.

It is based on: https://arxiv.org/abs/2407.12173

Add "beta" to the list of schedulers and the BetaSamplingScheduler node.
This commit is contained in:
comfyanonymous
2024-07-19 17:44:56 -04:00
parent 011b11d8d7
commit 6ab8cad22e
2 changed files with 37 additions and 1 deletions

View File

@@ -6,6 +6,8 @@ from comfy import model_management
import math
import logging
import comfy.sampler_helpers
import scipy
import numpy
def get_area_and_mult(conds, x_in, timestep_in):
dims = tuple(x_in.shape[2:])
@@ -337,6 +339,18 @@ def normal_scheduler(model_sampling, steps, sgm=False, floor=False):
sigs += [0.0]
return torch.FloatTensor(sigs)
# Implemented based on: https://arxiv.org/abs/2407.12173
def beta_scheduler(model_sampling, steps, alpha=0.6, beta=0.6):
total_timesteps = (len(model_sampling.sigmas) - 1)
ts = 1 - numpy.linspace(0, 1, steps, endpoint=False)
ts = numpy.rint(scipy.stats.beta.ppf(ts, alpha, beta) * total_timesteps)
sigs = []
for t in ts:
sigs += [float(model_sampling.sigmas[int(t)])]
sigs += [0.0]
return torch.FloatTensor(sigs)
def get_mask_aabb(masks):
if masks.numel() == 0:
return torch.zeros((0, 4), device=masks.device, dtype=torch.int)
@@ -703,7 +717,7 @@ def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model
return cfg_guider.sample(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)
SCHEDULER_NAMES = ["normal", "karras", "exponential", "sgm_uniform", "simple", "ddim_uniform"]
SCHEDULER_NAMES = ["normal", "karras", "exponential", "sgm_uniform", "simple", "ddim_uniform", "beta"]
SAMPLER_NAMES = KSAMPLER_NAMES + ["ddim", "uni_pc", "uni_pc_bh2"]
def calculate_sigmas(model_sampling, scheduler_name, steps):
@@ -719,6 +733,8 @@ def calculate_sigmas(model_sampling, scheduler_name, steps):
sigmas = ddim_scheduler(model_sampling, steps)
elif scheduler_name == "sgm_uniform":
sigmas = normal_scheduler(model_sampling, steps, sgm=True)
elif scheduler_name == "beta":
sigmas = beta_scheduler(model_sampling, steps)
else:
logging.error("error invalid scheduler {}".format(scheduler_name))
return sigmas