mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-08-02 23:14:49 +08:00
Optimize first attention block in cosmos VAE.
This commit is contained in:
@@ -293,6 +293,17 @@ def pytorch_attention(q, k, v):
|
||||
return out
|
||||
|
||||
|
||||
def vae_attention():
|
||||
if model_management.xformers_enabled_vae():
|
||||
logging.info("Using xformers attention in VAE")
|
||||
return xformers_attention
|
||||
elif model_management.pytorch_attention_enabled():
|
||||
logging.info("Using pytorch attention in VAE")
|
||||
return pytorch_attention
|
||||
else:
|
||||
logging.info("Using split attention in VAE")
|
||||
return normal_attention
|
||||
|
||||
class AttnBlock(nn.Module):
|
||||
def __init__(self, in_channels, conv_op=ops.Conv2d):
|
||||
super().__init__()
|
||||
@@ -320,15 +331,7 @@ class AttnBlock(nn.Module):
|
||||
stride=1,
|
||||
padding=0)
|
||||
|
||||
if model_management.xformers_enabled_vae():
|
||||
logging.info("Using xformers attention in VAE")
|
||||
self.optimized_attention = xformers_attention
|
||||
elif model_management.pytorch_attention_enabled():
|
||||
logging.info("Using pytorch attention in VAE")
|
||||
self.optimized_attention = pytorch_attention
|
||||
else:
|
||||
logging.info("Using split attention in VAE")
|
||||
self.optimized_attention = normal_attention
|
||||
self.optimized_attention = vae_attention()
|
||||
|
||||
def forward(self, x):
|
||||
h_ = x
|
||||
|
Reference in New Issue
Block a user