diff --git a/comfy_api/latest/__init__.py b/comfy_api/latest/__init__.py index 05f43293a..2d83ec012 100644 --- a/comfy_api/latest/__init__.py +++ b/comfy_api/latest/__init__.py @@ -1,5 +1,6 @@ from __future__ import annotations +from abc import ABC, abstractmethod from typing import Type, TYPE_CHECKING from comfy_api.internal import ComfyAPIBase from comfy_api.internal.singleton import ProxiedSingleton @@ -75,6 +76,19 @@ class ComfyAPI_latest(ComfyAPIBase): execution: Execution +class ComfyExtension(ABC): + async def on_load(self) -> None: + """ + Called when an extension is loaded. + This should be used to initialize any global resources neeeded by the extension. + """ + + @abstractmethod + async def get_node_list(self) -> list[type[io.ComfyNode]]: + """ + Returns a list of nodes that this extension provides. + """ + class Input: Image = ImageInput Audio = AudioInput @@ -106,4 +120,5 @@ __all__ = [ "Input", "InputImpl", "Types", + "ComfyExtension", ] diff --git a/comfy_extras/nodes_v3_test.py b/comfy_extras/nodes_v3_test.py index 9e43ef290..7372b3850 100644 --- a/comfy_extras/nodes_v3_test.py +++ b/comfy_extras/nodes_v3_test.py @@ -1,11 +1,12 @@ import torch import time -from comfy_api.latest import io, ui, resources, _io +from comfy_api.latest import io, ui, resources, _io, ComfyExtension import logging # noqa import folder_paths import comfy.utils import comfy.sd import asyncio +from typing_extensions import override @io.comfytype(io_type="XYZ") class XYZ(io.ComfyTypeIO): @@ -273,7 +274,6 @@ class V3DummyEndInherit(V3DummyEnd): logging.info(f"V3DummyEndInherit: {cls.COOL_VALUE}") return super().execute(xyz) - NODES_LIST: list[type[io.ComfyNode]] = [ V3TestNode, V3LoraLoader, @@ -283,3 +283,11 @@ NODES_LIST: list[type[io.ComfyNode]] = [ V3DummyEnd, V3DummyEndInherit, ] + +class v3TestExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return NODES_LIST + +async def comfy_entrypoint() -> v3TestExtension: + return v3TestExtension() diff --git a/nodes.py b/nodes.py index 51df2b064..ea43f4ea5 100644 --- a/nodes.py +++ b/nodes.py @@ -6,6 +6,7 @@ import os import sys import json import hashlib +import inspect import traceback import math import time @@ -2162,6 +2163,35 @@ async def load_custom_node(module_path: str, ignore=set(), module_parent="custom if hasattr(module, "NODE_DISPLAY_NAME_MAPPINGS") and getattr(module, "NODE_DISPLAY_NAME_MAPPINGS") is not None: NODE_DISPLAY_NAME_MAPPINGS.update(module.NODE_DISPLAY_NAME_MAPPINGS) return True + # V3 Extension Definition + elif hasattr(module, "comfy_entrypoint"): + entrypoint = getattr(module, "comfy_entrypoint") + if not callable(entrypoint): + logging.warning(f"comfy_entrypoint in {module_path} is not callable, skipping.") + return False + try: + if inspect.iscoroutinefunction(entrypoint): + extension = await entrypoint() + else: + extension = entrypoint() + if not isinstance(extension, io.ComfyExtension): + logging.warning(f"comfy_entrypoint in {module_path} did not return a ComfyExtension, skipping.") + return False + node_list = await extension.get_node_list() + if not isinstance(node_list, list): + logging.warning(f"comfy_entrypoint in {module_path} did not return a list of nodes, skipping.") + return False + for node_cls in node_list: + node_cls: io.ComfyNode + schema = node_cls.GET_SCHEMA() + if schema.node_id not in ignore: + NODE_CLASS_MAPPINGS[schema.node_id] = node_cls + node_cls.RELATIVE_PYTHON_MODULE = "{}.{}".format(module_parent, get_module_name(module_path)) + if schema.display_name is not None: + NODE_DISPLAY_NAME_MAPPINGS[schema.node_id] = schema.display_name + except Exception as e: + logging.warning(f"Error while calling comfy_entrypoint in {module_path}: {e}") + return False # V3 node definition elif getattr(module, "NODES_LIST", None) is not None: for node_cls in module.NODES_LIST: