1
mirror of https://github.com/comfyanonymous/ComfyUI.git synced 2025-08-02 15:04:50 +08:00

Merge branch 'master' into v3-definition

This commit is contained in:
Jedrzej Kosinski
2025-07-29 19:52:15 -07:00
6 changed files with 44 additions and 8 deletions

View File

@@ -146,6 +146,15 @@ WAN_CROSSATTENTION_CLASSES = {
} }
def repeat_e(e, x):
repeats = 1
if e.shape[1] > 1:
repeats = x.shape[1] // e.shape[1]
if repeats == 1:
return e
return torch.repeat_interleave(e, repeats, dim=1)
class WanAttentionBlock(nn.Module): class WanAttentionBlock(nn.Module):
def __init__(self, def __init__(self,
@@ -201,6 +210,7 @@ class WanAttentionBlock(nn.Module):
freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2] freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
""" """
# assert e.dtype == torch.float32 # assert e.dtype == torch.float32
if e.ndim < 4: if e.ndim < 4:
e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device) + e).chunk(6, dim=1) e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device) + e).chunk(6, dim=1)
else: else:
@@ -209,15 +219,15 @@ class WanAttentionBlock(nn.Module):
# self-attention # self-attention
y = self.self_attn( y = self.self_attn(
self.norm1(x) * (1 + e[1]) + e[0], self.norm1(x) * (1 + repeat_e(e[1], x)) + repeat_e(e[0], x),
freqs) freqs)
x = x + y * e[2] x = x + y * repeat_e(e[2], x)
# cross-attention & ffn # cross-attention & ffn
x = x + self.cross_attn(self.norm3(x), context, context_img_len=context_img_len) x = x + self.cross_attn(self.norm3(x), context, context_img_len=context_img_len)
y = self.ffn(self.norm2(x) * (1 + e[4]) + e[3]) y = self.ffn(self.norm2(x) * (1 + repeat_e(e[4], x)) + repeat_e(e[3], x))
x = x + y * e[5] x = x + y * repeat_e(e[5], x)
return x return x
@@ -331,7 +341,8 @@ class Head(nn.Module):
e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device) + e.unsqueeze(1)).chunk(2, dim=1) e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device) + e.unsqueeze(1)).chunk(2, dim=1)
else: else:
e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device).unsqueeze(0) + e.unsqueeze(2)).unbind(2) e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device).unsqueeze(0) + e.unsqueeze(2)).unbind(2)
x = (self.head(self.norm(x) * (1 + e[1]) + e[0]))
x = (self.head(self.norm(x) * (1 + repeat_e(e[1], x)) + repeat_e(e[0], x)))
return x return x

View File

@@ -1202,7 +1202,7 @@ class WAN22(BaseModel):
def process_timestep(self, timestep, x, denoise_mask=None, **kwargs): def process_timestep(self, timestep, x, denoise_mask=None, **kwargs):
if denoise_mask is None: if denoise_mask is None:
return timestep return timestep
temp_ts = (torch.mean(denoise_mask[:, :, :, ::2, ::2], dim=1, keepdim=True) * timestep.view([timestep.shape[0]] + [1] * (denoise_mask.ndim - 1))).reshape(timestep.shape[0], -1) temp_ts = (torch.mean(denoise_mask[:, :, :, :, :], dim=(1, 3, 4), keepdim=True) * timestep.view([timestep.shape[0]] + [1] * (denoise_mask.ndim - 1))).reshape(timestep.shape[0], -1)
return temp_ts return temp_ts
def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs): def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs):

View File

@@ -2,6 +2,7 @@ from __future__ import annotations
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Optional, Union from typing import Optional, Union
import io import io
import av
from comfy_api.util import VideoContainer, VideoCodec, VideoComponents from comfy_api.util import VideoContainer, VideoCodec, VideoComponents
class VideoInput(ABC): class VideoInput(ABC):
@@ -70,3 +71,15 @@ class VideoInput(ABC):
components = self.get_components() components = self.get_components()
frame_count = components.images.shape[0] frame_count = components.images.shape[0]
return float(frame_count / components.frame_rate) return float(frame_count / components.frame_rate)
def get_container_format(self) -> str:
"""
Returns the container format of the video (e.g., 'mp4', 'mov', 'avi').
Returns:
Container format as string
"""
# Default implementation - subclasses should override for better performance
source = self.get_stream_source()
with av.open(source, mode="r") as container:
return container.format.name

View File

@@ -120,6 +120,18 @@ class VideoFromFile(VideoInput):
raise ValueError(f"Could not determine duration for file '{self.__file}'") raise ValueError(f"Could not determine duration for file '{self.__file}'")
def get_container_format(self) -> str:
"""
Returns the container format of the video (e.g., 'mp4', 'mov', 'avi').
Returns:
Container format as string
"""
if isinstance(self.__file, io.BytesIO):
self.__file.seek(0)
with av.open(self.__file, mode='r') as container:
return container.format.name
def get_components_internal(self, container: InputContainer) -> VideoComponents: def get_components_internal(self, container: InputContainer) -> VideoComponents:
# Get video frames # Get video frames
frames = [] frames = []

View File

@@ -1,3 +1,3 @@
# This file is automatically generated by the build process when version is # This file is automatically generated by the build process when version is
# updated in pyproject.toml. # updated in pyproject.toml.
__version__ = "0.3.46" __version__ = "0.3.47"

View File

@@ -1,6 +1,6 @@
[project] [project]
name = "ComfyUI" name = "ComfyUI"
version = "0.3.46" version = "0.3.47"
readme = "README.md" readme = "README.md"
license = { file = "LICENSE" } license = { file = "LICENSE" }
requires-python = ">=3.9" requires-python = ">=3.9"