mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-08-02 23:14:49 +08:00
Multi dimension tiled scale function and tiled VAE audio encoding fallback.
This commit is contained in:
31
comfy/sd.py
31
comfy/sd.py
@@ -298,25 +298,9 @@ class VAE:
|
||||
/ 3.0)
|
||||
return output
|
||||
|
||||
def decode_tiled_1d(self, samples, tile_x=128, overlap=64):
|
||||
output = torch.zeros((samples.shape[0], self.output_channels) + tuple(map(lambda a: a * self.upscale_ratio, samples.shape[2:])), device=self.output_device)
|
||||
output_mult = torch.zeros((samples.shape[0], self.output_channels) + tuple(map(lambda a: a * self.upscale_ratio, samples.shape[2:])), device=self.output_device)
|
||||
|
||||
for j in range(samples.shape[0]):
|
||||
for i in range(0, samples.shape[-1], tile_x - overlap):
|
||||
f = i
|
||||
t = i + tile_x
|
||||
o = output[j:j+1,:,f * self.upscale_ratio:t * self.upscale_ratio]
|
||||
m = torch.ones_like(o)
|
||||
l = m.shape[-1]
|
||||
for x in range(overlap):
|
||||
c = ((x + 1) / overlap)
|
||||
m[:,:,x:x+1] *= c
|
||||
m[:,:,l-x-1:l-x] *= c
|
||||
o += self.first_stage_model.decode(samples[j:j+1,:,f:t].to(self.vae_dtype).to(self.device)).float().to(self.output_device) * m
|
||||
output_mult[j:j+1,:,f * self.upscale_ratio:t * self.upscale_ratio] += m
|
||||
|
||||
return output / output_mult
|
||||
def decode_tiled_1d(self, samples, tile_x=128, overlap=32):
|
||||
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float()
|
||||
return comfy.utils.tiled_scale_multidim(samples, decode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, output_device=self.output_device)
|
||||
|
||||
def encode_tiled_(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64):
|
||||
steps = pixel_samples.shape[0] * comfy.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x, tile_y, overlap)
|
||||
@@ -331,6 +315,10 @@ class VAE:
|
||||
samples /= 3.0
|
||||
return samples
|
||||
|
||||
def encode_tiled_1d(self, samples, tile_x=128 * 2048, overlap=32 * 2048):
|
||||
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).float()
|
||||
return comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=(1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device)
|
||||
|
||||
def decode(self, samples_in):
|
||||
try:
|
||||
memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype)
|
||||
@@ -374,7 +362,10 @@ class VAE:
|
||||
|
||||
except model_management.OOM_EXCEPTION as e:
|
||||
logging.warning("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.")
|
||||
samples = self.encode_tiled_(pixel_samples)
|
||||
if len(pixel_samples.shape) == 3:
|
||||
samples = self.encode_tiled_1d(pixel_samples)
|
||||
else:
|
||||
samples = self.encode_tiled_(pixel_samples)
|
||||
|
||||
return samples
|
||||
|
||||
|
Reference in New Issue
Block a user