1
mirror of https://github.com/comfyanonymous/ComfyUI.git synced 2025-08-03 23:49:57 +08:00

Playground V2.5 support with ModelSamplingContinuousEDM node.

Use ModelSamplingContinuousEDM with edm_playground_v2.5 selected.
This commit is contained in:
comfyanonymous
2024-02-27 15:12:33 -05:00
parent 1e0fcc9a65
commit d46583ecec
4 changed files with 48 additions and 7 deletions

View File

@@ -17,6 +17,11 @@ class V_PREDICTION(EPS):
sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
return model_input * self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2) - model_output * sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
class EDM(V_PREDICTION):
def calculate_denoised(self, sigma, model_output, model_input):
sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
return model_input * self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2) + model_output * sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
class ModelSamplingDiscrete(torch.nn.Module):
def __init__(self, model_config=None):
@@ -92,8 +97,6 @@ class ModelSamplingDiscrete(torch.nn.Module):
class ModelSamplingContinuousEDM(torch.nn.Module):
def __init__(self, model_config=None):
super().__init__()
self.sigma_data = 1.0
if model_config is not None:
sampling_settings = model_config.sampling_settings
else:
@@ -101,9 +104,11 @@ class ModelSamplingContinuousEDM(torch.nn.Module):
sigma_min = sampling_settings.get("sigma_min", 0.002)
sigma_max = sampling_settings.get("sigma_max", 120.0)
self.set_sigma_range(sigma_min, sigma_max)
sigma_data = sampling_settings.get("sigma_data", 1.0)
self.set_parameters(sigma_min, sigma_max, sigma_data)
def set_sigma_range(self, sigma_min, sigma_max):
def set_parameters(self, sigma_min, sigma_max, sigma_data):
self.sigma_data = sigma_data
sigmas = torch.linspace(math.log(sigma_min), math.log(sigma_max), 1000).exp()
self.register_buffer('sigmas', sigmas) #for compatibility with some schedulers