mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-08-03 23:49:57 +08:00
Compare commits
3 Commits
Author | SHA1 | Date | |
---|---|---|---|
|
55add50220 | ||
|
0aa2368e46 | ||
|
cca96a85ae |
@@ -293,7 +293,7 @@ class GeneralDIT(nn.Module):
|
|||||||
x_B_T_H_W_D = self.x_embedder(x_B_C_T_H_W)
|
x_B_T_H_W_D = self.x_embedder(x_B_C_T_H_W)
|
||||||
|
|
||||||
if self.extra_per_block_abs_pos_emb:
|
if self.extra_per_block_abs_pos_emb:
|
||||||
extra_pos_emb = self.extra_pos_embedder(x_B_T_H_W_D, fps=fps, device=x_B_C_T_H_W.device)
|
extra_pos_emb = self.extra_pos_embedder(x_B_T_H_W_D, fps=fps, device=x_B_C_T_H_W.device, dtype=x_B_C_T_H_W.dtype)
|
||||||
else:
|
else:
|
||||||
extra_pos_emb = None
|
extra_pos_emb = None
|
||||||
|
|
||||||
|
@@ -41,12 +41,12 @@ def normalize(x: torch.Tensor, dim: Optional[List[int]] = None, eps: float = 0)
|
|||||||
|
|
||||||
|
|
||||||
class VideoPositionEmb(nn.Module):
|
class VideoPositionEmb(nn.Module):
|
||||||
def forward(self, x_B_T_H_W_C: torch.Tensor, fps=Optional[torch.Tensor], device=None) -> torch.Tensor:
|
def forward(self, x_B_T_H_W_C: torch.Tensor, fps=Optional[torch.Tensor], device=None, dtype=None) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
It delegates the embedding generation to generate_embeddings function.
|
It delegates the embedding generation to generate_embeddings function.
|
||||||
"""
|
"""
|
||||||
B_T_H_W_C = x_B_T_H_W_C.shape
|
B_T_H_W_C = x_B_T_H_W_C.shape
|
||||||
embeddings = self.generate_embeddings(B_T_H_W_C, fps=fps, device=device)
|
embeddings = self.generate_embeddings(B_T_H_W_C, fps=fps, device=device, dtype=dtype)
|
||||||
|
|
||||||
return embeddings
|
return embeddings
|
||||||
|
|
||||||
@@ -104,6 +104,7 @@ class VideoRopePosition3DEmb(VideoPositionEmb):
|
|||||||
w_ntk_factor: Optional[float] = None,
|
w_ntk_factor: Optional[float] = None,
|
||||||
t_ntk_factor: Optional[float] = None,
|
t_ntk_factor: Optional[float] = None,
|
||||||
device=None,
|
device=None,
|
||||||
|
dtype=None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Generate embeddings for the given input size.
|
Generate embeddings for the given input size.
|
||||||
@@ -189,13 +190,12 @@ class LearnablePosEmbAxis(VideoPositionEmb):
|
|||||||
self.pos_emb_w = nn.Parameter(torch.empty(len_w, model_channels, device=device, dtype=dtype))
|
self.pos_emb_w = nn.Parameter(torch.empty(len_w, model_channels, device=device, dtype=dtype))
|
||||||
self.pos_emb_t = nn.Parameter(torch.empty(len_t, model_channels, device=device, dtype=dtype))
|
self.pos_emb_t = nn.Parameter(torch.empty(len_t, model_channels, device=device, dtype=dtype))
|
||||||
|
|
||||||
|
def generate_embeddings(self, B_T_H_W_C: torch.Size, fps=Optional[torch.Tensor], device=None, dtype=None) -> torch.Tensor:
|
||||||
def generate_embeddings(self, B_T_H_W_C: torch.Size, fps=Optional[torch.Tensor], device=None) -> torch.Tensor:
|
|
||||||
B, T, H, W, _ = B_T_H_W_C
|
B, T, H, W, _ = B_T_H_W_C
|
||||||
if self.interpolation == "crop":
|
if self.interpolation == "crop":
|
||||||
emb_h_H = self.pos_emb_h[:H].to(device=device)
|
emb_h_H = self.pos_emb_h[:H].to(device=device, dtype=dtype)
|
||||||
emb_w_W = self.pos_emb_w[:W].to(device=device)
|
emb_w_W = self.pos_emb_w[:W].to(device=device, dtype=dtype)
|
||||||
emb_t_T = self.pos_emb_t[:T].to(device=device)
|
emb_t_T = self.pos_emb_t[:T].to(device=device, dtype=dtype)
|
||||||
emb = (
|
emb = (
|
||||||
repeat(emb_t_T, "t d-> b t h w d", b=B, h=H, w=W)
|
repeat(emb_t_T, "t d-> b t h w d", b=B, h=H, w=W)
|
||||||
+ repeat(emb_h_H, "h d-> b t h w d", b=B, t=T, w=W)
|
+ repeat(emb_h_H, "h d-> b t h w d", b=B, t=T, w=W)
|
||||||
|
@@ -18,6 +18,7 @@ import logging
|
|||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
import math
|
||||||
|
|
||||||
from .cosmos_tokenizer.layers3d import (
|
from .cosmos_tokenizer.layers3d import (
|
||||||
EncoderFactorized,
|
EncoderFactorized,
|
||||||
@@ -105,17 +106,23 @@ class CausalContinuousVideoTokenizer(nn.Module):
|
|||||||
z, posteriors = self.distribution(moments)
|
z, posteriors = self.distribution(moments)
|
||||||
latent_ch = z.shape[1]
|
latent_ch = z.shape[1]
|
||||||
latent_t = z.shape[2]
|
latent_t = z.shape[2]
|
||||||
dtype = z.dtype
|
in_dtype = z.dtype
|
||||||
mean = self.latent_mean.view(latent_ch, -1)[:, : latent_t].reshape([1, latent_ch, -1, 1, 1]).to(dtype=dtype, device=z.device)
|
mean = self.latent_mean.view(latent_ch, -1)
|
||||||
std = self.latent_std.view(latent_ch, -1)[:, : latent_t].reshape([1, latent_ch, -1, 1, 1]).to(dtype=dtype, device=z.device)
|
std = self.latent_std.view(latent_ch, -1)
|
||||||
|
|
||||||
|
mean = mean.repeat(1, math.ceil(latent_t / mean.shape[-1]))[:, : latent_t].reshape([1, latent_ch, -1, 1, 1]).to(dtype=in_dtype, device=z.device)
|
||||||
|
std = std.repeat(1, math.ceil(latent_t / std.shape[-1]))[:, : latent_t].reshape([1, latent_ch, -1, 1, 1]).to(dtype=in_dtype, device=z.device)
|
||||||
return ((z - mean) / std) * self.sigma_data
|
return ((z - mean) / std) * self.sigma_data
|
||||||
|
|
||||||
def decode(self, z):
|
def decode(self, z):
|
||||||
in_dtype = z.dtype
|
in_dtype = z.dtype
|
||||||
latent_ch = z.shape[1]
|
latent_ch = z.shape[1]
|
||||||
latent_t = z.shape[2]
|
latent_t = z.shape[2]
|
||||||
mean = self.latent_mean.view(latent_ch, -1)[:, : latent_t].reshape([1, latent_ch, -1, 1, 1]).to(dtype=in_dtype, device=z.device)
|
mean = self.latent_mean.view(latent_ch, -1)
|
||||||
std = self.latent_std.view(latent_ch, -1)[:, : latent_t].reshape([1, latent_ch, -1, 1, 1]).to(dtype=in_dtype, device=z.device)
|
std = self.latent_std.view(latent_ch, -1)
|
||||||
|
|
||||||
|
mean = mean.repeat(1, math.ceil(latent_t / mean.shape[-1]))[:, : latent_t].reshape([1, latent_ch, -1, 1, 1]).to(dtype=in_dtype, device=z.device)
|
||||||
|
std = std.repeat(1, math.ceil(latent_t / std.shape[-1]))[:, : latent_t].reshape([1, latent_ch, -1, 1, 1]).to(dtype=in_dtype, device=z.device)
|
||||||
|
|
||||||
z = z / self.sigma_data
|
z = z / self.sigma_data
|
||||||
z = z * std + mean
|
z = z * std + mean
|
||||||
|
@@ -1,3 +1,3 @@
|
|||||||
# This file is automatically generated by the build process when version is
|
# This file is automatically generated by the build process when version is
|
||||||
# updated in pyproject.toml.
|
# updated in pyproject.toml.
|
||||||
__version__ = "0.3.11"
|
__version__ = "0.3.12"
|
||||||
|
@@ -1,6 +1,6 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "ComfyUI"
|
name = "ComfyUI"
|
||||||
version = "0.3.11"
|
version = "0.3.12"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
license = { file = "LICENSE" }
|
license = { file = "LICENSE" }
|
||||||
requires-python = ">=3.9"
|
requires-python = ">=3.9"
|
||||||
|
Reference in New Issue
Block a user