1
mirror of https://github.com/comfyanonymous/ComfyUI.git synced 2025-08-02 15:04:50 +08:00

Add support for TAESD decoder for SDXL.

This commit is contained in:
comfyanonymous
2023-06-25 02:38:14 -04:00
parent 20f579d91d
commit cef6aa62b2
4 changed files with 24 additions and 15 deletions

View File

@@ -49,14 +49,8 @@ class TAESDPreviewerImpl(LatentPreviewer):
class Latent2RGBPreviewer(LatentPreviewer):
def __init__(self):
self.latent_rgb_factors = torch.tensor([
# R G B
[0.298, 0.207, 0.208], # L1
[0.187, 0.286, 0.173], # L2
[-0.158, 0.189, 0.264], # L3
[-0.184, -0.271, -0.473], # L4
], device="cpu")
def __init__(self, latent_rgb_factors):
self.latent_rgb_factors = torch.tensor(latent_rgb_factors, device="cpu")
def decode_latent_to_preview(self, x0):
latent_image = x0[0].permute(1, 2, 0).cpu() @ self.latent_rgb_factors
@@ -69,12 +63,12 @@ class Latent2RGBPreviewer(LatentPreviewer):
return Image.fromarray(latents_ubyte.numpy())
def get_previewer(device):
def get_previewer(device, latent_format):
previewer = None
method = args.preview_method
if method != LatentPreviewMethod.NoPreviews:
# TODO previewer methods
taesd_decoder_path = folder_paths.get_full_path("vae_approx", "taesd_decoder.pth")
taesd_decoder_path = folder_paths.get_full_path("vae_approx", latent_format.taesd_decoder_name)
if method == LatentPreviewMethod.Auto:
method = LatentPreviewMethod.Latent2RGB
@@ -86,10 +80,10 @@ def get_previewer(device):
taesd = TAESD(None, taesd_decoder_path).to(device)
previewer = TAESDPreviewerImpl(taesd)
else:
print("Warning: TAESD previews enabled, but could not find models/vae_approx/taesd_decoder.pth")
print("Warning: TAESD previews enabled, but could not find models/vae_approx/{}".format(latent_format.taesd_decoder_name))
if previewer is None:
previewer = Latent2RGBPreviewer()
previewer = Latent2RGBPreviewer(latent_format.latent_rgb_factors)
return previewer