mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-08-02 23:14:49 +08:00
Add support for partial execution in backend (#9123)
When a prompt is submitted, it can optionally include `partial_execution_targets` as a list of ids. If it does, rather than adding all outputs to the execution list, we add only those in the list.
This commit is contained in:
@@ -7,7 +7,7 @@ import threading
|
|||||||
import time
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import List, Literal, NamedTuple, Optional
|
from typing import List, Literal, NamedTuple, Optional, Union
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -891,7 +891,7 @@ def full_type_name(klass):
|
|||||||
return klass.__qualname__
|
return klass.__qualname__
|
||||||
return module + '.' + klass.__qualname__
|
return module + '.' + klass.__qualname__
|
||||||
|
|
||||||
async def validate_prompt(prompt_id, prompt):
|
async def validate_prompt(prompt_id, prompt, partial_execution_list: Union[list[str], None]):
|
||||||
outputs = set()
|
outputs = set()
|
||||||
for x in prompt:
|
for x in prompt:
|
||||||
if 'class_type' not in prompt[x]:
|
if 'class_type' not in prompt[x]:
|
||||||
@@ -915,6 +915,7 @@ async def validate_prompt(prompt_id, prompt):
|
|||||||
return (False, error, [], {})
|
return (False, error, [], {})
|
||||||
|
|
||||||
if hasattr(class_, 'OUTPUT_NODE') and class_.OUTPUT_NODE is True:
|
if hasattr(class_, 'OUTPUT_NODE') and class_.OUTPUT_NODE is True:
|
||||||
|
if partial_execution_list is None or x in partial_execution_list:
|
||||||
outputs.add(x)
|
outputs.add(x)
|
||||||
|
|
||||||
if len(outputs) == 0:
|
if len(outputs) == 0:
|
||||||
|
@@ -681,7 +681,12 @@ class PromptServer():
|
|||||||
if "prompt" in json_data:
|
if "prompt" in json_data:
|
||||||
prompt = json_data["prompt"]
|
prompt = json_data["prompt"]
|
||||||
prompt_id = str(json_data.get("prompt_id", uuid.uuid4()))
|
prompt_id = str(json_data.get("prompt_id", uuid.uuid4()))
|
||||||
valid = await execution.validate_prompt(prompt_id, prompt)
|
|
||||||
|
partial_execution_targets = None
|
||||||
|
if "partial_execution_targets" in json_data:
|
||||||
|
partial_execution_targets = json_data["partial_execution_targets"]
|
||||||
|
|
||||||
|
valid = await execution.validate_prompt(prompt_id, prompt, partial_execution_targets)
|
||||||
extra_data = {}
|
extra_data = {}
|
||||||
if "extra_data" in json_data:
|
if "extra_data" in json_data:
|
||||||
extra_data = json_data["extra_data"]
|
extra_data = json_data["extra_data"]
|
||||||
|
@@ -7,7 +7,7 @@ import subprocess
|
|||||||
|
|
||||||
from pytest import fixture
|
from pytest import fixture
|
||||||
from comfy_execution.graph_utils import GraphBuilder
|
from comfy_execution.graph_utils import GraphBuilder
|
||||||
from tests.inference.test_execution import ComfyClient
|
from tests.inference.test_execution import ComfyClient, run_warmup
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.execution
|
@pytest.mark.execution
|
||||||
@@ -24,6 +24,7 @@ class TestAsyncNodes:
|
|||||||
'--listen', args_pytest["listen"],
|
'--listen', args_pytest["listen"],
|
||||||
'--port', str(args_pytest["port"]),
|
'--port', str(args_pytest["port"]),
|
||||||
'--extra-model-paths-config', 'tests/inference/extra_model_paths.yaml',
|
'--extra-model-paths-config', 'tests/inference/extra_model_paths.yaml',
|
||||||
|
'--cpu',
|
||||||
]
|
]
|
||||||
use_lru, lru_size = request.param
|
use_lru, lru_size = request.param
|
||||||
if use_lru:
|
if use_lru:
|
||||||
@@ -82,6 +83,9 @@ class TestAsyncNodes:
|
|||||||
|
|
||||||
def test_multiple_async_parallel_execution(self, client: ComfyClient, builder: GraphBuilder):
|
def test_multiple_async_parallel_execution(self, client: ComfyClient, builder: GraphBuilder):
|
||||||
"""Test that multiple async nodes execute in parallel."""
|
"""Test that multiple async nodes execute in parallel."""
|
||||||
|
# Warmup execution to ensure server is fully initialized
|
||||||
|
run_warmup(client)
|
||||||
|
|
||||||
g = builder
|
g = builder
|
||||||
image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||||
|
|
||||||
@@ -148,6 +152,9 @@ class TestAsyncNodes:
|
|||||||
|
|
||||||
def test_async_lazy_evaluation(self, client: ComfyClient, builder: GraphBuilder):
|
def test_async_lazy_evaluation(self, client: ComfyClient, builder: GraphBuilder):
|
||||||
"""Test async nodes with lazy evaluation."""
|
"""Test async nodes with lazy evaluation."""
|
||||||
|
# Warmup execution to ensure server is fully initialized
|
||||||
|
run_warmup(client, prefix="warmup_lazy")
|
||||||
|
|
||||||
g = builder
|
g = builder
|
||||||
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||||
input2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
|
input2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
|
||||||
@@ -305,6 +312,9 @@ class TestAsyncNodes:
|
|||||||
|
|
||||||
def test_async_caching_behavior(self, client: ComfyClient, builder: GraphBuilder):
|
def test_async_caching_behavior(self, client: ComfyClient, builder: GraphBuilder):
|
||||||
"""Test that async nodes are properly cached."""
|
"""Test that async nodes are properly cached."""
|
||||||
|
# Warmup execution to ensure server is fully initialized
|
||||||
|
run_warmup(client, prefix="warmup_cache")
|
||||||
|
|
||||||
g = builder
|
g = builder
|
||||||
image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||||
sleep_node = g.node("TestSleep", value=image.out(0), seconds=0.2)
|
sleep_node = g.node("TestSleep", value=image.out(0), seconds=0.2)
|
||||||
@@ -324,6 +334,9 @@ class TestAsyncNodes:
|
|||||||
|
|
||||||
def test_async_with_dynamic_prompts(self, client: ComfyClient, builder: GraphBuilder):
|
def test_async_with_dynamic_prompts(self, client: ComfyClient, builder: GraphBuilder):
|
||||||
"""Test async nodes within dynamically generated prompts."""
|
"""Test async nodes within dynamically generated prompts."""
|
||||||
|
# Warmup execution to ensure server is fully initialized
|
||||||
|
run_warmup(client, prefix="warmup_dynamic")
|
||||||
|
|
||||||
g = builder
|
g = builder
|
||||||
image1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
image1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||||
image2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
|
image2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
|
||||||
|
@@ -15,10 +15,18 @@ import urllib.parse
|
|||||||
import urllib.error
|
import urllib.error
|
||||||
from comfy_execution.graph_utils import GraphBuilder, Node
|
from comfy_execution.graph_utils import GraphBuilder, Node
|
||||||
|
|
||||||
|
def run_warmup(client, prefix="warmup"):
|
||||||
|
"""Run a simple workflow to warm up the server."""
|
||||||
|
warmup_g = GraphBuilder(prefix=prefix)
|
||||||
|
warmup_image = warmup_g.node("StubImage", content="BLACK", height=32, width=32, batch_size=1)
|
||||||
|
warmup_g.node("PreviewImage", images=warmup_image.out(0))
|
||||||
|
client.run(warmup_g)
|
||||||
|
|
||||||
class RunResult:
|
class RunResult:
|
||||||
def __init__(self, prompt_id: str):
|
def __init__(self, prompt_id: str):
|
||||||
self.outputs: Dict[str,Dict] = {}
|
self.outputs: Dict[str,Dict] = {}
|
||||||
self.runs: Dict[str,bool] = {}
|
self.runs: Dict[str,bool] = {}
|
||||||
|
self.cached: Dict[str,bool] = {}
|
||||||
self.prompt_id: str = prompt_id
|
self.prompt_id: str = prompt_id
|
||||||
|
|
||||||
def get_output(self, node: Node):
|
def get_output(self, node: Node):
|
||||||
@@ -27,6 +35,13 @@ class RunResult:
|
|||||||
def did_run(self, node: Node):
|
def did_run(self, node: Node):
|
||||||
return self.runs.get(node.id, False)
|
return self.runs.get(node.id, False)
|
||||||
|
|
||||||
|
def was_cached(self, node: Node):
|
||||||
|
return self.cached.get(node.id, False)
|
||||||
|
|
||||||
|
def was_executed(self, node: Node):
|
||||||
|
"""Returns True if node was either run or cached"""
|
||||||
|
return self.did_run(node) or self.was_cached(node)
|
||||||
|
|
||||||
def get_images(self, node: Node):
|
def get_images(self, node: Node):
|
||||||
output = self.get_output(node)
|
output = self.get_output(node)
|
||||||
if output is None:
|
if output is None:
|
||||||
@@ -51,8 +66,10 @@ class ComfyClient:
|
|||||||
ws.connect("ws://{}/ws?clientId={}".format(self.server_address, self.client_id))
|
ws.connect("ws://{}/ws?clientId={}".format(self.server_address, self.client_id))
|
||||||
self.ws = ws
|
self.ws = ws
|
||||||
|
|
||||||
def queue_prompt(self, prompt):
|
def queue_prompt(self, prompt, partial_execution_targets=None):
|
||||||
p = {"prompt": prompt, "client_id": self.client_id}
|
p = {"prompt": prompt, "client_id": self.client_id}
|
||||||
|
if partial_execution_targets is not None:
|
||||||
|
p["partial_execution_targets"] = partial_execution_targets
|
||||||
data = json.dumps(p).encode('utf-8')
|
data = json.dumps(p).encode('utf-8')
|
||||||
req = urllib.request.Request("http://{}/prompt".format(self.server_address), data=data)
|
req = urllib.request.Request("http://{}/prompt".format(self.server_address), data=data)
|
||||||
return json.loads(urllib.request.urlopen(req).read())
|
return json.loads(urllib.request.urlopen(req).read())
|
||||||
@@ -70,13 +87,13 @@ class ComfyClient:
|
|||||||
def set_test_name(self, name):
|
def set_test_name(self, name):
|
||||||
self.test_name = name
|
self.test_name = name
|
||||||
|
|
||||||
def run(self, graph):
|
def run(self, graph, partial_execution_targets=None):
|
||||||
prompt = graph.finalize()
|
prompt = graph.finalize()
|
||||||
for node in graph.nodes.values():
|
for node in graph.nodes.values():
|
||||||
if node.class_type == 'SaveImage':
|
if node.class_type == 'SaveImage':
|
||||||
node.inputs['filename_prefix'] = self.test_name
|
node.inputs['filename_prefix'] = self.test_name
|
||||||
|
|
||||||
prompt_id = self.queue_prompt(prompt)['prompt_id']
|
prompt_id = self.queue_prompt(prompt, partial_execution_targets)['prompt_id']
|
||||||
result = RunResult(prompt_id)
|
result = RunResult(prompt_id)
|
||||||
while True:
|
while True:
|
||||||
out = self.ws.recv()
|
out = self.ws.recv()
|
||||||
@@ -92,7 +109,10 @@ class ComfyClient:
|
|||||||
elif message['type'] == 'execution_error':
|
elif message['type'] == 'execution_error':
|
||||||
raise Exception(message['data'])
|
raise Exception(message['data'])
|
||||||
elif message['type'] == 'execution_cached':
|
elif message['type'] == 'execution_cached':
|
||||||
pass # Probably want to store this off for testing
|
if message['data']['prompt_id'] == prompt_id:
|
||||||
|
cached_nodes = message['data'].get('nodes', [])
|
||||||
|
for node_id in cached_nodes:
|
||||||
|
result.cached[node_id] = True
|
||||||
|
|
||||||
history = self.get_history(prompt_id)[prompt_id]
|
history = self.get_history(prompt_id)[prompt_id]
|
||||||
for node_id in history['outputs']:
|
for node_id in history['outputs']:
|
||||||
@@ -130,6 +150,7 @@ class TestExecution:
|
|||||||
'--listen', args_pytest["listen"],
|
'--listen', args_pytest["listen"],
|
||||||
'--port', str(args_pytest["port"]),
|
'--port', str(args_pytest["port"]),
|
||||||
'--extra-model-paths-config', 'tests/inference/extra_model_paths.yaml',
|
'--extra-model-paths-config', 'tests/inference/extra_model_paths.yaml',
|
||||||
|
'--cpu',
|
||||||
]
|
]
|
||||||
use_lru, lru_size = request.param
|
use_lru, lru_size = request.param
|
||||||
if use_lru:
|
if use_lru:
|
||||||
@@ -498,12 +519,15 @@ class TestExecution:
|
|||||||
assert not result.did_run(test_node), "The execution should have been cached"
|
assert not result.did_run(test_node), "The execution should have been cached"
|
||||||
|
|
||||||
def test_parallel_sleep_nodes(self, client: ComfyClient, builder: GraphBuilder):
|
def test_parallel_sleep_nodes(self, client: ComfyClient, builder: GraphBuilder):
|
||||||
|
# Warmup execution to ensure server is fully initialized
|
||||||
|
run_warmup(client)
|
||||||
|
|
||||||
g = builder
|
g = builder
|
||||||
image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||||
|
|
||||||
# Create sleep nodes for each duration
|
# Create sleep nodes for each duration
|
||||||
sleep_node1 = g.node("TestSleep", value=image.out(0), seconds=2.8)
|
sleep_node1 = g.node("TestSleep", value=image.out(0), seconds=2.9)
|
||||||
sleep_node2 = g.node("TestSleep", value=image.out(0), seconds=2.9)
|
sleep_node2 = g.node("TestSleep", value=image.out(0), seconds=3.1)
|
||||||
sleep_node3 = g.node("TestSleep", value=image.out(0), seconds=3.0)
|
sleep_node3 = g.node("TestSleep", value=image.out(0), seconds=3.0)
|
||||||
|
|
||||||
# Add outputs to verify the execution
|
# Add outputs to verify the execution
|
||||||
@@ -515,10 +539,9 @@ class TestExecution:
|
|||||||
result = client.run(g)
|
result = client.run(g)
|
||||||
elapsed_time = time.time() - start_time
|
elapsed_time = time.time() - start_time
|
||||||
|
|
||||||
# The test should take around 0.4 seconds (the longest sleep duration)
|
# The test should take around 3.0 seconds (the longest sleep duration)
|
||||||
# plus some overhead, but definitely less than the sum of all sleeps (0.9s)
|
# plus some overhead, but definitely less than the sum of all sleeps (9.0s)
|
||||||
# We'll allow for up to 0.8s total to account for overhead
|
assert elapsed_time < 8.9, f"Parallel execution took {elapsed_time}s, expected less than 8.9s"
|
||||||
assert elapsed_time < 4.0, f"Parallel execution took {elapsed_time}s, expected less than 0.8s"
|
|
||||||
|
|
||||||
# Verify that all nodes executed
|
# Verify that all nodes executed
|
||||||
assert result.did_run(sleep_node1), "Sleep node 1 should have run"
|
assert result.did_run(sleep_node1), "Sleep node 1 should have run"
|
||||||
@@ -526,6 +549,9 @@ class TestExecution:
|
|||||||
assert result.did_run(sleep_node3), "Sleep node 3 should have run"
|
assert result.did_run(sleep_node3), "Sleep node 3 should have run"
|
||||||
|
|
||||||
def test_parallel_sleep_expansion(self, client: ComfyClient, builder: GraphBuilder):
|
def test_parallel_sleep_expansion(self, client: ComfyClient, builder: GraphBuilder):
|
||||||
|
# Warmup execution to ensure server is fully initialized
|
||||||
|
run_warmup(client)
|
||||||
|
|
||||||
g = builder
|
g = builder
|
||||||
# Create input images with different values
|
# Create input images with different values
|
||||||
image1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
image1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||||
@@ -537,9 +563,9 @@ class TestExecution:
|
|||||||
image1=image1.out(0),
|
image1=image1.out(0),
|
||||||
image2=image2.out(0),
|
image2=image2.out(0),
|
||||||
image3=image3.out(0),
|
image3=image3.out(0),
|
||||||
sleep1=0.4,
|
sleep1=4.8,
|
||||||
sleep2=0.5,
|
sleep2=4.9,
|
||||||
sleep3=0.6)
|
sleep3=5.0)
|
||||||
output = g.node("SaveImage", images=parallel_sleep.out(0))
|
output = g.node("SaveImage", images=parallel_sleep.out(0))
|
||||||
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
@@ -548,7 +574,7 @@ class TestExecution:
|
|||||||
|
|
||||||
# Similar to the previous test, expect parallel execution of the sleep nodes
|
# Similar to the previous test, expect parallel execution of the sleep nodes
|
||||||
# which should complete in less than the sum of all sleeps
|
# which should complete in less than the sum of all sleeps
|
||||||
assert elapsed_time < 0.8, f"Expansion execution took {elapsed_time}s, expected less than 0.8s"
|
assert elapsed_time < 10.0, f"Expansion execution took {elapsed_time}s, expected less than 5.5s"
|
||||||
|
|
||||||
# Verify the parallel sleep node executed
|
# Verify the parallel sleep node executed
|
||||||
assert result.did_run(parallel_sleep), "ParallelSleep node should have run"
|
assert result.did_run(parallel_sleep), "ParallelSleep node should have run"
|
||||||
@@ -585,3 +611,151 @@ class TestExecution:
|
|||||||
assert len(images) == 2, "Should have 2 images"
|
assert len(images) == 2, "Should have 2 images"
|
||||||
assert numpy.array(images[0]).min() == 0 and numpy.array(images[0]).max() == 0, "First image should be black"
|
assert numpy.array(images[0]).min() == 0 and numpy.array(images[0]).max() == 0, "First image should be black"
|
||||||
assert numpy.array(images[1]).min() == 0 and numpy.array(images[1]).max() == 0, "Second image should also be black"
|
assert numpy.array(images[1]).min() == 0 and numpy.array(images[1]).max() == 0, "Second image should also be black"
|
||||||
|
|
||||||
|
# Output nodes included in the partial execution list are executed
|
||||||
|
def test_partial_execution_included_outputs(self, client: ComfyClient, builder: GraphBuilder):
|
||||||
|
g = builder
|
||||||
|
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||||
|
input2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
|
||||||
|
|
||||||
|
# Create two separate output nodes
|
||||||
|
output1 = g.node("SaveImage", images=input1.out(0))
|
||||||
|
output2 = g.node("SaveImage", images=input2.out(0))
|
||||||
|
|
||||||
|
# Run with partial execution targeting only output1
|
||||||
|
result = client.run(g, partial_execution_targets=[output1.id])
|
||||||
|
|
||||||
|
assert result.was_executed(input1), "Input1 should have been executed (run or cached)"
|
||||||
|
assert result.was_executed(output1), "Output1 should have been executed (run or cached)"
|
||||||
|
assert not result.did_run(input2), "Input2 should not have run"
|
||||||
|
assert not result.did_run(output2), "Output2 should not have run"
|
||||||
|
|
||||||
|
# Verify only output1 produced results
|
||||||
|
assert len(result.get_images(output1)) == 1, "Output1 should have produced an image"
|
||||||
|
assert len(result.get_images(output2)) == 0, "Output2 should not have produced an image"
|
||||||
|
|
||||||
|
# Output nodes NOT included in the partial execution list are NOT executed
|
||||||
|
def test_partial_execution_excluded_outputs(self, client: ComfyClient, builder: GraphBuilder):
|
||||||
|
g = builder
|
||||||
|
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||||
|
input2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
|
||||||
|
input3 = g.node("StubImage", content="NOISE", height=512, width=512, batch_size=1)
|
||||||
|
|
||||||
|
# Create three output nodes
|
||||||
|
output1 = g.node("SaveImage", images=input1.out(0))
|
||||||
|
output2 = g.node("SaveImage", images=input2.out(0))
|
||||||
|
output3 = g.node("SaveImage", images=input3.out(0))
|
||||||
|
|
||||||
|
# Run with partial execution targeting only output1 and output3
|
||||||
|
result = client.run(g, partial_execution_targets=[output1.id, output3.id])
|
||||||
|
|
||||||
|
assert result.was_executed(input1), "Input1 should have been executed"
|
||||||
|
assert result.was_executed(input3), "Input3 should have been executed"
|
||||||
|
assert result.was_executed(output1), "Output1 should have been executed"
|
||||||
|
assert result.was_executed(output3), "Output3 should have been executed"
|
||||||
|
assert not result.did_run(input2), "Input2 should not have run"
|
||||||
|
assert not result.did_run(output2), "Output2 should not have run"
|
||||||
|
|
||||||
|
# Output nodes NOT in list ARE executed if necessary for nodes that are in the list
|
||||||
|
def test_partial_execution_dependencies(self, client: ComfyClient, builder: GraphBuilder):
|
||||||
|
g = builder
|
||||||
|
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||||
|
|
||||||
|
# Create a processing chain with an OUTPUT_NODE that has socket outputs
|
||||||
|
output_with_socket = g.node("TestOutputNodeWithSocketOutput", image=input1.out(0), value=2.0)
|
||||||
|
|
||||||
|
# Create another node that depends on the output_with_socket
|
||||||
|
dependent_node = g.node("TestLazyMixImages",
|
||||||
|
image1=output_with_socket.out(0),
|
||||||
|
image2=input1.out(0),
|
||||||
|
mask=g.node("StubMask", value=0.5, height=512, width=512, batch_size=1).out(0))
|
||||||
|
|
||||||
|
# Create the final output
|
||||||
|
final_output = g.node("SaveImage", images=dependent_node.out(0))
|
||||||
|
|
||||||
|
# Run with partial execution targeting only the final output
|
||||||
|
result = client.run(g, partial_execution_targets=[final_output.id])
|
||||||
|
|
||||||
|
# All nodes should have been executed because they're dependencies
|
||||||
|
assert result.was_executed(input1), "Input1 should have been executed"
|
||||||
|
assert result.was_executed(output_with_socket), "Output with socket should have been executed (dependency)"
|
||||||
|
assert result.was_executed(dependent_node), "Dependent node should have been executed"
|
||||||
|
assert result.was_executed(final_output), "Final output should have been executed"
|
||||||
|
|
||||||
|
# Lazy execution works with partial execution
|
||||||
|
def test_partial_execution_with_lazy_nodes(self, client: ComfyClient, builder: GraphBuilder):
|
||||||
|
g = builder
|
||||||
|
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||||
|
input2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
|
||||||
|
input3 = g.node("StubImage", content="NOISE", height=512, width=512, batch_size=1)
|
||||||
|
|
||||||
|
# Create masks that will trigger different lazy execution paths
|
||||||
|
mask1 = g.node("StubMask", value=0.0, height=512, width=512, batch_size=1) # Will only need image1
|
||||||
|
mask2 = g.node("StubMask", value=0.5, height=512, width=512, batch_size=1) # Will need both images
|
||||||
|
|
||||||
|
# Create two lazy mix nodes
|
||||||
|
lazy_mix1 = g.node("TestLazyMixImages", image1=input1.out(0), image2=input2.out(0), mask=mask1.out(0))
|
||||||
|
lazy_mix2 = g.node("TestLazyMixImages", image1=input2.out(0), image2=input3.out(0), mask=mask2.out(0))
|
||||||
|
|
||||||
|
output1 = g.node("SaveImage", images=lazy_mix1.out(0))
|
||||||
|
output2 = g.node("SaveImage", images=lazy_mix2.out(0))
|
||||||
|
|
||||||
|
# Run with partial execution targeting only output1
|
||||||
|
result = client.run(g, partial_execution_targets=[output1.id])
|
||||||
|
|
||||||
|
# For output1 path - only input1 should run due to lazy evaluation (mask=0.0)
|
||||||
|
assert result.was_executed(input1), "Input1 should have been executed"
|
||||||
|
assert not result.did_run(input2), "Input2 should not have run (lazy evaluation)"
|
||||||
|
assert result.was_executed(mask1), "Mask1 should have been executed"
|
||||||
|
assert result.was_executed(lazy_mix1), "Lazy mix1 should have been executed"
|
||||||
|
assert result.was_executed(output1), "Output1 should have been executed"
|
||||||
|
|
||||||
|
# Nothing from output2 path should run
|
||||||
|
assert not result.did_run(input3), "Input3 should not have run"
|
||||||
|
assert not result.did_run(mask2), "Mask2 should not have run"
|
||||||
|
assert not result.did_run(lazy_mix2), "Lazy mix2 should not have run"
|
||||||
|
assert not result.did_run(output2), "Output2 should not have run"
|
||||||
|
|
||||||
|
# Multiple OUTPUT_NODEs with dependencies
|
||||||
|
def test_partial_execution_multiple_output_nodes(self, client: ComfyClient, builder: GraphBuilder):
|
||||||
|
g = builder
|
||||||
|
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||||
|
input2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
|
||||||
|
|
||||||
|
# Create a chain of OUTPUT_NODEs
|
||||||
|
output_node1 = g.node("TestOutputNodeWithSocketOutput", image=input1.out(0), value=1.5)
|
||||||
|
output_node2 = g.node("TestOutputNodeWithSocketOutput", image=output_node1.out(0), value=2.0)
|
||||||
|
|
||||||
|
# Create regular output nodes
|
||||||
|
save1 = g.node("SaveImage", images=output_node1.out(0))
|
||||||
|
save2 = g.node("SaveImage", images=output_node2.out(0))
|
||||||
|
save3 = g.node("SaveImage", images=input2.out(0))
|
||||||
|
|
||||||
|
# Run targeting only save2
|
||||||
|
result = client.run(g, partial_execution_targets=[save2.id])
|
||||||
|
|
||||||
|
# Should run: input1, output_node1, output_node2, save2
|
||||||
|
assert result.was_executed(input1), "Input1 should have been executed"
|
||||||
|
assert result.was_executed(output_node1), "Output node 1 should have been executed (dependency)"
|
||||||
|
assert result.was_executed(output_node2), "Output node 2 should have been executed (dependency)"
|
||||||
|
assert result.was_executed(save2), "Save2 should have been executed"
|
||||||
|
|
||||||
|
# Should NOT run: input2, save1, save3
|
||||||
|
assert not result.did_run(input2), "Input2 should not have run"
|
||||||
|
assert not result.did_run(save1), "Save1 should not have run"
|
||||||
|
assert not result.did_run(save3), "Save3 should not have run"
|
||||||
|
|
||||||
|
# Empty partial execution list (should execute nothing)
|
||||||
|
def test_partial_execution_empty_list(self, client: ComfyClient, builder: GraphBuilder):
|
||||||
|
g = builder
|
||||||
|
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||||
|
_output1 = g.node("SaveImage", images=input1.out(0))
|
||||||
|
|
||||||
|
# Run with empty partial execution list
|
||||||
|
try:
|
||||||
|
_result = client.run(g, partial_execution_targets=[])
|
||||||
|
# Should get an error because no outputs are selected
|
||||||
|
assert False, "Should have raised an error for empty partial execution list"
|
||||||
|
except urllib.error.HTTPError:
|
||||||
|
pass # Expected behavior
|
||||||
|
|
||||||
|
@@ -463,6 +463,25 @@ class TestParallelSleep(ComfyNodeABC):
|
|||||||
"expand": g.finalize(),
|
"expand": g.finalize(),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
class TestOutputNodeWithSocketOutput:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"image": ("IMAGE",),
|
||||||
|
"value": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0}),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
RETURN_TYPES = ("IMAGE",)
|
||||||
|
FUNCTION = "process"
|
||||||
|
CATEGORY = "_for_testing"
|
||||||
|
OUTPUT_NODE = True
|
||||||
|
|
||||||
|
def process(self, image, value):
|
||||||
|
# Apply value scaling and return both as output and socket
|
||||||
|
result = image * value
|
||||||
|
return (result,)
|
||||||
|
|
||||||
TEST_NODE_CLASS_MAPPINGS = {
|
TEST_NODE_CLASS_MAPPINGS = {
|
||||||
"TestLazyMixImages": TestLazyMixImages,
|
"TestLazyMixImages": TestLazyMixImages,
|
||||||
"TestVariadicAverage": TestVariadicAverage,
|
"TestVariadicAverage": TestVariadicAverage,
|
||||||
@@ -478,6 +497,7 @@ TEST_NODE_CLASS_MAPPINGS = {
|
|||||||
"TestSamplingInExpansion": TestSamplingInExpansion,
|
"TestSamplingInExpansion": TestSamplingInExpansion,
|
||||||
"TestSleep": TestSleep,
|
"TestSleep": TestSleep,
|
||||||
"TestParallelSleep": TestParallelSleep,
|
"TestParallelSleep": TestParallelSleep,
|
||||||
|
"TestOutputNodeWithSocketOutput": TestOutputNodeWithSocketOutput,
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_NODE_DISPLAY_NAME_MAPPINGS = {
|
TEST_NODE_DISPLAY_NAME_MAPPINGS = {
|
||||||
@@ -495,4 +515,5 @@ TEST_NODE_DISPLAY_NAME_MAPPINGS = {
|
|||||||
"TestSamplingInExpansion": "Sampling In Expansion",
|
"TestSamplingInExpansion": "Sampling In Expansion",
|
||||||
"TestSleep": "Test Sleep",
|
"TestSleep": "Test Sleep",
|
||||||
"TestParallelSleep": "Test Parallel Sleep",
|
"TestParallelSleep": "Test Parallel Sleep",
|
||||||
|
"TestOutputNodeWithSocketOutput": "Test Output Node With Socket Output",
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user