mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-08-02 23:14:49 +08:00
Properly handle and reshape masks when used on 3d latents.
This commit is contained in:
@@ -848,3 +848,24 @@ class ProgressBar:
|
||||
|
||||
def update(self, value):
|
||||
self.update_absolute(self.current + value)
|
||||
|
||||
def reshape_mask(input_mask, output_shape):
|
||||
dims = len(output_shape) - 2
|
||||
|
||||
if dims == 1:
|
||||
scale_mode = "linear"
|
||||
|
||||
if dims == 2:
|
||||
mask = input_mask.reshape((-1, 1, input_mask.shape[-2], input_mask.shape[-1]))
|
||||
scale_mode = "bilinear"
|
||||
|
||||
if dims == 3:
|
||||
if len(input_mask.shape) < 5:
|
||||
mask = input_mask.reshape((1, 1, -1, input_mask.shape[-2], input_mask.shape[-1]))
|
||||
scale_mode = "trilinear"
|
||||
|
||||
mask = torch.nn.functional.interpolate(mask, size=output_shape[2:], mode=scale_mode)
|
||||
if mask.shape[1] < output_shape[1]:
|
||||
mask = mask.repeat((1, output_shape[1]) + (1,) * dims)[:,:output_shape[1]]
|
||||
mask = comfy.utils.repeat_to_batch_size(mask, output_shape[0])
|
||||
return mask
|
||||
|
Reference in New Issue
Block a user