mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-08-02 23:14:49 +08:00
support TAESD3 (#3738)
This commit is contained in:
17
nodes.py
17
nodes.py
@@ -634,6 +634,8 @@ class VAELoader:
|
||||
sdxl_taesd_dec = False
|
||||
sd1_taesd_enc = False
|
||||
sd1_taesd_dec = False
|
||||
sd3_taesd_enc = False
|
||||
sd3_taesd_dec = False
|
||||
|
||||
for v in approx_vaes:
|
||||
if v.startswith("taesd_decoder."):
|
||||
@@ -644,10 +646,16 @@ class VAELoader:
|
||||
sdxl_taesd_dec = True
|
||||
elif v.startswith("taesdxl_encoder."):
|
||||
sdxl_taesd_enc = True
|
||||
elif v.startswith("taesd3_decoder."):
|
||||
sd3_taesd_dec = True
|
||||
elif v.startswith("taesd3_encoder."):
|
||||
sd3_taesd_enc = True
|
||||
if sd1_taesd_dec and sd1_taesd_enc:
|
||||
vaes.append("taesd")
|
||||
if sdxl_taesd_dec and sdxl_taesd_enc:
|
||||
vaes.append("taesdxl")
|
||||
if sd3_taesd_dec and sd3_taesd_enc:
|
||||
vaes.append("taesd3")
|
||||
return vaes
|
||||
|
||||
@staticmethod
|
||||
@@ -670,6 +678,8 @@ class VAELoader:
|
||||
sd["vae_scale"] = torch.tensor(0.18215)
|
||||
elif name == "taesdxl":
|
||||
sd["vae_scale"] = torch.tensor(0.13025)
|
||||
elif name == "taesd3":
|
||||
sd["vae_scale"] = torch.tensor(1.5305)
|
||||
return sd
|
||||
|
||||
@classmethod
|
||||
@@ -682,12 +692,15 @@ class VAELoader:
|
||||
|
||||
#TODO: scale factor?
|
||||
def load_vae(self, vae_name):
|
||||
if vae_name in ["taesd", "taesdxl"]:
|
||||
if vae_name in ["taesd", "taesdxl", "taesd3"]:
|
||||
sd = self.load_taesd(vae_name)
|
||||
else:
|
||||
vae_path = folder_paths.get_full_path("vae", vae_name)
|
||||
sd = comfy.utils.load_torch_file(vae_path)
|
||||
vae = comfy.sd.VAE(sd=sd)
|
||||
|
||||
latent_channels = 16 if vae_name == 'taesd3' else 4
|
||||
|
||||
vae = comfy.sd.VAE(sd=sd, latent_channels=latent_channels)
|
||||
return (vae,)
|
||||
|
||||
class ControlNetLoader:
|
||||
|
Reference in New Issue
Block a user