1
mirror of https://github.com/comfyanonymous/ComfyUI.git synced 2025-08-03 07:26:31 +08:00

Try out adding Type class var to IO_V3 to help with type hints

This commit is contained in:
kosinkadink1@gmail.com
2025-06-10 00:19:17 -07:00
parent 2197b6cbf3
commit 70d2bbfec0
2 changed files with 26 additions and 5 deletions

View File

@@ -1,10 +1,13 @@
from __future__ import annotations from __future__ import annotations
from typing import Any, Literal from typing import Any, Literal, TYPE_CHECKING, TypeVar
from enum import Enum from enum import Enum
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass, asdict from dataclasses import dataclass, asdict
from comfy.comfy_types.node_typing import IO from comfy.comfy_types.node_typing import IO
# if TYPE_CHECKING:
import torch
class InputBehavior(str, Enum): class InputBehavior(str, Enum):
required = "required" required = "required"
@@ -60,11 +63,14 @@ class IO_V3:
''' '''
Base class for V3 Inputs and Outputs. Base class for V3 Inputs and Outputs.
''' '''
Type = Any
def __init__(self): def __init__(self):
pass pass
def __init_subclass__(cls, io_type: IO | str, **kwargs): def __init_subclass__(cls, io_type: IO | str, Type=Any, **kwargs):
cls.io_type = io_type cls.io_type = io_type
cls.Type = Type
super().__init_subclass__(**kwargs) super().__init_subclass__(**kwargs)
class InputV3(IO_V3, io_type=None): class InputV3(IO_V3, io_type=None):
@@ -141,6 +147,7 @@ class BooleanInput(WidgetInputV3, io_type=IO.BOOLEAN):
''' '''
Boolean input. Boolean input.
''' '''
Type = bool
def __init__(self, id: str, display_name: str=None, behavior=InputBehavior.required, tooltip: str=None, lazy: bool=None, def __init__(self, id: str, display_name: str=None, behavior=InputBehavior.required, tooltip: str=None, lazy: bool=None,
default: bool=None, label_on: str=None, label_off: str=None, default: bool=None, label_on: str=None, label_off: str=None,
socketless: bool=None, widgetType: str=None): socketless: bool=None, widgetType: str=None):
@@ -159,6 +166,7 @@ class IntegerInput(WidgetInputV3, io_type=IO.INT):
''' '''
Integer input. Integer input.
''' '''
Type = int
def __init__(self, id: str, display_name: str=None, behavior=InputBehavior.required, tooltip: str=None, lazy: bool=None, def __init__(self, id: str, display_name: str=None, behavior=InputBehavior.required, tooltip: str=None, lazy: bool=None,
default: int=None, min: int=None, max: int=None, step: int=None, control_after_generate: bool=None, default: int=None, min: int=None, max: int=None, step: int=None, control_after_generate: bool=None,
display_mode: NumberDisplay=None, socketless: bool=None, widgetType: str=None): display_mode: NumberDisplay=None, socketless: bool=None, widgetType: str=None):
@@ -183,6 +191,7 @@ class FloatInput(WidgetInputV3, io_type=IO.FLOAT):
''' '''
Float input. Float input.
''' '''
Type = float
def __init__(self, id: str, display_name: str=None, behavior=InputBehavior.required, tooltip: str=None, lazy: bool=None, def __init__(self, id: str, display_name: str=None, behavior=InputBehavior.required, tooltip: str=None, lazy: bool=None,
default: float=None, min: float=None, max: float=None, step: float=None, round: float=None, default: float=None, min: float=None, max: float=None, step: float=None, round: float=None,
display_mode: NumberDisplay=None, socketless: bool=None, widgetType: str=None): display_mode: NumberDisplay=None, socketless: bool=None, widgetType: str=None):
@@ -208,6 +217,7 @@ class StringInput(WidgetInputV3, io_type=IO.STRING):
''' '''
String input. String input.
''' '''
Type = str
def __init__(self, id: str, display_name: str=None, behavior=InputBehavior.required, tooltip: str=None, lazy: bool=None, def __init__(self, id: str, display_name: str=None, behavior=InputBehavior.required, tooltip: str=None, lazy: bool=None,
multiline=False, placeholder: str=None, default: int=None, multiline=False, placeholder: str=None, default: int=None,
socketless: bool=None, widgetType: str=None): socketless: bool=None, widgetType: str=None):
@@ -224,6 +234,7 @@ class StringInput(WidgetInputV3, io_type=IO.STRING):
class ComboInput(WidgetInputV3, io_type=IO.COMBO): class ComboInput(WidgetInputV3, io_type=IO.COMBO):
'''Combo input (dropdown).''' '''Combo input (dropdown).'''
Type = str
def __init__(self, id: str, options: list[str]=None, display_name: str=None, behavior=InputBehavior.required, tooltip: str=None, lazy: bool=None, def __init__(self, id: str, options: list[str]=None, display_name: str=None, behavior=InputBehavior.required, tooltip: str=None, lazy: bool=None,
default: str=None, control_after_generate: bool=None, default: str=None, control_after_generate: bool=None,
image_upload: bool=None, image_folder: FolderType=None, image_upload: bool=None, image_folder: FolderType=None,
@@ -270,6 +281,7 @@ class ImageInput(InputV3, io_type=IO.IMAGE):
''' '''
Image input. Image input.
''' '''
Type = torch.Tensor
def __init__(self, id: str, display_name: str=None, behavior=InputBehavior.required, tooltip: str=None): def __init__(self, id: str, display_name: str=None, behavior=InputBehavior.required, tooltip: str=None):
super().__init__(id, display_name, behavior, tooltip) super().__init__(id, display_name, behavior, tooltip)
@@ -277,6 +289,7 @@ class MaskInput(InputV3, io_type=IO.MASK):
''' '''
Mask input. Mask input.
''' '''
Type = torch.Tensor
def __init__(self, id: str, display_name: str=None, behavior=InputBehavior.required, tooltip: str=None): def __init__(self, id: str, display_name: str=None, behavior=InputBehavior.required, tooltip: str=None):
super().__init__(id, display_name, behavior, tooltip) super().__init__(id, display_name, behavior, tooltip)

View File

@@ -2,12 +2,19 @@ import torch
from comfy_api.v3.io import ( from comfy_api.v3.io import (
ComfyNodeV3, SchemaV3, InputBehavior, NumberDisplay, ComfyNodeV3, SchemaV3, InputBehavior, NumberDisplay,
IntegerInput, MaskInput, ImageInput, ComboInput, CustomInput, StringInput, CustomType, IntegerInput, MaskInput, ImageInput, ComboInput, CustomInput, StringInput, CustomType,
IntegerOutput, ImageOutput, MultitypedInput, IntegerOutput, ImageOutput, MultitypedInput, InputV3, OutputV3,
NodeOutput, Hidden NodeOutput, Hidden
) )
import logging import logging
class XYZInput(InputV3, io_type="XYZ"):
Type = tuple[int,str]
class XYZOutput(OutputV3, io_type="XYZ"):
...
class V3TestNode(ComfyNodeV3): class V3TestNode(ComfyNodeV3):
def __init__(self): def __init__(self):
@@ -22,7 +29,8 @@ class V3TestNode(ComfyNodeV3):
category="v3 nodes", category="v3 nodes",
inputs=[ inputs=[
ImageInput("image", display_name="new_image"), ImageInput("image", display_name="new_image"),
CustomInput("xyz", "XYZ", behavior=InputBehavior.optional), XYZInput("xyz", behavior=InputBehavior.optional),
#CustomInput("xyz", "XYZ", behavior=InputBehavior.optional),
MaskInput("mask", behavior=InputBehavior.optional), MaskInput("mask", behavior=InputBehavior.optional),
IntegerInput("some_int", display_name="new_name", min=0, max=127, default=42, IntegerInput("some_int", display_name="new_name", min=0, max=127, default=42,
tooltip="My tooltip 😎", display_mode=NumberDisplay.slider), tooltip="My tooltip 😎", display_mode=NumberDisplay.slider),
@@ -55,7 +63,7 @@ class V3TestNode(ComfyNodeV3):
) )
@classmethod @classmethod
def execute(cls, image: torch.Tensor, some_int: int, combo: str, xyz=None, mask: torch.Tensor=None): def execute(cls, image: ImageInput.Type, some_int: IntegerInput.Type, combo: ComboInput.Type, xyz: XYZInput.Type=None, mask: MaskInput.Type=None):
if hasattr(cls, "hahajkunless"): if hasattr(cls, "hahajkunless"):
raise Exception("The 'cls' variable leaked instance state between runs!") raise Exception("The 'cls' variable leaked instance state between runs!")
if hasattr(cls, "doohickey"): if hasattr(cls, "doohickey"):