1
mirror of https://github.com/comfyanonymous/ComfyUI.git synced 2025-08-02 23:14:49 +08:00

Add support for partial execution in backend (#9123)

When a prompt is submitted, it can optionally include
`partial_execution_targets` as a list of ids. If it does, rather than
adding all outputs to the execution list, we add only those in the list.
This commit is contained in:
guill
2025-07-30 19:55:28 -07:00
committed by GitHub
parent 61b08d4ba6
commit 97eb256a35
5 changed files with 233 additions and 19 deletions

View File

@@ -7,7 +7,7 @@ import threading
import time import time
import traceback import traceback
from enum import Enum from enum import Enum
from typing import List, Literal, NamedTuple, Optional from typing import List, Literal, NamedTuple, Optional, Union
import asyncio import asyncio
import torch import torch
@@ -891,7 +891,7 @@ def full_type_name(klass):
return klass.__qualname__ return klass.__qualname__
return module + '.' + klass.__qualname__ return module + '.' + klass.__qualname__
async def validate_prompt(prompt_id, prompt): async def validate_prompt(prompt_id, prompt, partial_execution_list: Union[list[str], None]):
outputs = set() outputs = set()
for x in prompt: for x in prompt:
if 'class_type' not in prompt[x]: if 'class_type' not in prompt[x]:
@@ -915,6 +915,7 @@ async def validate_prompt(prompt_id, prompt):
return (False, error, [], {}) return (False, error, [], {})
if hasattr(class_, 'OUTPUT_NODE') and class_.OUTPUT_NODE is True: if hasattr(class_, 'OUTPUT_NODE') and class_.OUTPUT_NODE is True:
if partial_execution_list is None or x in partial_execution_list:
outputs.add(x) outputs.add(x)
if len(outputs) == 0: if len(outputs) == 0:

View File

@@ -681,7 +681,12 @@ class PromptServer():
if "prompt" in json_data: if "prompt" in json_data:
prompt = json_data["prompt"] prompt = json_data["prompt"]
prompt_id = str(json_data.get("prompt_id", uuid.uuid4())) prompt_id = str(json_data.get("prompt_id", uuid.uuid4()))
valid = await execution.validate_prompt(prompt_id, prompt)
partial_execution_targets = None
if "partial_execution_targets" in json_data:
partial_execution_targets = json_data["partial_execution_targets"]
valid = await execution.validate_prompt(prompt_id, prompt, partial_execution_targets)
extra_data = {} extra_data = {}
if "extra_data" in json_data: if "extra_data" in json_data:
extra_data = json_data["extra_data"] extra_data = json_data["extra_data"]

View File

@@ -7,7 +7,7 @@ import subprocess
from pytest import fixture from pytest import fixture
from comfy_execution.graph_utils import GraphBuilder from comfy_execution.graph_utils import GraphBuilder
from tests.inference.test_execution import ComfyClient from tests.inference.test_execution import ComfyClient, run_warmup
@pytest.mark.execution @pytest.mark.execution
@@ -24,6 +24,7 @@ class TestAsyncNodes:
'--listen', args_pytest["listen"], '--listen', args_pytest["listen"],
'--port', str(args_pytest["port"]), '--port', str(args_pytest["port"]),
'--extra-model-paths-config', 'tests/inference/extra_model_paths.yaml', '--extra-model-paths-config', 'tests/inference/extra_model_paths.yaml',
'--cpu',
] ]
use_lru, lru_size = request.param use_lru, lru_size = request.param
if use_lru: if use_lru:
@@ -82,6 +83,9 @@ class TestAsyncNodes:
def test_multiple_async_parallel_execution(self, client: ComfyClient, builder: GraphBuilder): def test_multiple_async_parallel_execution(self, client: ComfyClient, builder: GraphBuilder):
"""Test that multiple async nodes execute in parallel.""" """Test that multiple async nodes execute in parallel."""
# Warmup execution to ensure server is fully initialized
run_warmup(client)
g = builder g = builder
image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
@@ -148,6 +152,9 @@ class TestAsyncNodes:
def test_async_lazy_evaluation(self, client: ComfyClient, builder: GraphBuilder): def test_async_lazy_evaluation(self, client: ComfyClient, builder: GraphBuilder):
"""Test async nodes with lazy evaluation.""" """Test async nodes with lazy evaluation."""
# Warmup execution to ensure server is fully initialized
run_warmup(client, prefix="warmup_lazy")
g = builder g = builder
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
input2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1) input2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
@@ -305,6 +312,9 @@ class TestAsyncNodes:
def test_async_caching_behavior(self, client: ComfyClient, builder: GraphBuilder): def test_async_caching_behavior(self, client: ComfyClient, builder: GraphBuilder):
"""Test that async nodes are properly cached.""" """Test that async nodes are properly cached."""
# Warmup execution to ensure server is fully initialized
run_warmup(client, prefix="warmup_cache")
g = builder g = builder
image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
sleep_node = g.node("TestSleep", value=image.out(0), seconds=0.2) sleep_node = g.node("TestSleep", value=image.out(0), seconds=0.2)
@@ -324,6 +334,9 @@ class TestAsyncNodes:
def test_async_with_dynamic_prompts(self, client: ComfyClient, builder: GraphBuilder): def test_async_with_dynamic_prompts(self, client: ComfyClient, builder: GraphBuilder):
"""Test async nodes within dynamically generated prompts.""" """Test async nodes within dynamically generated prompts."""
# Warmup execution to ensure server is fully initialized
run_warmup(client, prefix="warmup_dynamic")
g = builder g = builder
image1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) image1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
image2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1) image2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)

