1
mirror of https://github.com/comfyanonymous/ComfyUI.git synced 2025-08-02 23:14:49 +08:00

Move latent scale factor from VAE to model.

This commit is contained in:
comfyanonymous
2023-06-23 02:14:12 -04:00
parent 30a3861946
commit 8607c2d42d
7 changed files with 73 additions and 33 deletions

View File

@@ -536,7 +536,7 @@ class CLIP:
class VAE:
def __init__(self, ckpt_path=None, scale_factor=0.18215, device=None, config=None):
def __init__(self, ckpt_path=None, device=None, config=None):
if config is None:
#default SD1.x/SD2.x VAE parameters
ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
@@ -550,7 +550,6 @@ class VAE:
sd = diffusers_convert.convert_vae_state_dict(sd)
self.first_stage_model.load_state_dict(sd, strict=False)
self.scale_factor = scale_factor
if device is None:
device = model_management.get_torch_device()
self.device = device
@@ -561,7 +560,7 @@ class VAE:
steps += samples.shape[0] * utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x * 2, tile_y // 2, overlap)
pbar = utils.ProgressBar(steps)
decode_fn = lambda a: (self.first_stage_model.decode(1. / self.scale_factor * a.to(self.device)) + 1.0)
decode_fn = lambda a: (self.first_stage_model.decode(a.to(self.device)) + 1.0)
output = torch.clamp((
(utils.tiled_scale(samples, decode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = 8, pbar = pbar) +
utils.tiled_scale(samples, decode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = 8, pbar = pbar) +
@@ -575,7 +574,7 @@ class VAE:
steps += pixel_samples.shape[0] * utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x * 2, tile_y // 2, overlap)
pbar = utils.ProgressBar(steps)
encode_fn = lambda a: self.first_stage_model.encode(2. * a.to(self.device) - 1.).sample() * self.scale_factor
encode_fn = lambda a: self.first_stage_model.encode(2. * a.to(self.device) - 1.).sample()
samples = utils.tiled_scale(pixel_samples, encode_fn, tile_x, tile_y, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar)
samples += utils.tiled_scale(pixel_samples, encode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar)
samples += utils.tiled_scale(pixel_samples, encode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar)
@@ -593,7 +592,7 @@ class VAE:
pixel_samples = torch.empty((samples_in.shape[0], 3, round(samples_in.shape[2] * 8), round(samples_in.shape[3] * 8)), device="cpu")
for x in range(0, samples_in.shape[0], batch_number):
samples = samples_in[x:x+batch_number].to(self.device)
pixel_samples[x:x+batch_number] = torch.clamp((self.first_stage_model.decode(1. / self.scale_factor * samples) + 1.0) / 2.0, min=0.0, max=1.0).cpu()
pixel_samples[x:x+batch_number] = torch.clamp((self.first_stage_model.decode(samples) + 1.0) / 2.0, min=0.0, max=1.0).cpu()
except model_management.OOM_EXCEPTION as e:
print("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
pixel_samples = self.decode_tiled_(samples_in)
@@ -620,7 +619,7 @@ class VAE:
samples = torch.empty((pixel_samples.shape[0], 4, round(pixel_samples.shape[2] // 8), round(pixel_samples.shape[3] // 8)), device="cpu")
for x in range(0, pixel_samples.shape[0], batch_number):
pixels_in = (2. * pixel_samples[x:x+batch_number] - 1.).to(self.device)
samples[x:x+batch_number] = self.first_stage_model.encode(pixels_in).sample().cpu() * self.scale_factor
samples[x:x+batch_number] = self.first_stage_model.encode(pixels_in).sample().cpu()
except model_management.OOM_EXCEPTION as e:
print("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.")
@@ -958,6 +957,7 @@ def load_gligen(ckpt_path):
return model
def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_clip=True, embedding_directory=None, state_dict=None, config=None):
#TODO: this function is a mess and should be removed eventually
if config is None:
with open(config_path, 'r') as stream:
config = yaml.safe_load(stream)
@@ -992,12 +992,20 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl
if state_dict is None:
state_dict = utils.load_torch_file(ckpt_path)
class EmptyClass:
pass
model_config = EmptyClass()
model_config.unet_config = unet_config
from . import latent_formats
model_config.latent_format = latent_formats.SD15(scale_factor=scale_factor)
if config['model']["target"].endswith("LatentInpaintDiffusion"):
model = model_base.SDInpaint(unet_config, v_prediction=v_prediction)
model = model_base.SDInpaint(model_config, v_prediction=v_prediction)
elif config['model']["target"].endswith("ImageEmbeddingConditionedLatentDiffusion"):
model = model_base.SD21UNCLIP(unet_config, noise_aug_config["params"], v_prediction=v_prediction)
model = model_base.SD21UNCLIP(model_config, noise_aug_config["params"], v_prediction=v_prediction)
else:
model = model_base.BaseModel(unet_config, v_prediction=v_prediction)
model = model_base.BaseModel(model_config, v_prediction=v_prediction)
if fp16:
model = model.half()
@@ -1006,14 +1014,12 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl
if output_vae:
w = WeightsLoader()
vae = VAE(scale_factor=scale_factor, config=vae_config)
vae = VAE(config=vae_config)
w.first_stage_model = vae.first_stage_model
load_model_weights(w, state_dict)
if output_clip:
w = WeightsLoader()
class EmptyClass:
pass
clip_target = EmptyClass()
clip_target.params = clip_config.get("params", {})
if clip_config["target"].endswith("FrozenOpenCLIPEmbedder"):
@@ -1055,7 +1061,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
model.load_model_weights(sd, "model.diffusion_model.")
if output_vae:
vae = VAE(scale_factor=model_config.vae_scale_factor)
vae = VAE()
w = WeightsLoader()
w.first_stage_model = vae.first_stage_model
load_model_weights(w, sd)