Add support for partial execution in backend (#9123)

When a prompt is submitted, it can optionally include
`partial_execution_targets` as a list of ids. If it does, rather than
adding all outputs to the execution list, we add only those in the list.
This commit is contained in:
guill
2025-07-30 19:55:28 -07:00
committed by GitHub
parent 61b08d4ba6
commit 97eb256a35
5 changed files with 233 additions and 19 deletions

View File

@@ -7,7 +7,7 @@ import threading
import time
import traceback
from enum import Enum
from typing import List, Literal, NamedTuple, Optional
from typing import List, Literal, NamedTuple, Optional, Union
import asyncio
import torch
@@ -891,7 +891,7 @@ def full_type_name(klass):
return klass.__qualname__
return module + '.' + klass.__qualname__
async def validate_prompt(prompt_id, prompt):
async def validate_prompt(prompt_id, prompt, partial_execution_list: Union[list[str], None]):
outputs = set()
for x in prompt:
if 'class_type' not in prompt[x]:
@@ -915,7 +915,8 @@ async def validate_prompt(prompt_id, prompt):
return (False, error, [], {})
if hasattr(class_, 'OUTPUT_NODE') and class_.OUTPUT_NODE is True:
outputs.add(x)
if partial_execution_list is None or x in partial_execution_list:
outputs.add(x)
if len(outputs) == 0:
error = {

View File

@@ -681,7 +681,12 @@ class PromptServer():
if "prompt" in json_data:
prompt = json_data["prompt"]
prompt_id = str(json_data.get("prompt_id", uuid.uuid4()))
valid = await execution.validate_prompt(prompt_id, prompt)
partial_execution_targets = None
if "partial_execution_targets" in json_data:
partial_execution_targets = json_data["partial_execution_targets"]
valid = await execution.validate_prompt(prompt_id, prompt, partial_execution_targets)
extra_data = {}
if "extra_data" in json_data:
extra_data = json_data["extra_data"]

View File

@@ -7,7 +7,7 @@ import subprocess
from pytest import fixture
from comfy_execution.graph_utils import GraphBuilder
from tests.inference.test_execution import ComfyClient
from tests.inference.test_execution import ComfyClient, run_warmup
@pytest.mark.execution
@@ -24,6 +24,7 @@ class TestAsyncNodes:
'--listen', args_pytest["listen"],
'--port', str(args_pytest["port"]),
'--extra-model-paths-config', 'tests/inference/extra_model_paths.yaml',
'--cpu',
]
use_lru, lru_size = request.param
if use_lru:
@@ -82,6 +83,9 @@ class TestAsyncNodes:
def test_multiple_async_parallel_execution(self, client: ComfyClient, builder: GraphBuilder):
"""Test that multiple async nodes execute in parallel."""
# Warmup execution to ensure server is fully initialized
run_warmup(client)
g = builder
image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
@@ -148,6 +152,9 @@ class TestAsyncNodes:
def test_async_lazy_evaluation(self, client: ComfyClient, builder: GraphBuilder):
"""Test async nodes with lazy evaluation."""
# Warmup execution to ensure server is fully initialized
run_warmup(client, prefix="warmup_lazy")
g = builder
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
input2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
@@ -305,6 +312,9 @@ class TestAsyncNodes:
def test_async_caching_behavior(self, client: ComfyClient, builder: GraphBuilder):
"""Test that async nodes are properly cached."""
# Warmup execution to ensure server is fully initialized
run_warmup(client, prefix="warmup_cache")
g = builder
image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
sleep_node = g.node("TestSleep", value=image.out(0), seconds=0.2)
@@ -324,6 +334,9 @@ class TestAsyncNodes:
def test_async_with_dynamic_prompts(self, client: ComfyClient, builder: GraphBuilder):
"""Test async nodes within dynamically generated prompts."""
# Warmup execution to ensure server is fully initialized
run_warmup(client, prefix="warmup_dynamic")
g = builder
image1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
image2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)

View File