View File

@@ -15,10 +15,18 @@ import urllib.parse
import urllib.error import urllib.error
from comfy_execution.graph_utils import GraphBuilder, Node from comfy_execution.graph_utils import GraphBuilder, Node
def run_warmup(client, prefix="warmup"):
"""Run a simple workflow to warm up the server."""
warmup_g = GraphBuilder(prefix=prefix)
warmup_image = warmup_g.node("StubImage", content="BLACK", height=32, width=32, batch_size=1)
warmup_g.node("PreviewImage", images=warmup_image.out(0))
client.run(warmup_g)
class RunResult: class RunResult:
def __init__(self, prompt_id: str): def __init__(self, prompt_id: str):
self.outputs: Dict[str,Dict] = {} self.outputs: Dict[str,Dict] = {}
self.runs: Dict[str,bool] = {} self.runs: Dict[str,bool] = {}
self.cached: Dict[str,bool] = {}
self.prompt_id: str = prompt_id self.prompt_id: str = prompt_id
def get_output(self, node: Node): def get_output(self, node: Node):
@@ -27,6 +35,13 @@ class RunResult:
def did_run(self, node: Node): def did_run(self, node: Node):
return self.runs.get(node.id, False) return self.runs.get(node.id, False)
def was_cached(self, node: Node):
return self.cached.get(node.id, False)
def was_executed(self, node: Node):
"""Returns True if node was either run or cached"""
return self.did_run(node) or self.was_cached(node)
def get_images(self, node: Node): def get_images(self, node: Node):
output = self.get_output(node) output = self.get_output(node)
if output is None: if output is None:
@@ -51,8 +66,10 @@ class ComfyClient:
ws.connect("ws://{}/ws?clientId={}".format(self.server_address, self.client_id)) ws.connect("ws://{}/ws?clientId={}".format(self.server_address, self.client_id))
self.ws = ws self.ws = ws
def queue_prompt(self, prompt): def queue_prompt(self, prompt, partial_execution_targets=None):
p = {"prompt": prompt, "client_id": self.client_id} p = {"prompt": prompt, "client_id": self.client_id}
if partial_execution_targets is not None:
p["partial_execution_targets"] = partial_execution_targets
data = json.dumps(p).encode('utf-8') data = json.dumps(p).encode('utf-8')
req = urllib.request.Request("http://{}/prompt".format(self.server_address), data=data) req = urllib.request.Request("http://{}/prompt".format(self.server_address), data=data)
return json.loads(urllib.request.urlopen(req).read()) return json.loads(urllib.request.urlopen(req).read())
@@ -70,13 +87,13 @@ class ComfyClient:
def set_test_name(self, name): def set_test_name(self, name):
self.test_name = name self.test_name = name
def run(self, graph): def run(self, graph, partial_execution_targets=None):
prompt = graph.finalize() prompt = graph.finalize()
for node in graph.nodes.values(): for node in graph.nodes.values():
if node.class_type == 'SaveImage': if node.class_type == 'SaveImage':
node.inputs['filename_prefix'] = self.test_name node.inputs['filename_prefix'] = self.test_name
prompt_id = self.queue_prompt(prompt)['prompt_id'] prompt_id = self.queue_prompt(prompt, partial_execution_targets)['prompt_id']
result = RunResult(prompt_id) result = RunResult(prompt_id)
while True: while True:
out = self.ws.recv() out = self.ws.recv()
@@ -92,7 +109,10 @@ class ComfyClient:
elif message['type'] == 'execution_error': elif message['type'] == 'execution_error':
raise Exception(message['data']) raise Exception(message['data'])
elif message['type'] == 'execution_cached': elif message['type'] == 'execution_cached':
pass # Probably want to store this off for testing if message['data']['prompt_id'] == prompt_id:
cached_nodes = message['data'].get('nodes', [])
for node_id in cached_nodes:
result.cached[node_id] = True
history = self.get_history(prompt_id)[prompt_id] history = self.get_history(prompt_id)[prompt_id]
for node_id in history['outputs']: for node_id in history['outputs']:
@@ -130,6 +150,7 @@ class TestExecution:
'--listen', args_pytest["listen"], '--listen', args_pytest["listen"],
'--port', str(args_pytest["port"]), '--port', str(args_pytest["port"]),
'--extra-model-paths-config', 'tests/inference/extra_model_paths.yaml', '--extra-model-paths-config', 'tests/inference/extra_model_paths.yaml',
'--cpu',
] ]
use_lru, lru_size = request.param use_lru, lru_size = request.param
if use_lru: if use_lru:
@@ -498,12 +519,15 @@ class TestExecution:
assert not result.did_run(test_node), "The execution should have been cached" assert not result.did_run(test_node), "The execution should have been cached"
def test_parallel_sleep_nodes(self, client: ComfyClient, builder: GraphBuilder): def test_parallel_sleep_nodes(self, client: ComfyClient, builder: GraphBuilder):
# Warmup execution to ensure server is fully initialized
run_warmup(client)
g = builder g = builder
image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
# Create sleep nodes for each duration # Create sleep nodes for each duration
sleep_node1 = g.node("TestSleep", value=image.out(0), seconds=2.8) sleep_node1 = g.node("TestSleep", value=image.out(0), seconds=2.9)
sleep_node2 = g.node("TestSleep", value=image.out(0), seconds=2.9) sleep_node2 = g.node("TestSleep", value=image.out(0), seconds=3.1)
sleep_node3 = g.node("TestSleep", value=image.out(0), seconds=3.0) sleep_node3 = g.node("TestSleep", value=image.out(0), seconds=3.0)
# Add outputs to verify the execution # Add outputs to verify the execution
@@ -515,10 +539,9 @@ class TestExecution:
result = client.run(g) result = client.run(g)
elapsed_time = time.time() - start_time elapsed_time = time.time() - start_time
# The test should take around 0.4 seconds (the longest sleep duration) # The test should take around 3.0 seconds (the longest sleep duration)
# plus some overhead, but definitely less than the sum of all sleeps (0.9s) # plus some overhead, but definitely less than the sum of all sleeps (9.0s)
# We'll allow for up to 0.8s total to account for overhead assert elapsed_time < 8.9, f"Parallel execution took {elapsed_time}s, expected less than 8.9s"
assert elapsed_time < 4.0, f"Parallel execution took {elapsed_time}s, expected less than 0.8s"
# Verify that all nodes executed # Verify that all nodes executed
assert result.did_run(sleep_node1), "Sleep node 1 should have run" assert result.did_run(sleep_node1), "Sleep node 1 should have run"
@@ -526,6 +549,9 @@ class TestExecution:
assert result.did_run(sleep_node3), "Sleep node 3 should have run" assert result.did_run(sleep_node3), "Sleep node 3 should have run"
def test_parallel_sleep_expansion(self, client: ComfyClient, builder: GraphBuilder): def test_parallel_sleep_expansion(self, client: ComfyClient, builder: GraphBuilder):
# Warmup execution to ensure server is fully initialized
run_warmup(client)
g = builder g = builder
# Create input images with different values # Create input images with different values
image1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) image1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
@@ -537,9 +563,9 @@ class TestExecution:
image1=image1.out(0), image1=image1.out(0),
image2=image2.out(0), image2=image2.out(0),
image3=image3.out(0), image3=image3.out(0),
sleep1=0.4, sleep1=4.8,
sleep2=0.5, sleep2=4.9,
sleep3=0.6) sleep3=5.0)
output = g.node("SaveImage", images=parallel_sleep.out(0)) output = g.node("SaveImage", images=parallel_sleep.out(0))
start_time = time.time() start_time = time.time()
@@ -548,7 +574,7 @@ class TestExecution:
# Similar to the previous test, expect parallel execution of the sleep nodes # Similar to the previous test, expect parallel execution of the sleep nodes
# which should complete in less than the sum of all sleeps # which should complete in less than the sum of all sleeps
assert elapsed_time < 0.8, f"Expansion execution took {elapsed_time}s, expected less than 0.8s" assert elapsed_time < 10.0, f"Expansion execution took {elapsed_time}s, expected less than 5.5s"
# Verify the parallel sleep node executed # Verify the parallel sleep node executed
assert result.did_run(parallel_sleep), "ParallelSleep node should have run" assert result.did_run(parallel_sleep), "ParallelSleep node should have run"
@@ -585,3 +611,151 @@ class TestExecution:
assert len(images) == 2, "Should have 2 images" assert len(images) == 2, "Should have 2 images"
assert numpy.array(images[0]).min() == 0 and numpy.array(images[0]).max() == 0, "First image should be black" assert numpy.array(images[0]).min() == 0 and numpy.array(images[0]).max() == 0, "First image should be black"
assert numpy.array(images[1]).min() == 0 and numpy.array(images[1]).max() == 0, "Second image should also be black" assert numpy.array(images[1]).min() == 0 and numpy.array(images[1]).max() == 0, "Second image should also be black"
# Output nodes included in the partial execution list are executed
def test_partial_execution_included_outputs(self, client: ComfyClient, builder: GraphBuilder):
g = builder
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
input2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
# Create two separate output nodes
output1 = g.node("SaveImage", images=input1.out(0))
output2 = g.node("SaveImage", images=input2.out(0))
# Run with partial execution targeting only output1
result = client.run(g, partial_execution_targets=[output1.id])
assert result.was_executed(input1), "Input1 should have been executed (run or cached)"
assert result.was_executed(output1), "Output1 should have been executed (run or cached)"
assert not result.did_run(input2), "Input2 should not have run"
assert not result.did_run(output2), "Output2 should not have run"
# Verify only output1 produced results
assert len(result.get_images(output1)) == 1, "Output1 should have produced an image"
assert len(result.get_images(output2)) == 0, "Output2 should not have produced an image"
# Output nodes NOT included in the partial execution list are NOT executed
def test_partial_execution_excluded_outputs(self, client: ComfyClient, builder: GraphBuilder):
g = builder
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
input2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
input3 = g.node("StubImage", content="NOISE", height=512, width=512, batch_size=1)
# Create three output nodes
output1 = g.node("SaveImage", images=input1.out(0))
output2 = g.node("SaveImage", images=input2.out(0))
output3 = g.node("SaveImage", images=input3.out(0))
# Run with partial execution targeting only output1 and output3
result = client.run(g, partial_execution_targets=[output1.id, output3.id])
assert result.was_executed(input1), "Input1 should have been executed"
assert result.was_executed(input3), "Input3 should have been executed"
assert result.was_executed(output1), "Output1 should have been executed"
assert result.was_executed(output3), "Output3 should have been executed"
assert not result.did_run(input2), "Input2 should not have run"
assert not result.did_run(output2), "Output2 should not have run"
# Output nodes NOT in list ARE executed if necessary for nodes that are in the list
def test_partial_execution_dependencies(self, client: ComfyClient, builder: GraphBuilder):
g = builder
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
# Create a processing chain with an OUTPUT_NODE that has socket outputs
output_with_socket = g.node("TestOutputNodeWithSocketOutput", image=input1.out(0), value=2.0)
# Create another node that depends on the output_with_socket
dependent_node = g.node("TestLazyMixImages",
image1=output_with_socket.out(0),
image2=input1.out(0),
mask=g.node("StubMask", value=0.5, height=512, width=512, batch_size=1).out(0))
# Create the final output
final_output = g.node("SaveImage", images=dependent_node.out(0))
# Run with partial execution targeting only the final output
result = client.run(g, partial_execution_targets=[final_output.id])
# All nodes should have been executed because they're dependencies
assert result.was_executed(input1), "Input1 should have been executed"
assert result.was_executed(output_with_socket), "Output with socket should have been executed (dependency)"
assert result.was_executed(dependent_node), "Dependent node should have been executed"
assert result.was_executed(final_output), "Final output should have been executed"
# Lazy execution works with partial execution
def test_partial_execution_with_lazy_nodes(self, client: ComfyClient, builder: GraphBuilder):
g = builder
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
input2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
input3 = g.node("StubImage", content="NOISE", height=512, width=512, batch_size=1)
# Create masks that will trigger different lazy execution paths
mask1 = g.node("StubMask", value=0.0, height=512, width=512, batch_size=1) # Will only need image1
mask2 = g.node("StubMask", value=0.5, height=512, width=512, batch_size=1) # Will need both images
# Create two lazy mix nodes
lazy_mix1 = g.node("TestLazyMixImages", image1=input1.out(0), image2=input2.out(0), mask=mask1.out(0))
lazy_mix2 = g.node("TestLazyMixImages", image1=input2.out(0), image2=input3.out(0), mask=mask2.out(0))
output1 = g.node("SaveImage", images=lazy_mix1.out(0))
output2 = g.node("SaveImage", images=lazy_mix2.out(0))
# Run with partial execution targeting only output1
result = client.run(g, partial_execution_targets=[output1.id])
# For output1 path - only input1 should run due to lazy evaluation (mask=0.0)
assert result.was_executed(input1), "Input1 should have been executed"
assert not result.did_run(input2), "Input2 should not have run (lazy evaluation)"
assert result.was_executed(mask1), "Mask1 should have been executed"
assert result.was_executed(lazy_mix1), "Lazy mix1 should have been executed"
assert result.was_executed(output1), "Output1 should have been executed"
# Nothing from output2 path should run
assert not result.did_run(input3), "Input3 should not have run"
assert not result.did_run(mask2), "Mask2 should not have run"
assert not result.did_run(lazy_mix2), "Lazy mix2 should not have run"
assert not result.did_run(output2), "Output2 should not have run"
# Multiple OUTPUT_NODEs with dependencies
def test_partial_execution_multiple_output_nodes(self, client: ComfyClient, builder: GraphBuilder):
g = builder
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
input2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
# Create a chain of OUTPUT_NODEs
output_node1 = g.node("TestOutputNodeWithSocketOutput", image=input1.out(0), value=1.5)
output_node2 = g.node("TestOutputNodeWithSocketOutput", image=output_node1.out(0), value=2.0)
# Create regular output nodes
save1 = g.node("SaveImage", images=output_node1.out(0))
save2 = g.node("SaveImage", images=output_node2.out(0))
save3 = g.node("SaveImage", images=input2.out(0))
# Run targeting only save2
result = client.run(g, partial_execution_targets=[save2.id])
# Should run: input1, output_node1, output_node2, save2
assert result.was_executed(input1), "Input1 should have been executed"
assert result.was_executed(output_node1), "Output node 1 should have been executed (dependency)"
assert result.was_executed(output_node2), "Output node 2 should have been executed (dependency)"
assert result.was_executed(save2), "Save2 should have been executed"
# Should NOT run: input2, save1, save3
assert not result.did_run(input2), "Input2 should not have run"
assert not result.did_run(save1), "Save1 should not have run"
assert not result.did_run(save3), "Save3 should not have run"
# Empty partial execution list (should execute nothing)
def test_partial_execution_empty_list(self, client: ComfyClient, builder: GraphBuilder):
g = builder
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
_output1 = g.node("SaveImage", images=input1.out(0))
# Run with empty partial execution list
try:
_result = client.run(g, partial_execution_targets=[])
# Should get an error because no outputs are selected
assert False, "Should have raised an error for empty partial execution list"
except urllib.error.HTTPError:
pass # Expected behavior

