mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-08-02 23:14:49 +08:00
Basic Genmo Mochi video model support.
To use: "Load CLIP" node with t5xxl + type mochi "Load Diffusion Model" node with the mochi dit file. "Load VAE" with the mochi vae file. EmptyMochiLatentVideo node for the latent. euler + linear_quadratic in the KSampler node.
This commit is contained in:
@@ -731,7 +731,27 @@ def get_tiled_scale_steps(width, height, tile_x, tile_y, overlap):
|
||||
@torch.inference_mode()
|
||||
def tiled_scale_multidim(samples, function, tile=(64, 64), overlap = 8, upscale_amount = 4, out_channels = 3, output_device="cpu", pbar = None):
|
||||
dims = len(tile)
|
||||
output = torch.empty([samples.shape[0], out_channels] + list(map(lambda a: round(a * upscale_amount), samples.shape[2:])), device=output_device)
|
||||
|
||||
if not (isinstance(upscale_amount, (tuple, list))):
|
||||
upscale_amount = [upscale_amount] * dims
|
||||
|
||||
if not (isinstance(overlap, (tuple, list))):
|
||||
overlap = [overlap] * dims
|
||||
|
||||
def get_upscale(dim, val):
|
||||
up = upscale_amount[dim]
|
||||
if callable(up):
|
||||
return up(val)
|
||||
else:
|
||||
return up * val
|
||||
|
||||
def mult_list_upscale(a):
|
||||
out = []
|
||||
for i in range(len(a)):
|
||||
out.append(round(get_upscale(i, a[i])))
|
||||
return out
|
||||
|
||||
output = torch.empty([samples.shape[0], out_channels] + mult_list_upscale(samples.shape[2:]), device=output_device)
|
||||
|
||||
for b in range(samples.shape[0]):
|
||||
s = samples[b:b+1]
|
||||
@@ -743,27 +763,27 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap = 8, upscale_
|
||||
pbar.update(1)
|
||||
continue
|
||||
|
||||
out = torch.zeros([s.shape[0], out_channels] + list(map(lambda a: round(a * upscale_amount), s.shape[2:])), device=output_device)
|
||||
out_div = torch.zeros([s.shape[0], out_channels] + list(map(lambda a: round(a * upscale_amount), s.shape[2:])), device=output_device)
|
||||
out = torch.zeros([s.shape[0], out_channels] + mult_list_upscale(s.shape[2:]), device=output_device)
|
||||
out_div = torch.zeros([s.shape[0], out_channels] + mult_list_upscale(s.shape[2:]), device=output_device)
|
||||
|
||||
positions = [range(0, s.shape[d+2], tile[d] - overlap) if s.shape[d+2] > tile[d] else [0] for d in range(dims)]
|
||||
positions = [range(0, s.shape[d+2], tile[d] - overlap[d]) if s.shape[d+2] > tile[d] else [0] for d in range(dims)]
|
||||
|
||||
for it in itertools.product(*positions):
|
||||
s_in = s
|
||||
upscaled = []
|
||||
|
||||
for d in range(dims):
|
||||
pos = max(0, min(s.shape[d + 2] - overlap, it[d]))
|
||||
pos = max(0, min(s.shape[d + 2] - (overlap[d] + 1), it[d]))
|
||||
l = min(tile[d], s.shape[d + 2] - pos)
|
||||
s_in = s_in.narrow(d + 2, pos, l)
|
||||
upscaled.append(round(pos * upscale_amount))
|
||||
upscaled.append(round(get_upscale(d, pos)))
|
||||
|
||||
ps = function(s_in).to(output_device)
|
||||
mask = torch.ones_like(ps)
|
||||
feather = round(overlap * upscale_amount)
|
||||
|
||||
for t in range(feather):
|
||||
for d in range(2, dims + 2):
|
||||
for d in range(2, dims + 2):
|
||||
feather = round(get_upscale(d - 2, overlap[d - 2]))
|
||||
for t in range(feather):
|
||||
a = (t + 1) / feather
|
||||
mask.narrow(d, t, 1).mul_(a)
|
||||
mask.narrow(d, mask.shape[d] - 1 - t, 1).mul_(a)
|
||||
|
Reference in New Issue
Block a user