mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-08-02 15:04:50 +08:00
[feat] Add ImageStitch node for concatenating images (#8369)
* [feat] Add ImageStitch node for concatenating images with borders Add ImageStitch node that concatenates images in four directions with optional borders and intelligent size handling. Features include optional second image input, configurable borders with color selection, automatic batch size matching, and dimension alignment via padding or resizing. Upstreamed from https://github.com/kijai/ComfyUI-KJNodes with enhancements for better error handling and comprehensive test coverage. * [fix] Fix CI issues with CUDA dependencies and linting - Mock CUDA-dependent modules in tests to avoid CI failures on CPU-only runners - Fix ruff linting issues for code style compliance * [fix] Improve CI compatibility by mocking nodes module import Prevent CUDA initialization chain by mocking the nodes module at import time, which is cleaner than deep mocking of CUDA-specific functions. * [refactor] Clean up ImageStitch tests - Remove unnecessary sys.path manipulation (pythonpath set in pytest.ini) - Remove metadata tests that test framework internals rather than functionality - Rename complex scenario test to be more descriptive of what it tests * [refactor] Rename 'border' to 'spacing' for semantic accuracy - Change border_width/border_color to spacing_width/spacing_color in API - Update all tests to use spacing terminology - Update comments and variable names throughout - More accurately describes the gap/separator between images
This commit is contained in:
@@ -14,6 +14,7 @@ import re
|
||||
from io import BytesIO
|
||||
from inspect import cleandoc
|
||||
import torch
|
||||
import comfy.utils
|
||||
|
||||
from comfy.comfy_types import FileLocator
|
||||
|
||||
@@ -229,6 +230,186 @@ class SVG:
|
||||
all_svgs_list.extend(svg_item.data)
|
||||
return SVG(all_svgs_list)
|
||||
|
||||
|
||||
class ImageStitch:
|
||||
"""Upstreamed from https://github.com/kijai/ComfyUI-KJNodes"""
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"image1": ("IMAGE",),
|
||||
"direction": (["right", "down", "left", "up"], {"default": "right"}),
|
||||
"match_image_size": ("BOOLEAN", {"default": True}),
|
||||
"spacing_width": (
|
||||
"INT",
|
||||
{"default": 0, "min": 0, "max": 1024, "step": 2},
|
||||
),
|
||||
"spacing_color": (
|
||||
["white", "black", "red", "green", "blue"],
|
||||
{"default": "white"},
|
||||
),
|
||||
},
|
||||
"optional": {
|
||||
"image2": ("IMAGE",),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "stitch"
|
||||
CATEGORY = "image/transform"
|
||||
DESCRIPTION = """
|
||||
Stitches image2 to image1 in the specified direction.
|
||||
If image2 is not provided, returns image1 unchanged.
|
||||
Optional spacing can be added between images.
|
||||
"""
|
||||
|
||||
def stitch(
|
||||
self,
|
||||
image1,
|
||||
direction,
|
||||
match_image_size,
|
||||
spacing_width,
|
||||
spacing_color,
|
||||
image2=None,
|
||||
):
|
||||
if image2 is None:
|
||||
return (image1,)
|
||||
|
||||
# Handle batch size differences
|
||||
if image1.shape[0] != image2.shape[0]:
|
||||
max_batch = max(image1.shape[0], image2.shape[0])
|
||||
if image1.shape[0] < max_batch:
|
||||
image1 = torch.cat(
|
||||
[image1, image1[-1:].repeat(max_batch - image1.shape[0], 1, 1, 1)]
|
||||
)
|
||||
if image2.shape[0] < max_batch:
|
||||
image2 = torch.cat(
|
||||
[image2, image2[-1:].repeat(max_batch - image2.shape[0], 1, 1, 1)]
|
||||
)
|
||||
|
||||
# Match image sizes if requested
|
||||
if match_image_size:
|
||||
h1, w1 = image1.shape[1:3]
|
||||
h2, w2 = image2.shape[1:3]
|
||||
aspect_ratio = w2 / h2
|
||||
|
||||
if direction in ["left", "right"]:
|
||||
target_h, target_w = h1, int(h1 * aspect_ratio)
|
||||
else: # up, down
|
||||
target_w, target_h = w1, int(w1 / aspect_ratio)
|
||||
|
||||
image2 = comfy.utils.common_upscale(
|
||||
image2.movedim(-1, 1), target_w, target_h, "lanczos", "disabled"
|
||||
).movedim(1, -1)
|
||||
|
||||
# When not matching sizes, pad to align non-concat dimensions
|
||||
if not match_image_size:
|
||||
h1, w1 = image1.shape[1:3]
|
||||
h2, w2 = image2.shape[1:3]
|
||||
|
||||
if direction in ["left", "right"]:
|
||||
# For horizontal concat, pad heights to match
|
||||
if h1 != h2:
|
||||
target_h = max(h1, h2)
|
||||
if h1 < target_h:
|
||||
pad_h = target_h - h1
|
||||
pad_top, pad_bottom = pad_h // 2, pad_h - pad_h // 2
|
||||
image1 = torch.nn.functional.pad(image1, (0, 0, 0, 0, pad_top, pad_bottom), mode='constant', value=0.0)
|
||||
if h2 < target_h:
|
||||
pad_h = target_h - h2
|
||||
pad_top, pad_bottom = pad_h // 2, pad_h - pad_h // 2
|
||||
image2 = torch.nn.functional.pad(image2, (0, 0, 0, 0, pad_top, pad_bottom), mode='constant', value=0.0)
|
||||
else: # up, down
|
||||
# For vertical concat, pad widths to match
|
||||
if w1 != w2:
|
||||
target_w = max(w1, w2)
|
||||
if w1 < target_w:
|
||||
pad_w = target_w - w1
|
||||
pad_left, pad_right = pad_w // 2, pad_w - pad_w // 2
|
||||
image1 = torch.nn.functional.pad(image1, (0, 0, pad_left, pad_right), mode='constant', value=0.0)
|
||||
if w2 < target_w:
|
||||
pad_w = target_w - w2
|
||||
pad_left, pad_right = pad_w // 2, pad_w - pad_w // 2
|
||||
image2 = torch.nn.functional.pad(image2, (0, 0, pad_left, pad_right), mode='constant', value=0.0)
|
||||
|
||||
# Ensure same number of channels
|
||||
if image1.shape[-1] != image2.shape[-1]:
|
||||
max_channels = max(image1.shape[-1], image2.shape[-1])
|
||||
if image1.shape[-1] < max_channels:
|
||||
image1 = torch.cat(
|
||||
[
|
||||
image1,
|
||||
torch.ones(
|
||||
*image1.shape[:-1],
|
||||
max_channels - image1.shape[-1],
|
||||
device=image1.device,
|
||||
),
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
if image2.shape[-1] < max_channels:
|
||||
image2 = torch.cat(
|
||||
[
|
||||
image2,
|
||||
torch.ones(
|
||||
*image2.shape[:-1],
|
||||
max_channels - image2.shape[-1],
|
||||
device=image2.device,
|
||||
),
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
# Add spacing if specified
|
||||
if spacing_width > 0:
|
||||
spacing_width = spacing_width + (spacing_width % 2) # Ensure even
|
||||
|
||||
color_map = {
|
||||
"white": 1.0,
|
||||
"black": 0.0,
|
||||
"red": (1.0, 0.0, 0.0),
|
||||
"green": (0.0, 1.0, 0.0),
|
||||
"blue": (0.0, 0.0, 1.0),
|
||||
}
|
||||
color_val = color_map[spacing_color]
|
||||
|
||||
if direction in ["left", "right"]:
|
||||
spacing_shape = (
|
||||
image1.shape[0],
|
||||
max(image1.shape[1], image2.shape[1]),
|
||||
spacing_width,
|
||||
image1.shape[-1],
|
||||
)
|
||||
else:
|
||||
spacing_shape = (
|
||||
image1.shape[0],
|
||||
spacing_width,
|
||||
max(image1.shape[2], image2.shape[2]),
|
||||
image1.shape[-1],
|
||||
)
|
||||
|
||||
spacing = torch.full(spacing_shape, 0.0, device=image1.device)
|
||||
if isinstance(color_val, tuple):
|
||||
for i, c in enumerate(color_val):
|
||||
if i < spacing.shape[-1]:
|
||||
spacing[..., i] = c
|
||||
if spacing.shape[-1] == 4: # Add alpha
|
||||
spacing[..., 3] = 1.0
|
||||
else:
|
||||
spacing[..., : min(3, spacing.shape[-1])] = color_val
|
||||
if spacing.shape[-1] == 4:
|
||||
spacing[..., 3] = 1.0
|
||||
|
||||
# Concatenate images
|
||||
images = [image2, image1] if direction in ["left", "up"] else [image1, image2]
|
||||
if spacing_width > 0:
|
||||
images.insert(1, spacing)
|
||||
|
||||
concat_dim = 2 if direction in ["left", "right"] else 1
|
||||
return (torch.cat(images, dim=concat_dim),)
|
||||
|
||||
|
||||
class SaveSVGNode:
|
||||
"""
|
||||
Save SVG files on disk.
|
||||
@@ -318,4 +499,5 @@ NODE_CLASS_MAPPINGS = {
|
||||
"SaveAnimatedWEBP": SaveAnimatedWEBP,
|
||||
"SaveAnimatedPNG": SaveAnimatedPNG,
|
||||
"SaveSVGNode": SaveSVGNode,
|
||||
"ImageStitch": ImageStitch,
|
||||
}
|
||||
|
Reference in New Issue
Block a user