mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-08-02 15:04:50 +08:00
Compare commits
29 Commits
bc6b0113e2
...
b6a4a4c664
Author | SHA1 | Date | |
---|---|---|---|
|
b6a4a4c664 | ||
|
fd9c34a3eb | ||
|
de0901bd02 | ||
|
2a7793394f | ||
|
18ed598fa1 | ||
|
7f492522b6 | ||
|
9eda706e64 | ||
|
650838fd6f | ||
|
491fafbd64 | ||
|
9bc2798f72 | ||
|
50afba747c | ||
|
6b8062f414 | ||
|
b1ae4126c3 | ||
|
9dabda19f0 | ||
|
543c24108c | ||
|
260a5ca5d9 | ||
|
861c3bbb3d | ||
|
9ca581c941 | ||
|
4831e9c2c4 | ||
|
480375f349 | ||
|
b40143984c | ||
|
b43916a134 | ||
|
7bc7dd2aa2 | ||
|
938d3e8216 | ||
|
8f05fb48ea | ||
|
b7ff5bd14d | ||
|
2b653e8c18 | ||
|
1fd306824d | ||
|
1205afc708 |
40
.github/workflows/check-line-endings.yml
vendored
Normal file
40
.github/workflows/check-line-endings.yml
vendored
Normal file
@@ -0,0 +1,40 @@
|
||||
name: Check for Windows Line Endings
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
branches: ['*'] # Trigger on all pull requests to any branch
|
||||
|
||||
jobs:
|
||||
check-line-endings:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0 # Fetch all history to compare changes
|
||||
|
||||
- name: Check for Windows line endings (CRLF)
|
||||
run: |
|
||||
# Get the list of changed files in the PR
|
||||
CHANGED_FILES=$(git diff --name-only origin/${{ github.base_ref }}..HEAD)
|
||||
|
||||
# Flag to track if CRLF is found
|
||||
CRLF_FOUND=false
|
||||
|
||||
# Loop through each changed file
|
||||
for FILE in $CHANGED_FILES; do
|
||||
# Check if the file exists and is a text file
|
||||
if [ -f "$FILE" ] && file "$FILE" | grep -q "text"; then
|
||||
# Check for CRLF line endings
|
||||
if grep -UP '\r$' "$FILE"; then
|
||||
echo "Error: Windows line endings (CRLF) detected in $FILE"
|
||||
CRLF_FOUND=true
|
||||
fi
|
||||
fi
|
||||
done
|
||||
|
||||
# Exit with error if CRLF was found
|
||||
if [ "$CRLF_FOUND" = true ]; then
|
||||
exit 1
|
||||
fi
|
@@ -144,6 +144,7 @@ class PerformanceFeature(enum.Enum):
|
||||
parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: fp16_accumulation fp8_matrix_mult cublas_ops")
|
||||
|
||||
parser.add_argument("--mmap-torch-files", action="store_true", help="Use mmap when loading ckpt/pt files.")
|
||||
parser.add_argument("--disable-mmap", action="store_true", help="Don't use mmap when loading safetensors.")
|
||||
|
||||
parser.add_argument("--dont-print-server", action="store_true", help="Don't print server output.")
|
||||
parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test for CI.")
|
||||
|
@@ -973,7 +973,7 @@ class VideoVAE(nn.Module):
|
||||
norm_layer=config.get("norm_layer", "group_norm"),
|
||||
causal=config.get("causal_decoder", False),
|
||||
timestep_conditioning=self.timestep_conditioning,
|
||||
spatial_padding_mode=config.get("spatial_padding_mode", "zeros"),
|
||||
spatial_padding_mode=config.get("spatial_padding_mode", "reflect"),
|
||||
)
|
||||
|
||||
self.per_channel_statistics = processor()
|
||||
|
@@ -1,256 +1,256 @@
|
||||
# Based on:
|
||||
# https://github.com/PixArt-alpha/PixArt-alpha [Apache 2.0 license]
|
||||
# https://github.com/PixArt-alpha/PixArt-sigma [Apache 2.0 license]
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from .blocks import (
|
||||
t2i_modulate,
|
||||
CaptionEmbedder,
|
||||
AttentionKVCompress,
|
||||
MultiHeadCrossAttention,
|
||||
T2IFinalLayer,
|
||||
SizeEmbedder,
|
||||
)
|
||||
from comfy.ldm.modules.diffusionmodules.mmdit import TimestepEmbedder, PatchEmbed, Mlp, get_1d_sincos_pos_embed_from_grid_torch
|
||||
|
||||
|
||||
def get_2d_sincos_pos_embed_torch(embed_dim, w, h, pe_interpolation=1.0, base_size=16, device=None, dtype=torch.float32):
|
||||
grid_h, grid_w = torch.meshgrid(
|
||||
torch.arange(h, device=device, dtype=dtype) / (h/base_size) / pe_interpolation,
|
||||
torch.arange(w, device=device, dtype=dtype) / (w/base_size) / pe_interpolation,
|
||||
indexing='ij'
|
||||
)
|
||||
emb_h = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid_h, device=device, dtype=dtype)
|
||||
emb_w = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid_w, device=device, dtype=dtype)
|
||||
emb = torch.cat([emb_w, emb_h], dim=1) # (H*W, D)
|
||||
return emb
|
||||
|
||||
class PixArtMSBlock(nn.Module):
|
||||
"""
|
||||
A PixArt block with adaptive layer norm zero (adaLN-Zero) conditioning.
|
||||
"""
|
||||
def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, drop_path=0., input_size=None,
|
||||
sampling=None, sr_ratio=1, qk_norm=False, dtype=None, device=None, operations=None, **block_kwargs):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
self.attn = AttentionKVCompress(
|
||||
hidden_size, num_heads=num_heads, qkv_bias=True, sampling=sampling, sr_ratio=sr_ratio,
|
||||
qk_norm=qk_norm, dtype=dtype, device=device, operations=operations, **block_kwargs
|
||||
)
|
||||
self.cross_attn = MultiHeadCrossAttention(
|
||||
hidden_size, num_heads, dtype=dtype, device=device, operations=operations, **block_kwargs
|
||||
)
|
||||
self.norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
# to be compatible with lower version pytorch
|
||||
approx_gelu = lambda: nn.GELU(approximate="tanh")
|
||||
self.mlp = Mlp(
|
||||
in_features=hidden_size, hidden_features=int(hidden_size * mlp_ratio), act_layer=approx_gelu,
|
||||
dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
self.scale_shift_table = nn.Parameter(torch.randn(6, hidden_size) / hidden_size ** 0.5)
|
||||
|
||||
def forward(self, x, y, t, mask=None, HW=None, **kwargs):
|
||||
B, N, C = x.shape
|
||||
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None].to(dtype=x.dtype, device=x.device) + t.reshape(B, 6, -1)).chunk(6, dim=1)
|
||||
x = x + (gate_msa * self.attn(t2i_modulate(self.norm1(x), shift_msa, scale_msa), HW=HW))
|
||||
x = x + self.cross_attn(x, y, mask)
|
||||
x = x + (gate_mlp * self.mlp(t2i_modulate(self.norm2(x), shift_mlp, scale_mlp)))
|
||||
|
||||
return x
|
||||
|
||||
|
||||
### Core PixArt Model ###
|
||||
class PixArtMS(nn.Module):
|
||||
"""
|
||||
Diffusion model with a Transformer backbone.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
input_size=32,
|
||||
patch_size=2,
|
||||
in_channels=4,
|
||||
hidden_size=1152,
|
||||
depth=28,
|
||||
num_heads=16,
|
||||
mlp_ratio=4.0,
|
||||
class_dropout_prob=0.1,
|
||||
learn_sigma=True,
|
||||
pred_sigma=True,
|
||||
drop_path: float = 0.,
|
||||
caption_channels=4096,
|
||||
pe_interpolation=None,
|
||||
pe_precision=None,
|
||||
config=None,
|
||||
model_max_length=120,
|
||||
micro_condition=True,
|
||||
qk_norm=False,
|
||||
kv_compress_config=None,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
**kwargs,
|
||||
):
|
||||
nn.Module.__init__(self)
|
||||
self.dtype = dtype
|
||||
self.pred_sigma = pred_sigma
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = in_channels * 2 if pred_sigma else in_channels
|
||||
self.patch_size = patch_size
|
||||
self.num_heads = num_heads
|
||||
self.pe_interpolation = pe_interpolation
|
||||
self.pe_precision = pe_precision
|
||||
self.hidden_size = hidden_size
|
||||
self.depth = depth
|
||||
|
||||
approx_gelu = lambda: nn.GELU(approximate="tanh")
|
||||
self.t_block = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
operations.Linear(hidden_size, 6 * hidden_size, bias=True, dtype=dtype, device=device)
|
||||
)
|
||||
self.x_embedder = PatchEmbed(
|
||||
patch_size=patch_size,
|
||||
in_chans=in_channels,
|
||||
embed_dim=hidden_size,
|
||||
bias=True,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations
|
||||
)
|
||||
self.t_embedder = TimestepEmbedder(
|
||||
hidden_size, dtype=dtype, device=device, operations=operations,
|
||||
)
|
||||
self.y_embedder = CaptionEmbedder(
|
||||
in_channels=caption_channels, hidden_size=hidden_size, uncond_prob=class_dropout_prob,
|
||||
act_layer=approx_gelu, token_num=model_max_length,
|
||||
dtype=dtype, device=device, operations=operations,
|
||||
)
|
||||
|
||||
self.micro_conditioning = micro_condition
|
||||
if self.micro_conditioning:
|
||||
self.csize_embedder = SizeEmbedder(hidden_size//3, dtype=dtype, device=device, operations=operations)
|
||||
self.ar_embedder = SizeEmbedder(hidden_size//3, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
# For fixed sin-cos embedding:
|
||||
# num_patches = (input_size // patch_size) * (input_size // patch_size)
|
||||
# self.base_size = input_size // self.patch_size
|
||||
# self.register_buffer("pos_embed", torch.zeros(1, num_patches, hidden_size))
|
||||
|
||||
drop_path = [x.item() for x in torch.linspace(0, drop_path, depth)] # stochastic depth decay rule
|
||||
if kv_compress_config is None:
|
||||
kv_compress_config = {
|
||||
'sampling': None,
|
||||
'scale_factor': 1,
|
||||
'kv_compress_layer': [],
|
||||
}
|
||||
self.blocks = nn.ModuleList([
|
||||
PixArtMSBlock(
|
||||
hidden_size, num_heads, mlp_ratio=mlp_ratio, drop_path=drop_path[i],
|
||||
sampling=kv_compress_config['sampling'],
|
||||
sr_ratio=int(kv_compress_config['scale_factor']) if i in kv_compress_config['kv_compress_layer'] else 1,
|
||||
qk_norm=qk_norm,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
)
|
||||
for i in range(depth)
|
||||
])
|
||||
self.final_layer = T2IFinalLayer(
|
||||
hidden_size, patch_size, self.out_channels, dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
|
||||
def forward_orig(self, x, timestep, y, mask=None, c_size=None, c_ar=None, **kwargs):
|
||||
"""
|
||||
Original forward pass of PixArt.
|
||||
x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
|
||||
t: (N,) tensor of diffusion timesteps
|
||||
y: (N, 1, 120, C) conditioning
|
||||
ar: (N, 1): aspect ratio
|
||||
cs: (N ,2) size conditioning for height/width
|
||||
"""
|
||||
B, C, H, W = x.shape
|
||||
c_res = (H + W) // 2
|
||||
pe_interpolation = self.pe_interpolation
|
||||
if pe_interpolation is None or self.pe_precision is not None:
|
||||
# calculate pe_interpolation on-the-fly
|
||||
pe_interpolation = round(c_res / (512/8.0), self.pe_precision or 0)
|
||||
|
||||
pos_embed = get_2d_sincos_pos_embed_torch(
|
||||
self.hidden_size,
|
||||
h=(H // self.patch_size),
|
||||
w=(W // self.patch_size),
|
||||
pe_interpolation=pe_interpolation,
|
||||
base_size=((round(c_res / 64) * 64) // self.patch_size),
|
||||
device=x.device,
|
||||
dtype=x.dtype,
|
||||
).unsqueeze(0)
|
||||
|
||||
x = self.x_embedder(x) + pos_embed # (N, T, D), where T = H * W / patch_size ** 2
|
||||
t = self.t_embedder(timestep, x.dtype) # (N, D)
|
||||
|
||||
if self.micro_conditioning and (c_size is not None and c_ar is not None):
|
||||
bs = x.shape[0]
|
||||
c_size = self.csize_embedder(c_size, bs) # (N, D)
|
||||
c_ar = self.ar_embedder(c_ar, bs) # (N, D)
|
||||
t = t + torch.cat([c_size, c_ar], dim=1)
|
||||
|
||||
t0 = self.t_block(t)
|
||||
y = self.y_embedder(y, self.training) # (N, D)
|
||||
|
||||
if mask is not None:
|
||||
if mask.shape[0] != y.shape[0]:
|
||||
mask = mask.repeat(y.shape[0] // mask.shape[0], 1)
|
||||
mask = mask.squeeze(1).squeeze(1)
|
||||
y = y.squeeze(1).masked_select(mask.unsqueeze(-1) != 0).view(1, -1, x.shape[-1])
|
||||
y_lens = mask.sum(dim=1).tolist()
|
||||
else:
|
||||
y_lens = None
|
||||
y = y.squeeze(1).view(1, -1, x.shape[-1])
|
||||
for block in self.blocks:
|
||||
x = block(x, y, t0, y_lens, (H, W), **kwargs) # (N, T, D)
|
||||
|
||||
x = self.final_layer(x, t) # (N, T, patch_size ** 2 * out_channels)
|
||||
x = self.unpatchify(x, H, W) # (N, out_channels, H, W)
|
||||
|
||||
return x
|
||||
|
||||
def forward(self, x, timesteps, context, c_size=None, c_ar=None, **kwargs):
|
||||
B, C, H, W = x.shape
|
||||
|
||||
# Fallback for missing microconds
|
||||
if self.micro_conditioning:
|
||||
if c_size is None:
|
||||
c_size = torch.tensor([H*8, W*8], dtype=x.dtype, device=x.device).repeat(B, 1)
|
||||
|
||||
if c_ar is None:
|
||||
c_ar = torch.tensor([H/W], dtype=x.dtype, device=x.device).repeat(B, 1)
|
||||
|
||||
## Still accepts the input w/o that dim but returns garbage
|
||||
if len(context.shape) == 3:
|
||||
context = context.unsqueeze(1)
|
||||
|
||||
## run original forward pass
|
||||
out = self.forward_orig(x, timesteps, context, c_size=c_size, c_ar=c_ar)
|
||||
|
||||
## only return EPS
|
||||
if self.pred_sigma:
|
||||
return out[:, :self.in_channels]
|
||||
return out
|
||||
|
||||
def unpatchify(self, x, h, w):
|
||||
"""
|
||||
x: (N, T, patch_size**2 * C)
|
||||
imgs: (N, H, W, C)
|
||||
"""
|
||||
c = self.out_channels
|
||||
p = self.x_embedder.patch_size[0]
|
||||
h = h // self.patch_size
|
||||
w = w // self.patch_size
|
||||
assert h * w == x.shape[1]
|
||||
|
||||
x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
|
||||
x = torch.einsum('nhwpqc->nchpwq', x)
|
||||
imgs = x.reshape(shape=(x.shape[0], c, h * p, w * p))
|
||||
return imgs
|
||||
# Based on:
|
||||
# https://github.com/PixArt-alpha/PixArt-alpha [Apache 2.0 license]
|
||||
# https://github.com/PixArt-alpha/PixArt-sigma [Apache 2.0 license]
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from .blocks import (
|
||||
t2i_modulate,
|
||||
CaptionEmbedder,
|
||||
AttentionKVCompress,
|
||||
MultiHeadCrossAttention,
|
||||
T2IFinalLayer,
|
||||
SizeEmbedder,
|
||||
)
|
||||
from comfy.ldm.modules.diffusionmodules.mmdit import TimestepEmbedder, PatchEmbed, Mlp, get_1d_sincos_pos_embed_from_grid_torch
|
||||
|
||||
|
||||
def get_2d_sincos_pos_embed_torch(embed_dim, w, h, pe_interpolation=1.0, base_size=16, device=None, dtype=torch.float32):
|
||||
grid_h, grid_w = torch.meshgrid(
|
||||
torch.arange(h, device=device, dtype=dtype) / (h/base_size) / pe_interpolation,
|
||||
torch.arange(w, device=device, dtype=dtype) / (w/base_size) / pe_interpolation,
|
||||
indexing='ij'
|
||||
)
|
||||
emb_h = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid_h, device=device, dtype=dtype)
|
||||
emb_w = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid_w, device=device, dtype=dtype)
|
||||
emb = torch.cat([emb_w, emb_h], dim=1) # (H*W, D)
|
||||
return emb
|
||||
|
||||
class PixArtMSBlock(nn.Module):
|
||||
"""
|
||||
A PixArt block with adaptive layer norm zero (adaLN-Zero) conditioning.
|
||||
"""
|
||||
def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, drop_path=0., input_size=None,
|
||||
sampling=None, sr_ratio=1, qk_norm=False, dtype=None, device=None, operations=None, **block_kwargs):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
self.attn = AttentionKVCompress(
|
||||
hidden_size, num_heads=num_heads, qkv_bias=True, sampling=sampling, sr_ratio=sr_ratio,
|
||||
qk_norm=qk_norm, dtype=dtype, device=device, operations=operations, **block_kwargs
|
||||
)
|
||||
self.cross_attn = MultiHeadCrossAttention(
|
||||
hidden_size, num_heads, dtype=dtype, device=device, operations=operations, **block_kwargs
|
||||
)
|
||||
self.norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
# to be compatible with lower version pytorch
|
||||
approx_gelu = lambda: nn.GELU(approximate="tanh")
|
||||
self.mlp = Mlp(
|
||||
in_features=hidden_size, hidden_features=int(hidden_size * mlp_ratio), act_layer=approx_gelu,
|
||||
dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
self.scale_shift_table = nn.Parameter(torch.randn(6, hidden_size) / hidden_size ** 0.5)
|
||||
|
||||
def forward(self, x, y, t, mask=None, HW=None, **kwargs):
|
||||
B, N, C = x.shape
|
||||
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None].to(dtype=x.dtype, device=x.device) + t.reshape(B, 6, -1)).chunk(6, dim=1)
|
||||
x = x + (gate_msa * self.attn(t2i_modulate(self.norm1(x), shift_msa, scale_msa), HW=HW))
|
||||
x = x + self.cross_attn(x, y, mask)
|
||||
x = x + (gate_mlp * self.mlp(t2i_modulate(self.norm2(x), shift_mlp, scale_mlp)))
|
||||
|
||||
return x
|
||||
|
||||
|
||||
### Core PixArt Model ###
|
||||
class PixArtMS(nn.Module):
|
||||
"""
|
||||
Diffusion model with a Transformer backbone.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
input_size=32,
|
||||
patch_size=2,
|
||||
in_channels=4,
|
||||
hidden_size=1152,
|
||||
depth=28,
|
||||
num_heads=16,
|
||||
mlp_ratio=4.0,
|
||||
class_dropout_prob=0.1,
|
||||
learn_sigma=True,
|
||||
pred_sigma=True,
|
||||
drop_path: float = 0.,
|
||||
caption_channels=4096,
|
||||
pe_interpolation=None,
|
||||
pe_precision=None,
|
||||
config=None,
|
||||
model_max_length=120,
|
||||
micro_condition=True,
|
||||
qk_norm=False,
|
||||
kv_compress_config=None,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
**kwargs,
|
||||
):
|
||||
nn.Module.__init__(self)
|
||||
self.dtype = dtype
|
||||
self.pred_sigma = pred_sigma
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = in_channels * 2 if pred_sigma else in_channels
|
||||
self.patch_size = patch_size
|
||||
self.num_heads = num_heads
|
||||
self.pe_interpolation = pe_interpolation
|
||||
self.pe_precision = pe_precision
|
||||
self.hidden_size = hidden_size
|
||||
self.depth = depth
|
||||
|
||||
approx_gelu = lambda: nn.GELU(approximate="tanh")
|
||||
self.t_block = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
operations.Linear(hidden_size, 6 * hidden_size, bias=True, dtype=dtype, device=device)
|
||||
)
|
||||
self.x_embedder = PatchEmbed(
|
||||
patch_size=patch_size,
|
||||
in_chans=in_channels,
|
||||
embed_dim=hidden_size,
|
||||
bias=True,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations
|
||||
)
|
||||
self.t_embedder = TimestepEmbedder(
|
||||
hidden_size, dtype=dtype, device=device, operations=operations,
|
||||
)
|
||||
self.y_embedder = CaptionEmbedder(
|
||||
in_channels=caption_channels, hidden_size=hidden_size, uncond_prob=class_dropout_prob,
|
||||
act_layer=approx_gelu, token_num=model_max_length,
|
||||
dtype=dtype, device=device, operations=operations,
|
||||
)
|
||||
|
||||
self.micro_conditioning = micro_condition
|
||||
if self.micro_conditioning:
|
||||
self.csize_embedder = SizeEmbedder(hidden_size//3, dtype=dtype, device=device, operations=operations)
|
||||
self.ar_embedder = SizeEmbedder(hidden_size//3, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
# For fixed sin-cos embedding:
|
||||
# num_patches = (input_size // patch_size) * (input_size // patch_size)
|
||||
# self.base_size = input_size // self.patch_size
|
||||
# self.register_buffer("pos_embed", torch.zeros(1, num_patches, hidden_size))
|
||||
|
||||
drop_path = [x.item() for x in torch.linspace(0, drop_path, depth)] # stochastic depth decay rule
|
||||
if kv_compress_config is None:
|
||||
kv_compress_config = {
|
||||
'sampling': None,
|
||||
'scale_factor': 1,
|
||||
'kv_compress_layer': [],
|
||||
}
|
||||
self.blocks = nn.ModuleList([
|
||||
PixArtMSBlock(
|
||||
hidden_size, num_heads, mlp_ratio=mlp_ratio, drop_path=drop_path[i],
|
||||
sampling=kv_compress_config['sampling'],
|
||||
sr_ratio=int(kv_compress_config['scale_factor']) if i in kv_compress_config['kv_compress_layer'] else 1,
|
||||
qk_norm=qk_norm,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
)
|
||||
for i in range(depth)
|
||||
])
|
||||
self.final_layer = T2IFinalLayer(
|
||||
hidden_size, patch_size, self.out_channels, dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
|
||||
def forward_orig(self, x, timestep, y, mask=None, c_size=None, c_ar=None, **kwargs):
|
||||
"""
|
||||
Original forward pass of PixArt.
|
||||
x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
|
||||
t: (N,) tensor of diffusion timesteps
|
||||
y: (N, 1, 120, C) conditioning
|
||||
ar: (N, 1): aspect ratio
|
||||
cs: (N ,2) size conditioning for height/width
|
||||
"""
|
||||
B, C, H, W = x.shape
|
||||
c_res = (H + W) // 2
|
||||
pe_interpolation = self.pe_interpolation
|
||||
if pe_interpolation is None or self.pe_precision is not None:
|
||||
# calculate pe_interpolation on-the-fly
|
||||
pe_interpolation = round(c_res / (512/8.0), self.pe_precision or 0)
|
||||
|
||||
pos_embed = get_2d_sincos_pos_embed_torch(
|
||||
self.hidden_size,
|
||||
h=(H // self.patch_size),
|
||||
w=(W // self.patch_size),
|
||||
pe_interpolation=pe_interpolation,
|
||||
base_size=((round(c_res / 64) * 64) // self.patch_size),
|
||||
device=x.device,
|
||||
dtype=x.dtype,
|
||||
).unsqueeze(0)
|
||||
|
||||
x = self.x_embedder(x) + pos_embed # (N, T, D), where T = H * W / patch_size ** 2
|
||||
t = self.t_embedder(timestep, x.dtype) # (N, D)
|
||||
|
||||
if self.micro_conditioning and (c_size is not None and c_ar is not None):
|
||||
bs = x.shape[0]
|
||||
c_size = self.csize_embedder(c_size, bs) # (N, D)
|
||||
c_ar = self.ar_embedder(c_ar, bs) # (N, D)
|
||||
t = t + torch.cat([c_size, c_ar], dim=1)
|
||||
|
||||
t0 = self.t_block(t)
|
||||
y = self.y_embedder(y, self.training) # (N, D)
|
||||
|
||||
if mask is not None:
|
||||
if mask.shape[0] != y.shape[0]:
|
||||
mask = mask.repeat(y.shape[0] // mask.shape[0], 1)
|
||||
mask = mask.squeeze(1).squeeze(1)
|
||||
y = y.squeeze(1).masked_select(mask.unsqueeze(-1) != 0).view(1, -1, x.shape[-1])
|
||||
y_lens = mask.sum(dim=1).tolist()
|
||||
else:
|
||||
y_lens = None
|
||||
y = y.squeeze(1).view(1, -1, x.shape[-1])
|
||||
for block in self.blocks:
|
||||
x = block(x, y, t0, y_lens, (H, W), **kwargs) # (N, T, D)
|
||||
|
||||
x = self.final_layer(x, t) # (N, T, patch_size ** 2 * out_channels)
|
||||
x = self.unpatchify(x, H, W) # (N, out_channels, H, W)
|
||||
|
||||
return x
|
||||
|
||||
def forward(self, x, timesteps, context, c_size=None, c_ar=None, **kwargs):
|
||||
B, C, H, W = x.shape
|
||||
|
||||
# Fallback for missing microconds
|
||||
if self.micro_conditioning:
|
||||
if c_size is None:
|
||||
c_size = torch.tensor([H*8, W*8], dtype=x.dtype, device=x.device).repeat(B, 1)
|
||||
|
||||
if c_ar is None:
|
||||
c_ar = torch.tensor([H/W], dtype=x.dtype, device=x.device).repeat(B, 1)
|
||||
|
||||
## Still accepts the input w/o that dim but returns garbage
|
||||
if len(context.shape) == 3:
|
||||
context = context.unsqueeze(1)
|
||||
|
||||
## run original forward pass
|
||||
out = self.forward_orig(x, timesteps, context, c_size=c_size, c_ar=c_ar)
|
||||
|
||||
## only return EPS
|
||||
if self.pred_sigma:
|
||||
return out[:, :self.in_channels]
|
||||
return out
|
||||
|
||||
def unpatchify(self, x, h, w):
|
||||
"""
|
||||
x: (N, T, patch_size**2 * C)
|
||||
imgs: (N, H, W, C)
|
||||
"""
|
||||
c = self.out_channels
|
||||
p = self.x_embedder.patch_size[0]
|
||||
h = h // self.patch_size
|
||||
w = w // self.patch_size
|
||||
assert h * w == x.shape[1]
|
||||
|
||||
x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
|
||||
x = torch.einsum('nhwpqc->nchpwq', x)
|
||||
imgs = x.reshape(shape=(x.shape[0], c, h * p, w * p))
|
||||
return imgs
|
||||
|
11
comfy/sd.py
11
comfy/sd.py
@@ -18,6 +18,7 @@ import comfy.ldm.hunyuan3d.vae
|
||||
import comfy.ldm.ace.vae.music_dcae_pipeline
|
||||
import yaml
|
||||
import math
|
||||
import os
|
||||
|
||||
import comfy.utils
|
||||
|
||||
@@ -977,6 +978,12 @@ def load_gligen(ckpt_path):
|
||||
model = model.half()
|
||||
return comfy.model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=model_management.unet_offload_device())
|
||||
|
||||
def model_detection_error_hint(path, state_dict):
|
||||
filename = os.path.basename(path)
|
||||
if 'lora' in filename.lower():
|
||||
return "\nHINT: This seems to be a Lora file and Lora files should be put in the lora folder and loaded with a lora loader node.."
|
||||
return ""
|
||||
|
||||
def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_clip=True, embedding_directory=None, state_dict=None, config=None):
|
||||
logging.warning("Warning: The load checkpoint with config function is deprecated and will eventually be removed, please use the other one.")
|
||||
model, clip, vae, _ = load_checkpoint_guess_config(ckpt_path, output_vae=output_vae, output_clip=output_clip, output_clipvision=False, embedding_directory=embedding_directory, output_model=True)
|
||||
@@ -1005,7 +1012,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
|
||||
sd, metadata = comfy.utils.load_torch_file(ckpt_path, return_metadata=True)
|
||||
out = load_state_dict_guess_config(sd, output_vae, output_clip, output_clipvision, embedding_directory, output_model, model_options, te_model_options=te_model_options, metadata=metadata)
|
||||
if out is None:
|
||||
raise RuntimeError("ERROR: Could not detect model type of: {}".format(ckpt_path))
|
||||
raise RuntimeError("ERROR: Could not detect model type of: {}\n{}".format(ckpt_path, model_detection_error_hint(ckpt_path, sd)))
|
||||
return out
|
||||
|
||||
def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}, te_model_options={}, metadata=None):
|
||||
@@ -1177,7 +1184,7 @@ def load_diffusion_model(unet_path, model_options={}):
|
||||
model = load_diffusion_model_state_dict(sd, model_options=model_options)
|
||||
if model is None:
|
||||
logging.error("ERROR UNSUPPORTED DIFFUSION MODEL {}".format(unet_path))
|
||||
raise RuntimeError("ERROR: Could not detect model type of: {}".format(unet_path))
|
||||
raise RuntimeError("ERROR: Could not detect model type of: {}\n{}".format(unet_path, model_detection_error_hint(unet_path, sd)))
|
||||
return model
|
||||
|
||||
def load_unet(unet_path, dtype=None):
|
||||
|
@@ -18,7 +18,7 @@
|
||||
"single_word": false
|
||||
},
|
||||
"errors": "replace",
|
||||
"model_max_length": 77,
|
||||
"model_max_length": 8192,
|
||||
"name_or_path": "openai/clip-vit-large-patch14",
|
||||
"pad_token": "<|endoftext|>",
|
||||
"special_tokens_map_file": "./special_tokens_map.json",
|
||||
|
@@ -1214,7 +1214,7 @@ class Omnigen2(supported_models_base.BASE):
|
||||
def clip_target(self, state_dict={}):
|
||||
pref = self.text_encoder_key_prefix[0]
|
||||
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_3b.transformer.".format(pref))
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.omnigen2.LuminaTokenizer, comfy.text_encoders.omnigen2.te(**hunyuan_detect))
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.omnigen2.Omnigen2Tokenizer, comfy.text_encoders.omnigen2.te(**hunyuan_detect))
|
||||
|
||||
|
||||
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream, Chroma, ACEStep, Omnigen2]
|
||||
|
@@ -1,42 +1,42 @@
|
||||
import os
|
||||
|
||||
from comfy import sd1_clip
|
||||
import comfy.text_encoders.t5
|
||||
import comfy.text_encoders.sd3_clip
|
||||
from comfy.sd1_clip import gen_empty_tokens
|
||||
|
||||
from transformers import T5TokenizerFast
|
||||
|
||||
class T5XXLModel(comfy.text_encoders.sd3_clip.T5XXLModel):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def gen_empty_tokens(self, special_tokens, *args, **kwargs):
|
||||
# PixArt expects the negative to be all pad tokens
|
||||
special_tokens = special_tokens.copy()
|
||||
special_tokens.pop("end")
|
||||
return gen_empty_tokens(special_tokens, *args, **kwargs)
|
||||
|
||||
class PixArtT5XXL(sd1_clip.SD1ClipModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
super().__init__(device=device, dtype=dtype, name="t5xxl", clip_model=T5XXLModel, model_options=model_options)
|
||||
|
||||
class T5XXLTokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
|
||||
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_data=tokenizer_data) # no padding
|
||||
|
||||
class PixArtTokenizer(sd1_clip.SD1Tokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="t5xxl", tokenizer=T5XXLTokenizer)
|
||||
|
||||
def pixart_te(dtype_t5=None, t5xxl_scaled_fp8=None):
|
||||
class PixArtTEModel_(PixArtT5XXL):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options:
|
||||
model_options = model_options.copy()
|
||||
model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8
|
||||
if dtype is None:
|
||||
dtype = dtype_t5
|
||||
super().__init__(device=device, dtype=dtype, model_options=model_options)
|
||||
return PixArtTEModel_
|
||||
import os
|
||||
|
||||
from comfy import sd1_clip
|
||||
import comfy.text_encoders.t5
|
||||
import comfy.text_encoders.sd3_clip
|
||||
from comfy.sd1_clip import gen_empty_tokens
|
||||
|
||||
from transformers import T5TokenizerFast
|
||||
|
||||
class T5XXLModel(comfy.text_encoders.sd3_clip.T5XXLModel):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def gen_empty_tokens(self, special_tokens, *args, **kwargs):
|
||||
# PixArt expects the negative to be all pad tokens
|
||||
special_tokens = special_tokens.copy()
|
||||
special_tokens.pop("end")
|
||||
return gen_empty_tokens(special_tokens, *args, **kwargs)
|
||||
|
||||
class PixArtT5XXL(sd1_clip.SD1ClipModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
super().__init__(device=device, dtype=dtype, name="t5xxl", clip_model=T5XXLModel, model_options=model_options)
|
||||
|
||||
class T5XXLTokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
|
||||
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_data=tokenizer_data) # no padding
|
||||
|
||||
class PixArtTokenizer(sd1_clip.SD1Tokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="t5xxl", tokenizer=T5XXLTokenizer)
|
||||
|
||||
def pixart_te(dtype_t5=None, t5xxl_scaled_fp8=None):
|
||||
class PixArtTEModel_(PixArtT5XXL):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options:
|
||||
model_options = model_options.copy()
|
||||
model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8
|
||||
if dtype is None:
|
||||
dtype = dtype_t5
|
||||
super().__init__(device=device, dtype=dtype, model_options=model_options)
|
||||
return PixArtTEModel_
|
||||
|
@@ -31,6 +31,7 @@ from einops import rearrange
|
||||
from comfy.cli_args import args
|
||||
|
||||
MMAP_TORCH_FILES = args.mmap_torch_files
|
||||
DISABLE_MMAP = args.disable_mmap
|
||||
|
||||
ALWAYS_SAFE_LOAD = False
|
||||
if hasattr(torch.serialization, "add_safe_globals"): # TODO: this was added in pytorch 2.4, the unsafe path should be removed once earlier versions are deprecated
|
||||
@@ -58,7 +59,10 @@ def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False):
|
||||
with safetensors.safe_open(ckpt, framework="pt", device=device.type) as f:
|
||||
sd = {}
|
||||
for k in f.keys():
|
||||
sd[k] = f.get_tensor(k)
|
||||
tensor = f.get_tensor(k)
|
||||
if DISABLE_MMAP: # TODO: Not sure if this is the best way to bypass the mmap issues
|
||||
tensor = tensor.to(device=device, copy=True)
|
||||
sd[k] = tensor
|
||||
if return_metadata:
|
||||
metadata = f.metadata()
|
||||
except Exception as e:
|
||||
@@ -998,11 +1002,12 @@ def set_progress_bar_global_hook(function):
|
||||
PROGRESS_BAR_HOOK = function
|
||||
|
||||
class ProgressBar:
|
||||
def __init__(self, total):
|
||||
def __init__(self, total, node_id=None):
|
||||
global PROGRESS_BAR_HOOK
|
||||
self.total = total
|
||||
self.current = 0
|
||||
self.hook = PROGRESS_BAR_HOOK
|
||||
self.node_id = node_id
|
||||
|
||||
def update_absolute(self, value, total=None, preview=None):
|
||||
if total is not None:
|
||||
@@ -1011,7 +1016,7 @@ class ProgressBar:
|
||||
value = self.total
|
||||
self.current = value
|
||||
if self.hook is not None:
|
||||
self.hook(self.current, self.total, preview)
|
||||
self.hook(self.current, self.total, preview, node_id=self.node_id)
|
||||
|
||||
def update(self, value):
|
||||
self.update_absolute(self.current + value)
|
||||
|
69
comfy_api/feature_flags.py
Normal file
69
comfy_api/feature_flags.py
Normal file
@@ -0,0 +1,69 @@
|
||||
"""
|
||||
Feature flags module for ComfyUI WebSocket protocol negotiation.
|
||||
|
||||
This module handles capability negotiation between frontend and backend,
|
||||
allowing graceful protocol evolution while maintaining backward compatibility.
|
||||
"""
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
from comfy.cli_args import args
|
||||
|
||||
# Default server capabilities
|
||||
SERVER_FEATURE_FLAGS: Dict[str, Any] = {
|
||||
"supports_preview_metadata": True,
|
||||
"max_upload_size": args.max_upload_size * 1024 * 1024, # Convert MB to bytes
|
||||
}
|
||||
|
||||
|
||||
def get_connection_feature(
|
||||
sockets_metadata: Dict[str, Dict[str, Any]],
|
||||
sid: str,
|
||||
feature_name: str,
|
||||
default: Any = False
|
||||
) -> Any:
|
||||
"""
|
||||
Get a feature flag value for a specific connection.
|
||||
|
||||
Args:
|
||||
sockets_metadata: Dictionary of socket metadata
|
||||
sid: Session ID of the connection
|
||||
feature_name: Name of the feature to check
|
||||
default: Default value if feature not found
|
||||
|
||||
Returns:
|
||||
Feature value or default if not found
|
||||
"""
|
||||
if sid not in sockets_metadata:
|
||||
return default
|
||||
|
||||
return sockets_metadata[sid].get("feature_flags", {}).get(feature_name, default)
|
||||
|
||||
|
||||
def supports_feature(
|
||||
sockets_metadata: Dict[str, Dict[str, Any]],
|
||||
sid: str,
|
||||
feature_name: str
|
||||
) -> bool:
|
||||
"""
|
||||
Check if a connection supports a specific feature.
|
||||
|
||||
Args:
|
||||
sockets_metadata: Dictionary of socket metadata
|
||||
sid: Session ID of the connection
|
||||
feature_name: Name of the feature to check
|
||||
|
||||
Returns:
|
||||
Boolean indicating if feature is supported
|
||||
"""
|
||||
return get_connection_feature(sockets_metadata, sid, feature_name, False) is True
|
||||
|
||||
|
||||
def get_server_features() -> Dict[str, Any]:
|
||||
"""
|
||||
Get the server's feature flags.
|
||||
|
||||
Returns:
|
||||
Dictionary of server feature flags
|
||||
"""
|
||||
return SERVER_FEATURE_FLAGS.copy()
|
@@ -1,3 +1,4 @@
|
||||
import asyncio
|
||||
from dataclasses import asdict
|
||||
from typing import Callable, Optional
|
||||
|
||||
@@ -112,3 +113,24 @@ def lock_class(cls):
|
||||
locked_dict['__setattr__'] = locked_instance_setattr
|
||||
|
||||
return LockedMeta(cls.__name__, cls.__bases__, locked_dict)
|
||||
|
||||
|
||||
def make_locked_method_func(type_obj, func, class_clone):
|
||||
"""
|
||||
Returns a function that, when called with **inputs, will execute:
|
||||
getattr(type_obj, func).__func__(lock_class(class_clone), **inputs)
|
||||
|
||||
Supports both synchronous and asynchronous methods.
|
||||
"""
|
||||
locked_class = lock_class(class_clone)
|
||||
method = getattr(type_obj, func).__func__
|
||||
|
||||
# Check if the original method is async
|
||||
if asyncio.iscoroutinefunction(method):
|
||||
async def wrapped_async_func(**inputs):
|
||||
return await method(locked_class, **inputs)
|
||||
return wrapped_async_func
|
||||
else:
|
||||
def wrapped_func(**inputs):
|
||||
return method(locked_class, **inputs)
|
||||
return wrapped_func
|
||||
|
@@ -1,6 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
import inspect
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import Counter
|
||||
from dataclasses import asdict, dataclass
|
||||
@@ -75,6 +76,16 @@ class NumberDisplay(str, Enum):
|
||||
color = "color"
|
||||
|
||||
|
||||
class _StringIOType(str):
|
||||
def __ne__(self, value: object) -> bool:
|
||||
if self == "*" or value == "*":
|
||||
return False
|
||||
if not isinstance(value, str):
|
||||
return True
|
||||
a = frozenset(self.split(","))
|
||||
b = frozenset(value.split(","))
|
||||
return not (b.issubset(a) or a.issubset(b))
|
||||
|
||||
class ComfyType(ABC):
|
||||
Type = Any
|
||||
io_type: str = None
|
||||
@@ -114,8 +125,8 @@ def comfytype(io_type: str, **kwargs):
|
||||
new_cls.__module__ = cls.__module__
|
||||
new_cls.__doc__ = cls.__doc__
|
||||
# assign ComfyType attributes, if needed
|
||||
# NOTE: do we need __ne__ trick for io_type? (see node_typing.IO.__ne__ for details)
|
||||
new_cls.io_type = io_type
|
||||
# NOTE: use __ne__ trick for io_type (see node_typing.IO.__ne__ for details)
|
||||
new_cls.io_type = _StringIOType(io_type)
|
||||
if hasattr(new_cls, "Input") and new_cls.Input is not None:
|
||||
new_cls.Input.Parent = new_cls
|
||||
if hasattr(new_cls, "Output") and new_cls.Output is not None:
|
||||
@@ -169,7 +180,7 @@ class InputV3(IO_V3):
|
||||
}) | prune_dict(self.extra_dict)
|
||||
|
||||
def get_io_type(self):
|
||||
return self.io_type
|
||||
return _StringIOType(self.io_type)
|
||||
|
||||
class WidgetInputV3(InputV3):
|
||||
'''
|
||||
@@ -439,6 +450,12 @@ class MultiCombo(ComfyTypeI):
|
||||
class Image(ComfyTypeIO):
|
||||
Type = torch.Tensor
|
||||
|
||||
|
||||
@comfytype(io_type="WAN_CAMERA_EMBEDDING")
|
||||
class WanCameraEmbedding(ComfyTypeIO):
|
||||
Type = torch.Tensor
|
||||
|
||||
|
||||
@comfytype(io_type="WEBCAM")
|
||||
class Webcam(ComfyTypeIO):
|
||||
Type = str
|
||||
@@ -1221,6 +1238,12 @@ class _ComfyNodeBaseInternal(ComfyNodeInternal):
|
||||
if first_real_override(cls, "execute") is None:
|
||||
raise Exception(f"No execute function was defined for node class {cls.__name__}.")
|
||||
|
||||
@classproperty
|
||||
def FUNCTION(cls): # noqa
|
||||
if inspect.iscoroutinefunction(cls.execute):
|
||||
return "EXECUTE_NORMALIZED_ASYNC"
|
||||
return "EXECUTE_NORMALIZED"
|
||||
|
||||
@final
|
||||
@classmethod
|
||||
def EXECUTE_NORMALIZED(cls, *args, **kwargs) -> NodeOutput:
|
||||
@@ -1238,6 +1261,23 @@ class _ComfyNodeBaseInternal(ComfyNodeInternal):
|
||||
else:
|
||||
raise Exception(f"Invalid return type from node: {type(to_return)}")
|
||||
|
||||
@final
|
||||
@classmethod
|
||||
async def EXECUTE_NORMALIZED_ASYNC(cls, *args, **kwargs) -> NodeOutput:
|
||||
to_return = await cls.execute(*args, **kwargs)
|
||||
if to_return is None:
|
||||
return NodeOutput()
|
||||
elif isinstance(to_return, NodeOutput):
|
||||
return to_return
|
||||
elif isinstance(to_return, tuple):
|
||||
return NodeOutput(*to_return)
|
||||
elif isinstance(to_return, dict):
|
||||
return NodeOutput.from_dict(to_return)
|
||||
elif isinstance(to_return, ExecutionBlocker):
|
||||
return NodeOutput(block_execution=to_return.message)
|
||||
else:
|
||||
raise Exception(f"Invalid return type from node: {type(to_return)}")
|
||||
|
||||
@final
|
||||
@classmethod
|
||||
def PREPARE_CLASS_CLONE(cls, hidden_inputs: dict) -> type[ComfyNodeV3]:
|
||||
@@ -1360,8 +1400,6 @@ class _ComfyNodeBaseInternal(ComfyNodeInternal):
|
||||
cls.GET_SCHEMA()
|
||||
return cls._NOT_IDEMPOTENT
|
||||
|
||||
FUNCTION = "EXECUTE_NORMALIZED"
|
||||
|
||||
@final
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls, include_hidden=True, return_schema=False) -> dict[str, dict] | tuple[dict[str, dict], SchemaV3]:
|
||||
|
@@ -406,7 +406,7 @@ class GeminiInputFiles(ComfyNodeABC):
|
||||
|
||||
def create_file_part(self, file_path: str) -> GeminiPart:
|
||||
mime_type = (
|
||||
GeminiMimeType.pdf
|
||||
GeminiMimeType.application_pdf
|
||||
if file_path.endswith(".pdf")
|
||||
else GeminiMimeType.text_plain
|
||||
)
|
||||
|
@@ -132,6 +132,8 @@ def poll_until_finished(
|
||||
result_url_extractor=result_url_extractor,
|
||||
estimated_duration=estimated_duration,
|
||||
node_id=node_id,
|
||||
poll_interval=16.0,
|
||||
max_poll_attempts=256,
|
||||
).execute()
|
||||
|
||||
|
||||
|
@@ -1,6 +1,7 @@
|
||||
import itertools
|
||||
from typing import Sequence, Mapping, Dict
|
||||
from comfy_execution.graph import DynamicPrompt
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
import nodes
|
||||
|
||||
@@ -16,12 +17,13 @@ def include_unique_id_in_input(class_type: str) -> bool:
|
||||
NODE_CLASS_CONTAINS_UNIQUE_ID[class_type] = "UNIQUE_ID" in class_def.INPUT_TYPES().get("hidden", {}).values()
|
||||
return NODE_CLASS_CONTAINS_UNIQUE_ID[class_type]
|
||||
|
||||
class CacheKeySet:
|
||||
class CacheKeySet(ABC):
|
||||
def __init__(self, dynprompt, node_ids, is_changed_cache):
|
||||
self.keys = {}
|
||||
self.subcache_keys = {}
|
||||
|
||||
def add_keys(self, node_ids):
|
||||
@abstractmethod
|
||||
async def add_keys(self, node_ids):
|
||||
raise NotImplementedError()
|
||||
|
||||
def all_node_ids(self):
|
||||
@@ -60,9 +62,8 @@ class CacheKeySetID(CacheKeySet):
|
||||
def __init__(self, dynprompt, node_ids, is_changed_cache):
|
||||
super().__init__(dynprompt, node_ids, is_changed_cache)
|
||||
self.dynprompt = dynprompt
|
||||
self.add_keys(node_ids)
|
||||
|
||||
def add_keys(self, node_ids):
|
||||
async def add_keys(self, node_ids):
|
||||
for node_id in node_ids:
|
||||
if node_id in self.keys:
|
||||
continue
|
||||
@@ -77,37 +78,36 @@ class CacheKeySetInputSignature(CacheKeySet):
|
||||
super().__init__(dynprompt, node_ids, is_changed_cache)
|
||||
self.dynprompt = dynprompt
|
||||
self.is_changed_cache = is_changed_cache
|
||||
self.add_keys(node_ids)
|
||||
|
||||
def include_node_id_in_input(self) -> bool:
|
||||
return False
|
||||
|
||||
def add_keys(self, node_ids):
|
||||
async def add_keys(self, node_ids):
|
||||
for node_id in node_ids:
|
||||
if node_id in self.keys:
|
||||
continue
|
||||
if not self.dynprompt.has_node(node_id):
|
||||
continue
|
||||
node = self.dynprompt.get_node(node_id)
|
||||
self.keys[node_id] = self.get_node_signature(self.dynprompt, node_id)
|
||||
self.keys[node_id] = await self.get_node_signature(self.dynprompt, node_id)
|
||||
self.subcache_keys[node_id] = (node_id, node["class_type"])
|
||||
|
||||
def get_node_signature(self, dynprompt, node_id):
|
||||
async def get_node_signature(self, dynprompt, node_id):
|
||||
signature = []
|
||||
ancestors, order_mapping = self.get_ordered_ancestry(dynprompt, node_id)
|
||||
signature.append(self.get_immediate_node_signature(dynprompt, node_id, order_mapping))
|
||||
signature.append(await self.get_immediate_node_signature(dynprompt, node_id, order_mapping))
|
||||
for ancestor_id in ancestors:
|
||||
signature.append(self.get_immediate_node_signature(dynprompt, ancestor_id, order_mapping))
|
||||
signature.append(await self.get_immediate_node_signature(dynprompt, ancestor_id, order_mapping))
|
||||
return to_hashable(signature)
|
||||
|
||||
def get_immediate_node_signature(self, dynprompt, node_id, ancestor_order_mapping):
|
||||
async def get_immediate_node_signature(self, dynprompt, node_id, ancestor_order_mapping):
|
||||
if not dynprompt.has_node(node_id):
|
||||
# This node doesn't exist -- we can't cache it.
|
||||
return [float("NaN")]
|
||||
node = dynprompt.get_node(node_id)
|
||||
class_type = node["class_type"]
|
||||
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
||||
signature = [class_type, self.is_changed_cache.get(node_id)]
|
||||
signature = [class_type, await self.is_changed_cache.get(node_id)]
|
||||
if self.include_node_id_in_input() or (hasattr(class_def, "NOT_IDEMPOTENT") and class_def.NOT_IDEMPOTENT) or include_unique_id_in_input(class_type):
|
||||
signature.append(node_id)
|
||||
inputs = node["inputs"]
|
||||
@@ -150,9 +150,10 @@ class BasicCache:
|
||||
self.cache = {}
|
||||
self.subcaches = {}
|
||||
|
||||
def set_prompt(self, dynprompt, node_ids, is_changed_cache):
|
||||
async def set_prompt(self, dynprompt, node_ids, is_changed_cache):
|
||||
self.dynprompt = dynprompt
|
||||
self.cache_key_set = self.key_class(dynprompt, node_ids, is_changed_cache)
|
||||
await self.cache_key_set.add_keys(node_ids)
|
||||
self.is_changed_cache = is_changed_cache
|
||||
self.initialized = True
|
||||
|
||||
@@ -201,13 +202,13 @@ class BasicCache:
|
||||
else:
|
||||
return None
|
||||
|
||||
def _ensure_subcache(self, node_id, children_ids):
|
||||
async def _ensure_subcache(self, node_id, children_ids):
|
||||
subcache_key = self.cache_key_set.get_subcache_key(node_id)
|
||||
subcache = self.subcaches.get(subcache_key, None)
|
||||
if subcache is None:
|
||||
subcache = BasicCache(self.key_class)
|
||||
self.subcaches[subcache_key] = subcache
|
||||
subcache.set_prompt(self.dynprompt, children_ids, self.is_changed_cache)
|
||||
await subcache.set_prompt(self.dynprompt, children_ids, self.is_changed_cache)
|
||||
return subcache
|
||||
|
||||
def _get_subcache(self, node_id):
|
||||
@@ -259,10 +260,10 @@ class HierarchicalCache(BasicCache):
|
||||
assert cache is not None
|
||||
cache._set_immediate(node_id, value)
|
||||
|
||||
def ensure_subcache_for(self, node_id, children_ids):
|
||||
async def ensure_subcache_for(self, node_id, children_ids):
|
||||
cache = self._get_cache_for(node_id)
|
||||
assert cache is not None
|
||||
return cache._ensure_subcache(node_id, children_ids)
|
||||
return await cache._ensure_subcache(node_id, children_ids)
|
||||
|
||||
class LRUCache(BasicCache):
|
||||
def __init__(self, key_class, max_size=100):
|
||||
@@ -273,8 +274,8 @@ class LRUCache(BasicCache):
|
||||
self.used_generation = {}
|
||||
self.children = {}
|
||||
|
||||
def set_prompt(self, dynprompt, node_ids, is_changed_cache):
|
||||
super().set_prompt(dynprompt, node_ids, is_changed_cache)
|
||||
async def set_prompt(self, dynprompt, node_ids, is_changed_cache):
|
||||
await super().set_prompt(dynprompt, node_ids, is_changed_cache)
|
||||
self.generation += 1
|
||||
for node_id in node_ids:
|
||||
self._mark_used(node_id)
|
||||
@@ -303,11 +304,11 @@ class LRUCache(BasicCache):
|
||||
self._mark_used(node_id)
|
||||
return self._set_immediate(node_id, value)
|
||||
|
||||
def ensure_subcache_for(self, node_id, children_ids):
|
||||
async def ensure_subcache_for(self, node_id, children_ids):
|
||||
# Just uses subcaches for tracking 'live' nodes
|
||||
super()._ensure_subcache(node_id, children_ids)
|
||||
await super()._ensure_subcache(node_id, children_ids)
|
||||
|
||||
self.cache_key_set.add_keys(children_ids)
|
||||
await self.cache_key_set.add_keys(children_ids)
|
||||
self._mark_used(node_id)
|
||||
cache_key = self.cache_key_set.get_data_key(node_id)
|
||||
self.children[cache_key] = []
|
||||
@@ -337,7 +338,7 @@ class DependencyAwareCache(BasicCache):
|
||||
self.ancestors = {} # Maps node_id -> set of ancestor node_ids
|
||||
self.executed_nodes = set() # Tracks nodes that have been executed
|
||||
|
||||
def set_prompt(self, dynprompt, node_ids, is_changed_cache):
|
||||
async def set_prompt(self, dynprompt, node_ids, is_changed_cache):
|
||||
"""
|
||||
Clear the entire cache and rebuild the dependency graph.
|
||||
|
||||
@@ -354,7 +355,7 @@ class DependencyAwareCache(BasicCache):
|
||||
self.executed_nodes.clear()
|
||||
|
||||
# Call the parent method to initialize the cache with the new prompt
|
||||
super().set_prompt(dynprompt, node_ids, is_changed_cache)
|
||||
await super().set_prompt(dynprompt, node_ids, is_changed_cache)
|
||||
|
||||
# Rebuild the dependency graph
|
||||
self._build_dependency_graph(dynprompt, node_ids)
|
||||
@@ -405,7 +406,7 @@ class DependencyAwareCache(BasicCache):
|
||||
"""
|
||||
return self._get_immediate(node_id)
|
||||
|
||||
def ensure_subcache_for(self, node_id, children_ids):
|
||||
async def ensure_subcache_for(self, node_id, children_ids):
|
||||
"""
|
||||
Ensure a subcache exists for a node and update dependencies.
|
||||
|
||||
@@ -416,7 +417,7 @@ class DependencyAwareCache(BasicCache):
|
||||
Returns:
|
||||
The subcache object for the node.
|
||||
"""
|
||||
subcache = super()._ensure_subcache(node_id, children_ids)
|
||||
subcache = await super()._ensure_subcache(node_id, children_ids)
|
||||
for child_id in children_ids:
|
||||
self.descendants[node_id].add(child_id)
|
||||
self.ancestors[child_id].add(node_id)
|
||||
|
@@ -2,6 +2,8 @@ from __future__ import annotations
|
||||
from typing import Type, Literal
|
||||
|
||||
import nodes
|
||||
import asyncio
|
||||
import inspect
|
||||
from comfy_execution.graph_utils import is_link
|
||||
from comfy.comfy_types.node_typing import ComfyNodeABC, InputTypeDict, InputTypeOptions
|
||||
|
||||
@@ -100,6 +102,8 @@ class TopologicalSort:
|
||||
self.pendingNodes = {}
|
||||
self.blockCount = {} # Number of nodes this node is directly blocked by
|
||||
self.blocking = {} # Which nodes are blocked by this node
|
||||
self.externalBlocks = 0
|
||||
self.unblockedEvent = asyncio.Event()
|
||||
|
||||
def get_input_info(self, unique_id, input_name):
|
||||
class_type = self.dynprompt.get_node(unique_id)["class_type"]
|
||||
@@ -153,6 +157,16 @@ class TopologicalSort:
|
||||
for link in links:
|
||||
self.add_strong_link(*link)
|
||||
|
||||
def add_external_block(self, node_id):
|
||||
assert node_id in self.blockCount, "Can't add external block to a node that isn't pending"
|
||||
self.externalBlocks += 1
|
||||
self.blockCount[node_id] += 1
|
||||
def unblock():
|
||||
self.externalBlocks -= 1
|
||||
self.blockCount[node_id] -= 1
|
||||
self.unblockedEvent.set()
|
||||
return unblock
|
||||
|
||||
def is_cached(self, node_id):
|
||||
return False
|
||||
|
||||
@@ -181,11 +195,16 @@ class ExecutionList(TopologicalSort):
|
||||
def is_cached(self, node_id):
|
||||
return self.output_cache.get(node_id) is not None
|
||||
|
||||
def stage_node_execution(self):
|
||||
async def stage_node_execution(self):
|
||||
assert self.staged_node_id is None
|
||||
if self.is_empty():
|
||||
return None, None, None
|
||||
available = self.get_ready_nodes()
|
||||
while len(available) == 0 and self.externalBlocks > 0:
|
||||
# Wait for an external block to be released
|
||||
await self.unblockedEvent.wait()
|
||||
self.unblockedEvent.clear()
|
||||
available = self.get_ready_nodes()
|
||||
if len(available) == 0:
|
||||
cycled_nodes = self.get_nodes_in_cycle()
|
||||
# Because cycles composed entirely of static nodes are caught during initial validation,
|
||||
@@ -221,8 +240,15 @@ class ExecutionList(TopologicalSort):
|
||||
return True
|
||||
return False
|
||||
|
||||
# If an available node is async, do that first.
|
||||
# This will execute the asynchronous function earlier, reducing the overall time.
|
||||
def is_async(node_id):
|
||||
class_type = self.dynprompt.get_node(node_id)["class_type"]
|
||||
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
||||
return inspect.iscoroutinefunction(getattr(class_def, class_def.FUNCTION))
|
||||
|
||||
for node_id in node_list:
|
||||
if is_output(node_id):
|
||||
if is_output(node_id) or is_async(node_id):
|
||||
return node_id
|
||||
|
||||
#This should handle the VAEDecode -> preview case
|
||||
|
347
comfy_execution/progress.py
Normal file
347
comfy_execution/progress.py
Normal file
@@ -0,0 +1,347 @@
|
||||
from typing import TypedDict, Dict, Optional
|
||||
from typing_extensions import override
|
||||
from PIL import Image
|
||||
from enum import Enum
|
||||
from abc import ABC
|
||||
from tqdm import tqdm
|
||||
from typing import TYPE_CHECKING
|
||||
if TYPE_CHECKING:
|
||||
from comfy_execution.graph import DynamicPrompt
|
||||
from protocol import BinaryEventTypes
|
||||
from comfy_api import feature_flags
|
||||
|
||||
|
||||
class NodeState(Enum):
|
||||
Pending = "pending"
|
||||
Running = "running"
|
||||
Finished = "finished"
|
||||
Error = "error"
|
||||
|
||||
|
||||
class NodeProgressState(TypedDict):
|
||||
"""
|
||||
A class to represent the state of a node's progress.
|
||||
"""
|
||||
|
||||
state: NodeState
|
||||
value: float
|
||||
max: float
|
||||
|
||||
|
||||
class ProgressHandler(ABC):
|
||||
"""
|
||||
Abstract base class for progress handlers.
|
||||
Progress handlers receive progress updates and display them in various ways.
|
||||
"""
|
||||
|
||||
def __init__(self, name: str):
|
||||
self.name = name
|
||||
self.enabled = True
|
||||
|
||||
def set_registry(self, registry: "ProgressRegistry"):
|
||||
pass
|
||||
|
||||
def start_handler(self, node_id: str, state: NodeProgressState, prompt_id: str):
|
||||
"""Called when a node starts processing"""
|
||||
pass
|
||||
|
||||
def update_handler(
|
||||
self,
|
||||
node_id: str,
|
||||
value: float,
|
||||
max_value: float,
|
||||
state: NodeProgressState,
|
||||
prompt_id: str,
|
||||
image: Optional[Image.Image] = None,
|
||||
):
|
||||
"""Called when a node's progress is updated"""
|
||||
pass
|
||||
|
||||
def finish_handler(self, node_id: str, state: NodeProgressState, prompt_id: str):
|
||||
"""Called when a node finishes processing"""
|
||||
pass
|
||||
|
||||
def reset(self):
|
||||
"""Called when the progress registry is reset"""
|
||||
pass
|
||||
|
||||
def enable(self):
|
||||
"""Enable this handler"""
|
||||
self.enabled = True
|
||||
|
||||
def disable(self):
|
||||
"""Disable this handler"""
|
||||
self.enabled = False
|
||||
|
||||
|
||||
class CLIProgressHandler(ProgressHandler):
|
||||
"""
|
||||
Handler that displays progress using tqdm progress bars in the CLI.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__("cli")
|
||||
self.progress_bars: Dict[str, tqdm] = {}
|
||||
|
||||
@override
|
||||
def start_handler(self, node_id: str, state: NodeProgressState, prompt_id: str):
|
||||
# Create a new tqdm progress bar
|
||||
if node_id not in self.progress_bars:
|
||||
self.progress_bars[node_id] = tqdm(
|
||||
total=state["max"],
|
||||
desc=f"Node {node_id}",
|
||||
unit="steps",
|
||||
leave=True,
|
||||
position=len(self.progress_bars),
|
||||
)
|
||||
|
||||
@override
|
||||
def update_handler(
|
||||
self,
|
||||
node_id: str,
|
||||
value: float,
|
||||
max_value: float,
|
||||
state: NodeProgressState,
|
||||
prompt_id: str,
|
||||
image: Optional[Image.Image] = None,
|
||||
):
|
||||
# Handle case where start_handler wasn't called
|
||||
if node_id not in self.progress_bars:
|
||||
self.progress_bars[node_id] = tqdm(
|
||||
total=max_value,
|
||||
desc=f"Node {node_id}",
|
||||
unit="steps",
|
||||
leave=True,
|
||||
position=len(self.progress_bars),
|
||||
)
|
||||
self.progress_bars[node_id].update(value)
|
||||
else:
|
||||
# Update existing progress bar
|
||||
if max_value != self.progress_bars[node_id].total:
|
||||
self.progress_bars[node_id].total = max_value
|
||||
# Calculate the update amount (difference from current position)
|
||||
current_position = self.progress_bars[node_id].n
|
||||
update_amount = value - current_position
|
||||
if update_amount > 0:
|
||||
self.progress_bars[node_id].update(update_amount)
|
||||
|
||||
@override
|
||||
def finish_handler(self, node_id: str, state: NodeProgressState, prompt_id: str):
|
||||
# Complete and close the progress bar if it exists
|
||||
if node_id in self.progress_bars:
|
||||
# Ensure the bar shows 100% completion
|
||||
remaining = state["max"] - self.progress_bars[node_id].n
|
||||
if remaining > 0:
|
||||
self.progress_bars[node_id].update(remaining)
|
||||
self.progress_bars[node_id].close()
|
||||
del self.progress_bars[node_id]
|
||||
|
||||
@override
|
||||
def reset(self):
|
||||
# Close all progress bars
|
||||
for bar in self.progress_bars.values():
|
||||
bar.close()
|
||||
self.progress_bars.clear()
|
||||
|
||||
|
||||
class WebUIProgressHandler(ProgressHandler):
|
||||
"""
|
||||
Handler that sends progress updates to the WebUI via WebSockets.
|
||||
"""
|
||||
|
||||
def __init__(self, server_instance):
|
||||
super().__init__("webui")
|
||||
self.server_instance = server_instance
|
||||
|
||||
def set_registry(self, registry: "ProgressRegistry"):
|
||||
self.registry = registry
|
||||
|
||||
def _send_progress_state(self, prompt_id: str, nodes: Dict[str, NodeProgressState]):
|
||||
"""Send the current progress state to the client"""
|
||||
if self.server_instance is None:
|
||||
return
|
||||
|
||||
# Only send info for non-pending nodes
|
||||
active_nodes = {
|
||||
node_id: {
|
||||
"value": state["value"],
|
||||
"max": state["max"],
|
||||
"state": state["state"].value,
|
||||
"node_id": node_id,
|
||||
"prompt_id": prompt_id,
|
||||
"display_node_id": self.registry.dynprompt.get_display_node_id(node_id),
|
||||
"parent_node_id": self.registry.dynprompt.get_parent_node_id(node_id),
|
||||
"real_node_id": self.registry.dynprompt.get_real_node_id(node_id),
|
||||
}
|
||||
for node_id, state in nodes.items()
|
||||
if state["state"] != NodeState.Pending
|
||||
}
|
||||
|
||||
# Send a combined progress_state message with all node states
|
||||
self.server_instance.send_sync(
|
||||
"progress_state", {"prompt_id": prompt_id, "nodes": active_nodes}
|
||||
)
|
||||
|
||||
@override
|
||||
def start_handler(self, node_id: str, state: NodeProgressState, prompt_id: str):
|
||||
# Send progress state of all nodes
|
||||
if self.registry:
|
||||
self._send_progress_state(prompt_id, self.registry.nodes)
|
||||
|
||||
@override
|
||||
def update_handler(
|
||||
self,
|
||||
node_id: str,
|
||||
value: float,
|
||||
max_value: float,
|
||||
state: NodeProgressState,
|
||||
prompt_id: str,
|
||||
image: Optional[Image.Image] = None,
|
||||
):
|
||||
# Send progress state of all nodes
|
||||
if self.registry:
|
||||
self._send_progress_state(prompt_id, self.registry.nodes)
|
||||
if image:
|
||||
# Only send new format if client supports it
|
||||
if feature_flags.supports_feature(
|
||||
self.server_instance.sockets_metadata,
|
||||
self.server_instance.client_id,
|
||||
"supports_preview_metadata",
|
||||
):
|
||||
metadata = {
|
||||
"node_id": node_id,
|
||||
"prompt_id": prompt_id,
|
||||
"display_node_id": self.registry.dynprompt.get_display_node_id(
|
||||
node_id
|
||||
),
|
||||
"parent_node_id": self.registry.dynprompt.get_parent_node_id(
|
||||
node_id
|
||||
),
|
||||
"real_node_id": self.registry.dynprompt.get_real_node_id(node_id),
|
||||
}
|
||||
self.server_instance.send_sync(
|
||||
BinaryEventTypes.PREVIEW_IMAGE_WITH_METADATA,
|
||||
(image, metadata),
|
||||
self.server_instance.client_id,
|
||||
)
|
||||
|
||||
@override
|
||||
def finish_handler(self, node_id: str, state: NodeProgressState, prompt_id: str):
|
||||
# Send progress state of all nodes
|
||||
if self.registry:
|
||||
self._send_progress_state(prompt_id, self.registry.nodes)
|
||||
|
||||
|
||||
class ProgressRegistry:
|
||||
"""
|
||||
Registry that maintains node progress state and notifies registered handlers.
|
||||
"""
|
||||
|
||||
def __init__(self, prompt_id: str, dynprompt: "DynamicPrompt"):
|
||||
self.prompt_id = prompt_id
|
||||
self.dynprompt = dynprompt
|
||||
self.nodes: Dict[str, NodeProgressState] = {}
|
||||
self.handlers: Dict[str, ProgressHandler] = {}
|
||||
|
||||
def register_handler(self, handler: ProgressHandler) -> None:
|
||||
"""Register a progress handler"""
|
||||
self.handlers[handler.name] = handler
|
||||
|
||||
def unregister_handler(self, handler_name: str) -> None:
|
||||
"""Unregister a progress handler"""
|
||||
if handler_name in self.handlers:
|
||||
# Allow handler to clean up resources
|
||||
self.handlers[handler_name].reset()
|
||||
del self.handlers[handler_name]
|
||||
|
||||
def enable_handler(self, handler_name: str) -> None:
|
||||
"""Enable a progress handler"""
|
||||
if handler_name in self.handlers:
|
||||
self.handlers[handler_name].enable()
|
||||
|
||||
def disable_handler(self, handler_name: str) -> None:
|
||||
"""Disable a progress handler"""
|
||||
if handler_name in self.handlers:
|
||||
self.handlers[handler_name].disable()
|
||||
|
||||
def ensure_entry(self, node_id: str) -> NodeProgressState:
|
||||
"""Ensure a node entry exists"""
|
||||
if node_id not in self.nodes:
|
||||
self.nodes[node_id] = NodeProgressState(
|
||||
state=NodeState.Pending, value=0, max=1
|
||||
)
|
||||
return self.nodes[node_id]
|
||||
|
||||
def start_progress(self, node_id: str) -> None:
|
||||
"""Start progress tracking for a node"""
|
||||
entry = self.ensure_entry(node_id)
|
||||
entry["state"] = NodeState.Running
|
||||
entry["value"] = 0.0
|
||||
entry["max"] = 1.0
|
||||
|
||||
# Notify all enabled handlers
|
||||
for handler in self.handlers.values():
|
||||
if handler.enabled:
|
||||
handler.start_handler(node_id, entry, self.prompt_id)
|
||||
|
||||
def update_progress(
|
||||
self, node_id: str, value: float, max_value: float, image: Optional[Image.Image]
|
||||
) -> None:
|
||||
"""Update progress for a node"""
|
||||
entry = self.ensure_entry(node_id)
|
||||
entry["state"] = NodeState.Running
|
||||
entry["value"] = value
|
||||
entry["max"] = max_value
|
||||
|
||||
# Notify all enabled handlers
|
||||
for handler in self.handlers.values():
|
||||
if handler.enabled:
|
||||
handler.update_handler(
|
||||
node_id, value, max_value, entry, self.prompt_id, image
|
||||
)
|
||||
|
||||
def finish_progress(self, node_id: str) -> None:
|
||||
"""Finish progress tracking for a node"""
|
||||
entry = self.ensure_entry(node_id)
|
||||
entry["state"] = NodeState.Finished
|
||||
entry["value"] = entry["max"]
|
||||
|
||||
# Notify all enabled handlers
|
||||
for handler in self.handlers.values():
|
||||
if handler.enabled:
|
||||
handler.finish_handler(node_id, entry, self.prompt_id)
|
||||
|
||||
def reset_handlers(self) -> None:
|
||||
"""Reset all handlers"""
|
||||
for handler in self.handlers.values():
|
||||
handler.reset()
|
||||
|
||||
# Global registry instance
|
||||
global_progress_registry: ProgressRegistry = None
|
||||
|
||||
def reset_progress_state(prompt_id: str, dynprompt: "DynamicPrompt") -> None:
|
||||
global global_progress_registry
|
||||
|
||||
# Reset existing handlers if registry exists
|
||||
if global_progress_registry is not None:
|
||||
global_progress_registry.reset_handlers()
|
||||
|
||||
# Create new registry
|
||||
global_progress_registry = ProgressRegistry(prompt_id, dynprompt)
|
||||
|
||||
|
||||
def add_progress_handler(handler: ProgressHandler) -> None:
|
||||
registry = get_progress_state()
|
||||
handler.set_registry(registry)
|
||||
registry.register_handler(handler)
|
||||
|
||||
|
||||
def get_progress_state() -> ProgressRegistry:
|
||||
global global_progress_registry
|
||||
if global_progress_registry is None:
|
||||
from comfy_execution.graph import DynamicPrompt
|
||||
|
||||
global_progress_registry = ProgressRegistry(
|
||||
prompt_id="", dynprompt=DynamicPrompt({})
|
||||
)
|
||||
return global_progress_registry
|
46
comfy_execution/utils.py
Normal file
46
comfy_execution/utils.py
Normal file
@@ -0,0 +1,46 @@
|
||||
import contextvars
|
||||
from typing import Optional, NamedTuple
|
||||
|
||||
class ExecutionContext(NamedTuple):
|
||||
"""
|
||||
Context information about the currently executing node.
|
||||
|
||||
Attributes:
|
||||
node_id: The ID of the currently executing node
|
||||
list_index: The index in a list being processed (for operations on batches/lists)
|
||||
"""
|
||||
prompt_id: str
|
||||
node_id: str
|
||||
list_index: Optional[int]
|
||||
|
||||
current_executing_context: contextvars.ContextVar[Optional[ExecutionContext]] = contextvars.ContextVar("current_executing_context", default=None)
|
||||
|
||||
def get_executing_context() -> Optional[ExecutionContext]:
|
||||
return current_executing_context.get(None)
|
||||
|
||||
class CurrentNodeContext:
|
||||
"""
|
||||
Context manager for setting the current executing node context.
|
||||
|
||||
Sets the current_executing_context on enter and resets it on exit.
|
||||
|
||||
Example:
|
||||
with CurrentNodeContext(node_id="123", list_index=0):
|
||||
# Code that should run with the current node context set
|
||||
process_image()
|
||||
"""
|
||||
def __init__(self, prompt_id: str, node_id: str, list_index: Optional[int] = None):
|
||||
self.context = ExecutionContext(
|
||||
prompt_id= prompt_id,
|
||||
node_id= node_id,
|
||||
list_index= list_index
|
||||
)
|
||||
self.token = None
|
||||
|
||||
def __enter__(self):
|
||||
self.token = current_executing_context.set(self.context)
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
if self.token is not None:
|
||||
current_executing_context.reset(self.token)
|
@@ -40,6 +40,33 @@ class CFGZeroStar:
|
||||
m.set_model_sampler_post_cfg_function(cfg_zero_star)
|
||||
return (m, )
|
||||
|
||||
class CFGNorm:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"model": ("MODEL",),
|
||||
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step": 0.01}),
|
||||
}}
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
RETURN_NAMES = ("patched_model",)
|
||||
FUNCTION = "patch"
|
||||
CATEGORY = "advanced/guidance"
|
||||
EXPERIMENTAL = True
|
||||
|
||||
def patch(self, model, strength):
|
||||
m = model.clone()
|
||||
def cfg_norm(args):
|
||||
cond_p = args['cond_denoised']
|
||||
pred_text_ = args["denoised"]
|
||||
|
||||
norm_full_cond = torch.norm(cond_p, dim=1, keepdim=True)
|
||||
norm_pred_text = torch.norm(pred_text_, dim=1, keepdim=True)
|
||||
scale = (norm_full_cond / (norm_pred_text + 1e-8)).clamp(min=0.0, max=1.0)
|
||||
return pred_text_ * scale * strength
|
||||
|
||||
m.set_model_sampler_post_cfg_function(cfg_norm)
|
||||
return (m, )
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"CFGZeroStar": CFGZeroStar
|
||||
"CFGZeroStar": CFGZeroStar,
|
||||
"CFGNorm": CFGNorm,
|
||||
}
|
||||
|
@@ -71,8 +71,11 @@ class FreSca:
|
||||
DESCRIPTION = "Applies frequency-dependent scaling to the guidance"
|
||||
def patch(self, model, scale_low, scale_high, freq_cutoff):
|
||||
def custom_cfg_function(args):
|
||||
cond = args["conds_out"][0]
|
||||
uncond = args["conds_out"][1]
|
||||
conds_out = args["conds_out"]
|
||||
if len(conds_out) <= 1 or None in args["conds"][:2]:
|
||||
return conds_out
|
||||
cond = conds_out[0]
|
||||
uncond = conds_out[1]
|
||||
|
||||
guidance = cond - uncond
|
||||
filtered_guidance = Fourier_filter(
|
||||
@@ -83,7 +86,7 @@ class FreSca:
|
||||
)
|
||||
filtered_cond = filtered_guidance + uncond
|
||||
|
||||
return [filtered_cond, uncond]
|
||||
return [filtered_cond, uncond] + conds_out[2:]
|
||||
|
||||
m = model.clone()
|
||||
m.set_model_sampler_pre_cfg_function(custom_cfg_function)
|
||||
|
@@ -247,7 +247,7 @@ class MaskComposite:
|
||||
visible_width, visible_height = (right - left, bottom - top,)
|
||||
|
||||
source_portion = source[:, :visible_height, :visible_width]
|
||||
destination_portion = destination[:, top:bottom, left:right]
|
||||
destination_portion = output[:, top:bottom, left:right]
|
||||
|
||||
if operation == "multiply":
|
||||
output[:, top:bottom, left:right] = destination_portion * source_portion
|
||||
|
@@ -1,24 +1,24 @@
|
||||
from nodes import MAX_RESOLUTION
|
||||
|
||||
class CLIPTextEncodePixArtAlpha:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {
|
||||
"width": ("INT", {"default": 1024.0, "min": 0, "max": MAX_RESOLUTION}),
|
||||
"height": ("INT", {"default": 1024.0, "min": 0, "max": MAX_RESOLUTION}),
|
||||
# "aspect_ratio": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||
"text": ("STRING", {"multiline": True, "dynamicPrompts": True}), "clip": ("CLIP", ),
|
||||
}}
|
||||
|
||||
RETURN_TYPES = ("CONDITIONING",)
|
||||
FUNCTION = "encode"
|
||||
CATEGORY = "advanced/conditioning"
|
||||
DESCRIPTION = "Encodes text and sets the resolution conditioning for PixArt Alpha. Does not apply to PixArt Sigma."
|
||||
|
||||
def encode(self, clip, width, height, text):
|
||||
tokens = clip.tokenize(text)
|
||||
return (clip.encode_from_tokens_scheduled(tokens, add_dict={"width": width, "height": height}),)
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"CLIPTextEncodePixArtAlpha": CLIPTextEncodePixArtAlpha,
|
||||
}
|
||||
from nodes import MAX_RESOLUTION
|
||||
|
||||
class CLIPTextEncodePixArtAlpha:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {
|
||||
"width": ("INT", {"default": 1024.0, "min": 0, "max": MAX_RESOLUTION}),
|
||||
"height": ("INT", {"default": 1024.0, "min": 0, "max": MAX_RESOLUTION}),
|
||||
# "aspect_ratio": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||
"text": ("STRING", {"multiline": True, "dynamicPrompts": True}), "clip": ("CLIP", ),
|
||||
}}
|
||||
|
||||
RETURN_TYPES = ("CONDITIONING",)
|
||||
FUNCTION = "encode"
|
||||
CATEGORY = "advanced/conditioning"
|
||||
DESCRIPTION = "Encodes text and sets the resolution conditioning for PixArt Alpha. Does not apply to PixArt Sigma."
|
||||
|
||||
def encode(self, clip, width, height, text):
|
||||
tokens = clip.tokenize(text)
|
||||
return (clip.encode_from_tokens_scheduled(tokens, add_dict={"width": width, "height": height}),)
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"CLIPTextEncodePixArtAlpha": CLIPTextEncodePixArtAlpha,
|
||||
}
|
||||
|
@@ -23,38 +23,78 @@ from comfy.comfy_types.node_typing import IO
|
||||
from comfy.weight_adapter import adapters
|
||||
|
||||
|
||||
def make_batch_extra_option_dict(d, indicies, full_size=None):
|
||||
new_dict = {}
|
||||
for k, v in d.items():
|
||||
newv = v
|
||||
if isinstance(v, dict):
|
||||
newv = make_batch_extra_option_dict(v, indicies, full_size=full_size)
|
||||
elif isinstance(v, torch.Tensor):
|
||||
if full_size is None or v.size(0) == full_size:
|
||||
newv = v[indicies]
|
||||
elif isinstance(v, (list, tuple)) and len(v) == full_size:
|
||||
newv = [v[i] for i in indicies]
|
||||
new_dict[k] = newv
|
||||
return new_dict
|
||||
|
||||
|
||||
class TrainSampler(comfy.samplers.Sampler):
|
||||
|
||||
def __init__(self, loss_fn, optimizer, loss_callback=None):
|
||||
def __init__(self, loss_fn, optimizer, loss_callback=None, batch_size=1, total_steps=1, seed=0, training_dtype=torch.bfloat16):
|
||||
self.loss_fn = loss_fn
|
||||
self.optimizer = optimizer
|
||||
self.loss_callback = loss_callback
|
||||
self.batch_size = batch_size
|
||||
self.total_steps = total_steps
|
||||
self.seed = seed
|
||||
self.training_dtype = training_dtype
|
||||
|
||||
def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False):
|
||||
self.optimizer.zero_grad()
|
||||
noise = model_wrap.inner_model.model_sampling.noise_scaling(sigmas, noise, latent_image, False)
|
||||
latent = model_wrap.inner_model.model_sampling.noise_scaling(
|
||||
torch.zeros_like(sigmas),
|
||||
torch.zeros_like(noise, requires_grad=True),
|
||||
latent_image,
|
||||
False
|
||||
)
|
||||
cond = model_wrap.conds["positive"]
|
||||
dataset_size = sigmas.size(0)
|
||||
torch.cuda.empty_cache()
|
||||
for i in (pbar:=tqdm.trange(self.total_steps, desc="Training LoRA", smoothing=0.01, disable=not comfy.utils.PROGRESS_BAR_ENABLED)):
|
||||
noisegen = comfy_extras.nodes_custom_sampler.Noise_RandomNoise(self.seed + i * 1000)
|
||||
indicies = torch.randperm(dataset_size)[:self.batch_size].tolist()
|
||||
|
||||
# Ensure model is in training mode and computing gradients
|
||||
# x0 pred
|
||||
denoised = model_wrap(noise, sigmas, **extra_args)
|
||||
try:
|
||||
loss = self.loss_fn(denoised, latent.clone())
|
||||
except RuntimeError as e:
|
||||
if "does not require grad and does not have a grad_fn" in str(e):
|
||||
logging.info("WARNING: This is likely due to the model is loaded in inference mode.")
|
||||
loss.backward()
|
||||
if self.loss_callback:
|
||||
self.loss_callback(loss.item())
|
||||
batch_latent = torch.stack([latent_image[i] for i in indicies])
|
||||
batch_noise = noisegen.generate_noise({"samples": batch_latent}).to(batch_latent.device)
|
||||
batch_sigmas = [
|
||||
model_wrap.inner_model.model_sampling.percent_to_sigma(
|
||||
torch.rand((1,)).item()
|
||||
) for _ in range(min(self.batch_size, dataset_size))
|
||||
]
|
||||
batch_sigmas = torch.tensor(batch_sigmas).to(batch_latent.device)
|
||||
|
||||
self.optimizer.step()
|
||||
# torch.cuda.memory._dump_snapshot("trainn.pickle")
|
||||
# torch.cuda.memory._record_memory_history(enabled=None)
|
||||
xt = model_wrap.inner_model.model_sampling.noise_scaling(
|
||||
batch_sigmas,
|
||||
batch_noise,
|
||||
batch_latent,
|
||||
False
|
||||
)
|
||||
x0 = model_wrap.inner_model.model_sampling.noise_scaling(
|
||||
torch.zeros_like(batch_sigmas),
|
||||
torch.zeros_like(batch_noise),
|
||||
batch_latent,
|
||||
False
|
||||
)
|
||||
|
||||
model_wrap.conds["positive"] = [
|
||||
cond[i] for i in indicies
|
||||
]
|
||||
batch_extra_args = make_batch_extra_option_dict(extra_args, indicies, full_size=dataset_size)
|
||||
|
||||
with torch.autocast(xt.device.type, dtype=self.training_dtype):
|
||||
x0_pred = model_wrap(xt, batch_sigmas, **batch_extra_args)
|
||||
loss = self.loss_fn(x0_pred, x0)
|
||||
loss.backward()
|
||||
if self.loss_callback:
|
||||
self.loss_callback(loss.item())
|
||||
pbar.set_postfix({"loss": f"{loss.item():.4f}"})
|
||||
|
||||
self.optimizer.step()
|
||||
self.optimizer.zero_grad()
|
||||
torch.cuda.empty_cache()
|
||||
return torch.zeros_like(latent_image)
|
||||
|
||||
|
||||
@@ -584,36 +624,34 @@ class TrainLoraNode:
|
||||
loss_map = {"loss": []}
|
||||
def loss_callback(loss):
|
||||
loss_map["loss"].append(loss)
|
||||
pbar.set_postfix({"loss": f"{loss:.4f}"})
|
||||
train_sampler = TrainSampler(
|
||||
criterion, optimizer, loss_callback=loss_callback
|
||||
criterion,
|
||||
optimizer,
|
||||
loss_callback=loss_callback,
|
||||
batch_size=batch_size,
|
||||
total_steps=steps,
|
||||
seed=seed,
|
||||
training_dtype=dtype
|
||||
)
|
||||
guider = comfy_extras.nodes_custom_sampler.Guider_Basic(mp)
|
||||
guider.set_conds(positive) # Set conditioning from input
|
||||
|
||||
# yoland: this currently resize to the first image in the dataset
|
||||
|
||||
# Training loop
|
||||
torch.cuda.empty_cache()
|
||||
try:
|
||||
for step in (pbar:=tqdm.trange(steps, desc="Training LoRA", smoothing=0.01, disable=not comfy.utils.PROGRESS_BAR_ENABLED)):
|
||||
# Generate random sigma
|
||||
sigmas = [mp.model.model_sampling.percent_to_sigma(
|
||||
torch.rand((1,)).item()
|
||||
) for _ in range(min(batch_size, num_images))]
|
||||
sigmas = torch.tensor(sigmas)
|
||||
|
||||
noise = comfy_extras.nodes_custom_sampler.Noise_RandomNoise(step * 1000 + seed)
|
||||
|
||||
indices = torch.randperm(num_images)[:batch_size]
|
||||
batch_latent = latents[indices].clone()
|
||||
guider.set_conds([positive[i] for i in indices]) # Set conditioning from input
|
||||
guider.sample(noise.generate_noise({"samples": batch_latent}), batch_latent, train_sampler, sigmas, seed=noise.seed)
|
||||
# Generate dummy sigmas and noise
|
||||
sigmas = torch.tensor(range(num_images))
|
||||
noise = comfy_extras.nodes_custom_sampler.Noise_RandomNoise(seed)
|
||||
guider.sample(
|
||||
noise.generate_noise({"samples": latents}),
|
||||
latents,
|
||||
train_sampler,
|
||||
sigmas,
|
||||
seed=noise.seed
|
||||
)
|
||||
finally:
|
||||
for m in mp.model.modules():
|
||||
unpatch(m)
|
||||
del train_sampler, optimizer
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
for adapter in all_weight_adapters:
|
||||
adapter.requires_grad_(False)
|
||||
|
@@ -1,5 +1,8 @@
|
||||
import torch
|
||||
from comfy.comfy_types.node_typing import ComfyNodeABC, IO
|
||||
import asyncio
|
||||
from comfy.utils import ProgressBar
|
||||
import time
|
||||
|
||||
|
||||
class TestNode(ComfyNodeABC):
|
||||
@@ -34,10 +37,41 @@ class TestNode(ComfyNodeABC):
|
||||
return (some_int, image)
|
||||
|
||||
|
||||
class TestSleep(ComfyNodeABC):
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"value": (IO.ANY, {}),
|
||||
"seconds": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 9999.0, "step": 0.01, "tooltip": "The amount of seconds to sleep."}),
|
||||
},
|
||||
"hidden": {
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
RETURN_TYPES = (IO.ANY,)
|
||||
FUNCTION = "sleep"
|
||||
|
||||
CATEGORY = "_for_testing"
|
||||
|
||||
async def sleep(self, value, seconds, unique_id):
|
||||
pbar = ProgressBar(seconds, node_id=unique_id)
|
||||
start = time.time()
|
||||
expiration = start + seconds
|
||||
now = start
|
||||
while now < expiration:
|
||||
now = time.time()
|
||||
pbar.update_absolute(now - start)
|
||||
await asyncio.sleep(0.02)
|
||||
return (value,)
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"V1TestNode1": TestNode,
|
||||
"V1TestSleep": TestSleep,
|
||||
}
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"V1TestNode1": "V1 Test Node",
|
||||
"V1TestSleep": "V1 Test Sleep",
|
||||
}
|
||||
|
@@ -5,6 +5,7 @@ import logging # noqa
|
||||
import folder_paths
|
||||
import comfy.utils
|
||||
import comfy.sd
|
||||
import asyncio
|
||||
|
||||
|
||||
@io.comfytype(io_type="XYZ")
|
||||
@@ -203,8 +204,43 @@ class NInputsTest(io.ComfyNodeV3):
|
||||
return io.NodeOutput(combined_image)
|
||||
|
||||
|
||||
class V3TestSleep(io.ComfyNodeV3):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.SchemaV3(
|
||||
node_id="V3_TestSleep",
|
||||
display_name="V3 Test Sleep",
|
||||
category="_for_testing",
|
||||
description="Test async sleep functionality.",
|
||||
inputs=[
|
||||
io.AnyType.Input("value", display_name="Value"),
|
||||
io.Float.Input("seconds", display_name="Seconds", default=1.0, min=0.0, max=9999.0, step=0.01, tooltip="The amount of seconds to sleep."),
|
||||
],
|
||||
outputs=[
|
||||
io.AnyType.Output(),
|
||||
],
|
||||
hidden=[
|
||||
io.Hidden.unique_id,
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(cls, value: io.AnyType.Type, seconds: io.Float.Type, **kwargs):
|
||||
logging.info(f"V3TestSleep: {cls.hidden.unique_id}")
|
||||
pbar = comfy.utils.ProgressBar(seconds, node_id=cls.hidden.unique_id)
|
||||
start = time.time()
|
||||
expiration = start + seconds
|
||||
now = start
|
||||
while now < expiration:
|
||||
now = time.time()
|
||||
pbar.update_absolute(now - start)
|
||||
await asyncio.sleep(0.02)
|
||||
return io.NodeOutput(value)
|
||||
|
||||
|
||||
NODES_LIST: list[type[io.ComfyNodeV3]] = [
|
||||
V3TestNode,
|
||||
V3LoraLoader,
|
||||
NInputsTest,
|
||||
V3TestSleep,
|
||||
]
|
||||
|
217
comfy_extras/v3/nodes_camera_trajectory.py
Normal file
217
comfy_extras/v3/nodes_camera_trajectory.py
Normal file
@@ -0,0 +1,217 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from einops import rearrange
|
||||
|
||||
import comfy.model_management
|
||||
import nodes
|
||||
from comfy_api.v3 import io
|
||||
|
||||
CAMERA_DICT = {
|
||||
"base_T_norm": 1.5,
|
||||
"base_angle": np.pi / 3,
|
||||
"Static": {"angle": [0.0, 0.0, 0.0], "T": [0.0, 0.0, 0.0]},
|
||||
"Pan Up": {"angle": [0.0, 0.0, 0.0], "T": [0.0, -1.0, 0.0]},
|
||||
"Pan Down": {"angle": [0.0, 0.0, 0.0], "T": [0.0, 1.0, 0.0]},
|
||||
"Pan Left": {"angle": [0.0, 0.0, 0.0], "T": [-1.0, 0.0, 0.0]},
|
||||
"Pan Right": {"angle": [0.0, 0.0, 0.0], "T": [1.0, 0.0, 0.0]},
|
||||
"Zoom In": {"angle": [0.0, 0.0, 0.0], "T": [0.0, 0.0, 2.0]},
|
||||
"Zoom Out": {"angle": [0.0, 0.0, 0.0], "T": [0.0, 0.0, -2.0]},
|
||||
"Anti Clockwise (ACW)": {"angle": [0.0, 0.0, -1.0], "T": [0.0, 0.0, 0.0]},
|
||||
"ClockWise (CW)": {"angle": [0.0, 0.0, 1.0], "T": [0.0, 0.0, 0.0]},
|
||||
}
|
||||
|
||||
|
||||
def process_pose_params(cam_params, width=672, height=384, original_pose_width=1280, original_pose_height=720, device="cpu"):
|
||||
def get_relative_pose(cam_params):
|
||||
"""Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py"""
|
||||
abs_w2cs = [cam_param.w2c_mat for cam_param in cam_params]
|
||||
abs_c2ws = [cam_param.c2w_mat for cam_param in cam_params]
|
||||
cam_to_origin = 0
|
||||
target_cam_c2w = np.array([[1, 0, 0, 0], [0, 1, 0, -cam_to_origin], [0, 0, 1, 0], [0, 0, 0, 1]])
|
||||
abs2rel = target_cam_c2w @ abs_w2cs[0]
|
||||
ret_poses = [target_cam_c2w] + [abs2rel @ abs_c2w for abs_c2w in abs_c2ws[1:]]
|
||||
return np.array(ret_poses, dtype=np.float32)
|
||||
|
||||
"""Modified from https://github.com/hehao13/CameraCtrl/blob/main/inference.py"""
|
||||
cam_params = [Camera(cam_param) for cam_param in cam_params]
|
||||
|
||||
sample_wh_ratio = width / height
|
||||
pose_wh_ratio = original_pose_width / original_pose_height # Assuming placeholder ratios, change as needed
|
||||
|
||||
if pose_wh_ratio > sample_wh_ratio:
|
||||
resized_ori_w = height * pose_wh_ratio
|
||||
for cam_param in cam_params:
|
||||
cam_param.fx = resized_ori_w * cam_param.fx / width
|
||||
else:
|
||||
resized_ori_h = width / pose_wh_ratio
|
||||
for cam_param in cam_params:
|
||||
cam_param.fy = resized_ori_h * cam_param.fy / height
|
||||
|
||||
intrinsic = np.asarray(
|
||||
[[cam_param.fx * width, cam_param.fy * height, cam_param.cx * width, cam_param.cy * height] for cam_param in cam_params],
|
||||
dtype=np.float32,
|
||||
)
|
||||
|
||||
K = torch.as_tensor(intrinsic)[None] # [1, 1, 4]
|
||||
c2ws = get_relative_pose(cam_params) # Assuming this function is defined elsewhere
|
||||
c2ws = torch.as_tensor(c2ws)[None] # [1, n_frame, 4, 4]
|
||||
plucker_embedding = ray_condition(K, c2ws, height, width, device=device)[0].permute(0, 3, 1, 2).contiguous() # V, 6, H, W
|
||||
plucker_embedding = plucker_embedding[None]
|
||||
return rearrange(plucker_embedding, "b f c h w -> b f h w c")[0]
|
||||
|
||||
|
||||
class Camera:
|
||||
"""Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py"""
|
||||
|
||||
def __init__(self, entry):
|
||||
fx, fy, cx, cy = entry[1:5]
|
||||
self.fx = fx
|
||||
self.fy = fy
|
||||
self.cx = cx
|
||||
self.cy = cy
|
||||
c2w_mat = np.array(entry[7:]).reshape(4, 4)
|
||||
self.c2w_mat = c2w_mat
|
||||
self.w2c_mat = np.linalg.inv(c2w_mat)
|
||||
|
||||
|
||||
def ray_condition(K, c2w, H, W, device):
|
||||
"""Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py"""
|
||||
# c2w: B, V, 4, 4
|
||||
# K: B, V, 4
|
||||
|
||||
B = K.shape[0]
|
||||
|
||||
j, i = torch.meshgrid(
|
||||
torch.linspace(0, H - 1, H, device=device, dtype=c2w.dtype),
|
||||
torch.linspace(0, W - 1, W, device=device, dtype=c2w.dtype),
|
||||
indexing="ij",
|
||||
)
|
||||
i = i.reshape([1, 1, H * W]).expand([B, 1, H * W]) + 0.5 # [B, HxW]
|
||||
j = j.reshape([1, 1, H * W]).expand([B, 1, H * W]) + 0.5 # [B, HxW]
|
||||
|
||||
fx, fy, cx, cy = K.chunk(4, dim=-1) # B,V, 1
|
||||
|
||||
zs = torch.ones_like(i) # [B, HxW]
|
||||
xs = (i - cx) / fx * zs
|
||||
ys = (j - cy) / fy * zs
|
||||
zs = zs.expand_as(ys)
|
||||
|
||||
directions = torch.stack((xs, ys, zs), dim=-1) # B, V, HW, 3
|
||||
directions = directions / directions.norm(dim=-1, keepdim=True) # B, V, HW, 3
|
||||
|
||||
rays_d = directions @ c2w[..., :3, :3].transpose(-1, -2) # B, V, 3, HW
|
||||
rays_o = c2w[..., :3, 3] # B, V, 3
|
||||
rays_o = rays_o[:, :, None].expand_as(rays_d) # B, V, 3, HW
|
||||
# c2w @ dirctions
|
||||
rays_dxo = torch.cross(rays_o, rays_d)
|
||||
plucker = torch.cat([rays_dxo, rays_d], dim=-1)
|
||||
plucker = plucker.reshape(B, c2w.shape[1], H, W, 6) # B, V, H, W, 6
|
||||
# plucker = plucker.permute(0, 1, 4, 2, 3)
|
||||
return plucker
|
||||
|
||||
|
||||
def get_camera_motion(angle, T, speed, n=81):
|
||||
def compute_R_form_rad_angle(angles):
|
||||
theta_x, theta_y, theta_z = angles
|
||||
Rx = np.array([[1, 0, 0], [0, np.cos(theta_x), -np.sin(theta_x)], [0, np.sin(theta_x), np.cos(theta_x)]])
|
||||
|
||||
Ry = np.array([[np.cos(theta_y), 0, np.sin(theta_y)], [0, 1, 0], [-np.sin(theta_y), 0, np.cos(theta_y)]])
|
||||
|
||||
Rz = np.array([[np.cos(theta_z), -np.sin(theta_z), 0], [np.sin(theta_z), np.cos(theta_z), 0], [0, 0, 1]])
|
||||
|
||||
R = np.dot(Rz, np.dot(Ry, Rx))
|
||||
return R
|
||||
|
||||
RT = []
|
||||
for i in range(n):
|
||||
_angle = (i / n) * speed * (CAMERA_DICT["base_angle"]) * angle
|
||||
R = compute_R_form_rad_angle(_angle)
|
||||
_T = (i / n) * speed * (CAMERA_DICT["base_T_norm"]) * (T.reshape(3, 1))
|
||||
_RT = np.concatenate([R, _T], axis=1)
|
||||
RT.append(_RT)
|
||||
RT = np.stack(RT)
|
||||
return RT
|
||||
|
||||
|
||||
class WanCameraEmbedding(io.ComfyNodeV3):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.SchemaV3(
|
||||
node_id="WanCameraEmbedding_V3",
|
||||
category="camera",
|
||||
inputs=[
|
||||
io.Combo.Input(
|
||||
"camera_pose",
|
||||
options=[
|
||||
"Static",
|
||||
"Pan Up",
|
||||
"Pan Down",
|
||||
"Pan Left",
|
||||
"Pan Right",
|
||||
"Zoom In",
|
||||
"Zoom Out",
|
||||
"Anti Clockwise (ACW)",
|
||||
"ClockWise (CW)",
|
||||
],
|
||||
default="Static",
|
||||
),
|
||||
io.Int.Input("width", default=832, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
||||
io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
||||
io.Int.Input("length", default=81, min=1, max=nodes.MAX_RESOLUTION, step=4),
|
||||
io.Float.Input("speed", default=1.0, min=0, max=10.0, step=0.1, optional=True),
|
||||
io.Float.Input("fx", default=0.5, min=0, max=1, step=0.000000001, optional=True),
|
||||
io.Float.Input("fy", default=0.5, min=0, max=1, step=0.000000001, optional=True),
|
||||
io.Float.Input("cx", default=0.5, min=0, max=1, step=0.01, optional=True),
|
||||
io.Float.Input("cy", default=0.5, min=0, max=1, step=0.01, optional=True),
|
||||
],
|
||||
outputs=[
|
||||
io.WanCameraEmbedding.Output(display_name="camera_embedding"),
|
||||
io.Int.Output(display_name="width"),
|
||||
io.Int.Output(display_name="height"),
|
||||
io.Int.Output(display_name="length"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, camera_pose, width, height, length, speed=1.0, fx=0.5, fy=0.5, cx=0.5, cy=0.5) -> io.NodeOutput:
|
||||
"""
|
||||
Use Camera trajectory as extrinsic parameters to calculate Plücker embeddings (Sitzmannet al., 2021)
|
||||
Adapted from https://github.com/aigc-apps/VideoX-Fun/blob/main/comfyui/comfyui_nodes.py
|
||||
"""
|
||||
motion_list = [camera_pose]
|
||||
speed = speed
|
||||
angle = np.array(CAMERA_DICT[motion_list[0]]["angle"])
|
||||
T = np.array(CAMERA_DICT[motion_list[0]]["T"])
|
||||
RT = get_camera_motion(angle, T, speed, length)
|
||||
|
||||
trajs = []
|
||||
for cp in RT.tolist():
|
||||
traj = [fx, fy, cx, cy, 0, 0]
|
||||
traj.extend(cp[0])
|
||||
traj.extend(cp[1])
|
||||
traj.extend(cp[2])
|
||||
traj.extend([0, 0, 0, 1])
|
||||
trajs.append(traj)
|
||||
|
||||
cam_params = np.array([[float(x) for x in pose] for pose in trajs])
|
||||
cam_params = np.concatenate([np.zeros_like(cam_params[:, :1]), cam_params], 1)
|
||||
control_camera_video = process_pose_params(cam_params, width=width, height=height)
|
||||
control_camera_video = control_camera_video.permute([3, 0, 1, 2]).unsqueeze(0).to(device=comfy.model_management.intermediate_device())
|
||||
|
||||
control_camera_video = torch.concat(
|
||||
[torch.repeat_interleave(control_camera_video[:, :, 0:1], repeats=4, dim=2), control_camera_video[:, :, 1:]], dim=2
|
||||
).transpose(1, 2)
|
||||
|
||||
# Reshape, transpose, and view into desired shape
|
||||
b, f, c, h, w = control_camera_video.shape
|
||||
control_camera_video = control_camera_video.contiguous().view(b, f // 4, 4, c, h, w).transpose(2, 3)
|
||||
control_camera_video = control_camera_video.contiguous().view(b, f // 4, c * 4, h, w).transpose(1, 2)
|
||||
|
||||
return io.NodeOutput(control_camera_video, width, height, length)
|
||||
|
||||
|
||||
NODES_LIST = [
|
||||
WanCameraEmbedding,
|
||||
]
|
32
comfy_extras/v3/nodes_canny.py
Normal file
32
comfy_extras/v3/nodes_canny.py
Normal file
@@ -0,0 +1,32 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from kornia.filters import canny
|
||||
|
||||
import comfy.model_management
|
||||
from comfy_api.v3 import io
|
||||
|
||||
|
||||
class Canny(io.ComfyNodeV3):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.SchemaV3(
|
||||
node_id="Canny_V3",
|
||||
category="image/preprocessors",
|
||||
inputs=[
|
||||
io.Image.Input("image"),
|
||||
io.Float.Input("low_threshold", default=0.4, min=0.01, max=0.99, step=0.01),
|
||||
io.Float.Input("high_threshold", default=0.8, min=0.01, max=0.99, step=0.01),
|
||||
],
|
||||
outputs=[io.Image.Output()],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, image, low_threshold, high_threshold) -> io.NodeOutput:
|
||||
output = canny(image.to(comfy.model_management.get_torch_device()).movedim(-1, 1), low_threshold, high_threshold)
|
||||
img_out = output[1].to(comfy.model_management.intermediate_device()).repeat(1, 3, 1, 1).movedim(1, -1)
|
||||
return io.NodeOutput(img_out)
|
||||
|
||||
|
||||
NODES_LIST = [
|
||||
Canny,
|
||||
]
|
88
comfy_extras/v3/nodes_cfg.py
Normal file
88
comfy_extras/v3/nodes_cfg.py
Normal file
@@ -0,0 +1,88 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
|
||||
from comfy_api.v3 import io
|
||||
|
||||
|
||||
def optimized_scale(positive, negative):
|
||||
positive_flat = positive.reshape(positive.shape[0], -1)
|
||||
negative_flat = negative.reshape(negative.shape[0], -1)
|
||||
|
||||
# Calculate dot production
|
||||
dot_product = torch.sum(positive_flat * negative_flat, dim=1, keepdim=True)
|
||||
|
||||
# Squared norm of uncondition
|
||||
squared_norm = torch.sum(negative_flat ** 2, dim=1, keepdim=True) + 1e-8
|
||||
|
||||
# st_star = v_cond^T * v_uncond / ||v_uncond||^2
|
||||
st_star = dot_product / squared_norm
|
||||
|
||||
return st_star.reshape([positive.shape[0]] + [1] * (positive.ndim - 1))
|
||||
|
||||
|
||||
class CFGNorm(io.ComfyNodeV3):
|
||||
@classmethod
|
||||
def define_schema(cls) -> io.SchemaV3:
|
||||
return io.SchemaV3(
|
||||
node_id="CFGNorm_V3",
|
||||
category="advanced/guidance",
|
||||
inputs=[
|
||||
io.Model.Input("model"),
|
||||
io.Float.Input("strength", default=1.0, min=0.0, max=100.0, step=0.01),
|
||||
],
|
||||
outputs=[io.Model.Output("patched_model", display_name="patched_model")],
|
||||
is_experimental=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model, strength) -> io.NodeOutput:
|
||||
m = model.clone()
|
||||
|
||||
def cfg_norm(args):
|
||||
cond_p = args['cond_denoised']
|
||||
pred_text_ = args["denoised"]
|
||||
|
||||
norm_full_cond = torch.norm(cond_p, dim=1, keepdim=True)
|
||||
norm_pred_text = torch.norm(pred_text_, dim=1, keepdim=True)
|
||||
scale = (norm_full_cond / (norm_pred_text + 1e-8)).clamp(min=0.0, max=1.0)
|
||||
return pred_text_ * scale * strength
|
||||
|
||||
m.set_model_sampler_post_cfg_function(cfg_norm)
|
||||
return io.NodeOutput(m)
|
||||
|
||||
|
||||
class CFGZeroStar(io.ComfyNodeV3):
|
||||
@classmethod
|
||||
def define_schema(cls) -> io.SchemaV3:
|
||||
return io.SchemaV3(
|
||||
node_id="CFGZeroStar_V3",
|
||||
category="advanced/guidance",
|
||||
inputs=[
|
||||
io.Model.Input("model"),
|
||||
],
|
||||
outputs=[io.Model.Output("patched_model", display_name="patched_model")],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model) -> io.NodeOutput:
|
||||
m = model.clone()
|
||||
|
||||
def cfg_zero_star(args):
|
||||
guidance_scale = args['cond_scale']
|
||||
x = args['input']
|
||||
cond_p = args['cond_denoised']
|
||||
uncond_p = args['uncond_denoised']
|
||||
out = args["denoised"]
|
||||
alpha = optimized_scale(x - cond_p, x - uncond_p)
|
||||
|
||||
return out + uncond_p * (alpha - 1.0) + guidance_scale * uncond_p * (1.0 - alpha)
|
||||
|
||||
m.set_model_sampler_post_cfg_function(cfg_zero_star)
|
||||
return io.NodeOutput(m)
|
||||
|
||||
|
||||
NODES_LIST = [
|
||||
CFGNorm,
|
||||
CFGZeroStar,
|
||||
]
|
79
comfy_extras/v3/nodes_clip_sdxl.py
Normal file
79
comfy_extras/v3/nodes_clip_sdxl.py
Normal file
@@ -0,0 +1,79 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import nodes
|
||||
from comfy_api.v3 import io
|
||||
|
||||
|
||||
class CLIPTextEncodeSDXL(io.ComfyNodeV3):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.SchemaV3(
|
||||
node_id="CLIPTextEncodeSDXL_V3",
|
||||
category="advanced/conditioning",
|
||||
inputs=[
|
||||
io.Clip.Input("clip"),
|
||||
io.Int.Input("width", default=1024, min=0, max=nodes.MAX_RESOLUTION),
|
||||
io.Int.Input("height", default=1024, min=0, max=nodes.MAX_RESOLUTION),
|
||||
io.Int.Input("crop_w", default=0, min=0, max=nodes.MAX_RESOLUTION),
|
||||
io.Int.Input("crop_h", default=0, min=0, max=nodes.MAX_RESOLUTION),
|
||||
io.Int.Input("target_width", default=1024, min=0, max=nodes.MAX_RESOLUTION),
|
||||
io.Int.Input("target_height", default=1024, min=0, max=nodes.MAX_RESOLUTION),
|
||||
io.String.Input("text_g", multiline=True, dynamic_prompts=True),
|
||||
io.String.Input("text_l", multiline=True, dynamic_prompts=True),
|
||||
],
|
||||
outputs=[io.Conditioning.Output()],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, clip, width, height, crop_w, crop_h, target_width, target_height, text_g, text_l) -> io.NodeOutput:
|
||||
tokens = clip.tokenize(text_g)
|
||||
tokens["l"] = clip.tokenize(text_l)["l"]
|
||||
if len(tokens["l"]) != len(tokens["g"]):
|
||||
empty = clip.tokenize("")
|
||||
while len(tokens["l"]) < len(tokens["g"]):
|
||||
tokens["l"] += empty["l"]
|
||||
while len(tokens["l"]) > len(tokens["g"]):
|
||||
tokens["g"] += empty["g"]
|
||||
conditioning = clip.encode_from_tokens_scheduled(
|
||||
tokens,
|
||||
add_dict={
|
||||
"width": width,
|
||||
"height": height,
|
||||
"crop_w": crop_w,
|
||||
"crop_h": crop_h,
|
||||
"target_width": target_width,
|
||||
"target_height": target_height,
|
||||
},
|
||||
)
|
||||
return io.NodeOutput(conditioning)
|
||||
|
||||
|
||||
class CLIPTextEncodeSDXLRefiner(io.ComfyNodeV3):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.SchemaV3(
|
||||
node_id="CLIPTextEncodeSDXLRefiner_V3",
|
||||
category="advanced/conditioning",
|
||||
inputs=[
|
||||
io.Float.Input("ascore", default=6.0, min=0.0, max=1000.0, step=0.01),
|
||||
io.Int.Input("width", default=1024, min=0, max=nodes.MAX_RESOLUTION),
|
||||
io.Int.Input("height", default=1024, min=0, max=nodes.MAX_RESOLUTION),
|
||||
io.String.Input("text", multiline=True, dynamic_prompts=True),
|
||||
io.Clip.Input("clip"),
|
||||
],
|
||||
outputs=[io.Conditioning.Output()],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, ascore, width, height, text, clip) -> io.NodeOutput:
|
||||
tokens = clip.tokenize(text)
|
||||
conditioning = clip.encode_from_tokens_scheduled(
|
||||
tokens, add_dict={"aesthetic_score": ascore, "width": width, "height": height}
|
||||
)
|
||||
return io.NodeOutput(conditioning)
|
||||
|
||||
|
||||
NODES_LIST = [
|
||||
CLIPTextEncodeSDXL,
|
||||
CLIPTextEncodeSDXLRefiner,
|
||||
]
|
226
comfy_extras/v3/nodes_compositing.py
Normal file
226
comfy_extras/v3/nodes_compositing.py
Normal file
@@ -0,0 +1,226 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from enum import Enum
|
||||
|
||||
import torch
|
||||
|
||||
import comfy.utils
|
||||
from comfy_api.v3 import io
|
||||
|
||||
|
||||
def resize_mask(mask, shape):
|
||||
return torch.nn.functional.interpolate(
|
||||
mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(shape[0], shape[1]), mode="bilinear"
|
||||
).squeeze(1)
|
||||
|
||||
|
||||
class PorterDuffMode(Enum):
|
||||
ADD = 0
|
||||
CLEAR = 1
|
||||
DARKEN = 2
|
||||
DST = 3
|
||||
DST_ATOP = 4
|
||||
DST_IN = 5
|
||||
DST_OUT = 6
|
||||
DST_OVER = 7
|
||||
LIGHTEN = 8
|
||||
MULTIPLY = 9
|
||||
OVERLAY = 10
|
||||
SCREEN = 11
|
||||
SRC = 12
|
||||
SRC_ATOP = 13
|
||||
SRC_IN = 14
|
||||
SRC_OUT = 15
|
||||
SRC_OVER = 16
|
||||
XOR = 17
|
||||
|
||||
|
||||
def porter_duff_composite(
|
||||
src_image: torch.Tensor, src_alpha: torch.Tensor, dst_image: torch.Tensor, dst_alpha: torch.Tensor, mode: PorterDuffMode
|
||||
):
|
||||
# convert mask to alpha
|
||||
src_alpha = 1 - src_alpha
|
||||
dst_alpha = 1 - dst_alpha
|
||||
# premultiply alpha
|
||||
src_image = src_image * src_alpha
|
||||
dst_image = dst_image * dst_alpha
|
||||
|
||||
# composite ops below assume alpha-premultiplied images
|
||||
if mode == PorterDuffMode.ADD:
|
||||
out_alpha = torch.clamp(src_alpha + dst_alpha, 0, 1)
|
||||
out_image = torch.clamp(src_image + dst_image, 0, 1)
|
||||
elif mode == PorterDuffMode.CLEAR:
|
||||
out_alpha = torch.zeros_like(dst_alpha)
|
||||
out_image = torch.zeros_like(dst_image)
|
||||
elif mode == PorterDuffMode.DARKEN:
|
||||
out_alpha = src_alpha + dst_alpha - src_alpha * dst_alpha
|
||||
out_image = (1 - dst_alpha) * src_image + (1 - src_alpha) * dst_image + torch.min(src_image, dst_image)
|
||||
elif mode == PorterDuffMode.DST:
|
||||
out_alpha = dst_alpha
|
||||
out_image = dst_image
|
||||
elif mode == PorterDuffMode.DST_ATOP:
|
||||
out_alpha = src_alpha
|
||||
out_image = src_alpha * dst_image + (1 - dst_alpha) * src_image
|
||||
elif mode == PorterDuffMode.DST_IN:
|
||||
out_alpha = src_alpha * dst_alpha
|
||||
out_image = dst_image * src_alpha
|
||||
elif mode == PorterDuffMode.DST_OUT:
|
||||
out_alpha = (1 - src_alpha) * dst_alpha
|
||||
out_image = (1 - src_alpha) * dst_image
|
||||
elif mode == PorterDuffMode.DST_OVER:
|
||||
out_alpha = dst_alpha + (1 - dst_alpha) * src_alpha
|
||||
out_image = dst_image + (1 - dst_alpha) * src_image
|
||||
elif mode == PorterDuffMode.LIGHTEN:
|
||||
out_alpha = src_alpha + dst_alpha - src_alpha * dst_alpha
|
||||
out_image = (1 - dst_alpha) * src_image + (1 - src_alpha) * dst_image + torch.max(src_image, dst_image)
|
||||
elif mode == PorterDuffMode.MULTIPLY:
|
||||
out_alpha = src_alpha * dst_alpha
|
||||
out_image = src_image * dst_image
|
||||
elif mode == PorterDuffMode.OVERLAY:
|
||||
out_alpha = src_alpha + dst_alpha - src_alpha * dst_alpha
|
||||
out_image = torch.where(2 * dst_image < dst_alpha, 2 * src_image * dst_image,
|
||||
src_alpha * dst_alpha - 2 * (dst_alpha - src_image) * (src_alpha - dst_image))
|
||||
elif mode == PorterDuffMode.SCREEN:
|
||||
out_alpha = src_alpha + dst_alpha - src_alpha * dst_alpha
|
||||
out_image = src_image + dst_image - src_image * dst_image
|
||||
elif mode == PorterDuffMode.SRC:
|
||||
out_alpha = src_alpha
|
||||
out_image = src_image
|
||||
elif mode == PorterDuffMode.SRC_ATOP:
|
||||
out_alpha = dst_alpha
|
||||
out_image = dst_alpha * src_image + (1 - src_alpha) * dst_image
|
||||
elif mode == PorterDuffMode.SRC_IN:
|
||||
out_alpha = src_alpha * dst_alpha
|
||||
out_image = src_image * dst_alpha
|
||||
elif mode == PorterDuffMode.SRC_OUT:
|
||||
out_alpha = (1 - dst_alpha) * src_alpha
|
||||
out_image = (1 - dst_alpha) * src_image
|
||||
elif mode == PorterDuffMode.SRC_OVER:
|
||||
out_alpha = src_alpha + (1 - src_alpha) * dst_alpha
|
||||
out_image = src_image + (1 - src_alpha) * dst_image
|
||||
elif mode == PorterDuffMode.XOR:
|
||||
out_alpha = (1 - dst_alpha) * src_alpha + (1 - src_alpha) * dst_alpha
|
||||
out_image = (1 - dst_alpha) * src_image + (1 - src_alpha) * dst_image
|
||||
else:
|
||||
return None, None
|
||||
|
||||
# back to non-premultiplied alpha
|
||||
out_image = torch.where(out_alpha > 1e-5, out_image / out_alpha, torch.zeros_like(out_image))
|
||||
out_image = torch.clamp(out_image, 0, 1)
|
||||
# convert alpha to mask
|
||||
out_alpha = 1 - out_alpha
|
||||
return out_image, out_alpha
|
||||
|
||||
|
||||
class JoinImageWithAlpha(io.ComfyNodeV3):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.SchemaV3(
|
||||
node_id="JoinImageWithAlpha_V3",
|
||||
display_name="Join Image with Alpha _V3",
|
||||
category="mask/compositing",
|
||||
inputs=[
|
||||
io.Image.Input("image"),
|
||||
io.Mask.Input("alpha"),
|
||||
],
|
||||
outputs=[io.Image.Output()],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, image: torch.Tensor, alpha: torch.Tensor) -> io.NodeOutput:
|
||||
batch_size = min(len(image), len(alpha))
|
||||
out_images = []
|
||||
|
||||
alpha = 1.0 - resize_mask(alpha, image.shape[1:])
|
||||
for i in range(batch_size):
|
||||
out_images.append(torch.cat((image[i][:, :, :3], alpha[i].unsqueeze(2)), dim=2))
|
||||
|
||||
return io.NodeOutput(torch.stack(out_images))
|
||||
|
||||
|
||||
class PorterDuffImageComposite(io.ComfyNodeV3):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.SchemaV3(
|
||||
node_id="PorterDuffImageComposite_V3",
|
||||
display_name="Porter-Duff Image Composite _V3",
|
||||
category="mask/compositing",
|
||||
inputs=[
|
||||
io.Image.Input("source"),
|
||||
io.Mask.Input("source_alpha"),
|
||||
io.Image.Input("destination"),
|
||||
io.Mask.Input("destination_alpha"),
|
||||
io.Combo.Input("mode", options=[mode.name for mode in PorterDuffMode], default=PorterDuffMode.DST.name),
|
||||
],
|
||||
outputs=[io.Image.Output(), io.Mask.Output()],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(
|
||||
cls, source: torch.Tensor, source_alpha: torch.Tensor, destination: torch.Tensor, destination_alpha: torch.Tensor, mode
|
||||
) -> io.NodeOutput:
|
||||
batch_size = min(len(source), len(source_alpha), len(destination), len(destination_alpha))
|
||||
out_images = []
|
||||
out_alphas = []
|
||||
|
||||
for i in range(batch_size):
|
||||
src_image = source[i]
|
||||
dst_image = destination[i]
|
||||
|
||||
assert src_image.shape[2] == dst_image.shape[2] # inputs need to have same number of channels
|
||||
|
||||
src_alpha = source_alpha[i].unsqueeze(2)
|
||||
dst_alpha = destination_alpha[i].unsqueeze(2)
|
||||
|
||||
if dst_alpha.shape[:2] != dst_image.shape[:2]:
|
||||
upscale_input = dst_alpha.unsqueeze(0).permute(0, 3, 1, 2)
|
||||
upscale_output = comfy.utils.common_upscale(
|
||||
upscale_input, dst_image.shape[1], dst_image.shape[0], upscale_method='bicubic', crop='center'
|
||||
)
|
||||
dst_alpha = upscale_output.permute(0, 2, 3, 1).squeeze(0)
|
||||
if src_image.shape != dst_image.shape:
|
||||
upscale_input = src_image.unsqueeze(0).permute(0, 3, 1, 2)
|
||||
upscale_output = comfy.utils.common_upscale(
|
||||
upscale_input, dst_image.shape[1], dst_image.shape[0], upscale_method='bicubic', crop='center'
|
||||
)
|
||||
src_image = upscale_output.permute(0, 2, 3, 1).squeeze(0)
|
||||
if src_alpha.shape != dst_alpha.shape:
|
||||
upscale_input = src_alpha.unsqueeze(0).permute(0, 3, 1, 2)
|
||||
upscale_output = comfy.utils.common_upscale(
|
||||
upscale_input, dst_alpha.shape[1], dst_alpha.shape[0], upscale_method='bicubic', crop='center'
|
||||
)
|
||||
src_alpha = upscale_output.permute(0, 2, 3, 1).squeeze(0)
|
||||
|
||||
out_image, out_alpha = porter_duff_composite(src_image, src_alpha, dst_image, dst_alpha, PorterDuffMode[mode])
|
||||
|
||||
out_images.append(out_image)
|
||||
out_alphas.append(out_alpha.squeeze(2))
|
||||
|
||||
return io.NodeOutput(torch.stack(out_images), torch.stack(out_alphas))
|
||||
|
||||
|
||||
class SplitImageWithAlpha(io.ComfyNodeV3):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.SchemaV3(
|
||||
node_id="SplitImageWithAlpha_V3",
|
||||
display_name="Split Image with Alpha _V3",
|
||||
category="mask/compositing",
|
||||
inputs=[
|
||||
io.Image.Input("image"),
|
||||
],
|
||||
outputs=[io.Image.Output(), io.Mask.Output()],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, image: torch.Tensor) -> io.NodeOutput:
|
||||
out_images = [i[:, :, :3] for i in image]
|
||||
out_alphas = [i[:, :, 3] if i.shape[2] > 3 else torch.ones_like(i[:, :, 0]) for i in image]
|
||||
return io.NodeOutput(torch.stack(out_images), 1.0 - torch.stack(out_alphas))
|
||||
|
||||
|
||||
NODES_LIST = [
|
||||
JoinImageWithAlpha,
|
||||
PorterDuffImageComposite,
|
||||
SplitImageWithAlpha,
|
||||
]
|
60
comfy_extras/v3/nodes_cond.py
Normal file
60
comfy_extras/v3/nodes_cond.py
Normal file
@@ -0,0 +1,60 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from comfy_api.v3 import io
|
||||
|
||||
|
||||
class CLIPTextEncodeControlnet(io.ComfyNodeV3):
|
||||
@classmethod
|
||||
def define_schema(cls) -> io.SchemaV3:
|
||||
return io.SchemaV3(
|
||||
node_id="CLIPTextEncodeControlnet_V3",
|
||||
category="_for_testing/conditioning",
|
||||
inputs=[
|
||||
io.Clip.Input("clip"),
|
||||
io.Conditioning.Input("conditioning"),
|
||||
io.String.Input("text", multiline=True, dynamic_prompts=True),
|
||||
],
|
||||
outputs=[io.Conditioning.Output()],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, clip, conditioning, text) -> io.NodeOutput:
|
||||
tokens = clip.tokenize(text)
|
||||
cond, pooled = clip.encode_from_tokens(tokens, return_pooled=True)
|
||||
c = []
|
||||
for t in conditioning:
|
||||
n = [t[0], t[1].copy()]
|
||||
n[1]['cross_attn_controlnet'] = cond
|
||||
n[1]['pooled_output_controlnet'] = pooled
|
||||
c.append(n)
|
||||
return io.NodeOutput(c)
|
||||
|
||||
|
||||
class T5TokenizerOptions(io.ComfyNodeV3):
|
||||
@classmethod
|
||||
def define_schema(cls) -> io.SchemaV3:
|
||||
return io.SchemaV3(
|
||||
node_id="T5TokenizerOptions_V3",
|
||||
category="_for_testing/conditioning",
|
||||
inputs=[
|
||||
io.Clip.Input("clip"),
|
||||
io.Int.Input("min_padding", default=0, min=0, max=10000, step=1),
|
||||
io.Int.Input("min_length", default=0, min=0, max=10000, step=1),
|
||||
],
|
||||
outputs=[io.Clip.Output()],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, clip, min_padding, min_length) -> io.NodeOutput:
|
||||
clip = clip.clone()
|
||||
for t5_type in ["t5xxl", "pile_t5xl", "t5base", "mt5xl", "umt5xxl"]:
|
||||
clip.set_tokenizer_option("{}_min_padding".format(t5_type), min_padding)
|
||||
clip.set_tokenizer_option("{}_min_length".format(t5_type), min_length)
|
||||
|
||||
return io.NodeOutput(clip)
|
||||
|
||||
|
||||
NODES_LIST = [
|
||||
CLIPTextEncodeControlnet,
|
||||
T5TokenizerOptions,
|
||||
]
|
146
comfy_extras/v3/nodes_cosmos.py
Normal file
146
comfy_extras/v3/nodes_cosmos.py
Normal file
@@ -0,0 +1,146 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
|
||||
import comfy.latent_formats
|
||||
import comfy.model_management
|
||||
import comfy.utils
|
||||
import nodes
|
||||
from comfy_api.v3 import io
|
||||
|
||||
|
||||
def vae_encode_with_padding(vae, image, width, height, length, padding=0):
|
||||
pixels = comfy.utils.common_upscale(image[..., :3].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
||||
pixel_len = min(pixels.shape[0], length)
|
||||
padded_length = min(length, (((pixel_len - 1) // 8) + 1 + padding) * 8 - 7)
|
||||
padded_pixels = torch.ones((padded_length, height, width, 3)) * 0.5
|
||||
padded_pixels[:pixel_len] = pixels[:pixel_len]
|
||||
latent_len = ((pixel_len - 1) // 8) + 1
|
||||
latent_temp = vae.encode(padded_pixels)
|
||||
return latent_temp[:, :, :latent_len]
|
||||
|
||||
|
||||
class CosmosImageToVideoLatent(io.ComfyNodeV3):
|
||||
@classmethod
|
||||
def define_schema(cls) -> io.SchemaV3:
|
||||
return io.SchemaV3(
|
||||
node_id="CosmosImageToVideoLatent_V3",
|
||||
category="conditioning/inpaint",
|
||||
inputs=[
|
||||
io.Vae.Input("vae"),
|
||||
io.Int.Input("width", default=1280, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
||||
io.Int.Input("height", default=704, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
||||
io.Int.Input("length", default=121, min=1, max=nodes.MAX_RESOLUTION, step=8),
|
||||
io.Int.Input("batch_size", default=1, min=1, max=4096),
|
||||
io.Image.Input("start_image", optional=True),
|
||||
io.Image.Input("end_image", optional=True),
|
||||
],
|
||||
outputs=[io.Latent.Output()],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, vae, width, height, length, batch_size, start_image=None, end_image=None) -> io.NodeOutput:
|
||||
latent = torch.zeros([1, 16, ((length - 1) // 8) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
|
||||
if start_image is None and end_image is None:
|
||||
out_latent = {}
|
||||
out_latent["samples"] = latent
|
||||
return io.NodeOutput(out_latent)
|
||||
|
||||
mask = torch.ones(
|
||||
[latent.shape[0], 1, ((length - 1) // 8) + 1, latent.shape[-2], latent.shape[-1]],
|
||||
device=comfy.model_management.intermediate_device(),
|
||||
)
|
||||
|
||||
if start_image is not None:
|
||||
latent_temp = vae_encode_with_padding(vae, start_image, width, height, length, padding=1)
|
||||
latent[:, :, :latent_temp.shape[-3]] = latent_temp
|
||||
mask[:, :, :latent_temp.shape[-3]] *= 0.0
|
||||
|
||||
if end_image is not None:
|
||||
latent_temp = vae_encode_with_padding(vae, end_image, width, height, length, padding=0)
|
||||
latent[:, :, -latent_temp.shape[-3]:] = latent_temp
|
||||
mask[:, :, -latent_temp.shape[-3]:] *= 0.0
|
||||
|
||||
out_latent = {}
|
||||
out_latent["samples"] = latent.repeat((batch_size, ) + (1,) * (latent.ndim - 1))
|
||||
out_latent["noise_mask"] = mask.repeat((batch_size, ) + (1,) * (mask.ndim - 1))
|
||||
return io.NodeOutput(out_latent)
|
||||
|
||||
|
||||
class CosmosPredict2ImageToVideoLatent(io.ComfyNodeV3):
|
||||
@classmethod
|
||||
def define_schema(cls) -> io.SchemaV3:
|
||||
return io.SchemaV3(
|
||||
node_id="CosmosPredict2ImageToVideoLatent_V3",
|
||||
category="conditioning/inpaint",
|
||||
inputs=[
|
||||
io.Vae.Input("vae"),
|
||||
io.Int.Input("width", default=848, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
||||
io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
||||
io.Int.Input("length", default=93, min=1, max=nodes.MAX_RESOLUTION, step=4),
|
||||
io.Int.Input("batch_size", default=1, min=1, max=4096),
|
||||
io.Image.Input("start_image", optional=True),
|
||||
io.Image.Input("end_image", optional=True),
|
||||
],
|
||||
outputs=[io.Latent.Output()],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, vae, width, height, length, batch_size, start_image=None, end_image=None) -> io.NodeOutput:
|
||||
latent = torch.zeros([1, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
|
||||
if start_image is None and end_image is None:
|
||||
out_latent = {}
|
||||
out_latent["samples"] = latent
|
||||
return io.NodeOutput(out_latent)
|
||||
|
||||
mask = torch.ones(
|
||||
[latent.shape[0], 1, ((length - 1) // 4) + 1, latent.shape[-2], latent.shape[-1]],
|
||||
device=comfy.model_management.intermediate_device(),
|
||||
)
|
||||
|
||||
if start_image is not None:
|
||||
latent_temp = vae_encode_with_padding(vae, start_image, width, height, length, padding=1)
|
||||
latent[:, :, :latent_temp.shape[-3]] = latent_temp
|
||||
mask[:, :, :latent_temp.shape[-3]] *= 0.0
|
||||
|
||||
if end_image is not None:
|
||||
latent_temp = vae_encode_with_padding(vae, end_image, width, height, length, padding=0)
|
||||
latent[:, :, -latent_temp.shape[-3]:] = latent_temp
|
||||
mask[:, :, -latent_temp.shape[-3]:] *= 0.0
|
||||
|
||||
out_latent = {}
|
||||
latent_format = comfy.latent_formats.Wan21()
|
||||
latent = latent_format.process_out(latent) * mask + latent * (1.0 - mask)
|
||||
out_latent["samples"] = latent.repeat((batch_size, ) + (1,) * (latent.ndim - 1))
|
||||
out_latent["noise_mask"] = mask.repeat((batch_size, ) + (1,) * (mask.ndim - 1))
|
||||
return io.NodeOutput(out_latent)
|
||||
|
||||
|
||||
class EmptyCosmosLatentVideo(io.ComfyNodeV3):
|
||||
@classmethod
|
||||
def define_schema(cls) -> io.SchemaV3:
|
||||
return io.SchemaV3(
|
||||
node_id="EmptyCosmosLatentVideo_V3",
|
||||
category="latent/video",
|
||||
inputs=[
|
||||
io.Int.Input("width", default=1280, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
||||
io.Int.Input("height", default=704, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
||||
io.Int.Input("length", default=121, min=1, max=nodes.MAX_RESOLUTION, step=8),
|
||||
io.Int.Input("batch_size", default=1, min=1, max=4096),
|
||||
],
|
||||
outputs=[io.Latent.Output()],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, width, height, length, batch_size) -> io.NodeOutput:
|
||||
latent = torch.zeros(
|
||||
[batch_size, 16, ((length - 1) // 8) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()
|
||||
)
|
||||
return io.NodeOutput({"samples": latent})
|
||||
|
||||
|
||||
NODES_LIST = [
|
||||
CosmosImageToVideoLatent,
|
||||
CosmosPredict2ImageToVideoLatent,
|
||||
EmptyCosmosLatentVideo,
|
||||
]
|
50
comfy_extras/v3/nodes_differential_diffusion.py
Normal file
50
comfy_extras/v3/nodes_differential_diffusion.py
Normal file
@@ -0,0 +1,50 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
|
||||
from comfy_api.v3 import io
|
||||
|
||||
|
||||
class DifferentialDiffusion(io.ComfyNodeV3):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.SchemaV3(
|
||||
node_id="DifferentialDiffusion_V3",
|
||||
display_name="Differential Diffusion _V3",
|
||||
category="_for_testing",
|
||||
inputs=[
|
||||
io.Model.Input(id="model"),
|
||||
],
|
||||
outputs=[
|
||||
io.Model.Output(),
|
||||
],
|
||||
is_experimental=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model):
|
||||
model = model.clone()
|
||||
model.set_model_denoise_mask_function(cls.forward)
|
||||
return io.NodeOutput(model)
|
||||
|
||||
@classmethod
|
||||
def forward(cls, sigma: torch.Tensor, denoise_mask: torch.Tensor, extra_options: dict):
|
||||
model = extra_options["model"]
|
||||
step_sigmas = extra_options["sigmas"]
|
||||
sigma_to = model.inner_model.model_sampling.sigma_min
|
||||
if step_sigmas[-1] > sigma_to:
|
||||
sigma_to = step_sigmas[-1]
|
||||
sigma_from = step_sigmas[0]
|
||||
|
||||
ts_from = model.inner_model.model_sampling.timestep(sigma_from)
|
||||
ts_to = model.inner_model.model_sampling.timestep(sigma_to)
|
||||
current_ts = model.inner_model.model_sampling.timestep(sigma[0])
|
||||
|
||||
threshold = (current_ts - ts_to) / (ts_from - ts_to)
|
||||
|
||||
return (denoise_mask >= threshold).to(denoise_mask.dtype)
|
||||
|
||||
|
||||
NODES_LIST = [
|
||||
DifferentialDiffusion,
|
||||
]
|
125
comfy_extras/v3/nodes_flux.py
Normal file
125
comfy_extras/v3/nodes_flux.py
Normal file
@@ -0,0 +1,125 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import comfy.utils
|
||||
import node_helpers
|
||||
from comfy_api.v3 import io
|
||||
|
||||
PREFERED_KONTEXT_RESOLUTIONS = [
|
||||
(672, 1568),
|
||||
(688, 1504),
|
||||
(720, 1456),
|
||||
(752, 1392),
|
||||
(800, 1328),
|
||||
(832, 1248),
|
||||
(880, 1184),
|
||||
(944, 1104),
|
||||
(1024, 1024),
|
||||
(1104, 944),
|
||||
(1184, 880),
|
||||
(1248, 832),
|
||||
(1328, 800),
|
||||
(1392, 752),
|
||||
(1456, 720),
|
||||
(1504, 688),
|
||||
(1568, 672),
|
||||
]
|
||||
|
||||
|
||||
class CLIPTextEncodeFlux(io.ComfyNodeV3):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.SchemaV3(
|
||||
node_id="CLIPTextEncodeFlux_V3",
|
||||
category="advanced/conditioning/flux",
|
||||
inputs=[
|
||||
io.Clip.Input(id="clip"),
|
||||
io.String.Input(id="clip_l", multiline=True, dynamic_prompts=True),
|
||||
io.String.Input(id="t5xxl", multiline=True, dynamic_prompts=True),
|
||||
io.Float.Input(id="guidance", default=3.5, min=0.0, max=100.0, step=0.1),
|
||||
],
|
||||
outputs=[
|
||||
io.Conditioning.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, clip, clip_l, t5xxl, guidance):
|
||||
tokens = clip.tokenize(clip_l)
|
||||
tokens["t5xxl"] = clip.tokenize(t5xxl)["t5xxl"]
|
||||
|
||||
return io.NodeOutput(clip.encode_from_tokens_scheduled(tokens, add_dict={"guidance": guidance}))
|
||||
|
||||
|
||||
class FluxDisableGuidance(io.ComfyNodeV3):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.SchemaV3(
|
||||
node_id="FluxDisableGuidance_V3",
|
||||
category="advanced/conditioning/flux",
|
||||
description="This node completely disables the guidance embed on Flux and Flux like models",
|
||||
inputs=[
|
||||
io.Conditioning.Input(id="conditioning"),
|
||||
],
|
||||
outputs=[
|
||||
io.Conditioning.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, conditioning):
|
||||
c = node_helpers.conditioning_set_values(conditioning, {"guidance": None})
|
||||
return io.NodeOutput(c)
|
||||
|
||||
|
||||
class FluxGuidance(io.ComfyNodeV3):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.SchemaV3(
|
||||
node_id="FluxGuidance_V3",
|
||||
category="advanced/conditioning/flux",
|
||||
inputs=[
|
||||
io.Conditioning.Input(id="conditioning"),
|
||||
io.Float.Input(id="guidance", default=3.5, min=0.0, max=100.0, step=0.1),
|
||||
],
|
||||
outputs=[
|
||||
io.Conditioning.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, conditioning, guidance):
|
||||
c = node_helpers.conditioning_set_values(conditioning, {"guidance": guidance})
|
||||
return io.NodeOutput(c)
|
||||
|
||||
|
||||
class FluxKontextImageScale(io.ComfyNodeV3):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.SchemaV3(
|
||||
node_id="FluxKontextImageScale_V3",
|
||||
category="advanced/conditioning/flux",
|
||||
description="This node resizes the image to one that is more optimal for flux kontext.",
|
||||
inputs=[
|
||||
io.Image.Input(id="image"),
|
||||
],
|
||||
outputs=[
|
||||
io.Image.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, image):
|
||||
width = image.shape[2]
|
||||
height = image.shape[1]
|
||||
aspect_ratio = width / height
|
||||
_, width, height = min((abs(aspect_ratio - w / h), w, h) for w, h in PREFERED_KONTEXT_RESOLUTIONS)
|
||||
image = comfy.utils.common_upscale(image.movedim(-1, 1), width, height, "lanczos", "center").movedim(1, -1)
|
||||
return io.NodeOutput(image)
|
||||
|
||||
|
||||
NODES_LIST = [
|
||||
CLIPTextEncodeFlux,
|
||||
FluxDisableGuidance,
|
||||
FluxGuidance,
|
||||
FluxKontextImageScale,
|
||||
]
|
131
comfy_extras/v3/nodes_freelunch.py
Normal file
131
comfy_extras/v3/nodes_freelunch.py
Normal file
@@ -0,0 +1,131 @@
|
||||
#code originally taken from: https://github.com/ChenyangSi/FreeU (under MIT License)
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
import torch
|
||||
|
||||
from comfy_api.v3 import io
|
||||
|
||||
|
||||
def Fourier_filter(x, threshold, scale):
|
||||
# FFT
|
||||
x_freq = torch.fft.fftn(x.float(), dim=(-2, -1))
|
||||
x_freq = torch.fft.fftshift(x_freq, dim=(-2, -1))
|
||||
|
||||
B, C, H, W = x_freq.shape
|
||||
mask = torch.ones((B, C, H, W), device=x.device)
|
||||
|
||||
crow, ccol = H // 2, W //2
|
||||
mask[..., crow - threshold:crow + threshold, ccol - threshold:ccol + threshold] = scale
|
||||
x_freq = x_freq * mask
|
||||
|
||||
# IFFT
|
||||
x_freq = torch.fft.ifftshift(x_freq, dim=(-2, -1))
|
||||
x_filtered = torch.fft.ifftn(x_freq, dim=(-2, -1)).real
|
||||
|
||||
return x_filtered.to(x.dtype)
|
||||
|
||||
|
||||
class FreeU(io.ComfyNodeV3):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.SchemaV3(
|
||||
node_id="FreeU_V3",
|
||||
category="model_patches/unet",
|
||||
inputs=[
|
||||
io.Model.Input(id="model"),
|
||||
io.Float.Input(id="b1", default=1.1, min=0.0, max=10.0, step=0.01),
|
||||
io.Float.Input(id="b2", default=1.2, min=0.0, max=10.0, step=0.01),
|
||||
io.Float.Input(id="s1", default=0.9, min=0.0, max=10.0, step=0.01),
|
||||
io.Float.Input(id="s2", default=0.2, min=0.0, max=10.0, step=0.01),
|
||||
],
|
||||
outputs=[
|
||||
io.Model.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model, b1, b2, s1, s2):
|
||||
model_channels = model.model.model_config.unet_config["model_channels"]
|
||||
scale_dict = {model_channels * 4: (b1, s1), model_channels * 2: (b2, s2)}
|
||||
on_cpu_devices = {}
|
||||
|
||||
def output_block_patch(h, hsp, transformer_options):
|
||||
scale = scale_dict.get(int(h.shape[1]), None)
|
||||
if scale is not None:
|
||||
h[:,:h.shape[1] // 2] = h[:,:h.shape[1] // 2] * scale[0]
|
||||
if hsp.device not in on_cpu_devices:
|
||||
try:
|
||||
hsp = Fourier_filter(hsp, threshold=1, scale=scale[1])
|
||||
except Exception:
|
||||
logging.warning("Device {} does not support the torch.fft functions used in the FreeU node, switching to CPU.".format(hsp.device))
|
||||
on_cpu_devices[hsp.device] = True
|
||||
hsp = Fourier_filter(hsp.cpu(), threshold=1, scale=scale[1]).to(hsp.device)
|
||||
else:
|
||||
hsp = Fourier_filter(hsp.cpu(), threshold=1, scale=scale[1]).to(hsp.device)
|
||||
|
||||
return h, hsp
|
||||
|
||||
m = model.clone()
|
||||
m.set_model_output_block_patch(output_block_patch)
|
||||
return io.NodeOutput(m)
|
||||
|
||||
|
||||
class FreeU_V2(io.ComfyNodeV3):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.SchemaV3(
|
||||
node_id="FreeU_V2_V3",
|
||||
category="model_patches/unet",
|
||||
inputs=[
|
||||
io.Model.Input(id="model"),
|
||||
io.Float.Input(id="b1", default=1.3, min=0.0, max=10.0, step=0.01),
|
||||
io.Float.Input(id="b2", default=1.4, min=0.0, max=10.0, step=0.01),
|
||||
io.Float.Input(id="s1", default=0.9, min=0.0, max=10.0, step=0.01),
|
||||
io.Float.Input(id="s2", default=0.2, min=0.0, max=10.0, step=0.01),
|
||||
],
|
||||
outputs=[
|
||||
io.Model.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model, b1, b2, s1, s2):
|
||||
model_channels = model.model.model_config.unet_config["model_channels"]
|
||||
scale_dict = {model_channels * 4: (b1, s1), model_channels * 2: (b2, s2)}
|
||||
on_cpu_devices = {}
|
||||
|
||||
def output_block_patch(h, hsp, transformer_options):
|
||||
scale = scale_dict.get(int(h.shape[1]), None)
|
||||
if scale is not None:
|
||||
hidden_mean = h.mean(1).unsqueeze(1)
|
||||
B = hidden_mean.shape[0]
|
||||
hidden_max, _ = torch.max(hidden_mean.view(B, -1), dim=-1, keepdim=True)
|
||||
hidden_min, _ = torch.min(hidden_mean.view(B, -1), dim=-1, keepdim=True)
|
||||
hidden_mean = (hidden_mean - hidden_min.unsqueeze(2).unsqueeze(3)) / (hidden_max - hidden_min).unsqueeze(2).unsqueeze(3)
|
||||
|
||||
h[:,:h.shape[1] // 2] = h[:,:h.shape[1] // 2] * ((scale[0] - 1 ) * hidden_mean + 1)
|
||||
|
||||
if hsp.device not in on_cpu_devices:
|
||||
try:
|
||||
hsp = Fourier_filter(hsp, threshold=1, scale=scale[1])
|
||||
except Exception:
|
||||
logging.warning("Device {} does not support the torch.fft functions used in the FreeU node, switching to CPU.".format(hsp.device))
|
||||
on_cpu_devices[hsp.device] = True
|
||||
hsp = Fourier_filter(hsp.cpu(), threshold=1, scale=scale[1]).to(hsp.device)
|
||||
else:
|
||||
hsp = Fourier_filter(hsp.cpu(), threshold=1, scale=scale[1]).to(hsp.device)
|
||||
|
||||
return h, hsp
|
||||
|
||||
m = model.clone()
|
||||
m.set_model_output_block_patch(output_block_patch)
|
||||
return io.NodeOutput(m)
|
||||
|
||||
|
||||
NODES_LIST = [
|
||||
FreeU,
|
||||
FreeU_V2,
|
||||
]
|
110
comfy_extras/v3/nodes_fresca.py
Normal file
110
comfy_extras/v3/nodes_fresca.py
Normal file
@@ -0,0 +1,110 @@
|
||||
# Code based on https://github.com/WikiChao/FreSca (MIT License)
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
import torch.fft as fft
|
||||
|
||||
from comfy_api.v3 import io
|
||||
|
||||
|
||||
def Fourier_filter(x, scale_low=1.0, scale_high=1.5, freq_cutoff=20):
|
||||
"""
|
||||
Apply frequency-dependent scaling to an image tensor using Fourier transforms.
|
||||
|
||||
Parameters:
|
||||
x: Input tensor of shape (B, C, H, W)
|
||||
scale_low: Scaling factor for low-frequency components (default: 1.0)
|
||||
scale_high: Scaling factor for high-frequency components (default: 1.5)
|
||||
freq_cutoff: Number of frequency indices around center to consider as low-frequency (default: 20)
|
||||
|
||||
Returns:
|
||||
x_filtered: Filtered version of x in spatial domain with frequency-specific scaling applied.
|
||||
"""
|
||||
# Preserve input dtype and device
|
||||
dtype, device = x.dtype, x.device
|
||||
|
||||
# Convert to float32 for FFT computations
|
||||
x = x.to(torch.float32)
|
||||
|
||||
# 1) Apply FFT and shift low frequencies to center
|
||||
x_freq = fft.fftn(x, dim=(-2, -1))
|
||||
x_freq = fft.fftshift(x_freq, dim=(-2, -1))
|
||||
|
||||
# Initialize mask with high-frequency scaling factor
|
||||
mask = torch.ones(x_freq.shape, device=device) * scale_high
|
||||
m = mask
|
||||
for d in range(len(x_freq.shape) - 2):
|
||||
dim = d + 2
|
||||
cc = x_freq.shape[dim] // 2
|
||||
f_c = min(freq_cutoff, cc)
|
||||
m = m.narrow(dim, cc - f_c, f_c * 2)
|
||||
|
||||
# Apply low-frequency scaling factor to center region
|
||||
m[:] = scale_low
|
||||
|
||||
# 3) Apply frequency-specific scaling
|
||||
x_freq = x_freq * mask
|
||||
|
||||
# 4) Convert back to spatial domain
|
||||
x_freq = fft.ifftshift(x_freq, dim=(-2, -1))
|
||||
x_filtered = fft.ifftn(x_freq, dim=(-2, -1)).real
|
||||
|
||||
# 5) Restore original dtype
|
||||
x_filtered = x_filtered.to(dtype)
|
||||
|
||||
return x_filtered
|
||||
|
||||
|
||||
class FreSca(io.ComfyNodeV3):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.SchemaV3(
|
||||
node_id="FreSca_V3",
|
||||
display_name="FreSca _V3",
|
||||
category="_for_testing",
|
||||
description="Applies frequency-dependent scaling to the guidance",
|
||||
inputs=[
|
||||
io.Model.Input(id="model"),
|
||||
io.Float.Input(id="scale_low", default=1.0, min=0, max=10, step=0.01,
|
||||
tooltip="Scaling factor for low-frequency components"),
|
||||
io.Float.Input(id="scale_high", default=1.25, min=0, max=10, step=0.01,
|
||||
tooltip="Scaling factor for high-frequency components"),
|
||||
io.Int.Input(id="freq_cutoff", default=20, min=1, max=10000, step=1,
|
||||
tooltip="Number of frequency indices around center to consider as low-frequency"),
|
||||
],
|
||||
outputs=[
|
||||
io.Model.Output(),
|
||||
],
|
||||
is_experimental=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model, scale_low, scale_high, freq_cutoff):
|
||||
def custom_cfg_function(args):
|
||||
conds_out = args["conds_out"]
|
||||
if len(conds_out) <= 1 or None in args["conds"][:2]:
|
||||
return conds_out
|
||||
cond = conds_out[0]
|
||||
uncond = conds_out[1]
|
||||
|
||||
guidance = cond - uncond
|
||||
filtered_guidance = Fourier_filter(
|
||||
guidance,
|
||||
scale_low=scale_low,
|
||||
scale_high=scale_high,
|
||||
freq_cutoff=freq_cutoff,
|
||||
)
|
||||
filtered_cond = filtered_guidance + uncond
|
||||
|
||||
return [filtered_cond, uncond] + conds_out[2:]
|
||||
|
||||
m = model.clone()
|
||||
m.set_model_sampler_pre_cfg_function(custom_cfg_function)
|
||||
|
||||
return io.NodeOutput(m)
|
||||
|
||||
|
||||
NODES_LIST = [
|
||||
FreSca,
|
||||
]
|
376
comfy_extras/v3/nodes_gits.py
Normal file
376
comfy_extras/v3/nodes_gits.py
Normal file
@@ -0,0 +1,376 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from comfy_api.v3 import io
|
||||
|
||||
|
||||
def loglinear_interp(t_steps, num_steps):
|
||||
"""Performs log-linear interpolation of a given array of decreasing numbers."""
|
||||
xs = np.linspace(0, 1, len(t_steps))
|
||||
ys = np.log(t_steps[::-1])
|
||||
|
||||
new_xs = np.linspace(0, 1, num_steps)
|
||||
new_ys = np.interp(new_xs, xs, ys)
|
||||
|
||||
return np.exp(new_ys)[::-1].copy()
|
||||
|
||||
|
||||
NOISE_LEVELS = {
|
||||
0.80: [
|
||||
[14.61464119, 7.49001646, 0.02916753],
|
||||
[14.61464119, 11.54541874, 6.77309084, 0.02916753],
|
||||
[14.61464119, 11.54541874, 7.49001646, 3.07277966, 0.02916753],
|
||||
[14.61464119, 11.54541874, 7.49001646, 5.85520077, 2.05039096, 0.02916753],
|
||||
[14.61464119, 12.2308979, 8.75849152, 7.49001646, 5.85520077, 2.05039096, 0.02916753],
|
||||
[14.61464119, 12.2308979, 8.75849152, 7.49001646, 5.85520077, 3.07277966, 1.56271636, 0.02916753],
|
||||
[14.61464119, 12.96784878, 11.54541874, 8.75849152, 7.49001646, 5.85520077, 3.07277966, 1.56271636, 0.02916753],
|
||||
[14.61464119, 13.76078796, 12.2308979, 10.90732002, 8.75849152, 7.49001646, 5.85520077, 3.07277966, 1.56271636, 0.02916753],
|
||||
[14.61464119, 13.76078796, 12.96784878, 12.2308979, 10.90732002, 8.75849152, 7.49001646, 5.85520077, 3.07277966, 1.56271636, 0.02916753],
|
||||
[14.61464119, 13.76078796, 12.96784878, 12.2308979, 10.90732002, 9.24142551, 8.30717278, 7.49001646, 5.85520077, 3.07277966, 1.56271636, 0.02916753],
|
||||
[14.61464119, 13.76078796, 12.96784878, 12.2308979, 10.90732002, 9.24142551, 8.30717278, 7.49001646, 6.14220476, 4.86714602, 3.07277966, 1.56271636, 0.02916753],
|
||||
[14.61464119, 13.76078796, 12.96784878, 12.2308979, 11.54541874, 10.31284904, 9.24142551, 8.30717278, 7.49001646, 6.14220476, 4.86714602, 3.07277966, 1.56271636, 0.02916753],
|
||||
[14.61464119, 13.76078796, 12.96784878, 12.2308979, 11.54541874, 10.90732002, 10.31284904, 9.24142551, 8.30717278, 7.49001646, 6.14220476, 4.86714602, 3.07277966, 1.56271636, 0.02916753],
|
||||
[14.61464119, 13.76078796, 12.96784878, 12.2308979, 11.54541874, 10.90732002, 10.31284904, 9.24142551, 8.75849152, 8.30717278, 7.49001646, 6.14220476, 4.86714602, 3.07277966, 1.56271636, 0.02916753],
|
||||
[14.61464119, 13.76078796, 12.96784878, 12.2308979, 11.54541874, 10.90732002, 10.31284904, 9.24142551, 8.75849152, 8.30717278, 7.49001646, 6.14220476, 4.86714602, 3.1956799, 1.98035145, 0.86115354, 0.02916753],
|
||||
[14.61464119, 13.76078796, 12.96784878, 12.2308979, 11.54541874, 10.90732002, 10.31284904, 9.75859547, 9.24142551, 8.75849152, 8.30717278, 7.49001646, 6.14220476, 4.86714602, 3.1956799, 1.98035145, 0.86115354, 0.02916753],
|
||||
[14.61464119, 13.76078796, 12.96784878, 12.2308979, 11.54541874, 10.90732002, 10.31284904, 9.75859547, 9.24142551, 8.75849152, 8.30717278, 7.49001646, 6.77309084, 5.85520077, 4.65472794, 3.07277966, 1.84880662, 0.83188516, 0.02916753],
|
||||
[14.61464119, 13.76078796, 12.96784878, 12.2308979, 11.54541874, 10.90732002, 10.31284904, 9.75859547, 9.24142551, 8.75849152, 8.30717278, 7.88507891, 7.49001646, 6.77309084, 5.85520077, 4.65472794, 3.07277966, 1.84880662, 0.83188516, 0.02916753],
|
||||
[14.61464119, 13.76078796, 12.96784878, 12.2308979, 11.54541874, 10.90732002, 10.31284904, 9.75859547, 9.24142551, 8.75849152, 8.30717278, 7.88507891, 7.49001646, 6.77309084, 5.85520077, 4.86714602, 3.75677586, 2.84484982, 1.78698075, 0.803307, 0.02916753],
|
||||
],
|
||||
0.85: [
|
||||
[14.61464119, 7.49001646, 0.02916753],
|
||||
[14.61464119, 7.49001646, 1.84880662, 0.02916753],
|
||||
[14.61464119, 11.54541874, 6.77309084, 1.56271636, 0.02916753],
|
||||
[14.61464119, 11.54541874, 7.11996698, 3.07277966, 1.24153244, 0.02916753],
|
||||
[14.61464119, 11.54541874, 7.49001646, 5.09240818, 2.84484982, 0.95350921, 0.02916753],
|
||||
[14.61464119, 12.2308979, 8.75849152, 7.49001646, 5.09240818, 2.84484982, 0.95350921, 0.02916753],
|
||||
[14.61464119, 12.2308979, 8.75849152, 7.49001646, 5.58536053, 3.1956799, 1.84880662, 0.803307, 0.02916753],
|
||||
[14.61464119, 12.96784878, 11.54541874, 8.75849152, 7.49001646, 5.58536053, 3.1956799, 1.84880662, 0.803307, 0.02916753],
|
||||
[14.61464119, 12.96784878, 11.54541874, 8.75849152, 7.49001646, 6.14220476, 4.65472794, 3.07277966, 1.84880662, 0.803307, 0.02916753],
|
||||
[14.61464119, 13.76078796, 12.2308979, 10.90732002, 8.75849152, 7.49001646, 6.14220476, 4.65472794, 3.07277966, 1.84880662, 0.803307, 0.02916753],
|
||||
[14.61464119, 13.76078796, 12.2308979, 10.90732002, 9.24142551, 8.30717278, 7.49001646, 6.14220476, 4.65472794, 3.07277966, 1.84880662, 0.803307, 0.02916753],
|
||||
[14.61464119, 13.76078796, 12.96784878, 12.2308979, 10.90732002, 9.24142551, 8.30717278, 7.49001646, 6.14220476, 4.65472794, 3.07277966, 1.84880662, 0.803307, 0.02916753],
|
||||
[14.61464119, 13.76078796, 12.96784878, 12.2308979, 11.54541874, 10.31284904, 9.24142551, 8.30717278, 7.49001646, 6.14220476, 4.65472794, 3.07277966, 1.84880662, 0.803307, 0.02916753],
|
||||
[14.61464119, 13.76078796, 12.96784878, 12.2308979, 11.54541874, 10.31284904, 9.24142551, 8.30717278, 7.49001646, 6.14220476, 4.86714602, 3.60512662, 2.6383388, 1.56271636, 0.72133851, 0.02916753],
|
||||
[14.61464119, 13.76078796, 12.96784878, 12.2308979, 11.54541874, 10.31284904, 9.24142551, 8.30717278, 7.49001646, 6.77309084, 5.85520077, 4.65472794, 3.46139455, 2.45070267, 1.56271636, 0.72133851, 0.02916753],
|
||||
[14.61464119, 13.76078796, 12.96784878, 12.2308979, 11.54541874, 10.31284904, 9.24142551, 8.75849152, 8.30717278, 7.49001646, 6.77309084, 5.85520077, 4.65472794, 3.46139455, 2.45070267, 1.56271636, 0.72133851, 0.02916753],
|
||||
[14.61464119, 13.76078796, 12.96784878, 12.2308979, 11.54541874, 10.90732002, 10.31284904, 9.24142551, 8.75849152, 8.30717278, 7.49001646, 6.77309084, 5.85520077, 4.65472794, 3.46139455, 2.45070267, 1.56271636, 0.72133851, 0.02916753],
|
||||
[14.61464119, 13.76078796, 12.96784878, 12.2308979, 11.54541874, 10.90732002, 10.31284904, 9.75859547, 9.24142551, 8.75849152, 8.30717278, 7.49001646, 6.77309084, 5.85520077, 4.65472794, 3.46139455, 2.45070267, 1.56271636, 0.72133851, 0.02916753],
|
||||
[14.61464119, 13.76078796, 12.96784878, 12.2308979, 11.54541874, 10.90732002, 10.31284904, 9.75859547, 9.24142551, 8.75849152, 8.30717278, 7.88507891, 7.49001646, 6.77309084, 5.85520077, 4.65472794, 3.46139455, 2.45070267, 1.56271636, 0.72133851, 0.02916753],
|
||||
],
|
||||
0.90: [
|
||||
[14.61464119, 6.77309084, 0.02916753],
|
||||
[14.61464119, 7.49001646, 1.56271636, 0.02916753],
|
||||
[14.61464119, 7.49001646, 3.07277966, 0.95350921, 0.02916753],
|
||||
[14.61464119, 7.49001646, 4.86714602, 2.54230714, 0.89115214, 0.02916753],
|
||||
[14.61464119, 11.54541874, 7.49001646, 4.86714602, 2.54230714, 0.89115214, 0.02916753],
|
||||
[14.61464119, 11.54541874, 7.49001646, 5.09240818, 3.07277966, 1.61558151, 0.69515091, 0.02916753],
|
||||
[14.61464119, 12.2308979, 8.75849152, 7.11996698, 4.86714602, 3.07277966, 1.61558151, 0.69515091, 0.02916753],
|
||||
[14.61464119, 12.2308979, 8.75849152, 7.49001646, 5.85520077, 4.45427561, 2.95596409, 1.61558151, 0.69515091, 0.02916753],
|
||||
[14.61464119, 12.2308979, 8.75849152, 7.49001646, 5.85520077, 4.45427561, 3.1956799, 2.19988537, 1.24153244, 0.57119018, 0.02916753],
|
||||
[14.61464119, 12.96784878, 10.90732002, 8.75849152, 7.49001646, 5.85520077, 4.45427561, 3.1956799, 2.19988537, 1.24153244, 0.57119018, 0.02916753],
|
||||
[14.61464119, 12.96784878, 11.54541874, 9.24142551, 8.30717278, 7.49001646, 5.85520077, 4.45427561, 3.1956799, 2.19988537, 1.24153244, 0.57119018, 0.02916753],
|
||||
[14.61464119, 12.96784878, 11.54541874, 9.24142551, 8.30717278, 7.49001646, 6.14220476, 4.86714602, 3.75677586, 2.84484982, 1.84880662, 1.08895338, 0.52423614, 0.02916753],
|
||||
[14.61464119, 13.76078796, 12.2308979, 10.90732002, 9.24142551, 8.30717278, 7.49001646, 6.14220476, 4.86714602, 3.75677586, 2.84484982, 1.84880662, 1.08895338, 0.52423614, 0.02916753],
|
||||
[14.61464119, 13.76078796, 12.2308979, 10.90732002, 9.24142551, 8.30717278, 7.49001646, 6.44769001, 5.58536053, 4.45427561, 3.32507086, 2.45070267, 1.61558151, 0.95350921, 0.45573691, 0.02916753],
|
||||
[14.61464119, 13.76078796, 12.96784878, 12.2308979, 10.90732002, 9.24142551, 8.30717278, 7.49001646, 6.44769001, 5.58536053, 4.45427561, 3.32507086, 2.45070267, 1.61558151, 0.95350921, 0.45573691, 0.02916753],
|
||||
[14.61464119, 13.76078796, 12.96784878, 12.2308979, 10.90732002, 9.24142551, 8.30717278, 7.49001646, 6.77309084, 5.85520077, 4.86714602, 3.91689563, 3.07277966, 2.27973175, 1.56271636, 0.95350921, 0.45573691, 0.02916753],
|
||||
[14.61464119, 13.76078796, 12.96784878, 12.2308979, 11.54541874, 10.31284904, 9.24142551, 8.30717278, 7.49001646, 6.77309084, 5.85520077, 4.86714602, 3.91689563, 3.07277966, 2.27973175, 1.56271636, 0.95350921, 0.45573691, 0.02916753],
|
||||
[14.61464119, 13.76078796, 12.96784878, 12.2308979, 11.54541874, 10.31284904, 9.24142551, 8.75849152, 8.30717278, 7.49001646, 6.77309084, 5.85520077, 4.86714602, 3.91689563, 3.07277966, 2.27973175, 1.56271636, 0.95350921, 0.45573691, 0.02916753],
|
||||
[14.61464119, 13.76078796, 12.96784878, 12.2308979, 11.54541874, 10.31284904, 9.24142551, 8.75849152, 8.30717278, 7.49001646, 6.77309084, 5.85520077, 5.09240818, 4.45427561, 3.60512662, 2.95596409, 2.19988537, 1.51179266, 0.89115214, 0.43325692, 0.02916753],
|
||||
],
|
||||
0.95: [
|
||||
[14.61464119, 6.77309084, 0.02916753],
|
||||
[14.61464119, 6.77309084, 1.56271636, 0.02916753],
|
||||
[14.61464119, 7.49001646, 2.84484982, 0.89115214, 0.02916753],
|
||||
[14.61464119, 7.49001646, 4.86714602, 2.36326075, 0.803307, 0.02916753],
|
||||
[14.61464119, 7.49001646, 4.86714602, 2.95596409, 1.56271636, 0.64427125, 0.02916753],
|
||||
[14.61464119, 11.54541874, 7.49001646, 4.86714602, 2.95596409, 1.56271636, 0.64427125, 0.02916753],
|
||||
[14.61464119, 11.54541874, 7.49001646, 4.86714602, 3.07277966, 1.91321158, 1.08895338, 0.50118381, 0.02916753],
|
||||
[14.61464119, 11.54541874, 7.49001646, 5.85520077, 4.45427561, 3.07277966, 1.91321158, 1.08895338, 0.50118381, 0.02916753],
|
||||
[14.61464119, 12.2308979, 8.75849152, 7.49001646, 5.85520077, 4.45427561, 3.07277966, 1.91321158, 1.08895338, 0.50118381, 0.02916753],
|
||||
[14.61464119, 12.2308979, 8.75849152, 7.49001646, 5.85520077, 4.45427561, 3.1956799, 2.19988537, 1.41535246, 0.803307, 0.38853383, 0.02916753],
|
||||
[14.61464119, 12.2308979, 8.75849152, 7.49001646, 5.85520077, 4.65472794, 3.46139455, 2.6383388, 1.84880662, 1.24153244, 0.72133851, 0.34370604, 0.02916753],
|
||||
[14.61464119, 12.96784878, 10.90732002, 8.75849152, 7.49001646, 5.85520077, 4.65472794, 3.46139455, 2.6383388, 1.84880662, 1.24153244, 0.72133851, 0.34370604, 0.02916753],
|
||||
[14.61464119, 12.96784878, 10.90732002, 8.75849152, 7.49001646, 6.14220476, 4.86714602, 3.75677586, 2.95596409, 2.19988537, 1.56271636, 1.05362725, 0.64427125, 0.32104823, 0.02916753],
|
||||
[14.61464119, 12.96784878, 10.90732002, 8.75849152, 7.49001646, 6.44769001, 5.58536053, 4.65472794, 3.60512662, 2.95596409, 2.19988537, 1.56271636, 1.05362725, 0.64427125, 0.32104823, 0.02916753],
|
||||
[14.61464119, 12.96784878, 11.54541874, 9.24142551, 8.30717278, 7.49001646, 6.44769001, 5.58536053, 4.65472794, 3.60512662, 2.95596409, 2.19988537, 1.56271636, 1.05362725, 0.64427125, 0.32104823, 0.02916753],
|
||||
[14.61464119, 12.96784878, 11.54541874, 9.24142551, 8.30717278, 7.49001646, 6.44769001, 5.58536053, 4.65472794, 3.75677586, 3.07277966, 2.45070267, 1.78698075, 1.24153244, 0.83188516, 0.50118381, 0.22545385, 0.02916753],
|
||||
[14.61464119, 12.96784878, 11.54541874, 9.24142551, 8.30717278, 7.49001646, 6.77309084, 5.85520077, 5.09240818, 4.45427561, 3.60512662, 2.95596409, 2.36326075, 1.72759056, 1.24153244, 0.83188516, 0.50118381, 0.22545385, 0.02916753],
|
||||
[14.61464119, 13.76078796, 12.2308979, 10.90732002, 9.24142551, 8.30717278, 7.49001646, 6.77309084, 5.85520077, 5.09240818, 4.45427561, 3.60512662, 2.95596409, 2.36326075, 1.72759056, 1.24153244, 0.83188516, 0.50118381, 0.22545385, 0.02916753],
|
||||
[14.61464119, 13.76078796, 12.2308979, 10.90732002, 9.24142551, 8.30717278, 7.49001646, 6.77309084, 5.85520077, 5.09240818, 4.45427561, 3.75677586, 3.07277966, 2.45070267, 1.91321158, 1.46270394, 1.05362725, 0.72133851, 0.43325692, 0.19894916, 0.02916753],
|
||||
],
|
||||
1.00: [
|
||||
[14.61464119, 1.56271636, 0.02916753],
|
||||
[14.61464119, 6.77309084, 0.95350921, 0.02916753],
|
||||
[14.61464119, 6.77309084, 2.36326075, 0.803307, 0.02916753],
|
||||
[14.61464119, 7.11996698, 3.07277966, 1.56271636, 0.59516323, 0.02916753],
|
||||
[14.61464119, 7.49001646, 4.86714602, 2.84484982, 1.41535246, 0.57119018, 0.02916753],
|
||||
[14.61464119, 7.49001646, 4.86714602, 2.84484982, 1.61558151, 0.86115354, 0.38853383, 0.02916753],
|
||||
[14.61464119, 11.54541874, 7.49001646, 4.86714602, 2.84484982, 1.61558151, 0.86115354, 0.38853383, 0.02916753],
|
||||
[14.61464119, 11.54541874, 7.49001646, 4.86714602, 3.07277966, 1.98035145, 1.24153244, 0.72133851, 0.34370604, 0.02916753],
|
||||
[14.61464119, 11.54541874, 7.49001646, 5.85520077, 4.45427561, 3.07277966, 1.98035145, 1.24153244, 0.72133851, 0.34370604, 0.02916753],
|
||||
[14.61464119, 11.54541874, 7.49001646, 5.85520077, 4.45427561, 3.1956799, 2.27973175, 1.51179266, 0.95350921, 0.54755926, 0.25053367, 0.02916753],
|
||||
[14.61464119, 11.54541874, 7.49001646, 5.85520077, 4.45427561, 3.1956799, 2.36326075, 1.61558151, 1.08895338, 0.72133851, 0.41087446, 0.17026083, 0.02916753],
|
||||
[14.61464119, 11.54541874, 8.75849152, 7.49001646, 5.85520077, 4.45427561, 3.1956799, 2.36326075, 1.61558151, 1.08895338, 0.72133851, 0.41087446, 0.17026083, 0.02916753],
|
||||
[14.61464119, 11.54541874, 8.75849152, 7.49001646, 5.85520077, 4.65472794, 3.60512662, 2.84484982, 2.12350607, 1.56271636, 1.08895338, 0.72133851, 0.41087446, 0.17026083, 0.02916753],
|
||||
[14.61464119, 11.54541874, 8.75849152, 7.49001646, 5.85520077, 4.65472794, 3.60512662, 2.84484982, 2.19988537, 1.61558151, 1.162866, 0.803307, 0.50118381, 0.27464288, 0.09824532, 0.02916753],
|
||||
[14.61464119, 11.54541874, 8.75849152, 7.49001646, 5.85520077, 4.65472794, 3.75677586, 3.07277966, 2.45070267, 1.84880662, 1.36964464, 1.01931262, 0.72133851, 0.45573691, 0.25053367, 0.09824532, 0.02916753],
|
||||
[14.61464119, 11.54541874, 8.75849152, 7.49001646, 6.14220476, 5.09240818, 4.26497746, 3.46139455, 2.84484982, 2.19988537, 1.67050016, 1.24153244, 0.92192322, 0.64427125, 0.43325692, 0.25053367, 0.09824532, 0.02916753],
|
||||
[14.61464119, 11.54541874, 8.75849152, 7.49001646, 6.14220476, 5.09240818, 4.26497746, 3.60512662, 2.95596409, 2.45070267, 1.91321158, 1.51179266, 1.12534678, 0.83188516, 0.59516323, 0.38853383, 0.22545385, 0.09824532, 0.02916753],
|
||||
[14.61464119, 12.2308979, 9.24142551, 8.30717278, 7.49001646, 6.14220476, 5.09240818, 4.26497746, 3.60512662, 2.95596409, 2.45070267, 1.91321158, 1.51179266, 1.12534678, 0.83188516, 0.59516323, 0.38853383, 0.22545385, 0.09824532, 0.02916753],
|
||||
[14.61464119, 12.2308979, 9.24142551, 8.30717278, 7.49001646, 6.77309084, 5.85520077, 5.09240818, 4.26497746, 3.60512662, 2.95596409, 2.45070267, 1.91321158, 1.51179266, 1.12534678, 0.83188516, 0.59516323, 0.38853383, 0.22545385, 0.09824532, 0.02916753],
|
||||
],
|
||||
1.05: [
|
||||
[14.61464119, 0.95350921, 0.02916753],
|
||||
[14.61464119, 6.77309084, 0.89115214, 0.02916753],
|
||||
[14.61464119, 6.77309084, 2.05039096, 0.72133851, 0.02916753],
|
||||
[14.61464119, 6.77309084, 2.84484982, 1.28281462, 0.52423614, 0.02916753],
|
||||
[14.61464119, 6.77309084, 3.07277966, 1.61558151, 0.803307, 0.34370604, 0.02916753],
|
||||
[14.61464119, 7.49001646, 4.86714602, 2.84484982, 1.56271636, 0.803307, 0.34370604, 0.02916753],
|
||||
[14.61464119, 7.49001646, 4.86714602, 2.84484982, 1.61558151, 0.95350921, 0.52423614, 0.22545385, 0.02916753],
|
||||
[14.61464119, 7.49001646, 4.86714602, 3.07277966, 1.98035145, 1.24153244, 0.74807048, 0.41087446, 0.17026083, 0.02916753],
|
||||
[14.61464119, 7.49001646, 4.86714602, 3.1956799, 2.27973175, 1.51179266, 0.95350921, 0.59516323, 0.34370604, 0.13792117, 0.02916753],
|
||||
[14.61464119, 7.49001646, 5.09240818, 3.46139455, 2.45070267, 1.61558151, 1.08895338, 0.72133851, 0.45573691, 0.25053367, 0.09824532, 0.02916753],
|
||||
[14.61464119, 11.54541874, 7.49001646, 5.09240818, 3.46139455, 2.45070267, 1.61558151, 1.08895338, 0.72133851, 0.45573691, 0.25053367, 0.09824532, 0.02916753],
|
||||
[14.61464119, 11.54541874, 7.49001646, 5.85520077, 4.45427561, 3.1956799, 2.36326075, 1.61558151, 1.08895338, 0.72133851, 0.45573691, 0.25053367, 0.09824532, 0.02916753],
|
||||
[14.61464119, 11.54541874, 7.49001646, 5.85520077, 4.45427561, 3.1956799, 2.45070267, 1.72759056, 1.24153244, 0.86115354, 0.59516323, 0.38853383, 0.22545385, 0.09824532, 0.02916753],
|
||||
[14.61464119, 11.54541874, 7.49001646, 5.85520077, 4.65472794, 3.60512662, 2.84484982, 2.19988537, 1.61558151, 1.162866, 0.83188516, 0.59516323, 0.38853383, 0.22545385, 0.09824532, 0.02916753],
|
||||
[14.61464119, 11.54541874, 7.49001646, 5.85520077, 4.65472794, 3.60512662, 2.84484982, 2.19988537, 1.67050016, 1.28281462, 0.95350921, 0.72133851, 0.52423614, 0.34370604, 0.19894916, 0.09824532, 0.02916753],
|
||||
[14.61464119, 11.54541874, 7.49001646, 5.85520077, 4.65472794, 3.60512662, 2.95596409, 2.36326075, 1.84880662, 1.41535246, 1.08895338, 0.83188516, 0.61951244, 0.45573691, 0.32104823, 0.19894916, 0.09824532, 0.02916753],
|
||||
[14.61464119, 11.54541874, 7.49001646, 5.85520077, 4.65472794, 3.60512662, 2.95596409, 2.45070267, 1.91321158, 1.51179266, 1.20157266, 0.95350921, 0.74807048, 0.57119018, 0.43325692, 0.29807833, 0.19894916, 0.09824532, 0.02916753],
|
||||
[14.61464119, 11.54541874, 8.30717278, 7.11996698, 5.85520077, 4.65472794, 3.60512662, 2.95596409, 2.45070267, 1.91321158, 1.51179266, 1.20157266, 0.95350921, 0.74807048, 0.57119018, 0.43325692, 0.29807833, 0.19894916, 0.09824532, 0.02916753],
|
||||
[14.61464119, 11.54541874, 8.30717278, 7.11996698, 5.85520077, 4.65472794, 3.60512662, 2.95596409, 2.45070267, 1.98035145, 1.61558151, 1.32549286, 1.08895338, 0.86115354, 0.69515091, 0.54755926, 0.41087446, 0.29807833, 0.19894916, 0.09824532, 0.02916753],
|
||||
],
|
||||
1.10: [
|
||||
[14.61464119, 0.89115214, 0.02916753],
|
||||
[14.61464119, 2.36326075, 0.72133851, 0.02916753],
|
||||
[14.61464119, 5.85520077, 1.61558151, 0.57119018, 0.02916753],
|
||||
[14.61464119, 6.77309084, 2.45070267, 1.08895338, 0.45573691, 0.02916753],
|
||||
[14.61464119, 6.77309084, 2.95596409, 1.56271636, 0.803307, 0.34370604, 0.02916753],
|
||||
[14.61464119, 6.77309084, 3.07277966, 1.61558151, 0.89115214, 0.4783645, 0.19894916, 0.02916753],
|
||||
[14.61464119, 6.77309084, 3.07277966, 1.84880662, 1.08895338, 0.64427125, 0.34370604, 0.13792117, 0.02916753],
|
||||
[14.61464119, 7.49001646, 4.86714602, 2.84484982, 1.61558151, 0.95350921, 0.54755926, 0.27464288, 0.09824532, 0.02916753],
|
||||
[14.61464119, 7.49001646, 4.86714602, 2.95596409, 1.91321158, 1.24153244, 0.803307, 0.4783645, 0.25053367, 0.09824532, 0.02916753],
|
||||
[14.61464119, 7.49001646, 4.86714602, 3.07277966, 2.05039096, 1.41535246, 0.95350921, 0.64427125, 0.41087446, 0.22545385, 0.09824532, 0.02916753],
|
||||
[14.61464119, 7.49001646, 4.86714602, 3.1956799, 2.27973175, 1.61558151, 1.12534678, 0.803307, 0.54755926, 0.36617002, 0.22545385, 0.09824532, 0.02916753],
|
||||
[14.61464119, 7.49001646, 4.86714602, 3.32507086, 2.45070267, 1.72759056, 1.24153244, 0.89115214, 0.64427125, 0.45573691, 0.32104823, 0.19894916, 0.09824532, 0.02916753],
|
||||
[14.61464119, 7.49001646, 5.09240818, 3.60512662, 2.84484982, 2.05039096, 1.51179266, 1.08895338, 0.803307, 0.59516323, 0.43325692, 0.29807833, 0.19894916, 0.09824532, 0.02916753],
|
||||
[14.61464119, 7.49001646, 5.09240818, 3.60512662, 2.84484982, 2.12350607, 1.61558151, 1.24153244, 0.95350921, 0.72133851, 0.54755926, 0.41087446, 0.29807833, 0.19894916, 0.09824532, 0.02916753],
|
||||
[14.61464119, 7.49001646, 5.85520077, 4.45427561, 3.1956799, 2.45070267, 1.84880662, 1.41535246, 1.08895338, 0.83188516, 0.64427125, 0.50118381, 0.36617002, 0.25053367, 0.17026083, 0.09824532, 0.02916753],
|
||||
[14.61464119, 7.49001646, 5.85520077, 4.45427561, 3.1956799, 2.45070267, 1.91321158, 1.51179266, 1.20157266, 0.95350921, 0.74807048, 0.59516323, 0.45573691, 0.34370604, 0.25053367, 0.17026083, 0.09824532, 0.02916753],
|
||||
[14.61464119, 7.49001646, 5.85520077, 4.45427561, 3.46139455, 2.84484982, 2.19988537, 1.72759056, 1.36964464, 1.08895338, 0.86115354, 0.69515091, 0.54755926, 0.43325692, 0.34370604, 0.25053367, 0.17026083, 0.09824532, 0.02916753],
|
||||
[14.61464119, 11.54541874, 7.49001646, 5.85520077, 4.45427561, 3.46139455, 2.84484982, 2.19988537, 1.72759056, 1.36964464, 1.08895338, 0.86115354, 0.69515091, 0.54755926, 0.43325692, 0.34370604, 0.25053367, 0.17026083, 0.09824532, 0.02916753],
|
||||
[14.61464119, 11.54541874, 7.49001646, 5.85520077, 4.45427561, 3.46139455, 2.84484982, 2.19988537, 1.72759056, 1.36964464, 1.08895338, 0.89115214, 0.72133851, 0.59516323, 0.4783645, 0.38853383, 0.29807833, 0.22545385, 0.17026083, 0.09824532, 0.02916753],
|
||||
],
|
||||
1.15: [
|
||||
[14.61464119, 0.83188516, 0.02916753],
|
||||
[14.61464119, 1.84880662, 0.59516323, 0.02916753],
|
||||
[14.61464119, 5.85520077, 1.56271636, 0.52423614, 0.02916753],
|
||||
[14.61464119, 5.85520077, 1.91321158, 0.83188516, 0.34370604, 0.02916753],
|
||||
[14.61464119, 5.85520077, 2.45070267, 1.24153244, 0.59516323, 0.25053367, 0.02916753],
|
||||
[14.61464119, 5.85520077, 2.84484982, 1.51179266, 0.803307, 0.41087446, 0.17026083, 0.02916753],
|
||||
[14.61464119, 5.85520077, 2.84484982, 1.56271636, 0.89115214, 0.50118381, 0.25053367, 0.09824532, 0.02916753],
|
||||
[14.61464119, 6.77309084, 3.07277966, 1.84880662, 1.12534678, 0.72133851, 0.43325692, 0.22545385, 0.09824532, 0.02916753],
|
||||
[14.61464119, 6.77309084, 3.07277966, 1.91321158, 1.24153244, 0.803307, 0.52423614, 0.34370604, 0.19894916, 0.09824532, 0.02916753],
|
||||
[14.61464119, 7.49001646, 4.86714602, 2.95596409, 1.91321158, 1.24153244, 0.803307, 0.52423614, 0.34370604, 0.19894916, 0.09824532, 0.02916753],
|
||||
[14.61464119, 7.49001646, 4.86714602, 3.07277966, 2.05039096, 1.36964464, 0.95350921, 0.69515091, 0.4783645, 0.32104823, 0.19894916, 0.09824532, 0.02916753],
|
||||
[14.61464119, 7.49001646, 4.86714602, 3.07277966, 2.12350607, 1.51179266, 1.08895338, 0.803307, 0.59516323, 0.43325692, 0.29807833, 0.19894916, 0.09824532, 0.02916753],
|
||||
[14.61464119, 7.49001646, 4.86714602, 3.07277966, 2.12350607, 1.51179266, 1.08895338, 0.803307, 0.59516323, 0.45573691, 0.34370604, 0.25053367, 0.17026083, 0.09824532, 0.02916753],
|
||||
[14.61464119, 7.49001646, 4.86714602, 3.07277966, 2.19988537, 1.61558151, 1.24153244, 0.95350921, 0.74807048, 0.59516323, 0.45573691, 0.34370604, 0.25053367, 0.17026083, 0.09824532, 0.02916753],
|
||||
[14.61464119, 7.49001646, 4.86714602, 3.1956799, 2.45070267, 1.78698075, 1.32549286, 1.01931262, 0.803307, 0.64427125, 0.50118381, 0.38853383, 0.29807833, 0.22545385, 0.17026083, 0.09824532, 0.02916753],
|
||||
[14.61464119, 7.49001646, 4.86714602, 3.1956799, 2.45070267, 1.78698075, 1.32549286, 1.01931262, 0.803307, 0.64427125, 0.52423614, 0.41087446, 0.32104823, 0.25053367, 0.19894916, 0.13792117, 0.09824532, 0.02916753],
|
||||
[14.61464119, 7.49001646, 4.86714602, 3.1956799, 2.45070267, 1.84880662, 1.41535246, 1.12534678, 0.89115214, 0.72133851, 0.59516323, 0.4783645, 0.38853383, 0.32104823, 0.25053367, 0.19894916, 0.13792117, 0.09824532, 0.02916753],
|
||||
[14.61464119, 7.49001646, 4.86714602, 3.1956799, 2.45070267, 1.84880662, 1.41535246, 1.12534678, 0.89115214, 0.72133851, 0.59516323, 0.50118381, 0.41087446, 0.34370604, 0.27464288, 0.22545385, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
|
||||
[14.61464119, 7.49001646, 4.86714602, 3.1956799, 2.45070267, 1.84880662, 1.41535246, 1.12534678, 0.89115214, 0.72133851, 0.59516323, 0.50118381, 0.41087446, 0.34370604, 0.29807833, 0.25053367, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
|
||||
],
|
||||
1.20: [
|
||||
[14.61464119, 0.803307, 0.02916753],
|
||||
[14.61464119, 1.56271636, 0.52423614, 0.02916753],
|
||||
[14.61464119, 2.36326075, 0.92192322, 0.36617002, 0.02916753],
|
||||
[14.61464119, 2.84484982, 1.24153244, 0.59516323, 0.25053367, 0.02916753],
|
||||
[14.61464119, 5.85520077, 2.05039096, 0.95350921, 0.45573691, 0.17026083, 0.02916753],
|
||||
[14.61464119, 5.85520077, 2.45070267, 1.24153244, 0.64427125, 0.29807833, 0.09824532, 0.02916753],
|
||||
[14.61464119, 5.85520077, 2.45070267, 1.36964464, 0.803307, 0.45573691, 0.25053367, 0.09824532, 0.02916753],
|
||||
[14.61464119, 5.85520077, 2.84484982, 1.61558151, 0.95350921, 0.59516323, 0.36617002, 0.19894916, 0.09824532, 0.02916753],
|
||||
[14.61464119, 5.85520077, 2.84484982, 1.67050016, 1.08895338, 0.74807048, 0.50118381, 0.32104823, 0.19894916, 0.09824532, 0.02916753],
|
||||
[14.61464119, 5.85520077, 2.95596409, 1.84880662, 1.24153244, 0.83188516, 0.59516323, 0.41087446, 0.27464288, 0.17026083, 0.09824532, 0.02916753],
|
||||
[14.61464119, 5.85520077, 3.07277966, 1.98035145, 1.36964464, 0.95350921, 0.69515091, 0.50118381, 0.36617002, 0.25053367, 0.17026083, 0.09824532, 0.02916753],
|
||||
[14.61464119, 6.77309084, 3.46139455, 2.36326075, 1.56271636, 1.08895338, 0.803307, 0.59516323, 0.45573691, 0.34370604, 0.25053367, 0.17026083, 0.09824532, 0.02916753],
|
||||
[14.61464119, 6.77309084, 3.46139455, 2.45070267, 1.61558151, 1.162866, 0.86115354, 0.64427125, 0.50118381, 0.38853383, 0.29807833, 0.22545385, 0.17026083, 0.09824532, 0.02916753],
|
||||
[14.61464119, 7.49001646, 4.65472794, 3.07277966, 2.12350607, 1.51179266, 1.08895338, 0.83188516, 0.64427125, 0.50118381, 0.38853383, 0.29807833, 0.22545385, 0.17026083, 0.09824532, 0.02916753],
|
||||
[14.61464119, 7.49001646, 4.65472794, 3.07277966, 2.12350607, 1.51179266, 1.08895338, 0.83188516, 0.64427125, 0.50118381, 0.41087446, 0.32104823, 0.25053367, 0.19894916, 0.13792117, 0.09824532, 0.02916753],
|
||||
[14.61464119, 7.49001646, 4.65472794, 3.07277966, 2.12350607, 1.51179266, 1.08895338, 0.83188516, 0.64427125, 0.50118381, 0.41087446, 0.34370604, 0.27464288, 0.22545385, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
|
||||
[14.61464119, 7.49001646, 4.65472794, 3.07277966, 2.19988537, 1.61558151, 1.20157266, 0.92192322, 0.72133851, 0.57119018, 0.45573691, 0.36617002, 0.29807833, 0.25053367, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
|
||||
[14.61464119, 7.49001646, 4.65472794, 3.07277966, 2.19988537, 1.61558151, 1.24153244, 0.95350921, 0.74807048, 0.59516323, 0.4783645, 0.38853383, 0.32104823, 0.27464288, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
|
||||
[14.61464119, 7.49001646, 4.65472794, 3.07277966, 2.19988537, 1.61558151, 1.24153244, 0.95350921, 0.74807048, 0.59516323, 0.50118381, 0.41087446, 0.34370604, 0.29807833, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
|
||||
],
|
||||
1.25: [
|
||||
[14.61464119, 0.72133851, 0.02916753],
|
||||
[14.61464119, 1.56271636, 0.50118381, 0.02916753],
|
||||
[14.61464119, 2.05039096, 0.803307, 0.32104823, 0.02916753],
|
||||
[14.61464119, 2.36326075, 0.95350921, 0.43325692, 0.17026083, 0.02916753],
|
||||
[14.61464119, 2.84484982, 1.24153244, 0.59516323, 0.27464288, 0.09824532, 0.02916753],
|
||||
[14.61464119, 3.07277966, 1.51179266, 0.803307, 0.43325692, 0.22545385, 0.09824532, 0.02916753],
|
||||
[14.61464119, 5.85520077, 2.36326075, 1.24153244, 0.72133851, 0.41087446, 0.22545385, 0.09824532, 0.02916753],
|
||||
[14.61464119, 5.85520077, 2.45070267, 1.36964464, 0.83188516, 0.52423614, 0.34370604, 0.19894916, 0.09824532, 0.02916753],
|
||||
[14.61464119, 5.85520077, 2.84484982, 1.61558151, 0.98595673, 0.64427125, 0.43325692, 0.27464288, 0.17026083, 0.09824532, 0.02916753],
|
||||
[14.61464119, 5.85520077, 2.84484982, 1.67050016, 1.08895338, 0.74807048, 0.52423614, 0.36617002, 0.25053367, 0.17026083, 0.09824532, 0.02916753],
|
||||
[14.61464119, 5.85520077, 2.84484982, 1.72759056, 1.162866, 0.803307, 0.59516323, 0.45573691, 0.34370604, 0.25053367, 0.17026083, 0.09824532, 0.02916753],
|
||||
[14.61464119, 5.85520077, 2.95596409, 1.84880662, 1.24153244, 0.86115354, 0.64427125, 0.4783645, 0.36617002, 0.27464288, 0.19894916, 0.13792117, 0.09824532, 0.02916753],
|
||||
[14.61464119, 5.85520077, 2.95596409, 1.84880662, 1.28281462, 0.92192322, 0.69515091, 0.52423614, 0.41087446, 0.32104823, 0.25053367, 0.19894916, 0.13792117, 0.09824532, 0.02916753],
|
||||
[14.61464119, 5.85520077, 2.95596409, 1.91321158, 1.32549286, 0.95350921, 0.72133851, 0.54755926, 0.43325692, 0.34370604, 0.27464288, 0.22545385, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
|
||||
[14.61464119, 5.85520077, 2.95596409, 1.91321158, 1.32549286, 0.95350921, 0.72133851, 0.57119018, 0.45573691, 0.36617002, 0.29807833, 0.25053367, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
|
||||
[14.61464119, 5.85520077, 2.95596409, 1.91321158, 1.32549286, 0.95350921, 0.74807048, 0.59516323, 0.4783645, 0.38853383, 0.32104823, 0.27464288, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
|
||||
[14.61464119, 5.85520077, 3.07277966, 2.05039096, 1.41535246, 1.05362725, 0.803307, 0.61951244, 0.50118381, 0.41087446, 0.34370604, 0.29807833, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
|
||||
[14.61464119, 5.85520077, 3.07277966, 2.05039096, 1.41535246, 1.05362725, 0.803307, 0.64427125, 0.52423614, 0.43325692, 0.36617002, 0.32104823, 0.27464288, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
|
||||
[14.61464119, 5.85520077, 3.07277966, 2.05039096, 1.46270394, 1.08895338, 0.83188516, 0.66947293, 0.54755926, 0.45573691, 0.38853383, 0.34370604, 0.29807833, 0.27464288, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
|
||||
],
|
||||
1.30: [
|
||||
[14.61464119, 0.72133851, 0.02916753],
|
||||
[14.61464119, 1.24153244, 0.43325692, 0.02916753],
|
||||
[14.61464119, 1.56271636, 0.59516323, 0.22545385, 0.02916753],
|
||||
[14.61464119, 1.84880662, 0.803307, 0.36617002, 0.13792117, 0.02916753],
|
||||
[14.61464119, 2.36326075, 1.01931262, 0.52423614, 0.25053367, 0.09824532, 0.02916753],
|
||||
[14.61464119, 2.84484982, 1.36964464, 0.74807048, 0.41087446, 0.22545385, 0.09824532, 0.02916753],
|
||||
[14.61464119, 3.07277966, 1.56271636, 0.89115214, 0.54755926, 0.34370604, 0.19894916, 0.09824532, 0.02916753],
|
||||
[14.61464119, 3.07277966, 1.61558151, 0.95350921, 0.61951244, 0.41087446, 0.27464288, 0.17026083, 0.09824532, 0.02916753],
|
||||
[14.61464119, 5.85520077, 2.45070267, 1.36964464, 0.83188516, 0.54755926, 0.36617002, 0.25053367, 0.17026083, 0.09824532, 0.02916753],
|
||||
[14.61464119, 5.85520077, 2.45070267, 1.41535246, 0.92192322, 0.64427125, 0.45573691, 0.34370604, 0.25053367, 0.17026083, 0.09824532, 0.02916753],
|
||||
[14.61464119, 5.85520077, 2.6383388, 1.56271636, 1.01931262, 0.72133851, 0.50118381, 0.36617002, 0.27464288, 0.19894916, 0.13792117, 0.09824532, 0.02916753],
|
||||
[14.61464119, 5.85520077, 2.84484982, 1.61558151, 1.05362725, 0.74807048, 0.54755926, 0.41087446, 0.32104823, 0.25053367, 0.19894916, 0.13792117, 0.09824532, 0.02916753],
|
||||
[14.61464119, 5.85520077, 2.84484982, 1.61558151, 1.08895338, 0.77538133, 0.57119018, 0.43325692, 0.34370604, 0.27464288, 0.22545385, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
|
||||
[14.61464119, 5.85520077, 2.84484982, 1.61558151, 1.08895338, 0.803307, 0.59516323, 0.45573691, 0.36617002, 0.29807833, 0.25053367, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
|
||||
[14.61464119, 5.85520077, 2.84484982, 1.61558151, 1.08895338, 0.803307, 0.59516323, 0.4783645, 0.38853383, 0.32104823, 0.27464288, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
|
||||
[14.61464119, 5.85520077, 2.84484982, 1.72759056, 1.162866, 0.83188516, 0.64427125, 0.50118381, 0.41087446, 0.34370604, 0.29807833, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
|
||||
[14.61464119, 5.85520077, 2.84484982, 1.72759056, 1.162866, 0.83188516, 0.64427125, 0.52423614, 0.43325692, 0.36617002, 0.32104823, 0.27464288, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
|
||||
[14.61464119, 5.85520077, 2.84484982, 1.78698075, 1.24153244, 0.92192322, 0.72133851, 0.57119018, 0.45573691, 0.38853383, 0.34370604, 0.29807833, 0.27464288, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
|
||||
[14.61464119, 5.85520077, 2.84484982, 1.78698075, 1.24153244, 0.92192322, 0.72133851, 0.57119018, 0.4783645, 0.41087446, 0.36617002, 0.32104823, 0.29807833, 0.27464288, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
|
||||
],
|
||||
1.35: [
|
||||
[14.61464119, 0.69515091, 0.02916753],
|
||||
[14.61464119, 0.95350921, 0.34370604, 0.02916753],
|
||||
[14.61464119, 1.56271636, 0.57119018, 0.19894916, 0.02916753],
|
||||
[14.61464119, 1.61558151, 0.69515091, 0.29807833, 0.09824532, 0.02916753],
|
||||
[14.61464119, 1.84880662, 0.83188516, 0.43325692, 0.22545385, 0.09824532, 0.02916753],
|
||||
[14.61464119, 2.45070267, 1.162866, 0.64427125, 0.36617002, 0.19894916, 0.09824532, 0.02916753],
|
||||
[14.61464119, 2.84484982, 1.36964464, 0.803307, 0.50118381, 0.32104823, 0.19894916, 0.09824532, 0.02916753],
|
||||
[14.61464119, 2.84484982, 1.41535246, 0.83188516, 0.54755926, 0.36617002, 0.25053367, 0.17026083, 0.09824532, 0.02916753],
|
||||
[14.61464119, 2.84484982, 1.56271636, 0.95350921, 0.64427125, 0.45573691, 0.32104823, 0.22545385, 0.17026083, 0.09824532, 0.02916753],
|
||||
[14.61464119, 2.84484982, 1.56271636, 0.95350921, 0.64427125, 0.45573691, 0.34370604, 0.25053367, 0.19894916, 0.13792117, 0.09824532, 0.02916753],
|
||||
[14.61464119, 3.07277966, 1.61558151, 1.01931262, 0.72133851, 0.52423614, 0.38853383, 0.29807833, 0.22545385, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
|
||||
[14.61464119, 3.07277966, 1.61558151, 1.01931262, 0.72133851, 0.52423614, 0.41087446, 0.32104823, 0.25053367, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
|
||||
[14.61464119, 3.07277966, 1.61558151, 1.05362725, 0.74807048, 0.54755926, 0.43325692, 0.34370604, 0.27464288, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
|
||||
[14.61464119, 3.07277966, 1.72759056, 1.12534678, 0.803307, 0.59516323, 0.45573691, 0.36617002, 0.29807833, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
|
||||
[14.61464119, 3.07277966, 1.72759056, 1.12534678, 0.803307, 0.59516323, 0.4783645, 0.38853383, 0.32104823, 0.27464288, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
|
||||
[14.61464119, 5.85520077, 2.45070267, 1.51179266, 1.01931262, 0.74807048, 0.57119018, 0.45573691, 0.36617002, 0.32104823, 0.27464288, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
|
||||
[14.61464119, 5.85520077, 2.6383388, 1.61558151, 1.08895338, 0.803307, 0.61951244, 0.50118381, 0.41087446, 0.34370604, 0.29807833, 0.27464288, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
|
||||
[14.61464119, 5.85520077, 2.6383388, 1.61558151, 1.08895338, 0.803307, 0.64427125, 0.52423614, 0.43325692, 0.36617002, 0.32104823, 0.29807833, 0.27464288, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
|
||||
[14.61464119, 5.85520077, 2.6383388, 1.61558151, 1.08895338, 0.803307, 0.64427125, 0.52423614, 0.45573691, 0.38853383, 0.34370604, 0.32104823, 0.29807833, 0.27464288, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
|
||||
],
|
||||
1.40: [
|
||||
[14.61464119, 0.59516323, 0.02916753],
|
||||
[14.61464119, 0.95350921, 0.34370604, 0.02916753],
|
||||
[14.61464119, 1.08895338, 0.43325692, 0.13792117, 0.02916753],
|
||||
[14.61464119, 1.56271636, 0.64427125, 0.27464288, 0.09824532, 0.02916753],
|
||||
[14.61464119, 1.61558151, 0.803307, 0.43325692, 0.22545385, 0.09824532, 0.02916753],
|
||||
[14.61464119, 2.05039096, 0.95350921, 0.54755926, 0.34370604, 0.19894916, 0.09824532, 0.02916753],
|
||||
[14.61464119, 2.45070267, 1.24153244, 0.72133851, 0.43325692, 0.27464288, 0.17026083, 0.09824532, 0.02916753],
|
||||
[14.61464119, 2.45070267, 1.24153244, 0.74807048, 0.50118381, 0.34370604, 0.25053367, 0.17026083, 0.09824532, 0.02916753],
|
||||
[14.61464119, 2.45070267, 1.28281462, 0.803307, 0.52423614, 0.36617002, 0.27464288, 0.19894916, 0.13792117, 0.09824532, 0.02916753],
|
||||
[14.61464119, 2.45070267, 1.28281462, 0.803307, 0.54755926, 0.38853383, 0.29807833, 0.22545385, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
|
||||
[14.61464119, 2.84484982, 1.41535246, 0.86115354, 0.59516323, 0.43325692, 0.32104823, 0.25053367, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
|
||||
[14.61464119, 2.84484982, 1.51179266, 0.95350921, 0.64427125, 0.45573691, 0.34370604, 0.27464288, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
|
||||
[14.61464119, 2.84484982, 1.51179266, 0.95350921, 0.64427125, 0.4783645, 0.36617002, 0.29807833, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
|
||||
[14.61464119, 2.84484982, 1.56271636, 0.98595673, 0.69515091, 0.52423614, 0.41087446, 0.34370604, 0.29807833, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
|
||||
[14.61464119, 2.84484982, 1.56271636, 1.01931262, 0.72133851, 0.54755926, 0.43325692, 0.36617002, 0.32104823, 0.27464288, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
|
||||
[14.61464119, 2.84484982, 1.61558151, 1.05362725, 0.74807048, 0.57119018, 0.45573691, 0.38853383, 0.34370604, 0.29807833, 0.27464288, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
|
||||
[14.61464119, 2.84484982, 1.61558151, 1.08895338, 0.803307, 0.61951244, 0.50118381, 0.41087446, 0.36617002, 0.32104823, 0.29807833, 0.27464288, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
|
||||
[14.61464119, 2.84484982, 1.61558151, 1.08895338, 0.803307, 0.61951244, 0.50118381, 0.43325692, 0.38853383, 0.34370604, 0.32104823, 0.29807833, 0.27464288, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
|
||||
[14.61464119, 2.84484982, 1.61558151, 1.08895338, 0.803307, 0.64427125, 0.52423614, 0.45573691, 0.41087446, 0.36617002, 0.34370604, 0.32104823, 0.29807833, 0.27464288, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
|
||||
],
|
||||
1.45: [
|
||||
[14.61464119, 0.59516323, 0.02916753],
|
||||
[14.61464119, 0.803307, 0.25053367, 0.02916753],
|
||||
[14.61464119, 0.95350921, 0.34370604, 0.09824532, 0.02916753],
|
||||
[14.61464119, 1.24153244, 0.54755926, 0.25053367, 0.09824532, 0.02916753],
|
||||
[14.61464119, 1.56271636, 0.72133851, 0.36617002, 0.19894916, 0.09824532, 0.02916753],
|
||||
[14.61464119, 1.61558151, 0.803307, 0.45573691, 0.27464288, 0.17026083, 0.09824532, 0.02916753],
|
||||
[14.61464119, 1.91321158, 0.95350921, 0.57119018, 0.36617002, 0.25053367, 0.17026083, 0.09824532, 0.02916753],
|
||||
[14.61464119, 2.19988537, 1.08895338, 0.64427125, 0.41087446, 0.27464288, 0.19894916, 0.13792117, 0.09824532, 0.02916753],
|
||||
[14.61464119, 2.45070267, 1.24153244, 0.74807048, 0.50118381, 0.34370604, 0.25053367, 0.19894916, 0.13792117, 0.09824532, 0.02916753],
|
||||
[14.61464119, 2.45070267, 1.24153244, 0.74807048, 0.50118381, 0.36617002, 0.27464288, 0.22545385, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
|
||||
[14.61464119, 2.45070267, 1.28281462, 0.803307, 0.54755926, 0.41087446, 0.32104823, 0.25053367, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
|
||||
[14.61464119, 2.45070267, 1.28281462, 0.803307, 0.57119018, 0.43325692, 0.34370604, 0.27464288, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
|
||||
[14.61464119, 2.45070267, 1.28281462, 0.83188516, 0.59516323, 0.45573691, 0.36617002, 0.29807833, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
|
||||
[14.61464119, 2.45070267, 1.28281462, 0.83188516, 0.59516323, 0.45573691, 0.36617002, 0.32104823, 0.27464288, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
|
||||
[14.61464119, 2.84484982, 1.51179266, 0.95350921, 0.69515091, 0.52423614, 0.41087446, 0.34370604, 0.29807833, 0.27464288, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
|
||||
[14.61464119, 2.84484982, 1.51179266, 0.95350921, 0.69515091, 0.52423614, 0.43325692, 0.36617002, 0.32104823, 0.29807833, 0.27464288, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
|
||||
[14.61464119, 2.84484982, 1.56271636, 0.98595673, 0.72133851, 0.54755926, 0.45573691, 0.38853383, 0.34370604, 0.32104823, 0.29807833, 0.27464288, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
|
||||
[14.61464119, 2.84484982, 1.56271636, 1.01931262, 0.74807048, 0.57119018, 0.4783645, 0.41087446, 0.36617002, 0.34370604, 0.32104823, 0.29807833, 0.27464288, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
|
||||
[14.61464119, 2.84484982, 1.56271636, 1.01931262, 0.74807048, 0.59516323, 0.50118381, 0.43325692, 0.38853383, 0.36617002, 0.34370604, 0.32104823, 0.29807833, 0.27464288, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
|
||||
],
|
||||
1.50: [
|
||||
[14.61464119, 0.54755926, 0.02916753],
|
||||
[14.61464119, 0.803307, 0.25053367, 0.02916753],
|
||||
[14.61464119, 0.86115354, 0.32104823, 0.09824532, 0.02916753],
|
||||
[14.61464119, 1.24153244, 0.54755926, 0.25053367, 0.09824532, 0.02916753],
|
||||
[14.61464119, 1.56271636, 0.72133851, 0.36617002, 0.19894916, 0.09824532, 0.02916753],
|
||||
[14.61464119, 1.61558151, 0.803307, 0.45573691, 0.27464288, 0.17026083, 0.09824532, 0.02916753],
|
||||
[14.61464119, 1.61558151, 0.83188516, 0.52423614, 0.34370604, 0.25053367, 0.17026083, 0.09824532, 0.02916753],
|
||||
[14.61464119, 1.84880662, 0.95350921, 0.59516323, 0.38853383, 0.27464288, 0.19894916, 0.13792117, 0.09824532, 0.02916753],
|
||||
[14.61464119, 1.84880662, 0.95350921, 0.59516323, 0.41087446, 0.29807833, 0.22545385, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
|
||||
[14.61464119, 1.84880662, 0.95350921, 0.61951244, 0.43325692, 0.32104823, 0.25053367, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
|
||||
[14.61464119, 2.19988537, 1.12534678, 0.72133851, 0.50118381, 0.36617002, 0.27464288, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
|
||||
[14.61464119, 2.19988537, 1.12534678, 0.72133851, 0.50118381, 0.36617002, 0.29807833, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
|
||||
[14.61464119, 2.36326075, 1.24153244, 0.803307, 0.57119018, 0.43325692, 0.34370604, 0.29807833, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
|
||||
[14.61464119, 2.36326075, 1.24153244, 0.803307, 0.57119018, 0.43325692, 0.34370604, 0.29807833, 0.27464288, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
|
||||
[14.61464119, 2.36326075, 1.24153244, 0.803307, 0.59516323, 0.45573691, 0.36617002, 0.32104823, 0.29807833, 0.27464288, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
|
||||
[14.61464119, 2.36326075, 1.24153244, 0.803307, 0.59516323, 0.45573691, 0.38853383, 0.34370604, 0.32104823, 0.29807833, 0.27464288, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
|
||||
[14.61464119, 2.45070267, 1.32549286, 0.86115354, 0.64427125, 0.50118381, 0.41087446, 0.36617002, 0.34370604, 0.32104823, 0.29807833, 0.27464288, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
|
||||
[14.61464119, 2.45070267, 1.36964464, 0.92192322, 0.69515091, 0.54755926, 0.45573691, 0.41087446, 0.36617002, 0.34370604, 0.32104823, 0.29807833, 0.27464288, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
|
||||
[14.61464119, 2.45070267, 1.41535246, 0.95350921, 0.72133851, 0.57119018, 0.4783645, 0.43325692, 0.38853383, 0.36617002, 0.34370604, 0.32104823, 0.29807833, 0.27464288, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
class GITSScheduler(io.ComfyNodeV3):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.SchemaV3(
|
||||
node_id="GITSScheduler_V3",
|
||||
category="sampling/custom_sampling/schedulers",
|
||||
inputs=[
|
||||
io.Float.Input(id="coeff", default=1.20, min=0.80, max=1.50, step=0.05),
|
||||
io.Int.Input(id="steps", default=10, min=2, max=1000),
|
||||
io.Float.Input(id="denoise", default=1.0, min=0.0, max=1.0, step=0.01),
|
||||
],
|
||||
outputs=[
|
||||
io.Sigmas.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, coeff, steps, denoise):
|
||||
total_steps = steps
|
||||
if denoise < 1.0:
|
||||
if denoise <= 0.0:
|
||||
return io.NodeOutput(torch.FloatTensor([]))
|
||||
total_steps = round(steps * denoise)
|
||||
|
||||
if steps <= 20:
|
||||
sigmas = NOISE_LEVELS[round(coeff, 2)][steps-2][:]
|
||||
else:
|
||||
sigmas = NOISE_LEVELS[round(coeff, 2)][-1][:]
|
||||
sigmas = loglinear_interp(sigmas, steps + 1)
|
||||
|
||||
sigmas = sigmas[-(total_steps + 1):]
|
||||
sigmas[-1] = 0
|
||||
return io.NodeOutput(torch.FloatTensor(sigmas))
|
||||
|
||||
|
||||
NODES_LIST = [
|
||||
GITSScheduler,
|
||||
]
|
148
comfy_extras/v3/nodes_rebatch.py
Normal file
148
comfy_extras/v3/nodes_rebatch.py
Normal file
@@ -0,0 +1,148 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
|
||||
from comfy_api.v3 import io
|
||||
|
||||
|
||||
class ImageRebatch(io.ComfyNodeV3):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.SchemaV3(
|
||||
node_id="RebatchImages_V3",
|
||||
display_name="Rebatch Images _V3",
|
||||
category="image/batch",
|
||||
is_input_list=True,
|
||||
inputs=[
|
||||
io.Image.Input("images"),
|
||||
io.Int.Input("batch_size", default=1, min=1, max=4096),
|
||||
],
|
||||
outputs=[
|
||||
io.Image.Output("IMAGE", display_name="IMAGE", is_output_list=True),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, images, batch_size):
|
||||
batch_size = batch_size[0]
|
||||
|
||||
output_list = []
|
||||
all_images = []
|
||||
for img in images:
|
||||
for i in range(img.shape[0]):
|
||||
all_images.append(img[i:i+1])
|
||||
|
||||
for i in range(0, len(all_images), batch_size):
|
||||
output_list.append(torch.cat(all_images[i:i+batch_size], dim=0))
|
||||
|
||||
return io.NodeOutput(output_list)
|
||||
|
||||
|
||||
class LatentRebatch(io.ComfyNodeV3):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.SchemaV3(
|
||||
node_id="RebatchLatents_V3",
|
||||
display_name="Rebatch Latents _V3",
|
||||
category="latent/batch",
|
||||
is_input_list=True,
|
||||
inputs=[
|
||||
io.Latent.Input("latents"),
|
||||
io.Int.Input("batch_size", default=1, min=1, max=4096),
|
||||
],
|
||||
outputs=[
|
||||
io.Latent.Output(is_output_list=True),
|
||||
],
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_batch(latents, list_ind, offset):
|
||||
"""prepare a batch out of the list of latents"""
|
||||
samples = latents[list_ind]['samples']
|
||||
shape = samples.shape
|
||||
mask = latents[list_ind]['noise_mask'] if 'noise_mask' in latents[list_ind] else torch.ones((shape[0], 1, shape[2]*8, shape[3]*8), device='cpu')
|
||||
if mask.shape[-1] != shape[-1] * 8 or mask.shape[-2] != shape[-2]:
|
||||
torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(shape[-2]*8, shape[-1]*8), mode="bilinear")
|
||||
if mask.shape[0] < samples.shape[0]:
|
||||
mask = mask.repeat((shape[0] - 1) // mask.shape[0] + 1, 1, 1, 1)[:shape[0]]
|
||||
if 'batch_index' in latents[list_ind]:
|
||||
batch_inds = latents[list_ind]['batch_index']
|
||||
else:
|
||||
batch_inds = [x+offset for x in range(shape[0])]
|
||||
return samples, mask, batch_inds
|
||||
|
||||
@staticmethod
|
||||
def get_slices(indexable, num, batch_size):
|
||||
"""divides an indexable object into num slices of length batch_size, and a remainder"""
|
||||
slices = []
|
||||
for i in range(num):
|
||||
slices.append(indexable[i*batch_size:(i+1)*batch_size])
|
||||
if num * batch_size < len(indexable):
|
||||
return slices, indexable[num * batch_size:]
|
||||
else:
|
||||
return slices, None
|
||||
|
||||
@staticmethod
|
||||
def slice_batch(batch, num, batch_size):
|
||||
result = [LatentRebatch.get_slices(x, num, batch_size) for x in batch]
|
||||
return list(zip(*result))
|
||||
|
||||
@staticmethod
|
||||
def cat_batch(batch1, batch2):
|
||||
if batch1[0] is None:
|
||||
return batch2
|
||||
result = [torch.cat((b1, b2)) if torch.is_tensor(b1) else b1 + b2 for b1, b2 in zip(batch1, batch2)]
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def execute(cls, latents, batch_size):
|
||||
batch_size = batch_size[0]
|
||||
|
||||
output_list = []
|
||||
current_batch = (None, None, None)
|
||||
processed = 0
|
||||
|
||||
for i in range(len(latents)):
|
||||
# fetch new entry of list
|
||||
#samples, masks, indices = self.get_batch(latents, i)
|
||||
next_batch = cls.get_batch(latents, i, processed)
|
||||
processed += len(next_batch[2])
|
||||
# set to current if current is None
|
||||
if current_batch[0] is None:
|
||||
current_batch = next_batch
|
||||
# add previous to list if dimensions do not match
|
||||
elif next_batch[0].shape[-1] != current_batch[0].shape[-1] or next_batch[0].shape[-2] != current_batch[0].shape[-2]:
|
||||
sliced, _ = cls.slice_batch(current_batch, 1, batch_size)
|
||||
output_list.append({'samples': sliced[0][0], 'noise_mask': sliced[1][0], 'batch_index': sliced[2][0]})
|
||||
current_batch = next_batch
|
||||
# cat if everything checks out
|
||||
else:
|
||||
current_batch = cls.cat_batch(current_batch, next_batch)
|
||||
|
||||
# add to list if dimensions gone above target batch size
|
||||
if current_batch[0].shape[0] > batch_size:
|
||||
num = current_batch[0].shape[0] // batch_size
|
||||
sliced, remainder = cls.slice_batch(current_batch, num, batch_size)
|
||||
|
||||
for i in range(num):
|
||||
output_list.append({'samples': sliced[0][i], 'noise_mask': sliced[1][i], 'batch_index': sliced[2][i]})
|
||||
|
||||
current_batch = remainder
|
||||
|
||||
#add remainder
|
||||
if current_batch[0] is not None:
|
||||
sliced, _ = cls.slice_batch(current_batch, 1, batch_size)
|
||||
output_list.append({'samples': sliced[0][0], 'noise_mask': sliced[1][0], 'batch_index': sliced[2][0]})
|
||||
|
||||
#get rid of empty masks
|
||||
for s in output_list:
|
||||
if s['noise_mask'].mean() == 1.0:
|
||||
del s['noise_mask']
|
||||
|
||||
return io.NodeOutput(output_list)
|
||||
|
||||
|
||||
NODES_LIST = [
|
||||
ImageRebatch,
|
||||
LatentRebatch,
|
||||
]
|
147
execution.py
147
execution.py
@@ -8,12 +8,14 @@ import time
|
||||
import traceback
|
||||
from enum import Enum
|
||||
from typing import List, Literal, NamedTuple, Optional
|
||||
import asyncio
|
||||
|
||||
import torch
|
||||
|
||||
import comfy.model_management
|
||||
import nodes
|
||||
from comfy_execution.caching import (
|
||||
BasicCache,
|
||||
CacheKeySetID,
|
||||
CacheKeySetInputSignature,
|
||||
DependencyAwareCache,
|
||||
@@ -28,7 +30,9 @@ from comfy_execution.graph import (
|
||||
)
|
||||
from comfy_execution.graph_utils import GraphBuilder, is_link
|
||||
from comfy_execution.validation import validate_node_input
|
||||
from comfy_api.internal import ComfyNodeInternal, lock_class, first_real_override, is_class
|
||||
from comfy_execution.progress import get_progress_state, reset_progress_state, add_progress_handler, WebUIProgressHandler
|
||||
from comfy_execution.utils import CurrentNodeContext
|
||||
from comfy_api.internal import ComfyNodeInternal, first_real_override, is_class, make_locked_method_func
|
||||
from comfy_api.v3 import io
|
||||
|
||||
|
||||
@@ -41,12 +45,13 @@ class DuplicateNodeError(Exception):
|
||||
pass
|
||||
|
||||
class IsChangedCache:
|
||||
def __init__(self, dynprompt, outputs_cache):
|
||||
def __init__(self, prompt_id: str, dynprompt: DynamicPrompt, outputs_cache: BasicCache):
|
||||
self.prompt_id = prompt_id
|
||||
self.dynprompt = dynprompt
|
||||
self.outputs_cache = outputs_cache
|
||||
self.is_changed = {}
|
||||
|
||||
def get(self, node_id):
|
||||
async def get(self, node_id):
|
||||
if node_id in self.is_changed:
|
||||
return self.is_changed[node_id]
|
||||
|
||||
@@ -72,7 +77,8 @@ class IsChangedCache:
|
||||
# Intentionally do not use cached outputs here. We only want constants in IS_CHANGED
|
||||
input_data_all, _, hidden_inputs = get_input_data(node["inputs"], class_def, node_id, None)
|
||||
try:
|
||||
is_changed = _map_node_over_list(class_def, input_data_all, is_changed_name)
|
||||
is_changed = await _async_map_node_over_list(self.prompt_id, node_id, class_def, input_data_all, is_changed_name)
|
||||
is_changed = await resolve_map_node_over_list_results(is_changed)
|
||||
node["is_changed"] = [None if isinstance(x, ExecutionBlocker) else x for x in is_changed]
|
||||
except Exception as e:
|
||||
logging.warning("WARNING: {}".format(e))
|
||||
@@ -127,6 +133,8 @@ class CacheSet:
|
||||
}
|
||||
return result
|
||||
|
||||
SENSITIVE_EXTRA_DATA_KEYS = ("auth_token_comfy_org", "api_key_comfy_org")
|
||||
|
||||
def get_input_data(inputs, class_def, unique_id, outputs=None, dynprompt=None, extra_data={}):
|
||||
is_v3 = issubclass(class_def, ComfyNodeInternal)
|
||||
if is_v3:
|
||||
@@ -194,7 +202,19 @@ def get_input_data(inputs, class_def, unique_id, outputs=None, dynprompt=None, e
|
||||
|
||||
map_node_over_list = None #Don't hook this please
|
||||
|
||||
def _map_node_over_list(obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None, hidden_inputs=None):
|
||||
async def resolve_map_node_over_list_results(results):
|
||||
remaining = [x for x in results if isinstance(x, asyncio.Task) and not x.done()]
|
||||
if len(remaining) == 0:
|
||||
return [x.result() if isinstance(x, asyncio.Task) else x for x in results]
|
||||
else:
|
||||
done, pending = await asyncio.wait(remaining)
|
||||
for task in done:
|
||||
exc = task.exception()
|
||||
if exc is not None:
|
||||
raise exc
|
||||
return [x.result() if isinstance(x, asyncio.Task) else x for x in results]
|
||||
|
||||
async def _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None, hidden_inputs=None):
|
||||
# check if node wants the lists
|
||||
input_is_list = getattr(obj, "INPUT_IS_LIST", False)
|
||||
|
||||
@@ -208,7 +228,7 @@ def _map_node_over_list(obj, input_data_all, func, allow_interrupt=False, execut
|
||||
return {k: v[i if len(v) > i else -1] for k, v in d.items()}
|
||||
|
||||
results = []
|
||||
def process_inputs(inputs, index=None, input_is_list=False):
|
||||
async def process_inputs(inputs, index=None, input_is_list=False):
|
||||
if allow_interrupt:
|
||||
nodes.before_node_execution()
|
||||
execution_block = None
|
||||
@@ -256,23 +276,41 @@ def _map_node_over_list(obj, input_data_all, func, allow_interrupt=False, execut
|
||||
dynamic_list.append(real_inputs.pop(d.id, None))
|
||||
dynamic_list = [x for x in dynamic_list if x is not None]
|
||||
inputs = {**real_inputs, add_key: dynamic_list}
|
||||
results.append(getattr(type_obj, func).__func__(lock_class(class_clone), **inputs))
|
||||
# TODO: make checkign for async work, this will currently always return False for iscoroutinefunction
|
||||
f = make_locked_method_func(type_obj, func, class_clone)
|
||||
# V1
|
||||
else:
|
||||
results.append(getattr(obj, func)(**inputs))
|
||||
f = getattr(obj, func)
|
||||
if inspect.iscoroutinefunction(f):
|
||||
async def async_wrapper(f, prompt_id, unique_id, list_index, args):
|
||||
with CurrentNodeContext(prompt_id, unique_id, list_index):
|
||||
return await f(**args)
|
||||
task = asyncio.create_task(async_wrapper(f, prompt_id, unique_id, index, args=inputs))
|
||||
# Give the task a chance to execute without yielding
|
||||
await asyncio.sleep(0)
|
||||
if task.done():
|
||||
result = task.result()
|
||||
results.append(result)
|
||||
else:
|
||||
results.append(task)
|
||||
else:
|
||||
with CurrentNodeContext(prompt_id, unique_id, index):
|
||||
result = f(**inputs)
|
||||
results.append(result)
|
||||
else:
|
||||
results.append(execution_block)
|
||||
|
||||
if input_is_list:
|
||||
process_inputs(input_data_all, 0, input_is_list=input_is_list)
|
||||
await process_inputs(input_data_all, 0, input_is_list=input_is_list)
|
||||
elif max_len_input == 0:
|
||||
process_inputs({})
|
||||
await process_inputs({})
|
||||
else:
|
||||
for i in range(max_len_input):
|
||||
input_dict = slice_dict(input_data_all, i)
|
||||
process_inputs(input_dict, i)
|
||||
await process_inputs(input_dict, i)
|
||||
return results
|
||||
|
||||
|
||||
def merge_result_data(results, obj):
|
||||
# check which outputs need concatenating
|
||||
output = []
|
||||
@@ -294,11 +332,18 @@ def merge_result_data(results, obj):
|
||||
output.append([o[i] for o in results])
|
||||
return output
|
||||
|
||||
def get_output_data(obj, input_data_all, execution_block_cb=None, pre_execute_cb=None, hidden_inputs=None):
|
||||
async def get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=None, pre_execute_cb=None, hidden_inputs=None):
|
||||
return_values = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, obj.FUNCTION, allow_interrupt=True, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, hidden_inputs=hidden_inputs)
|
||||
has_pending_task = any(isinstance(r, asyncio.Task) and not r.done() for r in return_values)
|
||||
if has_pending_task:
|
||||
return return_values, {}, False, has_pending_task
|
||||
output, ui, has_subgraph = get_output_from_returns(return_values, obj)
|
||||
return output, ui, has_subgraph, False
|
||||
|
||||
def get_output_from_returns(return_values, obj):
|
||||
results = []
|
||||
uis = []
|
||||
subgraph_results = []
|
||||
return_values = _map_node_over_list(obj, input_data_all, obj.FUNCTION, allow_interrupt=True, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, hidden_inputs=hidden_inputs)
|
||||
has_subgraph = False
|
||||
for i in range(len(return_values)):
|
||||
r = return_values[i]
|
||||
@@ -352,6 +397,10 @@ def get_output_data(obj, input_data_all, execution_block_cb=None, pre_execute_cb
|
||||
else:
|
||||
output = []
|
||||
ui = dict()
|
||||
# TODO: Think there's an existing bug here
|
||||
# If we're performing a subgraph expansion, we probably shouldn't be returning UI values yet.
|
||||
# They'll get cached without the completed subgraphs. It's an edge case and I'm not aware of
|
||||
# any nodes that use both subgraph expansion and custom UI outputs, but might be a problem in the future.
|
||||
if len(uis) > 0:
|
||||
ui = {k: [y for x in uis for y in x[k]] for k in uis[0].keys()}
|
||||
return output, ui, has_subgraph
|
||||
@@ -364,7 +413,7 @@ def format_value(x):
|
||||
else:
|
||||
return str(x)
|
||||
|
||||
def execute(server, dynprompt, caches, current_item, extra_data, executed, prompt_id, execution_list, pending_subgraph_results):
|
||||
async def execute(server, dynprompt, caches, current_item, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes):
|
||||
unique_id = current_item
|
||||
real_node_id = dynprompt.get_real_node_id(unique_id)
|
||||
display_node_id = dynprompt.get_display_node_id(unique_id)
|
||||
@@ -376,11 +425,26 @@ def execute(server, dynprompt, caches, current_item, extra_data, executed, promp
|
||||
if server.client_id is not None:
|
||||
cached_output = caches.ui.get(unique_id) or {}
|
||||
server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": cached_output.get("output",None), "prompt_id": prompt_id }, server.client_id)
|
||||
get_progress_state().finish_progress(unique_id)
|
||||
return (ExecutionResult.SUCCESS, None, None)
|
||||
|
||||
input_data_all = None
|
||||
try:
|
||||
if unique_id in pending_subgraph_results:
|
||||
if unique_id in pending_async_nodes:
|
||||
results = []
|
||||
for r in pending_async_nodes[unique_id]:
|
||||
if isinstance(r, asyncio.Task):
|
||||
try:
|
||||
results.append(r.result())
|
||||
except Exception as ex:
|
||||
# An async task failed - propagate the exception up
|
||||
del pending_async_nodes[unique_id]
|
||||
raise ex
|
||||
else:
|
||||
results.append(r)
|
||||
del pending_async_nodes[unique_id]
|
||||
output_data, output_ui, has_subgraph = get_output_from_returns(results, class_def)
|
||||
elif unique_id in pending_subgraph_results:
|
||||
cached_results = pending_subgraph_results[unique_id]
|
||||
resolved_outputs = []
|
||||
for is_subgraph, result in cached_results:
|
||||
@@ -402,6 +466,7 @@ def execute(server, dynprompt, caches, current_item, extra_data, executed, promp
|
||||
output_ui = []
|
||||
has_subgraph = False
|
||||
else:
|
||||
get_progress_state().start_progress(unique_id)
|
||||
input_data_all, missing_keys, hidden_inputs = get_input_data(inputs, class_def, unique_id, caches.outputs, dynprompt, extra_data)
|
||||
if server.client_id is not None:
|
||||
server.last_node_id = display_node_id
|
||||
@@ -417,7 +482,8 @@ def execute(server, dynprompt, caches, current_item, extra_data, executed, promp
|
||||
else:
|
||||
lazy_status_present = getattr(obj, "check_lazy_status", None) is not None
|
||||
if lazy_status_present:
|
||||
required_inputs = _map_node_over_list(obj, input_data_all, "check_lazy_status", allow_interrupt=True, hidden_inputs=hidden_inputs)
|
||||
required_inputs = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, "check_lazy_status", allow_interrupt=True, hidden_inputs=hidden_inputs)
|
||||
required_inputs = await resolve_map_node_over_list_results(required_inputs)
|
||||
required_inputs = set(sum([r for r in required_inputs if isinstance(r,list)], []))
|
||||
required_inputs = [x for x in required_inputs if isinstance(x,str) and (
|
||||
x not in input_data_all or x in missing_keys
|
||||
@@ -446,8 +512,18 @@ def execute(server, dynprompt, caches, current_item, extra_data, executed, promp
|
||||
else:
|
||||
return block
|
||||
def pre_execute_cb(call_index):
|
||||
# TODO - How to handle this with async functions without contextvars (which requires Python 3.12)?
|
||||
GraphBuilder.set_default_prefix(unique_id, call_index, 0)
|
||||
output_data, output_ui, has_subgraph = get_output_data(obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, hidden_inputs=hidden_inputs)
|
||||
output_data, output_ui, has_subgraph, has_pending_tasks = await get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, hidden_inputs=hidden_inputs)
|
||||
if has_pending_tasks:
|
||||
pending_async_nodes[unique_id] = output_data
|
||||
unblock = execution_list.add_external_block(unique_id)
|
||||
async def await_completion():
|
||||
tasks = [x for x in output_data if isinstance(x, asyncio.Task)]
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
unblock()
|
||||
asyncio.create_task(await_completion())
|
||||
return (ExecutionResult.PENDING, None, None)
|
||||
if len(output_ui) > 0:
|
||||
caches.ui.set(unique_id, {
|
||||
"meta": {
|
||||
@@ -490,7 +566,8 @@ def execute(server, dynprompt, caches, current_item, extra_data, executed, promp
|
||||
cached_outputs.append((True, node_outputs))
|
||||
new_node_ids = set(new_node_ids)
|
||||
for cache in caches.all:
|
||||
cache.ensure_subcache_for(unique_id, new_node_ids).clean_unused()
|
||||
subcache = await cache.ensure_subcache_for(unique_id, new_node_ids)
|
||||
subcache.clean_unused()
|
||||
for node_id in new_output_ids:
|
||||
execution_list.add_node(node_id)
|
||||
for link in new_output_links:
|
||||
@@ -535,6 +612,7 @@ def execute(server, dynprompt, caches, current_item, extra_data, executed, promp
|
||||
|
||||
return (ExecutionResult.FAILURE, error_details, ex)
|
||||
|
||||
get_progress_state().finish_progress(unique_id)
|
||||
executed.add(unique_id)
|
||||
|
||||
return (ExecutionResult.SUCCESS, None, None)
|
||||
@@ -589,6 +667,11 @@ class PromptExecutor:
|
||||
self.add_message("execution_error", mes, broadcast=False)
|
||||
|
||||
def execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]):
|
||||
asyncio_loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(asyncio_loop)
|
||||
asyncio.run(self.execute_async(prompt, prompt_id, extra_data, execute_outputs))
|
||||
|
||||
async def execute_async(self, prompt, prompt_id, extra_data={}, execute_outputs=[]):
|
||||
nodes.interrupt_processing(False)
|
||||
|
||||
if "client_id" in extra_data:
|
||||
@@ -601,9 +684,11 @@ class PromptExecutor:
|
||||
|
||||
with torch.inference_mode():
|
||||
dynamic_prompt = DynamicPrompt(prompt)
|
||||
is_changed_cache = IsChangedCache(dynamic_prompt, self.caches.outputs)
|
||||
reset_progress_state(prompt_id, dynamic_prompt)
|
||||
add_progress_handler(WebUIProgressHandler(self.server))
|
||||
is_changed_cache = IsChangedCache(prompt_id, dynamic_prompt, self.caches.outputs)
|
||||
for cache in self.caches.all:
|
||||
cache.set_prompt(dynamic_prompt, prompt.keys(), is_changed_cache)
|
||||
await cache.set_prompt(dynamic_prompt, prompt.keys(), is_changed_cache)
|
||||
cache.clean_unused()
|
||||
|
||||
cached_nodes = []
|
||||
@@ -616,6 +701,7 @@ class PromptExecutor:
|
||||
{ "nodes": cached_nodes, "prompt_id": prompt_id},
|
||||
broadcast=False)
|
||||
pending_subgraph_results = {}
|
||||
pending_async_nodes = {} # TODO - Unify this with pending_subgraph_results
|
||||
executed = set()
|
||||
execution_list = ExecutionList(dynamic_prompt, self.caches.outputs)
|
||||
current_outputs = self.caches.outputs.all_node_ids()
|
||||
@@ -623,12 +709,13 @@ class PromptExecutor:
|
||||
execution_list.add_node(node_id)
|
||||
|
||||
while not execution_list.is_empty():
|
||||
node_id, error, ex = execution_list.stage_node_execution()
|
||||
node_id, error, ex = await execution_list.stage_node_execution()
|
||||
if error is not None:
|
||||
self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex)
|
||||
break
|
||||
|
||||
result, error, ex = execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results)
|
||||
assert node_id is not None, "Node ID should not be None at this point"
|
||||
result, error, ex = await execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes)
|
||||
self.success = result != ExecutionResult.FAILURE
|
||||
if result == ExecutionResult.FAILURE:
|
||||
self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex)
|
||||
@@ -658,7 +745,7 @@ class PromptExecutor:
|
||||
comfy.model_management.unload_all_models()
|
||||
|
||||
|
||||
def validate_inputs(prompt, item, validated):
|
||||
async def validate_inputs(prompt_id, prompt, item, validated):
|
||||
unique_id = item
|
||||
if unique_id in validated:
|
||||
return validated[unique_id]
|
||||
@@ -741,7 +828,7 @@ def validate_inputs(prompt, item, validated):
|
||||
errors.append(error)
|
||||
continue
|
||||
try:
|
||||
r = validate_inputs(prompt, o_id, validated)
|
||||
r = await validate_inputs(prompt_id, prompt, o_id, validated)
|
||||
if r[0] is False:
|
||||
# `r` will be set in `validated[o_id]` already
|
||||
valid = False
|
||||
@@ -865,7 +952,8 @@ def validate_inputs(prompt, item, validated):
|
||||
if 'input_types' in validate_function_inputs:
|
||||
input_filtered['input_types'] = [received_types]
|
||||
|
||||
ret = _map_node_over_list(obj_class, input_filtered, validate_function_name, hidden_inputs=hidden_inputs)
|
||||
ret = await _async_map_node_over_list(prompt_id, unique_id, obj_class, input_filtered, validate_function_name, hidden_inputs=hidden_inputs)
|
||||
ret = await resolve_map_node_over_list_results(ret)
|
||||
for x in input_filtered:
|
||||
for i, r in enumerate(ret):
|
||||
if r is not True and not isinstance(r, ExecutionBlocker):
|
||||
@@ -898,7 +986,7 @@ def full_type_name(klass):
|
||||
return klass.__qualname__
|
||||
return module + '.' + klass.__qualname__
|
||||
|
||||
def validate_prompt(prompt):
|
||||
async def validate_prompt(prompt_id, prompt):
|
||||
outputs = set()
|
||||
for x in prompt:
|
||||
if 'class_type' not in prompt[x]:
|
||||
@@ -941,7 +1029,7 @@ def validate_prompt(prompt):
|
||||
valid = False
|
||||
reasons = []
|
||||
try:
|
||||
m = validate_inputs(prompt, o, validated)
|
||||
m = await validate_inputs(prompt_id, prompt, o, validated)
|
||||
valid = m[0]
|
||||
reasons = m[1]
|
||||
except Exception as ex:
|
||||
@@ -1054,6 +1142,11 @@ class PromptQueue:
|
||||
if status is not None:
|
||||
status_dict = copy.deepcopy(status._asdict())
|
||||
|
||||
# Remove sensitive data from extra_data before storing in history
|
||||
for sensitive_val in SENSITIVE_EXTRA_DATA_KEYS:
|
||||
if sensitive_val in prompt[3]:
|
||||
prompt[3].pop(sensitive_val)
|
||||
|
||||
self.history[prompt[1]] = {
|
||||
"prompt": prompt,
|
||||
"outputs": {},
|
||||
|
35
main.py
35
main.py
@@ -11,6 +11,9 @@ import itertools
|
||||
import utils.extra_config
|
||||
import logging
|
||||
import sys
|
||||
from comfy_execution.progress import get_progress_state
|
||||
from comfy_execution.utils import get_executing_context
|
||||
from comfy_api import feature_flags
|
||||
|
||||
if __name__ == "__main__":
|
||||
#NOTE: These do not do anything on core ComfyUI, they are for custom nodes.
|
||||
@@ -127,11 +130,14 @@ if __name__ == "__main__":
|
||||
|
||||
import cuda_malloc
|
||||
|
||||
if 'torch' in sys.modules:
|
||||
logging.warning("WARNING: Potential Error in code: Torch already imported, torch should never be imported before this point.")
|
||||
|
||||
import comfy.utils
|
||||
|
||||
import execution
|
||||
import server
|
||||
from server import BinaryEventTypes
|
||||
from protocol import BinaryEventTypes
|
||||
import nodes
|
||||
import comfy.model_management
|
||||
import comfyui_version
|
||||
@@ -227,15 +233,34 @@ async def run(server_instance, address='', port=8188, verbose=True, call_on_star
|
||||
server_instance.start_multi_address(addresses, call_on_start, verbose), server_instance.publish_loop()
|
||||
)
|
||||
|
||||
|
||||
def hijack_progress(server_instance):
|
||||
def hook(value, total, preview_image):
|
||||
def hook(value, total, preview_image, prompt_id=None, node_id=None):
|
||||
executing_context = get_executing_context()
|
||||
if prompt_id is None and executing_context is not None:
|
||||
prompt_id = executing_context.prompt_id
|
||||
if node_id is None and executing_context is not None:
|
||||
node_id = executing_context.node_id
|
||||
comfy.model_management.throw_exception_if_processing_interrupted()
|
||||
progress = {"value": value, "max": total, "prompt_id": server_instance.last_prompt_id, "node": server_instance.last_node_id}
|
||||
if prompt_id is None:
|
||||
prompt_id = server_instance.last_prompt_id
|
||||
if node_id is None:
|
||||
node_id = server_instance.last_node_id
|
||||
progress = {"value": value, "max": total, "prompt_id": prompt_id, "node": node_id}
|
||||
get_progress_state().update_progress(node_id, value, total, preview_image)
|
||||
|
||||
server_instance.send_sync("progress", progress, server_instance.client_id)
|
||||
if preview_image is not None:
|
||||
server_instance.send_sync(BinaryEventTypes.UNENCODED_PREVIEW_IMAGE, preview_image, server_instance.client_id)
|
||||
# Only send old method if client doesn't support preview metadata
|
||||
if not feature_flags.supports_feature(
|
||||
server_instance.sockets_metadata,
|
||||
server_instance.client_id,
|
||||
"supports_preview_metadata",
|
||||
):
|
||||
server_instance.send_sync(
|
||||
BinaryEventTypes.UNENCODED_PREVIEW_IMAGE,
|
||||
preview_image,
|
||||
server_instance.client_id,
|
||||
)
|
||||
|
||||
comfy.utils.set_progress_bar_global_hook(hook)
|
||||
|
||||
|
15
nodes.py
15
nodes.py
@@ -2302,14 +2302,27 @@ def init_builtin_extra_nodes():
|
||||
"v3/nodes_ace.py",
|
||||
"v3/nodes_advanced_samplers.py",
|
||||
"v3/nodes_align_your_steps.py",
|
||||
"v3/nodes_audio.py",
|
||||
"v3/nodes_apg.py",
|
||||
"v3/nodes_attention_multiply.py",
|
||||
"v3/nodes_audio.py",
|
||||
"v3/nodes_camera_trajectory.py",
|
||||
"v3/nodes_canny.py",
|
||||
"v3/nodes_cfg.py",
|
||||
"v3/nodes_clip_sdxl.py",
|
||||
"v3/nodes_compositing.py",
|
||||
"v3/nodes_cond.py",
|
||||
"v3/nodes_controlnet.py",
|
||||
"v3/nodes_cosmos.py",
|
||||
"v3/nodes_differential_diffusion.py",
|
||||
"v3/nodes_flux.py",
|
||||
"v3/nodes_freelunch.py",
|
||||
"v3/nodes_fresca.py",
|
||||
"v3/nodes_gits.py",
|
||||
"v3/nodes_images.py",
|
||||
"v3/nodes_mask.py",
|
||||
"v3/nodes_preview_any.py",
|
||||
"v3/nodes_primitive.py",
|
||||
"v3/nodes_rebatch.py",
|
||||
"v3/nodes_stable_cascade.py",
|
||||
"v3/nodes_webcam.py",
|
||||
]
|
||||
|
7
protocol.py
Normal file
7
protocol.py
Normal file
@@ -0,0 +1,7 @@
|
||||
|
||||
class BinaryEventTypes:
|
||||
PREVIEW_IMAGE = 1
|
||||
UNENCODED_PREVIEW_IMAGE = 2
|
||||
TEXT = 3
|
||||
PREVIEW_IMAGE_WITH_METADATA = 4
|
||||
|
@@ -23,9 +23,8 @@ lint.select = [
|
||||
# See all rules here: https://docs.astral.sh/ruff/rules/#pyflakes-f
|
||||
"F",
|
||||
]
|
||||
lint.ignore = ["E501"] # disable line-length checking
|
||||
exclude = ["*.ipynb"]
|
||||
line-length = 144
|
||||
lint.pycodestyle.ignore-overlong-task-comments = true
|
||||
|
||||
[tool.ruff.lint.per-file-ignores]
|
||||
"!comfy_extras/v3/*" = ["E", "I"] # enable these rules only for V3 nodes
|
||||
|
@@ -1,5 +1,5 @@
|
||||
comfyui-frontend-package==1.23.4
|
||||
comfyui-workflow-templates==0.1.35
|
||||
comfyui-workflow-templates==0.1.36
|
||||
comfyui-embedded-docs==0.2.4
|
||||
torch
|
||||
torchsde
|
||||
|
@@ -10,11 +10,11 @@ import urllib.parse
|
||||
server_address = "127.0.0.1:8188"
|
||||
client_id = str(uuid.uuid4())
|
||||
|
||||
def queue_prompt(prompt):
|
||||
p = {"prompt": prompt, "client_id": client_id}
|
||||
def queue_prompt(prompt, prompt_id):
|
||||
p = {"prompt": prompt, "client_id": client_id, "prompt_id": prompt_id}
|
||||
data = json.dumps(p).encode('utf-8')
|
||||
req = urllib.request.Request("http://{}/prompt".format(server_address), data=data)
|
||||
return json.loads(urllib.request.urlopen(req).read())
|
||||
req = urllib.request.Request("http://{}/prompt".format(server_address), data=data)
|
||||
urllib.request.urlopen(req).read()
|
||||
|
||||
def get_image(filename, subfolder, folder_type):
|
||||
data = {"filename": filename, "subfolder": subfolder, "type": folder_type}
|
||||
@@ -27,7 +27,8 @@ def get_history(prompt_id):
|
||||
return json.loads(response.read())
|
||||
|
||||
def get_images(ws, prompt):
|
||||
prompt_id = queue_prompt(prompt)['prompt_id']
|
||||
prompt_id = str(uuid.uuid4())
|
||||
queue_prompt(prompt, prompt_id)
|
||||
output_images = {}
|
||||
while True:
|
||||
out = ws.recv()
|
||||
|
98
server.py
98
server.py
@@ -26,6 +26,7 @@ import mimetypes
|
||||
from comfy.cli_args import args
|
||||
import comfy.utils
|
||||
import comfy.model_management
|
||||
from comfy_api import feature_flags
|
||||
import node_helpers
|
||||
from comfyui_version import __version__
|
||||
from app.frontend_management import FrontendManager
|
||||
@@ -36,11 +37,7 @@ from app.model_manager import ModelFileManager
|
||||
from app.custom_node_manager import CustomNodeManager
|
||||
from typing import Optional, Union
|
||||
from api_server.routes.internal.internal_routes import InternalRoutes
|
||||
|
||||
class BinaryEventTypes:
|
||||
PREVIEW_IMAGE = 1
|
||||
UNENCODED_PREVIEW_IMAGE = 2
|
||||
TEXT = 3
|
||||
from protocol import BinaryEventTypes
|
||||
|
||||
async def send_socket_catch_exception(function, message):
|
||||
try:
|
||||
@@ -179,6 +176,7 @@ class PromptServer():
|
||||
max_upload_size = round(args.max_upload_size * 1024 * 1024)
|
||||
self.app = web.Application(client_max_size=max_upload_size, middlewares=middlewares)
|
||||
self.sockets = dict()
|
||||
self.sockets_metadata = dict()
|
||||
self.web_root = (
|
||||
FrontendManager.init_frontend(args.front_end_version)
|
||||
if args.front_end_root is None
|
||||
@@ -203,20 +201,53 @@ class PromptServer():
|
||||
else:
|
||||
sid = uuid.uuid4().hex
|
||||
|
||||
# Store WebSocket for backward compatibility
|
||||
self.sockets[sid] = ws
|
||||
# Store metadata separately
|
||||
self.sockets_metadata[sid] = {"feature_flags": {}}
|
||||
|
||||
try:
|
||||
# Send initial state to the new client
|
||||
await self.send("status", { "status": self.get_queue_info(), 'sid': sid }, sid)
|
||||
await self.send("status", {"status": self.get_queue_info(), "sid": sid}, sid)
|
||||
# On reconnect if we are the currently executing client send the current node
|
||||
if self.client_id == sid and self.last_node_id is not None:
|
||||
await self.send("executing", { "node": self.last_node_id }, sid)
|
||||
|
||||
# Flag to track if we've received the first message
|
||||
first_message = True
|
||||
|
||||
async for msg in ws:
|
||||
if msg.type == aiohttp.WSMsgType.ERROR:
|
||||
logging.warning('ws connection closed with exception %s' % ws.exception())
|
||||
elif msg.type == aiohttp.WSMsgType.TEXT:
|
||||
try:
|
||||
data = json.loads(msg.data)
|
||||
# Check if first message is feature flags
|
||||
if first_message and data.get("type") == "feature_flags":
|
||||
# Store client feature flags
|
||||
client_flags = data.get("data", {})
|
||||
self.sockets_metadata[sid]["feature_flags"] = client_flags
|
||||
|
||||
# Send server feature flags in response
|
||||
await self.send(
|
||||
"feature_flags",
|
||||
feature_flags.get_server_features(),
|
||||
sid,
|
||||
)
|
||||
|
||||
logging.info(
|
||||
f"Feature flags negotiated for client {sid}: {client_flags}"
|
||||
)
|
||||
first_message = False
|
||||
except json.JSONDecodeError:
|
||||
logging.warning(
|
||||
f"Invalid JSON received from client {sid}: {msg.data}"
|
||||
)
|
||||
except Exception as e:
|
||||
logging.error(f"Error processing WebSocket message: {e}")
|
||||
finally:
|
||||
self.sockets.pop(sid, None)
|
||||
self.sockets_metadata.pop(sid, None)
|
||||
return ws
|
||||
|
||||
@routes.get("/")
|
||||
@@ -549,6 +580,10 @@ class PromptServer():
|
||||
}
|
||||
return web.json_response(system_stats)
|
||||
|
||||
@routes.get("/features")
|
||||
async def get_features(request):
|
||||
return web.json_response(feature_flags.get_server_features())
|
||||
|
||||
@routes.get("/prompt")
|
||||
async def get_prompt(request):
|
||||
return web.json_response(self.get_queue_info())
|
||||
@@ -646,7 +681,8 @@ class PromptServer():
|
||||
|
||||
if "prompt" in json_data:
|
||||
prompt = json_data["prompt"]
|
||||
valid = execution.validate_prompt(prompt)
|
||||
prompt_id = str(json_data.get("prompt_id", uuid.uuid4()))
|
||||
valid = await execution.validate_prompt(prompt_id, prompt)
|
||||
extra_data = {}
|
||||
if "extra_data" in json_data:
|
||||
extra_data = json_data["extra_data"]
|
||||
@@ -654,7 +690,6 @@ class PromptServer():
|
||||
if "client_id" in json_data:
|
||||
extra_data["client_id"] = json_data["client_id"]
|
||||
if valid[0]:
|
||||
prompt_id = str(uuid.uuid4())
|
||||
outputs_to_execute = valid[2]
|
||||
self.prompt_queue.put((number, prompt_id, prompt, extra_data, outputs_to_execute))
|
||||
response = {"prompt_id": prompt_id, "number": number, "node_errors": valid[3]}
|
||||
@@ -769,6 +804,10 @@ class PromptServer():
|
||||
async def send(self, event, data, sid=None):
|
||||
if event == BinaryEventTypes.UNENCODED_PREVIEW_IMAGE:
|
||||
await self.send_image(data, sid=sid)
|
||||
elif event == BinaryEventTypes.PREVIEW_IMAGE_WITH_METADATA:
|
||||
# data is (preview_image, metadata)
|
||||
preview_image, metadata = data
|
||||
await self.send_image_with_metadata(preview_image, metadata, sid=sid)
|
||||
elif isinstance(data, (bytes, bytearray)):
|
||||
await self.send_bytes(event, data, sid)
|
||||
else:
|
||||
@@ -807,6 +846,43 @@ class PromptServer():
|
||||
preview_bytes = bytesIO.getvalue()
|
||||
await self.send_bytes(BinaryEventTypes.PREVIEW_IMAGE, preview_bytes, sid=sid)
|
||||
|
||||
async def send_image_with_metadata(self, image_data, metadata=None, sid=None):
|
||||
image_type = image_data[0]
|
||||
image = image_data[1]
|
||||
max_size = image_data[2]
|
||||
if max_size is not None:
|
||||
if hasattr(Image, 'Resampling'):
|
||||
resampling = Image.Resampling.BILINEAR
|
||||
else:
|
||||
resampling = Image.Resampling.LANCZOS
|
||||
|
||||
image = ImageOps.contain(image, (max_size, max_size), resampling)
|
||||
|
||||
mimetype = "image/png" if image_type == "PNG" else "image/jpeg"
|
||||
|
||||
# Prepare metadata
|
||||
if metadata is None:
|
||||
metadata = {}
|
||||
metadata["image_type"] = mimetype
|
||||
|
||||
# Serialize metadata as JSON
|
||||
import json
|
||||
metadata_json = json.dumps(metadata).encode('utf-8')
|
||||
metadata_length = len(metadata_json)
|
||||
|
||||
# Prepare image data
|
||||
bytesIO = BytesIO()
|
||||
image.save(bytesIO, format=image_type, quality=95, compress_level=1)
|
||||
image_bytes = bytesIO.getvalue()
|
||||
|
||||
# Combine metadata and image
|
||||
combined_data = bytearray()
|
||||
combined_data.extend(struct.pack(">I", metadata_length))
|
||||
combined_data.extend(metadata_json)
|
||||
combined_data.extend(image_bytes)
|
||||
|
||||
await self.send_bytes(BinaryEventTypes.PREVIEW_IMAGE_WITH_METADATA, combined_data, sid=sid)
|
||||
|
||||
async def send_bytes(self, event, data, sid=None):
|
||||
message = self.encode_bytes(event, data)
|
||||
|
||||
@@ -848,10 +924,10 @@ class PromptServer():
|
||||
ssl_ctx = None
|
||||
scheme = "http"
|
||||
if args.tls_keyfile and args.tls_certfile:
|
||||
ssl_ctx = ssl.SSLContext(protocol=ssl.PROTOCOL_TLS_SERVER, verify_mode=ssl.CERT_NONE)
|
||||
ssl_ctx.load_cert_chain(certfile=args.tls_certfile,
|
||||
ssl_ctx = ssl.SSLContext(protocol=ssl.PROTOCOL_TLS_SERVER, verify_mode=ssl.CERT_NONE)
|
||||
ssl_ctx.load_cert_chain(certfile=args.tls_certfile,
|
||||
keyfile=args.tls_keyfile)
|
||||
scheme = "https"
|
||||
scheme = "https"
|
||||
|
||||
if verbose:
|
||||
logging.info("Starting server\n")
|
||||
|
98
tests-unit/feature_flags_test.py
Normal file
98
tests-unit/feature_flags_test.py
Normal file
@@ -0,0 +1,98 @@
|
||||
"""Tests for feature flags functionality."""
|
||||
|
||||
from comfy_api.feature_flags import (
|
||||
get_connection_feature,
|
||||
supports_feature,
|
||||
get_server_features,
|
||||
SERVER_FEATURE_FLAGS,
|
||||
)
|
||||
|
||||
|
||||
class TestFeatureFlags:
|
||||
"""Test suite for feature flags functions."""
|
||||
|
||||
def test_get_server_features_returns_copy(self):
|
||||
"""Test that get_server_features returns a copy of the server flags."""
|
||||
features = get_server_features()
|
||||
# Verify it's a copy by modifying it
|
||||
features["test_flag"] = True
|
||||
# Original should be unchanged
|
||||
assert "test_flag" not in SERVER_FEATURE_FLAGS
|
||||
|
||||
def test_get_server_features_contains_expected_flags(self):
|
||||
"""Test that server features contain expected flags."""
|
||||
features = get_server_features()
|
||||
assert "supports_preview_metadata" in features
|
||||
assert features["supports_preview_metadata"] is True
|
||||
assert "max_upload_size" in features
|
||||
assert isinstance(features["max_upload_size"], (int, float))
|
||||
|
||||
def test_get_connection_feature_with_missing_sid(self):
|
||||
"""Test getting feature for non-existent session ID."""
|
||||
sockets_metadata = {}
|
||||
result = get_connection_feature(sockets_metadata, "missing_sid", "some_feature")
|
||||
assert result is False # Default value
|
||||
|
||||
def test_get_connection_feature_with_custom_default(self):
|
||||
"""Test getting feature with custom default value."""
|
||||
sockets_metadata = {}
|
||||
result = get_connection_feature(
|
||||
sockets_metadata, "missing_sid", "some_feature", default="custom_default"
|
||||
)
|
||||
assert result == "custom_default"
|
||||
|
||||
def test_get_connection_feature_with_feature_flags(self):
|
||||
"""Test getting feature from connection with feature flags."""
|
||||
sockets_metadata = {
|
||||
"sid1": {
|
||||
"feature_flags": {
|
||||
"supports_preview_metadata": True,
|
||||
"custom_feature": "value",
|
||||
},
|
||||
}
|
||||
}
|
||||
result = get_connection_feature(sockets_metadata, "sid1", "supports_preview_metadata")
|
||||
assert result is True
|
||||
|
||||
result = get_connection_feature(sockets_metadata, "sid1", "custom_feature")
|
||||
assert result == "value"
|
||||
|
||||
def test_get_connection_feature_missing_feature(self):
|
||||
"""Test getting non-existent feature from connection."""
|
||||
sockets_metadata = {
|
||||
"sid1": {"feature_flags": {"existing_feature": True}}
|
||||
}
|
||||
result = get_connection_feature(sockets_metadata, "sid1", "missing_feature")
|
||||
assert result is False
|
||||
|
||||
def test_supports_feature_returns_boolean(self):
|
||||
"""Test that supports_feature always returns boolean."""
|
||||
sockets_metadata = {
|
||||
"sid1": {
|
||||
"feature_flags": {
|
||||
"bool_feature": True,
|
||||
"string_feature": "value",
|
||||
"none_feature": None,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
# True boolean feature
|
||||
assert supports_feature(sockets_metadata, "sid1", "bool_feature") is True
|
||||
|
||||
# Non-boolean values should return False
|
||||
assert supports_feature(sockets_metadata, "sid1", "string_feature") is False
|
||||
assert supports_feature(sockets_metadata, "sid1", "none_feature") is False
|
||||
assert supports_feature(sockets_metadata, "sid1", "missing_feature") is False
|
||||
|
||||
def test_supports_feature_with_missing_connection(self):
|
||||
"""Test supports_feature with missing connection."""
|
||||
sockets_metadata = {}
|
||||
assert supports_feature(sockets_metadata, "missing_sid", "any_feature") is False
|
||||
|
||||
def test_empty_feature_flags_dict(self):
|
||||
"""Test connection with empty feature flags dictionary."""
|
||||
sockets_metadata = {"sid1": {"feature_flags": {}}}
|
||||
result = get_connection_feature(sockets_metadata, "sid1", "any_feature")
|
||||
assert result is False
|
||||
assert supports_feature(sockets_metadata, "sid1", "any_feature") is False
|
@@ -1,3 +1,4 @@
|
||||
pytest>=7.8.0
|
||||
pytest-aiohttp
|
||||
pytest-asyncio
|
||||
websocket-client
|
||||
|
77
tests-unit/websocket_feature_flags_test.py
Normal file
77
tests-unit/websocket_feature_flags_test.py
Normal file
@@ -0,0 +1,77 @@
|
||||
"""Simplified tests for WebSocket feature flags functionality."""
|
||||
from comfy_api import feature_flags
|
||||
|
||||
|
||||
class TestWebSocketFeatureFlags:
|
||||
"""Test suite for WebSocket feature flags integration."""
|
||||
|
||||
def test_server_feature_flags_response(self):
|
||||
"""Test server feature flags are properly formatted."""
|
||||
features = feature_flags.get_server_features()
|
||||
|
||||
# Check expected server features
|
||||
assert "supports_preview_metadata" in features
|
||||
assert features["supports_preview_metadata"] is True
|
||||
assert "max_upload_size" in features
|
||||
assert isinstance(features["max_upload_size"], (int, float))
|
||||
|
||||
def test_progress_py_checks_feature_flags(self):
|
||||
"""Test that progress.py checks feature flags before sending metadata."""
|
||||
# This simulates the check in progress.py
|
||||
client_id = "test_client"
|
||||
sockets_metadata = {"test_client": {"feature_flags": {}}}
|
||||
|
||||
# The actual check would be in progress.py
|
||||
supports_metadata = feature_flags.supports_feature(
|
||||
sockets_metadata, client_id, "supports_preview_metadata"
|
||||
)
|
||||
|
||||
assert supports_metadata is False
|
||||
|
||||
def test_multiple_clients_different_features(self):
|
||||
"""Test handling multiple clients with different feature support."""
|
||||
sockets_metadata = {
|
||||
"modern_client": {
|
||||
"feature_flags": {"supports_preview_metadata": True}
|
||||
},
|
||||
"legacy_client": {
|
||||
"feature_flags": {}
|
||||
}
|
||||
}
|
||||
|
||||
# Check modern client
|
||||
assert feature_flags.supports_feature(
|
||||
sockets_metadata, "modern_client", "supports_preview_metadata"
|
||||
) is True
|
||||
|
||||
# Check legacy client
|
||||
assert feature_flags.supports_feature(
|
||||
sockets_metadata, "legacy_client", "supports_preview_metadata"
|
||||
) is False
|
||||
|
||||
def test_feature_negotiation_message_format(self):
|
||||
"""Test the format of feature negotiation messages."""
|
||||
# Client message format
|
||||
client_message = {
|
||||
"type": "feature_flags",
|
||||
"data": {
|
||||
"supports_preview_metadata": True,
|
||||
"api_version": "1.0.0"
|
||||
}
|
||||
}
|
||||
|
||||
# Verify structure
|
||||
assert client_message["type"] == "feature_flags"
|
||||
assert "supports_preview_metadata" in client_message["data"]
|
||||
|
||||
# Server response format (what would be sent)
|
||||
server_features = feature_flags.get_server_features()
|
||||
server_message = {
|
||||
"type": "feature_flags",
|
||||
"data": server_features
|
||||
}
|
||||
|
||||
# Verify structure
|
||||
assert server_message["type"] == "feature_flags"
|
||||
assert "supports_preview_metadata" in server_message["data"]
|
||||
assert server_message["data"]["supports_preview_metadata"] is True
|
@@ -1,4 +1,4 @@
|
||||
# Config for testing nodes
|
||||
testing:
|
||||
custom_nodes: tests/inference/testing_nodes
|
||||
custom_nodes: testing_nodes
|
||||
|
||||
|
410
tests/inference/test_async_nodes.py
Normal file
410
tests/inference/test_async_nodes.py
Normal file
@@ -0,0 +1,410 @@
|
||||
import pytest
|
||||
import time
|
||||
import torch
|
||||
import urllib.error
|
||||
import numpy as np
|
||||
import subprocess
|
||||
|
||||
from pytest import fixture
|
||||
from comfy_execution.graph_utils import GraphBuilder
|
||||
from tests.inference.test_execution import ComfyClient
|
||||
|
||||
|
||||
@pytest.mark.execution
|
||||
class TestAsyncNodes:
|
||||
@fixture(scope="class", autouse=True, params=[
|
||||
(False, 0),
|
||||
(True, 0),
|
||||
(True, 100),
|
||||
])
|
||||
def _server(self, args_pytest, request):
|
||||
pargs = [
|
||||
'python','main.py',
|
||||
'--output-directory', args_pytest["output_dir"],
|
||||
'--listen', args_pytest["listen"],
|
||||
'--port', str(args_pytest["port"]),
|
||||
'--extra-model-paths-config', 'tests/inference/extra_model_paths.yaml',
|
||||
]
|
||||
use_lru, lru_size = request.param
|
||||
if use_lru:
|
||||
pargs += ['--cache-lru', str(lru_size)]
|
||||
# Running server with args: pargs
|
||||
p = subprocess.Popen(pargs)
|
||||
yield
|
||||
p.kill()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
@fixture(scope="class", autouse=True)
|
||||
def shared_client(self, args_pytest, _server):
|
||||
client = ComfyClient()
|
||||
n_tries = 5
|
||||
for i in range(n_tries):
|
||||
time.sleep(4)
|
||||
try:
|
||||
client.connect(listen=args_pytest["listen"], port=args_pytest["port"])
|
||||
except ConnectionRefusedError:
|
||||
# Retrying...
|
||||
pass
|
||||
else:
|
||||
break
|
||||
yield client
|
||||
del client
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
@fixture
|
||||
def client(self, shared_client, request):
|
||||
shared_client.set_test_name(f"async_nodes[{request.node.name}]")
|
||||
yield shared_client
|
||||
|
||||
@fixture
|
||||
def builder(self, request):
|
||||
yield GraphBuilder(prefix=request.node.name)
|
||||
|
||||
# Happy Path Tests
|
||||
|
||||
def test_basic_async_execution(self, client: ComfyClient, builder: GraphBuilder):
|
||||
"""Test that a basic async node executes correctly."""
|
||||
g = builder
|
||||
image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||
sleep_node = g.node("TestSleep", value=image.out(0), seconds=0.1)
|
||||
output = g.node("SaveImage", images=sleep_node.out(0))
|
||||
|
||||
result = client.run(g)
|
||||
|
||||
# Verify execution completed
|
||||
assert result.did_run(sleep_node), "Async sleep node should have executed"
|
||||
assert result.did_run(output), "Output node should have executed"
|
||||
|
||||
# Verify the image passed through correctly
|
||||
result_images = result.get_images(output)
|
||||
assert len(result_images) == 1, "Should have 1 image"
|
||||
assert np.array(result_images[0]).min() == 0 and np.array(result_images[0]).max() == 0, "Image should be black"
|
||||
|
||||
def test_multiple_async_parallel_execution(self, client: ComfyClient, builder: GraphBuilder):
|
||||
"""Test that multiple async nodes execute in parallel."""
|
||||
g = builder
|
||||
image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||
|
||||
# Create multiple async sleep nodes with different durations
|
||||
sleep1 = g.node("TestSleep", value=image.out(0), seconds=0.3)
|
||||
sleep2 = g.node("TestSleep", value=image.out(0), seconds=0.4)
|
||||
sleep3 = g.node("TestSleep", value=image.out(0), seconds=0.5)
|
||||
|
||||
# Add outputs for each
|
||||
_output1 = g.node("PreviewImage", images=sleep1.out(0))
|
||||
_output2 = g.node("PreviewImage", images=sleep2.out(0))
|
||||
_output3 = g.node("PreviewImage", images=sleep3.out(0))
|
||||
|
||||
start_time = time.time()
|
||||
result = client.run(g)
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
# Should take ~0.5s (max duration) not 1.2s (sum of durations)
|
||||
assert elapsed_time < 0.8, f"Parallel execution took {elapsed_time}s, expected < 0.8s"
|
||||
|
||||
# Verify all nodes executed
|
||||
assert result.did_run(sleep1) and result.did_run(sleep2) and result.did_run(sleep3)
|
||||
|
||||
def test_async_with_dependencies(self, client: ComfyClient, builder: GraphBuilder):
|
||||
"""Test async nodes with proper dependency handling."""
|
||||
g = builder
|
||||
image1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||
image2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
|
||||
|
||||
# Chain of async operations
|
||||
sleep1 = g.node("TestSleep", value=image1.out(0), seconds=0.2)
|
||||
sleep2 = g.node("TestSleep", value=image2.out(0), seconds=0.2)
|
||||
|
||||
# Average depends on both async results
|
||||
average = g.node("TestVariadicAverage", input1=sleep1.out(0), input2=sleep2.out(0))
|
||||
output = g.node("SaveImage", images=average.out(0))
|
||||
|
||||
result = client.run(g)
|
||||
|
||||
# Verify execution order
|
||||
assert result.did_run(sleep1) and result.did_run(sleep2)
|
||||
assert result.did_run(average) and result.did_run(output)
|
||||
|
||||
# Verify averaged result
|
||||
result_images = result.get_images(output)
|
||||
avg_value = np.array(result_images[0]).mean()
|
||||
assert abs(avg_value - 127.5) < 1, f"Average value {avg_value} should be ~127.5"
|
||||
|
||||
def test_async_validate_inputs(self, client: ComfyClient, builder: GraphBuilder):
|
||||
"""Test async VALIDATE_INPUTS function."""
|
||||
g = builder
|
||||
# Create a test node with async validation
|
||||
validation_node = g.node("TestAsyncValidation", value=5.0, threshold=10.0)
|
||||
g.node("SaveImage", images=validation_node.out(0))
|
||||
|
||||
# Should pass validation
|
||||
result = client.run(g)
|
||||
assert result.did_run(validation_node)
|
||||
|
||||
# Test validation failure
|
||||
validation_node.inputs['threshold'] = 3.0 # Will fail since value > threshold
|
||||
with pytest.raises(urllib.error.HTTPError):
|
||||
client.run(g)
|
||||
|
||||
def test_async_lazy_evaluation(self, client: ComfyClient, builder: GraphBuilder):
|
||||
"""Test async nodes with lazy evaluation."""
|
||||
g = builder
|
||||
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||
input2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
|
||||
mask = g.node("StubMask", value=0.0, height=512, width=512, batch_size=1)
|
||||
|
||||
# Create async nodes that will be evaluated lazily
|
||||
sleep1 = g.node("TestSleep", value=input1.out(0), seconds=0.3)
|
||||
sleep2 = g.node("TestSleep", value=input2.out(0), seconds=0.3)
|
||||
|
||||
# Use lazy mix that only needs sleep1 (mask=0.0)
|
||||
lazy_mix = g.node("TestLazyMixImages", image1=sleep1.out(0), image2=sleep2.out(0), mask=mask.out(0))
|
||||
g.node("SaveImage", images=lazy_mix.out(0))
|
||||
|
||||
start_time = time.time()
|
||||
result = client.run(g)
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
# Should only execute sleep1, not sleep2
|
||||
assert elapsed_time < 0.5, f"Should skip sleep2, took {elapsed_time}s"
|
||||
assert result.did_run(sleep1), "Sleep1 should have executed"
|
||||
assert not result.did_run(sleep2), "Sleep2 should have been skipped"
|
||||
|
||||
def test_async_check_lazy_status(self, client: ComfyClient, builder: GraphBuilder):
|
||||
"""Test async check_lazy_status function."""
|
||||
g = builder
|
||||
# Create a node with async check_lazy_status
|
||||
lazy_node = g.node("TestAsyncLazyCheck",
|
||||
input1="value1",
|
||||
input2="value2",
|
||||
condition=True)
|
||||
g.node("SaveImage", images=lazy_node.out(0))
|
||||
|
||||
result = client.run(g)
|
||||
assert result.did_run(lazy_node)
|
||||
|
||||
# Error Handling Tests
|
||||
|
||||
def test_async_execution_error(self, client: ComfyClient, builder: GraphBuilder):
|
||||
"""Test that async execution errors are properly handled."""
|
||||
g = builder
|
||||
image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||
# Create an async node that will error
|
||||
error_node = g.node("TestAsyncError", value=image.out(0), error_after=0.1)
|
||||
g.node("SaveImage", images=error_node.out(0))
|
||||
|
||||
try:
|
||||
client.run(g)
|
||||
assert False, "Should have raised an error"
|
||||
except Exception as e:
|
||||
assert 'prompt_id' in e.args[0], f"Did not get proper error message: {e}"
|
||||
assert e.args[0]['node_id'] == error_node.id, "Error should be from async error node"
|
||||
|
||||
def test_async_validation_error(self, client: ComfyClient, builder: GraphBuilder):
|
||||
"""Test async validation error handling."""
|
||||
g = builder
|
||||
# Node with async validation that will fail
|
||||
validation_node = g.node("TestAsyncValidationError", value=15.0, max_value=10.0)
|
||||
g.node("SaveImage", images=validation_node.out(0))
|
||||
|
||||
with pytest.raises(urllib.error.HTTPError) as exc_info:
|
||||
client.run(g)
|
||||
# Verify it's a validation error
|
||||
assert exc_info.value.code == 400
|
||||
|
||||
def test_async_timeout_handling(self, client: ComfyClient, builder: GraphBuilder):
|
||||
"""Test handling of async operations that timeout."""
|
||||
g = builder
|
||||
image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||
# Very long sleep that would timeout
|
||||
timeout_node = g.node("TestAsyncTimeout", value=image.out(0), timeout=0.5, operation_time=2.0)
|
||||
g.node("SaveImage", images=timeout_node.out(0))
|
||||
|
||||
try:
|
||||
client.run(g)
|
||||
assert False, "Should have raised a timeout error"
|
||||
except Exception as e:
|
||||
assert 'timeout' in str(e).lower(), f"Expected timeout error, got: {e}"
|
||||
|
||||
def test_concurrent_async_error_recovery(self, client: ComfyClient, builder: GraphBuilder):
|
||||
"""Test that workflow can recover after async errors."""
|
||||
g = builder
|
||||
image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||
|
||||
# First run with error
|
||||
error_node = g.node("TestAsyncError", value=image.out(0), error_after=0.1)
|
||||
g.node("SaveImage", images=error_node.out(0))
|
||||
|
||||
try:
|
||||
client.run(g)
|
||||
except Exception:
|
||||
pass # Expected
|
||||
|
||||
# Second run should succeed
|
||||
g2 = GraphBuilder(prefix="recovery_test")
|
||||
image2 = g2.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
|
||||
sleep_node = g2.node("TestSleep", value=image2.out(0), seconds=0.1)
|
||||
g2.node("SaveImage", images=sleep_node.out(0))
|
||||
|
||||
result = client.run(g2)
|
||||
assert result.did_run(sleep_node), "Should be able to run after error"
|
||||
|
||||
def test_sync_error_during_async_execution(self, client: ComfyClient, builder: GraphBuilder):
|
||||
"""Test handling when sync node errors while async node is executing."""
|
||||
g = builder
|
||||
image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||
|
||||
# Async node that takes time
|
||||
sleep_node = g.node("TestSleep", value=image.out(0), seconds=0.5)
|
||||
|
||||
# Sync node that will error immediately
|
||||
error_node = g.node("TestSyncError", value=image.out(0))
|
||||
|
||||
# Both feed into output
|
||||
g.node("PreviewImage", images=sleep_node.out(0))
|
||||
g.node("PreviewImage", images=error_node.out(0))
|
||||
|
||||
try:
|
||||
client.run(g)
|
||||
assert False, "Should have raised an error"
|
||||
except Exception as e:
|
||||
# Verify the sync error was caught even though async was running
|
||||
assert 'prompt_id' in e.args[0]
|
||||
|
||||
# Edge Cases
|
||||
|
||||
def test_async_with_execution_blocker(self, client: ComfyClient, builder: GraphBuilder):
|
||||
"""Test async nodes with execution blockers."""
|
||||
g = builder
|
||||
image1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||
image2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
|
||||
|
||||
# Async sleep nodes
|
||||
sleep1 = g.node("TestSleep", value=image1.out(0), seconds=0.2)
|
||||
sleep2 = g.node("TestSleep", value=image2.out(0), seconds=0.2)
|
||||
|
||||
# Create list of images
|
||||
image_list = g.node("TestMakeListNode", value1=sleep1.out(0), value2=sleep2.out(0))
|
||||
|
||||
# Create list of blocking conditions - [False, True] to block only the second item
|
||||
int1 = g.node("StubInt", value=1)
|
||||
int2 = g.node("StubInt", value=2)
|
||||
block_list = g.node("TestMakeListNode", value1=int1.out(0), value2=int2.out(0))
|
||||
|
||||
# Compare each value against 2, so first is False (1 != 2) and second is True (2 == 2)
|
||||
compare = g.node("TestIntConditions", a=block_list.out(0), b=2, operation="==")
|
||||
|
||||
# Block based on the comparison results
|
||||
blocker = g.node("TestExecutionBlocker", input=image_list.out(0), block=compare.out(0), verbose=False)
|
||||
|
||||
output = g.node("PreviewImage", images=blocker.out(0))
|
||||
|
||||
result = client.run(g)
|
||||
images = result.get_images(output)
|
||||
assert len(images) == 1, "Should have blocked second image"
|
||||
|
||||
def test_async_caching_behavior(self, client: ComfyClient, builder: GraphBuilder):
|
||||
"""Test that async nodes are properly cached."""
|
||||
g = builder
|
||||
image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||
sleep_node = g.node("TestSleep", value=image.out(0), seconds=0.2)
|
||||
g.node("SaveImage", images=sleep_node.out(0))
|
||||
|
||||
# First run
|
||||
result1 = client.run(g)
|
||||
assert result1.did_run(sleep_node), "Should run first time"
|
||||
|
||||
# Second run - should be cached
|
||||
start_time = time.time()
|
||||
result2 = client.run(g)
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
assert not result2.did_run(sleep_node), "Should be cached"
|
||||
assert elapsed_time < 0.1, f"Cached run took {elapsed_time}s, should be instant"
|
||||
|
||||
def test_async_with_dynamic_prompts(self, client: ComfyClient, builder: GraphBuilder):
|
||||
"""Test async nodes within dynamically generated prompts."""
|
||||
g = builder
|
||||
image1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||
image2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
|
||||
|
||||
# Node that generates async nodes dynamically
|
||||
dynamic_async = g.node("TestDynamicAsyncGeneration",
|
||||
image1=image1.out(0),
|
||||
image2=image2.out(0),
|
||||
num_async_nodes=3,
|
||||
sleep_duration=0.2)
|
||||
g.node("SaveImage", images=dynamic_async.out(0))
|
||||
|
||||
start_time = time.time()
|
||||
result = client.run(g)
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
# Should execute async nodes in parallel within dynamic prompt
|
||||
assert elapsed_time < 0.5, f"Dynamic async execution took {elapsed_time}s"
|
||||
assert result.did_run(dynamic_async)
|
||||
|
||||
def test_async_resource_cleanup(self, client: ComfyClient, builder: GraphBuilder):
|
||||
"""Test that async resources are properly cleaned up."""
|
||||
g = builder
|
||||
image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||
|
||||
# Create multiple async nodes that use resources
|
||||
resource_nodes = []
|
||||
for i in range(5):
|
||||
node = g.node("TestAsyncResourceUser",
|
||||
value=image.out(0),
|
||||
resource_id=f"resource_{i}",
|
||||
duration=0.1)
|
||||
resource_nodes.append(node)
|
||||
g.node("PreviewImage", images=node.out(0))
|
||||
|
||||
result = client.run(g)
|
||||
|
||||
# Verify all nodes executed
|
||||
for node in resource_nodes:
|
||||
assert result.did_run(node)
|
||||
|
||||
# Run again to ensure resources were cleaned up
|
||||
result2 = client.run(g)
|
||||
# Should be cached but not error due to resource conflicts
|
||||
for node in resource_nodes:
|
||||
assert not result2.did_run(node), "Should be cached"
|
||||
|
||||
def test_async_cancellation(self, client: ComfyClient, builder: GraphBuilder):
|
||||
"""Test cancellation of async operations."""
|
||||
# This would require implementing cancellation in the client
|
||||
# For now, we'll test that long-running async operations can be interrupted
|
||||
pass # TODO: Implement when cancellation API is available
|
||||
|
||||
def test_mixed_sync_async_execution(self, client: ComfyClient, builder: GraphBuilder):
|
||||
"""Test workflows with both sync and async nodes."""
|
||||
g = builder
|
||||
image1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||
image2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
|
||||
mask = g.node("StubMask", value=0.5, height=512, width=512, batch_size=1)
|
||||
|
||||
# Mix of sync and async operations
|
||||
# Sync: lazy mix images
|
||||
sync_op1 = g.node("TestLazyMixImages", image1=image1.out(0), image2=image2.out(0), mask=mask.out(0))
|
||||
# Async: sleep
|
||||
async_op1 = g.node("TestSleep", value=sync_op1.out(0), seconds=0.2)
|
||||
# Sync: custom validation
|
||||
sync_op2 = g.node("TestCustomValidation1", input1=async_op1.out(0), input2=0.5)
|
||||
# Async: sleep again
|
||||
async_op2 = g.node("TestSleep", value=sync_op2.out(0), seconds=0.2)
|
||||
|
||||
output = g.node("SaveImage", images=async_op2.out(0))
|
||||
|
||||
result = client.run(g)
|
||||
|
||||
# Verify all nodes executed in correct order
|
||||
assert result.did_run(sync_op1)
|
||||
assert result.did_run(async_op1)
|
||||
assert result.did_run(sync_op2)
|
||||
assert result.did_run(async_op2)
|
||||
|
||||
# Image should be a mix of black and white (gray)
|
||||
result_images = result.get_images(output)
|
||||
avg_value = np.array(result_images[0]).mean()
|
||||
assert abs(avg_value - 63.75) < 5, f"Average value {avg_value} should be ~63.75"
|
@@ -252,7 +252,7 @@ class TestExecution:
|
||||
|
||||
@pytest.mark.parametrize("test_type, test_value", [
|
||||
("StubInt", 5),
|
||||
("StubFloat", 5.0)
|
||||
("StubMask", 5.0)
|
||||
])
|
||||
def test_validation_error_edge1(self, test_type, test_value, client: ComfyClient, builder: GraphBuilder):
|
||||
g = builder
|
||||
@@ -497,6 +497,69 @@ class TestExecution:
|
||||
assert numpy.array(images[0]).min() == 63 and numpy.array(images[0]).max() == 63, "Image should have value 0.25"
|
||||
assert not result.did_run(test_node), "The execution should have been cached"
|
||||
|
||||
def test_parallel_sleep_nodes(self, client: ComfyClient, builder: GraphBuilder):
|
||||
g = builder
|
||||
image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||
|
||||
# Create sleep nodes for each duration
|
||||
sleep_node1 = g.node("TestSleep", value=image.out(0), seconds=2.8)
|
||||
sleep_node2 = g.node("TestSleep", value=image.out(0), seconds=2.9)
|
||||
sleep_node3 = g.node("TestSleep", value=image.out(0), seconds=3.0)
|
||||
|
||||
# Add outputs to verify the execution
|
||||
_output1 = g.node("PreviewImage", images=sleep_node1.out(0))
|
||||
_output2 = g.node("PreviewImage", images=sleep_node2.out(0))
|
||||
_output3 = g.node("PreviewImage", images=sleep_node3.out(0))
|
||||
|
||||
start_time = time.time()
|
||||
result = client.run(g)
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
# The test should take around 0.4 seconds (the longest sleep duration)
|
||||
# plus some overhead, but definitely less than the sum of all sleeps (0.9s)
|
||||
# We'll allow for up to 0.8s total to account for overhead
|
||||
assert elapsed_time < 4.0, f"Parallel execution took {elapsed_time}s, expected less than 0.8s"
|
||||
|
||||
# Verify that all nodes executed
|
||||
assert result.did_run(sleep_node1), "Sleep node 1 should have run"
|
||||
assert result.did_run(sleep_node2), "Sleep node 2 should have run"
|
||||
assert result.did_run(sleep_node3), "Sleep node 3 should have run"
|
||||
|
||||
def test_parallel_sleep_expansion(self, client: ComfyClient, builder: GraphBuilder):
|
||||
g = builder
|
||||
# Create input images with different values
|
||||
image1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||
image2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
|
||||
image3 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
|
||||
|
||||
# Create a TestParallelSleep node that expands into multiple TestSleep nodes
|
||||
parallel_sleep = g.node("TestParallelSleep",
|
||||
image1=image1.out(0),
|
||||
image2=image2.out(0),
|
||||
image3=image3.out(0),
|
||||
sleep1=0.4,
|
||||
sleep2=0.5,
|
||||
sleep3=0.6)
|
||||
output = g.node("SaveImage", images=parallel_sleep.out(0))
|
||||
|
||||
start_time = time.time()
|
||||
result = client.run(g)
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
# Similar to the previous test, expect parallel execution of the sleep nodes
|
||||
# which should complete in less than the sum of all sleeps
|
||||
assert elapsed_time < 0.8, f"Expansion execution took {elapsed_time}s, expected less than 0.8s"
|
||||
|
||||
# Verify the parallel sleep node executed
|
||||
assert result.did_run(parallel_sleep), "ParallelSleep node should have run"
|
||||
|
||||
# Verify we get an image as output (blend of the three input images)
|
||||
result_images = result.get_images(output)
|
||||
assert len(result_images) == 1, "Should have 1 image"
|
||||
# Average pixel value should be around 170 (255 * 2 // 3)
|
||||
avg_value = numpy.array(result_images[0]).mean()
|
||||
assert avg_value == 170, f"Image average value {avg_value} should be 170"
|
||||
|
||||
# This tests that nodes with OUTPUT_IS_LIST function correctly when they receive an ExecutionBlocker
|
||||
# as input. We also test that when that list (containing an ExecutionBlocker) is passed to a node,
|
||||
# only that one entry in the list is blocked.
|
||||
|
@@ -1,23 +1,26 @@
|
||||
from .specific_tests import TEST_NODE_CLASS_MAPPINGS, TEST_NODE_DISPLAY_NAME_MAPPINGS
|
||||
from .flow_control import FLOW_CONTROL_NODE_CLASS_MAPPINGS, FLOW_CONTROL_NODE_DISPLAY_NAME_MAPPINGS
|
||||
from .util import UTILITY_NODE_CLASS_MAPPINGS, UTILITY_NODE_DISPLAY_NAME_MAPPINGS
|
||||
from .conditions import CONDITION_NODE_CLASS_MAPPINGS, CONDITION_NODE_DISPLAY_NAME_MAPPINGS
|
||||
from .stubs import TEST_STUB_NODE_CLASS_MAPPINGS, TEST_STUB_NODE_DISPLAY_NAME_MAPPINGS
|
||||
|
||||
# NODE_CLASS_MAPPINGS = GENERAL_NODE_CLASS_MAPPINGS.update(COMPONENT_NODE_CLASS_MAPPINGS)
|
||||
# NODE_DISPLAY_NAME_MAPPINGS = GENERAL_NODE_DISPLAY_NAME_MAPPINGS.update(COMPONENT_NODE_DISPLAY_NAME_MAPPINGS)
|
||||
|
||||
NODE_CLASS_MAPPINGS = {}
|
||||
NODE_CLASS_MAPPINGS.update(TEST_NODE_CLASS_MAPPINGS)
|
||||
NODE_CLASS_MAPPINGS.update(FLOW_CONTROL_NODE_CLASS_MAPPINGS)
|
||||
NODE_CLASS_MAPPINGS.update(UTILITY_NODE_CLASS_MAPPINGS)
|
||||
NODE_CLASS_MAPPINGS.update(CONDITION_NODE_CLASS_MAPPINGS)
|
||||
NODE_CLASS_MAPPINGS.update(TEST_STUB_NODE_CLASS_MAPPINGS)
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {}
|
||||
NODE_DISPLAY_NAME_MAPPINGS.update(TEST_NODE_DISPLAY_NAME_MAPPINGS)
|
||||
NODE_DISPLAY_NAME_MAPPINGS.update(FLOW_CONTROL_NODE_DISPLAY_NAME_MAPPINGS)
|
||||
NODE_DISPLAY_NAME_MAPPINGS.update(UTILITY_NODE_DISPLAY_NAME_MAPPINGS)
|
||||
NODE_DISPLAY_NAME_MAPPINGS.update(CONDITION_NODE_DISPLAY_NAME_MAPPINGS)
|
||||
NODE_DISPLAY_NAME_MAPPINGS.update(TEST_STUB_NODE_DISPLAY_NAME_MAPPINGS)
|
||||
|
||||
from .specific_tests import TEST_NODE_CLASS_MAPPINGS, TEST_NODE_DISPLAY_NAME_MAPPINGS
|
||||
from .flow_control import FLOW_CONTROL_NODE_CLASS_MAPPINGS, FLOW_CONTROL_NODE_DISPLAY_NAME_MAPPINGS
|
||||
from .util import UTILITY_NODE_CLASS_MAPPINGS, UTILITY_NODE_DISPLAY_NAME_MAPPINGS
|
||||
from .conditions import CONDITION_NODE_CLASS_MAPPINGS, CONDITION_NODE_DISPLAY_NAME_MAPPINGS
|
||||
from .stubs import TEST_STUB_NODE_CLASS_MAPPINGS, TEST_STUB_NODE_DISPLAY_NAME_MAPPINGS
|
||||
from .async_test_nodes import ASYNC_TEST_NODE_CLASS_MAPPINGS, ASYNC_TEST_NODE_DISPLAY_NAME_MAPPINGS
|
||||
|
||||
# NODE_CLASS_MAPPINGS = GENERAL_NODE_CLASS_MAPPINGS.update(COMPONENT_NODE_CLASS_MAPPINGS)
|
||||
# NODE_DISPLAY_NAME_MAPPINGS = GENERAL_NODE_DISPLAY_NAME_MAPPINGS.update(COMPONENT_NODE_DISPLAY_NAME_MAPPINGS)
|
||||
|
||||
NODE_CLASS_MAPPINGS = {}
|
||||
NODE_CLASS_MAPPINGS.update(TEST_NODE_CLASS_MAPPINGS)
|
||||
NODE_CLASS_MAPPINGS.update(FLOW_CONTROL_NODE_CLASS_MAPPINGS)
|
||||
NODE_CLASS_MAPPINGS.update(UTILITY_NODE_CLASS_MAPPINGS)
|
||||
NODE_CLASS_MAPPINGS.update(CONDITION_NODE_CLASS_MAPPINGS)
|
||||
NODE_CLASS_MAPPINGS.update(TEST_STUB_NODE_CLASS_MAPPINGS)
|
||||
NODE_CLASS_MAPPINGS.update(ASYNC_TEST_NODE_CLASS_MAPPINGS)
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {}
|
||||
NODE_DISPLAY_NAME_MAPPINGS.update(TEST_NODE_DISPLAY_NAME_MAPPINGS)
|
||||
NODE_DISPLAY_NAME_MAPPINGS.update(FLOW_CONTROL_NODE_DISPLAY_NAME_MAPPINGS)
|
||||
NODE_DISPLAY_NAME_MAPPINGS.update(UTILITY_NODE_DISPLAY_NAME_MAPPINGS)
|
||||
NODE_DISPLAY_NAME_MAPPINGS.update(CONDITION_NODE_DISPLAY_NAME_MAPPINGS)
|
||||
NODE_DISPLAY_NAME_MAPPINGS.update(TEST_STUB_NODE_DISPLAY_NAME_MAPPINGS)
|
||||
NODE_DISPLAY_NAME_MAPPINGS.update(ASYNC_TEST_NODE_DISPLAY_NAME_MAPPINGS)
|
||||
|
||||
|
343
tests/inference/testing_nodes/testing-pack/async_test_nodes.py
Normal file
343
tests/inference/testing_nodes/testing-pack/async_test_nodes.py
Normal file
@@ -0,0 +1,343 @@
|
||||
import torch
|
||||
import asyncio
|
||||
from typing import Dict
|
||||
from comfy.utils import ProgressBar
|
||||
from comfy_execution.graph_utils import GraphBuilder
|
||||
from comfy.comfy_types.node_typing import ComfyNodeABC
|
||||
from comfy.comfy_types import IO
|
||||
|
||||
|
||||
class TestAsyncValidation(ComfyNodeABC):
|
||||
"""Test node with async VALIDATE_INPUTS."""
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"value": ("FLOAT", {"default": 5.0}),
|
||||
"threshold": ("FLOAT", {"default": 10.0}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "process"
|
||||
CATEGORY = "_for_testing/async"
|
||||
|
||||
@classmethod
|
||||
async def VALIDATE_INPUTS(cls, value, threshold):
|
||||
# Simulate async validation (e.g., checking remote service)
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
if value > threshold:
|
||||
return f"Value {value} exceeds threshold {threshold}"
|
||||
return True
|
||||
|
||||
def process(self, value, threshold):
|
||||
# Create image based on value
|
||||
intensity = value / 10.0
|
||||
image = torch.ones([1, 512, 512, 3]) * intensity
|
||||
return (image,)
|
||||
|
||||
|
||||
class TestAsyncError(ComfyNodeABC):
|
||||
"""Test node that errors during async execution."""
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"value": (IO.ANY, {}),
|
||||
"error_after": ("FLOAT", {"default": 0.1, "min": 0.0, "max": 10.0}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = (IO.ANY,)
|
||||
FUNCTION = "error_execution"
|
||||
CATEGORY = "_for_testing/async"
|
||||
|
||||
async def error_execution(self, value, error_after):
|
||||
await asyncio.sleep(error_after)
|
||||
raise RuntimeError("Intentional async execution error for testing")
|
||||
|
||||
|
||||
class TestAsyncValidationError(ComfyNodeABC):
|
||||
"""Test node with async validation that always fails."""
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"value": ("FLOAT", {"default": 5.0}),
|
||||
"max_value": ("FLOAT", {"default": 10.0}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "process"
|
||||
CATEGORY = "_for_testing/async"
|
||||
|
||||
@classmethod
|
||||
async def VALIDATE_INPUTS(cls, value, max_value):
|
||||
await asyncio.sleep(0.05)
|
||||
# Always fail validation for values > max_value
|
||||
if value > max_value:
|
||||
return f"Async validation failed: {value} > {max_value}"
|
||||
return True
|
||||
|
||||
def process(self, value, max_value):
|
||||
# This won't be reached if validation fails
|
||||
image = torch.ones([1, 512, 512, 3]) * (value / max_value)
|
||||
return (image,)
|
||||
|
||||
|
||||
class TestAsyncTimeout(ComfyNodeABC):
|
||||
"""Test node that simulates timeout scenarios."""
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"value": (IO.ANY, {}),
|
||||
"timeout": ("FLOAT", {"default": 1.0, "min": 0.1, "max": 10.0}),
|
||||
"operation_time": ("FLOAT", {"default": 2.0, "min": 0.1, "max": 10.0}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = (IO.ANY,)
|
||||
FUNCTION = "timeout_execution"
|
||||
CATEGORY = "_for_testing/async"
|
||||
|
||||
async def timeout_execution(self, value, timeout, operation_time):
|
||||
try:
|
||||
# This will timeout if operation_time > timeout
|
||||
await asyncio.wait_for(asyncio.sleep(operation_time), timeout=timeout)
|
||||
return (value,)
|
||||
except asyncio.TimeoutError:
|
||||
raise RuntimeError(f"Operation timed out after {timeout} seconds")
|
||||
|
||||
|
||||
class TestSyncError(ComfyNodeABC):
|
||||
"""Test node that errors synchronously (for mixed sync/async testing)."""
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"value": (IO.ANY, {}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = (IO.ANY,)
|
||||
FUNCTION = "sync_error"
|
||||
CATEGORY = "_for_testing/async"
|
||||
|
||||
def sync_error(self, value):
|
||||
raise RuntimeError("Intentional sync execution error for testing")
|
||||
|
||||
|
||||
class TestAsyncLazyCheck(ComfyNodeABC):
|
||||
"""Test node with async check_lazy_status."""
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"input1": (IO.ANY, {"lazy": True}),
|
||||
"input2": (IO.ANY, {"lazy": True}),
|
||||
"condition": ("BOOLEAN", {"default": True}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "process"
|
||||
CATEGORY = "_for_testing/async"
|
||||
|
||||
async def check_lazy_status(self, condition, input1, input2):
|
||||
# Simulate async checking (e.g., querying remote service)
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
needed = []
|
||||
if condition and input1 is None:
|
||||
needed.append("input1")
|
||||
if not condition and input2 is None:
|
||||
needed.append("input2")
|
||||
return needed
|
||||
|
||||
def process(self, input1, input2, condition):
|
||||
# Return a simple image
|
||||
return (torch.ones([1, 512, 512, 3]),)
|
||||
|
||||
|
||||
class TestDynamicAsyncGeneration(ComfyNodeABC):
|
||||
"""Test node that dynamically generates async nodes."""
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"image1": ("IMAGE",),
|
||||
"image2": ("IMAGE",),
|
||||
"num_async_nodes": ("INT", {"default": 3, "min": 1, "max": 10}),
|
||||
"sleep_duration": ("FLOAT", {"default": 0.2, "min": 0.1, "max": 1.0}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "generate_async_workflow"
|
||||
CATEGORY = "_for_testing/async"
|
||||
|
||||
def generate_async_workflow(self, image1, image2, num_async_nodes, sleep_duration):
|
||||
g = GraphBuilder()
|
||||
|
||||
# Create multiple async sleep nodes
|
||||
sleep_nodes = []
|
||||
for i in range(num_async_nodes):
|
||||
image = image1 if i % 2 == 0 else image2
|
||||
sleep_node = g.node("TestSleep", value=image, seconds=sleep_duration)
|
||||
sleep_nodes.append(sleep_node)
|
||||
|
||||
# Average all results
|
||||
if len(sleep_nodes) == 1:
|
||||
final_node = sleep_nodes[0]
|
||||
else:
|
||||
avg_inputs = {"input1": sleep_nodes[0].out(0)}
|
||||
for i, node in enumerate(sleep_nodes[1:], 2):
|
||||
avg_inputs[f"input{i}"] = node.out(0)
|
||||
final_node = g.node("TestVariadicAverage", **avg_inputs)
|
||||
|
||||
return {
|
||||
"result": (final_node.out(0),),
|
||||
"expand": g.finalize(),
|
||||
}
|
||||
|
||||
|
||||
class TestAsyncResourceUser(ComfyNodeABC):
|
||||
"""Test node that uses resources during async execution."""
|
||||
|
||||
# Class-level resource tracking for testing
|
||||
_active_resources: Dict[str, bool] = {}
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"value": (IO.ANY, {}),
|
||||
"resource_id": ("STRING", {"default": "resource_0"}),
|
||||
"duration": ("FLOAT", {"default": 0.1, "min": 0.0, "max": 1.0}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = (IO.ANY,)
|
||||
FUNCTION = "use_resource"
|
||||
CATEGORY = "_for_testing/async"
|
||||
|
||||
async def use_resource(self, value, resource_id, duration):
|
||||
# Check if resource is already in use
|
||||
if self._active_resources.get(resource_id, False):
|
||||
raise RuntimeError(f"Resource {resource_id} is already in use!")
|
||||
|
||||
# Mark resource as in use
|
||||
self._active_resources[resource_id] = True
|
||||
|
||||
try:
|
||||
# Simulate resource usage
|
||||
await asyncio.sleep(duration)
|
||||
return (value,)
|
||||
finally:
|
||||
# Always clean up resource
|
||||
self._active_resources[resource_id] = False
|
||||
|
||||
|
||||
class TestAsyncBatchProcessing(ComfyNodeABC):
|
||||
"""Test async processing of batched inputs."""
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"images": ("IMAGE",),
|
||||
"process_time_per_item": ("FLOAT", {"default": 0.1, "min": 0.01, "max": 1.0}),
|
||||
},
|
||||
"hidden": {
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "process_batch"
|
||||
CATEGORY = "_for_testing/async"
|
||||
|
||||
async def process_batch(self, images, process_time_per_item, unique_id):
|
||||
batch_size = images.shape[0]
|
||||
pbar = ProgressBar(batch_size, node_id=unique_id)
|
||||
|
||||
# Process each image in the batch
|
||||
processed = []
|
||||
for i in range(batch_size):
|
||||
# Simulate async processing
|
||||
await asyncio.sleep(process_time_per_item)
|
||||
|
||||
# Simple processing: invert the image
|
||||
processed_image = 1.0 - images[i:i+1]
|
||||
processed.append(processed_image)
|
||||
|
||||
pbar.update(1)
|
||||
|
||||
# Stack processed images
|
||||
result = torch.cat(processed, dim=0)
|
||||
return (result,)
|
||||
|
||||
|
||||
class TestAsyncConcurrentLimit(ComfyNodeABC):
|
||||
"""Test concurrent execution limits for async nodes."""
|
||||
|
||||
_semaphore = asyncio.Semaphore(2) # Only allow 2 concurrent executions
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"value": (IO.ANY, {}),
|
||||
"duration": ("FLOAT", {"default": 0.5, "min": 0.1, "max": 2.0}),
|
||||
"node_id": ("INT", {"default": 0}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = (IO.ANY,)
|
||||
FUNCTION = "limited_execution"
|
||||
CATEGORY = "_for_testing/async"
|
||||
|
||||
async def limited_execution(self, value, duration, node_id):
|
||||
async with self._semaphore:
|
||||
# Node {node_id} acquired semaphore
|
||||
await asyncio.sleep(duration)
|
||||
# Node {node_id} releasing semaphore
|
||||
return (value,)
|
||||
|
||||
|
||||
# Add node mappings
|
||||
ASYNC_TEST_NODE_CLASS_MAPPINGS = {
|
||||
"TestAsyncValidation": TestAsyncValidation,
|
||||
"TestAsyncError": TestAsyncError,
|
||||
"TestAsyncValidationError": TestAsyncValidationError,
|
||||
"TestAsyncTimeout": TestAsyncTimeout,
|
||||
"TestSyncError": TestSyncError,
|
||||
"TestAsyncLazyCheck": TestAsyncLazyCheck,
|
||||
"TestDynamicAsyncGeneration": TestDynamicAsyncGeneration,
|
||||
"TestAsyncResourceUser": TestAsyncResourceUser,
|
||||
"TestAsyncBatchProcessing": TestAsyncBatchProcessing,
|
||||
"TestAsyncConcurrentLimit": TestAsyncConcurrentLimit,
|
||||
}
|
||||
|
||||
ASYNC_TEST_NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"TestAsyncValidation": "Test Async Validation",
|
||||
"TestAsyncError": "Test Async Error",
|
||||
"TestAsyncValidationError": "Test Async Validation Error",
|
||||
"TestAsyncTimeout": "Test Async Timeout",
|
||||
"TestSyncError": "Test Sync Error",
|
||||
"TestAsyncLazyCheck": "Test Async Lazy Check",
|
||||
"TestDynamicAsyncGeneration": "Test Dynamic Async Generation",
|
||||
"TestAsyncResourceUser": "Test Async Resource User",
|
||||
"TestAsyncBatchProcessing": "Test Async Batch Processing",
|
||||
"TestAsyncConcurrentLimit": "Test Async Concurrent Limit",
|
||||
}
|
@@ -1,6 +1,11 @@
|
||||
import torch
|
||||
import time
|
||||
import asyncio
|
||||
from comfy.utils import ProgressBar
|
||||
from .tools import VariantSupport
|
||||
from comfy_execution.graph_utils import GraphBuilder
|
||||
from comfy.comfy_types.node_typing import ComfyNodeABC
|
||||
from comfy.comfy_types import IO
|
||||
|
||||
class TestLazyMixImages:
|
||||
@classmethod
|
||||
@@ -333,6 +338,131 @@ class TestMixedExpansionReturns:
|
||||
"expand": g.finalize(),
|
||||
}
|
||||
|
||||
class TestSamplingInExpansion:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"model": ("MODEL",),
|
||||
"clip": ("CLIP",),
|
||||
"vae": ("VAE",),
|
||||
"seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
|
||||
"steps": ("INT", {"default": 20, "min": 1, "max": 100}),
|
||||
"cfg": ("FLOAT", {"default": 7.0, "min": 0.0, "max": 30.0}),
|
||||
"prompt": ("STRING", {"multiline": True, "default": "a beautiful landscape with mountains and trees"}),
|
||||
"negative_prompt": ("STRING", {"multiline": True, "default": "blurry, bad quality, worst quality"}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "sampling_in_expansion"
|
||||
|
||||
CATEGORY = "Testing/Nodes"
|
||||
|
||||
def sampling_in_expansion(self, model, clip, vae, seed, steps, cfg, prompt, negative_prompt):
|
||||
g = GraphBuilder()
|
||||
|
||||
# Create a basic image generation workflow using the input model, clip and vae
|
||||
# 1. Setup text prompts using the provided CLIP model
|
||||
positive_prompt = g.node("CLIPTextEncode",
|
||||
text=prompt,
|
||||
clip=clip)
|
||||
negative_prompt = g.node("CLIPTextEncode",
|
||||
text=negative_prompt,
|
||||
clip=clip)
|
||||
|
||||
# 2. Create empty latent with specified size
|
||||
empty_latent = g.node("EmptyLatentImage", width=512, height=512, batch_size=1)
|
||||
|
||||
# 3. Setup sampler and generate image latent
|
||||
sampler = g.node("KSampler",
|
||||
model=model,
|
||||
positive=positive_prompt.out(0),
|
||||
negative=negative_prompt.out(0),
|
||||
latent_image=empty_latent.out(0),
|
||||
seed=seed,
|
||||
steps=steps,
|
||||
cfg=cfg,
|
||||
sampler_name="euler_ancestral",
|
||||
scheduler="normal")
|
||||
|
||||
# 4. Decode latent to image using VAE
|
||||
output = g.node("VAEDecode", samples=sampler.out(0), vae=vae)
|
||||
|
||||
return {
|
||||
"result": (output.out(0),),
|
||||
"expand": g.finalize(),
|
||||
}
|
||||
|
||||
class TestSleep(ComfyNodeABC):
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"value": (IO.ANY, {}),
|
||||
"seconds": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 9999.0, "step": 0.01, "tooltip": "The amount of seconds to sleep."}),
|
||||
},
|
||||
"hidden": {
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
RETURN_TYPES = (IO.ANY,)
|
||||
FUNCTION = "sleep"
|
||||
|
||||
CATEGORY = "_for_testing"
|
||||
|
||||
async def sleep(self, value, seconds, unique_id):
|
||||
pbar = ProgressBar(seconds, node_id=unique_id)
|
||||
start = time.time()
|
||||
expiration = start + seconds
|
||||
now = start
|
||||
while now < expiration:
|
||||
now = time.time()
|
||||
pbar.update_absolute(now - start)
|
||||
await asyncio.sleep(0.01)
|
||||
return (value,)
|
||||
|
||||
class TestParallelSleep(ComfyNodeABC):
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"image1": ("IMAGE", ),
|
||||
"image2": ("IMAGE", ),
|
||||
"image3": ("IMAGE", ),
|
||||
"sleep1": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||
"sleep2": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||
"sleep3": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||
},
|
||||
"hidden": {
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "parallel_sleep"
|
||||
CATEGORY = "_for_testing"
|
||||
OUTPUT_NODE = True
|
||||
|
||||
def parallel_sleep(self, image1, image2, image3, sleep1, sleep2, sleep3, unique_id):
|
||||
# Create a graph dynamically with three TestSleep nodes
|
||||
g = GraphBuilder()
|
||||
|
||||
# Create sleep nodes for each duration and image
|
||||
sleep_node1 = g.node("TestSleep", value=image1, seconds=sleep1)
|
||||
sleep_node2 = g.node("TestSleep", value=image2, seconds=sleep2)
|
||||
sleep_node3 = g.node("TestSleep", value=image3, seconds=sleep3)
|
||||
|
||||
# Blend the results using TestVariadicAverage
|
||||
blend = g.node("TestVariadicAverage",
|
||||
input1=sleep_node1.out(0),
|
||||
input2=sleep_node2.out(0),
|
||||
input3=sleep_node3.out(0))
|
||||
|
||||
return {
|
||||
"result": (blend.out(0),),
|
||||
"expand": g.finalize(),
|
||||
}
|
||||
|
||||
TEST_NODE_CLASS_MAPPINGS = {
|
||||
"TestLazyMixImages": TestLazyMixImages,
|
||||
"TestVariadicAverage": TestVariadicAverage,
|
||||
@@ -345,6 +475,9 @@ TEST_NODE_CLASS_MAPPINGS = {
|
||||
"TestCustomValidation5": TestCustomValidation5,
|
||||
"TestDynamicDependencyCycle": TestDynamicDependencyCycle,
|
||||
"TestMixedExpansionReturns": TestMixedExpansionReturns,
|
||||
"TestSamplingInExpansion": TestSamplingInExpansion,
|
||||
"TestSleep": TestSleep,
|
||||
"TestParallelSleep": TestParallelSleep,
|
||||
}
|
||||
|
||||
TEST_NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
@@ -359,4 +492,7 @@ TEST_NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"TestCustomValidation5": "Custom Validation 5",
|
||||
"TestDynamicDependencyCycle": "Dynamic Dependency Cycle",
|
||||
"TestMixedExpansionReturns": "Mixed Expansion Returns",
|
||||
"TestSamplingInExpansion": "Sampling In Expansion",
|
||||
"TestSleep": "Test Sleep",
|
||||
"TestParallelSleep": "Test Parallel Sleep",
|
||||
}
|
||||
|
Reference in New Issue
Block a user