mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-08-02 23:14:49 +08:00
Make VAE Encode tiled node work with video VAE.
This commit is contained in:
56
comfy/sd.py
56
comfy/sd.py
@@ -336,6 +336,7 @@ class VAE:
|
||||
self.memory_used_decode = lambda shape, dtype: (1000 * shape[2] * shape[3] * shape[4] * (6 * 8 * 8)) * model_management.dtype_size(dtype)
|
||||
self.memory_used_encode = lambda shape, dtype: (1.5 * max(shape[2], 7) * shape[3] * shape[4] * (6 * 8 * 8)) * model_management.dtype_size(dtype)
|
||||
self.upscale_ratio = (lambda a: max(0, a * 6 - 5), 8, 8)
|
||||
self.downscale_ratio = (lambda a: max(0, (a + 3) / 6), 8, 8)
|
||||
self.working_dtypes = [torch.float16, torch.float32]
|
||||
elif "decoder.up_blocks.0.res_blocks.0.conv1.conv.weight" in sd: #lightricks ltxv
|
||||
self.first_stage_model = comfy.ldm.lightricks.vae.causal_video_autoencoder.VideoVAE()
|
||||
@@ -344,12 +345,14 @@ class VAE:
|
||||
self.memory_used_decode = lambda shape, dtype: (900 * shape[2] * shape[3] * shape[4] * (8 * 8 * 8)) * model_management.dtype_size(dtype)
|
||||
self.memory_used_encode = lambda shape, dtype: (70 * max(shape[2], 7) * shape[3] * shape[4]) * model_management.dtype_size(dtype)
|
||||
self.upscale_ratio = (lambda a: max(0, a * 8 - 7), 32, 32)
|
||||
self.downscale_ratio = (lambda a: max(0, (a + 4) / 8), 32, 32)
|
||||
self.working_dtypes = [torch.bfloat16, torch.float32]
|
||||
elif "decoder.conv_in.conv.weight" in sd:
|
||||
ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
|
||||
ddconfig["conv3d"] = True
|
||||
ddconfig["time_compress"] = 4
|
||||
self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 8, 8)
|
||||
self.downscale_ratio = (lambda a: max(0, (a + 2) / 4), 8, 8)
|
||||
self.latent_dim = 3
|
||||
self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.conv.weight"].shape[1]
|
||||
self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=sd['post_quant_conv.weight'].shape[1])
|
||||
@@ -385,10 +388,12 @@ class VAE:
|
||||
logging.debug("VAE load device: {}, offload device: {}, dtype: {}".format(self.device, offload_device, self.vae_dtype))
|
||||
|
||||
def vae_encode_crop_pixels(self, pixels):
|
||||
downscale_ratio = self.spacial_compression_encode()
|
||||
|
||||
dims = pixels.shape[1:-1]
|
||||
for d in range(len(dims)):
|
||||
x = (dims[d] // self.downscale_ratio) * self.downscale_ratio
|
||||
x_offset = (dims[d] % self.downscale_ratio) // 2
|
||||
x = (dims[d] // downscale_ratio) * downscale_ratio
|
||||
x_offset = (dims[d] % downscale_ratio) // 2
|
||||
if x != dims[d]:
|
||||
pixels = pixels.narrow(d + 1, x_offset, x)
|
||||
return pixels
|
||||
@@ -409,7 +414,7 @@ class VAE:
|
||||
|
||||
def decode_tiled_1d(self, samples, tile_x=128, overlap=32):
|
||||
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float()
|
||||
return comfy.utils.tiled_scale_multidim(samples, decode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, output_device=self.output_device)
|
||||
return self.process_output(comfy.utils.tiled_scale_multidim(samples, decode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, output_device=self.output_device))
|
||||
|
||||
def decode_tiled_3d(self, samples, tile_t=999, tile_x=32, tile_y=32, overlap=(1, 8, 8)):
|
||||
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float()
|
||||
@@ -432,6 +437,10 @@ class VAE:
|
||||
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).float()
|
||||
return comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=(1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device)
|
||||
|
||||
def encode_tiled_3d(self, samples, tile_t=9999, tile_x=512, tile_y=512, overlap=(1, 64, 64)):
|
||||
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).float()
|
||||
return comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.downscale_ratio, out_channels=self.latent_channels, downscale=True, output_device=self.output_device)
|
||||
|
||||
def decode(self, samples_in):
|
||||
pixel_samples = None
|
||||
try:
|
||||
@@ -504,18 +513,43 @@ class VAE:
|
||||
|
||||
except model_management.OOM_EXCEPTION:
|
||||
logging.warning("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.")
|
||||
if len(pixel_samples.shape) == 3:
|
||||
if self.latent_dim == 3:
|
||||
tile = 256
|
||||
overlap = tile // 4
|
||||
samples = self.encode_tiled_3d(pixel_samples, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap))
|
||||
elif self.latent_dim == 1:
|
||||
samples = self.encode_tiled_1d(pixel_samples)
|
||||
else:
|
||||
samples = self.encode_tiled_(pixel_samples)
|
||||
|
||||
return samples
|
||||
|
||||
def encode_tiled(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64):
|
||||
def encode_tiled(self, pixel_samples, tile_x=None, tile_y=None, overlap=None):
|
||||
pixel_samples = self.vae_encode_crop_pixels(pixel_samples)
|
||||
model_management.load_model_gpu(self.patcher)
|
||||
pixel_samples = pixel_samples.movedim(-1,1)
|
||||
samples = self.encode_tiled_(pixel_samples, tile_x=tile_x, tile_y=tile_y, overlap=overlap)
|
||||
dims = self.latent_dim
|
||||
pixel_samples = pixel_samples.movedim(-1, 1)
|
||||
if dims == 3:
|
||||
pixel_samples = pixel_samples.movedim(1, 0).unsqueeze(0)
|
||||
|
||||
memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype) # TODO: calculate mem required for tile
|
||||
model_management.load_models_gpu([self.patcher], memory_required=memory_used)
|
||||
|
||||
args = {}
|
||||
if tile_x is not None:
|
||||
args["tile_x"] = tile_x
|
||||
if tile_y is not None:
|
||||
args["tile_y"] = tile_y
|
||||
if overlap is not None:
|
||||
args["overlap"] = overlap
|
||||
|
||||
if dims == 1:
|
||||
args.pop("tile_y")
|
||||
samples = self.encode_tiled_1d(pixel_samples, **args)
|
||||
elif dims == 2:
|
||||
samples = self.encode_tiled_(pixel_samples, **args)
|
||||
elif dims == 3:
|
||||
samples = self.encode_tiled_3d(pixel_samples, **args)
|
||||
|
||||
return samples
|
||||
|
||||
def get_sd(self):
|
||||
@@ -527,6 +561,12 @@ class VAE:
|
||||
except:
|
||||
return self.upscale_ratio
|
||||
|
||||
def spacial_compression_encode(self):
|
||||
try:
|
||||
return self.downscale_ratio[-1]
|
||||
except:
|
||||
return self.downscale_ratio
|
||||
|
||||
class StyleModel:
|
||||
def __init__(self, model, device="cpu"):
|
||||
self.model = model
|
||||
|
Reference in New Issue
Block a user