1
mirror of https://github.com/comfyanonymous/ComfyUI.git synced 2025-08-02 23:14:49 +08:00

Basic Genmo Mochi video model support.

To use:
"Load CLIP" node with t5xxl + type mochi
"Load Diffusion Model" node with the mochi dit file.
"Load VAE" with the mochi vae file.

EmptyMochiLatentVideo node for the latent.
euler + linear_quadratic in the KSampler node.
This commit is contained in:
comfyanonymous
2024-10-26 06:54:00 -04:00
parent c3ffbae067
commit 5cbb01bc2f
18 changed files with 1677 additions and 24 deletions

View File

@@ -731,7 +731,27 @@ def get_tiled_scale_steps(width, height, tile_x, tile_y, overlap):
@torch.inference_mode()
def tiled_scale_multidim(samples, function, tile=(64, 64), overlap = 8, upscale_amount = 4, out_channels = 3, output_device="cpu", pbar = None):
dims = len(tile)
output = torch.empty([samples.shape[0], out_channels] + list(map(lambda a: round(a * upscale_amount), samples.shape[2:])), device=output_device)
if not (isinstance(upscale_amount, (tuple, list))):
upscale_amount = [upscale_amount] * dims
if not (isinstance(overlap, (tuple, list))):
overlap = [overlap] * dims
def get_upscale(dim, val):
up = upscale_amount[dim]
if callable(up):
return up(val)
else:
return up * val
def mult_list_upscale(a):
out = []
for i in range(len(a)):
out.append(round(get_upscale(i, a[i])))
return out
output = torch.empty([samples.shape[0], out_channels] + mult_list_upscale(samples.shape[2:]), device=output_device)
for b in range(samples.shape[0]):
s = samples[b:b+1]
@@ -743,27 +763,27 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap = 8, upscale_
pbar.update(1)
continue
out = torch.zeros([s.shape[0], out_channels] + list(map(lambda a: round(a * upscale_amount), s.shape[2:])), device=output_device)
out_div = torch.zeros([s.shape[0], out_channels] + list(map(lambda a: round(a * upscale_amount), s.shape[2:])), device=output_device)
out = torch.zeros([s.shape[0], out_channels] + mult_list_upscale(s.shape[2:]), device=output_device)
out_div = torch.zeros([s.shape[0], out_channels] + mult_list_upscale(s.shape[2:]), device=output_device)
positions = [range(0, s.shape[d+2], tile[d] - overlap) if s.shape[d+2] > tile[d] else [0] for d in range(dims)]
positions = [range(0, s.shape[d+2], tile[d] - overlap[d]) if s.shape[d+2] > tile[d] else [0] for d in range(dims)]
for it in itertools.product(*positions):
s_in = s
upscaled = []
for d in range(dims):
pos = max(0, min(s.shape[d + 2] - overlap, it[d]))
pos = max(0, min(s.shape[d + 2] - (overlap[d] + 1), it[d]))
l = min(tile[d], s.shape[d + 2] - pos)
s_in = s_in.narrow(d + 2, pos, l)
upscaled.append(round(pos * upscale_amount))
upscaled.append(round(get_upscale(d, pos)))
ps = function(s_in).to(output_device)
mask = torch.ones_like(ps)
feather = round(overlap * upscale_amount)
for t in range(feather):
for d in range(2, dims + 2):
for d in range(2, dims + 2):
feather = round(get_upscale(d - 2, overlap[d - 2]))
for t in range(feather):
a = (t + 1) / feather
mask.narrow(d, t, 1).mul_(a)
mask.narrow(d, mask.shape[d] - 1 - t, 1).mul_(a)