mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-08-02 15:04:50 +08:00
Compare commits
2 Commits
97b8a2c26a
...
5ee381c058
Author | SHA1 | Date | |
---|---|---|---|
|
5ee381c058 | ||
|
4887743a2a |
@@ -5,3 +5,146 @@ from .api_registry import (
|
|||||||
register_versions as register_versions,
|
register_versions as register_versions,
|
||||||
get_all_versions as get_all_versions,
|
get_all_versions as get_all_versions,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from dataclasses import asdict
|
||||||
|
from typing import Callable, Optional
|
||||||
|
|
||||||
|
|
||||||
|
def first_real_override(cls: type, name: str, *, base: type=None) -> Optional[Callable]:
|
||||||
|
"""Return the *callable* override of `name` visible on `cls`, or None if every
|
||||||
|
implementation up to (and including) `base` is the placeholder defined on `base`.
|
||||||
|
|
||||||
|
If base is not provided, it will assume cls has a GET_BASE_CLASS
|
||||||
|
"""
|
||||||
|
if base is None:
|
||||||
|
if not hasattr(cls, "GET_BASE_CLASS"):
|
||||||
|
raise ValueError("base is required if cls does not have a GET_BASE_CLASS; is this a valid ComfyNode subclass?")
|
||||||
|
base = cls.GET_BASE_CLASS()
|
||||||
|
base_attr = getattr(base, name, None)
|
||||||
|
if base_attr is None:
|
||||||
|
return None
|
||||||
|
base_func = base_attr.__func__
|
||||||
|
for c in cls.mro(): # NodeB, NodeA, ComfyNode, object …
|
||||||
|
if c is base: # reached the placeholder – we're done
|
||||||
|
break
|
||||||
|
if name in c.__dict__: # first class that *defines* the attr
|
||||||
|
func = getattr(c, name).__func__
|
||||||
|
if func is not base_func: # real override
|
||||||
|
return getattr(cls, name) # bound to *cls*
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class _ComfyNodeInternal:
|
||||||
|
"""Class that all V3-based APIs inherit from for ComfyNode.
|
||||||
|
|
||||||
|
This is intended to only be referenced within execution.py, as it has to handle all V3 APIs going forward."""
|
||||||
|
@classmethod
|
||||||
|
def GET_NODE_INFO_V1(cls):
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class _NodeOutputInternal:
|
||||||
|
"""Class that all V3-based APIs inherit from for NodeOutput.
|
||||||
|
|
||||||
|
This is intended to only be referenced within execution.py, as it has to handle all V3 APIs going forward."""
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
def as_pruned_dict(dataclass_obj):
|
||||||
|
'''Return dict of dataclass object with pruned None values.'''
|
||||||
|
return prune_dict(asdict(dataclass_obj))
|
||||||
|
|
||||||
|
def prune_dict(d: dict):
|
||||||
|
return {k: v for k,v in d.items() if v is not None}
|
||||||
|
|
||||||
|
|
||||||
|
def is_class(obj):
|
||||||
|
'''
|
||||||
|
Returns True if is a class type.
|
||||||
|
Returns False if is a class instance.
|
||||||
|
'''
|
||||||
|
return isinstance(obj, type)
|
||||||
|
|
||||||
|
|
||||||
|
def copy_class(cls: type) -> type:
|
||||||
|
'''
|
||||||
|
Copy a class and its attributes.
|
||||||
|
'''
|
||||||
|
if cls is None:
|
||||||
|
return None
|
||||||
|
cls_dict = {
|
||||||
|
k: v for k, v in cls.__dict__.items()
|
||||||
|
if k not in ('__dict__', '__weakref__', '__module__', '__doc__')
|
||||||
|
}
|
||||||
|
# new class
|
||||||
|
new_cls = type(
|
||||||
|
cls.__name__,
|
||||||
|
(cls,),
|
||||||
|
cls_dict
|
||||||
|
)
|
||||||
|
# metadata preservation
|
||||||
|
new_cls.__module__ = cls.__module__
|
||||||
|
new_cls.__doc__ = cls.__doc__
|
||||||
|
return new_cls
|
||||||
|
|
||||||
|
|
||||||
|
class classproperty(object):
|
||||||
|
def __init__(self, f):
|
||||||
|
self.f = f
|
||||||
|
def __get__(self, obj, owner):
|
||||||
|
return self.f(owner)
|
||||||
|
|
||||||
|
|
||||||
|
# NOTE: this was ai generated and validated by hand
|
||||||
|
def shallow_clone_class(cls, new_name=None):
|
||||||
|
'''
|
||||||
|
Shallow clone a class while preserving super() functionality.
|
||||||
|
'''
|
||||||
|
new_name = new_name or f"{cls.__name__}Clone"
|
||||||
|
# Include the original class in the bases to maintain proper inheritance
|
||||||
|
new_bases = (cls,) + cls.__bases__
|
||||||
|
return type(new_name, new_bases, dict(cls.__dict__))
|
||||||
|
|
||||||
|
# NOTE: this was ai generated and validated by hand
|
||||||
|
def lock_class(cls):
|
||||||
|
'''
|
||||||
|
Lock a class so that its top-levelattributes cannot be modified.
|
||||||
|
'''
|
||||||
|
# Locked instance __setattr__
|
||||||
|
def locked_instance_setattr(self, name, value):
|
||||||
|
raise AttributeError(
|
||||||
|
f"Cannot set attribute '{name}' on immutable instance of {type(self).__name__}"
|
||||||
|
)
|
||||||
|
# Locked metaclass
|
||||||
|
class LockedMeta(type(cls)):
|
||||||
|
def __setattr__(cls_, name, value):
|
||||||
|
raise AttributeError(
|
||||||
|
f"Cannot modify class attribute '{name}' on locked class '{cls_.__name__}'"
|
||||||
|
)
|
||||||
|
# Rebuild class with locked behavior
|
||||||
|
locked_dict = dict(cls.__dict__)
|
||||||
|
locked_dict['__setattr__'] = locked_instance_setattr
|
||||||
|
|
||||||
|
return LockedMeta(cls.__name__, cls.__bases__, locked_dict)
|
||||||
|
|
||||||
|
|
||||||
|
def make_locked_method_func(type_obj, func, class_clone):
|
||||||
|
"""
|
||||||
|
Returns a function that, when called with **inputs, will execute:
|
||||||
|
getattr(type_obj, func).__func__(lock_class(class_clone), **inputs)
|
||||||
|
|
||||||
|
Supports both synchronous and asynchronous methods.
|
||||||
|
"""
|
||||||
|
locked_class = lock_class(class_clone)
|
||||||
|
method = getattr(type_obj, func).__func__
|
||||||
|
|
||||||
|
# Check if the original method is async
|
||||||
|
if asyncio.iscoroutinefunction(method):
|
||||||
|
async def wrapped_async_func(**inputs):
|
||||||
|
return await method(locked_class, **inputs)
|
||||||
|
return wrapped_async_func
|
||||||
|
else:
|
||||||
|
def wrapped_func(**inputs):
|
||||||
|
return method(locked_class, **inputs)
|
||||||
|
return wrapped_func
|
||||||
|
@@ -1,5 +1,6 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
from typing import Type, TYPE_CHECKING
|
from typing import Type, TYPE_CHECKING
|
||||||
from comfy_api.internal import ComfyAPIBase
|
from comfy_api.internal import ComfyAPIBase
|
||||||
from comfy_api.internal.singleton import ProxiedSingleton
|
from comfy_api.internal.singleton import ProxiedSingleton
|
||||||
@@ -7,6 +8,9 @@ from comfy_api.internal.async_to_sync import create_sync_class
|
|||||||
from comfy_api.latest._input import ImageInput, AudioInput, MaskInput, LatentInput, VideoInput
|
from comfy_api.latest._input import ImageInput, AudioInput, MaskInput, LatentInput, VideoInput
|
||||||
from comfy_api.latest._input_impl import VideoFromFile, VideoFromComponents
|
from comfy_api.latest._input_impl import VideoFromFile, VideoFromComponents
|
||||||
from comfy_api.latest._util import VideoCodec, VideoContainer, VideoComponents
|
from comfy_api.latest._util import VideoCodec, VideoContainer, VideoComponents
|
||||||
|
from comfy_api.latest._io import _IO as io #noqa: F401
|
||||||
|
from comfy_api.latest._ui import _UI as ui #noqa: F401
|
||||||
|
# from comfy_api.latest._resources import _RESOURCES as resources #noqa: F401
|
||||||
from comfy_execution.utils import get_executing_context
|
from comfy_execution.utils import get_executing_context
|
||||||
from comfy_execution.progress import get_progress_state, PreviewImageTuple
|
from comfy_execution.progress import get_progress_state, PreviewImageTuple
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
@@ -72,6 +76,19 @@ class ComfyAPI_latest(ComfyAPIBase):
|
|||||||
|
|
||||||
execution: Execution
|
execution: Execution
|
||||||
|
|
||||||
|
class ComfyExtension(ABC):
|
||||||
|
async def on_load(self) -> None:
|
||||||
|
"""
|
||||||
|
Called when an extension is loaded.
|
||||||
|
This should be used to initialize any global resources neeeded by the extension.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||||
|
"""
|
||||||
|
Returns a list of nodes that this extension provides.
|
||||||
|
"""
|
||||||
|
|
||||||
class Input:
|
class Input:
|
||||||
Image = ImageInput
|
Image = ImageInput
|
||||||
Audio = AudioInput
|
Audio = AudioInput
|
||||||
@@ -103,4 +120,5 @@ __all__ = [
|
|||||||
"Input",
|
"Input",
|
||||||
"InputImpl",
|
"InputImpl",
|
||||||
"Types",
|
"Types",
|
||||||
|
"ComfyExtension",
|
||||||
]
|
]
|
||||||
|
1618
comfy_api/latest/_io.py
Normal file
1618
comfy_api/latest/_io.py
Normal file
File diff suppressed because it is too large
Load Diff
72
comfy_api/latest/_resources.py
Normal file
72
comfy_api/latest/_resources.py
Normal file
@@ -0,0 +1,72 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
import comfy.utils
|
||||||
|
import folder_paths
|
||||||
|
import logging
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Any
|
||||||
|
import torch
|
||||||
|
|
||||||
|
class ResourceKey(ABC):
|
||||||
|
Type = Any
|
||||||
|
def __init__(self):
|
||||||
|
...
|
||||||
|
|
||||||
|
class TorchDictFolderFilename(ResourceKey):
|
||||||
|
'''Key for requesting a torch file via file_name from a folder category.'''
|
||||||
|
Type = dict[str, torch.Tensor]
|
||||||
|
def __init__(self, folder_name: str, file_name: str):
|
||||||
|
self.folder_name = folder_name
|
||||||
|
self.file_name = file_name
|
||||||
|
|
||||||
|
def __hash__(self):
|
||||||
|
return hash((self.folder_name, self.file_name))
|
||||||
|
|
||||||
|
def __eq__(self, other: object) -> bool:
|
||||||
|
if not isinstance(other, TorchDictFolderFilename):
|
||||||
|
return False
|
||||||
|
return self.folder_name == other.folder_name and self.file_name == other.file_name
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return f"{self.folder_name} -> {self.file_name}"
|
||||||
|
|
||||||
|
class Resources(ABC):
|
||||||
|
def __init__(self):
|
||||||
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get(self, key: ResourceKey, default: Any=...) -> Any:
|
||||||
|
pass
|
||||||
|
|
||||||
|
class ResourcesLocal(Resources):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.local_resources: dict[ResourceKey, Any] = {}
|
||||||
|
|
||||||
|
def get(self, key: ResourceKey, default: Any=...) -> Any:
|
||||||
|
cached = self.local_resources.get(key, None)
|
||||||
|
if cached is not None:
|
||||||
|
logging.info(f"Using cached resource '{key}'")
|
||||||
|
return cached
|
||||||
|
logging.info(f"Loading resource '{key}'")
|
||||||
|
to_return = None
|
||||||
|
if isinstance(key, TorchDictFolderFilename):
|
||||||
|
if default is ...:
|
||||||
|
to_return = comfy.utils.load_torch_file(folder_paths.get_full_path_or_raise(key.folder_name, key.file_name), safe_load=True)
|
||||||
|
else:
|
||||||
|
full_path = folder_paths.get_full_path(key.folder_name, key.file_name)
|
||||||
|
if full_path is not None:
|
||||||
|
to_return = comfy.utils.load_torch_file(full_path, safe_load=True)
|
||||||
|
|
||||||
|
if to_return is not None:
|
||||||
|
self.local_resources[key] = to_return
|
||||||
|
return to_return
|
||||||
|
if default is not ...:
|
||||||
|
return default
|
||||||
|
raise Exception(f"Unsupported resource key type: {type(key)}")
|
||||||
|
|
||||||
|
|
||||||
|
class _RESOURCES:
|
||||||
|
ResourceKey = ResourceKey
|
||||||
|
TorchDictFolderFilename = TorchDictFolderFilename
|
||||||
|
Resources = Resources
|
||||||
|
ResourcesLocal = ResourcesLocal
|
457
comfy_api/latest/_ui.py
Normal file
457
comfy_api/latest/_ui.py
Normal file
@@ -0,0 +1,457 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
from io import BytesIO
|
||||||
|
from typing import Type
|
||||||
|
|
||||||
|
import av
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torchaudio
|
||||||
|
from PIL import Image as PILImage
|
||||||
|
from PIL.PngImagePlugin import PngInfo
|
||||||
|
|
||||||
|
import folder_paths
|
||||||
|
|
||||||
|
# used for image preview
|
||||||
|
from comfy.cli_args import args
|
||||||
|
from comfy_api.latest._io import ComfyNode, FolderType, Image, _UIOutput
|
||||||
|
|
||||||
|
|
||||||
|
class SavedResult(dict):
|
||||||
|
def __init__(self, filename: str, subfolder: str, type: FolderType):
|
||||||
|
super().__init__(filename=filename, subfolder=subfolder,type=type.value)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def filename(self) -> str:
|
||||||
|
return self["filename"]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def subfolder(self) -> str:
|
||||||
|
return self["subfolder"]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def type(self) -> FolderType:
|
||||||
|
return FolderType(self["type"])
|
||||||
|
|
||||||
|
|
||||||
|
class SavedImages(_UIOutput):
|
||||||
|
"""A UI output class to represent one or more saved images, potentially animated."""
|
||||||
|
def __init__(self, results: list[SavedResult], is_animated: bool = False):
|
||||||
|
super().__init__()
|
||||||
|
self.results = results
|
||||||
|
self.is_animated = is_animated
|
||||||
|
|
||||||
|
def as_dict(self) -> dict:
|
||||||
|
data = {"images": self.results}
|
||||||
|
if self.is_animated:
|
||||||
|
data["animated"] = (True,)
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
class SavedAudios(_UIOutput):
|
||||||
|
"""UI wrapper around one or more audio files on disk (FLAC / MP3 / Opus)."""
|
||||||
|
def __init__(self, results: list[SavedResult]):
|
||||||
|
super().__init__()
|
||||||
|
self.results = results
|
||||||
|
|
||||||
|
def as_dict(self) -> dict:
|
||||||
|
return {"audio": self.results}
|
||||||
|
|
||||||
|
|
||||||
|
def _get_directory_by_folder_type(folder_type: FolderType) -> str:
|
||||||
|
if folder_type == FolderType.input:
|
||||||
|
return folder_paths.get_input_directory()
|
||||||
|
if folder_type == FolderType.output:
|
||||||
|
return folder_paths.get_output_directory()
|
||||||
|
return folder_paths.get_temp_directory()
|
||||||
|
|
||||||
|
|
||||||
|
class ImageSaveHelper:
|
||||||
|
"""A helper class with static methods to handle image saving and metadata."""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _convert_tensor_to_pil(image_tensor: torch.Tensor) -> PILImage.Image:
|
||||||
|
"""Converts a single torch tensor to a PIL Image."""
|
||||||
|
return PILImage.fromarray(np.clip(255.0 * image_tensor.cpu().numpy(), 0, 255).astype(np.uint8))
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _create_png_metadata(cls: Type[ComfyNode] | None) -> PngInfo | None:
|
||||||
|
"""Creates a PngInfo object with prompt and extra_pnginfo."""
|
||||||
|
if args.disable_metadata or cls is None or not cls.hidden:
|
||||||
|
return None
|
||||||
|
metadata = PngInfo()
|
||||||
|
if cls.hidden.prompt:
|
||||||
|
metadata.add_text("prompt", json.dumps(cls.hidden.prompt))
|
||||||
|
if cls.hidden.extra_pnginfo:
|
||||||
|
for x in cls.hidden.extra_pnginfo:
|
||||||
|
metadata.add_text(x, json.dumps(cls.hidden.extra_pnginfo[x]))
|
||||||
|
return metadata
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _create_animated_png_metadata(cls: Type[ComfyNode] | None) -> PngInfo | None:
|
||||||
|
"""Creates a PngInfo object with prompt and extra_pnginfo for animated PNGs (APNG)."""
|
||||||
|
if args.disable_metadata or cls is None or not cls.hidden:
|
||||||
|
return None
|
||||||
|
metadata = PngInfo()
|
||||||
|
if cls.hidden.prompt:
|
||||||
|
metadata.add(
|
||||||
|
b"comf",
|
||||||
|
"prompt".encode("latin-1", "strict")
|
||||||
|
+ b"\0"
|
||||||
|
+ json.dumps(cls.hidden.prompt).encode("latin-1", "strict"),
|
||||||
|
after_idat=True,
|
||||||
|
)
|
||||||
|
if cls.hidden.extra_pnginfo:
|
||||||
|
for x in cls.hidden.extra_pnginfo:
|
||||||
|
metadata.add(
|
||||||
|
b"comf",
|
||||||
|
x.encode("latin-1", "strict")
|
||||||
|
+ b"\0"
|
||||||
|
+ json.dumps(cls.hidden.extra_pnginfo[x]).encode("latin-1", "strict"),
|
||||||
|
after_idat=True,
|
||||||
|
)
|
||||||
|
return metadata
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _create_webp_metadata(pil_image: PILImage.Image, cls: Type[ComfyNode] | None) -> PILImage.Exif:
|
||||||
|
"""Creates EXIF metadata bytes for WebP images."""
|
||||||
|
exif_data = pil_image.getexif()
|
||||||
|
if args.disable_metadata or cls is None or cls.hidden is None:
|
||||||
|
return exif_data
|
||||||
|
if cls.hidden.prompt is not None:
|
||||||
|
exif_data[0x0110] = "prompt:{}".format(json.dumps(cls.hidden.prompt)) # EXIF 0x0110 = Model
|
||||||
|
if cls.hidden.extra_pnginfo is not None:
|
||||||
|
inital_exif_tag = 0x010F # EXIF 0x010f = Make
|
||||||
|
for key, value in cls.hidden.extra_pnginfo.items():
|
||||||
|
exif_data[inital_exif_tag] = "{}:{}".format(key, json.dumps(value))
|
||||||
|
inital_exif_tag -= 1
|
||||||
|
return exif_data
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def save_images(
|
||||||
|
images, filename_prefix: str, folder_type: FolderType, cls: Type[ComfyNode] | None, compress_level = 4,
|
||||||
|
) -> list[SavedResult]:
|
||||||
|
"""Saves a batch of images as individual PNG files."""
|
||||||
|
full_output_folder, filename, counter, subfolder, _ = folder_paths.get_save_image_path(
|
||||||
|
filename_prefix, _get_directory_by_folder_type(folder_type), images[0].shape[1], images[0].shape[0]
|
||||||
|
)
|
||||||
|
results = []
|
||||||
|
metadata = ImageSaveHelper._create_png_metadata(cls)
|
||||||
|
for batch_number, image_tensor in enumerate(images):
|
||||||
|
img = ImageSaveHelper._convert_tensor_to_pil(image_tensor)
|
||||||
|
filename_with_batch_num = filename.replace("%batch_num%", str(batch_number))
|
||||||
|
file = f"{filename_with_batch_num}_{counter:05}_.png"
|
||||||
|
img.save(os.path.join(full_output_folder, file), pnginfo=metadata, compress_level=compress_level)
|
||||||
|
results.append(SavedResult(file, subfolder, folder_type))
|
||||||
|
counter += 1
|
||||||
|
return results
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_save_images_ui(images, filename_prefix: str, cls: Type[ComfyNode] | None, compress_level=4) -> SavedImages:
|
||||||
|
"""Saves a batch of images and returns a UI object for the node output."""
|
||||||
|
return SavedImages(
|
||||||
|
ImageSaveHelper.save_images(
|
||||||
|
images,
|
||||||
|
filename_prefix=filename_prefix,
|
||||||
|
folder_type=FolderType.output,
|
||||||
|
cls=cls,
|
||||||
|
compress_level=compress_level,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def save_animated_png(
|
||||||
|
images, filename_prefix: str, folder_type: FolderType, cls: Type[ComfyNode] | None, fps: float, compress_level: int
|
||||||
|
) -> SavedResult:
|
||||||
|
"""Saves a batch of images as a single animated PNG."""
|
||||||
|
full_output_folder, filename, counter, subfolder, _ = folder_paths.get_save_image_path(
|
||||||
|
filename_prefix, _get_directory_by_folder_type(folder_type), images[0].shape[1], images[0].shape[0]
|
||||||
|
)
|
||||||
|
pil_images = [ImageSaveHelper._convert_tensor_to_pil(img) for img in images]
|
||||||
|
metadata = ImageSaveHelper._create_animated_png_metadata(cls)
|
||||||
|
file = f"{filename}_{counter:05}_.png"
|
||||||
|
save_path = os.path.join(full_output_folder, file)
|
||||||
|
pil_images[0].save(
|
||||||
|
save_path,
|
||||||
|
pnginfo=metadata,
|
||||||
|
compress_level=compress_level,
|
||||||
|
save_all=True,
|
||||||
|
duration=int(1000.0 / fps),
|
||||||
|
append_images=pil_images[1:],
|
||||||
|
)
|
||||||
|
return SavedResult(file, subfolder, folder_type)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_save_animated_png_ui(
|
||||||
|
images, filename_prefix: str, cls: Type[ComfyNode] | None, fps: float, compress_level: int
|
||||||
|
) -> SavedImages:
|
||||||
|
"""Saves an animated PNG and returns a UI object for the node output."""
|
||||||
|
result = ImageSaveHelper.save_animated_png(
|
||||||
|
images,
|
||||||
|
filename_prefix=filename_prefix,
|
||||||
|
folder_type=FolderType.output,
|
||||||
|
cls=cls,
|
||||||
|
fps=fps,
|
||||||
|
compress_level=compress_level,
|
||||||
|
)
|
||||||
|
return SavedImages([result], is_animated=len(images) > 1)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def save_animated_webp(
|
||||||
|
images,
|
||||||
|
filename_prefix: str,
|
||||||
|
folder_type: FolderType,
|
||||||
|
cls: Type[ComfyNode] | None,
|
||||||
|
fps: float,
|
||||||
|
lossless: bool,
|
||||||
|
quality: int,
|
||||||
|
method: int,
|
||||||
|
) -> SavedResult:
|
||||||
|
"""Saves a batch of images as a single animated WebP."""
|
||||||
|
full_output_folder, filename, counter, subfolder, _ = folder_paths.get_save_image_path(
|
||||||
|
filename_prefix, _get_directory_by_folder_type(folder_type), images[0].shape[1], images[0].shape[0]
|
||||||
|
)
|
||||||
|
pil_images = [ImageSaveHelper._convert_tensor_to_pil(img) for img in images]
|
||||||
|
pil_exif = ImageSaveHelper._create_webp_metadata(pil_images[0], cls)
|
||||||
|
file = f"{filename}_{counter:05}_.webp"
|
||||||
|
pil_images[0].save(
|
||||||
|
os.path.join(full_output_folder, file),
|
||||||
|
save_all=True,
|
||||||
|
duration=int(1000.0 / fps),
|
||||||
|
append_images=pil_images[1:],
|
||||||
|
exif=pil_exif,
|
||||||
|
lossless=lossless,
|
||||||
|
quality=quality,
|
||||||
|
method=method,
|
||||||
|
)
|
||||||
|
return SavedResult(file, subfolder, folder_type)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_save_animated_webp_ui(
|
||||||
|
images,
|
||||||
|
filename_prefix: str,
|
||||||
|
cls: Type[ComfyNode] | None,
|
||||||
|
fps: float,
|
||||||
|
lossless: bool,
|
||||||
|
quality: int,
|
||||||
|
method: int,
|
||||||
|
) -> SavedImages:
|
||||||
|
"""Saves an animated WebP and returns a UI object for the node output."""
|
||||||
|
result = ImageSaveHelper.save_animated_webp(
|
||||||
|
images,
|
||||||
|
filename_prefix=filename_prefix,
|
||||||
|
folder_type=FolderType.output,
|
||||||
|
cls=cls,
|
||||||
|
fps=fps,
|
||||||
|
lossless=lossless,
|
||||||
|
quality=quality,
|
||||||
|
method=method,
|
||||||
|
)
|
||||||
|
return SavedImages([result], is_animated=len(images) > 1)
|
||||||
|
|
||||||
|
|
||||||
|
class AudioSaveHelper:
|
||||||
|
"""A helper class with static methods to handle audio saving and metadata."""
|
||||||
|
_OPUS_RATES = [8000, 12000, 16000, 24000, 48000]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def save_audio(
|
||||||
|
audio: dict,
|
||||||
|
filename_prefix: str,
|
||||||
|
folder_type: FolderType,
|
||||||
|
cls: Type[ComfyNode] | None,
|
||||||
|
format: str = "flac",
|
||||||
|
quality: str = "128k",
|
||||||
|
) -> list[SavedResult]:
|
||||||
|
full_output_folder, filename, counter, subfolder, _ = folder_paths.get_save_image_path(
|
||||||
|
filename_prefix, _get_directory_by_folder_type(folder_type)
|
||||||
|
)
|
||||||
|
|
||||||
|
metadata = {}
|
||||||
|
if not args.disable_metadata and cls is not None:
|
||||||
|
if cls.hidden.prompt is not None:
|
||||||
|
metadata["prompt"] = json.dumps(cls.hidden.prompt)
|
||||||
|
if cls.hidden.extra_pnginfo is not None:
|
||||||
|
for x in cls.hidden.extra_pnginfo:
|
||||||
|
metadata[x] = json.dumps(cls.hidden.extra_pnginfo[x])
|
||||||
|
|
||||||
|
results = []
|
||||||
|
for batch_number, waveform in enumerate(audio["waveform"].cpu()):
|
||||||
|
filename_with_batch_num = filename.replace("%batch_num%", str(batch_number))
|
||||||
|
file = f"{filename_with_batch_num}_{counter:05}_.{format}"
|
||||||
|
output_path = os.path.join(full_output_folder, file)
|
||||||
|
|
||||||
|
# Use original sample rate initially
|
||||||
|
sample_rate = audio["sample_rate"]
|
||||||
|
|
||||||
|
# Handle Opus sample rate requirements
|
||||||
|
if format == "opus":
|
||||||
|
if sample_rate > 48000:
|
||||||
|
sample_rate = 48000
|
||||||
|
elif sample_rate not in AudioSaveHelper._OPUS_RATES:
|
||||||
|
# Find the next highest supported rate
|
||||||
|
for rate in sorted(AudioSaveHelper._OPUS_RATES):
|
||||||
|
if rate > sample_rate:
|
||||||
|
sample_rate = rate
|
||||||
|
break
|
||||||
|
if sample_rate not in AudioSaveHelper._OPUS_RATES: # Fallback if still not supported
|
||||||
|
sample_rate = 48000
|
||||||
|
|
||||||
|
# Resample if necessary
|
||||||
|
if sample_rate != audio["sample_rate"]:
|
||||||
|
waveform = torchaudio.functional.resample(waveform, audio["sample_rate"], sample_rate)
|
||||||
|
|
||||||
|
# Create output with specified format
|
||||||
|
output_buffer = BytesIO()
|
||||||
|
output_container = av.open(output_buffer, mode="w", format=format)
|
||||||
|
|
||||||
|
# Set metadata on the container
|
||||||
|
for key, value in metadata.items():
|
||||||
|
output_container.metadata[key] = value
|
||||||
|
|
||||||
|
# Set up the output stream with appropriate properties
|
||||||
|
if format == "opus":
|
||||||
|
out_stream = output_container.add_stream("libopus", rate=sample_rate)
|
||||||
|
if quality == "64k":
|
||||||
|
out_stream.bit_rate = 64000
|
||||||
|
elif quality == "96k":
|
||||||
|
out_stream.bit_rate = 96000
|
||||||
|
elif quality == "128k":
|
||||||
|
out_stream.bit_rate = 128000
|
||||||
|
elif quality == "192k":
|
||||||
|
out_stream.bit_rate = 192000
|
||||||
|
elif quality == "320k":
|
||||||
|
out_stream.bit_rate = 320000
|
||||||
|
elif format == "mp3":
|
||||||
|
out_stream = output_container.add_stream("libmp3lame", rate=sample_rate)
|
||||||
|
if quality == "V0":
|
||||||
|
# TODO i would really love to support V3 and V5 but there doesn't seem to be a way to set the qscale level, the property below is a bool
|
||||||
|
out_stream.codec_context.qscale = 1
|
||||||
|
elif quality == "128k":
|
||||||
|
out_stream.bit_rate = 128000
|
||||||
|
elif quality == "320k":
|
||||||
|
out_stream.bit_rate = 320000
|
||||||
|
else: # format == "flac":
|
||||||
|
out_stream = output_container.add_stream("flac", rate=sample_rate)
|
||||||
|
|
||||||
|
frame = av.AudioFrame.from_ndarray(
|
||||||
|
waveform.movedim(0, 1).reshape(1, -1).float().numpy(),
|
||||||
|
format="flt",
|
||||||
|
layout="mono" if waveform.shape[0] == 1 else "stereo",
|
||||||
|
)
|
||||||
|
frame.sample_rate = sample_rate
|
||||||
|
frame.pts = 0
|
||||||
|
output_container.mux(out_stream.encode(frame))
|
||||||
|
|
||||||
|
# Flush encoder
|
||||||
|
output_container.mux(out_stream.encode(None))
|
||||||
|
|
||||||
|
# Close containers
|
||||||
|
output_container.close()
|
||||||
|
|
||||||
|
# Write the output to file
|
||||||
|
output_buffer.seek(0)
|
||||||
|
with open(output_path, "wb") as f:
|
||||||
|
f.write(output_buffer.getbuffer())
|
||||||
|
|
||||||
|
results.append(SavedResult(file, subfolder, folder_type))
|
||||||
|
counter += 1
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_save_audio_ui(
|
||||||
|
audio, filename_prefix: str, cls: Type[ComfyNode] | None, format: str = "flac", quality: str = "128k",
|
||||||
|
) -> SavedAudios:
|
||||||
|
"""Save and instantly wrap for UI."""
|
||||||
|
return SavedAudios(
|
||||||
|
AudioSaveHelper.save_audio(
|
||||||
|
audio,
|
||||||
|
filename_prefix=filename_prefix,
|
||||||
|
folder_type=FolderType.output,
|
||||||
|
cls=cls,
|
||||||
|
format=format,
|
||||||
|
quality=quality,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class PreviewImage(_UIOutput):
|
||||||
|
def __init__(self, image: Image.Type, animated: bool = False, cls: Type[ComfyNode] = None, **kwargs):
|
||||||
|
self.values = ImageSaveHelper.save_images(
|
||||||
|
image,
|
||||||
|
filename_prefix="ComfyUI_temp_" + ''.join(random.choice("abcdefghijklmnopqrstupvxyz") for _ in range(5)),
|
||||||
|
folder_type=FolderType.temp,
|
||||||
|
cls=cls,
|
||||||
|
compress_level=1,
|
||||||
|
)
|
||||||
|
self.animated = animated
|
||||||
|
|
||||||
|
def as_dict(self):
|
||||||
|
return {
|
||||||
|
"images": self.values,
|
||||||
|
"animated": (self.animated,)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class PreviewMask(PreviewImage):
|
||||||
|
def __init__(self, mask: PreviewMask.Type, animated: bool=False, cls: ComfyNode=None, **kwargs):
|
||||||
|
preview = mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])).movedim(1, -1).expand(-1, -1, -1, 3)
|
||||||
|
super().__init__(preview, animated, cls, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class PreviewAudio(_UIOutput):
|
||||||
|
def __init__(self, audio: dict, cls: Type[ComfyNode] = None, **kwargs):
|
||||||
|
self.values = AudioSaveHelper.save_audio(
|
||||||
|
audio,
|
||||||
|
filename_prefix="ComfyUI_temp_" + "".join(random.choice("abcdefghijklmnopqrstuvwxyz") for _ in range(5)),
|
||||||
|
folder_type=FolderType.temp,
|
||||||
|
cls=cls,
|
||||||
|
format="flac",
|
||||||
|
quality="128k",
|
||||||
|
)
|
||||||
|
|
||||||
|
def as_dict(self) -> dict:
|
||||||
|
return {"audio": self.values}
|
||||||
|
|
||||||
|
|
||||||
|
class PreviewVideo(_UIOutput):
|
||||||
|
def __init__(self, values: list[SavedResult | dict], **kwargs):
|
||||||
|
self.values = values
|
||||||
|
|
||||||
|
def as_dict(self):
|
||||||
|
return {"images": self.values, "animated": (True,)}
|
||||||
|
|
||||||
|
|
||||||
|
class PreviewUI3D(_UIOutput):
|
||||||
|
def __init__(self, model_file, camera_info, **kwargs):
|
||||||
|
self.model_file = model_file
|
||||||
|
self.camera_info = camera_info
|
||||||
|
|
||||||
|
def as_dict(self):
|
||||||
|
return {"result": [self.model_file, self.camera_info]}
|
||||||
|
|
||||||
|
|
||||||
|
class PreviewText(_UIOutput):
|
||||||
|
def __init__(self, value: str, **kwargs):
|
||||||
|
self.value = value
|
||||||
|
|
||||||
|
def as_dict(self):
|
||||||
|
return {"text": (self.value,)}
|
||||||
|
|
||||||
|
|
||||||
|
class _UI:
|
||||||
|
SavedResult = SavedResult
|
||||||
|
SavedImages = SavedImages
|
||||||
|
SavedAudios = SavedAudios
|
||||||
|
ImageSaveHelper = ImageSaveHelper
|
||||||
|
AudioSaveHelper = AudioSaveHelper
|
||||||
|
PreviewImage = PreviewImage
|
||||||
|
PreviewMask = PreviewMask
|
||||||
|
PreviewAudio = PreviewAudio
|
||||||
|
PreviewVideo = PreviewVideo
|
||||||
|
PreviewUI3D = PreviewUI3D
|
||||||
|
PreviewText = PreviewText
|
@@ -6,6 +6,7 @@ from comfy_api.latest import (
|
|||||||
)
|
)
|
||||||
from typing import Type, TYPE_CHECKING
|
from typing import Type, TYPE_CHECKING
|
||||||
from comfy_api.internal.async_to_sync import create_sync_class
|
from comfy_api.internal.async_to_sync import create_sync_class
|
||||||
|
from comfy_api.latest import io, ui, ComfyExtension #noqa: F401
|
||||||
|
|
||||||
|
|
||||||
class ComfyAPIAdapter_v0_0_2(ComfyAPI_latest):
|
class ComfyAPIAdapter_v0_0_2(ComfyAPI_latest):
|
||||||
@@ -40,4 +41,5 @@ __all__ = [
|
|||||||
"Input",
|
"Input",
|
||||||
"InputImpl",
|
"InputImpl",
|
||||||
"Types",
|
"Types",
|
||||||
|
"ComfyExtension",
|
||||||
]
|
]
|
||||||
|
@@ -4,9 +4,12 @@ from typing import Type, Literal
|
|||||||
import nodes
|
import nodes
|
||||||
import asyncio
|
import asyncio
|
||||||
import inspect
|
import inspect
|
||||||
from comfy_execution.graph_utils import is_link
|
from comfy_execution.graph_utils import is_link, ExecutionBlocker
|
||||||
from comfy.comfy_types.node_typing import ComfyNodeABC, InputTypeDict, InputTypeOptions
|
from comfy.comfy_types.node_typing import ComfyNodeABC, InputTypeDict, InputTypeOptions
|
||||||
|
|
||||||
|
# NOTE: ExecutionBlocker code got moved to graph_utils.py to prevent torch being imported too soon during unit tests
|
||||||
|
ExecutionBlocker = ExecutionBlocker
|
||||||
|
|
||||||
class DependencyCycleError(Exception):
|
class DependencyCycleError(Exception):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@@ -294,21 +297,3 @@ class ExecutionList(TopologicalSort):
|
|||||||
del blocked_by[node_id]
|
del blocked_by[node_id]
|
||||||
to_remove = [node_id for node_id in blocked_by if len(blocked_by[node_id]) == 0]
|
to_remove = [node_id for node_id in blocked_by if len(blocked_by[node_id]) == 0]
|
||||||
return list(blocked_by.keys())
|
return list(blocked_by.keys())
|
||||||
|
|
||||||
class ExecutionBlocker:
|
|
||||||
"""
|
|
||||||
Return this from a node and any users will be blocked with the given error message.
|
|
||||||
If the message is None, execution will be blocked silently instead.
|
|
||||||
Generally, you should avoid using this functionality unless absolutely necessary. Whenever it's
|
|
||||||
possible, a lazy input will be more efficient and have a better user experience.
|
|
||||||
This functionality is useful in two cases:
|
|
||||||
1. You want to conditionally prevent an output node from executing. (Particularly a built-in node
|
|
||||||
like SaveImage. For your own output nodes, I would recommend just adding a BOOL input and using
|
|
||||||
lazy evaluation to let it conditionally disable itself.)
|
|
||||||
2. You have a node with multiple possible outputs, some of which are invalid and should not be used.
|
|
||||||
(I would recommend not making nodes like this in the future -- instead, make multiple nodes with
|
|
||||||
different outputs. Unfortunately, there are several popular existing nodes using this pattern.)
|
|
||||||
"""
|
|
||||||
def __init__(self, message):
|
|
||||||
self.message = message
|
|
||||||
|
|
||||||
|
@@ -137,3 +137,19 @@ def add_graph_prefix(graph, outputs, prefix):
|
|||||||
|
|
||||||
return new_graph, tuple(new_outputs)
|
return new_graph, tuple(new_outputs)
|
||||||
|
|
||||||
|
class ExecutionBlocker:
|
||||||
|
"""
|
||||||
|
Return this from a node and any users will be blocked with the given error message.
|
||||||
|
If the message is None, execution will be blocked silently instead.
|
||||||
|
Generally, you should avoid using this functionality unless absolutely necessary. Whenever it's
|
||||||
|
possible, a lazy input will be more efficient and have a better user experience.
|
||||||
|
This functionality is useful in two cases:
|
||||||
|
1. You want to conditionally prevent an output node from executing. (Particularly a built-in node
|
||||||
|
like SaveImage. For your own output nodes, I would recommend just adding a BOOL input and using
|
||||||
|
lazy evaluation to let it conditionally disable itself.)
|
||||||
|
2. You have a node with multiple possible outputs, some of which are invalid and should not be used.
|
||||||
|
(I would recommend not making nodes like this in the future -- instead, make multiple nodes with
|
||||||
|
different outputs. Unfortunately, there are several popular existing nodes using this pattern.)
|
||||||
|
"""
|
||||||
|
def __init__(self, message):
|
||||||
|
self.message = message
|
||||||
|
@@ -149,6 +149,7 @@ class WanFirstLastFrameToVideo:
|
|||||||
positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent_image, "concat_mask": mask})
|
positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent_image, "concat_mask": mask})
|
||||||
negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent_image, "concat_mask": mask})
|
negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent_image, "concat_mask": mask})
|
||||||
|
|
||||||
|
clip_vision_output = None
|
||||||
if clip_vision_start_image is not None:
|
if clip_vision_start_image is not None:
|
||||||
clip_vision_output = clip_vision_start_image
|
clip_vision_output = clip_vision_start_image
|
||||||
|
|
||||||
|
106
execution.py
106
execution.py
@@ -32,6 +32,8 @@ from comfy_execution.graph_utils import GraphBuilder, is_link
|
|||||||
from comfy_execution.validation import validate_node_input
|
from comfy_execution.validation import validate_node_input
|
||||||
from comfy_execution.progress import get_progress_state, reset_progress_state, add_progress_handler, WebUIProgressHandler
|
from comfy_execution.progress import get_progress_state, reset_progress_state, add_progress_handler, WebUIProgressHandler
|
||||||
from comfy_execution.utils import CurrentNodeContext
|
from comfy_execution.utils import CurrentNodeContext
|
||||||
|
from comfy_api.internal import _ComfyNodeInternal, _NodeOutputInternal, first_real_override, is_class, make_locked_method_func
|
||||||
|
from comfy_api.latest import io
|
||||||
|
|
||||||
|
|
||||||
class ExecutionResult(Enum):
|
class ExecutionResult(Enum):
|
||||||
@@ -56,7 +58,15 @@ class IsChangedCache:
|
|||||||
node = self.dynprompt.get_node(node_id)
|
node = self.dynprompt.get_node(node_id)
|
||||||
class_type = node["class_type"]
|
class_type = node["class_type"]
|
||||||
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
||||||
if not hasattr(class_def, "IS_CHANGED"):
|
has_is_changed = False
|
||||||
|
is_changed_name = None
|
||||||
|
if issubclass(class_def, _ComfyNodeInternal) and first_real_override(class_def, "fingerprint_inputs") is not None:
|
||||||
|
has_is_changed = True
|
||||||
|
is_changed_name = "fingerprint_inputs"
|
||||||
|
elif hasattr(class_def, "IS_CHANGED"):
|
||||||
|
has_is_changed = True
|
||||||
|
is_changed_name = "IS_CHANGED"
|
||||||
|
if not has_is_changed:
|
||||||
self.is_changed[node_id] = False
|
self.is_changed[node_id] = False
|
||||||
return self.is_changed[node_id]
|
return self.is_changed[node_id]
|
||||||
|
|
||||||
@@ -65,9 +75,9 @@ class IsChangedCache:
|
|||||||
return self.is_changed[node_id]
|
return self.is_changed[node_id]
|
||||||
|
|
||||||
# Intentionally do not use cached outputs here. We only want constants in IS_CHANGED
|
# Intentionally do not use cached outputs here. We only want constants in IS_CHANGED
|
||||||
input_data_all, _ = get_input_data(node["inputs"], class_def, node_id, None)
|
input_data_all, _, hidden_inputs = get_input_data(node["inputs"], class_def, node_id, None)
|
||||||
try:
|
try:
|
||||||
is_changed = await _async_map_node_over_list(self.prompt_id, node_id, class_def, input_data_all, "IS_CHANGED")
|
is_changed = await _async_map_node_over_list(self.prompt_id, node_id, class_def, input_data_all, is_changed_name)
|
||||||
is_changed = await resolve_map_node_over_list_results(is_changed)
|
is_changed = await resolve_map_node_over_list_results(is_changed)
|
||||||
node["is_changed"] = [None if isinstance(x, ExecutionBlocker) else x for x in is_changed]
|
node["is_changed"] = [None if isinstance(x, ExecutionBlocker) else x for x in is_changed]
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -126,9 +136,14 @@ class CacheSet:
|
|||||||
SENSITIVE_EXTRA_DATA_KEYS = ("auth_token_comfy_org", "api_key_comfy_org")
|
SENSITIVE_EXTRA_DATA_KEYS = ("auth_token_comfy_org", "api_key_comfy_org")
|
||||||
|
|
||||||
def get_input_data(inputs, class_def, unique_id, outputs=None, dynprompt=None, extra_data={}):
|
def get_input_data(inputs, class_def, unique_id, outputs=None, dynprompt=None, extra_data={}):
|
||||||
|
is_v3 = issubclass(class_def, _ComfyNodeInternal)
|
||||||
|
if is_v3:
|
||||||
|
valid_inputs, schema = class_def.INPUT_TYPES(include_hidden=False, return_schema=True)
|
||||||
|
else:
|
||||||
valid_inputs = class_def.INPUT_TYPES()
|
valid_inputs = class_def.INPUT_TYPES()
|
||||||
input_data_all = {}
|
input_data_all = {}
|
||||||
missing_keys = {}
|
missing_keys = {}
|
||||||
|
hidden_inputs_v3 = {}
|
||||||
for x in inputs:
|
for x in inputs:
|
||||||
input_data = inputs[x]
|
input_data = inputs[x]
|
||||||
_, input_category, input_info = get_input_info(class_def, x, valid_inputs)
|
_, input_category, input_info = get_input_info(class_def, x, valid_inputs)
|
||||||
@@ -153,6 +168,21 @@ def get_input_data(inputs, class_def, unique_id, outputs=None, dynprompt=None, e
|
|||||||
elif input_category is not None:
|
elif input_category is not None:
|
||||||
input_data_all[x] = [input_data]
|
input_data_all[x] = [input_data]
|
||||||
|
|
||||||
|
if is_v3:
|
||||||
|
if schema.hidden:
|
||||||
|
if io.Hidden.prompt in schema.hidden:
|
||||||
|
hidden_inputs_v3[io.Hidden.prompt] = dynprompt.get_original_prompt() if dynprompt is not None else {}
|
||||||
|
if io.Hidden.dynprompt in schema.hidden:
|
||||||
|
hidden_inputs_v3[io.Hidden.dynprompt] = dynprompt
|
||||||
|
if io.Hidden.extra_pnginfo in schema.hidden:
|
||||||
|
hidden_inputs_v3[io.Hidden.extra_pnginfo] = extra_data.get('extra_pnginfo', None)
|
||||||
|
if io.Hidden.unique_id in schema.hidden:
|
||||||
|
hidden_inputs_v3[io.Hidden.unique_id] = unique_id
|
||||||
|
if io.Hidden.auth_token_comfy_org in schema.hidden:
|
||||||
|
hidden_inputs_v3[io.Hidden.auth_token_comfy_org] = extra_data.get("auth_token_comfy_org", None)
|
||||||
|
if io.Hidden.api_key_comfy_org in schema.hidden:
|
||||||
|
hidden_inputs_v3[io.Hidden.api_key_comfy_org] = extra_data.get("api_key_comfy_org", None)
|
||||||
|
else:
|
||||||
if "hidden" in valid_inputs:
|
if "hidden" in valid_inputs:
|
||||||
h = valid_inputs["hidden"]
|
h = valid_inputs["hidden"]
|
||||||
for x in h:
|
for x in h:
|
||||||
@@ -168,7 +198,7 @@ def get_input_data(inputs, class_def, unique_id, outputs=None, dynprompt=None, e
|
|||||||
input_data_all[x] = [extra_data.get("auth_token_comfy_org", None)]
|
input_data_all[x] = [extra_data.get("auth_token_comfy_org", None)]
|
||||||
if h[x] == "API_KEY_COMFY_ORG":
|
if h[x] == "API_KEY_COMFY_ORG":
|
||||||
input_data_all[x] = [extra_data.get("api_key_comfy_org", None)]
|
input_data_all[x] = [extra_data.get("api_key_comfy_org", None)]
|
||||||
return input_data_all, missing_keys
|
return input_data_all, missing_keys, hidden_inputs_v3
|
||||||
|
|
||||||
map_node_over_list = None #Don't hook this please
|
map_node_over_list = None #Don't hook this please
|
||||||
|
|
||||||
@@ -184,7 +214,7 @@ async def resolve_map_node_over_list_results(results):
|
|||||||
raise exc
|
raise exc
|
||||||
return [x.result() if isinstance(x, asyncio.Task) else x for x in results]
|
return [x.result() if isinstance(x, asyncio.Task) else x for x in results]
|
||||||
|
|
||||||
async def _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None):
|
async def _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None, hidden_inputs=None):
|
||||||
# check if node wants the lists
|
# check if node wants the lists
|
||||||
input_is_list = getattr(obj, "INPUT_IS_LIST", False)
|
input_is_list = getattr(obj, "INPUT_IS_LIST", False)
|
||||||
|
|
||||||
@@ -214,6 +244,21 @@ async def _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, f
|
|||||||
if execution_block is None:
|
if execution_block is None:
|
||||||
if pre_execute_cb is not None and index is not None:
|
if pre_execute_cb is not None and index is not None:
|
||||||
pre_execute_cb(index)
|
pre_execute_cb(index)
|
||||||
|
# V3
|
||||||
|
if isinstance(obj, _ComfyNodeInternal) or (is_class(obj) and issubclass(obj, _ComfyNodeInternal)):
|
||||||
|
# if is just a class, then assign no resources or state, just create clone
|
||||||
|
if is_class(obj):
|
||||||
|
type_obj = obj
|
||||||
|
obj.VALIDATE_CLASS()
|
||||||
|
class_clone = obj.PREPARE_CLASS_CLONE(hidden_inputs)
|
||||||
|
# otherwise, use class instance to populate/reuse some fields
|
||||||
|
else:
|
||||||
|
type_obj = type(obj)
|
||||||
|
type_obj.VALIDATE_CLASS()
|
||||||
|
class_clone = type_obj.PREPARE_CLASS_CLONE(hidden_inputs)
|
||||||
|
f = make_locked_method_func(type_obj, func, class_clone)
|
||||||
|
# V1
|
||||||
|
else:
|
||||||
f = getattr(obj, func)
|
f = getattr(obj, func)
|
||||||
if inspect.iscoroutinefunction(f):
|
if inspect.iscoroutinefunction(f):
|
||||||
async def async_wrapper(f, prompt_id, unique_id, list_index, args):
|
async def async_wrapper(f, prompt_id, unique_id, list_index, args):
|
||||||
@@ -266,8 +311,8 @@ def merge_result_data(results, obj):
|
|||||||
output.append([o[i] for o in results])
|
output.append([o[i] for o in results])
|
||||||
return output
|
return output
|
||||||
|
|
||||||
async def get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=None, pre_execute_cb=None):
|
async def get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=None, pre_execute_cb=None, hidden_inputs=None):
|
||||||
return_values = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, obj.FUNCTION, allow_interrupt=True, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb)
|
return_values = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, obj.FUNCTION, allow_interrupt=True, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, hidden_inputs=hidden_inputs)
|
||||||
has_pending_task = any(isinstance(r, asyncio.Task) and not r.done() for r in return_values)
|
has_pending_task = any(isinstance(r, asyncio.Task) and not r.done() for r in return_values)
|
||||||
if has_pending_task:
|
if has_pending_task:
|
||||||
return return_values, {}, False, has_pending_task
|
return return_values, {}, False, has_pending_task
|
||||||
@@ -298,6 +343,26 @@ def get_output_from_returns(return_values, obj):
|
|||||||
result = tuple([result] * len(obj.RETURN_TYPES))
|
result = tuple([result] * len(obj.RETURN_TYPES))
|
||||||
results.append(result)
|
results.append(result)
|
||||||
subgraph_results.append((None, result))
|
subgraph_results.append((None, result))
|
||||||
|
elif isinstance(r, _NodeOutputInternal):
|
||||||
|
# V3
|
||||||
|
if r.ui is not None:
|
||||||
|
if isinstance(r.ui, dict):
|
||||||
|
uis.append(r.ui)
|
||||||
|
else:
|
||||||
|
uis.append(r.ui.as_dict())
|
||||||
|
if r.expand is not None:
|
||||||
|
has_subgraph = True
|
||||||
|
new_graph = r.expand
|
||||||
|
result = r.result
|
||||||
|
if r.block_execution is not None:
|
||||||
|
result = tuple([ExecutionBlocker(r.block_execution)] * len(obj.RETURN_TYPES))
|
||||||
|
subgraph_results.append((new_graph, result))
|
||||||
|
elif r.result is not None:
|
||||||
|
result = r.result
|
||||||
|
if r.block_execution is not None:
|
||||||
|
result = tuple([ExecutionBlocker(r.block_execution)] * len(obj.RETURN_TYPES))
|
||||||
|
results.append(result)
|
||||||
|
subgraph_results.append((None, result))
|
||||||
else:
|
else:
|
||||||
if isinstance(r, ExecutionBlocker):
|
if isinstance(r, ExecutionBlocker):
|
||||||
r = tuple([r] * len(obj.RETURN_TYPES))
|
r = tuple([r] * len(obj.RETURN_TYPES))
|
||||||
@@ -381,7 +446,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
|||||||
has_subgraph = False
|
has_subgraph = False
|
||||||
else:
|
else:
|
||||||
get_progress_state().start_progress(unique_id)
|
get_progress_state().start_progress(unique_id)
|
||||||
input_data_all, missing_keys = get_input_data(inputs, class_def, unique_id, caches.outputs, dynprompt, extra_data)
|
input_data_all, missing_keys, hidden_inputs = get_input_data(inputs, class_def, unique_id, caches.outputs, dynprompt, extra_data)
|
||||||
if server.client_id is not None:
|
if server.client_id is not None:
|
||||||
server.last_node_id = display_node_id
|
server.last_node_id = display_node_id
|
||||||
server.send_sync("executing", { "node": unique_id, "display_node": display_node_id, "prompt_id": prompt_id }, server.client_id)
|
server.send_sync("executing", { "node": unique_id, "display_node": display_node_id, "prompt_id": prompt_id }, server.client_id)
|
||||||
@@ -391,8 +456,12 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
|||||||
obj = class_def()
|
obj = class_def()
|
||||||
caches.objects.set(unique_id, obj)
|
caches.objects.set(unique_id, obj)
|
||||||
|
|
||||||
if hasattr(obj, "check_lazy_status"):
|
if issubclass(class_def, _ComfyNodeInternal):
|
||||||
required_inputs = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, "check_lazy_status", allow_interrupt=True)
|
lazy_status_present = first_real_override(class_def, "check_lazy_status") is not None
|
||||||
|
else:
|
||||||
|
lazy_status_present = getattr(obj, "check_lazy_status", None) is not None
|
||||||
|
if lazy_status_present:
|
||||||
|
required_inputs = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, "check_lazy_status", allow_interrupt=True, hidden_inputs=hidden_inputs)
|
||||||
required_inputs = await resolve_map_node_over_list_results(required_inputs)
|
required_inputs = await resolve_map_node_over_list_results(required_inputs)
|
||||||
required_inputs = set(sum([r for r in required_inputs if isinstance(r,list)], []))
|
required_inputs = set(sum([r for r in required_inputs if isinstance(r,list)], []))
|
||||||
required_inputs = [x for x in required_inputs if isinstance(x,str) and (
|
required_inputs = [x for x in required_inputs if isinstance(x,str) and (
|
||||||
@@ -424,7 +493,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
|||||||
def pre_execute_cb(call_index):
|
def pre_execute_cb(call_index):
|
||||||
# TODO - How to handle this with async functions without contextvars (which requires Python 3.12)?
|
# TODO - How to handle this with async functions without contextvars (which requires Python 3.12)?
|
||||||
GraphBuilder.set_default_prefix(unique_id, call_index, 0)
|
GraphBuilder.set_default_prefix(unique_id, call_index, 0)
|
||||||
output_data, output_ui, has_subgraph, has_pending_tasks = await get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb)
|
output_data, output_ui, has_subgraph, has_pending_tasks = await get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, hidden_inputs=hidden_inputs)
|
||||||
if has_pending_tasks:
|
if has_pending_tasks:
|
||||||
pending_async_nodes[unique_id] = output_data
|
pending_async_nodes[unique_id] = output_data
|
||||||
unblock = execution_list.add_external_block(unique_id)
|
unblock = execution_list.add_external_block(unique_id)
|
||||||
@@ -672,8 +741,14 @@ async def validate_inputs(prompt_id, prompt, item, validated):
|
|||||||
|
|
||||||
validate_function_inputs = []
|
validate_function_inputs = []
|
||||||
validate_has_kwargs = False
|
validate_has_kwargs = False
|
||||||
if hasattr(obj_class, "VALIDATE_INPUTS"):
|
if issubclass(obj_class, _ComfyNodeInternal):
|
||||||
argspec = inspect.getfullargspec(obj_class.VALIDATE_INPUTS)
|
validate_function_name = "validate_inputs"
|
||||||
|
validate_function = first_real_override(obj_class, validate_function_name)
|
||||||
|
else:
|
||||||
|
validate_function_name = "VALIDATE_INPUTS"
|
||||||
|
validate_function = getattr(obj_class, validate_function_name, None)
|
||||||
|
if validate_function is not None:
|
||||||
|
argspec = inspect.getfullargspec(validate_function)
|
||||||
validate_function_inputs = argspec.args
|
validate_function_inputs = argspec.args
|
||||||
validate_has_kwargs = argspec.varkw is not None
|
validate_has_kwargs = argspec.varkw is not None
|
||||||
received_types = {}
|
received_types = {}
|
||||||
@@ -848,7 +923,7 @@ async def validate_inputs(prompt_id, prompt, item, validated):
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
if len(validate_function_inputs) > 0 or validate_has_kwargs:
|
if len(validate_function_inputs) > 0 or validate_has_kwargs:
|
||||||
input_data_all, _ = get_input_data(inputs, obj_class, unique_id)
|
input_data_all, _, hidden_inputs = get_input_data(inputs, obj_class, unique_id)
|
||||||
input_filtered = {}
|
input_filtered = {}
|
||||||
for x in input_data_all:
|
for x in input_data_all:
|
||||||
if x in validate_function_inputs or validate_has_kwargs:
|
if x in validate_function_inputs or validate_has_kwargs:
|
||||||
@@ -856,8 +931,7 @@ async def validate_inputs(prompt_id, prompt, item, validated):
|
|||||||
if 'input_types' in validate_function_inputs:
|
if 'input_types' in validate_function_inputs:
|
||||||
input_filtered['input_types'] = [received_types]
|
input_filtered['input_types'] = [received_types]
|
||||||
|
|
||||||
#ret = obj_class.VALIDATE_INPUTS(**input_filtered)
|
ret = await _async_map_node_over_list(prompt_id, unique_id, obj_class, input_filtered, validate_function_name, hidden_inputs=hidden_inputs)
|
||||||
ret = await _async_map_node_over_list(prompt_id, unique_id, obj_class, input_filtered, "VALIDATE_INPUTS")
|
|
||||||
ret = await resolve_map_node_over_list_results(ret)
|
ret = await resolve_map_node_over_list_results(ret)
|
||||||
for x in input_filtered:
|
for x in input_filtered:
|
||||||
for i, r in enumerate(ret):
|
for i, r in enumerate(ret):
|
||||||
|
37
nodes.py
37
nodes.py
@@ -6,6 +6,7 @@ import os
|
|||||||
import sys
|
import sys
|
||||||
import json
|
import json
|
||||||
import hashlib
|
import hashlib
|
||||||
|
import inspect
|
||||||
import traceback
|
import traceback
|
||||||
import math
|
import math
|
||||||
import time
|
import time
|
||||||
@@ -29,6 +30,7 @@ import comfy.controlnet
|
|||||||
from comfy.comfy_types import IO, ComfyNodeABC, InputTypeDict, FileLocator
|
from comfy.comfy_types import IO, ComfyNodeABC, InputTypeDict, FileLocator
|
||||||
from comfy_api.internal import register_versions, ComfyAPIWithVersion
|
from comfy_api.internal import register_versions, ComfyAPIWithVersion
|
||||||
from comfy_api.version_list import supported_versions
|
from comfy_api.version_list import supported_versions
|
||||||
|
from comfy_api.latest import io, ComfyExtension
|
||||||
|
|
||||||
import comfy.clip_vision
|
import comfy.clip_vision
|
||||||
|
|
||||||
@@ -2152,6 +2154,7 @@ async def load_custom_node(module_path: str, ignore=set(), module_parent="custom
|
|||||||
if os.path.isdir(web_dir):
|
if os.path.isdir(web_dir):
|
||||||
EXTENSION_WEB_DIRS[module_name] = web_dir
|
EXTENSION_WEB_DIRS[module_name] = web_dir
|
||||||
|
|
||||||
|
# V1 node definition
|
||||||
if hasattr(module, "NODE_CLASS_MAPPINGS") and getattr(module, "NODE_CLASS_MAPPINGS") is not None:
|
if hasattr(module, "NODE_CLASS_MAPPINGS") and getattr(module, "NODE_CLASS_MAPPINGS") is not None:
|
||||||
for name, node_cls in module.NODE_CLASS_MAPPINGS.items():
|
for name, node_cls in module.NODE_CLASS_MAPPINGS.items():
|
||||||
if name not in ignore:
|
if name not in ignore:
|
||||||
@@ -2160,8 +2163,38 @@ async def load_custom_node(module_path: str, ignore=set(), module_parent="custom
|
|||||||
if hasattr(module, "NODE_DISPLAY_NAME_MAPPINGS") and getattr(module, "NODE_DISPLAY_NAME_MAPPINGS") is not None:
|
if hasattr(module, "NODE_DISPLAY_NAME_MAPPINGS") and getattr(module, "NODE_DISPLAY_NAME_MAPPINGS") is not None:
|
||||||
NODE_DISPLAY_NAME_MAPPINGS.update(module.NODE_DISPLAY_NAME_MAPPINGS)
|
NODE_DISPLAY_NAME_MAPPINGS.update(module.NODE_DISPLAY_NAME_MAPPINGS)
|
||||||
return True
|
return True
|
||||||
|
# V3 Extension Definition
|
||||||
|
elif hasattr(module, "comfy_entrypoint"):
|
||||||
|
entrypoint = getattr(module, "comfy_entrypoint")
|
||||||
|
if not callable(entrypoint):
|
||||||
|
logging.warning(f"comfy_entrypoint in {module_path} is not callable, skipping.")
|
||||||
|
return False
|
||||||
|
try:
|
||||||
|
if inspect.iscoroutinefunction(entrypoint):
|
||||||
|
extension = await entrypoint()
|
||||||
else:
|
else:
|
||||||
logging.warning(f"Skip {module_path} module for custom nodes due to the lack of NODE_CLASS_MAPPINGS.")
|
extension = entrypoint()
|
||||||
|
if not isinstance(extension, ComfyExtension):
|
||||||
|
logging.warning(f"comfy_entrypoint in {module_path} did not return a ComfyExtension, skipping.")
|
||||||
|
return False
|
||||||
|
node_list = await extension.get_node_list()
|
||||||
|
if not isinstance(node_list, list):
|
||||||
|
logging.warning(f"comfy_entrypoint in {module_path} did not return a list of nodes, skipping.")
|
||||||
|
return False
|
||||||
|
for node_cls in node_list:
|
||||||
|
node_cls: io.ComfyNode
|
||||||
|
schema = node_cls.GET_SCHEMA()
|
||||||
|
if schema.node_id not in ignore:
|
||||||
|
NODE_CLASS_MAPPINGS[schema.node_id] = node_cls
|
||||||
|
node_cls.RELATIVE_PYTHON_MODULE = "{}.{}".format(module_parent, get_module_name(module_path))
|
||||||
|
if schema.display_name is not None:
|
||||||
|
NODE_DISPLAY_NAME_MAPPINGS[schema.node_id] = schema.display_name
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
logging.warning(f"Error while calling comfy_entrypoint in {module_path}: {e}")
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
logging.warning(f"Skip {module_path} module for custom nodes due to the lack of NODE_CLASS_MAPPINGS or NODES_LIST (need one).")
|
||||||
return False
|
return False
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.warning(traceback.format_exc())
|
logging.warning(traceback.format_exc())
|
||||||
@@ -2286,7 +2319,7 @@ async def init_builtin_extra_nodes():
|
|||||||
"nodes_string.py",
|
"nodes_string.py",
|
||||||
"nodes_camera_trajectory.py",
|
"nodes_camera_trajectory.py",
|
||||||
"nodes_edit_model.py",
|
"nodes_edit_model.py",
|
||||||
"nodes_tcfg.py"
|
"nodes_tcfg.py",
|
||||||
]
|
]
|
||||||
|
|
||||||
import_failed = []
|
import_failed = []
|
||||||
|
@@ -30,6 +30,7 @@ from comfy_api import feature_flags
|
|||||||
import node_helpers
|
import node_helpers
|
||||||
from comfyui_version import __version__
|
from comfyui_version import __version__
|
||||||
from app.frontend_management import FrontendManager
|
from app.frontend_management import FrontendManager
|
||||||
|
from comfy_api.internal import _ComfyNodeInternal
|
||||||
|
|
||||||
from app.user_manager import UserManager
|
from app.user_manager import UserManager
|
||||||
from app.model_manager import ModelFileManager
|
from app.model_manager import ModelFileManager
|
||||||
@@ -591,6 +592,8 @@ class PromptServer():
|
|||||||
|
|
||||||
def node_info(node_class):
|
def node_info(node_class):
|
||||||
obj_class = nodes.NODE_CLASS_MAPPINGS[node_class]
|
obj_class = nodes.NODE_CLASS_MAPPINGS[node_class]
|
||||||
|
if issubclass(obj_class, _ComfyNodeInternal):
|
||||||
|
return obj_class.GET_NODE_INFO_V1()
|
||||||
info = {}
|
info = {}
|
||||||
info['input'] = obj_class.INPUT_TYPES()
|
info['input'] = obj_class.INPUT_TYPES()
|
||||||
info['input_order'] = {key: list(value.keys()) for (key, value) in obj_class.INPUT_TYPES().items()}
|
info['input_order'] = {key: list(value.keys()) for (key, value) in obj_class.INPUT_TYPES().items()}
|
||||||
|
Reference in New Issue
Block a user