@@ -15,10 +15,18 @@ import urllib.parse
import urllib.error
from comfy_execution.graph_utils import GraphBuilder, Node
def run_warmup(client, prefix="warmup"):
"""Run a simple workflow to warm up the server."""
warmup_g = GraphBuilder(prefix=prefix)
warmup_image = warmup_g.node("StubImage", content="BLACK", height=32, width=32, batch_size=1)
warmup_g.node("PreviewImage", images=warmup_image.out(0))
client.run(warmup_g)
class RunResult:
def __init__(self, prompt_id: str):
self.outputs: Dict[str,Dict] = {}
self.runs: Dict[str,bool] = {}
self.cached: Dict[str,bool] = {}
self.prompt_id: str = prompt_id
def get_output(self, node: Node):
@@ -27,6 +35,13 @@ class RunResult:
def did_run(self, node: Node):
return self.runs.get(node.id, False)
def was_cached(self, node: Node):
return self.cached.get(node.id, False)
def was_executed(self, node: Node):
"""Returns True if node was either run or cached"""
return self.did_run(node) or self.was_cached(node)
def get_images(self, node: Node):
output = self.get_output(node)
if output is None:
@@ -51,8 +66,10 @@ class ComfyClient:
ws.connect("ws://{}/ws?clientId={}".format(self.server_address, self.client_id))
self.ws = ws
def queue_prompt(self, prompt):
def queue_prompt(self, prompt, partial_execution_targets=None):
p = {"prompt": prompt, "client_id": self.client_id}
if partial_execution_targets is not None:
p["partial_execution_targets"] = partial_execution_targets
data = json.dumps(p).encode('utf-8')
req = urllib.request.Request("http://{}/prompt".format(self.server_address), data=data)
return json.loads(urllib.request.urlopen(req).read())
@@ -70,13 +87,13 @@ class ComfyClient:
def set_test_name(self, name):
self.test_name = name
def run(self, graph):
def run(self, graph, partial_execution_targets=None):
prompt = graph.finalize()
for node in graph.nodes.values():
if node.class_type == 'SaveImage':
node.inputs['filename_prefix'] = self.test_name
prompt_id = self.queue_prompt(prompt)['prompt_id']
prompt_id = self.queue_prompt(prompt, partial_execution_targets)['prompt_id']
result = RunResult(prompt_id)
while True:
out = self.ws.recv()
@@ -92,7 +109,10 @@ class ComfyClient:
elif message['type'] == 'execution_error':
raise Exception(message['data'])
elif message['type'] == 'execution_cached':
pass # Probably want to store this off for testing
if message['data']['prompt_id'] == prompt_id:
cached_nodes = message['data'].get('nodes', [])
for node_id in cached_nodes:
result.cached[node_id] = True
history = self.get_history(prompt_id)[prompt_id]
for node_id in history['outputs']:
@@ -130,6 +150,7 @@ class TestExecution:
'--listen', args_pytest["listen"],
'--port', str(args_pytest["port"]),
'--extra-model-paths-config', 'tests/inference/extra_model_paths.yaml',
'--cpu',
]
use_lru, lru_size = request.param
if use_lru:
@@ -498,12 +519,15 @@ class TestExecution:
assert not result.did_run(test_node), "The execution should have been cached"
def test_parallel_sleep_nodes(self, client: ComfyClient, builder: GraphBuilder):
# Warmup execution to ensure server is fully initialized
run_warmup(client)
g = builder
image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
# Create sleep nodes for each duration
sleep_node1 = g.node("TestSleep", value=image.out(0), seconds=2.8)
sleep_node2 = g.node("TestSleep", value=image.out(0), seconds=2.9)
sleep_node1 = g.node("TestSleep", value=image.out(0), seconds=2.9)
sleep_node2 = g.node("TestSleep", value=image.out(0), seconds=3.1)
sleep_node3 = g.node("TestSleep", value=image.out(0), seconds=3.0)
# Add outputs to verify the execution
@@ -515,10 +539,9 @@ class TestExecution:
result = client.run(g)
elapsed_time = time.time() - start_time
# The test should take around 0.4 seconds (the longest sleep duration)
# plus some overhead, but definitely less than the sum of all sleeps (0.9s)
# We'll allow for up to 0.8s total to account for overhead
assert elapsed_time < 4.0, f"Parallel execution took {elapsed_time}s, expected less than 0.8s"
# The test should take around 3.0 seconds (the longest sleep duration)
# plus some overhead, but definitely less than the sum of all sleeps (9.0s)
assert elapsed_time < 8.9, f"Parallel execution took {elapsed_time}s, expected less than 8.9s"
# Verify that all nodes executed
assert result.did_run(sleep_node1), "Sleep node 1 should have run"
@@ -526,6 +549,9 @@ class TestExecution:
assert result.did_run(sleep_node3), "Sleep node 3 should have run"
def test_parallel_sleep_expansion(self, client: ComfyClient, builder: GraphBuilder):
# Warmup execution to ensure server is fully initialized
run_warmup(client)
g = builder
# Create input images with different values
image1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
@@ -537,9 +563,9 @@ class TestExecution:
image1=image1.out(0),
image2=image2.out(0),
image3=image3.out(0),
sleep1=0.4,
sleep2=0.5,
sleep3=0.6)
sleep1=4.8,
sleep2=4.9,
sleep3=5.0)
output = g.node("SaveImage", images=parallel_sleep.out(0))
start_time = time.time()
@@ -548,7 +574,7 @@ class TestExecution:
# Similar to the previous test, expect parallel execution of the sleep nodes
# which should complete in less than the sum of all sleeps
assert elapsed_time < 0.8, f"Expansion execution took {elapsed_time}s, expected less than 0.8s"
assert elapsed_time < 10.0, f"Expansion execution took {elapsed_time}s, expected less than 5.5s"
# Verify the parallel sleep node executed
assert result.did_run(parallel_sleep), "ParallelSleep node should have run"
@@ -585,3 +611,151 @@ class TestExecution:
assert len(images) == 2, "Should have 2 images"
assert numpy.array(images[0]).min() == 0 and numpy.array(images[0]).max() == 0, "First image should be black"
assert numpy.array(images[1]).min() == 0 and numpy.array(images[1]).max() == 0, "Second image should also be black"
# Output nodes included in the partial execution list are executed
def test_partial_execution_included_outputs(self, client: ComfyClient, builder: GraphBuilder):
g = builder
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
input2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
# Create two separate output nodes
output1 = g.node("SaveImage", images=input1.out(0))
output2 = g.node("SaveImage", images=input2.out(0))
# Run with partial execution targeting only output1
result = client.run(g, partial_execution_targets=[output1.id])
assert result.was_executed(input1), "Input1 should have been executed (run or cached)"
assert result.was_executed(output1), "Output1 should have been executed (run or cached)"
assert not result.did_run(input2), "Input2 should not have run"
assert not result.did_run(output2), "Output2 should not have run"
# Verify only output1 produced results
assert len(result.get_images(output1)) == 1, "Output1 should have produced an image"
assert len(result.get_images(output2)) == 0, "Output2 should not have produced an image"
# Output nodes NOT included in the partial execution list are NOT executed
def test_partial_execution_excluded_outputs(self, client: ComfyClient, builder: GraphBuilder):
g = builder
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
input2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
input3 = g.node("StubImage", content="NOISE", height=512, width=512, batch_size=1)
# Create three output nodes
output1 = g.node("SaveImage", images=input1.out(0))
output2 = g.node("SaveImage", images=input2.out(0))
output3 = g.node("SaveImage", images=input3.out(0))
# Run with partial execution targeting only output1 and output3
result = client.run(g, partial_execution_targets=[output1.id, output3.id])
assert result.was_executed(input1), "Input1 should have been executed"
assert result.was_executed(input3), "Input3 should have been executed"
assert result.was_executed(output1), "Output1 should have been executed"
assert result.was_executed(output3), "Output3 should have been executed"
assert not result.did_run(input2), "Input2 should not have run"
assert not result.did_run(output2), "Output2 should not have run"
# Output nodes NOT in list ARE executed if necessary for nodes that are in the list
def test_partial_execution_dependencies(self, client: ComfyClient, builder: GraphBuilder):
g = builder
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
# Create a processing chain with an OUTPUT_NODE that has socket outputs
output_with_socket = g.node("TestOutputNodeWithSocketOutput", image=input1.out(0), value=2.0)
# Create another node that depends on the output_with_socket
dependent_node = g.node("TestLazyMixImages",
image1=output_with_socket.out(0),
image2=input1.out(0),
mask=g.node("StubMask", value=0.5, height=512, width=512, batch_size=1).out(0))
# Create the final output
final_output = g.node("SaveImage", images=dependent_node.out(0))
# Run with partial execution targeting only the final output
result = client.run(g, partial_execution_targets=[final_output.id])
# All nodes should have been executed because they're dependencies
assert result.was_executed(input1), "Input1 should have been executed"
assert result.was_executed(output_with_socket), "Output with socket should have been executed (dependency)"
assert result.was_executed(dependent_node), "Dependent node should have been executed"
assert result.was_executed(final_output), "Final output should have been executed"
# Lazy execution works with partial execution
def test_partial_execution_with_lazy_nodes(self, client: ComfyClient, builder: GraphBuilder):
g = builder
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
input2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
input3 = g.node("StubImage", content="NOISE", height=512, width=512, batch_size=1)
# Create masks that will trigger different lazy execution paths
mask1 = g.node("StubMask", value=0.0, height=512, width=512, batch_size=1) # Will only need image1
mask2 = g.node("StubMask", value=0.5, height=512, width=512, batch_size=1) # Will need both images
# Create two lazy mix nodes
lazy_mix1 = g.node("TestLazyMixImages", image1=input1.out(0), image2=input2.out(0), mask=mask1.out(0))
lazy_mix2 = g.node("TestLazyMixImages", image1=input2.out(0), image2=input3.out(0), mask=mask2.out(0))
output1 = g.node("SaveImage", images=lazy_mix1.out(0))
output2 = g.node("SaveImage", images=lazy_mix2.out(0))
# Run with partial execution targeting only output1
result = client.run(g, partial_execution_targets=[output1.id])
# For output1 path - only input1 should run due to lazy evaluation (mask=0.0)
assert result.was_executed(input1), "Input1 should have been executed"
assert not result.did_run(input2), "Input2 should not have run (lazy evaluation)"
assert result.was_executed(mask1), "Mask1 should have been executed"
assert result.was_executed(lazy_mix1), "Lazy mix1 should have been executed"
assert result.was_executed(output1), "Output1 should have been executed"
# Nothing from output2 path should run
assert not result.did_run(input3), "Input3 should not have run"
assert not result.did_run(mask2), "Mask2 should not have run"
assert not result.did_run(lazy_mix2), "Lazy mix2 should not have run"
assert not result.did_run(output2), "Output2 should not have run"
# Multiple OUTPUT_NODEs with dependencies
def test_partial_execution_multiple_output_nodes(self, client: ComfyClient, builder: GraphBuilder):
g = builder
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
input2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
# Create a chain of OUTPUT_NODEs
output_node1 = g.node("TestOutputNodeWithSocketOutput", image=input1.out(0), value=1.5)
output_node2 = g.node("TestOutputNodeWithSocketOutput", image=output_node1.out(0), value=2.0)
# Create regular output nodes
save1 = g.node("SaveImage", images=output_node1.out(0))
save2 = g.node("SaveImage", images=output_node2.out(0))
save3 = g.node("SaveImage", images=input2.out(0))
# Run targeting only save2
result = client.run(g, partial_execution_targets=[save2.id])
# Should run: input1, output_node1, output_node2, save2
assert result.was_executed(input1), "Input1 should have been executed"
assert result.was_executed(output_node1), "Output node 1 should have been executed (dependency)"
assert result.was_executed(output_node2), "Output node 2 should have been executed (dependency)"
assert result.was_executed(save2), "Save2 should have been executed"
# Should NOT run: input2, save1, save3
assert not result.did_run(input2), "Input2 should not have run"
assert not result.did_run(save1), "Save1 should not have run"
assert not result.did_run(save3), "Save3 should not have run"
# Empty partial execution list (should execute nothing)
def test_partial_execution_empty_list(self, client: ComfyClient, builder: GraphBuilder):
g = builder
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
_output1 = g.node("SaveImage", images=input1.out(0))
# Run with empty partial execution list
try:
_result = client.run(g, partial_execution_targets=[])
# Should get an error because no outputs are selected
assert False, "Should have raised an error for empty partial execution list"
except urllib.error.HTTPError:
pass # Expected behavior

