1
mirror of https://github.com/comfyanonymous/ComfyUI.git synced 2025-08-02 23:14:49 +08:00

Fix sub quadratic attention for SD2 and make it the default optimization.

This commit is contained in:
comfyanonymous
2023-01-25 01:22:43 -05:00
parent 3b38a31cc7
commit 051f472e8f
2 changed files with 60 additions and 26 deletions

View File

@@ -53,14 +53,27 @@ def _summarize_chunk(
key_t: Tensor,
value: Tensor,
scale: float,
upcast_attention: bool,
) -> AttnChunk:
attn_weights = torch.baddbmm(
torch.empty(1, 1, 1, device=query.device, dtype=query.dtype),
query,
key_t,
alpha=scale,
beta=0,
)
if upcast_attention:
with torch.autocast(enabled=False, device_type = 'cuda'):
query = query.float()
key_t = key_t.float()
attn_weights = torch.baddbmm(
torch.empty(1, 1, 1, device=query.device, dtype=query.dtype),
query,
key_t,
alpha=scale,
beta=0,
)
else:
attn_weights = torch.baddbmm(
torch.empty(1, 1, 1, device=query.device, dtype=query.dtype),
query,
key_t,
alpha=scale,
beta=0,
)
max_score, _ = torch.max(attn_weights, -1, keepdim=True)
max_score = max_score.detach()
exp_weights = torch.exp(attn_weights - max_score)
@@ -112,14 +125,27 @@ def _get_attention_scores_no_kv_chunking(
key_t: Tensor,
value: Tensor,
scale: float,
upcast_attention: bool,
) -> Tensor:
attn_scores = torch.baddbmm(
torch.empty(1, 1, 1, device=query.device, dtype=query.dtype),
query,
key_t,
alpha=scale,
beta=0,
)
if upcast_attention:
with torch.autocast(enabled=False, device_type = 'cuda'):
query = query.float()
key_t = key_t.float()
attn_scores = torch.baddbmm(
torch.empty(1, 1, 1, device=query.device, dtype=query.dtype),
query,
key_t,
alpha=scale,
beta=0,
)
else:
attn_scores = torch.baddbmm(
torch.empty(1, 1, 1, device=query.device, dtype=query.dtype),
query,
key_t,
alpha=scale,
beta=0,
)
attn_probs = attn_scores.softmax(dim=-1)
del attn_scores
hidden_states_slice = torch.bmm(attn_probs, value)
@@ -137,6 +163,7 @@ def efficient_dot_product_attention(
kv_chunk_size: Optional[int] = None,
kv_chunk_size_min: Optional[int] = None,
use_checkpoint=True,
upcast_attention=False,
):
"""Computes efficient dot-product attention given query, transposed key, and value.
This is efficient version of attention presented in
@@ -170,11 +197,12 @@ def efficient_dot_product_attention(
(batch_x_heads, min(query_chunk_size, q_tokens), q_channels_per_head)
)
summarize_chunk: SummarizeChunk = partial(_summarize_chunk, scale=scale)
summarize_chunk: SummarizeChunk = partial(_summarize_chunk, scale=scale, upcast_attention=upcast_attention)
summarize_chunk: SummarizeChunk = partial(checkpoint, summarize_chunk) if use_checkpoint else summarize_chunk
compute_query_chunk_attn: ComputeQueryChunkAttn = partial(
_get_attention_scores_no_kv_chunking,
scale=scale
scale=scale,
upcast_attention=upcast_attention
) if k_tokens <= kv_chunk_size else (
# fast-path for when there's just 1 key-value chunk per query chunk (this is just sliced attention btw)
partial(