View File

@@ -463,6 +463,25 @@ class TestParallelSleep(ComfyNodeABC):
"expand": g.finalize(), "expand": g.finalize(),
} }
class TestOutputNodeWithSocketOutput:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"image": ("IMAGE",),
"value": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0}),
},
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "process"
CATEGORY = "_for_testing"
OUTPUT_NODE = True
def process(self, image, value):
# Apply value scaling and return both as output and socket
result = image * value
return (result,)
TEST_NODE_CLASS_MAPPINGS = { TEST_NODE_CLASS_MAPPINGS = {
"TestLazyMixImages": TestLazyMixImages, "TestLazyMixImages": TestLazyMixImages,
"TestVariadicAverage": TestVariadicAverage, "TestVariadicAverage": TestVariadicAverage,
@@ -478,6 +497,7 @@ TEST_NODE_CLASS_MAPPINGS = {
"TestSamplingInExpansion": TestSamplingInExpansion, "TestSamplingInExpansion": TestSamplingInExpansion,
"TestSleep": TestSleep, "TestSleep": TestSleep,
"TestParallelSleep": TestParallelSleep, "TestParallelSleep": TestParallelSleep,
"TestOutputNodeWithSocketOutput": TestOutputNodeWithSocketOutput,
} }
TEST_NODE_DISPLAY_NAME_MAPPINGS = { TEST_NODE_DISPLAY_NAME_MAPPINGS = {
@@ -495,4 +515,5 @@ TEST_NODE_DISPLAY_NAME_MAPPINGS = {
"TestSamplingInExpansion": "Sampling In Expansion", "TestSamplingInExpansion": "Sampling In Expansion",
"TestSleep": "Test Sleep", "TestSleep": "Test Sleep",
"TestParallelSleep": "Test Parallel Sleep", "TestParallelSleep": "Test Parallel Sleep",
"TestOutputNodeWithSocketOutput": "Test Output Node With Socket Output",
} }