1
mirror of https://github.com/comfyanonymous/ComfyUI.git synced 2025-08-02 15:04:50 +08:00

Add support for attention masking in Flux (#5942)

* fix attention OOM in xformers

* allow passing attention mask in flux attention

* allow an attn_mask in flux

* attn masks can be done using replace patches instead of a separate dict

* fix return types

* fix return order

* enumerate

* patch the right keys

* arg names

* fix a silly bug

* fix xformers masks

* replace match with if, elif, else

* mask with image_ref_size

* remove unused import

* remove unused import 2

* fix pytorch/xformers attention

This corrects a weird inconsistency with skip_reshape.
It also allows masks of various shapes to be passed, which will be
automtically expanded (in a memory-efficient way) to a size that is
compatible with xformers or pytorch sdpa respectively.

* fix mask shapes
This commit is contained in:
Raphael Walker
2024-12-17 00:21:17 +01:00
committed by GitHub
parent 0f954f34af
commit 61b50720d0
7 changed files with 182 additions and 48 deletions

View File

@@ -1008,23 +1008,58 @@ class StyleModelApply:
"style_model": ("STYLE_MODEL", ),
"clip_vision_output": ("CLIP_VISION_OUTPUT", ),
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}),
"strength_type": (["multiply"], ),
"strength_type": (["multiply", "attn_bias"], ),
}}
RETURN_TYPES = ("CONDITIONING",)
FUNCTION = "apply_stylemodel"
CATEGORY = "conditioning/style_model"
def apply_stylemodel(self, clip_vision_output, style_model, conditioning, strength, strength_type):
def apply_stylemodel(self, conditioning, style_model, clip_vision_output, strength, strength_type):
cond = style_model.get_cond(clip_vision_output).flatten(start_dim=0, end_dim=1).unsqueeze(dim=0)
if strength_type == "multiply":
cond *= strength
c = []
n = cond.shape[1]
c_out = []
for t in conditioning:
n = [torch.cat((t[0], cond), dim=1), t[1].copy()]
c.append(n)
return (c, )
(txt, keys) = t
keys = keys.copy()
if strength_type == "attn_bias" and strength != 1.0:
# math.log raises an error if the argument is zero
# torch.log returns -inf, which is what we want
attn_bias = torch.log(torch.Tensor([strength]))
# get the size of the mask image
mask_ref_size = keys.get("attention_mask_img_shape", (1, 1))
n_ref = mask_ref_size[0] * mask_ref_size[1]
n_txt = txt.shape[1]
# grab the existing mask
mask = keys.get("attention_mask", None)
# create a default mask if it doesn't exist
if mask is None:
mask = torch.zeros((txt.shape[0], n_txt + n_ref, n_txt + n_ref), dtype=torch.float16)
# convert the mask dtype, because it might be boolean
# we want it to be interpreted as a bias
if mask.dtype == torch.bool:
# log(True) = log(1) = 0
# log(False) = log(0) = -inf
mask = torch.log(mask.to(dtype=torch.float16))
# now we make the mask bigger to add space for our new tokens
new_mask = torch.zeros((txt.shape[0], n_txt + n + n_ref, n_txt + n + n_ref), dtype=torch.float16)
# copy over the old mask, in quandrants
new_mask[:, :n_txt, :n_txt] = mask[:, :n_txt, :n_txt]
new_mask[:, :n_txt, n_txt+n:] = mask[:, :n_txt, n_txt:]
new_mask[:, n_txt+n:, :n_txt] = mask[:, n_txt:, :n_txt]
new_mask[:, n_txt+n:, n_txt+n:] = mask[:, n_txt:, n_txt:]
# now fill in the attention bias to our redux tokens
new_mask[:, :n_txt, n_txt:n_txt+n] = attn_bias
new_mask[:, n_txt+n:, n_txt:n_txt+n] = attn_bias
keys["attention_mask"] = new_mask.to(txt.device)
keys["attention_mask_img_shape"] = mask_ref_size
c_out.append([torch.cat((txt, cond), dim=1), keys])
return (c_out,)
class unCLIPConditioning:
@classmethod