mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-08-02 15:04:50 +08:00
Create a ComfyExtension
class for future growth
This commit is contained in:
@@ -1,5 +1,6 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
from typing import Type, TYPE_CHECKING
|
from typing import Type, TYPE_CHECKING
|
||||||
from comfy_api.internal import ComfyAPIBase
|
from comfy_api.internal import ComfyAPIBase
|
||||||
from comfy_api.internal.singleton import ProxiedSingleton
|
from comfy_api.internal.singleton import ProxiedSingleton
|
||||||
@@ -75,6 +76,19 @@ class ComfyAPI_latest(ComfyAPIBase):
|
|||||||
|
|
||||||
execution: Execution
|
execution: Execution
|
||||||
|
|
||||||
|
class ComfyExtension(ABC):
|
||||||
|
async def on_load(self) -> None:
|
||||||
|
"""
|
||||||
|
Called when an extension is loaded.
|
||||||
|
This should be used to initialize any global resources neeeded by the extension.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||||
|
"""
|
||||||
|
Returns a list of nodes that this extension provides.
|
||||||
|
"""
|
||||||
|
|
||||||
class Input:
|
class Input:
|
||||||
Image = ImageInput
|
Image = ImageInput
|
||||||
Audio = AudioInput
|
Audio = AudioInput
|
||||||
@@ -106,4 +120,5 @@ __all__ = [
|
|||||||
"Input",
|
"Input",
|
||||||
"InputImpl",
|
"InputImpl",
|
||||||
"Types",
|
"Types",
|
||||||
|
"ComfyExtension",
|
||||||
]
|
]
|
||||||
|
@@ -1,11 +1,12 @@
|
|||||||
import torch
|
import torch
|
||||||
import time
|
import time
|
||||||
from comfy_api.latest import io, ui, resources, _io
|
from comfy_api.latest import io, ui, resources, _io, ComfyExtension
|
||||||
import logging # noqa
|
import logging # noqa
|
||||||
import folder_paths
|
import folder_paths
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
import comfy.sd
|
import comfy.sd
|
||||||
import asyncio
|
import asyncio
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
@io.comfytype(io_type="XYZ")
|
@io.comfytype(io_type="XYZ")
|
||||||
class XYZ(io.ComfyTypeIO):
|
class XYZ(io.ComfyTypeIO):
|
||||||
@@ -273,7 +274,6 @@ class V3DummyEndInherit(V3DummyEnd):
|
|||||||
logging.info(f"V3DummyEndInherit: {cls.COOL_VALUE}")
|
logging.info(f"V3DummyEndInherit: {cls.COOL_VALUE}")
|
||||||
return super().execute(xyz)
|
return super().execute(xyz)
|
||||||
|
|
||||||
|
|
||||||
NODES_LIST: list[type[io.ComfyNode]] = [
|
NODES_LIST: list[type[io.ComfyNode]] = [
|
||||||
V3TestNode,
|
V3TestNode,
|
||||||
V3LoraLoader,
|
V3LoraLoader,
|
||||||
@@ -283,3 +283,11 @@ NODES_LIST: list[type[io.ComfyNode]] = [
|
|||||||
V3DummyEnd,
|
V3DummyEnd,
|
||||||
V3DummyEndInherit,
|
V3DummyEndInherit,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
class v3TestExtension(ComfyExtension):
|
||||||
|
@override
|
||||||
|
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||||
|
return NODES_LIST
|
||||||
|
|
||||||
|
async def comfy_entrypoint() -> v3TestExtension:
|
||||||
|
return v3TestExtension()
|
||||||
|
30
nodes.py
30
nodes.py
@@ -6,6 +6,7 @@ import os
|
|||||||
import sys
|
import sys
|
||||||
import json
|
import json
|
||||||
import hashlib
|
import hashlib
|
||||||
|
import inspect
|
||||||
import traceback
|
import traceback
|
||||||
import math
|
import math
|
||||||
import time
|
import time
|
||||||
@@ -2162,6 +2163,35 @@ async def load_custom_node(module_path: str, ignore=set(), module_parent="custom
|
|||||||
if hasattr(module, "NODE_DISPLAY_NAME_MAPPINGS") and getattr(module, "NODE_DISPLAY_NAME_MAPPINGS") is not None:
|
if hasattr(module, "NODE_DISPLAY_NAME_MAPPINGS") and getattr(module, "NODE_DISPLAY_NAME_MAPPINGS") is not None:
|
||||||
NODE_DISPLAY_NAME_MAPPINGS.update(module.NODE_DISPLAY_NAME_MAPPINGS)
|
NODE_DISPLAY_NAME_MAPPINGS.update(module.NODE_DISPLAY_NAME_MAPPINGS)
|
||||||
return True
|
return True
|
||||||
|
# V3 Extension Definition
|
||||||
|
elif hasattr(module, "comfy_entrypoint"):
|
||||||
|
entrypoint = getattr(module, "comfy_entrypoint")
|
||||||
|
if not callable(entrypoint):
|
||||||
|
logging.warning(f"comfy_entrypoint in {module_path} is not callable, skipping.")
|
||||||
|
return False
|
||||||
|
try:
|
||||||
|
if inspect.iscoroutinefunction(entrypoint):
|
||||||
|
extension = await entrypoint()
|
||||||
|
else:
|
||||||
|
extension = entrypoint()
|
||||||
|
if not isinstance(extension, io.ComfyExtension):
|
||||||
|
logging.warning(f"comfy_entrypoint in {module_path} did not return a ComfyExtension, skipping.")
|
||||||
|
return False
|
||||||
|
node_list = await extension.get_node_list()
|
||||||
|
if not isinstance(node_list, list):
|
||||||
|
logging.warning(f"comfy_entrypoint in {module_path} did not return a list of nodes, skipping.")
|
||||||
|
return False
|
||||||
|
for node_cls in node_list:
|
||||||
|
node_cls: io.ComfyNode
|
||||||
|
schema = node_cls.GET_SCHEMA()
|
||||||
|
if schema.node_id not in ignore:
|
||||||
|
NODE_CLASS_MAPPINGS[schema.node_id] = node_cls
|
||||||
|
node_cls.RELATIVE_PYTHON_MODULE = "{}.{}".format(module_parent, get_module_name(module_path))
|
||||||
|
if schema.display_name is not None:
|
||||||
|
NODE_DISPLAY_NAME_MAPPINGS[schema.node_id] = schema.display_name
|
||||||
|
except Exception as e:
|
||||||
|
logging.warning(f"Error while calling comfy_entrypoint in {module_path}: {e}")
|
||||||
|
return False
|
||||||
# V3 node definition
|
# V3 node definition
|
||||||
elif getattr(module, "NODES_LIST", None) is not None:
|
elif getattr(module, "NODES_LIST", None) is not None:
|
||||||
for node_cls in module.NODES_LIST:
|
for node_cls in module.NODES_LIST:
|
||||||
|
Reference in New Issue
Block a user