mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-08-02 23:14:49 +08:00
Fix SAG.
This commit is contained in:
@@ -5,12 +5,12 @@ import math
|
||||
|
||||
from einops import rearrange, repeat
|
||||
import os
|
||||
from comfy.ldm.modules.attention import optimized_attention, _ATTN_PRECISION
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
import comfy.samplers
|
||||
|
||||
# from comfy/ldm/modules/attention.py
|
||||
# but modified to return attention scores as well as output
|
||||
def attention_basic_with_sim(q, k, v, heads, mask=None):
|
||||
def attention_basic_with_sim(q, k, v, heads, mask=None, attn_precision=None):
|
||||
b, _, dim_head = q.shape
|
||||
dim_head //= heads
|
||||
scale = dim_head ** -0.5
|
||||
@@ -26,7 +26,7 @@ def attention_basic_with_sim(q, k, v, heads, mask=None):
|
||||
)
|
||||
|
||||
# force cast to fp32 to avoid overflowing
|
||||
if _ATTN_PRECISION =="fp32":
|
||||
if attn_precision == torch.float32:
|
||||
sim = einsum('b i d, b j d -> b i j', q.float(), k.float()) * scale
|
||||
else:
|
||||
sim = einsum('b i d, b j d -> b i j', q, k) * scale
|
||||
@@ -121,13 +121,13 @@ class SelfAttentionGuidance:
|
||||
if 1 in cond_or_uncond:
|
||||
uncond_index = cond_or_uncond.index(1)
|
||||
# do the entire attention operation, but save the attention scores to attn_scores
|
||||
(out, sim) = attention_basic_with_sim(q, k, v, heads=heads)
|
||||
(out, sim) = attention_basic_with_sim(q, k, v, heads=heads, attn_precision=extra_options["attn_precision"])
|
||||
# when using a higher batch size, I BELIEVE the result batch dimension is [uc1, ... ucn, c1, ... cn]
|
||||
n_slices = heads * b
|
||||
attn_scores = sim[n_slices * uncond_index:n_slices * (uncond_index+1)]
|
||||
return out
|
||||
else:
|
||||
return optimized_attention(q, k, v, heads=heads)
|
||||
return optimized_attention(q, k, v, heads=heads, attn_precision=extra_options["attn_precision"])
|
||||
|
||||
def post_cfg_function(args):
|
||||
nonlocal attn_scores
|
||||
|
Reference in New Issue
Block a user