View File

@@ -463,6 +463,25 @@ class TestParallelSleep(ComfyNodeABC):
"expand": g.finalize(),
}
class TestOutputNodeWithSocketOutput:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"image": ("IMAGE",),
"value": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0}),
},
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "process"
CATEGORY = "_for_testing"
OUTPUT_NODE = True
def process(self, image, value):
# Apply value scaling and return both as output and socket
result = image * value
return (result,)
TEST_NODE_CLASS_MAPPINGS = {
"TestLazyMixImages": TestLazyMixImages,
"TestVariadicAverage": TestVariadicAverage,
@@ -478,6 +497,7 @@ TEST_NODE_CLASS_MAPPINGS = {
"TestSamplingInExpansion": TestSamplingInExpansion,
"TestSleep": TestSleep,
"TestParallelSleep": TestParallelSleep,
"TestOutputNodeWithSocketOutput": TestOutputNodeWithSocketOutput,
}
TEST_NODE_DISPLAY_NAME_MAPPINGS = {
@@ -495,4 +515,5 @@ TEST_NODE_DISPLAY_NAME_MAPPINGS = {
"TestSamplingInExpansion": "Sampling In Expansion",
"TestSleep": "Test Sleep",
"TestParallelSleep": "Test Parallel Sleep",
"TestOutputNodeWithSocketOutput": "Test Output Node With Socket Output",
}