From 97eb256a355b434bbc96ec27bbce33dd10273857 Mon Sep 17 00:00:00 2001 From: guill Date: Wed, 30 Jul 2025 19:55:28 -0700 Subject: [PATCH] Add support for partial execution in backend (#9123) When a prompt is submitted, it can optionally include `partial_execution_targets` as a list of ids. If it does, rather than adding all outputs to the execution list, we add only those in the list. --- execution.py | 7 +- server.py | 7 +- tests/inference/test_async_nodes.py | 15 +- tests/inference/test_execution.py | 202 ++++++++++++++++-- .../testing-pack/specific_tests.py | 21 ++ 5 files changed, 233 insertions(+), 19 deletions(-) diff --git a/execution.py b/execution.py index 8a9663a7d..cde14c52f 100644 --- a/execution.py +++ b/execution.py @@ -7,7 +7,7 @@ import threading import time import traceback from enum import Enum -from typing import List, Literal, NamedTuple, Optional +from typing import List, Literal, NamedTuple, Optional, Union import asyncio import torch @@ -891,7 +891,7 @@ def full_type_name(klass): return klass.__qualname__ return module + '.' + klass.__qualname__ -async def validate_prompt(prompt_id, prompt): +async def validate_prompt(prompt_id, prompt, partial_execution_list: Union[list[str], None]): outputs = set() for x in prompt: if 'class_type' not in prompt[x]: @@ -915,7 +915,8 @@ async def validate_prompt(prompt_id, prompt): return (False, error, [], {}) if hasattr(class_, 'OUTPUT_NODE') and class_.OUTPUT_NODE is True: - outputs.add(x) + if partial_execution_list is None or x in partial_execution_list: + outputs.add(x) if len(outputs) == 0: error = { diff --git a/server.py b/server.py index f4de0079b..3e06d2fbb 100644 --- a/server.py +++ b/server.py @@ -681,7 +681,12 @@ class PromptServer(): if "prompt" in json_data: prompt = json_data["prompt"] prompt_id = str(json_data.get("prompt_id", uuid.uuid4())) - valid = await execution.validate_prompt(prompt_id, prompt) + + partial_execution_targets = None + if "partial_execution_targets" in json_data: + partial_execution_targets = json_data["partial_execution_targets"] + + valid = await execution.validate_prompt(prompt_id, prompt, partial_execution_targets) extra_data = {} if "extra_data" in json_data: extra_data = json_data["extra_data"] diff --git a/tests/inference/test_async_nodes.py b/tests/inference/test_async_nodes.py index b243bbca9..f029953dd 100644 --- a/tests/inference/test_async_nodes.py +++ b/tests/inference/test_async_nodes.py @@ -7,7 +7,7 @@ import subprocess from pytest import fixture from comfy_execution.graph_utils import GraphBuilder -from tests.inference.test_execution import ComfyClient +from tests.inference.test_execution import ComfyClient, run_warmup @pytest.mark.execution @@ -24,6 +24,7 @@ class TestAsyncNodes: '--listen', args_pytest["listen"], '--port', str(args_pytest["port"]), '--extra-model-paths-config', 'tests/inference/extra_model_paths.yaml', + '--cpu', ] use_lru, lru_size = request.param if use_lru: @@ -82,6 +83,9 @@ class TestAsyncNodes: def test_multiple_async_parallel_execution(self, client: ComfyClient, builder: GraphBuilder): """Test that multiple async nodes execute in parallel.""" + # Warmup execution to ensure server is fully initialized + run_warmup(client) + g = builder image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) @@ -148,6 +152,9 @@ class TestAsyncNodes: def test_async_lazy_evaluation(self, client: ComfyClient, builder: GraphBuilder): """Test async nodes with lazy evaluation.""" + # Warmup execution to ensure server is fully initialized + run_warmup(client, prefix="warmup_lazy") + g = builder input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) input2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1) @@ -305,6 +312,9 @@ class TestAsyncNodes: def test_async_caching_behavior(self, client: ComfyClient, builder: GraphBuilder): """Test that async nodes are properly cached.""" + # Warmup execution to ensure server is fully initialized + run_warmup(client, prefix="warmup_cache") + g = builder image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) sleep_node = g.node("TestSleep", value=image.out(0), seconds=0.2) @@ -324,6 +334,9 @@ class TestAsyncNodes: def test_async_with_dynamic_prompts(self, client: ComfyClient, builder: GraphBuilder): """Test async nodes within dynamically generated prompts.""" + # Warmup execution to ensure server is fully initialized + run_warmup(client, prefix="warmup_dynamic") + g = builder image1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) image2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1) diff --git a/tests/inference/test_execution.py b/tests/inference/test_execution.py index 9d3d685cc..e7b29302e 100644 --- a/tests/inference/test_execution.py +++ b/tests/inference/test_execution.py @@ -15,10 +15,18 @@ import urllib.parse import urllib.error from comfy_execution.graph_utils import GraphBuilder, Node +def run_warmup(client, prefix="warmup"): + """Run a simple workflow to warm up the server.""" + warmup_g = GraphBuilder(prefix=prefix) + warmup_image = warmup_g.node("StubImage", content="BLACK", height=32, width=32, batch_size=1) + warmup_g.node("PreviewImage", images=warmup_image.out(0)) + client.run(warmup_g) + class RunResult: def __init__(self, prompt_id: str): self.outputs: Dict[str,Dict] = {} self.runs: Dict[str,bool] = {} + self.cached: Dict[str,bool] = {} self.prompt_id: str = prompt_id def get_output(self, node: Node): @@ -27,6 +35,13 @@ class RunResult: def did_run(self, node: Node): return self.runs.get(node.id, False) + def was_cached(self, node: Node): + return self.cached.get(node.id, False) + + def was_executed(self, node: Node): + """Returns True if node was either run or cached""" + return self.did_run(node) or self.was_cached(node) + def get_images(self, node: Node): output = self.get_output(node) if output is None: @@ -51,8 +66,10 @@ class ComfyClient: ws.connect("ws://{}/ws?clientId={}".format(self.server_address, self.client_id)) self.ws = ws - def queue_prompt(self, prompt): + def queue_prompt(self, prompt, partial_execution_targets=None): p = {"prompt": prompt, "client_id": self.client_id} + if partial_execution_targets is not None: + p["partial_execution_targets"] = partial_execution_targets data = json.dumps(p).encode('utf-8') req = urllib.request.Request("http://{}/prompt".format(self.server_address), data=data) return json.loads(urllib.request.urlopen(req).read()) @@ -70,13 +87,13 @@ class ComfyClient: def set_test_name(self, name): self.test_name = name - def run(self, graph): + def run(self, graph, partial_execution_targets=None): prompt = graph.finalize() for node in graph.nodes.values(): if node.class_type == 'SaveImage': node.inputs['filename_prefix'] = self.test_name - prompt_id = self.queue_prompt(prompt)['prompt_id'] + prompt_id = self.queue_prompt(prompt, partial_execution_targets)['prompt_id'] result = RunResult(prompt_id) while True: out = self.ws.recv() @@ -92,7 +109,10 @@ class ComfyClient: elif message['type'] == 'execution_error': raise Exception(message['data']) elif message['type'] == 'execution_cached': - pass # Probably want to store this off for testing + if message['data']['prompt_id'] == prompt_id: + cached_nodes = message['data'].get('nodes', []) + for node_id in cached_nodes: + result.cached[node_id] = True history = self.get_history(prompt_id)[prompt_id] for node_id in history['outputs']: @@ -130,6 +150,7 @@ class TestExecution: '--listen', args_pytest["listen"], '--port', str(args_pytest["port"]), '--extra-model-paths-config', 'tests/inference/extra_model_paths.yaml', + '--cpu', ] use_lru, lru_size = request.param if use_lru: @@ -498,12 +519,15 @@ class TestExecution: assert not result.did_run(test_node), "The execution should have been cached" def test_parallel_sleep_nodes(self, client: ComfyClient, builder: GraphBuilder): + # Warmup execution to ensure server is fully initialized + run_warmup(client) + g = builder image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) # Create sleep nodes for each duration - sleep_node1 = g.node("TestSleep", value=image.out(0), seconds=2.8) - sleep_node2 = g.node("TestSleep", value=image.out(0), seconds=2.9) + sleep_node1 = g.node("TestSleep", value=image.out(0), seconds=2.9) + sleep_node2 = g.node("TestSleep", value=image.out(0), seconds=3.1) sleep_node3 = g.node("TestSleep", value=image.out(0), seconds=3.0) # Add outputs to verify the execution @@ -515,10 +539,9 @@ class TestExecution: result = client.run(g) elapsed_time = time.time() - start_time - # The test should take around 0.4 seconds (the longest sleep duration) - # plus some overhead, but definitely less than the sum of all sleeps (0.9s) - # We'll allow for up to 0.8s total to account for overhead - assert elapsed_time < 4.0, f"Parallel execution took {elapsed_time}s, expected less than 0.8s" + # The test should take around 3.0 seconds (the longest sleep duration) + # plus some overhead, but definitely less than the sum of all sleeps (9.0s) + assert elapsed_time < 8.9, f"Parallel execution took {elapsed_time}s, expected less than 8.9s" # Verify that all nodes executed assert result.did_run(sleep_node1), "Sleep node 1 should have run" @@ -526,6 +549,9 @@ class TestExecution: assert result.did_run(sleep_node3), "Sleep node 3 should have run" def test_parallel_sleep_expansion(self, client: ComfyClient, builder: GraphBuilder): + # Warmup execution to ensure server is fully initialized + run_warmup(client) + g = builder # Create input images with different values image1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) @@ -537,9 +563,9 @@ class TestExecution: image1=image1.out(0), image2=image2.out(0), image3=image3.out(0), - sleep1=0.4, - sleep2=0.5, - sleep3=0.6) + sleep1=4.8, + sleep2=4.9, + sleep3=5.0) output = g.node("SaveImage", images=parallel_sleep.out(0)) start_time = time.time() @@ -548,7 +574,7 @@ class TestExecution: # Similar to the previous test, expect parallel execution of the sleep nodes # which should complete in less than the sum of all sleeps - assert elapsed_time < 0.8, f"Expansion execution took {elapsed_time}s, expected less than 0.8s" + assert elapsed_time < 10.0, f"Expansion execution took {elapsed_time}s, expected less than 5.5s" # Verify the parallel sleep node executed assert result.did_run(parallel_sleep), "ParallelSleep node should have run" @@ -585,3 +611,151 @@ class TestExecution: assert len(images) == 2, "Should have 2 images" assert numpy.array(images[0]).min() == 0 and numpy.array(images[0]).max() == 0, "First image should be black" assert numpy.array(images[1]).min() == 0 and numpy.array(images[1]).max() == 0, "Second image should also be black" + + # Output nodes included in the partial execution list are executed + def test_partial_execution_included_outputs(self, client: ComfyClient, builder: GraphBuilder): + g = builder + input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) + input2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1) + + # Create two separate output nodes + output1 = g.node("SaveImage", images=input1.out(0)) + output2 = g.node("SaveImage", images=input2.out(0)) + + # Run with partial execution targeting only output1 + result = client.run(g, partial_execution_targets=[output1.id]) + + assert result.was_executed(input1), "Input1 should have been executed (run or cached)" + assert result.was_executed(output1), "Output1 should have been executed (run or cached)" + assert not result.did_run(input2), "Input2 should not have run" + assert not result.did_run(output2), "Output2 should not have run" + + # Verify only output1 produced results + assert len(result.get_images(output1)) == 1, "Output1 should have produced an image" + assert len(result.get_images(output2)) == 0, "Output2 should not have produced an image" + + # Output nodes NOT included in the partial execution list are NOT executed + def test_partial_execution_excluded_outputs(self, client: ComfyClient, builder: GraphBuilder): + g = builder + input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) + input2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1) + input3 = g.node("StubImage", content="NOISE", height=512, width=512, batch_size=1) + + # Create three output nodes + output1 = g.node("SaveImage", images=input1.out(0)) + output2 = g.node("SaveImage", images=input2.out(0)) + output3 = g.node("SaveImage", images=input3.out(0)) + + # Run with partial execution targeting only output1 and output3 + result = client.run(g, partial_execution_targets=[output1.id, output3.id]) + + assert result.was_executed(input1), "Input1 should have been executed" + assert result.was_executed(input3), "Input3 should have been executed" + assert result.was_executed(output1), "Output1 should have been executed" + assert result.was_executed(output3), "Output3 should have been executed" + assert not result.did_run(input2), "Input2 should not have run" + assert not result.did_run(output2), "Output2 should not have run" + + # Output nodes NOT in list ARE executed if necessary for nodes that are in the list + def test_partial_execution_dependencies(self, client: ComfyClient, builder: GraphBuilder): + g = builder + input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) + + # Create a processing chain with an OUTPUT_NODE that has socket outputs + output_with_socket = g.node("TestOutputNodeWithSocketOutput", image=input1.out(0), value=2.0) + + # Create another node that depends on the output_with_socket + dependent_node = g.node("TestLazyMixImages", + image1=output_with_socket.out(0), + image2=input1.out(0), + mask=g.node("StubMask", value=0.5, height=512, width=512, batch_size=1).out(0)) + + # Create the final output + final_output = g.node("SaveImage", images=dependent_node.out(0)) + + # Run with partial execution targeting only the final output + result = client.run(g, partial_execution_targets=[final_output.id]) + + # All nodes should have been executed because they're dependencies + assert result.was_executed(input1), "Input1 should have been executed" + assert result.was_executed(output_with_socket), "Output with socket should have been executed (dependency)" + assert result.was_executed(dependent_node), "Dependent node should have been executed" + assert result.was_executed(final_output), "Final output should have been executed" + + # Lazy execution works with partial execution + def test_partial_execution_with_lazy_nodes(self, client: ComfyClient, builder: GraphBuilder): + g = builder + input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) + input2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1) + input3 = g.node("StubImage", content="NOISE", height=512, width=512, batch_size=1) + + # Create masks that will trigger different lazy execution paths + mask1 = g.node("StubMask", value=0.0, height=512, width=512, batch_size=1) # Will only need image1 + mask2 = g.node("StubMask", value=0.5, height=512, width=512, batch_size=1) # Will need both images + + # Create two lazy mix nodes + lazy_mix1 = g.node("TestLazyMixImages", image1=input1.out(0), image2=input2.out(0), mask=mask1.out(0)) + lazy_mix2 = g.node("TestLazyMixImages", image1=input2.out(0), image2=input3.out(0), mask=mask2.out(0)) + + output1 = g.node("SaveImage", images=lazy_mix1.out(0)) + output2 = g.node("SaveImage", images=lazy_mix2.out(0)) + + # Run with partial execution targeting only output1 + result = client.run(g, partial_execution_targets=[output1.id]) + + # For output1 path - only input1 should run due to lazy evaluation (mask=0.0) + assert result.was_executed(input1), "Input1 should have been executed" + assert not result.did_run(input2), "Input2 should not have run (lazy evaluation)" + assert result.was_executed(mask1), "Mask1 should have been executed" + assert result.was_executed(lazy_mix1), "Lazy mix1 should have been executed" + assert result.was_executed(output1), "Output1 should have been executed" + + # Nothing from output2 path should run + assert not result.did_run(input3), "Input3 should not have run" + assert not result.did_run(mask2), "Mask2 should not have run" + assert not result.did_run(lazy_mix2), "Lazy mix2 should not have run" + assert not result.did_run(output2), "Output2 should not have run" + + # Multiple OUTPUT_NODEs with dependencies + def test_partial_execution_multiple_output_nodes(self, client: ComfyClient, builder: GraphBuilder): + g = builder + input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) + input2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1) + + # Create a chain of OUTPUT_NODEs + output_node1 = g.node("TestOutputNodeWithSocketOutput", image=input1.out(0), value=1.5) + output_node2 = g.node("TestOutputNodeWithSocketOutput", image=output_node1.out(0), value=2.0) + + # Create regular output nodes + save1 = g.node("SaveImage", images=output_node1.out(0)) + save2 = g.node("SaveImage", images=output_node2.out(0)) + save3 = g.node("SaveImage", images=input2.out(0)) + + # Run targeting only save2 + result = client.run(g, partial_execution_targets=[save2.id]) + + # Should run: input1, output_node1, output_node2, save2 + assert result.was_executed(input1), "Input1 should have been executed" + assert result.was_executed(output_node1), "Output node 1 should have been executed (dependency)" + assert result.was_executed(output_node2), "Output node 2 should have been executed (dependency)" + assert result.was_executed(save2), "Save2 should have been executed" + + # Should NOT run: input2, save1, save3 + assert not result.did_run(input2), "Input2 should not have run" + assert not result.did_run(save1), "Save1 should not have run" + assert not result.did_run(save3), "Save3 should not have run" + + # Empty partial execution list (should execute nothing) + def test_partial_execution_empty_list(self, client: ComfyClient, builder: GraphBuilder): + g = builder + input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) + _output1 = g.node("SaveImage", images=input1.out(0)) + + # Run with empty partial execution list + try: + _result = client.run(g, partial_execution_targets=[]) + # Should get an error because no outputs are selected + assert False, "Should have raised an error for empty partial execution list" + except urllib.error.HTTPError: + pass # Expected behavior + diff --git a/tests/inference/testing_nodes/testing-pack/specific_tests.py b/tests/inference/testing_nodes/testing-pack/specific_tests.py index 657d49f2f..4f8f01ae4 100644 --- a/tests/inference/testing_nodes/testing-pack/specific_tests.py +++ b/tests/inference/testing_nodes/testing-pack/specific_tests.py @@ -463,6 +463,25 @@ class TestParallelSleep(ComfyNodeABC): "expand": g.finalize(), } +class TestOutputNodeWithSocketOutput: + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "image": ("IMAGE",), + "value": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0}), + }, + } + RETURN_TYPES = ("IMAGE",) + FUNCTION = "process" + CATEGORY = "_for_testing" + OUTPUT_NODE = True + + def process(self, image, value): + # Apply value scaling and return both as output and socket + result = image * value + return (result,) + TEST_NODE_CLASS_MAPPINGS = { "TestLazyMixImages": TestLazyMixImages, "TestVariadicAverage": TestVariadicAverage, @@ -478,6 +497,7 @@ TEST_NODE_CLASS_MAPPINGS = { "TestSamplingInExpansion": TestSamplingInExpansion, "TestSleep": TestSleep, "TestParallelSleep": TestParallelSleep, + "TestOutputNodeWithSocketOutput": TestOutputNodeWithSocketOutput, } TEST_NODE_DISPLAY_NAME_MAPPINGS = { @@ -495,4 +515,5 @@ TEST_NODE_DISPLAY_NAME_MAPPINGS = { "TestSamplingInExpansion": "Sampling In Expansion", "TestSleep": "Test Sleep", "TestParallelSleep": "Test Parallel Sleep", + "TestOutputNodeWithSocketOutput": "Test Output Node With Socket Output", }