1
mirror of https://github.com/comfyanonymous/ComfyUI.git synced 2025-08-02 23:14:49 +08:00

Only enable attention upcasting on models that actually need it.

This commit is contained in:
comfyanonymous
2024-05-14 15:18:00 -04:00
parent b0ab31d06c
commit bb4940d837
5 changed files with 27 additions and 24 deletions

View File

@@ -431,6 +431,7 @@ class UNetModel(nn.Module):
video_kernel_size=None,
disable_temporal_crossattention=False,
max_ddpm_temb_period=10000,
attn_precision=None,
device=None,
operations=ops,
):
@@ -550,13 +551,14 @@ class UNetModel(nn.Module):
disable_self_attn=disable_self_attn,
disable_temporal_crossattention=disable_temporal_crossattention,
max_time_embed_period=max_ddpm_temb_period,
attn_precision=attn_precision,
dtype=self.dtype, device=device, operations=operations
)
else:
return SpatialTransformer(
ch, num_heads, dim_head, depth=depth, context_dim=context_dim,
disable_self_attn=disable_self_attn, use_linear=use_linear_in_transformer,
use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations
use_checkpoint=use_checkpoint, attn_precision=attn_precision, dtype=self.dtype, device=device, operations=operations
)
def get_resblock(