1
mirror of https://github.com/comfyanonymous/ComfyUI.git synced 2025-08-02 23:14:49 +08:00

Use common function for casting weights to input.

This commit is contained in:
comfyanonymous
2024-07-30 05:03:20 -04:00
parent 79040635da
commit 25853d0be8
7 changed files with 51 additions and 31 deletions

View File

@@ -8,6 +8,7 @@ import torch.nn as nn
from .. import attention
from einops import rearrange, repeat
from .util import timestep_embedding
import comfy.ops
def default(x, y):
if x is not None:
@@ -926,7 +927,7 @@ class MMDiT(nn.Module):
context = self.context_processor(context)
hw = x.shape[-2:]
x = self.x_embedder(x) + self.cropped_pos_embed(hw, device=x.device).to(dtype=x.dtype, device=x.device)
x = self.x_embedder(x) + comfy.ops.cast_to_input(self.cropped_pos_embed(hw, device=x.device), x)
c = self.t_embedder(t, dtype=x.dtype) # (N, D)
if y is not None and self.y_embedder is not None:
y = self.y_embedder(y) # (N, D)