mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-08-02 23:14:49 +08:00
Compare commits
4 Commits
5d5024296d
...
v0.3.47
Author | SHA1 | Date | |
---|---|---|---|
|
2f74e17975 | ||
|
dca6bdd4fa | ||
|
7d593baf91 | ||
|
c60dc4177c |
@@ -146,6 +146,15 @@ WAN_CROSSATTENTION_CLASSES = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def repeat_e(e, x):
|
||||||
|
repeats = 1
|
||||||
|
if e.shape[1] > 1:
|
||||||
|
repeats = x.shape[1] // e.shape[1]
|
||||||
|
if repeats == 1:
|
||||||
|
return e
|
||||||
|
return torch.repeat_interleave(e, repeats, dim=1)
|
||||||
|
|
||||||
|
|
||||||
class WanAttentionBlock(nn.Module):
|
class WanAttentionBlock(nn.Module):
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
@@ -201,6 +210,7 @@ class WanAttentionBlock(nn.Module):
|
|||||||
freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
|
freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
|
||||||
"""
|
"""
|
||||||
# assert e.dtype == torch.float32
|
# assert e.dtype == torch.float32
|
||||||
|
|
||||||
if e.ndim < 4:
|
if e.ndim < 4:
|
||||||
e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device) + e).chunk(6, dim=1)
|
e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device) + e).chunk(6, dim=1)
|
||||||
else:
|
else:
|
||||||
@@ -209,15 +219,15 @@ class WanAttentionBlock(nn.Module):
|
|||||||
|
|
||||||
# self-attention
|
# self-attention
|
||||||
y = self.self_attn(
|
y = self.self_attn(
|
||||||
self.norm1(x) * (1 + e[1]) + e[0],
|
self.norm1(x) * (1 + repeat_e(e[1], x)) + repeat_e(e[0], x),
|
||||||
freqs)
|
freqs)
|
||||||
|
|
||||||
x = x + y * e[2]
|
x = x + y * repeat_e(e[2], x)
|
||||||
|
|
||||||
# cross-attention & ffn
|
# cross-attention & ffn
|
||||||
x = x + self.cross_attn(self.norm3(x), context, context_img_len=context_img_len)
|
x = x + self.cross_attn(self.norm3(x), context, context_img_len=context_img_len)
|
||||||
y = self.ffn(self.norm2(x) * (1 + e[4]) + e[3])
|
y = self.ffn(self.norm2(x) * (1 + repeat_e(e[4], x)) + repeat_e(e[3], x))
|
||||||
x = x + y * e[5]
|
x = x + y * repeat_e(e[5], x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
@@ -331,7 +341,8 @@ class Head(nn.Module):
|
|||||||
e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device) + e.unsqueeze(1)).chunk(2, dim=1)
|
e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device) + e.unsqueeze(1)).chunk(2, dim=1)
|
||||||
else:
|
else:
|
||||||
e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device).unsqueeze(0) + e.unsqueeze(2)).unbind(2)
|
e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device).unsqueeze(0) + e.unsqueeze(2)).unbind(2)
|
||||||
x = (self.head(self.norm(x) * (1 + e[1]) + e[0]))
|
|
||||||
|
x = (self.head(self.norm(x) * (1 + repeat_e(e[1], x)) + repeat_e(e[0], x)))
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
@@ -136,7 +136,7 @@ class ResidualBlock(nn.Module):
|
|||||||
if in_dim != out_dim else nn.Identity())
|
if in_dim != out_dim else nn.Identity())
|
||||||
|
|
||||||
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
||||||
h = self.shortcut(x)
|
old_x = x
|
||||||
for layer in self.residual:
|
for layer in self.residual:
|
||||||
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
||||||
idx = feat_idx[0]
|
idx = feat_idx[0]
|
||||||
@@ -156,7 +156,7 @@ class ResidualBlock(nn.Module):
|
|||||||
feat_idx[0] += 1
|
feat_idx[0] += 1
|
||||||
else:
|
else:
|
||||||
x = layer(x)
|
x = layer(x)
|
||||||
return x + h
|
return x + self.shortcut(old_x)
|
||||||
|
|
||||||
|
|
||||||
def patchify(x, patch_size):
|
def patchify(x, patch_size):
|
||||||
@@ -327,7 +327,7 @@ class Down_ResidualBlock(nn.Module):
|
|||||||
self.downsamples = nn.Sequential(*downsamples)
|
self.downsamples = nn.Sequential(*downsamples)
|
||||||
|
|
||||||
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
||||||
x_copy = x.clone()
|
x_copy = x
|
||||||
for module in self.downsamples:
|
for module in self.downsamples:
|
||||||
x = module(x, feat_cache, feat_idx)
|
x = module(x, feat_cache, feat_idx)
|
||||||
|
|
||||||
@@ -369,7 +369,7 @@ class Up_ResidualBlock(nn.Module):
|
|||||||
self.upsamples = nn.Sequential(*upsamples)
|
self.upsamples = nn.Sequential(*upsamples)
|
||||||
|
|
||||||
def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False):
|
def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False):
|
||||||
x_main = x.clone()
|
x_main = x
|
||||||
for module in self.upsamples:
|
for module in self.upsamples:
|
||||||
x_main = module(x_main, feat_cache, feat_idx)
|
x_main = module(x_main, feat_cache, feat_idx)
|
||||||
if self.avg_shortcut is not None:
|
if self.avg_shortcut is not None:
|
||||||
|
@@ -1202,7 +1202,7 @@ class WAN22(BaseModel):
|
|||||||
def process_timestep(self, timestep, x, denoise_mask=None, **kwargs):
|
def process_timestep(self, timestep, x, denoise_mask=None, **kwargs):
|
||||||
if denoise_mask is None:
|
if denoise_mask is None:
|
||||||
return timestep
|
return timestep
|
||||||
temp_ts = (torch.mean(denoise_mask[:, :, :, ::2, ::2], dim=1, keepdim=True) * timestep.view([timestep.shape[0]] + [1] * (denoise_mask.ndim - 1))).reshape(timestep.shape[0], -1)
|
temp_ts = (torch.mean(denoise_mask[:, :, :, :, :], dim=(1, 3, 4), keepdim=True) * timestep.view([timestep.shape[0]] + [1] * (denoise_mask.ndim - 1))).reshape(timestep.shape[0], -1)
|
||||||
return temp_ts
|
return temp_ts
|
||||||
|
|
||||||
def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs):
|
def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs):
|
||||||
|
@@ -529,6 +529,8 @@ WINDOWS = any(platform.win32_ver())
|
|||||||
EXTRA_RESERVED_VRAM = 400 * 1024 * 1024
|
EXTRA_RESERVED_VRAM = 400 * 1024 * 1024
|
||||||
if WINDOWS:
|
if WINDOWS:
|
||||||
EXTRA_RESERVED_VRAM = 600 * 1024 * 1024 #Windows is higher because of the shared vram issue
|
EXTRA_RESERVED_VRAM = 600 * 1024 * 1024 #Windows is higher because of the shared vram issue
|
||||||
|
if total_vram > (15 * 1024): # more extra reserved vram on 16GB+ cards
|
||||||
|
EXTRA_RESERVED_VRAM += 100 * 1024 * 1024
|
||||||
|
|
||||||
if args.reserve_vram is not None:
|
if args.reserve_vram is not None:
|
||||||
EXTRA_RESERVED_VRAM = args.reserve_vram * 1024 * 1024 * 1024
|
EXTRA_RESERVED_VRAM = args.reserve_vram * 1024 * 1024 * 1024
|
||||||
|
@@ -1,3 +1,3 @@
|
|||||||
# This file is automatically generated by the build process when version is
|
# This file is automatically generated by the build process when version is
|
||||||
# updated in pyproject.toml.
|
# updated in pyproject.toml.
|
||||||
__version__ = "0.3.46"
|
__version__ = "0.3.47"
|
||||||
|
@@ -1,6 +1,6 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "ComfyUI"
|
name = "ComfyUI"
|
||||||
version = "0.3.46"
|
version = "0.3.47"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
license = { file = "LICENSE" }
|
license = { file = "LICENSE" }
|
||||||
requires-python = ">=3.9"
|
requires-python = ">=3.9"
|
||||||
|
Reference in New Issue
Block a user