mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-08-02 15:04:50 +08:00
Add support for attention masking in Flux (#5942)
* fix attention OOM in xformers * allow passing attention mask in flux attention * allow an attn_mask in flux * attn masks can be done using replace patches instead of a separate dict * fix return types * fix return order * enumerate * patch the right keys * arg names * fix a silly bug * fix xformers masks * replace match with if, elif, else * mask with image_ref_size * remove unused import * remove unused import 2 * fix pytorch/xformers attention This corrects a weird inconsistency with skip_reshape. It also allows masks of various shapes to be passed, which will be automtically expanded (in a memory-efficient way) to a size that is compatible with xformers or pytorch sdpa respectively. * fix mask shapes
This commit is contained in:
47
nodes.py
47
nodes.py
@@ -1008,23 +1008,58 @@ class StyleModelApply:
|
||||
"style_model": ("STYLE_MODEL", ),
|
||||
"clip_vision_output": ("CLIP_VISION_OUTPUT", ),
|
||||
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}),
|
||||
"strength_type": (["multiply"], ),
|
||||
"strength_type": (["multiply", "attn_bias"], ),
|
||||
}}
|
||||
RETURN_TYPES = ("CONDITIONING",)
|
||||
FUNCTION = "apply_stylemodel"
|
||||
|
||||
CATEGORY = "conditioning/style_model"
|
||||
|
||||
def apply_stylemodel(self, clip_vision_output, style_model, conditioning, strength, strength_type):
|
||||
def apply_stylemodel(self, conditioning, style_model, clip_vision_output, strength, strength_type):
|
||||
cond = style_model.get_cond(clip_vision_output).flatten(start_dim=0, end_dim=1).unsqueeze(dim=0)
|
||||
if strength_type == "multiply":
|
||||
cond *= strength
|
||||
|
||||
c = []
|
||||
n = cond.shape[1]
|
||||
c_out = []
|
||||
for t in conditioning:
|
||||
n = [torch.cat((t[0], cond), dim=1), t[1].copy()]
|
||||
c.append(n)
|
||||
return (c, )
|
||||
(txt, keys) = t
|
||||
keys = keys.copy()
|
||||
if strength_type == "attn_bias" and strength != 1.0:
|
||||
# math.log raises an error if the argument is zero
|
||||
# torch.log returns -inf, which is what we want
|
||||
attn_bias = torch.log(torch.Tensor([strength]))
|
||||
# get the size of the mask image
|
||||
mask_ref_size = keys.get("attention_mask_img_shape", (1, 1))
|
||||
n_ref = mask_ref_size[0] * mask_ref_size[1]
|
||||
n_txt = txt.shape[1]
|
||||
# grab the existing mask
|
||||
mask = keys.get("attention_mask", None)
|
||||
# create a default mask if it doesn't exist
|
||||
if mask is None:
|
||||
mask = torch.zeros((txt.shape[0], n_txt + n_ref, n_txt + n_ref), dtype=torch.float16)
|
||||
# convert the mask dtype, because it might be boolean
|
||||
# we want it to be interpreted as a bias
|
||||
if mask.dtype == torch.bool:
|
||||
# log(True) = log(1) = 0
|
||||
# log(False) = log(0) = -inf
|
||||
mask = torch.log(mask.to(dtype=torch.float16))
|
||||
# now we make the mask bigger to add space for our new tokens
|
||||
new_mask = torch.zeros((txt.shape[0], n_txt + n + n_ref, n_txt + n + n_ref), dtype=torch.float16)
|
||||
# copy over the old mask, in quandrants
|
||||
new_mask[:, :n_txt, :n_txt] = mask[:, :n_txt, :n_txt]
|
||||
new_mask[:, :n_txt, n_txt+n:] = mask[:, :n_txt, n_txt:]
|
||||
new_mask[:, n_txt+n:, :n_txt] = mask[:, n_txt:, :n_txt]
|
||||
new_mask[:, n_txt+n:, n_txt+n:] = mask[:, n_txt:, n_txt:]
|
||||
# now fill in the attention bias to our redux tokens
|
||||
new_mask[:, :n_txt, n_txt:n_txt+n] = attn_bias
|
||||
new_mask[:, n_txt+n:, n_txt:n_txt+n] = attn_bias
|
||||
keys["attention_mask"] = new_mask.to(txt.device)
|
||||
keys["attention_mask_img_shape"] = mask_ref_size
|
||||
|
||||
c_out.append([torch.cat((txt, cond), dim=1), keys])
|
||||
|
||||
return (c_out,)
|
||||
|
||||
class unCLIPConditioning:
|
||||
@classmethod
|
||||
|
Reference in New Issue
Block a user