diff --git a/comfy/ldm/wan/model.py b/comfy/ldm/wan/model.py index b9e47e9f7..a93a13c86 100644 --- a/comfy/ldm/wan/model.py +++ b/comfy/ldm/wan/model.py @@ -146,6 +146,15 @@ WAN_CROSSATTENTION_CLASSES = { } +def repeat_e(e, x): + repeats = 1 + if e.shape[1] > 1: + repeats = x.shape[1] // e.shape[1] + if repeats == 1: + return e + return torch.repeat_interleave(e, repeats, dim=1) + + class WanAttentionBlock(nn.Module): def __init__(self, @@ -201,6 +210,7 @@ class WanAttentionBlock(nn.Module): freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2] """ # assert e.dtype == torch.float32 + if e.ndim < 4: e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device) + e).chunk(6, dim=1) else: @@ -209,15 +219,15 @@ class WanAttentionBlock(nn.Module): # self-attention y = self.self_attn( - self.norm1(x) * (1 + e[1]) + e[0], + self.norm1(x) * (1 + repeat_e(e[1], x)) + repeat_e(e[0], x), freqs) - x = x + y * e[2] + x = x + y * repeat_e(e[2], x) # cross-attention & ffn x = x + self.cross_attn(self.norm3(x), context, context_img_len=context_img_len) - y = self.ffn(self.norm2(x) * (1 + e[4]) + e[3]) - x = x + y * e[5] + y = self.ffn(self.norm2(x) * (1 + repeat_e(e[4], x)) + repeat_e(e[3], x)) + x = x + y * repeat_e(e[5], x) return x @@ -331,7 +341,8 @@ class Head(nn.Module): e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device) + e.unsqueeze(1)).chunk(2, dim=1) else: e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device).unsqueeze(0) + e.unsqueeze(2)).unbind(2) - x = (self.head(self.norm(x) * (1 + e[1]) + e[0])) + + x = (self.head(self.norm(x) * (1 + repeat_e(e[1], x)) + repeat_e(e[0], x))) return x diff --git a/comfy/model_base.py b/comfy/model_base.py index d019b991a..6b7978949 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1202,7 +1202,7 @@ class WAN22(BaseModel): def process_timestep(self, timestep, x, denoise_mask=None, **kwargs): if denoise_mask is None: return timestep - temp_ts = (torch.mean(denoise_mask[:, :, :, ::2, ::2], dim=1, keepdim=True) * timestep.view([timestep.shape[0]] + [1] * (denoise_mask.ndim - 1))).reshape(timestep.shape[0], -1) + temp_ts = (torch.mean(denoise_mask[:, :, :, :, :], dim=(1, 3, 4), keepdim=True) * timestep.view([timestep.shape[0]] + [1] * (denoise_mask.ndim - 1))).reshape(timestep.shape[0], -1) return temp_ts def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs): diff --git a/comfy_api/latest/_input/video_types.py b/comfy_api/latest/_input/video_types.py index bb936e0a4..5d95dc507 100644 --- a/comfy_api/latest/_input/video_types.py +++ b/comfy_api/latest/_input/video_types.py @@ -2,6 +2,7 @@ from __future__ import annotations from abc import ABC, abstractmethod from typing import Optional, Union import io +import av from comfy_api.util import VideoContainer, VideoCodec, VideoComponents class VideoInput(ABC): @@ -70,3 +71,15 @@ class VideoInput(ABC): components = self.get_components() frame_count = components.images.shape[0] return float(frame_count / components.frame_rate) + + def get_container_format(self) -> str: + """ + Returns the container format of the video (e.g., 'mp4', 'mov', 'avi'). + + Returns: + Container format as string + """ + # Default implementation - subclasses should override for better performance + source = self.get_stream_source() + with av.open(source, mode="r") as container: + return container.format.name diff --git a/comfy_api/latest/_input_impl/video_types.py b/comfy_api/latest/_input_impl/video_types.py index 2089307df..28de9651d 100644 --- a/comfy_api/latest/_input_impl/video_types.py +++ b/comfy_api/latest/_input_impl/video_types.py @@ -120,6 +120,18 @@ class VideoFromFile(VideoInput): raise ValueError(f"Could not determine duration for file '{self.__file}'") + def get_container_format(self) -> str: + """ + Returns the container format of the video (e.g., 'mp4', 'mov', 'avi'). + + Returns: + Container format as string + """ + if isinstance(self.__file, io.BytesIO): + self.__file.seek(0) + with av.open(self.__file, mode='r') as container: + return container.format.name + def get_components_internal(self, container: InputContainer) -> VideoComponents: # Get video frames frames = [] diff --git a/comfyui_version.py b/comfyui_version.py index 315710dd2..20a2e892a 100644 --- a/comfyui_version.py +++ b/comfyui_version.py @@ -1,3 +1,3 @@ # This file is automatically generated by the build process when version is # updated in pyproject.toml. -__version__ = "0.3.46" +__version__ = "0.3.47" diff --git a/pyproject.toml b/pyproject.toml index f0a979145..4fd3acd79 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "ComfyUI" -version = "0.3.46" +version = "0.3.47" readme = "README.md" license = { file = "LICENSE" } requires-python = ">=3.9"