1
mirror of https://github.com/comfyanonymous/ComfyUI.git synced 2025-08-02 23:14:49 +08:00
This commit is contained in:
comfyanonymous
2024-05-14 18:02:27 -04:00
parent bb4940d837
commit ec6f16adb6
2 changed files with 9 additions and 7 deletions

View File

@@ -5,12 +5,12 @@ import math
from einops import rearrange, repeat
import os
from comfy.ldm.modules.attention import optimized_attention, _ATTN_PRECISION
from comfy.ldm.modules.attention import optimized_attention
import comfy.samplers
# from comfy/ldm/modules/attention.py
# but modified to return attention scores as well as output
def attention_basic_with_sim(q, k, v, heads, mask=None):
def attention_basic_with_sim(q, k, v, heads, mask=None, attn_precision=None):
b, _, dim_head = q.shape
dim_head //= heads
scale = dim_head ** -0.5
@@ -26,7 +26,7 @@ def attention_basic_with_sim(q, k, v, heads, mask=None):
)
# force cast to fp32 to avoid overflowing
if _ATTN_PRECISION =="fp32":
if attn_precision == torch.float32:
sim = einsum('b i d, b j d -> b i j', q.float(), k.float()) * scale
else:
sim = einsum('b i d, b j d -> b i j', q, k) * scale
@@ -121,13 +121,13 @@ class SelfAttentionGuidance:
if 1 in cond_or_uncond:
uncond_index = cond_or_uncond.index(1)
# do the entire attention operation, but save the attention scores to attn_scores
(out, sim) = attention_basic_with_sim(q, k, v, heads=heads)
(out, sim) = attention_basic_with_sim(q, k, v, heads=heads, attn_precision=extra_options["attn_precision"])
# when using a higher batch size, I BELIEVE the result batch dimension is [uc1, ... ucn, c1, ... cn]
n_slices = heads * b
attn_scores = sim[n_slices * uncond_index:n_slices * (uncond_index+1)]
return out
else:
return optimized_attention(q, k, v, heads=heads)
return optimized_attention(q, k, v, heads=heads, attn_precision=extra_options["attn_precision"])
def post_cfg_function(args):
nonlocal attn_scores