1
mirror of https://github.com/comfyanonymous/ComfyUI.git synced 2025-08-02 23:14:49 +08:00

Add support for attention masking in Flux (#5942)

* fix attention OOM in xformers

* allow passing attention mask in flux attention

* allow an attn_mask in flux

* attn masks can be done using replace patches instead of a separate dict

* fix return types

* fix return order

* enumerate

* patch the right keys

* arg names

* fix a silly bug

* fix xformers masks

* replace match with if, elif, else

* mask with image_ref_size

* remove unused import

* remove unused import 2

* fix pytorch/xformers attention

This corrects a weird inconsistency with skip_reshape.
It also allows masks of various shapes to be passed, which will be
automtically expanded (in a memory-efficient way) to a size that is
compatible with xformers or pytorch sdpa respectively.

* fix mask shapes
This commit is contained in:
Raphael Walker
2024-12-17 00:21:17 +01:00
committed by GitHub
parent 0f954f34af
commit 61b50720d0
7 changed files with 182 additions and 48 deletions

View File

@@ -26,6 +26,8 @@ import numpy as np
from PIL import Image
import logging
import itertools
from torch.nn.functional import interpolate
from einops import rearrange
def load_torch_file(ckpt, safe_load=False, device=None):
if device is None:
@@ -873,5 +875,46 @@ def reshape_mask(input_mask, output_shape):
mask = torch.nn.functional.interpolate(input_mask, size=output_shape[2:], mode=scale_mode)
if mask.shape[1] < output_shape[1]:
mask = mask.repeat((1, output_shape[1]) + (1,) * dims)[:,:output_shape[1]]
mask = comfy.utils.repeat_to_batch_size(mask, output_shape[0])
mask = repeat_to_batch_size(mask, output_shape[0])
return mask
def upscale_dit_mask(mask: torch.Tensor, img_size_in, img_size_out):
hi, wi = img_size_in
ho, wo = img_size_out
# if it's already the correct size, no need to do anything
if (hi, wi) == (ho, wo):
return mask
if mask.ndim == 2:
mask = mask.unsqueeze(0)
if mask.ndim != 3:
raise ValueError(f"Got a mask of shape {list(mask.shape)}, expected [b, q, k] or [q, k]")
txt_tokens = mask.shape[1] - (hi * wi)
# quadrants of the mask
txt_to_txt = mask[:, :txt_tokens, :txt_tokens]
txt_to_img = mask[:, :txt_tokens, txt_tokens:]
img_to_img = mask[:, txt_tokens:, txt_tokens:]
img_to_txt = mask[:, txt_tokens:, :txt_tokens]
# convert to 1d x 2d, interpolate, then back to 1d x 1d
txt_to_img = rearrange (txt_to_img, "b t (h w) -> b t h w", h=hi, w=wi)
txt_to_img = interpolate(txt_to_img, size=img_size_out, mode="bilinear")
txt_to_img = rearrange (txt_to_img, "b t h w -> b t (h w)")
# this one is hard because we have to do it twice
# convert to 1d x 2d, interpolate, then to 2d x 1d, interpolate, then 1d x 1d
img_to_img = rearrange (img_to_img, "b hw (h w) -> b hw h w", h=hi, w=wi)
img_to_img = interpolate(img_to_img, size=img_size_out, mode="bilinear")
img_to_img = rearrange (img_to_img, "b (hk wk) hq wq -> b (hq wq) hk wk", hk=hi, wk=wi)
img_to_img = interpolate(img_to_img, size=img_size_out, mode="bilinear")
img_to_img = rearrange (img_to_img, "b (hq wq) hk wk -> b (hk wk) (hq wq)", hq=ho, wq=wo)
# convert to 2d x 1d, interpolate, then back to 1d x 1d
img_to_txt = rearrange (img_to_txt, "b (h w) t -> b t h w", h=hi, w=wi)
img_to_txt = interpolate(img_to_txt, size=img_size_out, mode="bilinear")
img_to_txt = rearrange (img_to_txt, "b t h w -> b (h w) t")
# reassemble the mask from blocks
out = torch.cat([
torch.cat([txt_to_txt, txt_to_img], dim=2),
torch.cat([img_to_txt, img_to_img], dim=2)],
dim=1
)
return out