mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-08-02 15:04:50 +08:00
Compare commits
3 Commits
9cca36fa2b
...
382f84a826
Author | SHA1 | Date | |
---|---|---|---|
|
382f84a826 | ||
|
2f74e17975 | ||
|
dca6bdd4fa |
@@ -146,6 +146,15 @@ WAN_CROSSATTENTION_CLASSES = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def repeat_e(e, x):
|
||||||
|
repeats = 1
|
||||||
|
if e.shape[1] > 1:
|
||||||
|
repeats = x.shape[1] // e.shape[1]
|
||||||
|
if repeats == 1:
|
||||||
|
return e
|
||||||
|
return torch.repeat_interleave(e, repeats, dim=1)
|
||||||
|
|
||||||
|
|
||||||
class WanAttentionBlock(nn.Module):
|
class WanAttentionBlock(nn.Module):
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
@@ -201,6 +210,7 @@ class WanAttentionBlock(nn.Module):
|
|||||||
freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
|
freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
|
||||||
"""
|
"""
|
||||||
# assert e.dtype == torch.float32
|
# assert e.dtype == torch.float32
|
||||||
|
|
||||||
if e.ndim < 4:
|
if e.ndim < 4:
|
||||||
e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device) + e).chunk(6, dim=1)
|
e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device) + e).chunk(6, dim=1)
|
||||||
else:
|
else:
|
||||||
@@ -209,15 +219,15 @@ class WanAttentionBlock(nn.Module):
|
|||||||
|
|
||||||
# self-attention
|
# self-attention
|
||||||
y = self.self_attn(
|
y = self.self_attn(
|
||||||
self.norm1(x) * (1 + e[1]) + e[0],
|
self.norm1(x) * (1 + repeat_e(e[1], x)) + repeat_e(e[0], x),
|
||||||
freqs)
|
freqs)
|
||||||
|
|
||||||
x = x + y * e[2]
|
x = x + y * repeat_e(e[2], x)
|
||||||
|
|
||||||
# cross-attention & ffn
|
# cross-attention & ffn
|
||||||
x = x + self.cross_attn(self.norm3(x), context, context_img_len=context_img_len)
|
x = x + self.cross_attn(self.norm3(x), context, context_img_len=context_img_len)
|
||||||
y = self.ffn(self.norm2(x) * (1 + e[4]) + e[3])
|
y = self.ffn(self.norm2(x) * (1 + repeat_e(e[4], x)) + repeat_e(e[3], x))
|
||||||
x = x + y * e[5]
|
x = x + y * repeat_e(e[5], x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
@@ -331,7 +341,8 @@ class Head(nn.Module):
|
|||||||
e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device) + e.unsqueeze(1)).chunk(2, dim=1)
|
e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device) + e.unsqueeze(1)).chunk(2, dim=1)
|
||||||
else:
|
else:
|
||||||
e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device).unsqueeze(0) + e.unsqueeze(2)).unbind(2)
|
e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device).unsqueeze(0) + e.unsqueeze(2)).unbind(2)
|
||||||
x = (self.head(self.norm(x) * (1 + e[1]) + e[0]))
|
|
||||||
|
x = (self.head(self.norm(x) * (1 + repeat_e(e[1], x)) + repeat_e(e[0], x)))
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
@@ -1202,7 +1202,7 @@ class WAN22(BaseModel):
|
|||||||
def process_timestep(self, timestep, x, denoise_mask=None, **kwargs):
|
def process_timestep(self, timestep, x, denoise_mask=None, **kwargs):
|
||||||
if denoise_mask is None:
|
if denoise_mask is None:
|
||||||
return timestep
|
return timestep
|
||||||
temp_ts = (torch.mean(denoise_mask[:, :, :, ::2, ::2], dim=1, keepdim=True) * timestep.view([timestep.shape[0]] + [1] * (denoise_mask.ndim - 1))).reshape(timestep.shape[0], -1)
|
temp_ts = (torch.mean(denoise_mask[:, :, :, :, :], dim=(1, 3, 4), keepdim=True) * timestep.view([timestep.shape[0]] + [1] * (denoise_mask.ndim - 1))).reshape(timestep.shape[0], -1)
|
||||||
return temp_ts
|
return temp_ts
|
||||||
|
|
||||||
def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs):
|
def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs):
|
||||||
|
@@ -1,3 +1,3 @@
|
|||||||
# This file is automatically generated by the build process when version is
|
# This file is automatically generated by the build process when version is
|
||||||
# updated in pyproject.toml.
|
# updated in pyproject.toml.
|
||||||
__version__ = "0.3.46"
|
__version__ = "0.3.47"
|
||||||
|
@@ -1,6 +1,6 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "ComfyUI"
|
name = "ComfyUI"
|
||||||
version = "0.3.46"
|
version = "0.3.47"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
license = { file = "LICENSE" }
|
license = { file = "LICENSE" }
|
||||||
requires-python = ">=3.9"
|
requires-python = ">=3.9"
|
||||||
|
Reference in New Issue
Block a user