1
mirror of https://github.com/comfyanonymous/ComfyUI.git synced 2025-08-02 15:04:50 +08:00

Create a ComfyExtension class for future growth

This commit is contained in:
Jacob Segal
2025-07-29 16:44:53 -07:00
parent 930f8d9e6d
commit e9a9762ca0
3 changed files with 55 additions and 2 deletions

View File

@@ -1,5 +1,6 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Type, TYPE_CHECKING
from comfy_api.internal import ComfyAPIBase
from comfy_api.internal.singleton import ProxiedSingleton
@@ -75,6 +76,19 @@ class ComfyAPI_latest(ComfyAPIBase):
execution: Execution
class ComfyExtension(ABC):
async def on_load(self) -> None:
"""
Called when an extension is loaded.
This should be used to initialize any global resources neeeded by the extension.
"""
@abstractmethod
async def get_node_list(self) -> list[type[io.ComfyNode]]:
"""
Returns a list of nodes that this extension provides.
"""
class Input:
Image = ImageInput
Audio = AudioInput
@@ -106,4 +120,5 @@ __all__ = [
"Input",
"InputImpl",
"Types",
"ComfyExtension",
]

View File

@@ -1,11 +1,12 @@
import torch
import time
from comfy_api.latest import io, ui, resources, _io
from comfy_api.latest import io, ui, resources, _io, ComfyExtension
import logging # noqa
import folder_paths
import comfy.utils
import comfy.sd
import asyncio
from typing_extensions import override
@io.comfytype(io_type="XYZ")
class XYZ(io.ComfyTypeIO):
@@ -273,7 +274,6 @@ class V3DummyEndInherit(V3DummyEnd):
logging.info(f"V3DummyEndInherit: {cls.COOL_VALUE}")
return super().execute(xyz)
NODES_LIST: list[type[io.ComfyNode]] = [
V3TestNode,
V3LoraLoader,
@@ -283,3 +283,11 @@ NODES_LIST: list[type[io.ComfyNode]] = [
V3DummyEnd,
V3DummyEndInherit,
]
class v3TestExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return NODES_LIST
async def comfy_entrypoint() -> v3TestExtension:
return v3TestExtension()

View File

@@ -6,6 +6,7 @@ import os
import sys
import json
import hashlib
import inspect
import traceback
import math
import time
@@ -2162,6 +2163,35 @@ async def load_custom_node(module_path: str, ignore=set(), module_parent="custom
if hasattr(module, "NODE_DISPLAY_NAME_MAPPINGS") and getattr(module, "NODE_DISPLAY_NAME_MAPPINGS") is not None:
NODE_DISPLAY_NAME_MAPPINGS.update(module.NODE_DISPLAY_NAME_MAPPINGS)
return True
# V3 Extension Definition
elif hasattr(module, "comfy_entrypoint"):
entrypoint = getattr(module, "comfy_entrypoint")
if not callable(entrypoint):
logging.warning(f"comfy_entrypoint in {module_path} is not callable, skipping.")
return False
try:
if inspect.iscoroutinefunction(entrypoint):
extension = await entrypoint()
else:
extension = entrypoint()
if not isinstance(extension, io.ComfyExtension):
logging.warning(f"comfy_entrypoint in {module_path} did not return a ComfyExtension, skipping.")
return False
node_list = await extension.get_node_list()
if not isinstance(node_list, list):
logging.warning(f"comfy_entrypoint in {module_path} did not return a list of nodes, skipping.")
return False
for node_cls in node_list:
node_cls: io.ComfyNode
schema = node_cls.GET_SCHEMA()
if schema.node_id not in ignore:
NODE_CLASS_MAPPINGS[schema.node_id] = node_cls
node_cls.RELATIVE_PYTHON_MODULE = "{}.{}".format(module_parent, get_module_name(module_path))
if schema.display_name is not None:
NODE_DISPLAY_NAME_MAPPINGS[schema.node_id] = schema.display_name
except Exception as e:
logging.warning(f"Error while calling comfy_entrypoint in {module_path}: {e}")
return False
# V3 node definition
elif getattr(module, "NODES_LIST", None) is not None:
for node_cls in module.NODES_LIST: