1
mirror of https://github.com/comfyanonymous/ComfyUI.git synced 2025-08-03 07:26:31 +08:00

Make wan2.2 5B i2v take a lot less memory. (#9102)

This commit is contained in:
comfyanonymous
2025-07-29 16:44:18 -07:00
committed by GitHub
parent 7d593baf91
commit dca6bdd4fa
2 changed files with 17 additions and 6 deletions

View File

@@ -146,6 +146,15 @@ WAN_CROSSATTENTION_CLASSES = {
} }
def repeat_e(e, x):
repeats = 1
if e.shape[1] > 1:
repeats = x.shape[1] // e.shape[1]
if repeats == 1:
return e
return torch.repeat_interleave(e, repeats, dim=1)
class WanAttentionBlock(nn.Module): class WanAttentionBlock(nn.Module):
def __init__(self, def __init__(self,
@@ -201,6 +210,7 @@ class WanAttentionBlock(nn.Module):
freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2] freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
""" """
# assert e.dtype == torch.float32 # assert e.dtype == torch.float32
if e.ndim < 4: if e.ndim < 4:
e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device) + e).chunk(6, dim=1) e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device) + e).chunk(6, dim=1)
else: else:
@@ -209,15 +219,15 @@ class WanAttentionBlock(nn.Module):
# self-attention # self-attention
y = self.self_attn( y = self.self_attn(
self.norm1(x) * (1 + e[1]) + e[0], self.norm1(x) * (1 + repeat_e(e[1], x)) + repeat_e(e[0], x),
freqs) freqs)
x = x + y * e[2] x = x + y * repeat_e(e[2], x)
# cross-attention & ffn # cross-attention & ffn
x = x + self.cross_attn(self.norm3(x), context, context_img_len=context_img_len) x = x + self.cross_attn(self.norm3(x), context, context_img_len=context_img_len)
y = self.ffn(self.norm2(x) * (1 + e[4]) + e[3]) y = self.ffn(self.norm2(x) * (1 + repeat_e(e[4], x)) + repeat_e(e[3], x))
x = x + y * e[5] x = x + y * repeat_e(e[5], x)
return x return x
@@ -331,7 +341,8 @@ class Head(nn.Module):
e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device) + e.unsqueeze(1)).chunk(2, dim=1) e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device) + e.unsqueeze(1)).chunk(2, dim=1)
else: else:
e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device).unsqueeze(0) + e.unsqueeze(2)).unbind(2) e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device).unsqueeze(0) + e.unsqueeze(2)).unbind(2)
x = (self.head(self.norm(x) * (1 + e[1]) + e[0]))
x = (self.head(self.norm(x) * (1 + repeat_e(e[1], x)) + repeat_e(e[0], x)))
return x return x

View File

@@ -1202,7 +1202,7 @@ class WAN22(BaseModel):
def process_timestep(self, timestep, x, denoise_mask=None, **kwargs): def process_timestep(self, timestep, x, denoise_mask=None, **kwargs):
if denoise_mask is None: if denoise_mask is None:
return timestep return timestep
temp_ts = (torch.mean(denoise_mask[:, :, :, ::2, ::2], dim=1, keepdim=True) * timestep.view([timestep.shape[0]] + [1] * (denoise_mask.ndim - 1))).reshape(timestep.shape[0], -1) temp_ts = (torch.mean(denoise_mask[:, :, :, :, :], dim=(1, 3, 4), keepdim=True) * timestep.view([timestep.shape[0]] + [1] * (denoise_mask.ndim - 1))).reshape(timestep.shape[0], -1)
return temp_ts return temp_ts
def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs): def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs):