mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-08-02 15:04:50 +08:00
Compare commits
10 Commits
8beead753a
...
ab98b65226
Author | SHA1 | Date | |
---|---|---|---|
|
ab98b65226 | ||
|
b99e3d1336 | ||
|
3aceeab359 | ||
|
326a2593e0 | ||
|
a8f1981bf2 | ||
|
5c94199b04 | ||
|
205611cc22 | ||
|
d703ba9633 | ||
|
106bc9b32a | ||
|
c3334ae813 |
@@ -1,6 +1,10 @@
|
||||
class ComfyNodeInternal:
|
||||
'''Class that all V3-based APIs inhertif from for ComfyNode.
|
||||
|
||||
This is intended to only be referenced within execution.py, as it has to handle all V3 APIs going forward.'''
|
||||
...
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
class ComfyNodeInternal(ABC):
|
||||
"""Class that all V3-based APIs inherit from for ComfyNode.
|
||||
|
||||
This is intended to only be referenced within execution.py, as it has to handle all V3 APIs going forward."""
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def GET_NODE_INFO_V1(cls):
|
||||
...
|
||||
|
@@ -1,10 +1,16 @@
|
||||
from typing import Callable, Optional
|
||||
|
||||
|
||||
def first_real_override(cls: type, name: str, *, base: type) -> Optional[Callable]:
|
||||
def first_real_override(cls: type, name: str, *, base: type=None) -> Optional[Callable]:
|
||||
"""Return the *callable* override of `name` visible on `cls`, or None if every
|
||||
implementation up to (and including) `base` is the placeholder defined on `base`.
|
||||
|
||||
If base is not provided, it will assume cls has a GET_BASE_CLASS
|
||||
"""
|
||||
if base is None:
|
||||
if not hasattr(cls, "GET_BASE_CLASS"):
|
||||
raise ValueError("base is required if cls does not have a GET_BASE_CLASS; is this a valid ComfyNode subclass?")
|
||||
base = cls.GET_BASE_CLASS()
|
||||
base_attr = getattr(base, name, None)
|
||||
if base_attr is None:
|
||||
return None
|
||||
|
@@ -6,11 +6,12 @@ from collections import Counter
|
||||
from dataclasses import asdict, dataclass
|
||||
from enum import Enum
|
||||
from typing import Any, Callable, Literal, TypedDict, TypeVar
|
||||
from comfy_api.v3.helpers import first_real_override
|
||||
|
||||
# used for type hinting
|
||||
import torch
|
||||
from spandrel import ImageModelDescriptor
|
||||
from typing_extensions import NotRequired
|
||||
from typing_extensions import NotRequired, final
|
||||
|
||||
from comfy.clip_vision import ClipVisionModel
|
||||
from comfy.clip_vision import Output as ClipVisionOutput_
|
||||
@@ -101,6 +102,7 @@ def copy_class(cls: type) -> type:
|
||||
class NumberDisplay(str, Enum):
|
||||
number = "number"
|
||||
slider = "slider"
|
||||
color = "color"
|
||||
|
||||
|
||||
class ComfyType(ABC):
|
||||
@@ -188,14 +190,15 @@ class InputV3(IO_V3):
|
||||
self.lazy = lazy
|
||||
self.extra_dict = extra_dict if extra_dict is not None else {}
|
||||
|
||||
def as_dict_V1(self):
|
||||
def as_dict(self):
|
||||
return prune_dict({
|
||||
"display_name": self.display_name,
|
||||
"optional": self.optional,
|
||||
"tooltip": self.tooltip,
|
||||
"lazy": self.lazy,
|
||||
}) | prune_dict(self.extra_dict)
|
||||
|
||||
def get_io_type_V1(self):
|
||||
def get_io_type(self):
|
||||
return self.io_type
|
||||
|
||||
class WidgetInputV3(InputV3):
|
||||
@@ -204,23 +207,23 @@ class WidgetInputV3(InputV3):
|
||||
'''
|
||||
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
|
||||
default: Any=None,
|
||||
socketless: bool=None, widgetType: str=None, force_input: bool=None, extra_dict=None):
|
||||
socketless: bool=None, widget_type: str=None, force_input: bool=None, extra_dict=None):
|
||||
super().__init__(id, display_name, optional, tooltip, lazy, extra_dict)
|
||||
self.default = default
|
||||
self.socketless = socketless
|
||||
self.widgetType = widgetType
|
||||
self.widget_type = widget_type
|
||||
self.force_input = force_input
|
||||
|
||||
def as_dict_V1(self):
|
||||
return super().as_dict_V1() | prune_dict({
|
||||
def as_dict(self):
|
||||
return super().as_dict() | prune_dict({
|
||||
"default": self.default,
|
||||
"socketless": self.socketless,
|
||||
"widgetType": self.widgetType,
|
||||
"widgetType": self.widget_type,
|
||||
"forceInput": self.force_input,
|
||||
})
|
||||
|
||||
def get_io_type_V1(self):
|
||||
return self.widgetType if self.widgetType is not None else super().get_io_type_V1()
|
||||
def get_io_type(self):
|
||||
return self.widget_type if self.widget_type is not None else super().get_io_type()
|
||||
|
||||
|
||||
class OutputV3(IO_V3):
|
||||
@@ -230,14 +233,17 @@ class OutputV3(IO_V3):
|
||||
self.display_name = display_name
|
||||
self.tooltip = tooltip
|
||||
self.is_output_list = is_output_list
|
||||
|
||||
def as_dict_V3(self):
|
||||
|
||||
def as_dict(self):
|
||||
return prune_dict({
|
||||
"display_name": self.display_name,
|
||||
"tooltip": self.tooltip,
|
||||
"is_output_list": self.is_output_list,
|
||||
})
|
||||
|
||||
def get_io_type(self):
|
||||
return self.io_type
|
||||
|
||||
|
||||
class ComfyTypeI(ComfyType):
|
||||
'''ComfyType subclass that only has a default Input class - intended for types that only have Inputs.'''
|
||||
@@ -325,8 +331,8 @@ class Boolean(ComfyTypeIO):
|
||||
self.label_off = label_off
|
||||
self.default: bool
|
||||
|
||||
def as_dict_V1(self):
|
||||
return super().as_dict_V1() | prune_dict({
|
||||
def as_dict(self):
|
||||
return super().as_dict() | prune_dict({
|
||||
"label_on": self.label_on,
|
||||
"label_off": self.label_off,
|
||||
})
|
||||
@@ -348,13 +354,13 @@ class Int(ComfyTypeIO):
|
||||
self.display_mode = display_mode
|
||||
self.default: int
|
||||
|
||||
def as_dict_V1(self):
|
||||
return super().as_dict_V1() | prune_dict({
|
||||
def as_dict(self):
|
||||
return super().as_dict() | prune_dict({
|
||||
"min": self.min,
|
||||
"max": self.max,
|
||||
"step": self.step,
|
||||
"control_after_generate": self.control_after_generate,
|
||||
"display": self.display_mode,
|
||||
"display": self.display_mode.value if self.display_mode else None,
|
||||
})
|
||||
|
||||
@comfytype(io_type="FLOAT")
|
||||
@@ -374,8 +380,8 @@ class Float(ComfyTypeIO):
|
||||
self.display_mode = display_mode
|
||||
self.default: float
|
||||
|
||||
def as_dict_V1(self):
|
||||
return super().as_dict_V1() | prune_dict({
|
||||
def as_dict(self):
|
||||
return super().as_dict() | prune_dict({
|
||||
"min": self.min,
|
||||
"max": self.max,
|
||||
"step": self.step,
|
||||
@@ -390,19 +396,19 @@ class String(ComfyTypeIO):
|
||||
class Input(WidgetInputV3):
|
||||
'''String input.'''
|
||||
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
|
||||
multiline=False, placeholder: str=None, default: str=None, dynamicPrompts: bool=None,
|
||||
multiline=False, placeholder: str=None, default: str=None, dynamic_prompts: bool=None,
|
||||
socketless: bool=None, force_input: bool=None):
|
||||
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input)
|
||||
self.multiline = multiline
|
||||
self.placeholder = placeholder
|
||||
self.dynamicPrompts = dynamicPrompts
|
||||
self.dynamic_prompts = dynamic_prompts
|
||||
self.default: str
|
||||
|
||||
def as_dict_V1(self):
|
||||
return super().as_dict_V1() | prune_dict({
|
||||
def as_dict(self):
|
||||
return super().as_dict() | prune_dict({
|
||||
"multiline": self.multiline,
|
||||
"placeholder": self.placeholder,
|
||||
"dynamicPrompts": self.dynamicPrompts,
|
||||
"dynamicPrompts": self.dynamic_prompts,
|
||||
})
|
||||
|
||||
@comfytype(io_type="COMBO")
|
||||
@@ -425,8 +431,8 @@ class Combo(ComfyTypeI):
|
||||
self.remote = remote
|
||||
self.default: str
|
||||
|
||||
def as_dict_V1(self):
|
||||
return super().as_dict_V1() | prune_dict({
|
||||
def as_dict(self):
|
||||
return super().as_dict() | prune_dict({
|
||||
"multiselect": self.multiselect,
|
||||
"options": self.options,
|
||||
"control_after_generate": self.control_after_generate,
|
||||
@@ -444,15 +450,15 @@ class MultiCombo(ComfyTypeI):
|
||||
class Input(Combo.Input):
|
||||
def __init__(self, id: str, options: list[str], display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
|
||||
default: list[str]=None, placeholder: str=None, chip: bool=None, control_after_generate: bool=None,
|
||||
socketless: bool=None, widgetType: str=None):
|
||||
super().__init__(id, options, display_name, optional, tooltip, lazy, default, control_after_generate, socketless, widgetType)
|
||||
socketless: bool=None):
|
||||
super().__init__(id, options, display_name, optional, tooltip, lazy, default, control_after_generate, socketless=socketless)
|
||||
self.multiselect = True
|
||||
self.placeholder = placeholder
|
||||
self.chip = chip
|
||||
self.default: list[str]
|
||||
|
||||
def as_dict_V1(self):
|
||||
to_return = super().as_dict_V1() | prune_dict({
|
||||
def as_dict(self):
|
||||
to_return = super().as_dict() | prune_dict({
|
||||
"multi_select": self.multiselect,
|
||||
"placeholder": self.placeholder,
|
||||
"chip": self.chip,
|
||||
@@ -767,9 +773,9 @@ class MultiType:
|
||||
display_name = id.display_name if id.display_name is not None else display_name
|
||||
lazy = id.lazy if id.lazy is not None else lazy
|
||||
id = id.id
|
||||
# if is a widget input, make sure widgetType is set appropriately
|
||||
# if is a widget input, make sure widget_type is set appropriately
|
||||
if isinstance(self.input_override, WidgetInputV3):
|
||||
self.input_override.widgetType = self.input_override.get_io_type_V1()
|
||||
self.input_override.widget_type = self.input_override.get_io_type()
|
||||
super().__init__(id, display_name, optional, tooltip, lazy, extra_dict)
|
||||
self._io_types = types
|
||||
|
||||
@@ -786,18 +792,18 @@ class MultiType:
|
||||
io_types.append(x)
|
||||
return io_types
|
||||
|
||||
def get_io_type_V1(self):
|
||||
def get_io_type(self):
|
||||
# ensure types are unique and order is preserved
|
||||
str_types = [x.io_type for x in self.io_types]
|
||||
if self.input_override is not None:
|
||||
str_types.insert(0, self.input_override.get_io_type_V1())
|
||||
str_types.insert(0, self.input_override.get_io_type())
|
||||
return ",".join(list(dict.fromkeys(str_types)))
|
||||
|
||||
def as_dict_V1(self):
|
||||
def as_dict(self):
|
||||
if self.input_override is not None:
|
||||
return self.input_override.as_dict_V1() | super().as_dict_V1()
|
||||
return self.input_override.as_dict() | super().as_dict()
|
||||
else:
|
||||
return super().as_dict_V1()
|
||||
return super().as_dict()
|
||||
|
||||
class DynamicInput(InputV3, ABC):
|
||||
'''
|
||||
@@ -889,22 +895,22 @@ class MatchType(ComfyTypeIO):
|
||||
def get_dynamic(self) -> list[InputV3]:
|
||||
return [self]
|
||||
|
||||
def as_dict_V1(self):
|
||||
return super().as_dict_V1() | prune_dict({
|
||||
def as_dict(self):
|
||||
return super().as_dict() | prune_dict({
|
||||
"template": self.template.as_dict(),
|
||||
})
|
||||
|
||||
|
||||
class Output(DynamicOutput):
|
||||
def __init__(self, id: str, template: MatchType.Template, display_name: str=None, tooltip: str=None,
|
||||
is_output_list=False):
|
||||
super().__init__(id, display_name, tooltip, is_output_list)
|
||||
self.template = template
|
||||
|
||||
|
||||
def get_dynamic(self) -> list[OutputV3]:
|
||||
return [self]
|
||||
|
||||
def as_dict_V3(self):
|
||||
return super().as_dict_V3() | prune_dict({
|
||||
|
||||
def as_dict(self):
|
||||
return super().as_dict() | prune_dict({
|
||||
"template": self.template.as_dict(),
|
||||
})
|
||||
|
||||
@@ -979,6 +985,19 @@ class NodeInfoV1:
|
||||
experimental: bool=None
|
||||
api_node: bool=None
|
||||
|
||||
@dataclass
|
||||
class NodeInfoV3:
|
||||
input: dict=None
|
||||
output: dict=None
|
||||
hidden: list[str]=None
|
||||
name: str=None
|
||||
display_name: str=None
|
||||
description: str=None
|
||||
category: str=None
|
||||
output_node: bool=None
|
||||
deprecated: bool=None
|
||||
experimental: bool=None
|
||||
api_node: bool=None
|
||||
|
||||
def as_pruned_dict(dataclass_obj):
|
||||
'''Return dict of dataclass object with pruned None values.'''
|
||||
@@ -1081,6 +1100,84 @@ class SchemaV3:
|
||||
if output.id is None:
|
||||
output.id = f"_{i}_{output.io_type}_"
|
||||
|
||||
def get_v1_info(self, cls) -> NodeInfoV1:
|
||||
# get V1 inputs
|
||||
input = {
|
||||
"required": {}
|
||||
}
|
||||
if self.inputs:
|
||||
for i in self.inputs:
|
||||
if isinstance(i, DynamicInput):
|
||||
dynamic_inputs = i.get_dynamic()
|
||||
for d in dynamic_inputs:
|
||||
add_to_dict_v1(d, input)
|
||||
else:
|
||||
add_to_dict_v1(i, input)
|
||||
if self.hidden:
|
||||
for hidden in self.hidden:
|
||||
input.setdefault("hidden", {})[hidden.name] = (hidden.value,)
|
||||
# create separate lists from output fields
|
||||
output = []
|
||||
output_is_list = []
|
||||
output_name = []
|
||||
output_tooltips = []
|
||||
if self.outputs:
|
||||
for o in self.outputs:
|
||||
output.append(o.io_type)
|
||||
output_is_list.append(o.is_output_list)
|
||||
output_name.append(o.display_name if o.display_name else o.io_type)
|
||||
output_tooltips.append(o.tooltip if o.tooltip else None)
|
||||
|
||||
info = NodeInfoV1(
|
||||
input=input,
|
||||
input_order={key: list(value.keys()) for (key, value) in input.items()},
|
||||
output=output,
|
||||
output_is_list=output_is_list,
|
||||
output_name=output_name,
|
||||
output_tooltips=output_tooltips,
|
||||
name=self.node_id,
|
||||
display_name=self.display_name,
|
||||
category=self.category,
|
||||
description=self.description,
|
||||
output_node=self.is_output_node,
|
||||
deprecated=self.is_deprecated,
|
||||
experimental=self.is_experimental,
|
||||
api_node=self.is_api_node,
|
||||
python_module=getattr(cls, "RELATIVE_PYTHON_MODULE", "nodes")
|
||||
)
|
||||
return info
|
||||
|
||||
|
||||
def get_v3_info(self, cls) -> NodeInfoV3:
|
||||
input_dict = {}
|
||||
output_dict = {}
|
||||
hidden_list = []
|
||||
# TODO: make sure dynamic types will be handled correctly
|
||||
if self.inputs:
|
||||
for input in self.inputs:
|
||||
add_to_dict_v3(input, input_dict)
|
||||
if self.outputs:
|
||||
for output in self.outputs:
|
||||
add_to_dict_v3(output, output_dict)
|
||||
if self.hidden:
|
||||
for hidden in self.hidden:
|
||||
hidden_list.append(hidden.value)
|
||||
|
||||
info = NodeInfoV3(
|
||||
input=input_dict,
|
||||
output=output_dict,
|
||||
hidden=hidden_list,
|
||||
name=self.node_id,
|
||||
display_name=self.display_name,
|
||||
description=self.description,
|
||||
category=self.category,
|
||||
output_node=self.is_output_node,
|
||||
deprecated=self.is_deprecated,
|
||||
experimental=self.is_experimental,
|
||||
api_node=self.is_api_node,
|
||||
python_module=getattr(cls, "RELATIVE_PYTHON_MODULE", "nodes")
|
||||
)
|
||||
return info
|
||||
|
||||
class Serializer:
|
||||
def __init_subclass__(cls, io_type: str, **kwargs):
|
||||
@@ -1139,11 +1236,18 @@ def lock_class(cls):
|
||||
|
||||
def add_to_dict_v1(i: InputV3, input: dict):
|
||||
key = "optional" if i.optional else "required"
|
||||
input.setdefault(key, {})[i.id] = (i.get_io_type_V1(), i.as_dict_V1())
|
||||
as_dict = i.as_dict()
|
||||
# for v1, we don't want to include the optional key
|
||||
as_dict.pop("optional", None)
|
||||
input.setdefault(key, {})[i.id] = (i.get_io_type(), as_dict)
|
||||
|
||||
def add_to_dict_v3(io: InputV3 | OutputV3, d: dict):
|
||||
d[io.id] = (io.get_io_type(), io.as_dict())
|
||||
|
||||
|
||||
class ComfyNodeV3(ComfyNodeInternal):
|
||||
"""Common base class for all V3 nodes."""
|
||||
|
||||
class _ComfyNodeBaseInternal(ComfyNodeInternal):
|
||||
"""Common base class for storing internal methods and properties; DO NOT USE for defining nodes."""
|
||||
|
||||
RELATIVE_PYTHON_MODULE = None
|
||||
SCHEMA = None
|
||||
@@ -1155,13 +1259,14 @@ class ComfyNodeV3(ComfyNodeInternal):
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def DEFINE_SCHEMA(cls) -> SchemaV3:
|
||||
def define_schema(cls) -> SchemaV3:
|
||||
"""Override this function with one that returns a SchemaV3 instance."""
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def execute(cls, **kwargs) -> NodeOutput:
|
||||
"""Override this function with one that performs node's actions."""
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
@@ -1190,28 +1295,28 @@ class ComfyNodeV3(ComfyNodeInternal):
|
||||
"""
|
||||
return [name for name in kwargs if kwargs[name] is None]
|
||||
|
||||
@classmethod
|
||||
def GET_SERIALIZERS(cls) -> list[Serializer]:
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
def GET_NODE_INFO_V3(cls) -> dict[str, Any]:
|
||||
# schema = cls.GET_SCHEMA()
|
||||
# TODO: finish
|
||||
return None
|
||||
|
||||
def __init__(self):
|
||||
self.local_state: NodeStateLocal = None
|
||||
self.local_resources: ResourcesLocal = None
|
||||
self.__class__.VALIDATE_CLASS()
|
||||
|
||||
@classmethod
|
||||
def GET_SERIALIZERS(cls) -> list[Serializer]:
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
def GET_BASE_CLASS(cls):
|
||||
return _ComfyNodeBaseInternal
|
||||
|
||||
@final
|
||||
@classmethod
|
||||
def VALIDATE_CLASS(cls):
|
||||
if not callable(cls.DEFINE_SCHEMA):
|
||||
raise Exception(f"No DEFINE_SCHEMA function was defined for node class {cls.__name__}.")
|
||||
if not callable(cls.execute):
|
||||
if first_real_override(cls, "define_schema") is None:
|
||||
raise Exception(f"No define_schema function was defined for node class {cls.__name__}.")
|
||||
if first_real_override(cls, "execute") is None:
|
||||
raise Exception(f"No execute function was defined for node class {cls.__name__}.")
|
||||
|
||||
@final
|
||||
@classmethod
|
||||
def EXECUTE_NORMALIZED(cls, *args, **kwargs) -> NodeOutput:
|
||||
to_return = cls.execute(*args, **kwargs)
|
||||
@@ -1228,6 +1333,7 @@ class ComfyNodeV3(ComfyNodeInternal):
|
||||
else:
|
||||
raise Exception(f"Invalid return type from node: {type(to_return)}")
|
||||
|
||||
@final
|
||||
@classmethod
|
||||
def PREPARE_CLASS_CLONE(cls, hidden_inputs: dict) -> type[ComfyNodeV3]:
|
||||
"""Creates clone of real node class to prevent monkey-patching."""
|
||||
@@ -1237,10 +1343,24 @@ class ComfyNodeV3(ComfyNodeInternal):
|
||||
type_clone.hidden = HiddenHolder.from_dict(hidden_inputs)
|
||||
return type_clone
|
||||
|
||||
@final
|
||||
@classmethod
|
||||
def GET_NODE_INFO_V3(cls) -> dict[str, Any]:
|
||||
schema = cls.GET_SCHEMA()
|
||||
info = schema.get_v3_info(cls)
|
||||
return asdict(info)
|
||||
#############################################
|
||||
# V1 Backwards Compatibility code
|
||||
#--------------------------------------------
|
||||
@final
|
||||
@classmethod
|
||||
def GET_NODE_INFO_V1(cls) -> dict[str, Any]:
|
||||
schema = cls.GET_SCHEMA()
|
||||
info = schema.get_v1_info(cls)
|
||||
return asdict(info)
|
||||
|
||||
_DESCRIPTION = None
|
||||
@final
|
||||
@classproperty
|
||||
def DESCRIPTION(cls): # noqa
|
||||
if cls._DESCRIPTION is None:
|
||||
@@ -1248,6 +1368,7 @@ class ComfyNodeV3(ComfyNodeInternal):
|
||||
return cls._DESCRIPTION
|
||||
|
||||
_CATEGORY = None
|
||||
@final
|
||||
@classproperty
|
||||
def CATEGORY(cls): # noqa
|
||||
if cls._CATEGORY is None:
|
||||
@@ -1255,6 +1376,7 @@ class ComfyNodeV3(ComfyNodeInternal):
|
||||
return cls._CATEGORY
|
||||
|
||||
_EXPERIMENTAL = None
|
||||
@final
|
||||
@classproperty
|
||||
def EXPERIMENTAL(cls): # noqa
|
||||
if cls._EXPERIMENTAL is None:
|
||||
@@ -1262,6 +1384,7 @@ class ComfyNodeV3(ComfyNodeInternal):
|
||||
return cls._EXPERIMENTAL
|
||||
|
||||
_DEPRECATED = None
|
||||
@final
|
||||
@classproperty
|
||||
def DEPRECATED(cls): # noqa
|
||||
if cls._DEPRECATED is None:
|
||||
@@ -1269,6 +1392,7 @@ class ComfyNodeV3(ComfyNodeInternal):
|
||||
return cls._DEPRECATED
|
||||
|
||||
_API_NODE = None
|
||||
@final
|
||||
@classproperty
|
||||
def API_NODE(cls): # noqa
|
||||
if cls._API_NODE is None:
|
||||
@@ -1276,6 +1400,7 @@ class ComfyNodeV3(ComfyNodeInternal):
|
||||
return cls._API_NODE
|
||||
|
||||
_OUTPUT_NODE = None
|
||||
@final
|
||||
@classproperty
|
||||
def OUTPUT_NODE(cls): # noqa
|
||||
if cls._OUTPUT_NODE is None:
|
||||
@@ -1283,6 +1408,7 @@ class ComfyNodeV3(ComfyNodeInternal):
|
||||
return cls._OUTPUT_NODE
|
||||
|
||||
_INPUT_IS_LIST = None
|
||||
@final
|
||||
@classproperty
|
||||
def INPUT_IS_LIST(cls): # noqa
|
||||
if cls._INPUT_IS_LIST is None:
|
||||
@@ -1290,6 +1416,7 @@ class ComfyNodeV3(ComfyNodeInternal):
|
||||
return cls._INPUT_IS_LIST
|
||||
_OUTPUT_IS_LIST = None
|
||||
|
||||
@final
|
||||
@classproperty
|
||||
def OUTPUT_IS_LIST(cls): # noqa
|
||||
if cls._OUTPUT_IS_LIST is None:
|
||||
@@ -1297,6 +1424,7 @@ class ComfyNodeV3(ComfyNodeInternal):
|
||||
return cls._OUTPUT_IS_LIST
|
||||
|
||||
_RETURN_TYPES = None
|
||||
@final
|
||||
@classproperty
|
||||
def RETURN_TYPES(cls): # noqa
|
||||
if cls._RETURN_TYPES is None:
|
||||
@@ -1304,6 +1432,7 @@ class ComfyNodeV3(ComfyNodeInternal):
|
||||
return cls._RETURN_TYPES
|
||||
|
||||
_RETURN_NAMES = None
|
||||
@final
|
||||
@classproperty
|
||||
def RETURN_NAMES(cls): # noqa
|
||||
if cls._RETURN_NAMES is None:
|
||||
@@ -1311,6 +1440,7 @@ class ComfyNodeV3(ComfyNodeInternal):
|
||||
return cls._RETURN_NAMES
|
||||
|
||||
_OUTPUT_TOOLTIPS = None
|
||||
@final
|
||||
@classproperty
|
||||
def OUTPUT_TOOLTIPS(cls): # noqa
|
||||
if cls._OUTPUT_TOOLTIPS is None:
|
||||
@@ -1318,6 +1448,7 @@ class ComfyNodeV3(ComfyNodeInternal):
|
||||
return cls._OUTPUT_TOOLTIPS
|
||||
|
||||
_NOT_IDEMPOTENT = None
|
||||
@final
|
||||
@classproperty
|
||||
def NOT_IDEMPOTENT(cls): # noqa
|
||||
if cls._NOT_IDEMPOTENT is None:
|
||||
@@ -1326,35 +1457,27 @@ class ComfyNodeV3(ComfyNodeInternal):
|
||||
|
||||
FUNCTION = "EXECUTE_NORMALIZED"
|
||||
|
||||
@final
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls, include_hidden=True, return_schema=False) -> dict[str, dict] | tuple[dict[str, dict], SchemaV3]:
|
||||
schema = cls.FINALIZE_SCHEMA()
|
||||
# for V1, make inputs be a dict with potential keys {required, optional, hidden}
|
||||
input = {
|
||||
"required": {}
|
||||
}
|
||||
if schema.inputs:
|
||||
for i in schema.inputs:
|
||||
if isinstance(i, DynamicInput):
|
||||
dynamic_inputs = i.get_dynamic()
|
||||
for d in dynamic_inputs:
|
||||
add_to_dict_v1(d, input)
|
||||
else:
|
||||
add_to_dict_v1(i, input)
|
||||
if schema.hidden and include_hidden:
|
||||
for hidden in schema.hidden:
|
||||
input.setdefault("hidden", {})[hidden.name] = (hidden.value,)
|
||||
info = schema.get_v1_info(cls)
|
||||
input = info.input
|
||||
if not include_hidden:
|
||||
input.pop("hidden", None)
|
||||
if return_schema:
|
||||
return input, schema
|
||||
return input
|
||||
|
||||
@final
|
||||
@classmethod
|
||||
def FINALIZE_SCHEMA(cls):
|
||||
"""Call DEFINE_SCHEMA and finalize it."""
|
||||
schema = cls.DEFINE_SCHEMA()
|
||||
"""Call define_schema and finalize it."""
|
||||
schema = cls.define_schema()
|
||||
schema.finalize()
|
||||
return schema
|
||||
|
||||
@final
|
||||
@classmethod
|
||||
def GET_SCHEMA(cls) -> SchemaV3:
|
||||
"""Validate node class, finalize schema, validate schema, and set expected class properties."""
|
||||
@@ -1396,47 +1519,58 @@ class ComfyNodeV3(ComfyNodeInternal):
|
||||
cls._OUTPUT_TOOLTIPS = output_tooltips
|
||||
cls.SCHEMA = schema
|
||||
return schema
|
||||
|
||||
@classmethod
|
||||
def GET_NODE_INFO_V1(cls) -> dict[str, Any]:
|
||||
schema = cls.GET_SCHEMA()
|
||||
# get V1 inputs
|
||||
input = cls.INPUT_TYPES()
|
||||
|
||||
# create separate lists from output fields
|
||||
output = []
|
||||
output_is_list = []
|
||||
output_name = []
|
||||
output_tooltips = []
|
||||
if schema.outputs:
|
||||
for o in schema.outputs:
|
||||
output.append(o.io_type)
|
||||
output_is_list.append(o.is_output_list)
|
||||
output_name.append(o.display_name if o.display_name else o.io_type)
|
||||
output_tooltips.append(o.tooltip if o.tooltip else None)
|
||||
|
||||
info = NodeInfoV1(
|
||||
input=input,
|
||||
input_order={key: list(value.keys()) for (key, value) in input.items()},
|
||||
output=output,
|
||||
output_is_list=output_is_list,
|
||||
output_name=output_name,
|
||||
output_tooltips=output_tooltips,
|
||||
name=schema.node_id,
|
||||
display_name=schema.display_name,
|
||||
category=schema.category,
|
||||
description=schema.description,
|
||||
output_node=schema.is_output_node,
|
||||
deprecated=schema.is_deprecated,
|
||||
experimental=schema.is_experimental,
|
||||
api_node=schema.is_api_node,
|
||||
python_module=getattr(cls, "RELATIVE_PYTHON_MODULE", "nodes")
|
||||
)
|
||||
return asdict(info)
|
||||
#--------------------------------------------
|
||||
#############################################
|
||||
|
||||
|
||||
class ComfyNodeV3(_ComfyNodeBaseInternal):
|
||||
"""Common base class for all V3 nodes."""
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def define_schema(cls) -> SchemaV3:
|
||||
"""Override this function with one that returns a SchemaV3 instance."""
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def execute(cls, **kwargs) -> NodeOutput:
|
||||
"""Override this function with one that performs node's actions."""
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def validate_inputs(cls, **kwargs) -> bool:
|
||||
"""Optionally, define this function to validate inputs; equivalent to V1's VALIDATE_INPUTS."""
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def fingerprint_inputs(cls, **kwargs) -> Any:
|
||||
"""Optionally, define this function to fingerprint inputs; equivalent to V1's IS_CHANGED."""
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def check_lazy_status(cls, **kwargs) -> list[str]:
|
||||
"""Optionally, define this function to return a list of input names that should be evaluated.
|
||||
|
||||
This basic mixin impl. requires all inputs.
|
||||
|
||||
:kwargs: All node inputs will be included here. If the input is ``None``, it should be assumed that it has not yet been evaluated. \
|
||||
When using ``INPUT_IS_LIST = True``, unevaluated will instead be ``(None,)``.
|
||||
|
||||
Params should match the nodes execution ``FUNCTION`` (self, and all inputs by name).
|
||||
Will be executed repeatedly until it returns an empty list, or all requested items were already evaluated (and sent as params).
|
||||
|
||||
Comfy Docs: https://docs.comfy.org/custom-nodes/backend/lazy_evaluation#defining-check-lazy-status
|
||||
"""
|
||||
return [name for name in kwargs if kwargs[name] is None]
|
||||
|
||||
@final
|
||||
@classmethod
|
||||
def GET_BASE_CLASS(cls):
|
||||
"""DO NOT override this class. Will break things in execution.py."""
|
||||
return ComfyNodeV3
|
||||
|
||||
|
||||
class NodeOutput:
|
||||
'''
|
||||
Standardized output of a node; can pass in any number of args and/or a UIOutput into 'ui' kwarg.
|
||||
@@ -1478,57 +1612,4 @@ class _UIOutput(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def as_dict(self) -> dict:
|
||||
... # TODO: finish
|
||||
|
||||
class TestNode(ComfyNodeV3):
|
||||
@classmethod
|
||||
def DEFINE_SCHEMA(cls):
|
||||
return SchemaV3(
|
||||
node_id="TestNode_v3",
|
||||
display_name="Test Node (V3)",
|
||||
category="v3_test",
|
||||
inputs=[Int.Input("my_int"),
|
||||
#AutoGrowDynamicInput("growing", Image.Input),
|
||||
Mask.Input("thing"),
|
||||
],
|
||||
outputs=[Image.Output("image_output")],
|
||||
hidden=[Hidden.api_key_comfy_org, Hidden.auth_token_comfy_org, Hidden.unique_id]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, **kwargs):
|
||||
pass
|
||||
|
||||
# if __name__ == "__main__":
|
||||
# print("hello there")
|
||||
# inputs: list[InputV3] = [
|
||||
# Int.Input("tessfes", widgetType=String.io_type),
|
||||
# Int.Input("my_int"),
|
||||
# Custom("XYZ").Input("xyz"),
|
||||
# Custom("MODEL_M").Input("model1"),
|
||||
# Image.Input("my_image"),
|
||||
# Float.Input("my_float"),
|
||||
# MultiType.Input("my_inputs", [String, Custom("MODEL_M"), Custom("XYZ")]),
|
||||
# ]
|
||||
# Custom("XYZ").Input()
|
||||
# outputs: list[OutputV3] = [
|
||||
# Image.Output("image"),
|
||||
# Custom("XYZ").Output("xyz"),
|
||||
# ]
|
||||
#
|
||||
# for c in inputs:
|
||||
# if isinstance(c, MultiType):
|
||||
# print(f"{c}, {type(c)}, {type(c).io_type}, {c.id}, {[x.io_type for x in c.io_types]}")
|
||||
# print(c.get_io_type_V1())
|
||||
# else:
|
||||
# print(f"{c}, {type(c)}, {type(c).io_type}, {c.id}")
|
||||
#
|
||||
# for c in outputs:
|
||||
# print(f"{c}, {type(c)}, {type(c).io_type}, {c.id}")
|
||||
#
|
||||
# zz = TestNode()
|
||||
# print(zz.GET_NODE_INFO_V1())
|
||||
#
|
||||
# # aa = NodeInfoV1()
|
||||
# # print(asdict(aa))
|
||||
# # print(as_pruned_dict(aa))
|
||||
...
|
||||
|
@@ -27,7 +27,7 @@ class V3TestNode(io.ComfyNodeV3):
|
||||
self.hahajkunless = ";)"
|
||||
|
||||
@classmethod
|
||||
def DEFINE_SCHEMA(cls):
|
||||
def define_schema(cls):
|
||||
return io.SchemaV3(
|
||||
node_id="V3_01_TestNode1",
|
||||
display_name="V3 Test Node",
|
||||
@@ -113,7 +113,7 @@ class V3TestNode(io.ComfyNodeV3):
|
||||
|
||||
class V3LoraLoader(io.ComfyNodeV3):
|
||||
@classmethod
|
||||
def DEFINE_SCHEMA(cls):
|
||||
def define_schema(cls):
|
||||
return io.SchemaV3(
|
||||
node_id="V3_LoraLoader",
|
||||
display_name="V3 LoRA Loader",
|
||||
@@ -163,7 +163,7 @@ class V3LoraLoader(io.ComfyNodeV3):
|
||||
|
||||
class NInputsTest(io.ComfyNodeV3):
|
||||
@classmethod
|
||||
def DEFINE_SCHEMA(cls):
|
||||
def define_schema(cls):
|
||||
return io.SchemaV3(
|
||||
node_id="V3_NInputsTest",
|
||||
display_name="V3 N Inputs Test",
|
||||
|
57
comfy_extras/v3/nodes_ace.py
Normal file
57
comfy_extras/v3/nodes_ace.py
Normal file
@@ -0,0 +1,57 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
|
||||
import comfy.model_management
|
||||
import node_helpers
|
||||
from comfy_api.v3 import io
|
||||
|
||||
|
||||
class TextEncodeAceStepAudio(io.ComfyNodeV3):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.SchemaV3(
|
||||
node_id="TextEncodeAceStepAudio_V3",
|
||||
category="conditioning",
|
||||
inputs=[
|
||||
io.Clip.Input("clip"),
|
||||
io.String.Input("tags", multiline=True, dynamic_prompts=True),
|
||||
io.String.Input("lyrics", multiline=True, dynamic_prompts=True),
|
||||
io.Float.Input("lyrics_strength", default=1.0, min=0.0, max=10.0, step=0.01),
|
||||
],
|
||||
outputs=[io.Conditioning.Output()],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, clip, tags, lyrics, lyrics_strength) -> io.NodeOutput:
|
||||
conditioning = clip.encode_from_tokens_scheduled(clip.tokenize(tags, lyrics=lyrics))
|
||||
conditioning = node_helpers.conditioning_set_values(conditioning, {"lyrics_strength": lyrics_strength})
|
||||
return io.NodeOutput(conditioning)
|
||||
|
||||
|
||||
class EmptyAceStepLatentAudio(io.ComfyNodeV3):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.SchemaV3(
|
||||
node_id="EmptyAceStepLatentAudio_V3",
|
||||
category="latent/audio",
|
||||
inputs=[
|
||||
io.Float.Input("seconds", default=120.0, min=1.0, max=1000.0, step=0.1),
|
||||
io.Int.Input(
|
||||
"batch_size", default=1, min=1, max=4096, tooltip="The number of latent images in the batch."
|
||||
),
|
||||
],
|
||||
outputs=[io.Latent.Output()],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, seconds, batch_size) -> io.NodeOutput:
|
||||
length = int(seconds * 44100 / 512 / 8)
|
||||
latent = torch.zeros([batch_size, 8, 16, length], device=comfy.model_management.intermediate_device())
|
||||
return io.NodeOutput({"samples": latent, "type": "audio"})
|
||||
|
||||
|
||||
NODES_LIST: list[type[io.ComfyNodeV3]] = [
|
||||
TextEncodeAceStepAudio,
|
||||
EmptyAceStepLatentAudio,
|
||||
]
|
128
comfy_extras/v3/nodes_advanced_samplers.py
Normal file
128
comfy_extras/v3/nodes_advanced_samplers.py
Normal file
@@ -0,0 +1,128 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
from tqdm.auto import trange
|
||||
|
||||
import comfy.model_patcher
|
||||
import comfy.samplers
|
||||
import comfy.utils
|
||||
from comfy.k_diffusion.sampling import to_d
|
||||
from comfy_api.v3 import io
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_lcm_upscale(
|
||||
model, x, sigmas, extra_args=None, callback=None, disable=None, total_upscale=2.0, upscale_method="bislerp", upscale_steps=None
|
||||
):
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
|
||||
if upscale_steps is None:
|
||||
upscale_steps = max(len(sigmas) // 2 + 1, 2)
|
||||
else:
|
||||
upscale_steps += 1
|
||||
upscale_steps = min(upscale_steps, len(sigmas) + 1)
|
||||
|
||||
upscales = np.linspace(1.0, total_upscale, upscale_steps)[1:]
|
||||
|
||||
orig_shape = x.size()
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||
if callback is not None:
|
||||
callback({"x": x, "i": i, "sigma": sigmas[i], "sigma_hat": sigmas[i], "denoised": denoised})
|
||||
|
||||
x = denoised
|
||||
if i < len(upscales):
|
||||
x = comfy.utils.common_upscale(
|
||||
x, round(orig_shape[-1] * upscales[i]), round(orig_shape[-2] * upscales[i]), upscale_method, "disabled"
|
||||
)
|
||||
|
||||
if sigmas[i + 1] > 0:
|
||||
x += sigmas[i + 1] * torch.randn_like(x)
|
||||
return x
|
||||
|
||||
|
||||
class SamplerLCMUpscale(io.ComfyNodeV3):
|
||||
UPSCALE_METHODS = ["bislerp", "nearest-exact", "bilinear", "area", "bicubic"]
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls) -> io.SchemaV3:
|
||||
return io.SchemaV3(
|
||||
node_id="SamplerLCMUpscale_V3",
|
||||
category="sampling/custom_sampling/samplers",
|
||||
inputs=[
|
||||
io.Float.Input("scale_ratio", default=1.0, min=0.1, max=20.0, step=0.01),
|
||||
io.Int.Input("scale_steps", default=-1, min=-1, max=1000, step=1),
|
||||
io.Combo.Input("upscale_method", options=cls.UPSCALE_METHODS),
|
||||
],
|
||||
outputs=[io.Sampler.Output()],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, scale_ratio, scale_steps, upscale_method) -> io.NodeOutput:
|
||||
if scale_steps < 0:
|
||||
scale_steps = None
|
||||
sampler = comfy.samplers.KSAMPLER(
|
||||
sample_lcm_upscale,
|
||||
extra_options={
|
||||
"total_upscale": scale_ratio,
|
||||
"upscale_steps": scale_steps,
|
||||
"upscale_method": upscale_method,
|
||||
},
|
||||
)
|
||||
return io.NodeOutput(sampler)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_euler_pp(model, x, sigmas, extra_args=None, callback=None, disable=None):
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
|
||||
temp = [0]
|
||||
|
||||
def post_cfg_function(args):
|
||||
temp[0] = args["uncond_denoised"]
|
||||
return args["denoised"]
|
||||
|
||||
model_options = extra_args.get("model_options", {}).copy()
|
||||
extra_args["model_options"] = comfy.model_patcher.set_model_options_post_cfg_function(
|
||||
model_options, post_cfg_function, disable_cfg1_optimization=True
|
||||
)
|
||||
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
sigma_hat = sigmas[i]
|
||||
denoised = model(x, sigma_hat * s_in, **extra_args)
|
||||
d = to_d(x - denoised + temp[0], sigmas[i], denoised)
|
||||
if callback is not None:
|
||||
callback({"x": x, "i": i, "sigma": sigmas[i], "sigma_hat": sigma_hat, "denoised": denoised})
|
||||
dt = sigmas[i + 1] - sigma_hat
|
||||
x = x + d * dt
|
||||
return x
|
||||
|
||||
|
||||
class SamplerEulerCFGpp(io.ComfyNodeV3):
|
||||
@classmethod
|
||||
def define_schema(cls) -> io.SchemaV3:
|
||||
return io.SchemaV3(
|
||||
node_id="SamplerEulerCFGpp_V3",
|
||||
display_name="SamplerEulerCFG++ _V3",
|
||||
category="_for_testing",
|
||||
inputs=[
|
||||
io.Combo.Input("version", options=["regular", "alternative"]),
|
||||
],
|
||||
outputs=[io.Sampler.Output()],
|
||||
is_experimental=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, version) -> io.NodeOutput:
|
||||
if version == "alternative":
|
||||
sampler = comfy.samplers.KSAMPLER(sample_euler_pp)
|
||||
else:
|
||||
sampler = comfy.samplers.ksampler("euler_cfg_pp")
|
||||
return io.NodeOutput(sampler)
|
||||
|
||||
|
||||
NODES_LIST = [
|
||||
SamplerLCMUpscale,
|
||||
SamplerEulerCFGpp,
|
||||
]
|
83
comfy_extras/v3/nodes_align_your_steps.py
Normal file
83
comfy_extras/v3/nodes_align_your_steps.py
Normal file
@@ -0,0 +1,83 @@
|
||||
# from: https://research.nvidia.com/labs/toronto-ai/AlignYourSteps/howto.html
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from comfy_api.v3 import io
|
||||
|
||||
NOISE_LEVELS = {
|
||||
"SD1": [
|
||||
14.6146412293,
|
||||
6.4745760956,
|
||||
3.8636745985,
|
||||
2.6946151520,
|
||||
1.8841921177,
|
||||
1.3943805092,
|
||||
0.9642583904,
|
||||
0.6523686016,
|
||||
0.3977456272,
|
||||
0.1515232662,
|
||||
0.0291671582,
|
||||
],
|
||||
"SDXL": [
|
||||
14.6146412293,
|
||||
6.3184485287,
|
||||
3.7681790315,
|
||||
2.1811480769,
|
||||
1.3405244945,
|
||||
0.8620721141,
|
||||
0.5550693289,
|
||||
0.3798540708,
|
||||
0.2332364134,
|
||||
0.1114188177,
|
||||
0.0291671582,
|
||||
],
|
||||
"SVD": [700.00, 54.5, 15.886, 7.977, 4.248, 1.789, 0.981, 0.403, 0.173, 0.034, 0.002],
|
||||
}
|
||||
|
||||
|
||||
def loglinear_interp(t_steps, num_steps):
|
||||
"""Performs log-linear interpolation of a given array of decreasing numbers."""
|
||||
xs = np.linspace(0, 1, len(t_steps))
|
||||
ys = np.log(t_steps[::-1])
|
||||
|
||||
new_xs = np.linspace(0, 1, num_steps)
|
||||
new_ys = np.interp(new_xs, xs, ys)
|
||||
|
||||
return np.exp(new_ys)[::-1].copy()
|
||||
|
||||
|
||||
class AlignYourStepsScheduler(io.ComfyNodeV3):
|
||||
@classmethod
|
||||
def define_schema(cls) -> io.SchemaV3:
|
||||
return io.SchemaV3(
|
||||
node_id="AlignYourStepsScheduler_V3",
|
||||
category="sampling/custom_sampling/schedulers",
|
||||
inputs=[
|
||||
io.Combo.Input("model_type", options=["SD1", "SDXL", "SVD"]),
|
||||
io.Int.Input("steps", default=10, min=1, max=10000),
|
||||
io.Float.Input("denoise", default=1.0, min=0.0, max=1.0, step=0.01),
|
||||
],
|
||||
outputs=[io.Sigmas.Output()],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model_type, steps, denoise) -> io.NodeOutput:
|
||||
total_steps = steps
|
||||
if denoise < 1.0:
|
||||
if denoise <= 0.0:
|
||||
return io.NodeOutput(torch.FloatTensor([]))
|
||||
total_steps = round(steps * denoise)
|
||||
|
||||
sigmas = NOISE_LEVELS[model_type][:]
|
||||
if (steps + 1) != len(sigmas):
|
||||
sigmas = loglinear_interp(sigmas, steps + 1)
|
||||
|
||||
sigmas = sigmas[-(total_steps + 1) :]
|
||||
sigmas[-1] = 0
|
||||
return io.NodeOutput(torch.FloatTensor(sigmas))
|
||||
|
||||
|
||||
NODES_LIST = [
|
||||
AlignYourStepsScheduler,
|
||||
]
|
98
comfy_extras/v3/nodes_apg.py
Normal file
98
comfy_extras/v3/nodes_apg.py
Normal file
@@ -0,0 +1,98 @@
|
||||
import torch
|
||||
|
||||
from comfy_api.v3 import io
|
||||
|
||||
|
||||
def project(v0, v1):
|
||||
v1 = torch.nn.functional.normalize(v1, dim=[-1, -2, -3])
|
||||
v0_parallel = (v0 * v1).sum(dim=[-1, -2, -3], keepdim=True) * v1
|
||||
v0_orthogonal = v0 - v0_parallel
|
||||
return v0_parallel, v0_orthogonal
|
||||
|
||||
|
||||
class APG(io.ComfyNodeV3):
|
||||
@classmethod
|
||||
def define_schema(cls) -> io.SchemaV3:
|
||||
return io.SchemaV3(
|
||||
node_id="APG_V3",
|
||||
display_name="Adaptive Projected Guidance _V3",
|
||||
category="sampling/custom_sampling",
|
||||
inputs=[
|
||||
io.Model.Input("model"),
|
||||
io.Float.Input(
|
||||
"eta",
|
||||
default=1.0,
|
||||
min=-10.0,
|
||||
max=10.0,
|
||||
step=0.01,
|
||||
tooltip="Controls the scale of the parallel guidance vector. Default CFG behavior at a setting of 1.",
|
||||
),
|
||||
io.Float.Input(
|
||||
"norm_threshold",
|
||||
default=5.0,
|
||||
min=0.0,
|
||||
max=50.0,
|
||||
step=0.1,
|
||||
tooltip="Normalize guidance vector to this value, normalization disable at a setting of 0.",
|
||||
),
|
||||
io.Float.Input(
|
||||
"momentum",
|
||||
default=0.0,
|
||||
min=-5.0,
|
||||
max=1.0,
|
||||
step=0.01,
|
||||
tooltip="Controls a running average of guidance during diffusion, disabled at a setting of 0.",
|
||||
),
|
||||
],
|
||||
outputs=[io.Model.Output()],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model, eta, norm_threshold, momentum) -> io.NodeOutput:
|
||||
running_avg = 0
|
||||
prev_sigma = None
|
||||
|
||||
def pre_cfg_function(args):
|
||||
nonlocal running_avg, prev_sigma
|
||||
|
||||
if len(args["conds_out"]) == 1:
|
||||
return args["conds_out"]
|
||||
|
||||
cond = args["conds_out"][0]
|
||||
uncond = args["conds_out"][1]
|
||||
sigma = args["sigma"][0]
|
||||
cond_scale = args["cond_scale"]
|
||||
|
||||
if prev_sigma is not None and sigma > prev_sigma:
|
||||
running_avg = 0
|
||||
prev_sigma = sigma
|
||||
|
||||
guidance = cond - uncond
|
||||
|
||||
if momentum != 0:
|
||||
if not torch.is_tensor(running_avg):
|
||||
running_avg = guidance
|
||||
else:
|
||||
running_avg = momentum * running_avg + guidance
|
||||
guidance = running_avg
|
||||
|
||||
if norm_threshold > 0:
|
||||
guidance_norm = guidance.norm(p=2, dim=[-1, -2, -3], keepdim=True)
|
||||
scale = torch.minimum(torch.ones_like(guidance_norm), norm_threshold / guidance_norm)
|
||||
guidance = guidance * scale
|
||||
|
||||
guidance_parallel, guidance_orthogonal = project(guidance, cond)
|
||||
modified_guidance = guidance_orthogonal + eta * guidance_parallel
|
||||
|
||||
modified_cond = (uncond + modified_guidance) + (cond - uncond) / cond_scale
|
||||
|
||||
return [modified_cond, uncond] + args["conds_out"][2:]
|
||||
|
||||
m = model.clone()
|
||||
m.set_model_sampler_pre_cfg_function(pre_cfg_function)
|
||||
return io.NodeOutput(m)
|
||||
|
||||
|
||||
NODES_LIST = [
|
||||
APG,
|
||||
]
|
139
comfy_extras/v3/nodes_attention_multiply.py
Normal file
139
comfy_extras/v3/nodes_attention_multiply.py
Normal file
@@ -0,0 +1,139 @@
|
||||
from comfy_api.v3 import io
|
||||
|
||||
|
||||
def attention_multiply(attn, model, q, k, v, out):
|
||||
m = model.clone()
|
||||
sd = model.model_state_dict()
|
||||
|
||||
for key in sd:
|
||||
if key.endswith("{}.to_q.bias".format(attn)) or key.endswith("{}.to_q.weight".format(attn)):
|
||||
m.add_patches({key: (None,)}, 0.0, q)
|
||||
if key.endswith("{}.to_k.bias".format(attn)) or key.endswith("{}.to_k.weight".format(attn)):
|
||||
m.add_patches({key: (None,)}, 0.0, k)
|
||||
if key.endswith("{}.to_v.bias".format(attn)) or key.endswith("{}.to_v.weight".format(attn)):
|
||||
m.add_patches({key: (None,)}, 0.0, v)
|
||||
if key.endswith("{}.to_out.0.bias".format(attn)) or key.endswith("{}.to_out.0.weight".format(attn)):
|
||||
m.add_patches({key: (None,)}, 0.0, out)
|
||||
return m
|
||||
|
||||
|
||||
class UNetSelfAttentionMultiply(io.ComfyNodeV3):
|
||||
@classmethod
|
||||
def define_schema(cls) -> io.SchemaV3:
|
||||
return io.SchemaV3(
|
||||
node_id="UNetSelfAttentionMultiply_V3",
|
||||
category="_for_testing/attention_experiments",
|
||||
inputs=[
|
||||
io.Model.Input("model"),
|
||||
io.Float.Input("q", default=1.0, min=0.0, max=10.0, step=0.01),
|
||||
io.Float.Input("k", default=1.0, min=0.0, max=10.0, step=0.01),
|
||||
io.Float.Input("v", default=1.0, min=0.0, max=10.0, step=0.01),
|
||||
io.Float.Input("out", default=1.0, min=0.0, max=10.0, step=0.01),
|
||||
],
|
||||
outputs=[io.Model.Output()],
|
||||
is_experimental=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model, q, k, v, out) -> io.NodeOutput:
|
||||
return io.NodeOutput(attention_multiply("attn1", model, q, k, v, out))
|
||||
|
||||
|
||||
class UNetCrossAttentionMultiply(io.ComfyNodeV3):
|
||||
@classmethod
|
||||
def define_schema(cls) -> io.SchemaV3:
|
||||
return io.SchemaV3(
|
||||
node_id="UNetCrossAttentionMultiply_V3",
|
||||
category="_for_testing/attention_experiments",
|
||||
inputs=[
|
||||
io.Model.Input("model"),
|
||||
io.Float.Input("q", default=1.0, min=0.0, max=10.0, step=0.01),
|
||||
io.Float.Input("k", default=1.0, min=0.0, max=10.0, step=0.01),
|
||||
io.Float.Input("v", default=1.0, min=0.0, max=10.0, step=0.01),
|
||||
io.Float.Input("out", default=1.0, min=0.0, max=10.0, step=0.01),
|
||||
],
|
||||
outputs=[io.Model.Output()],
|
||||
is_experimental=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model, q, k, v, out) -> io.NodeOutput:
|
||||
return io.NodeOutput(attention_multiply("attn2", model, q, k, v, out))
|
||||
|
||||
|
||||
class CLIPAttentionMultiply(io.ComfyNodeV3):
|
||||
@classmethod
|
||||
def define_schema(cls) -> io.SchemaV3:
|
||||
return io.SchemaV3(
|
||||
node_id="CLIPAttentionMultiply_V3",
|
||||
category="_for_testing/attention_experiments",
|
||||
inputs=[
|
||||
io.Clip.Input("clip"),
|
||||
io.Float.Input("q", default=1.0, min=0.0, max=10.0, step=0.01),
|
||||
io.Float.Input("k", default=1.0, min=0.0, max=10.0, step=0.01),
|
||||
io.Float.Input("v", default=1.0, min=0.0, max=10.0, step=0.01),
|
||||
io.Float.Input("out", default=1.0, min=0.0, max=10.0, step=0.01),
|
||||
],
|
||||
outputs=[io.Clip.Output()],
|
||||
is_experimental=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, clip, q, k, v, out) -> io.NodeOutput:
|
||||
m = clip.clone()
|
||||
sd = m.patcher.model_state_dict()
|
||||
|
||||
for key in sd:
|
||||
if key.endswith("self_attn.q_proj.weight") or key.endswith("self_attn.q_proj.bias"):
|
||||
m.add_patches({key: (None,)}, 0.0, q)
|
||||
if key.endswith("self_attn.k_proj.weight") or key.endswith("self_attn.k_proj.bias"):
|
||||
m.add_patches({key: (None,)}, 0.0, k)
|
||||
if key.endswith("self_attn.v_proj.weight") or key.endswith("self_attn.v_proj.bias"):
|
||||
m.add_patches({key: (None,)}, 0.0, v)
|
||||
if key.endswith("self_attn.out_proj.weight") or key.endswith("self_attn.out_proj.bias"):
|
||||
m.add_patches({key: (None,)}, 0.0, out)
|
||||
return io.NodeOutput(m)
|
||||
|
||||
|
||||
class UNetTemporalAttentionMultiply(io.ComfyNodeV3):
|
||||
@classmethod
|
||||
def define_schema(cls) -> io.SchemaV3:
|
||||
return io.SchemaV3(
|
||||
node_id="UNetTemporalAttentionMultiply_V3",
|
||||
category="_for_testing/attention_experiments",
|
||||
inputs=[
|
||||
io.Model.Input("model"),
|
||||
io.Float.Input("self_structural", default=1.0, min=0.0, max=10.0, step=0.01),
|
||||
io.Float.Input("self_temporal", default=1.0, min=0.0, max=10.0, step=0.01),
|
||||
io.Float.Input("cross_structural", default=1.0, min=0.0, max=10.0, step=0.01),
|
||||
io.Float.Input("cross_temporal", default=1.0, min=0.0, max=10.0, step=0.01),
|
||||
],
|
||||
outputs=[io.Model.Output()],
|
||||
is_experimental=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model, self_structural, self_temporal, cross_structural, cross_temporal) -> io.NodeOutput:
|
||||
m = model.clone()
|
||||
sd = model.model_state_dict()
|
||||
|
||||
for k in sd:
|
||||
if (k.endswith("attn1.to_out.0.bias") or k.endswith("attn1.to_out.0.weight")):
|
||||
if '.time_stack.' in k:
|
||||
m.add_patches({k: (None,)}, 0.0, self_temporal)
|
||||
else:
|
||||
m.add_patches({k: (None,)}, 0.0, self_structural)
|
||||
elif (k.endswith("attn2.to_out.0.bias") or k.endswith("attn2.to_out.0.weight")):
|
||||
if '.time_stack.' in k:
|
||||
m.add_patches({k: (None,)}, 0.0, cross_temporal)
|
||||
else:
|
||||
m.add_patches({k: (None,)}, 0.0, cross_structural)
|
||||
return io.NodeOutput(m)
|
||||
|
||||
|
||||
NODES_LIST = [
|
||||
UNetSelfAttentionMultiply,
|
||||
UNetCrossAttentionMultiply,
|
||||
CLIPAttentionMultiply,
|
||||
UNetTemporalAttentionMultiply,
|
||||
]
|
@@ -16,9 +16,9 @@ from comfy.cli_args import args
|
||||
from comfy_api.v3 import io, ui
|
||||
|
||||
|
||||
class ConditioningStableAudio_V3(io.ComfyNodeV3):
|
||||
class ConditioningStableAudio(io.ComfyNodeV3):
|
||||
@classmethod
|
||||
def DEFINE_SCHEMA(cls):
|
||||
def define_schema(cls):
|
||||
return io.SchemaV3(
|
||||
node_id="ConditioningStableAudio_V3",
|
||||
category="conditioning",
|
||||
@@ -46,9 +46,9 @@ class ConditioningStableAudio_V3(io.ComfyNodeV3):
|
||||
)
|
||||
|
||||
|
||||
class EmptyLatentAudio_V3(io.ComfyNodeV3):
|
||||
class EmptyLatentAudio(io.ComfyNodeV3):
|
||||
@classmethod
|
||||
def DEFINE_SCHEMA(cls):
|
||||
def define_schema(cls):
|
||||
return io.SchemaV3(
|
||||
node_id="EmptyLatentAudio_V3",
|
||||
category="latent/audio",
|
||||
@@ -68,9 +68,9 @@ class EmptyLatentAudio_V3(io.ComfyNodeV3):
|
||||
return io.NodeOutput({"samples": latent, "type": "audio"})
|
||||
|
||||
|
||||
class LoadAudio_V3(io.ComfyNodeV3):
|
||||
class LoadAudio(io.ComfyNodeV3):
|
||||
@classmethod
|
||||
def DEFINE_SCHEMA(cls):
|
||||
def define_schema(cls):
|
||||
return io.SchemaV3(
|
||||
node_id="LoadAudio_V3", # frontend expects "LoadAudio" to work
|
||||
display_name="Load Audio _V3", # frontend ignores "display_name" for this node
|
||||
@@ -106,9 +106,9 @@ class LoadAudio_V3(io.ComfyNodeV3):
|
||||
return True
|
||||
|
||||
|
||||
class PreviewAudio_V3(io.ComfyNodeV3):
|
||||
class PreviewAudio(io.ComfyNodeV3):
|
||||
@classmethod
|
||||
def DEFINE_SCHEMA(cls):
|
||||
def define_schema(cls):
|
||||
return io.SchemaV3(
|
||||
node_id="PreviewAudio_V3", # frontend expects "PreviewAudio" to work
|
||||
display_name="Preview Audio _V3", # frontend ignores "display_name" for this node
|
||||
@@ -125,9 +125,9 @@ class PreviewAudio_V3(io.ComfyNodeV3):
|
||||
return io.NodeOutput(ui=ui.PreviewAudio(audio, cls=cls))
|
||||
|
||||
|
||||
class SaveAudioMP3_V3(io.ComfyNodeV3):
|
||||
class SaveAudioMP3(io.ComfyNodeV3):
|
||||
@classmethod
|
||||
def DEFINE_SCHEMA(cls):
|
||||
def define_schema(cls):
|
||||
return io.SchemaV3(
|
||||
node_id="SaveAudioMP3_V3", # frontend expects "SaveAudioMP3" to work
|
||||
display_name="Save Audio(MP3) _V3", # frontend ignores "display_name" for this node
|
||||
@@ -146,9 +146,9 @@ class SaveAudioMP3_V3(io.ComfyNodeV3):
|
||||
return _save_audio(self, audio, filename_prefix, format, quality)
|
||||
|
||||
|
||||
class SaveAudioOpus_V3(io.ComfyNodeV3):
|
||||
class SaveAudioOpus(io.ComfyNodeV3):
|
||||
@classmethod
|
||||
def DEFINE_SCHEMA(cls):
|
||||
def define_schema(cls):
|
||||
return io.SchemaV3(
|
||||
node_id="SaveAudioOpus_V3", # frontend expects "SaveAudioOpus" to work
|
||||
display_name="Save Audio(Opus) _V3", # frontend ignores "display_name" for this node
|
||||
@@ -167,9 +167,9 @@ class SaveAudioOpus_V3(io.ComfyNodeV3):
|
||||
return _save_audio(self, audio, filename_prefix, format, quality)
|
||||
|
||||
|
||||
class SaveAudio_V3(io.ComfyNodeV3):
|
||||
class SaveAudio(io.ComfyNodeV3):
|
||||
@classmethod
|
||||
def DEFINE_SCHEMA(cls):
|
||||
def define_schema(cls):
|
||||
return io.SchemaV3(
|
||||
node_id="SaveAudio_V3", # frontend expects "SaveAudio" to work
|
||||
display_name="Save Audio _V3", # frontend ignores "display_name" for this node
|
||||
@@ -187,9 +187,9 @@ class SaveAudio_V3(io.ComfyNodeV3):
|
||||
return _save_audio(cls, audio, filename_prefix, format)
|
||||
|
||||
|
||||
class VAEDecodeAudio_V3(io.ComfyNodeV3):
|
||||
class VAEDecodeAudio(io.ComfyNodeV3):
|
||||
@classmethod
|
||||
def DEFINE_SCHEMA(cls):
|
||||
def define_schema(cls):
|
||||
return io.SchemaV3(
|
||||
node_id="VAEDecodeAudio_V3",
|
||||
category="latent/audio",
|
||||
@@ -209,9 +209,9 @@ class VAEDecodeAudio_V3(io.ComfyNodeV3):
|
||||
return io.NodeOutput({"waveform": audio, "sample_rate": 44100})
|
||||
|
||||
|
||||
class VAEEncodeAudio_V3(io.ComfyNodeV3):
|
||||
class VAEEncodeAudio(io.ComfyNodeV3):
|
||||
@classmethod
|
||||
def DEFINE_SCHEMA(cls):
|
||||
def define_schema(cls):
|
||||
return io.SchemaV3(
|
||||
node_id="VAEEncodeAudio_V3",
|
||||
category="latent/audio",
|
||||
@@ -335,13 +335,13 @@ def _save_audio(cls, audio, filename_prefix="ComfyUI", format="flac", quality="1
|
||||
|
||||
|
||||
NODES_LIST: list[type[io.ComfyNodeV3]] = [
|
||||
ConditioningStableAudio_V3,
|
||||
EmptyLatentAudio_V3,
|
||||
LoadAudio_V3,
|
||||
PreviewAudio_V3,
|
||||
SaveAudioMP3_V3,
|
||||
SaveAudioOpus_V3,
|
||||
SaveAudio_V3,
|
||||
VAEDecodeAudio_V3,
|
||||
VAEEncodeAudio_V3,
|
||||
ConditioningStableAudio,
|
||||
EmptyLatentAudio,
|
||||
LoadAudio,
|
||||
PreviewAudio,
|
||||
SaveAudioMP3,
|
||||
SaveAudioOpus,
|
||||
SaveAudio,
|
||||
VAEDecodeAudio,
|
||||
VAEEncodeAudio,
|
||||
]
|
||||
|
@@ -3,9 +3,9 @@ from comfy.cldm.control_types import UNION_CONTROLNET_TYPES
|
||||
from comfy_api.v3 import io
|
||||
|
||||
|
||||
class ControlNetApplyAdvanced_V3(io.ComfyNodeV3):
|
||||
class ControlNetApplyAdvanced(io.ComfyNodeV3):
|
||||
@classmethod
|
||||
def DEFINE_SCHEMA(cls):
|
||||
def define_schema(cls):
|
||||
return io.SchemaV3(
|
||||
node_id="ControlNetApplyAdvanced_V3",
|
||||
display_name="Apply ControlNet _V3",
|
||||
@@ -60,9 +60,9 @@ class ControlNetApplyAdvanced_V3(io.ComfyNodeV3):
|
||||
return io.NodeOutput(out[0], out[1])
|
||||
|
||||
|
||||
class SetUnionControlNetType_V3(io.ComfyNodeV3):
|
||||
class SetUnionControlNetType(io.ComfyNodeV3):
|
||||
@classmethod
|
||||
def DEFINE_SCHEMA(cls):
|
||||
def define_schema(cls):
|
||||
return io.SchemaV3(
|
||||
node_id="SetUnionControlNetType_V3",
|
||||
category="conditioning/controlnet",
|
||||
@@ -87,9 +87,9 @@ class SetUnionControlNetType_V3(io.ComfyNodeV3):
|
||||
return io.NodeOutput(control_net)
|
||||
|
||||
|
||||
class ControlNetInpaintingAliMamaApply_V3(ControlNetApplyAdvanced_V3):
|
||||
class ControlNetInpaintingAliMamaApply(ControlNetApplyAdvanced):
|
||||
@classmethod
|
||||
def DEFINE_SCHEMA(cls):
|
||||
def define_schema(cls):
|
||||
return io.SchemaV3(
|
||||
node_id="ControlNetInpaintingAliMamaApply_V3",
|
||||
category="conditioning/controlnet",
|
||||
@@ -135,7 +135,7 @@ class ControlNetInpaintingAliMamaApply_V3(ControlNetApplyAdvanced_V3):
|
||||
|
||||
|
||||
NODES_LIST: list[type[io.ComfyNodeV3]] = [
|
||||
ControlNetApplyAdvanced_V3,
|
||||
SetUnionControlNetType_V3,
|
||||
ControlNetInpaintingAliMamaApply_V3,
|
||||
ControlNetApplyAdvanced,
|
||||
SetUnionControlNetType,
|
||||
ControlNetInpaintingAliMamaApply,
|
||||
]
|
||||
|
@@ -18,7 +18,7 @@ from server import PromptServer
|
||||
|
||||
class GetImageSize(io.ComfyNodeV3):
|
||||
@classmethod
|
||||
def DEFINE_SCHEMA(cls):
|
||||
def define_schema(cls):
|
||||
return io.SchemaV3(
|
||||
node_id="GetImageSize_V3",
|
||||
display_name="Get Image Size _V3",
|
||||
@@ -51,7 +51,7 @@ class GetImageSize(io.ComfyNodeV3):
|
||||
|
||||
class ImageAddNoise(io.ComfyNodeV3):
|
||||
@classmethod
|
||||
def DEFINE_SCHEMA(cls):
|
||||
def define_schema(cls):
|
||||
return io.SchemaV3(
|
||||
node_id="ImageAddNoise_V3",
|
||||
display_name="Image Add Noise _V3",
|
||||
@@ -84,7 +84,7 @@ class ImageAddNoise(io.ComfyNodeV3):
|
||||
|
||||
class ImageCrop(io.ComfyNodeV3):
|
||||
@classmethod
|
||||
def DEFINE_SCHEMA(cls):
|
||||
def define_schema(cls):
|
||||
return io.SchemaV3(
|
||||
node_id="ImageCrop_V3",
|
||||
display_name="Image Crop _V3",
|
||||
@@ -110,7 +110,7 @@ class ImageCrop(io.ComfyNodeV3):
|
||||
|
||||
class ImageFlip(io.ComfyNodeV3):
|
||||
@classmethod
|
||||
def DEFINE_SCHEMA(cls):
|
||||
def define_schema(cls):
|
||||
return io.SchemaV3(
|
||||
node_id="ImageFlip_V3",
|
||||
display_name="Image Flip _V3",
|
||||
@@ -134,7 +134,7 @@ class ImageFlip(io.ComfyNodeV3):
|
||||
|
||||
class ImageFromBatch(io.ComfyNodeV3):
|
||||
@classmethod
|
||||
def DEFINE_SCHEMA(cls):
|
||||
def define_schema(cls):
|
||||
return io.SchemaV3(
|
||||
node_id="ImageFromBatch_V3",
|
||||
display_name="Image From Batch _V3",
|
||||
@@ -158,7 +158,7 @@ class ImageFromBatch(io.ComfyNodeV3):
|
||||
|
||||
class ImageRotate(io.ComfyNodeV3):
|
||||
@classmethod
|
||||
def DEFINE_SCHEMA(cls):
|
||||
def define_schema(cls):
|
||||
return io.SchemaV3(
|
||||
node_id="ImageRotate_V3",
|
||||
display_name="Image Rotate _V3",
|
||||
@@ -187,7 +187,7 @@ class ImageStitch(io.ComfyNodeV3):
|
||||
"""Upstreamed from https://github.com/kijai/ComfyUI-KJNodes"""
|
||||
|
||||
@classmethod
|
||||
def DEFINE_SCHEMA(cls):
|
||||
def define_schema(cls):
|
||||
return io.SchemaV3(
|
||||
node_id="ImageStitch_V3",
|
||||
display_name="Image Stitch _V3",
|
||||
@@ -355,7 +355,7 @@ class ImageStitch(io.ComfyNodeV3):
|
||||
|
||||
class LoadImage(io.ComfyNodeV3):
|
||||
@classmethod
|
||||
def DEFINE_SCHEMA(cls):
|
||||
def define_schema(cls):
|
||||
return io.SchemaV3(
|
||||
node_id="LoadImage_V3",
|
||||
display_name="Load Image _V3",
|
||||
@@ -443,7 +443,7 @@ class LoadImage(io.ComfyNodeV3):
|
||||
|
||||
class LoadImageOutput(io.ComfyNodeV3):
|
||||
@classmethod
|
||||
def DEFINE_SCHEMA(cls):
|
||||
def define_schema(cls):
|
||||
return io.SchemaV3(
|
||||
node_id="LoadImageOutput_V3",
|
||||
display_name="Load Image (from Outputs) _V3",
|
||||
@@ -532,7 +532,7 @@ class LoadImageOutput(io.ComfyNodeV3):
|
||||
|
||||
class PreviewImage(io.ComfyNodeV3):
|
||||
@classmethod
|
||||
def DEFINE_SCHEMA(cls):
|
||||
def define_schema(cls):
|
||||
return io.SchemaV3(
|
||||
node_id="PreviewImage_V3",
|
||||
display_name="Preview Image _V3",
|
||||
@@ -552,7 +552,7 @@ class PreviewImage(io.ComfyNodeV3):
|
||||
|
||||
class RepeatImageBatch(io.ComfyNodeV3):
|
||||
@classmethod
|
||||
def DEFINE_SCHEMA(cls):
|
||||
def define_schema(cls):
|
||||
return io.SchemaV3(
|
||||
node_id="RepeatImageBatch_V3",
|
||||
display_name="Repeat Image Batch _V3",
|
||||
@@ -571,7 +571,7 @@ class RepeatImageBatch(io.ComfyNodeV3):
|
||||
|
||||
class ResizeAndPadImage(io.ComfyNodeV3):
|
||||
@classmethod
|
||||
def DEFINE_SCHEMA(cls):
|
||||
def define_schema(cls):
|
||||
return io.SchemaV3(
|
||||
node_id="ResizeAndPadImage_V3",
|
||||
display_name="Resize and Pad Image _V3",
|
||||
@@ -616,7 +616,7 @@ class ResizeAndPadImage(io.ComfyNodeV3):
|
||||
|
||||
class SaveAnimatedPNG(io.ComfyNodeV3):
|
||||
@classmethod
|
||||
def DEFINE_SCHEMA(cls):
|
||||
def define_schema(cls):
|
||||
return io.SchemaV3(
|
||||
node_id="SaveAnimatedPNG_V3",
|
||||
display_name="Save Animated PNG _V3",
|
||||
@@ -681,7 +681,7 @@ class SaveAnimatedWEBP(io.ComfyNodeV3):
|
||||
COMPRESS_METHODS = {"default": 4, "fastest": 0, "slowest": 6}
|
||||
|
||||
@classmethod
|
||||
def DEFINE_SCHEMA(cls):
|
||||
def define_schema(cls):
|
||||
return io.SchemaV3(
|
||||
node_id="SaveAnimatedWEBP_V3",
|
||||
display_name="Save Animated WEBP _V3",
|
||||
@@ -744,7 +744,7 @@ class SaveAnimatedWEBP(io.ComfyNodeV3):
|
||||
|
||||
class SaveImage(io.ComfyNodeV3):
|
||||
@classmethod
|
||||
def DEFINE_SCHEMA(cls):
|
||||
def define_schema(cls):
|
||||
return io.SchemaV3(
|
||||
node_id="SaveImage_V3",
|
||||
display_name="Save Image _V3",
|
||||
|
@@ -1,7 +1,341 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import numpy as np
|
||||
import scipy.ndimage
|
||||
import torch
|
||||
|
||||
import comfy.utils
|
||||
import node_helpers
|
||||
import nodes
|
||||
from comfy_api.v3 import io, ui
|
||||
|
||||
|
||||
class MaskPreview_V3(io.ComfyNodeV3):
|
||||
def composite(destination, source, x, y, mask=None, multiplier=8, resize_source=False):
|
||||
source = source.to(destination.device)
|
||||
if resize_source:
|
||||
source = torch.nn.functional.interpolate(
|
||||
source, size=(destination.shape[2], destination.shape[3]), mode="bilinear"
|
||||
)
|
||||
|
||||
source = comfy.utils.repeat_to_batch_size(source, destination.shape[0])
|
||||
|
||||
x = max(-source.shape[3] * multiplier, min(x, destination.shape[3] * multiplier))
|
||||
y = max(-source.shape[2] * multiplier, min(y, destination.shape[2] * multiplier))
|
||||
|
||||
left, top = (x // multiplier, y // multiplier)
|
||||
right, bottom = (
|
||||
left + source.shape[3],
|
||||
top + source.shape[2],
|
||||
)
|
||||
|
||||
if mask is None:
|
||||
mask = torch.ones_like(source)
|
||||
else:
|
||||
mask = mask.to(destination.device, copy=True)
|
||||
mask = torch.nn.functional.interpolate(
|
||||
mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])),
|
||||
size=(source.shape[2], source.shape[3]),
|
||||
mode="bilinear",
|
||||
)
|
||||
mask = comfy.utils.repeat_to_batch_size(mask, source.shape[0])
|
||||
|
||||
# calculate the bounds of the source that will be overlapping the destination
|
||||
# this prevents the source trying to overwrite latent pixels that are out of bounds
|
||||
# of the destination
|
||||
visible_width, visible_height = (
|
||||
destination.shape[3] - left + min(0, x),
|
||||
destination.shape[2] - top + min(0, y),
|
||||
)
|
||||
|
||||
mask = mask[:, :, :visible_height, :visible_width]
|
||||
inverse_mask = torch.ones_like(mask) - mask
|
||||
|
||||
source_portion = mask * source[:, :, :visible_height, :visible_width]
|
||||
destination_portion = inverse_mask * destination[:, :, top:bottom, left:right]
|
||||
|
||||
destination[:, :, top:bottom, left:right] = source_portion + destination_portion
|
||||
return destination
|
||||
|
||||
|
||||
class CropMask(io.ComfyNodeV3):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.SchemaV3(
|
||||
node_id="CropMask_V3",
|
||||
display_name="Crop Mask _V3",
|
||||
category="mask",
|
||||
inputs=[
|
||||
io.Mask.Input("mask"),
|
||||
io.Int.Input("x", default=0, min=0, max=nodes.MAX_RESOLUTION),
|
||||
io.Int.Input("y", default=0, min=0, max=nodes.MAX_RESOLUTION),
|
||||
io.Int.Input("width", default=512, min=1, max=nodes.MAX_RESOLUTION),
|
||||
io.Int.Input("height", default=512, min=1, max=nodes.MAX_RESOLUTION),
|
||||
],
|
||||
outputs=[io.Mask.Output()],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, mask, x, y, width, height) -> io.NodeOutput:
|
||||
mask = mask.reshape((-1, mask.shape[-2], mask.shape[-1]))
|
||||
return io.NodeOutput(mask[:, y : y + height, x : x + width])
|
||||
|
||||
|
||||
class FeatherMask(io.ComfyNodeV3):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.SchemaV3(
|
||||
node_id="FeatherMask_V3",
|
||||
display_name="Feather Mask _V3",
|
||||
category="mask",
|
||||
inputs=[
|
||||
io.Mask.Input("mask"),
|
||||
io.Int.Input("left", default=0, min=0, max=nodes.MAX_RESOLUTION),
|
||||
io.Int.Input("top", default=0, min=0, max=nodes.MAX_RESOLUTION),
|
||||
io.Int.Input("right", default=0, min=0, max=nodes.MAX_RESOLUTION),
|
||||
io.Int.Input("bottom", default=0, min=0, max=nodes.MAX_RESOLUTION),
|
||||
],
|
||||
outputs=[io.Mask.Output()],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, mask, left, top, right, bottom) -> io.NodeOutput:
|
||||
output = mask.reshape((-1, mask.shape[-2], mask.shape[-1])).clone()
|
||||
|
||||
left = min(left, output.shape[-1])
|
||||
right = min(right, output.shape[-1])
|
||||
top = min(top, output.shape[-2])
|
||||
bottom = min(bottom, output.shape[-2])
|
||||
|
||||
for x in range(left):
|
||||
feather_rate = (x + 1.0) / left
|
||||
output[:, :, x] *= feather_rate
|
||||
|
||||
for x in range(right):
|
||||
feather_rate = (x + 1) / right
|
||||
output[:, :, -x] *= feather_rate
|
||||
|
||||
for y in range(top):
|
||||
feather_rate = (y + 1) / top
|
||||
output[:, y, :] *= feather_rate
|
||||
|
||||
for y in range(bottom):
|
||||
feather_rate = (y + 1) / bottom
|
||||
output[:, -y, :] *= feather_rate
|
||||
|
||||
return io.NodeOutput(output)
|
||||
|
||||
|
||||
class GrowMask(io.ComfyNodeV3):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.SchemaV3(
|
||||
node_id="GrowMask_V3",
|
||||
display_name="Grow Mask _V3",
|
||||
category="mask",
|
||||
inputs=[
|
||||
io.Mask.Input("mask"),
|
||||
io.Int.Input("expand", default=0, min=-nodes.MAX_RESOLUTION, max=nodes.MAX_RESOLUTION),
|
||||
io.Boolean.Input("tapered_corners", default=True),
|
||||
],
|
||||
outputs=[io.Mask.Output()],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, mask, expand, tapered_corners) -> io.NodeOutput:
|
||||
c = 0 if tapered_corners else 1
|
||||
kernel = np.array([[c, 1, c], [1, 1, 1], [c, 1, c]])
|
||||
mask = mask.reshape((-1, mask.shape[-2], mask.shape[-1]))
|
||||
out = []
|
||||
for m in mask:
|
||||
output = m.numpy()
|
||||
for _ in range(abs(expand)):
|
||||
if expand < 0:
|
||||
output = scipy.ndimage.grey_erosion(output, footprint=kernel)
|
||||
else:
|
||||
output = scipy.ndimage.grey_dilation(output, footprint=kernel)
|
||||
output = torch.from_numpy(output)
|
||||
out.append(output)
|
||||
return io.NodeOutput(torch.stack(out, dim=0))
|
||||
|
||||
|
||||
class ImageColorToMask(io.ComfyNodeV3):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.SchemaV3(
|
||||
node_id="ImageColorToMask_V3",
|
||||
display_name="Image Color to Mask _V3",
|
||||
category="mask",
|
||||
inputs=[
|
||||
io.Image.Input("image"),
|
||||
io.Int.Input("color", default=0, min=0, max=0xFFFFFF, display_mode=io.NumberDisplay.color),
|
||||
],
|
||||
outputs=[io.Mask.Output()],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, image, color) -> io.NodeOutput:
|
||||
temp = (torch.clamp(image, 0, 1.0) * 255.0).round().to(torch.int)
|
||||
temp = (
|
||||
torch.bitwise_left_shift(temp[:, :, :, 0], 16)
|
||||
+ torch.bitwise_left_shift(temp[:, :, :, 1], 8)
|
||||
+ temp[:, :, :, 2]
|
||||
)
|
||||
return io.NodeOutput(torch.where(temp == color, 1.0, 0).float())
|
||||
|
||||
|
||||
class ImageCompositeMasked(io.ComfyNodeV3):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.SchemaV3(
|
||||
node_id="ImageCompositeMasked_V3",
|
||||
display_name="Image Composite Masked _V3",
|
||||
category="image",
|
||||
inputs=[
|
||||
io.Image.Input("destination"),
|
||||
io.Image.Input("source"),
|
||||
io.Int.Input("x", default=0, min=0, max=nodes.MAX_RESOLUTION),
|
||||
io.Int.Input("y", default=0, min=0, max=nodes.MAX_RESOLUTION),
|
||||
io.Boolean.Input("resize_source", default=False),
|
||||
io.Mask.Input("mask", optional=True),
|
||||
],
|
||||
outputs=[io.Image.Output()],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, destination, source, x, y, resize_source, mask=None) -> io.NodeOutput:
|
||||
destination, source = node_helpers.image_alpha_fix(destination, source)
|
||||
destination = destination.clone().movedim(-1, 1)
|
||||
output = composite(destination, source.movedim(-1, 1), x, y, mask, 1, resize_source).movedim(1, -1)
|
||||
return io.NodeOutput(output)
|
||||
|
||||
|
||||
class ImageToMask(io.ComfyNodeV3):
|
||||
CHANNELS = ["red", "green", "blue", "alpha"]
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.SchemaV3(
|
||||
node_id="ImageToMask_V3",
|
||||
display_name="Convert Image to Mask _V3",
|
||||
category="mask",
|
||||
inputs=[
|
||||
io.Image.Input("image"),
|
||||
io.Combo.Input("channel", options=cls.CHANNELS),
|
||||
],
|
||||
outputs=[io.Mask.Output()],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, image, channel) -> io.NodeOutput:
|
||||
return io.NodeOutput(image[:, :, :, cls.CHANNELS.index(channel)])
|
||||
|
||||
|
||||
class InvertMask(io.ComfyNodeV3):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.SchemaV3(
|
||||
node_id="InvertMask_V3",
|
||||
display_name="Invert Mask _V3",
|
||||
category="mask",
|
||||
inputs=[
|
||||
io.Mask.Input("mask"),
|
||||
],
|
||||
outputs=[io.Mask.Output()],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, mask) -> io.NodeOutput:
|
||||
return io.NodeOutput(1.0 - mask)
|
||||
|
||||
|
||||
class LatentCompositeMasked(io.ComfyNodeV3):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.SchemaV3(
|
||||
node_id="LatentCompositeMasked_V3",
|
||||
display_name="Latent Composite Masked _V3",
|
||||
category="latent",
|
||||
inputs=[
|
||||
io.Latent.Input("destination"),
|
||||
io.Latent.Input("source"),
|
||||
io.Int.Input("x", default=0, min=0, max=nodes.MAX_RESOLUTION, step=8),
|
||||
io.Int.Input("y", default=0, min=0, max=nodes.MAX_RESOLUTION, step=8),
|
||||
io.Boolean.Input("resize_source", default=False),
|
||||
io.Mask.Input("mask", optional=True),
|
||||
],
|
||||
outputs=[io.Latent.Output()],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, destination, source, x, y, resize_source, mask=None) -> io.NodeOutput:
|
||||
output = destination.copy()
|
||||
destination_samples = destination["samples"].clone()
|
||||
source_samples = source["samples"]
|
||||
output["samples"] = composite(destination_samples, source_samples, x, y, mask, 8, resize_source)
|
||||
return io.NodeOutput(output)
|
||||
|
||||
|
||||
class MaskComposite(io.ComfyNodeV3):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.SchemaV3(
|
||||
node_id="MaskComposite_V3",
|
||||
display_name="Mask Composite _V3",
|
||||
category="mask",
|
||||
inputs=[
|
||||
io.Mask.Input("destination"),
|
||||
io.Mask.Input("source"),
|
||||
io.Int.Input("x", default=0, min=0, max=nodes.MAX_RESOLUTION),
|
||||
io.Int.Input("y", default=0, min=0, max=nodes.MAX_RESOLUTION),
|
||||
io.Combo.Input("operation", options=["multiply", "add", "subtract", "and", "or", "xor"]),
|
||||
],
|
||||
outputs=[io.Mask.Output()],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, destination, source, x, y, operation) -> io.NodeOutput:
|
||||
output = destination.reshape((-1, destination.shape[-2], destination.shape[-1])).clone()
|
||||
source = source.reshape((-1, source.shape[-2], source.shape[-1]))
|
||||
|
||||
left, top = (
|
||||
x,
|
||||
y,
|
||||
)
|
||||
right, bottom = (
|
||||
min(left + source.shape[-1], destination.shape[-1]),
|
||||
min(top + source.shape[-2], destination.shape[-2]),
|
||||
)
|
||||
visible_width, visible_height = (
|
||||
right - left,
|
||||
bottom - top,
|
||||
)
|
||||
|
||||
source_portion = source[:, :visible_height, :visible_width]
|
||||
destination_portion = output[:, top:bottom, left:right]
|
||||
|
||||
if operation == "multiply":
|
||||
output[:, top:bottom, left:right] = destination_portion * source_portion
|
||||
elif operation == "add":
|
||||
output[:, top:bottom, left:right] = destination_portion + source_portion
|
||||
elif operation == "subtract":
|
||||
output[:, top:bottom, left:right] = destination_portion - source_portion
|
||||
elif operation == "and":
|
||||
output[:, top:bottom, left:right] = torch.bitwise_and(
|
||||
destination_portion.round().bool(), source_portion.round().bool()
|
||||
).float()
|
||||
elif operation == "or":
|
||||
output[:, top:bottom, left:right] = torch.bitwise_or(
|
||||
destination_portion.round().bool(), source_portion.round().bool()
|
||||
).float()
|
||||
elif operation == "xor":
|
||||
output[:, top:bottom, left:right] = torch.bitwise_xor(
|
||||
destination_portion.round().bool(), source_portion.round().bool()
|
||||
).float()
|
||||
|
||||
return io.NodeOutput(torch.clamp(output, 0.0, 1.0))
|
||||
|
||||
|
||||
class MaskPreview(io.ComfyNodeV3):
|
||||
"""Mask Preview - original implement in ComfyUI_essentials.
|
||||
|
||||
https://github.com/cubiq/ComfyUI_essentials/blob/9d9f4bedfc9f0321c19faf71855e228c93bd0dc9/mask.py#L81
|
||||
@@ -9,7 +343,7 @@ class MaskPreview_V3(io.ComfyNodeV3):
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def DEFINE_SCHEMA(cls):
|
||||
def define_schema(cls):
|
||||
return io.SchemaV3(
|
||||
node_id="MaskPreview_V3",
|
||||
display_name="Preview Mask _V3",
|
||||
@@ -26,4 +360,75 @@ class MaskPreview_V3(io.ComfyNodeV3):
|
||||
return io.NodeOutput(ui=ui.PreviewMask(masks))
|
||||
|
||||
|
||||
NODES_LIST: list[type[io.ComfyNodeV3]] = [MaskPreview_V3]
|
||||
class MaskToImage(io.ComfyNodeV3):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.SchemaV3(
|
||||
node_id="MaskToImage_V3",
|
||||
display_name="Convert Mask to Image _V3",
|
||||
category="mask",
|
||||
inputs=[
|
||||
io.Mask.Input("mask"),
|
||||
],
|
||||
outputs=[io.Image.Output()],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, mask) -> io.NodeOutput:
|
||||
return io.NodeOutput(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])).movedim(1, -1).expand(-1, -1, -1, 3))
|
||||
|
||||
|
||||
class SolidMask(io.ComfyNodeV3):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.SchemaV3(
|
||||
node_id="SolidMask_V3",
|
||||
display_name="Solid Mask _V3",
|
||||
category="mask",
|
||||
inputs=[
|
||||
io.Float.Input("value", default=1.0, min=0.0, max=1.0, step=0.01),
|
||||
io.Int.Input("width", default=512, min=1, max=nodes.MAX_RESOLUTION),
|
||||
io.Int.Input("height", default=512, min=1, max=nodes.MAX_RESOLUTION),
|
||||
],
|
||||
outputs=[io.Mask.Output()],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, value, width, height) -> io.NodeOutput:
|
||||
return io.NodeOutput(torch.full((1, height, width), value, dtype=torch.float32, device="cpu"))
|
||||
|
||||
|
||||
class ThresholdMask(io.ComfyNodeV3):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.SchemaV3(
|
||||
node_id="ThresholdMask_V3",
|
||||
display_name="Threshold Mask _V3",
|
||||
category="mask",
|
||||
inputs=[
|
||||
io.Mask.Input("mask"),
|
||||
io.Float.Input("value", default=0.5, min=0.0, max=1.0, step=0.01),
|
||||
],
|
||||
outputs=[io.Mask.Output()],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, mask, value) -> io.NodeOutput:
|
||||
return io.NodeOutput((mask > value).float())
|
||||
|
||||
|
||||
NODES_LIST: list[type[io.ComfyNodeV3]] = [
|
||||
CropMask,
|
||||
FeatherMask,
|
||||
GrowMask,
|
||||
ImageColorToMask,
|
||||
ImageCompositeMasked,
|
||||
ImageToMask,
|
||||
InvertMask,
|
||||
LatentCompositeMasked,
|
||||
MaskComposite,
|
||||
MaskPreview,
|
||||
MaskToImage,
|
||||
SolidMask,
|
||||
ThresholdMask,
|
||||
]
|
||||
|
47
comfy_extras/v3/nodes_preview_any.py
Normal file
47
comfy_extras/v3/nodes_preview_any.py
Normal file
@@ -0,0 +1,47 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
|
||||
from comfy_api.v3 import io, ui
|
||||
|
||||
|
||||
class PreviewAny(io.ComfyNodeV3):
|
||||
"""Originally implement from https://github.com/rgthree/rgthree-comfy/blob/main/py/display_any.py
|
||||
|
||||
upstream requested in https://github.com/Kosinkadink/rfcs/blob/main/rfcs/0000-corenodes.md#preview-nodes"""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.SchemaV3(
|
||||
node_id="PreviewAny_V3", # frontend expects "PreviewAny" to work
|
||||
display_name="Preview Any _V3", # frontend ignores "display_name" for this node
|
||||
description="Preview any type of data by converting it to a readable text format.",
|
||||
category="utils",
|
||||
inputs=[
|
||||
io.AnyType.Input("source"), # TODO: does not work currently, as `io.AnyType` does not define __ne__
|
||||
],
|
||||
is_output_node=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, source=None) -> io.NodeOutput:
|
||||
value = "None"
|
||||
if isinstance(source, str):
|
||||
value = source
|
||||
elif isinstance(source, (int, float, bool)):
|
||||
value = str(source)
|
||||
elif source is not None:
|
||||
try:
|
||||
value = json.dumps(source)
|
||||
except Exception:
|
||||
try:
|
||||
value = str(source)
|
||||
except Exception:
|
||||
value = "source exists, but could not be serialized."
|
||||
|
||||
return io.NodeOutput(ui=ui.PreviewText(value))
|
||||
|
||||
|
||||
NODES_LIST: list[type[io.ComfyNodeV3]] = [
|
||||
PreviewAny,
|
||||
]
|
@@ -5,9 +5,9 @@ import sys
|
||||
from comfy_api.v3 import io
|
||||
|
||||
|
||||
class String_V3(io.ComfyNodeV3):
|
||||
class String(io.ComfyNodeV3):
|
||||
@classmethod
|
||||
def DEFINE_SCHEMA(cls):
|
||||
def define_schema(cls):
|
||||
return io.SchemaV3(
|
||||
node_id="PrimitiveString_V3",
|
||||
display_name="String _V3",
|
||||
@@ -23,9 +23,9 @@ class String_V3(io.ComfyNodeV3):
|
||||
return io.NodeOutput(value)
|
||||
|
||||
|
||||
class StringMultiline_V3(io.ComfyNodeV3):
|
||||
class StringMultiline(io.ComfyNodeV3):
|
||||
@classmethod
|
||||
def DEFINE_SCHEMA(cls):
|
||||
def define_schema(cls):
|
||||
return io.SchemaV3(
|
||||
node_id="PrimitiveStringMultiline_V3",
|
||||
display_name="String (Multiline) _V3",
|
||||
@@ -41,9 +41,9 @@ class StringMultiline_V3(io.ComfyNodeV3):
|
||||
return io.NodeOutput(value)
|
||||
|
||||
|
||||
class Int_V3(io.ComfyNodeV3):
|
||||
class Int(io.ComfyNodeV3):
|
||||
@classmethod
|
||||
def DEFINE_SCHEMA(cls):
|
||||
def define_schema(cls):
|
||||
return io.SchemaV3(
|
||||
node_id="PrimitiveInt_V3",
|
||||
display_name="Int _V3",
|
||||
@@ -59,9 +59,9 @@ class Int_V3(io.ComfyNodeV3):
|
||||
return io.NodeOutput(value)
|
||||
|
||||
|
||||
class Float_V3(io.ComfyNodeV3):
|
||||
class Float(io.ComfyNodeV3):
|
||||
@classmethod
|
||||
def DEFINE_SCHEMA(cls):
|
||||
def define_schema(cls):
|
||||
return io.SchemaV3(
|
||||
node_id="PrimitiveFloat_V3",
|
||||
display_name="Float _V3",
|
||||
@@ -77,9 +77,9 @@ class Float_V3(io.ComfyNodeV3):
|
||||
return io.NodeOutput(value)
|
||||
|
||||
|
||||
class Boolean_V3(io.ComfyNodeV3):
|
||||
class Boolean(io.ComfyNodeV3):
|
||||
@classmethod
|
||||
def DEFINE_SCHEMA(cls):
|
||||
def define_schema(cls):
|
||||
return io.SchemaV3(
|
||||
node_id="PrimitiveBoolean_V3",
|
||||
display_name="Boolean _V3",
|
||||
@@ -96,9 +96,9 @@ class Boolean_V3(io.ComfyNodeV3):
|
||||
|
||||
|
||||
NODES_LIST: list[type[io.ComfyNodeV3]] = [
|
||||
String_V3,
|
||||
StringMultiline_V3,
|
||||
Int_V3,
|
||||
Float_V3,
|
||||
Boolean_V3,
|
||||
String,
|
||||
StringMultiline,
|
||||
Int,
|
||||
Float,
|
||||
Boolean,
|
||||
]
|
||||
|
@@ -23,9 +23,9 @@ import nodes
|
||||
from comfy_api.v3 import io
|
||||
|
||||
|
||||
class StableCascade_EmptyLatentImage_V3(io.ComfyNodeV3):
|
||||
class StableCascade_EmptyLatentImage(io.ComfyNodeV3):
|
||||
@classmethod
|
||||
def DEFINE_SCHEMA(cls):
|
||||
def define_schema(cls):
|
||||
return io.SchemaV3(
|
||||
node_id="StableCascade_EmptyLatentImage_V3",
|
||||
category="latent/stable_cascade",
|
||||
@@ -48,9 +48,9 @@ class StableCascade_EmptyLatentImage_V3(io.ComfyNodeV3):
|
||||
return io.NodeOutput({"samples": c_latent}, {"samples": b_latent})
|
||||
|
||||
|
||||
class StableCascade_StageC_VAEEncode_V3(io.ComfyNodeV3):
|
||||
class StableCascade_StageC_VAEEncode(io.ComfyNodeV3):
|
||||
@classmethod
|
||||
def DEFINE_SCHEMA(cls):
|
||||
def define_schema(cls):
|
||||
return io.SchemaV3(
|
||||
node_id="StableCascade_StageC_VAEEncode_V3",
|
||||
category="latent/stable_cascade",
|
||||
@@ -79,9 +79,9 @@ class StableCascade_StageC_VAEEncode_V3(io.ComfyNodeV3):
|
||||
return io.NodeOutput({"samples": c_latent}, {"samples": b_latent})
|
||||
|
||||
|
||||
class StableCascade_StageB_Conditioning_V3(io.ComfyNodeV3):
|
||||
class StableCascade_StageB_Conditioning(io.ComfyNodeV3):
|
||||
@classmethod
|
||||
def DEFINE_SCHEMA(cls):
|
||||
def define_schema(cls):
|
||||
return io.SchemaV3(
|
||||
node_id="StableCascade_StageB_Conditioning_V3",
|
||||
category="conditioning/stable_cascade",
|
||||
@@ -105,9 +105,9 @@ class StableCascade_StageB_Conditioning_V3(io.ComfyNodeV3):
|
||||
return io.NodeOutput(c)
|
||||
|
||||
|
||||
class StableCascade_SuperResolutionControlnet_V3(io.ComfyNodeV3):
|
||||
class StableCascade_SuperResolutionControlnet(io.ComfyNodeV3):
|
||||
@classmethod
|
||||
def DEFINE_SCHEMA(cls):
|
||||
def define_schema(cls):
|
||||
return io.SchemaV3(
|
||||
node_id="StableCascade_SuperResolutionControlnet_V3",
|
||||
category="_for_testing/stable_cascade",
|
||||
@@ -136,8 +136,8 @@ class StableCascade_SuperResolutionControlnet_V3(io.ComfyNodeV3):
|
||||
|
||||
|
||||
NODES_LIST: list[type[io.ComfyNodeV3]] = [
|
||||
StableCascade_EmptyLatentImage_V3,
|
||||
StableCascade_StageB_Conditioning_V3,
|
||||
StableCascade_StageC_VAEEncode_V3,
|
||||
StableCascade_SuperResolutionControlnet_V3,
|
||||
StableCascade_EmptyLatentImage,
|
||||
StableCascade_StageB_Conditioning,
|
||||
StableCascade_StageC_VAEEncode,
|
||||
StableCascade_SuperResolutionControlnet,
|
||||
]
|
||||
|
@@ -9,32 +9,18 @@ import node_helpers
|
||||
import nodes
|
||||
from comfy_api.v3 import io
|
||||
|
||||
MAX_RESOLUTION = nodes.MAX_RESOLUTION
|
||||
|
||||
|
||||
class WebcamCapture_V3(io.ComfyNodeV3):
|
||||
class WebcamCapture(io.ComfyNodeV3):
|
||||
@classmethod
|
||||
def DEFINE_SCHEMA(cls):
|
||||
def define_schema(cls):
|
||||
return io.SchemaV3(
|
||||
node_id="WebcamCapture_V3",
|
||||
display_name="Webcam Capture _V3",
|
||||
category="image",
|
||||
inputs=[
|
||||
io.Webcam.Input("image"),
|
||||
io.Int.Input(
|
||||
"width",
|
||||
default=0,
|
||||
min=0,
|
||||
max=MAX_RESOLUTION,
|
||||
step=1,
|
||||
),
|
||||
io.Int.Input(
|
||||
"height",
|
||||
default=0,
|
||||
min=0,
|
||||
max=MAX_RESOLUTION,
|
||||
step=1,
|
||||
),
|
||||
io.Int.Input("width", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1),
|
||||
io.Int.Input("height", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1),
|
||||
io.Boolean.Input("capture_on_queue", default=True),
|
||||
],
|
||||
outputs=[
|
||||
@@ -103,4 +89,4 @@ class WebcamCapture_V3(io.ComfyNodeV3):
|
||||
return True
|
||||
|
||||
|
||||
NODES_LIST: list[type[io.ComfyNodeV3]] = [WebcamCapture_V3]
|
||||
NODES_LIST: list[type[io.ComfyNodeV3]] = [WebcamCapture]
|
||||
|
15
execution.py
15
execution.py
@@ -28,6 +28,7 @@ from comfy_execution.graph import (
|
||||
)
|
||||
from comfy_execution.graph_utils import GraphBuilder, is_link
|
||||
from comfy_execution.validation import validate_node_input
|
||||
from comfy_api.internal import ComfyNodeInternal
|
||||
from comfy_api.v3 import io, helpers
|
||||
|
||||
|
||||
@@ -54,7 +55,7 @@ class IsChangedCache:
|
||||
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
||||
has_is_changed = False
|
||||
is_changed_name = None
|
||||
if issubclass(class_def, io.ComfyNodeV3) and helpers.first_real_override(class_def, "fingerprint_inputs", base=io.ComfyNodeV3) is not None:
|
||||
if issubclass(class_def, ComfyNodeInternal) and helpers.first_real_override(class_def, "fingerprint_inputs") is not None:
|
||||
has_is_changed = True
|
||||
is_changed_name = "fingerprint_inputs"
|
||||
elif hasattr(class_def, "IS_CHANGED"):
|
||||
@@ -127,7 +128,7 @@ class CacheSet:
|
||||
return result
|
||||
|
||||
def get_input_data(inputs, class_def, unique_id, outputs=None, dynprompt=None, extra_data={}):
|
||||
is_v3 = issubclass(class_def, io.ComfyNodeV3)
|
||||
is_v3 = issubclass(class_def, ComfyNodeInternal)
|
||||
if is_v3:
|
||||
valid_inputs, schema = class_def.INPUT_TYPES(include_hidden=False, return_schema=True)
|
||||
else:
|
||||
@@ -224,7 +225,7 @@ def _map_node_over_list(obj, input_data_all, func, allow_interrupt=False, execut
|
||||
if pre_execute_cb is not None and index is not None:
|
||||
pre_execute_cb(index)
|
||||
# V3
|
||||
if isinstance(obj, io.ComfyNodeV3) or (io.is_class(obj) and issubclass(obj, io.ComfyNodeV3)):
|
||||
if isinstance(obj, ComfyNodeInternal) or (io.is_class(obj) and issubclass(obj, ComfyNodeInternal)):
|
||||
# if is just a class, then assign no resources or state, just create clone
|
||||
if io.is_class(obj):
|
||||
type_obj = obj
|
||||
@@ -411,8 +412,8 @@ def execute(server, dynprompt, caches, current_item, extra_data, executed, promp
|
||||
obj = class_def()
|
||||
caches.objects.set(unique_id, obj)
|
||||
|
||||
if issubclass(class_def, io.ComfyNodeV3):
|
||||
lazy_status_present = helpers.first_real_override(class_def, "check_lazy_status", base=io.ComfyNodeV3) is not None
|
||||
if issubclass(class_def, ComfyNodeInternal):
|
||||
lazy_status_present = helpers.first_real_override(class_def, "check_lazy_status") is not None
|
||||
else:
|
||||
lazy_status_present = getattr(obj, "check_lazy_status", None) is not None
|
||||
if lazy_status_present:
|
||||
@@ -674,9 +675,9 @@ def validate_inputs(prompt, item, validated):
|
||||
|
||||
validate_function_inputs = []
|
||||
validate_has_kwargs = False
|
||||
if issubclass(obj_class, io.ComfyNodeV3):
|
||||
if issubclass(obj_class, ComfyNodeInternal):
|
||||
validate_function_name = "validate_inputs"
|
||||
validate_function = helpers.first_real_override(obj_class, validate_function_name, base=io.ComfyNodeV3)
|
||||
validate_function = helpers.first_real_override(obj_class, validate_function_name)
|
||||
else:
|
||||
validate_function_name = "VALIDATE_INPUTS"
|
||||
validate_function = getattr(obj_class, validate_function_name, None)
|
||||
|
8
nodes.py
8
nodes.py
@@ -2299,13 +2299,19 @@ def init_builtin_extra_nodes():
|
||||
"nodes_tcfg.py",
|
||||
"nodes_v3_test.py",
|
||||
"nodes_v1_test.py",
|
||||
"v3/nodes_ace.py",
|
||||
"v3/nodes_advanced_samplers.py",
|
||||
"v3/nodes_align_your_steps.py",
|
||||
"v3/nodes_audio.py",
|
||||
"v3/nodes_apg.py",
|
||||
"v3/nodes_attention_multiply.py",
|
||||
"v3/nodes_controlnet.py",
|
||||
"v3/nodes_images.py",
|
||||
"v3/nodes_mask.py",
|
||||
"v3/nodes_preview_any.py",
|
||||
"v3/nodes_primitive.py",
|
||||
"v3/nodes_webcam.py",
|
||||
"v3/nodes_stable_cascade.py",
|
||||
"v3/nodes_webcam.py",
|
||||
]
|
||||
|
||||
import_failed = []
|
||||
|
@@ -24,7 +24,7 @@ lint.select = [
|
||||
"F",
|
||||
]
|
||||
exclude = ["*.ipynb"]
|
||||
line-length = 120
|
||||
line-length = 144
|
||||
lint.pycodestyle.ignore-overlong-task-comments = true
|
||||
|
||||
[tool.ruff.lint.per-file-ignores]
|
||||
|
@@ -29,7 +29,7 @@ import comfy.model_management
|
||||
import node_helpers
|
||||
from comfyui_version import __version__
|
||||
from app.frontend_management import FrontendManager
|
||||
from comfy_api.v3.io import ComfyNodeV3
|
||||
from comfy_api.internal import ComfyNodeInternal
|
||||
|
||||
from app.user_manager import UserManager
|
||||
from app.model_manager import ModelFileManager
|
||||
@@ -555,7 +555,7 @@ class PromptServer():
|
||||
|
||||
def node_info(node_class):
|
||||
obj_class = nodes.NODE_CLASS_MAPPINGS[node_class]
|
||||
if issubclass(obj_class, ComfyNodeV3):
|
||||
if issubclass(obj_class, ComfyNodeInternal):
|
||||
return obj_class.GET_NODE_INFO_V1()
|
||||
info = {}
|
||||
info['input'] = obj_class.INPUT_TYPES()
|
||||
|
Reference in New Issue
Block a user