mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-08-03 07:26:31 +08:00
Try out adding Type class var to IO_V3 to help with type hints
This commit is contained in:
@@ -1,10 +1,13 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
from typing import Any, Literal
|
from typing import Any, Literal, TYPE_CHECKING, TypeVar
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from dataclasses import dataclass, asdict
|
from dataclasses import dataclass, asdict
|
||||||
from comfy.comfy_types.node_typing import IO
|
from comfy.comfy_types.node_typing import IO
|
||||||
|
|
||||||
|
# if TYPE_CHECKING:
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
class InputBehavior(str, Enum):
|
class InputBehavior(str, Enum):
|
||||||
required = "required"
|
required = "required"
|
||||||
@@ -60,11 +63,14 @@ class IO_V3:
|
|||||||
'''
|
'''
|
||||||
Base class for V3 Inputs and Outputs.
|
Base class for V3 Inputs and Outputs.
|
||||||
'''
|
'''
|
||||||
|
Type = Any
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def __init_subclass__(cls, io_type: IO | str, **kwargs):
|
def __init_subclass__(cls, io_type: IO | str, Type=Any, **kwargs):
|
||||||
cls.io_type = io_type
|
cls.io_type = io_type
|
||||||
|
cls.Type = Type
|
||||||
super().__init_subclass__(**kwargs)
|
super().__init_subclass__(**kwargs)
|
||||||
|
|
||||||
class InputV3(IO_V3, io_type=None):
|
class InputV3(IO_V3, io_type=None):
|
||||||
@@ -141,6 +147,7 @@ class BooleanInput(WidgetInputV3, io_type=IO.BOOLEAN):
|
|||||||
'''
|
'''
|
||||||
Boolean input.
|
Boolean input.
|
||||||
'''
|
'''
|
||||||
|
Type = bool
|
||||||
def __init__(self, id: str, display_name: str=None, behavior=InputBehavior.required, tooltip: str=None, lazy: bool=None,
|
def __init__(self, id: str, display_name: str=None, behavior=InputBehavior.required, tooltip: str=None, lazy: bool=None,
|
||||||
default: bool=None, label_on: str=None, label_off: str=None,
|
default: bool=None, label_on: str=None, label_off: str=None,
|
||||||
socketless: bool=None, widgetType: str=None):
|
socketless: bool=None, widgetType: str=None):
|
||||||
@@ -159,6 +166,7 @@ class IntegerInput(WidgetInputV3, io_type=IO.INT):
|
|||||||
'''
|
'''
|
||||||
Integer input.
|
Integer input.
|
||||||
'''
|
'''
|
||||||
|
Type = int
|
||||||
def __init__(self, id: str, display_name: str=None, behavior=InputBehavior.required, tooltip: str=None, lazy: bool=None,
|
def __init__(self, id: str, display_name: str=None, behavior=InputBehavior.required, tooltip: str=None, lazy: bool=None,
|
||||||
default: int=None, min: int=None, max: int=None, step: int=None, control_after_generate: bool=None,
|
default: int=None, min: int=None, max: int=None, step: int=None, control_after_generate: bool=None,
|
||||||
display_mode: NumberDisplay=None, socketless: bool=None, widgetType: str=None):
|
display_mode: NumberDisplay=None, socketless: bool=None, widgetType: str=None):
|
||||||
@@ -183,6 +191,7 @@ class FloatInput(WidgetInputV3, io_type=IO.FLOAT):
|
|||||||
'''
|
'''
|
||||||
Float input.
|
Float input.
|
||||||
'''
|
'''
|
||||||
|
Type = float
|
||||||
def __init__(self, id: str, display_name: str=None, behavior=InputBehavior.required, tooltip: str=None, lazy: bool=None,
|
def __init__(self, id: str, display_name: str=None, behavior=InputBehavior.required, tooltip: str=None, lazy: bool=None,
|
||||||
default: float=None, min: float=None, max: float=None, step: float=None, round: float=None,
|
default: float=None, min: float=None, max: float=None, step: float=None, round: float=None,
|
||||||
display_mode: NumberDisplay=None, socketless: bool=None, widgetType: str=None):
|
display_mode: NumberDisplay=None, socketless: bool=None, widgetType: str=None):
|
||||||
@@ -208,6 +217,7 @@ class StringInput(WidgetInputV3, io_type=IO.STRING):
|
|||||||
'''
|
'''
|
||||||
String input.
|
String input.
|
||||||
'''
|
'''
|
||||||
|
Type = str
|
||||||
def __init__(self, id: str, display_name: str=None, behavior=InputBehavior.required, tooltip: str=None, lazy: bool=None,
|
def __init__(self, id: str, display_name: str=None, behavior=InputBehavior.required, tooltip: str=None, lazy: bool=None,
|
||||||
multiline=False, placeholder: str=None, default: int=None,
|
multiline=False, placeholder: str=None, default: int=None,
|
||||||
socketless: bool=None, widgetType: str=None):
|
socketless: bool=None, widgetType: str=None):
|
||||||
@@ -224,6 +234,7 @@ class StringInput(WidgetInputV3, io_type=IO.STRING):
|
|||||||
|
|
||||||
class ComboInput(WidgetInputV3, io_type=IO.COMBO):
|
class ComboInput(WidgetInputV3, io_type=IO.COMBO):
|
||||||
'''Combo input (dropdown).'''
|
'''Combo input (dropdown).'''
|
||||||
|
Type = str
|
||||||
def __init__(self, id: str, options: list[str]=None, display_name: str=None, behavior=InputBehavior.required, tooltip: str=None, lazy: bool=None,
|
def __init__(self, id: str, options: list[str]=None, display_name: str=None, behavior=InputBehavior.required, tooltip: str=None, lazy: bool=None,
|
||||||
default: str=None, control_after_generate: bool=None,
|
default: str=None, control_after_generate: bool=None,
|
||||||
image_upload: bool=None, image_folder: FolderType=None,
|
image_upload: bool=None, image_folder: FolderType=None,
|
||||||
@@ -270,6 +281,7 @@ class ImageInput(InputV3, io_type=IO.IMAGE):
|
|||||||
'''
|
'''
|
||||||
Image input.
|
Image input.
|
||||||
'''
|
'''
|
||||||
|
Type = torch.Tensor
|
||||||
def __init__(self, id: str, display_name: str=None, behavior=InputBehavior.required, tooltip: str=None):
|
def __init__(self, id: str, display_name: str=None, behavior=InputBehavior.required, tooltip: str=None):
|
||||||
super().__init__(id, display_name, behavior, tooltip)
|
super().__init__(id, display_name, behavior, tooltip)
|
||||||
|
|
||||||
@@ -277,6 +289,7 @@ class MaskInput(InputV3, io_type=IO.MASK):
|
|||||||
'''
|
'''
|
||||||
Mask input.
|
Mask input.
|
||||||
'''
|
'''
|
||||||
|
Type = torch.Tensor
|
||||||
def __init__(self, id: str, display_name: str=None, behavior=InputBehavior.required, tooltip: str=None):
|
def __init__(self, id: str, display_name: str=None, behavior=InputBehavior.required, tooltip: str=None):
|
||||||
super().__init__(id, display_name, behavior, tooltip)
|
super().__init__(id, display_name, behavior, tooltip)
|
||||||
|
|
||||||
|
@@ -2,12 +2,19 @@ import torch
|
|||||||
from comfy_api.v3.io import (
|
from comfy_api.v3.io import (
|
||||||
ComfyNodeV3, SchemaV3, InputBehavior, NumberDisplay,
|
ComfyNodeV3, SchemaV3, InputBehavior, NumberDisplay,
|
||||||
IntegerInput, MaskInput, ImageInput, ComboInput, CustomInput, StringInput, CustomType,
|
IntegerInput, MaskInput, ImageInput, ComboInput, CustomInput, StringInput, CustomType,
|
||||||
IntegerOutput, ImageOutput, MultitypedInput,
|
IntegerOutput, ImageOutput, MultitypedInput, InputV3, OutputV3,
|
||||||
NodeOutput, Hidden
|
NodeOutput, Hidden
|
||||||
)
|
)
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
|
||||||
|
class XYZInput(InputV3, io_type="XYZ"):
|
||||||
|
Type = tuple[int,str]
|
||||||
|
|
||||||
|
class XYZOutput(OutputV3, io_type="XYZ"):
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
class V3TestNode(ComfyNodeV3):
|
class V3TestNode(ComfyNodeV3):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
@@ -22,7 +29,8 @@ class V3TestNode(ComfyNodeV3):
|
|||||||
category="v3 nodes",
|
category="v3 nodes",
|
||||||
inputs=[
|
inputs=[
|
||||||
ImageInput("image", display_name="new_image"),
|
ImageInput("image", display_name="new_image"),
|
||||||
CustomInput("xyz", "XYZ", behavior=InputBehavior.optional),
|
XYZInput("xyz", behavior=InputBehavior.optional),
|
||||||
|
#CustomInput("xyz", "XYZ", behavior=InputBehavior.optional),
|
||||||
MaskInput("mask", behavior=InputBehavior.optional),
|
MaskInput("mask", behavior=InputBehavior.optional),
|
||||||
IntegerInput("some_int", display_name="new_name", min=0, max=127, default=42,
|
IntegerInput("some_int", display_name="new_name", min=0, max=127, default=42,
|
||||||
tooltip="My tooltip 😎", display_mode=NumberDisplay.slider),
|
tooltip="My tooltip 😎", display_mode=NumberDisplay.slider),
|
||||||
@@ -55,7 +63,7 @@ class V3TestNode(ComfyNodeV3):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def execute(cls, image: torch.Tensor, some_int: int, combo: str, xyz=None, mask: torch.Tensor=None):
|
def execute(cls, image: ImageInput.Type, some_int: IntegerInput.Type, combo: ComboInput.Type, xyz: XYZInput.Type=None, mask: MaskInput.Type=None):
|
||||||
if hasattr(cls, "hahajkunless"):
|
if hasattr(cls, "hahajkunless"):
|
||||||
raise Exception("The 'cls' variable leaked instance state between runs!")
|
raise Exception("The 'cls' variable leaked instance state between runs!")
|
||||||
if hasattr(cls, "doohickey"):
|
if hasattr(cls, "doohickey"):
|
||||||
|
Reference in New Issue
Block a user