mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-08-02 15:04:50 +08:00
Compare commits
96 Commits
v3-definit
...
67e906aa64
Author | SHA1 | Date | |
---|---|---|---|
|
67e906aa64 | ||
|
97b8a2c26a | ||
|
97eb256a35 | ||
|
61b08d4ba6 | ||
|
da9dab7edd | ||
|
d2aaef029c | ||
|
382f84a826 | ||
|
9cca36fa2b | ||
|
5d5024296d | ||
|
3b90a30178 | ||
|
3c4104652b | ||
|
9855baaab3 | ||
|
d53479a197 | ||
|
443a795850 | ||
|
431dec8e53 | ||
|
44e053c26d | ||
|
1ae98932f1 | ||
|
0336b0ace8 | ||
|
8ae25235ec | ||
|
9726eac475 | ||
|
272e8d42c1 | ||
|
6211d2be5a | ||
|
8be711715c | ||
|
b5cccf1325 | ||
|
2a54a904f4 | ||
|
ed6f92c975 | ||
|
adc66c0698 | ||
|
ccd5c01e5a | ||
|
2fa9affcc1 | ||
|
407a5a656f | ||
|
9ce9ff8ef8 | ||
|
63567c0ce8 | ||
|
a786ce5ead | ||
|
4879b47648 | ||
|
5ccec33c22 | ||
|
219d3cd0d0 | ||
|
c4ba399475 | ||
|
cc928a786d | ||
|
6e144b98c4 | ||
|
6dca17bd2d | ||
|
5080105c23 | ||
|
093914a247 | ||
|
605893d3cf | ||
|
048f4f0b3a | ||
|
d2504fb701 | ||
|
b03763bca6 | ||
|
476aa79b64 | ||
|
441cfd1a7a | ||
|
99a5c1068a | ||
|
02747cde7d | ||
|
0b3233b4e2 | ||
|
eda866bf51 | ||
|
e3298b84de | ||
|
c7feef9060 | ||
|
51af7fa1b4 | ||
|
46969c380a | ||
|
5db4277449 | ||
|
02a4d0ad7d | ||
|
ef137ac0b6 | ||
|
328d4f16a9 | ||
|
bdbcb85b8d | ||
|
6c9e94bae7 | ||
|
bfce723311 | ||
|
31f5458938 | ||
|
2145a202eb | ||
|
25818dc848 | ||
|
198953cd08 | ||
|
ec16ee2f39 | ||
|
d5088072fb | ||
|
8d4b50158e | ||
|
e88c6c03ff | ||
|
d3cf2b7b24 | ||
|
7448f02b7c | ||
|
871258aa72 | ||
|
66838ebd39 | ||
|
7333281698 | ||
|
3cd4c5cb0a | ||
|
11c6d56037 | ||
|
216fea15ee | ||
|
58bf8815c8 | ||
|
1b38f5bf57 | ||
|
2724ac4a60 | ||
|
f48f90e471 | ||
|
6463c39ce0 | ||
|
0a7e2ae787 | ||
|
03a97b604a | ||
|
4446c86052 | ||
|
8270ff312f | ||
|
db2d7ad9ba | ||
|
6620d86318 | ||
|
111fd0cadf | ||
|
776aa734e1 | ||
|
5a2ad032cb | ||
|
d44295ef71 | ||
|
bf21be066f | ||
|
72bbf49349 |
@@ -111,7 +111,7 @@ Workflow examples can be found on the [Examples page](https://comfyanonymous.git
|
||||
|
||||
## Release Process
|
||||
|
||||
ComfyUI follows a weekly release cycle every Friday, with three interconnected repositories:
|
||||
ComfyUI follows a weekly release cycle targeting Friday but this regularly changes because of model releases or large changes to the codebase. There are three interconnected repositories:
|
||||
|
||||
1. **[ComfyUI Core](https://github.com/comfyanonymous/ComfyUI)**
|
||||
- Releases a new stable version (e.g., v0.7.0)
|
||||
|
@@ -49,7 +49,7 @@ parser.add_argument("--temp-directory", type=str, default=None, help="Set the Co
|
||||
parser.add_argument("--input-directory", type=str, default=None, help="Set the ComfyUI input directory. Overrides --base-directory.")
|
||||
parser.add_argument("--auto-launch", action="store_true", help="Automatically launch ComfyUI in the default browser.")
|
||||
parser.add_argument("--disable-auto-launch", action="store_true", help="Disable auto launching the browser.")
|
||||
parser.add_argument("--cuda-device", type=int, default=None, metavar="DEVICE_ID", help="Set the id of the cuda device this instance will use. All other devices will not be visible.")
|
||||
parser.add_argument("--cuda-device", type=str, default=None, metavar="DEVICE_ID", help="Set the ids of cuda devices this instance will use. All other devices will not be visible.")
|
||||
parser.add_argument("--default-device", type=int, default=None, metavar="DEFAULT_DEVICE_ID", help="Set the id of the default device, all other devices will stay visible.")
|
||||
cm_group = parser.add_mutually_exclusive_group()
|
||||
cm_group.add_argument("--cuda-malloc", action="store_true", help="Enable cudaMallocAsync (enabled by default for torch 2.0 and up).")
|
||||
|
@@ -15,13 +15,14 @@
|
||||
You should have received a copy of the GNU General Public License
|
||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
from enum import Enum
|
||||
import math
|
||||
import os
|
||||
import logging
|
||||
import copy
|
||||
import comfy.utils
|
||||
import comfy.model_management
|
||||
import comfy.model_detection
|
||||
@@ -36,7 +37,7 @@ import comfy.cldm.mmdit
|
||||
import comfy.ldm.hydit.controlnet
|
||||
import comfy.ldm.flux.controlnet
|
||||
import comfy.cldm.dit_embedder
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Union
|
||||
if TYPE_CHECKING:
|
||||
from comfy.hooks import HookGroup
|
||||
|
||||
@@ -63,6 +64,18 @@ class StrengthType(Enum):
|
||||
CONSTANT = 1
|
||||
LINEAR_UP = 2
|
||||
|
||||
class ControlIsolation:
|
||||
'''Temporarily set a ControlBase object's previous_controlnet to None to prevent cascading calls.'''
|
||||
def __init__(self, control: ControlBase):
|
||||
self.control = control
|
||||
self.orig_previous_controlnet = control.previous_controlnet
|
||||
|
||||
def __enter__(self):
|
||||
self.control.previous_controlnet = None
|
||||
|
||||
def __exit__(self, *args):
|
||||
self.control.previous_controlnet = self.orig_previous_controlnet
|
||||
|
||||
class ControlBase:
|
||||
def __init__(self):
|
||||
self.cond_hint_original = None
|
||||
@@ -76,7 +89,7 @@ class ControlBase:
|
||||
self.compression_ratio = 8
|
||||
self.upscale_algorithm = 'nearest-exact'
|
||||
self.extra_args = {}
|
||||
self.previous_controlnet = None
|
||||
self.previous_controlnet: Union[ControlBase, None] = None
|
||||
self.extra_conds = []
|
||||
self.strength_type = StrengthType.CONSTANT
|
||||
self.concat_mask = False
|
||||
@@ -84,6 +97,7 @@ class ControlBase:
|
||||
self.extra_concat = None
|
||||
self.extra_hooks: HookGroup = None
|
||||
self.preprocess_image = lambda a: a
|
||||
self.multigpu_clones: dict[torch.device, ControlBase] = {}
|
||||
|
||||
def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(0.0, 1.0), vae=None, extra_concat=[]):
|
||||
self.cond_hint_original = cond_hint
|
||||
@@ -110,17 +124,38 @@ class ControlBase:
|
||||
def cleanup(self):
|
||||
if self.previous_controlnet is not None:
|
||||
self.previous_controlnet.cleanup()
|
||||
|
||||
for device_cnet in self.multigpu_clones.values():
|
||||
with ControlIsolation(device_cnet):
|
||||
device_cnet.cleanup()
|
||||
self.cond_hint = None
|
||||
self.extra_concat = None
|
||||
self.timestep_range = None
|
||||
|
||||
def get_models(self):
|
||||
out = []
|
||||
for device_cnet in self.multigpu_clones.values():
|
||||
out += device_cnet.get_models_only_self()
|
||||
if self.previous_controlnet is not None:
|
||||
out += self.previous_controlnet.get_models()
|
||||
return out
|
||||
|
||||
def get_models_only_self(self):
|
||||
'Calls get_models, but temporarily sets previous_controlnet to None.'
|
||||
with ControlIsolation(self):
|
||||
return self.get_models()
|
||||
|
||||
def get_instance_for_device(self, device):
|
||||
'Returns instance of this Control object intended for selected device.'
|
||||
return self.multigpu_clones.get(device, self)
|
||||
|
||||
def deepclone_multigpu(self, load_device, autoregister=False):
|
||||
'''
|
||||
Create deep clone of Control object where model(s) is set to other devices.
|
||||
|
||||
When autoregister is set to True, the deep clone is also added to multigpu_clones dict.
|
||||
'''
|
||||
raise NotImplementedError("Classes inheriting from ControlBase should define their own deepclone_multigpu funtion.")
|
||||
|
||||
def get_extra_hooks(self):
|
||||
out = []
|
||||
if self.extra_hooks is not None:
|
||||
@@ -129,7 +164,7 @@ class ControlBase:
|
||||
out += self.previous_controlnet.get_extra_hooks()
|
||||
return out
|
||||
|
||||
def copy_to(self, c):
|
||||
def copy_to(self, c: ControlBase):
|
||||
c.cond_hint_original = self.cond_hint_original
|
||||
c.strength = self.strength
|
||||
c.timestep_percent_range = self.timestep_percent_range
|
||||
@@ -280,6 +315,14 @@ class ControlNet(ControlBase):
|
||||
self.copy_to(c)
|
||||
return c
|
||||
|
||||
def deepclone_multigpu(self, load_device, autoregister=False):
|
||||
c = self.copy()
|
||||
c.control_model = copy.deepcopy(c.control_model)
|
||||
c.control_model_wrapped = comfy.model_patcher.ModelPatcher(c.control_model, load_device=load_device, offload_device=comfy.model_management.unet_offload_device())
|
||||
if autoregister:
|
||||
self.multigpu_clones[load_device] = c
|
||||
return c
|
||||
|
||||
def get_models(self):
|
||||
out = super().get_models()
|
||||
out.append(self.control_model_wrapped)
|
||||
@@ -806,6 +849,14 @@ class T2IAdapter(ControlBase):
|
||||
self.copy_to(c)
|
||||
return c
|
||||
|
||||
def deepclone_multigpu(self, load_device, autoregister=False):
|
||||
c = self.copy()
|
||||
c.t2i_model = copy.deepcopy(c.t2i_model)
|
||||
c.device = load_device
|
||||
if autoregister:
|
||||
self.multigpu_clones[load_device] = c
|
||||
return c
|
||||
|
||||
def load_t2i_adapter(t2i_data, model_options={}): #TODO: model_options
|
||||
compression_ratio = 8
|
||||
upscale_algorithm = 'nearest-exact'
|
||||
|
@@ -58,7 +58,8 @@ def is_odd(n: int) -> bool:
|
||||
|
||||
|
||||
def nonlinearity(x):
|
||||
return x * torch.sigmoid(x)
|
||||
# x * sigmoid(x)
|
||||
return torch.nn.functional.silu(x)
|
||||
|
||||
|
||||
def Normalize(in_channels, num_groups=32):
|
||||
|
@@ -36,7 +36,7 @@ def get_timestep_embedding(timesteps, embedding_dim):
|
||||
|
||||
def nonlinearity(x):
|
||||
# swish
|
||||
return x*torch.sigmoid(x)
|
||||
return torch.nn.functional.silu(x)
|
||||
|
||||
|
||||
def Normalize(in_channels, num_groups=32):
|
||||
|
@@ -769,8 +769,7 @@ class CameraWanModel(WanModel):
|
||||
# embeddings
|
||||
x = self.patch_embedding(x.float()).to(x.dtype)
|
||||
if self.control_adapter is not None and camera_conditions is not None:
|
||||
x_camera = self.control_adapter(camera_conditions).to(x.dtype)
|
||||
x = x + x_camera
|
||||
x = x + self.control_adapter(camera_conditions).to(x.dtype)
|
||||
grid_sizes = x.shape[2:]
|
||||
x = x.flatten(2).transpose(1, 2)
|
||||
|
||||
|
@@ -15,6 +15,7 @@
|
||||
You should have received a copy of the GNU General Public License
|
||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import psutil
|
||||
import logging
|
||||
@@ -26,6 +27,10 @@ import platform
|
||||
import weakref
|
||||
import gc
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
if TYPE_CHECKING:
|
||||
from comfy.model_patcher import ModelPatcher
|
||||
|
||||
class VRAMState(Enum):
|
||||
DISABLED = 0 #No vram present: no need to move models to vram
|
||||
NO_VRAM = 1 #Very low vram: enable all the options to save vram
|
||||
@@ -182,6 +187,25 @@ def get_torch_device():
|
||||
else:
|
||||
return torch.device(torch.cuda.current_device())
|
||||
|
||||
def get_all_torch_devices(exclude_current=False):
|
||||
global cpu_state
|
||||
devices = []
|
||||
if cpu_state == CPUState.GPU:
|
||||
if is_nvidia():
|
||||
for i in range(torch.cuda.device_count()):
|
||||
devices.append(torch.device(i))
|
||||
elif is_intel_xpu():
|
||||
for i in range(torch.xpu.device_count()):
|
||||
devices.append(torch.device(i))
|
||||
elif is_ascend_npu():
|
||||
for i in range(torch.npu.device_count()):
|
||||
devices.append(torch.device(i))
|
||||
else:
|
||||
devices.append(get_torch_device())
|
||||
if exclude_current:
|
||||
devices.remove(get_torch_device())
|
||||
return devices
|
||||
|
||||
def get_total_memory(dev=None, torch_total_too=False):
|
||||
global directml_enabled
|
||||
if dev is None:
|
||||
@@ -409,9 +433,13 @@ try:
|
||||
logging.info("Device: {}".format(get_torch_device_name(get_torch_device())))
|
||||
except:
|
||||
logging.warning("Could not pick default device.")
|
||||
try:
|
||||
for device in get_all_torch_devices(exclude_current=True):
|
||||
logging.info("Device: {}".format(get_torch_device_name(device)))
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
current_loaded_models = []
|
||||
current_loaded_models: list[LoadedModel] = []
|
||||
|
||||
def module_size(module):
|
||||
module_mem = 0
|
||||
@@ -422,7 +450,7 @@ def module_size(module):
|
||||
return module_mem
|
||||
|
||||
class LoadedModel:
|
||||
def __init__(self, model):
|
||||
def __init__(self, model: ModelPatcher):
|
||||
self._set_model(model)
|
||||
self.device = model.load_device
|
||||
self.real_model = None
|
||||
@@ -430,7 +458,7 @@ class LoadedModel:
|
||||
self.model_finalizer = None
|
||||
self._patcher_finalizer = None
|
||||
|
||||
def _set_model(self, model):
|
||||
def _set_model(self, model: ModelPatcher):
|
||||
self._model = weakref.ref(model)
|
||||
if model.parent is not None:
|
||||
self._parent_model = weakref.ref(model.parent)
|
||||
@@ -1364,8 +1392,34 @@ def soft_empty_cache(force=False):
|
||||
torch.cuda.ipc_collect()
|
||||
|
||||
def unload_all_models():
|
||||
free_memory(1e30, get_torch_device())
|
||||
for device in get_all_torch_devices():
|
||||
free_memory(1e30, device)
|
||||
|
||||
def unload_model_and_clones(model: ModelPatcher, unload_additional_models=True, all_devices=False):
|
||||
'Unload only model and its clones - primarily for multigpu cloning purposes.'
|
||||
initial_keep_loaded: list[LoadedModel] = current_loaded_models.copy()
|
||||
additional_models = []
|
||||
if unload_additional_models:
|
||||
additional_models = model.get_nested_additional_models()
|
||||
keep_loaded = []
|
||||
for loaded_model in initial_keep_loaded:
|
||||
if loaded_model.model is not None:
|
||||
if model.clone_base_uuid == loaded_model.model.clone_base_uuid:
|
||||
continue
|
||||
# check additional models if they are a match
|
||||
skip = False
|
||||
for add_model in additional_models:
|
||||
if add_model.clone_base_uuid == loaded_model.model.clone_base_uuid:
|
||||
skip = True
|
||||
break
|
||||
if skip:
|
||||
continue
|
||||
keep_loaded.append(loaded_model)
|
||||
if not all_devices:
|
||||
free_memory(1e30, get_torch_device(), keep_loaded)
|
||||
else:
|
||||
for device in get_all_torch_devices():
|
||||
free_memory(1e30, device, keep_loaded)
|
||||
|
||||
#TODO: might be cleaner to put this somewhere else
|
||||
import threading
|
||||
|
@@ -87,12 +87,15 @@ def set_model_options_pre_cfg_function(model_options, pre_cfg_function, disable_
|
||||
def create_model_options_clone(orig_model_options: dict):
|
||||
return comfy.patcher_extension.copy_nested_dicts(orig_model_options)
|
||||
|
||||
def create_hook_patches_clone(orig_hook_patches):
|
||||
def create_hook_patches_clone(orig_hook_patches, copy_tuples=False):
|
||||
new_hook_patches = {}
|
||||
for hook_ref in orig_hook_patches:
|
||||
new_hook_patches[hook_ref] = {}
|
||||
for k in orig_hook_patches[hook_ref]:
|
||||
new_hook_patches[hook_ref][k] = orig_hook_patches[hook_ref][k][:]
|
||||
if copy_tuples:
|
||||
for i in range(len(new_hook_patches[hook_ref][k])):
|
||||
new_hook_patches[hook_ref][k][i] = tuple(new_hook_patches[hook_ref][k][i])
|
||||
return new_hook_patches
|
||||
|
||||
def wipe_lowvram_weight(m):
|
||||
@@ -243,6 +246,9 @@ class ModelPatcher:
|
||||
self.is_clip = False
|
||||
self.hook_mode = comfy.hooks.EnumHookMode.MaxSpeed
|
||||
|
||||
self.is_multigpu_base_clone = False
|
||||
self.clone_base_uuid = uuid.uuid4()
|
||||
|
||||
if not hasattr(self.model, 'model_loaded_weight_memory'):
|
||||
self.model.model_loaded_weight_memory = 0
|
||||
|
||||
@@ -321,18 +327,92 @@ class ModelPatcher:
|
||||
n.is_clip = self.is_clip
|
||||
n.hook_mode = self.hook_mode
|
||||
|
||||
n.is_multigpu_base_clone = self.is_multigpu_base_clone
|
||||
n.clone_base_uuid = self.clone_base_uuid
|
||||
|
||||
for callback in self.get_all_callbacks(CallbacksMP.ON_CLONE):
|
||||
callback(self, n)
|
||||
return n
|
||||
|
||||
def deepclone_multigpu(self, new_load_device=None, models_cache: dict[uuid.UUID,ModelPatcher]=None):
|
||||
logging.info(f"Creating deepclone of {self.model.__class__.__name__} for {new_load_device if new_load_device else self.load_device}.")
|
||||
comfy.model_management.unload_model_and_clones(self)
|
||||
n = self.clone()
|
||||
# set load device, if present
|
||||
if new_load_device is not None:
|
||||
n.load_device = new_load_device
|
||||
# unlike for normal clone, backup dicts that shared same ref should not;
|
||||
# otherwise, patchers that have deep copies of base models will erroneously influence each other.
|
||||
n.backup = copy.deepcopy(n.backup)
|
||||
n.object_patches_backup = copy.deepcopy(n.object_patches_backup)
|
||||
n.hook_backup = copy.deepcopy(n.hook_backup)
|
||||
n.model = copy.deepcopy(n.model)
|
||||
# multigpu clone should not have multigpu additional_models entry
|
||||
n.remove_additional_models("multigpu")
|
||||
# multigpu_clone all stored additional_models; make sure circular references are properly handled
|
||||
if models_cache is None:
|
||||
models_cache = {}
|
||||
for key, model_list in n.additional_models.items():
|
||||
for i in range(len(model_list)):
|
||||
add_model = n.additional_models[key][i]
|
||||
if add_model.clone_base_uuid not in models_cache:
|
||||
models_cache[add_model.clone_base_uuid] = add_model.deepclone_multigpu(new_load_device=new_load_device, models_cache=models_cache)
|
||||
n.additional_models[key][i] = models_cache[add_model.clone_base_uuid]
|
||||
for callback in self.get_all_callbacks(CallbacksMP.ON_DEEPCLONE_MULTIGPU):
|
||||
callback(self, n)
|
||||
return n
|
||||
|
||||
def match_multigpu_clones(self):
|
||||
multigpu_models = self.get_additional_models_with_key("multigpu")
|
||||
if len(multigpu_models) > 0:
|
||||
new_multigpu_models = []
|
||||
for mm in multigpu_models:
|
||||
# clone main model, but bring over relevant props from existing multigpu clone
|
||||
n = self.clone()
|
||||
n.load_device = mm.load_device
|
||||
n.backup = mm.backup
|
||||
n.object_patches_backup = mm.object_patches_backup
|
||||
n.hook_backup = mm.hook_backup
|
||||
n.model = mm.model
|
||||
n.is_multigpu_base_clone = mm.is_multigpu_base_clone
|
||||
n.remove_additional_models("multigpu")
|
||||
orig_additional_models: dict[str, list[ModelPatcher]] = comfy.patcher_extension.copy_nested_dicts(n.additional_models)
|
||||
n.additional_models = comfy.patcher_extension.copy_nested_dicts(mm.additional_models)
|
||||
# figure out which additional models are not present in multigpu clone
|
||||
models_cache = {}
|
||||
for mm_add_model in mm.get_additional_models():
|
||||
models_cache[mm_add_model.clone_base_uuid] = mm_add_model
|
||||
remove_models_uuids = set(list(models_cache.keys()))
|
||||
for key, model_list in orig_additional_models.items():
|
||||
for orig_add_model in model_list:
|
||||
if orig_add_model.clone_base_uuid not in models_cache:
|
||||
models_cache[orig_add_model.clone_base_uuid] = orig_add_model.deepclone_multigpu(new_load_device=n.load_device, models_cache=models_cache)
|
||||
existing_list = n.get_additional_models_with_key(key)
|
||||
existing_list.append(models_cache[orig_add_model.clone_base_uuid])
|
||||
n.set_additional_models(key, existing_list)
|
||||
if orig_add_model.clone_base_uuid in remove_models_uuids:
|
||||
remove_models_uuids.remove(orig_add_model.clone_base_uuid)
|
||||
# remove duplicate additional models
|
||||
for key, model_list in n.additional_models.items():
|
||||
new_model_list = [x for x in model_list if x.clone_base_uuid not in remove_models_uuids]
|
||||
n.set_additional_models(key, new_model_list)
|
||||
for callback in self.get_all_callbacks(CallbacksMP.ON_MATCH_MULTIGPU_CLONES):
|
||||
callback(self, n)
|
||||
new_multigpu_models.append(n)
|
||||
self.set_additional_models("multigpu", new_multigpu_models)
|
||||
|
||||
def is_clone(self, other):
|
||||
if hasattr(other, 'model') and self.model is other.model:
|
||||
return True
|
||||
return False
|
||||
|
||||
def clone_has_same_weights(self, clone: 'ModelPatcher'):
|
||||
if not self.is_clone(clone):
|
||||
return False
|
||||
def clone_has_same_weights(self, clone: ModelPatcher, allow_multigpu=False):
|
||||
if allow_multigpu:
|
||||
if self.clone_base_uuid != clone.clone_base_uuid:
|
||||
return False
|
||||
else:
|
||||
if not self.is_clone(clone):
|
||||
return False
|
||||
|
||||
if self.current_hooks != clone.current_hooks:
|
||||
return False
|
||||
@@ -935,7 +1015,7 @@ class ModelPatcher:
|
||||
return self.additional_models.get(key, [])
|
||||
|
||||
def get_additional_models(self):
|
||||
all_models = []
|
||||
all_models: list[ModelPatcher] = []
|
||||
for models in self.additional_models.values():
|
||||
all_models.extend(models)
|
||||
return all_models
|
||||
@@ -989,9 +1069,13 @@ class ModelPatcher:
|
||||
for callback in self.get_all_callbacks(CallbacksMP.ON_PRE_RUN):
|
||||
callback(self)
|
||||
|
||||
def prepare_state(self, timestep):
|
||||
def prepare_state(self, timestep, model_options, ignore_multigpu=False):
|
||||
for callback in self.get_all_callbacks(CallbacksMP.ON_PREPARE_STATE):
|
||||
callback(self, timestep)
|
||||
callback(self, timestep, model_options, ignore_multigpu)
|
||||
if not ignore_multigpu and "multigpu_clones" in model_options:
|
||||
for p in model_options["multigpu_clones"].values():
|
||||
p: ModelPatcher
|
||||
p.prepare_state(timestep, model_options, ignore_multigpu=True)
|
||||
|
||||
def restore_hook_patches(self):
|
||||
if self.hook_patches_backup is not None:
|
||||
@@ -1004,12 +1088,18 @@ class ModelPatcher:
|
||||
def prepare_hook_patches_current_keyframe(self, t: torch.Tensor, hook_group: comfy.hooks.HookGroup, model_options: dict[str]):
|
||||
curr_t = t[0]
|
||||
reset_current_hooks = False
|
||||
multigpu_kf_changed_cache = None
|
||||
transformer_options = model_options.get("transformer_options", {})
|
||||
for hook in hook_group.hooks:
|
||||
changed = hook.hook_keyframe.prepare_current_keyframe(curr_t=curr_t, transformer_options=transformer_options)
|
||||
# if keyframe changed, remove any cached HookGroups that contain hook with the same hook_ref;
|
||||
# this will cause the weights to be recalculated when sampling
|
||||
if changed:
|
||||
# cache changed for multigpu usage
|
||||
if "multigpu_clones" in model_options:
|
||||
if multigpu_kf_changed_cache is None:
|
||||
multigpu_kf_changed_cache = []
|
||||
multigpu_kf_changed_cache.append(hook)
|
||||
# reset current_hooks if contains hook that changed
|
||||
if self.current_hooks is not None:
|
||||
for current_hook in self.current_hooks.hooks:
|
||||
@@ -1021,6 +1111,28 @@ class ModelPatcher:
|
||||
self.cached_hook_patches.pop(cached_group)
|
||||
if reset_current_hooks:
|
||||
self.patch_hooks(None)
|
||||
if "multigpu_clones" in model_options:
|
||||
for p in model_options["multigpu_clones"].values():
|
||||
p: ModelPatcher
|
||||
p._handle_changed_hook_keyframes(multigpu_kf_changed_cache)
|
||||
|
||||
def _handle_changed_hook_keyframes(self, kf_changed_cache: list[comfy.hooks.Hook]):
|
||||
'Used to handle multigpu behavior inside prepare_hook_patches_current_keyframe.'
|
||||
if kf_changed_cache is None:
|
||||
return
|
||||
reset_current_hooks = False
|
||||
# reset current_hooks if contains hook that changed
|
||||
for hook in kf_changed_cache:
|
||||
if self.current_hooks is not None:
|
||||
for current_hook in self.current_hooks.hooks:
|
||||
if current_hook == hook:
|
||||
reset_current_hooks = True
|
||||
break
|
||||
for cached_group in list(self.cached_hook_patches.keys()):
|
||||
if cached_group.contains(hook):
|
||||
self.cached_hook_patches.pop(cached_group)
|
||||
if reset_current_hooks:
|
||||
self.patch_hooks(None)
|
||||
|
||||
def register_all_hook_patches(self, hooks: comfy.hooks.HookGroup, target_dict: dict[str], model_options: dict=None,
|
||||
registered: comfy.hooks.HookGroup = None):
|
||||
|
167
comfy/multigpu.py
Normal file
167
comfy/multigpu.py
Normal file
@@ -0,0 +1,167 @@
|
||||
from __future__ import annotations
|
||||
import torch
|
||||
import logging
|
||||
|
||||
from collections import namedtuple
|
||||
from typing import TYPE_CHECKING
|
||||
if TYPE_CHECKING:
|
||||
from comfy.model_patcher import ModelPatcher
|
||||
import comfy.utils
|
||||
import comfy.patcher_extension
|
||||
import comfy.model_management
|
||||
|
||||
|
||||
class GPUOptions:
|
||||
def __init__(self, device_index: int, relative_speed: float):
|
||||
self.device_index = device_index
|
||||
self.relative_speed = relative_speed
|
||||
|
||||
def clone(self):
|
||||
return GPUOptions(self.device_index, self.relative_speed)
|
||||
|
||||
def create_dict(self):
|
||||
return {
|
||||
"relative_speed": self.relative_speed
|
||||
}
|
||||
|
||||
class GPUOptionsGroup:
|
||||
def __init__(self):
|
||||
self.options: dict[int, GPUOptions] = {}
|
||||
|
||||
def add(self, info: GPUOptions):
|
||||
self.options[info.device_index] = info
|
||||
|
||||
def clone(self):
|
||||
c = GPUOptionsGroup()
|
||||
for opt in self.options.values():
|
||||
c.add(opt)
|
||||
return c
|
||||
|
||||
def register(self, model: ModelPatcher):
|
||||
opts_dict = {}
|
||||
# get devices that are valid for this model
|
||||
devices: list[torch.device] = [model.load_device]
|
||||
for extra_model in model.get_additional_models_with_key("multigpu"):
|
||||
extra_model: ModelPatcher
|
||||
devices.append(extra_model.load_device)
|
||||
# create dictionary with actual device mapped to its GPUOptions
|
||||
device_opts_list: list[GPUOptions] = []
|
||||
for device in devices:
|
||||
device_opts = self.options.get(device.index, GPUOptions(device_index=device.index, relative_speed=1.0))
|
||||
opts_dict[device] = device_opts.create_dict()
|
||||
device_opts_list.append(device_opts)
|
||||
# make relative_speed relative to 1.0
|
||||
min_speed = min([x.relative_speed for x in device_opts_list])
|
||||
for value in opts_dict.values():
|
||||
value['relative_speed'] /= min_speed
|
||||
model.model_options['multigpu_options'] = opts_dict
|
||||
|
||||
|
||||
def create_multigpu_deepclones(model: ModelPatcher, max_gpus: int, gpu_options: GPUOptionsGroup=None, reuse_loaded=False):
|
||||
'Prepare ModelPatcher to contain deepclones of its BaseModel and related properties.'
|
||||
model = model.clone()
|
||||
# check if multigpu is already prepared - get the load devices from them if possible to exclude
|
||||
skip_devices = set()
|
||||
multigpu_models = model.get_additional_models_with_key("multigpu")
|
||||
if len(multigpu_models) > 0:
|
||||
for mm in multigpu_models:
|
||||
skip_devices.add(mm.load_device)
|
||||
skip_devices = list(skip_devices)
|
||||
|
||||
full_extra_devices = comfy.model_management.get_all_torch_devices(exclude_current=True)
|
||||
limit_extra_devices = full_extra_devices[:max_gpus-1]
|
||||
extra_devices = limit_extra_devices.copy()
|
||||
# exclude skipped devices
|
||||
for skip in skip_devices:
|
||||
if skip in extra_devices:
|
||||
extra_devices.remove(skip)
|
||||
# create new deepclones
|
||||
if len(extra_devices) > 0:
|
||||
for device in extra_devices:
|
||||
device_patcher = None
|
||||
if reuse_loaded:
|
||||
# check if there are any ModelPatchers currently loaded that could be referenced here after a clone
|
||||
loaded_models: list[ModelPatcher] = comfy.model_management.loaded_models()
|
||||
for lm in loaded_models:
|
||||
if lm.model is not None and lm.clone_base_uuid == model.clone_base_uuid and lm.load_device == device:
|
||||
device_patcher = lm.clone()
|
||||
logging.info(f"Reusing loaded deepclone of {device_patcher.model.__class__.__name__} for {device}")
|
||||
break
|
||||
if device_patcher is None:
|
||||
device_patcher = model.deepclone_multigpu(new_load_device=device)
|
||||
device_patcher.is_multigpu_base_clone = True
|
||||
multigpu_models = model.get_additional_models_with_key("multigpu")
|
||||
multigpu_models.append(device_patcher)
|
||||
model.set_additional_models("multigpu", multigpu_models)
|
||||
model.match_multigpu_clones()
|
||||
if gpu_options is None:
|
||||
gpu_options = GPUOptionsGroup()
|
||||
gpu_options.register(model)
|
||||
else:
|
||||
logging.info("No extra torch devices need initialization, skipping initializing MultiGPU Work Units.")
|
||||
# TODO: only keep model clones that don't go 'past' the intended max_gpu count
|
||||
# multigpu_models = model.get_additional_models_with_key("multigpu")
|
||||
# new_multigpu_models = []
|
||||
# for m in multigpu_models:
|
||||
# if m.load_device in limit_extra_devices:
|
||||
# new_multigpu_models.append(m)
|
||||
# model.set_additional_models("multigpu", new_multigpu_models)
|
||||
# persist skip_devices for use in sampling code
|
||||
# if len(skip_devices) > 0 or "multigpu_skip_devices" in model.model_options:
|
||||
# model.model_options["multigpu_skip_devices"] = skip_devices
|
||||
return model
|
||||
|
||||
|
||||
LoadBalance = namedtuple('LoadBalance', ['work_per_device', 'idle_time'])
|
||||
def load_balance_devices(model_options: dict[str], total_work: int, return_idle_time=False, work_normalized: int=None):
|
||||
'Optimize work assigned to different devices, accounting for their relative speeds and splittable work.'
|
||||
opts_dict = model_options['multigpu_options']
|
||||
devices = list(model_options['multigpu_clones'].keys())
|
||||
speed_per_device = []
|
||||
work_per_device = []
|
||||
# get sum of each device's relative_speed
|
||||
total_speed = 0.0
|
||||
for opts in opts_dict.values():
|
||||
total_speed += opts['relative_speed']
|
||||
# get relative work for each device;
|
||||
# obtained by w = (W*r)/R
|
||||
for device in devices:
|
||||
relative_speed = opts_dict[device]['relative_speed']
|
||||
relative_work = (total_work*relative_speed) / total_speed
|
||||
speed_per_device.append(relative_speed)
|
||||
work_per_device.append(relative_work)
|
||||
# relative work must be expressed in whole numbers, but likely is a decimal;
|
||||
# perform rounding while maintaining total sum equal to total work (sum of relative works)
|
||||
work_per_device = round_preserved(work_per_device)
|
||||
dict_work_per_device = {}
|
||||
for device, relative_work in zip(devices, work_per_device):
|
||||
dict_work_per_device[device] = relative_work
|
||||
if not return_idle_time:
|
||||
return LoadBalance(dict_work_per_device, None)
|
||||
# divide relative work by relative speed to get estimated completion time of said work by each device;
|
||||
# time here is relative and does not correspond to real-world units
|
||||
completion_time = [w/r for w,r in zip(work_per_device, speed_per_device)]
|
||||
# calculate relative time spent by the devices waiting on each other after their work is completed
|
||||
idle_time = abs(min(completion_time) - max(completion_time))
|
||||
# if need to compare work idle time, need to normalize to a common total work
|
||||
if work_normalized:
|
||||
idle_time *= (work_normalized/total_work)
|
||||
|
||||
return LoadBalance(dict_work_per_device, idle_time)
|
||||
|
||||
def round_preserved(values: list[float]):
|
||||
'Round all values in a list, preserving the combined sum of values.'
|
||||
# get floor of values; casting to int does it too
|
||||
floored = [int(x) for x in values]
|
||||
total_floored = sum(floored)
|
||||
# get remainder to distribute
|
||||
remainder = round(sum(values)) - total_floored
|
||||
# pair values with fractional portions
|
||||
fractional = [(i, x-floored[i]) for i, x in enumerate(values)]
|
||||
# sort by fractional part in descending order
|
||||
fractional.sort(key=lambda x: x[1], reverse=True)
|
||||
# distribute the remainder
|
||||
for i in range(remainder):
|
||||
index = fractional[i][0]
|
||||
floored[index] += 1
|
||||
return floored
|
@@ -3,6 +3,8 @@ from typing import Callable
|
||||
|
||||
class CallbacksMP:
|
||||
ON_CLONE = "on_clone"
|
||||
ON_DEEPCLONE_MULTIGPU = "on_deepclone_multigpu"
|
||||
ON_MATCH_MULTIGPU_CLONES = "on_match_multigpu_clones"
|
||||
ON_LOAD = "on_load_after"
|
||||
ON_DETACH = "on_detach_after"
|
||||
ON_CLEANUP = "on_cleanup"
|
||||
|
@@ -1,9 +1,11 @@
|
||||
from __future__ import annotations
|
||||
import torch
|
||||
import uuid
|
||||
import math
|
||||
import collections
|
||||
import comfy.model_management
|
||||
import comfy.conds
|
||||
import comfy.model_patcher
|
||||
import comfy.utils
|
||||
import comfy.hooks
|
||||
import comfy.patcher_extension
|
||||
@@ -106,6 +108,47 @@ def cleanup_additional_models(models):
|
||||
if hasattr(m, 'cleanup'):
|
||||
m.cleanup()
|
||||
|
||||
def preprocess_multigpu_conds(conds: dict[str, list[dict[str]]], model: ModelPatcher, model_options: dict[str]):
|
||||
'''If multigpu acceleration required, creates deepclones of ControlNets and GLIGEN per device.'''
|
||||
multigpu_models: list[ModelPatcher] = model.get_additional_models_with_key("multigpu")
|
||||
if len(multigpu_models) == 0:
|
||||
return
|
||||
extra_devices = [x.load_device for x in multigpu_models]
|
||||
# handle controlnets
|
||||
controlnets: set[ControlBase] = set()
|
||||
for k in conds:
|
||||
for kk in conds[k]:
|
||||
if 'control' in kk:
|
||||
controlnets.add(kk['control'])
|
||||
if len(controlnets) > 0:
|
||||
# first, unload all controlnet clones
|
||||
for cnet in list(controlnets):
|
||||
cnet_models = cnet.get_models()
|
||||
for cm in cnet_models:
|
||||
comfy.model_management.unload_model_and_clones(cm, unload_additional_models=True)
|
||||
|
||||
# next, make sure each controlnet has a deepclone for all relevant devices
|
||||
for cnet in controlnets:
|
||||
curr_cnet = cnet
|
||||
while curr_cnet is not None:
|
||||
for device in extra_devices:
|
||||
if device not in curr_cnet.multigpu_clones:
|
||||
curr_cnet.deepclone_multigpu(device, autoregister=True)
|
||||
curr_cnet = curr_cnet.previous_controlnet
|
||||
# since all device clones are now present, recreate the linked list for cloned cnets per device
|
||||
for cnet in controlnets:
|
||||
curr_cnet = cnet
|
||||
while curr_cnet is not None:
|
||||
prev_cnet = curr_cnet.previous_controlnet
|
||||
for device in extra_devices:
|
||||
device_cnet = curr_cnet.get_instance_for_device(device)
|
||||
prev_device_cnet = None
|
||||
if prev_cnet is not None:
|
||||
prev_device_cnet = prev_cnet.get_instance_for_device(device)
|
||||
device_cnet.set_previous_controlnet(prev_device_cnet)
|
||||
curr_cnet = prev_cnet
|
||||
# potentially handle gligen - since not widely used, ignored for now
|
||||
|
||||
def estimate_memory(model, noise_shape, conds):
|
||||
cond_shapes = collections.defaultdict(list)
|
||||
cond_shapes_min = {}
|
||||
@@ -130,7 +173,8 @@ def prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None
|
||||
return executor.execute(model, noise_shape, conds, model_options=model_options)
|
||||
|
||||
def _prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None):
|
||||
real_model: BaseModel = None
|
||||
model.match_multigpu_clones()
|
||||
preprocess_multigpu_conds(conds, model, model_options)
|
||||
models, inference_memory = get_additional_models(conds, model.model_dtype())
|
||||
models += get_additional_models_from_model_options(model_options)
|
||||
models += model.get_nested_additional_models() # TODO: does this require inference_memory update?
|
||||
@@ -149,7 +193,7 @@ def cleanup_models(conds, models):
|
||||
|
||||
cleanup_additional_models(set(control_cleanup))
|
||||
|
||||
def prepare_model_patcher(model: 'ModelPatcher', conds, model_options: dict):
|
||||
def prepare_model_patcher(model: ModelPatcher, conds, model_options: dict):
|
||||
'''
|
||||
Registers hooks from conds.
|
||||
'''
|
||||
@@ -182,3 +226,18 @@ def prepare_model_patcher(model: 'ModelPatcher', conds, model_options: dict):
|
||||
comfy.patcher_extension.merge_nested_dicts(to_load_options.setdefault(wc_name, {}), model_options["transformer_options"][wc_name],
|
||||
copy_dict1=False)
|
||||
return to_load_options
|
||||
|
||||
def prepare_model_patcher_multigpu_clones(model_patcher: ModelPatcher, loaded_models: list[ModelPatcher], model_options: dict):
|
||||
'''
|
||||
In case multigpu acceleration is enabled, prep ModelPatchers for each device.
|
||||
'''
|
||||
multigpu_patchers: list[ModelPatcher] = [x for x in loaded_models if x.is_multigpu_base_clone]
|
||||
if len(multigpu_patchers) > 0:
|
||||
multigpu_dict: dict[torch.device, ModelPatcher] = {}
|
||||
multigpu_dict[model_patcher.load_device] = model_patcher
|
||||
for x in multigpu_patchers:
|
||||
x.hook_patches = comfy.model_patcher.create_hook_patches_clone(model_patcher.hook_patches, copy_tuples=True)
|
||||
x.hook_mode = model_patcher.hook_mode # match main model's hook_mode
|
||||
multigpu_dict[x.load_device] = x
|
||||
model_options["multigpu_clones"] = multigpu_dict
|
||||
return multigpu_patchers
|
||||
|
@@ -1,7 +1,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import comfy.model_management
|
||||
from .k_diffusion import sampling as k_diffusion_sampling
|
||||
from .extra_samplers import uni_pc
|
||||
from typing import TYPE_CHECKING, Callable, NamedTuple
|
||||
from typing import TYPE_CHECKING, Callable, NamedTuple, Any
|
||||
if TYPE_CHECKING:
|
||||
from comfy.model_patcher import ModelPatcher
|
||||
from comfy.model_base import BaseModel
|
||||
@@ -18,6 +20,7 @@ import comfy.patcher_extension
|
||||
import comfy.hooks
|
||||
import scipy.stats
|
||||
import numpy
|
||||
import threading
|
||||
|
||||
|
||||
def add_area_dims(area, num_dims):
|
||||
@@ -140,7 +143,7 @@ def can_concat_cond(c1, c2):
|
||||
|
||||
return cond_equal_size(c1.conditioning, c2.conditioning)
|
||||
|
||||
def cond_cat(c_list):
|
||||
def cond_cat(c_list, device=None):
|
||||
temp = {}
|
||||
for x in c_list:
|
||||
for k in x:
|
||||
@@ -152,6 +155,8 @@ def cond_cat(c_list):
|
||||
for k in temp:
|
||||
conds = temp[k]
|
||||
out[k] = conds[0].concat(conds[1:])
|
||||
if device is not None and hasattr(out[k], 'to'):
|
||||
out[k] = out[k].to(device)
|
||||
|
||||
return out
|
||||
|
||||
@@ -205,7 +210,9 @@ def calc_cond_batch(model: 'BaseModel', conds: list[list[dict]], x_in: torch.Ten
|
||||
)
|
||||
return executor.execute(model, conds, x_in, timestep, model_options)
|
||||
|
||||
def _calc_cond_batch(model: 'BaseModel', conds: list[list[dict]], x_in: torch.Tensor, timestep, model_options):
|
||||
def _calc_cond_batch(model: 'BaseModel', conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]):
|
||||
if 'multigpu_clones' in model_options:
|
||||
return _calc_cond_batch_multigpu(model, conds, x_in, timestep, model_options)
|
||||
out_conds = []
|
||||
out_counts = []
|
||||
# separate conds by matching hooks
|
||||
@@ -237,7 +244,7 @@ def _calc_cond_batch(model: 'BaseModel', conds: list[list[dict]], x_in: torch.Te
|
||||
if has_default_conds:
|
||||
finalize_default_conds(model, hooked_to_run, default_conds, x_in, timestep, model_options)
|
||||
|
||||
model.current_patcher.prepare_state(timestep)
|
||||
model.current_patcher.prepare_state(timestep, model_options)
|
||||
|
||||
# run every hooked_to_run separately
|
||||
for hooks, to_run in hooked_to_run.items():
|
||||
@@ -345,6 +352,203 @@ def _calc_cond_batch(model: 'BaseModel', conds: list[list[dict]], x_in: torch.Te
|
||||
|
||||
return out_conds
|
||||
|
||||
def _calc_cond_batch_multigpu(model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]):
|
||||
out_conds = []
|
||||
out_counts = []
|
||||
# separate conds by matching hooks
|
||||
hooked_to_run: dict[comfy.hooks.HookGroup,list[tuple[tuple,int]]] = {}
|
||||
default_conds = []
|
||||
has_default_conds = False
|
||||
|
||||
output_device = x_in.device
|
||||
|
||||
for i in range(len(conds)):
|
||||
out_conds.append(torch.zeros_like(x_in))
|
||||
out_counts.append(torch.ones_like(x_in) * 1e-37)
|
||||
|
||||
cond = conds[i]
|
||||
default_c = []
|
||||
if cond is not None:
|
||||
for x in cond:
|
||||
if 'default' in x:
|
||||
default_c.append(x)
|
||||
has_default_conds = True
|
||||
continue
|
||||
p = get_area_and_mult(x, x_in, timestep)
|
||||
if p is None:
|
||||
continue
|
||||
if p.hooks is not None:
|
||||
model.current_patcher.prepare_hook_patches_current_keyframe(timestep, p.hooks, model_options)
|
||||
hooked_to_run.setdefault(p.hooks, list())
|
||||
hooked_to_run[p.hooks] += [(p, i)]
|
||||
default_conds.append(default_c)
|
||||
|
||||
if has_default_conds:
|
||||
finalize_default_conds(model, hooked_to_run, default_conds, x_in, timestep, model_options)
|
||||
|
||||
model.current_patcher.prepare_state(timestep, model_options)
|
||||
|
||||
devices = [dev_m for dev_m in model_options['multigpu_clones'].keys()]
|
||||
device_batched_hooked_to_run: dict[torch.device, list[tuple[comfy.hooks.HookGroup, tuple]]] = {}
|
||||
|
||||
total_conds = 0
|
||||
for to_run in hooked_to_run.values():
|
||||
total_conds += len(to_run)
|
||||
conds_per_device = max(1, math.ceil(total_conds//len(devices)))
|
||||
index_device = 0
|
||||
current_device = devices[index_device]
|
||||
# run every hooked_to_run separately
|
||||
for hooks, to_run in hooked_to_run.items():
|
||||
while len(to_run) > 0:
|
||||
current_device = devices[index_device % len(devices)]
|
||||
batched_to_run = device_batched_hooked_to_run.setdefault(current_device, [])
|
||||
# keep track of conds currently scheduled onto this device
|
||||
batched_to_run_length = 0
|
||||
for btr in batched_to_run:
|
||||
batched_to_run_length += len(btr[1])
|
||||
|
||||
first = to_run[0]
|
||||
first_shape = first[0][0].shape
|
||||
to_batch_temp = []
|
||||
# make sure not over conds_per_device limit when creating temp batch
|
||||
for x in range(len(to_run)):
|
||||
if can_concat_cond(to_run[x][0], first[0]) and len(to_batch_temp) < (conds_per_device - batched_to_run_length):
|
||||
to_batch_temp += [x]
|
||||
|
||||
to_batch_temp.reverse()
|
||||
to_batch = to_batch_temp[:1]
|
||||
|
||||
free_memory = model_management.get_free_memory(current_device)
|
||||
for i in range(1, len(to_batch_temp) + 1):
|
||||
batch_amount = to_batch_temp[:len(to_batch_temp)//i]
|
||||
input_shape = [len(batch_amount) * first_shape[0]] + list(first_shape)[1:]
|
||||
if model.memory_required(input_shape) * 1.5 < free_memory:
|
||||
to_batch = batch_amount
|
||||
break
|
||||
conds_to_batch = []
|
||||
for x in to_batch:
|
||||
conds_to_batch.append(to_run.pop(x))
|
||||
batched_to_run_length += len(conds_to_batch)
|
||||
|
||||
batched_to_run.append((hooks, conds_to_batch))
|
||||
if batched_to_run_length >= conds_per_device:
|
||||
index_device += 1
|
||||
|
||||
class thread_result(NamedTuple):
|
||||
output: Any
|
||||
mult: Any
|
||||
area: Any
|
||||
batch_chunks: int
|
||||
cond_or_uncond: Any
|
||||
error: Exception = None
|
||||
|
||||
def _handle_batch(device: torch.device, batch_tuple: tuple[comfy.hooks.HookGroup, tuple], results: list[thread_result]):
|
||||
try:
|
||||
model_current: BaseModel = model_options["multigpu_clones"][device].model
|
||||
# run every hooked_to_run separately
|
||||
with torch.no_grad():
|
||||
for hooks, to_batch in batch_tuple:
|
||||
input_x = []
|
||||
mult = []
|
||||
c = []
|
||||
cond_or_uncond = []
|
||||
uuids = []
|
||||
area = []
|
||||
control: ControlBase = None
|
||||
patches = None
|
||||
for x in to_batch:
|
||||
o = x
|
||||
p = o[0]
|
||||
input_x.append(p.input_x)
|
||||
mult.append(p.mult)
|
||||
c.append(p.conditioning)
|
||||
area.append(p.area)
|
||||
cond_or_uncond.append(o[1])
|
||||
uuids.append(p.uuid)
|
||||
control = p.control
|
||||
patches = p.patches
|
||||
|
||||
batch_chunks = len(cond_or_uncond)
|
||||
input_x = torch.cat(input_x).to(device)
|
||||
c = cond_cat(c, device=device)
|
||||
timestep_ = torch.cat([timestep.to(device)] * batch_chunks)
|
||||
|
||||
transformer_options = model_current.current_patcher.apply_hooks(hooks=hooks)
|
||||
if 'transformer_options' in model_options:
|
||||
transformer_options = comfy.patcher_extension.merge_nested_dicts(transformer_options,
|
||||
model_options['transformer_options'],
|
||||
copy_dict1=False)
|
||||
|
||||
if patches is not None:
|
||||
# TODO: replace with merge_nested_dicts function
|
||||
if "patches" in transformer_options:
|
||||
cur_patches = transformer_options["patches"].copy()
|
||||
for p in patches:
|
||||
if p in cur_patches:
|
||||
cur_patches[p] = cur_patches[p] + patches[p]
|
||||
else:
|
||||
cur_patches[p] = patches[p]
|
||||
transformer_options["patches"] = cur_patches
|
||||
else:
|
||||
transformer_options["patches"] = patches
|
||||
|
||||
transformer_options["cond_or_uncond"] = cond_or_uncond[:]
|
||||
transformer_options["uuids"] = uuids[:]
|
||||
transformer_options["sigmas"] = timestep
|
||||
transformer_options["sample_sigmas"] = transformer_options["sample_sigmas"].to(device)
|
||||
transformer_options["multigpu_thread_device"] = device
|
||||
|
||||
cast_transformer_options(transformer_options, device=device)
|
||||
c['transformer_options'] = transformer_options
|
||||
|
||||
if control is not None:
|
||||
device_control = control.get_instance_for_device(device)
|
||||
c['control'] = device_control.get_control(input_x, timestep_, c, len(cond_or_uncond), transformer_options)
|
||||
|
||||
if 'model_function_wrapper' in model_options:
|
||||
output = model_options['model_function_wrapper'](model_current.apply_model, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).to(output_device).chunk(batch_chunks)
|
||||
else:
|
||||
output = model_current.apply_model(input_x, timestep_, **c).to(output_device).chunk(batch_chunks)
|
||||
results.append(thread_result(output, mult, area, batch_chunks, cond_or_uncond))
|
||||
except Exception as e:
|
||||
results.append(thread_result(None, None, None, None, None, error=e))
|
||||
raise
|
||||
|
||||
|
||||
results: list[thread_result] = []
|
||||
threads: list[threading.Thread] = []
|
||||
for device, batch_tuple in device_batched_hooked_to_run.items():
|
||||
new_thread = threading.Thread(target=_handle_batch, args=(device, batch_tuple, results))
|
||||
threads.append(new_thread)
|
||||
new_thread.start()
|
||||
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
|
||||
for output, mult, area, batch_chunks, cond_or_uncond, error in results:
|
||||
if error is not None:
|
||||
raise error
|
||||
for o in range(batch_chunks):
|
||||
cond_index = cond_or_uncond[o]
|
||||
a = area[o]
|
||||
if a is None:
|
||||
out_conds[cond_index] += output[o] * mult[o]
|
||||
out_counts[cond_index] += mult[o]
|
||||
else:
|
||||
out_c = out_conds[cond_index]
|
||||
out_cts = out_counts[cond_index]
|
||||
dims = len(a) // 2
|
||||
for i in range(dims):
|
||||
out_c = out_c.narrow(i + 2, a[i + dims], a[i])
|
||||
out_cts = out_cts.narrow(i + 2, a[i + dims], a[i])
|
||||
out_c += output[o] * mult[o]
|
||||
out_cts += mult[o]
|
||||
|
||||
for i in range(len(out_conds)):
|
||||
out_conds[i] /= out_counts[i]
|
||||
|
||||
return out_conds
|
||||
|
||||
def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options): #TODO: remove
|
||||
logging.warning("WARNING: The comfy.samplers.calc_cond_uncond_batch function is deprecated please use the calc_cond_batch one instead.")
|
||||
return tuple(calc_cond_batch(model, [cond, uncond], x_in, timestep, model_options))
|
||||
@@ -646,6 +850,8 @@ def pre_run_control(model, conds):
|
||||
percent_to_timestep_function = lambda a: s.percent_to_sigma(a)
|
||||
if 'control' in x:
|
||||
x['control'].pre_run(model, percent_to_timestep_function)
|
||||
for device_cnet in x['control'].multigpu_clones.values():
|
||||
device_cnet.pre_run(model, percent_to_timestep_function)
|
||||
|
||||
def apply_empty_x_to_equal_area(conds, uncond, name, uncond_fill_func):
|
||||
cond_cnets = []
|
||||
@@ -888,7 +1094,9 @@ def cast_to_load_options(model_options: dict[str], device=None, dtype=None):
|
||||
to_load_options = model_options.get("to_load_options", None)
|
||||
if to_load_options is None:
|
||||
return
|
||||
cast_transformer_options(to_load_options, device, dtype)
|
||||
|
||||
def cast_transformer_options(transformer_options: dict[str], device=None, dtype=None):
|
||||
casts = []
|
||||
if device is not None:
|
||||
casts.append(device)
|
||||
@@ -897,18 +1105,17 @@ def cast_to_load_options(model_options: dict[str], device=None, dtype=None):
|
||||
# if nothing to apply, do nothing
|
||||
if len(casts) == 0:
|
||||
return
|
||||
|
||||
# try to call .to on patches
|
||||
if "patches" in to_load_options:
|
||||
patches = to_load_options["patches"]
|
||||
if "patches" in transformer_options:
|
||||
patches = transformer_options["patches"]
|
||||
for name in patches:
|
||||
patch_list = patches[name]
|
||||
for i in range(len(patch_list)):
|
||||
if hasattr(patch_list[i], "to"):
|
||||
for cast in casts:
|
||||
patch_list[i] = patch_list[i].to(cast)
|
||||
if "patches_replace" in to_load_options:
|
||||
patches = to_load_options["patches_replace"]
|
||||
if "patches_replace" in transformer_options:
|
||||
patches = transformer_options["patches_replace"]
|
||||
for name in patches:
|
||||
patch_list = patches[name]
|
||||
for k in patch_list:
|
||||
@@ -918,8 +1125,8 @@ def cast_to_load_options(model_options: dict[str], device=None, dtype=None):
|
||||
# try to call .to on any wrappers/callbacks
|
||||
wrappers_and_callbacks = ["wrappers", "callbacks"]
|
||||
for wc_name in wrappers_and_callbacks:
|
||||
if wc_name in to_load_options:
|
||||
wc: dict[str, list] = to_load_options[wc_name]
|
||||
if wc_name in transformer_options:
|
||||
wc: dict[str, list] = transformer_options[wc_name]
|
||||
for wc_dict in wc.values():
|
||||
for wc_list in wc_dict.values():
|
||||
for i in range(len(wc_list)):
|
||||
@@ -927,7 +1134,6 @@ def cast_to_load_options(model_options: dict[str], device=None, dtype=None):
|
||||
for cast in casts:
|
||||
wc_list[i] = wc_list[i].to(cast)
|
||||
|
||||
|
||||
class CFGGuider:
|
||||
def __init__(self, model_patcher: ModelPatcher):
|
||||
self.model_patcher = model_patcher
|
||||
@@ -973,6 +1179,8 @@ class CFGGuider:
|
||||
self.inner_model, self.conds, self.loaded_models = comfy.sampler_helpers.prepare_sampling(self.model_patcher, noise.shape, self.conds, self.model_options)
|
||||
device = self.model_patcher.load_device
|
||||
|
||||
multigpu_patchers = comfy.sampler_helpers.prepare_model_patcher_multigpu_clones(self.model_patcher, self.loaded_models, self.model_options)
|
||||
|
||||
if denoise_mask is not None:
|
||||
denoise_mask = comfy.sampler_helpers.prepare_mask(denoise_mask, noise.shape, device)
|
||||
|
||||
@@ -983,9 +1191,13 @@ class CFGGuider:
|
||||
|
||||
try:
|
||||
self.model_patcher.pre_run()
|
||||
for multigpu_patcher in multigpu_patchers:
|
||||
multigpu_patcher.pre_run()
|
||||
output = self.inner_sample(noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)
|
||||
finally:
|
||||
self.model_patcher.cleanup()
|
||||
for multigpu_patcher in multigpu_patchers:
|
||||
multigpu_patcher.cleanup()
|
||||
|
||||
comfy.sampler_helpers.cleanup_models(self.conds, self.loaded_models)
|
||||
del self.inner_model
|
||||
|
86
comfy_extras/nodes_multigpu.py
Normal file
86
comfy_extras/nodes_multigpu.py
Normal file
@@ -0,0 +1,86 @@
|
||||
from __future__ import annotations
|
||||
from inspect import cleandoc
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
if TYPE_CHECKING:
|
||||
from comfy.model_patcher import ModelPatcher
|
||||
import comfy.multigpu
|
||||
|
||||
|
||||
class MultiGPUWorkUnitsNode:
|
||||
"""
|
||||
Prepares model to have sampling accelerated via splitting work units.
|
||||
|
||||
Should be placed after nodes that modify the model object itself, such as compile or attention-switch nodes.
|
||||
|
||||
Other than those exceptions, this node can be placed in any order.
|
||||
"""
|
||||
|
||||
NodeId = "MultiGPU_WorkUnits"
|
||||
NodeName = "MultiGPU Work Units"
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"model": ("MODEL",),
|
||||
"max_gpus" : ("INT", {"default": 8, "min": 1, "step": 1}),
|
||||
},
|
||||
"optional": {
|
||||
"gpu_options": ("GPU_OPTIONS",)
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
FUNCTION = "init_multigpu"
|
||||
CATEGORY = "advanced/multigpu"
|
||||
DESCRIPTION = cleandoc(__doc__)
|
||||
|
||||
def init_multigpu(self, model: ModelPatcher, max_gpus: int, gpu_options: comfy.multigpu.GPUOptionsGroup=None):
|
||||
model = comfy.multigpu.create_multigpu_deepclones(model, max_gpus, gpu_options, reuse_loaded=True)
|
||||
return (model,)
|
||||
|
||||
class MultiGPUOptionsNode:
|
||||
"""
|
||||
Select the relative speed of GPUs in the special case they have significantly different performance from one another.
|
||||
"""
|
||||
|
||||
NodeId = "MultiGPU_Options"
|
||||
NodeName = "MultiGPU Options"
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"device_index": ("INT", {"default": 0, "min": 0, "max": 64}),
|
||||
"relative_speed": ("FLOAT", {"default": 1.0, "min": 0.0, "step": 0.01})
|
||||
},
|
||||
"optional": {
|
||||
"gpu_options": ("GPU_OPTIONS",)
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("GPU_OPTIONS",)
|
||||
FUNCTION = "create_gpu_options"
|
||||
CATEGORY = "advanced/multigpu"
|
||||
DESCRIPTION = cleandoc(__doc__)
|
||||
|
||||
def create_gpu_options(self, device_index: int, relative_speed: float, gpu_options: comfy.multigpu.GPUOptionsGroup=None):
|
||||
if not gpu_options:
|
||||
gpu_options = comfy.multigpu.GPUOptionsGroup()
|
||||
gpu_options.clone()
|
||||
|
||||
opt = comfy.multigpu.GPUOptions(device_index=device_index, relative_speed=relative_speed)
|
||||
gpu_options.add(opt)
|
||||
|
||||
return (gpu_options,)
|
||||
|
||||
|
||||
node_list = [
|
||||
MultiGPUWorkUnitsNode,
|
||||
MultiGPUOptionsNode
|
||||
]
|
||||
NODE_CLASS_MAPPINGS = {}
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {}
|
||||
|
||||
for node in node_list:
|
||||
NODE_CLASS_MAPPINGS[node.NodeId] = node
|
||||
NODE_DISPLAY_NAME_MAPPINGS[node.NodeId] = node.NodeName
|
@@ -7,7 +7,7 @@ import threading
|
||||
import time
|
||||
import traceback
|
||||
from enum import Enum
|
||||
from typing import List, Literal, NamedTuple, Optional
|
||||
from typing import List, Literal, NamedTuple, Optional, Union
|
||||
import asyncio
|
||||
|
||||
import torch
|
||||
@@ -891,7 +891,7 @@ def full_type_name(klass):
|
||||
return klass.__qualname__
|
||||
return module + '.' + klass.__qualname__
|
||||
|
||||
async def validate_prompt(prompt_id, prompt):
|
||||
async def validate_prompt(prompt_id, prompt, partial_execution_list: Union[list[str], None]):
|
||||
outputs = set()
|
||||
for x in prompt:
|
||||
if 'class_type' not in prompt[x]:
|
||||
@@ -915,7 +915,8 @@ async def validate_prompt(prompt_id, prompt):
|
||||
return (False, error, [], {})
|
||||
|
||||
if hasattr(class_, 'OUTPUT_NODE') and class_.OUTPUT_NODE is True:
|
||||
outputs.add(x)
|
||||
if partial_execution_list is None or x in partial_execution_list:
|
||||
outputs.add(x)
|
||||
|
||||
if len(outputs) == 0:
|
||||
error = {
|
||||
|
1
nodes.py
1
nodes.py
@@ -2268,6 +2268,7 @@ async def init_builtin_extra_nodes():
|
||||
"nodes_mahiro.py",
|
||||
"nodes_lt.py",
|
||||
"nodes_hooks.py",
|
||||
"nodes_multigpu.py",
|
||||
"nodes_load_3d.py",
|
||||
"nodes_cosmos.py",
|
||||
"nodes_video.py",
|
||||
|
@@ -1,5 +1,5 @@
|
||||
comfyui-frontend-package==1.23.4
|
||||
comfyui-workflow-templates==0.1.41
|
||||
comfyui-workflow-templates==0.1.44
|
||||
comfyui-embedded-docs==0.2.4
|
||||
torch
|
||||
torchsde
|
||||
|
@@ -681,7 +681,12 @@ class PromptServer():
|
||||
if "prompt" in json_data:
|
||||
prompt = json_data["prompt"]
|
||||
prompt_id = str(json_data.get("prompt_id", uuid.uuid4()))
|
||||
valid = await execution.validate_prompt(prompt_id, prompt)
|
||||
|
||||
partial_execution_targets = None
|
||||
if "partial_execution_targets" in json_data:
|
||||
partial_execution_targets = json_data["partial_execution_targets"]
|
||||
|
||||
valid = await execution.validate_prompt(prompt_id, prompt, partial_execution_targets)
|
||||
extra_data = {}
|
||||
if "extra_data" in json_data:
|
||||
extra_data = json_data["extra_data"]
|
||||
|
@@ -7,7 +7,7 @@ import subprocess
|
||||
|
||||
from pytest import fixture
|
||||
from comfy_execution.graph_utils import GraphBuilder
|
||||
from tests.inference.test_execution import ComfyClient
|
||||
from tests.inference.test_execution import ComfyClient, run_warmup
|
||||
|
||||
|
||||
@pytest.mark.execution
|
||||
@@ -24,6 +24,7 @@ class TestAsyncNodes:
|
||||
'--listen', args_pytest["listen"],
|
||||
'--port', str(args_pytest["port"]),
|
||||
'--extra-model-paths-config', 'tests/inference/extra_model_paths.yaml',
|
||||
'--cpu',
|
||||
]
|
||||
use_lru, lru_size = request.param
|
||||
if use_lru:
|
||||
@@ -82,6 +83,9 @@ class TestAsyncNodes:
|
||||
|
||||
def test_multiple_async_parallel_execution(self, client: ComfyClient, builder: GraphBuilder):
|
||||
"""Test that multiple async nodes execute in parallel."""
|
||||
# Warmup execution to ensure server is fully initialized
|
||||
run_warmup(client)
|
||||
|
||||
g = builder
|
||||
image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||
|
||||
@@ -148,6 +152,9 @@ class TestAsyncNodes:
|
||||
|
||||
def test_async_lazy_evaluation(self, client: ComfyClient, builder: GraphBuilder):
|
||||
"""Test async nodes with lazy evaluation."""
|
||||
# Warmup execution to ensure server is fully initialized
|
||||
run_warmup(client, prefix="warmup_lazy")
|
||||
|
||||
g = builder
|
||||
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||
input2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
|
||||
@@ -305,6 +312,9 @@ class TestAsyncNodes:
|
||||
|
||||
def test_async_caching_behavior(self, client: ComfyClient, builder: GraphBuilder):
|
||||
"""Test that async nodes are properly cached."""
|
||||
# Warmup execution to ensure server is fully initialized
|
||||
run_warmup(client, prefix="warmup_cache")
|
||||
|
||||
g = builder
|
||||
image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||
sleep_node = g.node("TestSleep", value=image.out(0), seconds=0.2)
|
||||
@@ -324,6 +334,9 @@ class TestAsyncNodes:
|
||||
|
||||
def test_async_with_dynamic_prompts(self, client: ComfyClient, builder: GraphBuilder):
|
||||
"""Test async nodes within dynamically generated prompts."""
|
||||
# Warmup execution to ensure server is fully initialized
|
||||
run_warmup(client, prefix="warmup_dynamic")
|
||||
|
||||
g = builder
|
||||
image1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||
image2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
|
||||
|
@@ -15,10 +15,18 @@ import urllib.parse
|
||||
import urllib.error
|
||||
from comfy_execution.graph_utils import GraphBuilder, Node
|
||||
|
||||
def run_warmup(client, prefix="warmup"):
|
||||
"""Run a simple workflow to warm up the server."""
|
||||
warmup_g = GraphBuilder(prefix=prefix)
|
||||
warmup_image = warmup_g.node("StubImage", content="BLACK", height=32, width=32, batch_size=1)
|
||||
warmup_g.node("PreviewImage", images=warmup_image.out(0))
|
||||
client.run(warmup_g)
|
||||
|
||||
class RunResult:
|
||||
def __init__(self, prompt_id: str):
|
||||
self.outputs: Dict[str,Dict] = {}
|
||||
self.runs: Dict[str,bool] = {}
|
||||
self.cached: Dict[str,bool] = {}
|
||||
self.prompt_id: str = prompt_id
|
||||
|
||||
def get_output(self, node: Node):
|
||||
@@ -27,6 +35,13 @@ class RunResult:
|
||||
def did_run(self, node: Node):
|
||||
return self.runs.get(node.id, False)
|
||||
|
||||
def was_cached(self, node: Node):
|
||||
return self.cached.get(node.id, False)
|
||||
|
||||
def was_executed(self, node: Node):
|
||||
"""Returns True if node was either run or cached"""
|
||||
return self.did_run(node) or self.was_cached(node)
|
||||
|
||||
def get_images(self, node: Node):
|
||||
output = self.get_output(node)
|
||||
if output is None:
|
||||
@@ -51,8 +66,10 @@ class ComfyClient:
|
||||
ws.connect("ws://{}/ws?clientId={}".format(self.server_address, self.client_id))
|
||||
self.ws = ws
|
||||
|
||||
def queue_prompt(self, prompt):
|
||||
def queue_prompt(self, prompt, partial_execution_targets=None):
|
||||
p = {"prompt": prompt, "client_id": self.client_id}
|
||||
if partial_execution_targets is not None:
|
||||
p["partial_execution_targets"] = partial_execution_targets
|
||||
data = json.dumps(p).encode('utf-8')
|
||||
req = urllib.request.Request("http://{}/prompt".format(self.server_address), data=data)
|
||||
return json.loads(urllib.request.urlopen(req).read())
|
||||
@@ -70,13 +87,13 @@ class ComfyClient:
|
||||
def set_test_name(self, name):
|
||||
self.test_name = name
|
||||
|
||||
def run(self, graph):
|
||||
def run(self, graph, partial_execution_targets=None):
|
||||
prompt = graph.finalize()
|
||||
for node in graph.nodes.values():
|
||||
if node.class_type == 'SaveImage':
|
||||
node.inputs['filename_prefix'] = self.test_name
|
||||
|
||||
prompt_id = self.queue_prompt(prompt)['prompt_id']
|
||||
prompt_id = self.queue_prompt(prompt, partial_execution_targets)['prompt_id']
|
||||
result = RunResult(prompt_id)
|
||||
while True:
|
||||
out = self.ws.recv()
|
||||
@@ -92,7 +109,10 @@ class ComfyClient:
|
||||
elif message['type'] == 'execution_error':
|
||||
raise Exception(message['data'])
|
||||
elif message['type'] == 'execution_cached':
|
||||
pass # Probably want to store this off for testing
|
||||
if message['data']['prompt_id'] == prompt_id:
|
||||
cached_nodes = message['data'].get('nodes', [])
|
||||
for node_id in cached_nodes:
|
||||
result.cached[node_id] = True
|
||||
|
||||
history = self.get_history(prompt_id)[prompt_id]
|
||||
for node_id in history['outputs']:
|
||||
@@ -130,6 +150,7 @@ class TestExecution:
|
||||
'--listen', args_pytest["listen"],
|
||||
'--port', str(args_pytest["port"]),
|
||||
'--extra-model-paths-config', 'tests/inference/extra_model_paths.yaml',
|
||||
'--cpu',
|
||||
]
|
||||
use_lru, lru_size = request.param
|
||||
if use_lru:
|
||||
@@ -498,12 +519,15 @@ class TestExecution:
|
||||
assert not result.did_run(test_node), "The execution should have been cached"
|
||||
|
||||
def test_parallel_sleep_nodes(self, client: ComfyClient, builder: GraphBuilder):
|
||||
# Warmup execution to ensure server is fully initialized
|
||||
run_warmup(client)
|
||||
|
||||
g = builder
|
||||
image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||
|
||||
# Create sleep nodes for each duration
|
||||
sleep_node1 = g.node("TestSleep", value=image.out(0), seconds=2.8)
|
||||
sleep_node2 = g.node("TestSleep", value=image.out(0), seconds=2.9)
|
||||
sleep_node1 = g.node("TestSleep", value=image.out(0), seconds=2.9)
|
||||
sleep_node2 = g.node("TestSleep", value=image.out(0), seconds=3.1)
|
||||
sleep_node3 = g.node("TestSleep", value=image.out(0), seconds=3.0)
|
||||
|
||||
# Add outputs to verify the execution
|
||||
@@ -515,10 +539,9 @@ class TestExecution:
|
||||
result = client.run(g)
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
# The test should take around 0.4 seconds (the longest sleep duration)
|
||||
# plus some overhead, but definitely less than the sum of all sleeps (0.9s)
|
||||
# We'll allow for up to 0.8s total to account for overhead
|
||||
assert elapsed_time < 4.0, f"Parallel execution took {elapsed_time}s, expected less than 0.8s"
|
||||
# The test should take around 3.0 seconds (the longest sleep duration)
|
||||
# plus some overhead, but definitely less than the sum of all sleeps (9.0s)
|
||||
assert elapsed_time < 8.9, f"Parallel execution took {elapsed_time}s, expected less than 8.9s"
|
||||
|
||||
# Verify that all nodes executed
|
||||
assert result.did_run(sleep_node1), "Sleep node 1 should have run"
|
||||
@@ -526,6 +549,9 @@ class TestExecution:
|
||||
assert result.did_run(sleep_node3), "Sleep node 3 should have run"
|
||||
|
||||
def test_parallel_sleep_expansion(self, client: ComfyClient, builder: GraphBuilder):
|
||||
# Warmup execution to ensure server is fully initialized
|
||||
run_warmup(client)
|
||||
|
||||
g = builder
|
||||
# Create input images with different values
|
||||
image1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||
@@ -537,9 +563,9 @@ class TestExecution:
|
||||
image1=image1.out(0),
|
||||
image2=image2.out(0),
|
||||
image3=image3.out(0),
|
||||
sleep1=0.4,
|
||||
sleep2=0.5,
|
||||
sleep3=0.6)
|
||||
sleep1=4.8,
|
||||
sleep2=4.9,
|
||||
sleep3=5.0)
|
||||
output = g.node("SaveImage", images=parallel_sleep.out(0))
|
||||
|
||||
start_time = time.time()
|
||||
@@ -548,7 +574,7 @@ class TestExecution:
|
||||
|
||||
# Similar to the previous test, expect parallel execution of the sleep nodes
|
||||
# which should complete in less than the sum of all sleeps
|
||||
assert elapsed_time < 0.8, f"Expansion execution took {elapsed_time}s, expected less than 0.8s"
|
||||
assert elapsed_time < 10.0, f"Expansion execution took {elapsed_time}s, expected less than 5.5s"
|
||||
|
||||
# Verify the parallel sleep node executed
|
||||
assert result.did_run(parallel_sleep), "ParallelSleep node should have run"
|
||||
@@ -585,3 +611,151 @@ class TestExecution:
|
||||
assert len(images) == 2, "Should have 2 images"
|
||||
assert numpy.array(images[0]).min() == 0 and numpy.array(images[0]).max() == 0, "First image should be black"
|
||||
assert numpy.array(images[1]).min() == 0 and numpy.array(images[1]).max() == 0, "Second image should also be black"
|
||||
|
||||
# Output nodes included in the partial execution list are executed
|
||||
def test_partial_execution_included_outputs(self, client: ComfyClient, builder: GraphBuilder):
|
||||
g = builder
|
||||
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||
input2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
|
||||
|
||||
# Create two separate output nodes
|
||||
output1 = g.node("SaveImage", images=input1.out(0))
|
||||
output2 = g.node("SaveImage", images=input2.out(0))
|
||||
|
||||
# Run with partial execution targeting only output1
|
||||
result = client.run(g, partial_execution_targets=[output1.id])
|
||||
|
||||
assert result.was_executed(input1), "Input1 should have been executed (run or cached)"
|
||||
assert result.was_executed(output1), "Output1 should have been executed (run or cached)"
|
||||
assert not result.did_run(input2), "Input2 should not have run"
|
||||
assert not result.did_run(output2), "Output2 should not have run"
|
||||
|
||||
# Verify only output1 produced results
|
||||
assert len(result.get_images(output1)) == 1, "Output1 should have produced an image"
|
||||
assert len(result.get_images(output2)) == 0, "Output2 should not have produced an image"
|
||||
|
||||
# Output nodes NOT included in the partial execution list are NOT executed
|
||||
def test_partial_execution_excluded_outputs(self, client: ComfyClient, builder: GraphBuilder):
|
||||
g = builder
|
||||
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||
input2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
|
||||
input3 = g.node("StubImage", content="NOISE", height=512, width=512, batch_size=1)
|
||||
|
||||
# Create three output nodes
|
||||
output1 = g.node("SaveImage", images=input1.out(0))
|
||||
output2 = g.node("SaveImage", images=input2.out(0))
|
||||
output3 = g.node("SaveImage", images=input3.out(0))
|
||||
|
||||
# Run with partial execution targeting only output1 and output3
|
||||
result = client.run(g, partial_execution_targets=[output1.id, output3.id])
|
||||
|
||||
assert result.was_executed(input1), "Input1 should have been executed"
|
||||
assert result.was_executed(input3), "Input3 should have been executed"
|
||||
assert result.was_executed(output1), "Output1 should have been executed"
|
||||
assert result.was_executed(output3), "Output3 should have been executed"
|
||||
assert not result.did_run(input2), "Input2 should not have run"
|
||||
assert not result.did_run(output2), "Output2 should not have run"
|
||||
|
||||
# Output nodes NOT in list ARE executed if necessary for nodes that are in the list
|
||||
def test_partial_execution_dependencies(self, client: ComfyClient, builder: GraphBuilder):
|
||||
g = builder
|
||||
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||
|
||||
# Create a processing chain with an OUTPUT_NODE that has socket outputs
|
||||
output_with_socket = g.node("TestOutputNodeWithSocketOutput", image=input1.out(0), value=2.0)
|
||||
|
||||
# Create another node that depends on the output_with_socket
|
||||
dependent_node = g.node("TestLazyMixImages",
|
||||
image1=output_with_socket.out(0),
|
||||
image2=input1.out(0),
|
||||
mask=g.node("StubMask", value=0.5, height=512, width=512, batch_size=1).out(0))
|
||||
|
||||
# Create the final output
|
||||
final_output = g.node("SaveImage", images=dependent_node.out(0))
|
||||
|
||||
# Run with partial execution targeting only the final output
|
||||
result = client.run(g, partial_execution_targets=[final_output.id])
|
||||
|
||||
# All nodes should have been executed because they're dependencies
|
||||
assert result.was_executed(input1), "Input1 should have been executed"
|
||||
assert result.was_executed(output_with_socket), "Output with socket should have been executed (dependency)"
|
||||
assert result.was_executed(dependent_node), "Dependent node should have been executed"
|
||||
assert result.was_executed(final_output), "Final output should have been executed"
|
||||
|
||||
# Lazy execution works with partial execution
|
||||
def test_partial_execution_with_lazy_nodes(self, client: ComfyClient, builder: GraphBuilder):
|
||||
g = builder
|
||||
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||
input2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
|
||||
input3 = g.node("StubImage", content="NOISE", height=512, width=512, batch_size=1)
|
||||
|
||||
# Create masks that will trigger different lazy execution paths
|
||||
mask1 = g.node("StubMask", value=0.0, height=512, width=512, batch_size=1) # Will only need image1
|
||||
mask2 = g.node("StubMask", value=0.5, height=512, width=512, batch_size=1) # Will need both images
|
||||
|
||||
# Create two lazy mix nodes
|
||||
lazy_mix1 = g.node("TestLazyMixImages", image1=input1.out(0), image2=input2.out(0), mask=mask1.out(0))
|
||||
lazy_mix2 = g.node("TestLazyMixImages", image1=input2.out(0), image2=input3.out(0), mask=mask2.out(0))
|
||||
|
||||
output1 = g.node("SaveImage", images=lazy_mix1.out(0))
|
||||
output2 = g.node("SaveImage", images=lazy_mix2.out(0))
|
||||
|
||||
# Run with partial execution targeting only output1
|
||||
result = client.run(g, partial_execution_targets=[output1.id])
|
||||
|
||||
# For output1 path - only input1 should run due to lazy evaluation (mask=0.0)
|
||||
assert result.was_executed(input1), "Input1 should have been executed"
|
||||
assert not result.did_run(input2), "Input2 should not have run (lazy evaluation)"
|
||||
assert result.was_executed(mask1), "Mask1 should have been executed"
|
||||
assert result.was_executed(lazy_mix1), "Lazy mix1 should have been executed"
|
||||
assert result.was_executed(output1), "Output1 should have been executed"
|
||||
|
||||
# Nothing from output2 path should run
|
||||
assert not result.did_run(input3), "Input3 should not have run"
|
||||
assert not result.did_run(mask2), "Mask2 should not have run"
|
||||
assert not result.did_run(lazy_mix2), "Lazy mix2 should not have run"
|
||||
assert not result.did_run(output2), "Output2 should not have run"
|
||||
|
||||
# Multiple OUTPUT_NODEs with dependencies
|
||||
def test_partial_execution_multiple_output_nodes(self, client: ComfyClient, builder: GraphBuilder):
|
||||
g = builder
|
||||
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||
input2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
|
||||
|
||||
# Create a chain of OUTPUT_NODEs
|
||||
output_node1 = g.node("TestOutputNodeWithSocketOutput", image=input1.out(0), value=1.5)
|
||||
output_node2 = g.node("TestOutputNodeWithSocketOutput", image=output_node1.out(0), value=2.0)
|
||||
|
||||
# Create regular output nodes
|
||||
save1 = g.node("SaveImage", images=output_node1.out(0))
|
||||
save2 = g.node("SaveImage", images=output_node2.out(0))
|
||||
save3 = g.node("SaveImage", images=input2.out(0))
|
||||
|
||||
# Run targeting only save2
|
||||
result = client.run(g, partial_execution_targets=[save2.id])
|
||||
|
||||
# Should run: input1, output_node1, output_node2, save2
|
||||
assert result.was_executed(input1), "Input1 should have been executed"
|
||||
assert result.was_executed(output_node1), "Output node 1 should have been executed (dependency)"
|
||||
assert result.was_executed(output_node2), "Output node 2 should have been executed (dependency)"
|
||||
assert result.was_executed(save2), "Save2 should have been executed"
|
||||
|
||||
# Should NOT run: input2, save1, save3
|
||||
assert not result.did_run(input2), "Input2 should not have run"
|
||||
assert not result.did_run(save1), "Save1 should not have run"
|
||||
assert not result.did_run(save3), "Save3 should not have run"
|
||||
|
||||
# Empty partial execution list (should execute nothing)
|
||||
def test_partial_execution_empty_list(self, client: ComfyClient, builder: GraphBuilder):
|
||||
g = builder
|
||||
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||
_output1 = g.node("SaveImage", images=input1.out(0))
|
||||
|
||||
# Run with empty partial execution list
|
||||
try:
|
||||
_result = client.run(g, partial_execution_targets=[])
|
||||
# Should get an error because no outputs are selected
|
||||
assert False, "Should have raised an error for empty partial execution list"
|
||||
except urllib.error.HTTPError:
|
||||
pass # Expected behavior
|
||||
|
||||
|
@@ -463,6 +463,25 @@ class TestParallelSleep(ComfyNodeABC):
|
||||
"expand": g.finalize(),
|
||||
}
|
||||
|
||||
class TestOutputNodeWithSocketOutput:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"image": ("IMAGE",),
|
||||
"value": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0}),
|
||||
},
|
||||
}
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "process"
|
||||
CATEGORY = "_for_testing"
|
||||
OUTPUT_NODE = True
|
||||
|
||||
def process(self, image, value):
|
||||
# Apply value scaling and return both as output and socket
|
||||
result = image * value
|
||||
return (result,)
|
||||
|
||||
TEST_NODE_CLASS_MAPPINGS = {
|
||||
"TestLazyMixImages": TestLazyMixImages,
|
||||
"TestVariadicAverage": TestVariadicAverage,
|
||||
@@ -478,6 +497,7 @@ TEST_NODE_CLASS_MAPPINGS = {
|
||||
"TestSamplingInExpansion": TestSamplingInExpansion,
|
||||
"TestSleep": TestSleep,
|
||||
"TestParallelSleep": TestParallelSleep,
|
||||
"TestOutputNodeWithSocketOutput": TestOutputNodeWithSocketOutput,
|
||||
}
|
||||
|
||||
TEST_NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
@@ -495,4 +515,5 @@ TEST_NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"TestSamplingInExpansion": "Sampling In Expansion",
|
||||
"TestSleep": "Test Sleep",
|
||||
"TestParallelSleep": "Test Parallel Sleep",
|
||||
"TestOutputNodeWithSocketOutput": "Test Output Node With Socket Output",
|
||||
}
|
||||
|
Reference in New Issue
Block a user