1
mirror of https://github.com/comfyanonymous/ComfyUI.git synced 2025-08-02 23:14:49 +08:00

Merge pull request #9103 from guill/js/extension-api-example

`ComfyExtension` Example (PR to v3-definition)
This commit is contained in:
Jedrzej Kosinski
2025-07-30 19:26:24 -07:00
committed by GitHub
3 changed files with 55 additions and 2 deletions

View File

@@ -1,5 +1,6 @@
from __future__ import annotations from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Type, TYPE_CHECKING from typing import Type, TYPE_CHECKING
from comfy_api.internal import ComfyAPIBase from comfy_api.internal import ComfyAPIBase
from comfy_api.internal.singleton import ProxiedSingleton from comfy_api.internal.singleton import ProxiedSingleton
@@ -75,6 +76,19 @@ class ComfyAPI_latest(ComfyAPIBase):
execution: Execution execution: Execution
class ComfyExtension(ABC):
async def on_load(self) -> None:
"""
Called when an extension is loaded.
This should be used to initialize any global resources neeeded by the extension.
"""
@abstractmethod
async def get_node_list(self) -> list[type[io.ComfyNode]]:
"""
Returns a list of nodes that this extension provides.
"""
class Input: class Input:
Image = ImageInput Image = ImageInput
Audio = AudioInput Audio = AudioInput
@@ -106,4 +120,5 @@ __all__ = [
"Input", "Input",
"InputImpl", "InputImpl",
"Types", "Types",
"ComfyExtension",
] ]

View File

@@ -1,9 +1,10 @@
import torch import torch
import time import time
from comfy_api.latest import io, ui, _io from comfy_api.latest import io, ui, _io, ComfyExtension
import logging # noqa import logging # noqa
import comfy.utils import comfy.utils
import asyncio import asyncio
from typing_extensions import override
@io.comfytype(io_type="XYZ") @io.comfytype(io_type="XYZ")
class XYZ(io.ComfyTypeIO): class XYZ(io.ComfyTypeIO):
@@ -271,7 +272,6 @@ class V3DummyEndInherit(V3DummyEnd):
logging.info(f"V3DummyEndInherit: {cls.COOL_VALUE}") logging.info(f"V3DummyEndInherit: {cls.COOL_VALUE}")
return super().execute(xyz) return super().execute(xyz)
NODES_LIST: list[type[io.ComfyNode]] = [ NODES_LIST: list[type[io.ComfyNode]] = [
V3TestNode, V3TestNode,
# V3LoraLoader, # V3LoraLoader,
@@ -281,3 +281,11 @@ NODES_LIST: list[type[io.ComfyNode]] = [
V3DummyEnd, V3DummyEnd,
V3DummyEndInherit, V3DummyEndInherit,
] ]
class v3TestExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return NODES_LIST
async def comfy_entrypoint() -> v3TestExtension:
return v3TestExtension()

View File

@@ -6,6 +6,7 @@ import os
import sys import sys
import json import json
import hashlib import hashlib
import inspect
import traceback import traceback
import math import math
import time import time
@@ -2162,6 +2163,35 @@ async def load_custom_node(module_path: str, ignore=set(), module_parent="custom
if hasattr(module, "NODE_DISPLAY_NAME_MAPPINGS") and getattr(module, "NODE_DISPLAY_NAME_MAPPINGS") is not None: if hasattr(module, "NODE_DISPLAY_NAME_MAPPINGS") and getattr(module, "NODE_DISPLAY_NAME_MAPPINGS") is not None:
NODE_DISPLAY_NAME_MAPPINGS.update(module.NODE_DISPLAY_NAME_MAPPINGS) NODE_DISPLAY_NAME_MAPPINGS.update(module.NODE_DISPLAY_NAME_MAPPINGS)
return True return True
# V3 Extension Definition
elif hasattr(module, "comfy_entrypoint"):
entrypoint = getattr(module, "comfy_entrypoint")
if not callable(entrypoint):
logging.warning(f"comfy_entrypoint in {module_path} is not callable, skipping.")
return False
try:
if inspect.iscoroutinefunction(entrypoint):
extension = await entrypoint()
else:
extension = entrypoint()
if not isinstance(extension, io.ComfyExtension):
logging.warning(f"comfy_entrypoint in {module_path} did not return a ComfyExtension, skipping.")
return False
node_list = await extension.get_node_list()
if not isinstance(node_list, list):
logging.warning(f"comfy_entrypoint in {module_path} did not return a list of nodes, skipping.")
return False
for node_cls in node_list:
node_cls: io.ComfyNode
schema = node_cls.GET_SCHEMA()
if schema.node_id not in ignore:
NODE_CLASS_MAPPINGS[schema.node_id] = node_cls
node_cls.RELATIVE_PYTHON_MODULE = "{}.{}".format(module_parent, get_module_name(module_path))
if schema.display_name is not None:
NODE_DISPLAY_NAME_MAPPINGS[schema.node_id] = schema.display_name
except Exception as e:
logging.warning(f"Error while calling comfy_entrypoint in {module_path}: {e}")
return False
# V3 node definition # V3 node definition
elif getattr(module, "NODES_LIST", None) is not None: elif getattr(module, "NODES_LIST", None) is not None:
for node_cls in module.NODES_LIST: for node_cls in module.NODES_LIST: