mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-08-02 06:44:49 +08:00
Compare commits
2 Commits
97b8a2c26a
...
5ee381c058
Author | SHA1 | Date | |
---|---|---|---|
|
5ee381c058 | ||
|
4887743a2a |
@@ -5,3 +5,146 @@ from .api_registry import (
|
||||
register_versions as register_versions,
|
||||
get_all_versions as get_all_versions,
|
||||
)
|
||||
|
||||
import asyncio
|
||||
from dataclasses import asdict
|
||||
from typing import Callable, Optional
|
||||
|
||||
|
||||
def first_real_override(cls: type, name: str, *, base: type=None) -> Optional[Callable]:
|
||||
"""Return the *callable* override of `name` visible on `cls`, or None if every
|
||||
implementation up to (and including) `base` is the placeholder defined on `base`.
|
||||
|
||||
If base is not provided, it will assume cls has a GET_BASE_CLASS
|
||||
"""
|
||||
if base is None:
|
||||
if not hasattr(cls, "GET_BASE_CLASS"):
|
||||
raise ValueError("base is required if cls does not have a GET_BASE_CLASS; is this a valid ComfyNode subclass?")
|
||||
base = cls.GET_BASE_CLASS()
|
||||
base_attr = getattr(base, name, None)
|
||||
if base_attr is None:
|
||||
return None
|
||||
base_func = base_attr.__func__
|
||||
for c in cls.mro(): # NodeB, NodeA, ComfyNode, object …
|
||||
if c is base: # reached the placeholder – we're done
|
||||
break
|
||||
if name in c.__dict__: # first class that *defines* the attr
|
||||
func = getattr(c, name).__func__
|
||||
if func is not base_func: # real override
|
||||
return getattr(cls, name) # bound to *cls*
|
||||
return None
|
||||
|
||||
|
||||
class _ComfyNodeInternal:
|
||||
"""Class that all V3-based APIs inherit from for ComfyNode.
|
||||
|
||||
This is intended to only be referenced within execution.py, as it has to handle all V3 APIs going forward."""
|
||||
@classmethod
|
||||
def GET_NODE_INFO_V1(cls):
|
||||
...
|
||||
|
||||
|
||||
class _NodeOutputInternal:
|
||||
"""Class that all V3-based APIs inherit from for NodeOutput.
|
||||
|
||||
This is intended to only be referenced within execution.py, as it has to handle all V3 APIs going forward."""
|
||||
...
|
||||
|
||||
|
||||
def as_pruned_dict(dataclass_obj):
|
||||
'''Return dict of dataclass object with pruned None values.'''
|
||||
return prune_dict(asdict(dataclass_obj))
|
||||
|
||||
def prune_dict(d: dict):
|
||||
return {k: v for k,v in d.items() if v is not None}
|
||||
|
||||
|
||||
def is_class(obj):
|
||||
'''
|
||||
Returns True if is a class type.
|
||||
Returns False if is a class instance.
|
||||
'''
|
||||
return isinstance(obj, type)
|
||||
|
||||
|
||||
def copy_class(cls: type) -> type:
|
||||
'''
|
||||
Copy a class and its attributes.
|
||||
'''
|
||||
if cls is None:
|
||||
return None
|
||||
cls_dict = {
|
||||
k: v for k, v in cls.__dict__.items()
|
||||
if k not in ('__dict__', '__weakref__', '__module__', '__doc__')
|
||||
}
|
||||
# new class
|
||||
new_cls = type(
|
||||
cls.__name__,
|
||||
(cls,),
|
||||
cls_dict
|
||||
)
|
||||
# metadata preservation
|
||||
new_cls.__module__ = cls.__module__
|
||||
new_cls.__doc__ = cls.__doc__
|
||||
return new_cls
|
||||
|
||||
|
||||
class classproperty(object):
|
||||
def __init__(self, f):
|
||||
self.f = f
|
||||
def __get__(self, obj, owner):
|
||||
return self.f(owner)
|
||||
|
||||
|
||||
# NOTE: this was ai generated and validated by hand
|
||||
def shallow_clone_class(cls, new_name=None):
|
||||
'''
|
||||
Shallow clone a class while preserving super() functionality.
|
||||
'''
|
||||
new_name = new_name or f"{cls.__name__}Clone"
|
||||
# Include the original class in the bases to maintain proper inheritance
|
||||
new_bases = (cls,) + cls.__bases__
|
||||
return type(new_name, new_bases, dict(cls.__dict__))
|
||||
|
||||
# NOTE: this was ai generated and validated by hand
|
||||
def lock_class(cls):
|
||||
'''
|
||||
Lock a class so that its top-levelattributes cannot be modified.
|
||||
'''
|
||||
# Locked instance __setattr__
|
||||
def locked_instance_setattr(self, name, value):
|
||||
raise AttributeError(
|
||||
f"Cannot set attribute '{name}' on immutable instance of {type(self).__name__}"
|
||||
)
|
||||
# Locked metaclass
|
||||
class LockedMeta(type(cls)):
|
||||
def __setattr__(cls_, name, value):
|
||||
raise AttributeError(
|
||||
f"Cannot modify class attribute '{name}' on locked class '{cls_.__name__}'"
|
||||
)
|
||||
# Rebuild class with locked behavior
|
||||
locked_dict = dict(cls.__dict__)
|
||||
locked_dict['__setattr__'] = locked_instance_setattr
|
||||
|
||||
return LockedMeta(cls.__name__, cls.__bases__, locked_dict)
|
||||
|
||||
|
||||
def make_locked_method_func(type_obj, func, class_clone):
|
||||
"""
|
||||
Returns a function that, when called with **inputs, will execute:
|
||||
getattr(type_obj, func).__func__(lock_class(class_clone), **inputs)
|
||||
|
||||
Supports both synchronous and asynchronous methods.
|
||||
"""
|
||||
locked_class = lock_class(class_clone)
|
||||
method = getattr(type_obj, func).__func__
|
||||
|
||||
# Check if the original method is async
|
||||
if asyncio.iscoroutinefunction(method):
|
||||
async def wrapped_async_func(**inputs):
|
||||
return await method(locked_class, **inputs)
|
||||
return wrapped_async_func
|
||||
else:
|
||||
def wrapped_func(**inputs):
|
||||
return method(locked_class, **inputs)
|
||||
return wrapped_func
|
||||
|
@@ -1,5 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Type, TYPE_CHECKING
|
||||
from comfy_api.internal import ComfyAPIBase
|
||||
from comfy_api.internal.singleton import ProxiedSingleton
|
||||
@@ -7,6 +8,9 @@ from comfy_api.internal.async_to_sync import create_sync_class
|
||||
from comfy_api.latest._input import ImageInput, AudioInput, MaskInput, LatentInput, VideoInput
|
||||
from comfy_api.latest._input_impl import VideoFromFile, VideoFromComponents
|
||||
from comfy_api.latest._util import VideoCodec, VideoContainer, VideoComponents
|
||||
from comfy_api.latest._io import _IO as io #noqa: F401
|
||||
from comfy_api.latest._ui import _UI as ui #noqa: F401
|
||||
# from comfy_api.latest._resources import _RESOURCES as resources #noqa: F401
|
||||
from comfy_execution.utils import get_executing_context
|
||||
from comfy_execution.progress import get_progress_state, PreviewImageTuple
|
||||
from PIL import Image
|
||||
@@ -72,6 +76,19 @@ class ComfyAPI_latest(ComfyAPIBase):
|
||||
|
||||
execution: Execution
|
||||
|
||||
class ComfyExtension(ABC):
|
||||
async def on_load(self) -> None:
|
||||
"""
|
||||
Called when an extension is loaded.
|
||||
This should be used to initialize any global resources neeeded by the extension.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
"""
|
||||
Returns a list of nodes that this extension provides.
|
||||
"""
|
||||
|
||||
class Input:
|
||||
Image = ImageInput
|
||||
Audio = AudioInput
|
||||
@@ -103,4 +120,5 @@ __all__ = [
|
||||
"Input",
|
||||
"InputImpl",
|
||||
"Types",
|
||||
"ComfyExtension",
|
||||
]
|
||||
|
1618
comfy_api/latest/_io.py
Normal file
1618
comfy_api/latest/_io.py
Normal file
File diff suppressed because it is too large
Load Diff
72
comfy_api/latest/_resources.py
Normal file
72
comfy_api/latest/_resources.py
Normal file
@@ -0,0 +1,72 @@
|
||||
from __future__ import annotations
|
||||
import comfy.utils
|
||||
import folder_paths
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
import torch
|
||||
|
||||
class ResourceKey(ABC):
|
||||
Type = Any
|
||||
def __init__(self):
|
||||
...
|
||||
|
||||
class TorchDictFolderFilename(ResourceKey):
|
||||
'''Key for requesting a torch file via file_name from a folder category.'''
|
||||
Type = dict[str, torch.Tensor]
|
||||
def __init__(self, folder_name: str, file_name: str):
|
||||
self.folder_name = folder_name
|
||||
self.file_name = file_name
|
||||
|
||||
def __hash__(self):
|
||||
return hash((self.folder_name, self.file_name))
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if not isinstance(other, TorchDictFolderFilename):
|
||||
return False
|
||||
return self.folder_name == other.folder_name and self.file_name == other.file_name
|
||||
|
||||
def __str__(self):
|
||||
return f"{self.folder_name} -> {self.file_name}"
|
||||
|
||||
class Resources(ABC):
|
||||
def __init__(self):
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def get(self, key: ResourceKey, default: Any=...) -> Any:
|
||||
pass
|
||||
|
||||
class ResourcesLocal(Resources):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.local_resources: dict[ResourceKey, Any] = {}
|
||||
|
||||
def get(self, key: ResourceKey, default: Any=...) -> Any:
|
||||
cached = self.local_resources.get(key, None)
|
||||
if cached is not None:
|
||||
logging.info(f"Using cached resource '{key}'")
|
||||
return cached
|
||||
logging.info(f"Loading resource '{key}'")
|
||||
to_return = None
|
||||
if isinstance(key, TorchDictFolderFilename):
|
||||
if default is ...:
|
||||
to_return = comfy.utils.load_torch_file(folder_paths.get_full_path_or_raise(key.folder_name, key.file_name), safe_load=True)
|
||||
else:
|
||||
full_path = folder_paths.get_full_path(key.folder_name, key.file_name)
|
||||
if full_path is not None:
|
||||
to_return = comfy.utils.load_torch_file(full_path, safe_load=True)
|
||||
|
||||
if to_return is not None:
|
||||
self.local_resources[key] = to_return
|
||||
return to_return
|
||||
if default is not ...:
|
||||
return default
|
||||
raise Exception(f"Unsupported resource key type: {type(key)}")
|
||||
|
||||
|
||||
class _RESOURCES:
|
||||
ResourceKey = ResourceKey
|
||||
TorchDictFolderFilename = TorchDictFolderFilename
|
||||
Resources = Resources
|
||||
ResourcesLocal = ResourcesLocal
|
457
comfy_api/latest/_ui.py
Normal file
457
comfy_api/latest/_ui.py
Normal file
@@ -0,0 +1,457 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
from io import BytesIO
|
||||
from typing import Type
|
||||
|
||||
import av
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchaudio
|
||||
from PIL import Image as PILImage
|
||||
from PIL.PngImagePlugin import PngInfo
|
||||
|
||||
import folder_paths
|
||||
|
||||
# used for image preview
|
||||
from comfy.cli_args import args
|
||||
from comfy_api.latest._io import ComfyNode, FolderType, Image, _UIOutput
|
||||
|
||||
|
||||
class SavedResult(dict):
|
||||
def __init__(self, filename: str, subfolder: str, type: FolderType):
|
||||
super().__init__(filename=filename, subfolder=subfolder,type=type.value)
|
||||
|
||||
@property
|
||||
def filename(self) -> str:
|
||||
return self["filename"]
|
||||
|
||||
@property
|
||||
def subfolder(self) -> str:
|
||||
return self["subfolder"]
|
||||
|
||||
@property
|
||||
def type(self) -> FolderType:
|
||||
return FolderType(self["type"])
|
||||
|
||||
|
||||
class SavedImages(_UIOutput):
|
||||
"""A UI output class to represent one or more saved images, potentially animated."""
|
||||
def __init__(self, results: list[SavedResult], is_animated: bool = False):
|
||||
super().__init__()
|
||||
self.results = results
|
||||
self.is_animated = is_animated
|
||||
|
||||
def as_dict(self) -> dict:
|
||||
data = {"images": self.results}
|
||||
if self.is_animated:
|
||||
data["animated"] = (True,)
|
||||
return data
|
||||
|
||||
|
||||
class SavedAudios(_UIOutput):
|
||||
"""UI wrapper around one or more audio files on disk (FLAC / MP3 / Opus)."""
|
||||
def __init__(self, results: list[SavedResult]):
|
||||
super().__init__()
|
||||
self.results = results
|
||||
|
||||
def as_dict(self) -> dict:
|
||||
return {"audio": self.results}
|
||||
|
||||
|
||||
def _get_directory_by_folder_type(folder_type: FolderType) -> str:
|
||||
if folder_type == FolderType.input:
|
||||
return folder_paths.get_input_directory()
|
||||
if folder_type == FolderType.output:
|
||||
return folder_paths.get_output_directory()
|
||||
return folder_paths.get_temp_directory()
|
||||
|
||||
|
||||
class ImageSaveHelper:
|
||||
"""A helper class with static methods to handle image saving and metadata."""
|
||||
|
||||
@staticmethod
|
||||
def _convert_tensor_to_pil(image_tensor: torch.Tensor) -> PILImage.Image:
|
||||
"""Converts a single torch tensor to a PIL Image."""
|
||||
return PILImage.fromarray(np.clip(255.0 * image_tensor.cpu().numpy(), 0, 255).astype(np.uint8))
|
||||
|
||||
@staticmethod
|
||||
def _create_png_metadata(cls: Type[ComfyNode] | None) -> PngInfo | None:
|
||||
"""Creates a PngInfo object with prompt and extra_pnginfo."""
|
||||
if args.disable_metadata or cls is None or not cls.hidden:
|
||||
return None
|
||||
metadata = PngInfo()
|
||||
if cls.hidden.prompt:
|
||||
metadata.add_text("prompt", json.dumps(cls.hidden.prompt))
|
||||
if cls.hidden.extra_pnginfo:
|
||||
for x in cls.hidden.extra_pnginfo:
|
||||
metadata.add_text(x, json.dumps(cls.hidden.extra_pnginfo[x]))
|
||||
return metadata
|
||||
|
||||
@staticmethod
|
||||
def _create_animated_png_metadata(cls: Type[ComfyNode] | None) -> PngInfo | None:
|
||||
"""Creates a PngInfo object with prompt and extra_pnginfo for animated PNGs (APNG)."""
|
||||
if args.disable_metadata or cls is None or not cls.hidden:
|
||||
return None
|
||||
metadata = PngInfo()
|
||||
if cls.hidden.prompt:
|
||||
metadata.add(
|
||||
b"comf",
|
||||
"prompt".encode("latin-1", "strict")
|
||||
+ b"\0"
|
||||
+ json.dumps(cls.hidden.prompt).encode("latin-1", "strict"),
|
||||
after_idat=True,
|
||||
)
|
||||
if cls.hidden.extra_pnginfo:
|
||||
for x in cls.hidden.extra_pnginfo:
|
||||
metadata.add(
|
||||
b"comf",
|
||||
x.encode("latin-1", "strict")
|
||||
+ b"\0"
|
||||
+ json.dumps(cls.hidden.extra_pnginfo[x]).encode("latin-1", "strict"),
|
||||
after_idat=True,
|
||||
)
|
||||
return metadata
|
||||
|
||||
@staticmethod
|
||||
def _create_webp_metadata(pil_image: PILImage.Image, cls: Type[ComfyNode] | None) -> PILImage.Exif:
|
||||
"""Creates EXIF metadata bytes for WebP images."""
|
||||
exif_data = pil_image.getexif()
|
||||
if args.disable_metadata or cls is None or cls.hidden is None:
|
||||
return exif_data
|
||||
if cls.hidden.prompt is not None:
|
||||
exif_data[0x0110] = "prompt:{}".format(json.dumps(cls.hidden.prompt)) # EXIF 0x0110 = Model
|
||||
if cls.hidden.extra_pnginfo is not None:
|
||||
inital_exif_tag = 0x010F # EXIF 0x010f = Make
|
||||
for key, value in cls.hidden.extra_pnginfo.items():
|
||||
exif_data[inital_exif_tag] = "{}:{}".format(key, json.dumps(value))
|
||||
inital_exif_tag -= 1
|
||||
return exif_data
|
||||
|
||||
@staticmethod
|
||||
def save_images(
|
||||
images, filename_prefix: str, folder_type: FolderType, cls: Type[ComfyNode] | None, compress_level = 4,
|
||||
) -> list[SavedResult]:
|
||||
"""Saves a batch of images as individual PNG files."""
|
||||
full_output_folder, filename, counter, subfolder, _ = folder_paths.get_save_image_path(
|
||||
filename_prefix, _get_directory_by_folder_type(folder_type), images[0].shape[1], images[0].shape[0]
|
||||
)
|
||||
results = []
|
||||
metadata = ImageSaveHelper._create_png_metadata(cls)
|
||||
for batch_number, image_tensor in enumerate(images):
|
||||
img = ImageSaveHelper._convert_tensor_to_pil(image_tensor)
|
||||
filename_with_batch_num = filename.replace("%batch_num%", str(batch_number))
|
||||
file = f"{filename_with_batch_num}_{counter:05}_.png"
|
||||
img.save(os.path.join(full_output_folder, file), pnginfo=metadata, compress_level=compress_level)
|
||||
results.append(SavedResult(file, subfolder, folder_type))
|
||||
counter += 1
|
||||
return results
|
||||
|
||||
@staticmethod
|
||||
def get_save_images_ui(images, filename_prefix: str, cls: Type[ComfyNode] | None, compress_level=4) -> SavedImages:
|
||||
"""Saves a batch of images and returns a UI object for the node output."""
|
||||
return SavedImages(
|
||||
ImageSaveHelper.save_images(
|
||||
images,
|
||||
filename_prefix=filename_prefix,
|
||||
folder_type=FolderType.output,
|
||||
cls=cls,
|
||||
compress_level=compress_level,
|
||||
)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def save_animated_png(
|
||||
images, filename_prefix: str, folder_type: FolderType, cls: Type[ComfyNode] | None, fps: float, compress_level: int
|
||||
) -> SavedResult:
|
||||
"""Saves a batch of images as a single animated PNG."""
|
||||
full_output_folder, filename, counter, subfolder, _ = folder_paths.get_save_image_path(
|
||||
filename_prefix, _get_directory_by_folder_type(folder_type), images[0].shape[1], images[0].shape[0]
|
||||
)
|
||||
pil_images = [ImageSaveHelper._convert_tensor_to_pil(img) for img in images]
|
||||
metadata = ImageSaveHelper._create_animated_png_metadata(cls)
|
||||
file = f"{filename}_{counter:05}_.png"
|
||||
save_path = os.path.join(full_output_folder, file)
|
||||
pil_images[0].save(
|
||||
save_path,
|
||||
pnginfo=metadata,
|
||||
compress_level=compress_level,
|
||||
save_all=True,
|
||||
duration=int(1000.0 / fps),
|
||||
append_images=pil_images[1:],
|
||||
)
|
||||
return SavedResult(file, subfolder, folder_type)
|
||||
|
||||
@staticmethod
|
||||
def get_save_animated_png_ui(
|
||||
images, filename_prefix: str, cls: Type[ComfyNode] | None, fps: float, compress_level: int
|
||||
) -> SavedImages:
|
||||
"""Saves an animated PNG and returns a UI object for the node output."""
|
||||
result = ImageSaveHelper.save_animated_png(
|
||||
images,
|
||||
filename_prefix=filename_prefix,
|
||||
folder_type=FolderType.output,
|
||||
cls=cls,
|
||||
fps=fps,
|
||||
compress_level=compress_level,
|
||||
)
|
||||
return SavedImages([result], is_animated=len(images) > 1)
|
||||
|
||||
@staticmethod
|
||||
def save_animated_webp(
|
||||
images,
|
||||
filename_prefix: str,
|
||||
folder_type: FolderType,
|
||||
cls: Type[ComfyNode] | None,
|
||||
fps: float,
|
||||
lossless: bool,
|
||||
quality: int,
|
||||
method: int,
|
||||
) -> SavedResult:
|
||||
"""Saves a batch of images as a single animated WebP."""
|
||||
full_output_folder, filename, counter, subfolder, _ = folder_paths.get_save_image_path(
|
||||
filename_prefix, _get_directory_by_folder_type(folder_type), images[0].shape[1], images[0].shape[0]
|
||||
)
|
||||
pil_images = [ImageSaveHelper._convert_tensor_to_pil(img) for img in images]
|
||||
pil_exif = ImageSaveHelper._create_webp_metadata(pil_images[0], cls)
|
||||
file = f"{filename}_{counter:05}_.webp"
|
||||
pil_images[0].save(
|
||||
os.path.join(full_output_folder, file),
|
||||
save_all=True,
|
||||
duration=int(1000.0 / fps),
|
||||
append_images=pil_images[1:],
|
||||
exif=pil_exif,
|
||||
lossless=lossless,
|
||||
quality=quality,
|
||||
method=method,
|
||||
)
|
||||
return SavedResult(file, subfolder, folder_type)
|
||||
|
||||
@staticmethod
|
||||
def get_save_animated_webp_ui(
|
||||
images,
|
||||
filename_prefix: str,
|
||||
cls: Type[ComfyNode] | None,
|
||||
fps: float,
|
||||
lossless: bool,
|
||||
quality: int,
|
||||
method: int,
|
||||
) -> SavedImages:
|
||||
"""Saves an animated WebP and returns a UI object for the node output."""
|
||||
result = ImageSaveHelper.save_animated_webp(
|
||||
images,
|
||||
filename_prefix=filename_prefix,
|
||||
folder_type=FolderType.output,
|
||||
cls=cls,
|
||||
fps=fps,
|
||||
lossless=lossless,
|
||||
quality=quality,
|
||||
method=method,
|
||||
)
|
||||
return SavedImages([result], is_animated=len(images) > 1)
|
||||
|
||||
|
||||
class AudioSaveHelper:
|
||||
"""A helper class with static methods to handle audio saving and metadata."""
|
||||
_OPUS_RATES = [8000, 12000, 16000, 24000, 48000]
|
||||
|
||||
@staticmethod
|
||||
def save_audio(
|
||||
audio: dict,
|
||||
filename_prefix: str,
|
||||
folder_type: FolderType,
|
||||
cls: Type[ComfyNode] | None,
|
||||
format: str = "flac",
|
||||
quality: str = "128k",
|
||||
) -> list[SavedResult]:
|
||||
full_output_folder, filename, counter, subfolder, _ = folder_paths.get_save_image_path(
|
||||
filename_prefix, _get_directory_by_folder_type(folder_type)
|
||||
)
|
||||
|
||||
metadata = {}
|
||||
if not args.disable_metadata and cls is not None:
|
||||
if cls.hidden.prompt is not None:
|
||||
metadata["prompt"] = json.dumps(cls.hidden.prompt)
|
||||
if cls.hidden.extra_pnginfo is not None:
|
||||
for x in cls.hidden.extra_pnginfo:
|
||||
metadata[x] = json.dumps(cls.hidden.extra_pnginfo[x])
|
||||
|
||||
results = []
|
||||
for batch_number, waveform in enumerate(audio["waveform"].cpu()):
|
||||
filename_with_batch_num = filename.replace("%batch_num%", str(batch_number))
|
||||
file = f"{filename_with_batch_num}_{counter:05}_.{format}"
|
||||
output_path = os.path.join(full_output_folder, file)
|
||||
|
||||
# Use original sample rate initially
|
||||
sample_rate = audio["sample_rate"]
|
||||
|
||||
# Handle Opus sample rate requirements
|
||||
if format == "opus":
|
||||
if sample_rate > 48000:
|
||||
sample_rate = 48000
|
||||
elif sample_rate not in AudioSaveHelper._OPUS_RATES:
|
||||
# Find the next highest supported rate
|
||||
for rate in sorted(AudioSaveHelper._OPUS_RATES):
|
||||
if rate > sample_rate:
|
||||
sample_rate = rate
|
||||
break
|
||||
if sample_rate not in AudioSaveHelper._OPUS_RATES: # Fallback if still not supported
|
||||
sample_rate = 48000
|
||||
|
||||
# Resample if necessary
|
||||
if sample_rate != audio["sample_rate"]:
|
||||
waveform = torchaudio.functional.resample(waveform, audio["sample_rate"], sample_rate)
|
||||
|
||||
# Create output with specified format
|
||||
output_buffer = BytesIO()
|
||||
output_container = av.open(output_buffer, mode="w", format=format)
|
||||
|
||||
# Set metadata on the container
|
||||
for key, value in metadata.items():
|
||||
output_container.metadata[key] = value
|
||||
|
||||
# Set up the output stream with appropriate properties
|
||||
if format == "opus":
|
||||
out_stream = output_container.add_stream("libopus", rate=sample_rate)
|
||||
if quality == "64k":
|
||||
out_stream.bit_rate = 64000
|
||||
elif quality == "96k":
|
||||
out_stream.bit_rate = 96000
|
||||
elif quality == "128k":
|
||||
out_stream.bit_rate = 128000
|
||||
elif quality == "192k":
|
||||
out_stream.bit_rate = 192000
|
||||
elif quality == "320k":
|
||||
out_stream.bit_rate = 320000
|
||||
elif format == "mp3":
|
||||
out_stream = output_container.add_stream("libmp3lame", rate=sample_rate)
|
||||
if quality == "V0":
|
||||
# TODO i would really love to support V3 and V5 but there doesn't seem to be a way to set the qscale level, the property below is a bool
|
||||
out_stream.codec_context.qscale = 1
|
||||
elif quality == "128k":
|
||||
out_stream.bit_rate = 128000
|
||||
elif quality == "320k":
|
||||
out_stream.bit_rate = 320000
|
||||
else: # format == "flac":
|
||||
out_stream = output_container.add_stream("flac", rate=sample_rate)
|
||||
|
||||
frame = av.AudioFrame.from_ndarray(
|
||||
waveform.movedim(0, 1).reshape(1, -1).float().numpy(),
|
||||
format="flt",
|
||||
layout="mono" if waveform.shape[0] == 1 else "stereo",
|
||||
)
|
||||
frame.sample_rate = sample_rate
|
||||
frame.pts = 0
|
||||
output_container.mux(out_stream.encode(frame))
|
||||
|
||||
# Flush encoder
|
||||
output_container.mux(out_stream.encode(None))
|
||||
|
||||
# Close containers
|
||||
output_container.close()
|
||||
|
||||
# Write the output to file
|
||||
output_buffer.seek(0)
|
||||
with open(output_path, "wb") as f:
|
||||
f.write(output_buffer.getbuffer())
|
||||
|
||||
results.append(SavedResult(file, subfolder, folder_type))
|
||||
counter += 1
|
||||
|
||||
return results
|
||||
|
||||
@staticmethod
|
||||
def get_save_audio_ui(
|
||||
audio, filename_prefix: str, cls: Type[ComfyNode] | None, format: str = "flac", quality: str = "128k",
|
||||
) -> SavedAudios:
|
||||
"""Save and instantly wrap for UI."""
|
||||
return SavedAudios(
|
||||
AudioSaveHelper.save_audio(
|
||||
audio,
|
||||
filename_prefix=filename_prefix,
|
||||
folder_type=FolderType.output,
|
||||
cls=cls,
|
||||
format=format,
|
||||
quality=quality,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class PreviewImage(_UIOutput):
|
||||
def __init__(self, image: Image.Type, animated: bool = False, cls: Type[ComfyNode] = None, **kwargs):
|
||||
self.values = ImageSaveHelper.save_images(
|
||||
image,
|
||||
filename_prefix="ComfyUI_temp_" + ''.join(random.choice("abcdefghijklmnopqrstupvxyz") for _ in range(5)),
|
||||
folder_type=FolderType.temp,
|
||||
cls=cls,
|
||||
compress_level=1,
|
||||
)
|
||||
self.animated = animated
|
||||
|
||||
def as_dict(self):
|
||||
return {
|
||||
"images": self.values,
|
||||
"animated": (self.animated,)
|
||||
}
|
||||
|
||||
|
||||
class PreviewMask(PreviewImage):
|
||||
def __init__(self, mask: PreviewMask.Type, animated: bool=False, cls: ComfyNode=None, **kwargs):
|
||||
preview = mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])).movedim(1, -1).expand(-1, -1, -1, 3)
|
||||
super().__init__(preview, animated, cls, **kwargs)
|
||||
|
||||
|
||||
class PreviewAudio(_UIOutput):
|
||||
def __init__(self, audio: dict, cls: Type[ComfyNode] = None, **kwargs):
|
||||
self.values = AudioSaveHelper.save_audio(
|
||||
audio,
|
||||
filename_prefix="ComfyUI_temp_" + "".join(random.choice("abcdefghijklmnopqrstuvwxyz") for _ in range(5)),
|
||||
folder_type=FolderType.temp,
|
||||
cls=cls,
|
||||
format="flac",
|
||||
quality="128k",
|
||||
)
|
||||
|
||||
def as_dict(self) -> dict:
|
||||
return {"audio": self.values}
|
||||
|
||||
|
||||
class PreviewVideo(_UIOutput):
|
||||
def __init__(self, values: list[SavedResult | dict], **kwargs):
|
||||
self.values = values
|
||||
|
||||
def as_dict(self):
|
||||
return {"images": self.values, "animated": (True,)}
|
||||
|
||||
|
||||
class PreviewUI3D(_UIOutput):
|
||||
def __init__(self, model_file, camera_info, **kwargs):
|
||||
self.model_file = model_file
|
||||
self.camera_info = camera_info
|
||||
|
||||
def as_dict(self):
|
||||
return {"result": [self.model_file, self.camera_info]}
|
||||
|
||||
|
||||
class PreviewText(_UIOutput):
|
||||
def __init__(self, value: str, **kwargs):
|
||||
self.value = value
|
||||
|
||||
def as_dict(self):
|
||||
return {"text": (self.value,)}
|
||||
|
||||
|
||||
class _UI:
|
||||
SavedResult = SavedResult
|
||||
SavedImages = SavedImages
|
||||
SavedAudios = SavedAudios
|
||||
ImageSaveHelper = ImageSaveHelper
|
||||
AudioSaveHelper = AudioSaveHelper
|
||||
PreviewImage = PreviewImage
|
||||
PreviewMask = PreviewMask
|
||||
PreviewAudio = PreviewAudio
|
||||
PreviewVideo = PreviewVideo
|
||||
PreviewUI3D = PreviewUI3D
|
||||
PreviewText = PreviewText
|
@@ -6,6 +6,7 @@ from comfy_api.latest import (
|
||||
)
|
||||
from typing import Type, TYPE_CHECKING
|
||||
from comfy_api.internal.async_to_sync import create_sync_class
|
||||
from comfy_api.latest import io, ui, ComfyExtension #noqa: F401
|
||||
|
||||
|
||||
class ComfyAPIAdapter_v0_0_2(ComfyAPI_latest):
|
||||
@@ -40,4 +41,5 @@ __all__ = [
|
||||
"Input",
|
||||
"InputImpl",
|
||||
"Types",
|
||||
"ComfyExtension",
|
||||
]
|
||||
|
@@ -4,9 +4,12 @@ from typing import Type, Literal
|
||||
import nodes
|
||||
import asyncio
|
||||
import inspect
|
||||
from comfy_execution.graph_utils import is_link
|
||||
from comfy_execution.graph_utils import is_link, ExecutionBlocker
|
||||
from comfy.comfy_types.node_typing import ComfyNodeABC, InputTypeDict, InputTypeOptions
|
||||
|
||||
# NOTE: ExecutionBlocker code got moved to graph_utils.py to prevent torch being imported too soon during unit tests
|
||||
ExecutionBlocker = ExecutionBlocker
|
||||
|
||||
class DependencyCycleError(Exception):
|
||||
pass
|
||||
|
||||
@@ -294,21 +297,3 @@ class ExecutionList(TopologicalSort):
|
||||
del blocked_by[node_id]
|
||||
to_remove = [node_id for node_id in blocked_by if len(blocked_by[node_id]) == 0]
|
||||
return list(blocked_by.keys())
|
||||
|
||||
class ExecutionBlocker:
|
||||
"""
|
||||
Return this from a node and any users will be blocked with the given error message.
|
||||
If the message is None, execution will be blocked silently instead.
|
||||
Generally, you should avoid using this functionality unless absolutely necessary. Whenever it's
|
||||
possible, a lazy input will be more efficient and have a better user experience.
|
||||
This functionality is useful in two cases:
|
||||
1. You want to conditionally prevent an output node from executing. (Particularly a built-in node
|
||||
like SaveImage. For your own output nodes, I would recommend just adding a BOOL input and using
|
||||
lazy evaluation to let it conditionally disable itself.)
|
||||
2. You have a node with multiple possible outputs, some of which are invalid and should not be used.
|
||||
(I would recommend not making nodes like this in the future -- instead, make multiple nodes with
|
||||
different outputs. Unfortunately, there are several popular existing nodes using this pattern.)
|
||||
"""
|
||||
def __init__(self, message):
|
||||
self.message = message
|
||||
|
||||
|
@@ -137,3 +137,19 @@ def add_graph_prefix(graph, outputs, prefix):
|
||||
|
||||
return new_graph, tuple(new_outputs)
|
||||
|
||||
class ExecutionBlocker:
|
||||
"""
|
||||
Return this from a node and any users will be blocked with the given error message.
|
||||
If the message is None, execution will be blocked silently instead.
|
||||
Generally, you should avoid using this functionality unless absolutely necessary. Whenever it's
|
||||
possible, a lazy input will be more efficient and have a better user experience.
|
||||
This functionality is useful in two cases:
|
||||
1. You want to conditionally prevent an output node from executing. (Particularly a built-in node
|
||||
like SaveImage. For your own output nodes, I would recommend just adding a BOOL input and using
|
||||
lazy evaluation to let it conditionally disable itself.)
|
||||
2. You have a node with multiple possible outputs, some of which are invalid and should not be used.
|
||||
(I would recommend not making nodes like this in the future -- instead, make multiple nodes with
|
||||
different outputs. Unfortunately, there are several popular existing nodes using this pattern.)
|
||||
"""
|
||||
def __init__(self, message):
|
||||
self.message = message
|
||||
|
@@ -149,6 +149,7 @@ class WanFirstLastFrameToVideo:
|
||||
positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent_image, "concat_mask": mask})
|
||||
negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent_image, "concat_mask": mask})
|
||||
|
||||
clip_vision_output = None
|
||||
if clip_vision_start_image is not None:
|
||||
clip_vision_output = clip_vision_start_image
|
||||
|
||||
|
140
execution.py
140
execution.py
@@ -32,6 +32,8 @@ from comfy_execution.graph_utils import GraphBuilder, is_link
|
||||
from comfy_execution.validation import validate_node_input
|
||||
from comfy_execution.progress import get_progress_state, reset_progress_state, add_progress_handler, WebUIProgressHandler
|
||||
from comfy_execution.utils import CurrentNodeContext
|
||||
from comfy_api.internal import _ComfyNodeInternal, _NodeOutputInternal, first_real_override, is_class, make_locked_method_func
|
||||
from comfy_api.latest import io
|
||||
|
||||
|
||||
class ExecutionResult(Enum):
|
||||
@@ -56,7 +58,15 @@ class IsChangedCache:
|
||||
node = self.dynprompt.get_node(node_id)
|
||||
class_type = node["class_type"]
|
||||
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
||||
if not hasattr(class_def, "IS_CHANGED"):
|
||||
has_is_changed = False
|
||||
is_changed_name = None
|
||||
if issubclass(class_def, _ComfyNodeInternal) and first_real_override(class_def, "fingerprint_inputs") is not None:
|
||||
has_is_changed = True
|
||||
is_changed_name = "fingerprint_inputs"
|
||||
elif hasattr(class_def, "IS_CHANGED"):
|
||||
has_is_changed = True
|
||||
is_changed_name = "IS_CHANGED"
|
||||
if not has_is_changed:
|
||||
self.is_changed[node_id] = False
|
||||
return self.is_changed[node_id]
|
||||
|
||||
@@ -65,9 +75,9 @@ class IsChangedCache:
|
||||
return self.is_changed[node_id]
|
||||
|
||||
# Intentionally do not use cached outputs here. We only want constants in IS_CHANGED
|
||||
input_data_all, _ = get_input_data(node["inputs"], class_def, node_id, None)
|
||||
input_data_all, _, hidden_inputs = get_input_data(node["inputs"], class_def, node_id, None)
|
||||
try:
|
||||
is_changed = await _async_map_node_over_list(self.prompt_id, node_id, class_def, input_data_all, "IS_CHANGED")
|
||||
is_changed = await _async_map_node_over_list(self.prompt_id, node_id, class_def, input_data_all, is_changed_name)
|
||||
is_changed = await resolve_map_node_over_list_results(is_changed)
|
||||
node["is_changed"] = [None if isinstance(x, ExecutionBlocker) else x for x in is_changed]
|
||||
except Exception as e:
|
||||
@@ -126,9 +136,14 @@ class CacheSet:
|
||||
SENSITIVE_EXTRA_DATA_KEYS = ("auth_token_comfy_org", "api_key_comfy_org")
|
||||
|
||||
def get_input_data(inputs, class_def, unique_id, outputs=None, dynprompt=None, extra_data={}):
|
||||
valid_inputs = class_def.INPUT_TYPES()
|
||||
is_v3 = issubclass(class_def, _ComfyNodeInternal)
|
||||
if is_v3:
|
||||
valid_inputs, schema = class_def.INPUT_TYPES(include_hidden=False, return_schema=True)
|
||||
else:
|
||||
valid_inputs = class_def.INPUT_TYPES()
|
||||
input_data_all = {}
|
||||
missing_keys = {}
|
||||
hidden_inputs_v3 = {}
|
||||
for x in inputs:
|
||||
input_data = inputs[x]
|
||||
_, input_category, input_info = get_input_info(class_def, x, valid_inputs)
|
||||
@@ -153,22 +168,37 @@ def get_input_data(inputs, class_def, unique_id, outputs=None, dynprompt=None, e
|
||||
elif input_category is not None:
|
||||
input_data_all[x] = [input_data]
|
||||
|
||||
if "hidden" in valid_inputs:
|
||||
h = valid_inputs["hidden"]
|
||||
for x in h:
|
||||
if h[x] == "PROMPT":
|
||||
input_data_all[x] = [dynprompt.get_original_prompt() if dynprompt is not None else {}]
|
||||
if h[x] == "DYNPROMPT":
|
||||
input_data_all[x] = [dynprompt]
|
||||
if h[x] == "EXTRA_PNGINFO":
|
||||
input_data_all[x] = [extra_data.get('extra_pnginfo', None)]
|
||||
if h[x] == "UNIQUE_ID":
|
||||
input_data_all[x] = [unique_id]
|
||||
if h[x] == "AUTH_TOKEN_COMFY_ORG":
|
||||
input_data_all[x] = [extra_data.get("auth_token_comfy_org", None)]
|
||||
if h[x] == "API_KEY_COMFY_ORG":
|
||||
input_data_all[x] = [extra_data.get("api_key_comfy_org", None)]
|
||||
return input_data_all, missing_keys
|
||||
if is_v3:
|
||||
if schema.hidden:
|
||||
if io.Hidden.prompt in schema.hidden:
|
||||
hidden_inputs_v3[io.Hidden.prompt] = dynprompt.get_original_prompt() if dynprompt is not None else {}
|
||||
if io.Hidden.dynprompt in schema.hidden:
|
||||
hidden_inputs_v3[io.Hidden.dynprompt] = dynprompt
|
||||
if io.Hidden.extra_pnginfo in schema.hidden:
|
||||
hidden_inputs_v3[io.Hidden.extra_pnginfo] = extra_data.get('extra_pnginfo', None)
|
||||
if io.Hidden.unique_id in schema.hidden:
|
||||
hidden_inputs_v3[io.Hidden.unique_id] = unique_id
|
||||
if io.Hidden.auth_token_comfy_org in schema.hidden:
|
||||
hidden_inputs_v3[io.Hidden.auth_token_comfy_org] = extra_data.get("auth_token_comfy_org", None)
|
||||
if io.Hidden.api_key_comfy_org in schema.hidden:
|
||||
hidden_inputs_v3[io.Hidden.api_key_comfy_org] = extra_data.get("api_key_comfy_org", None)
|
||||
else:
|
||||
if "hidden" in valid_inputs:
|
||||
h = valid_inputs["hidden"]
|
||||
for x in h:
|
||||
if h[x] == "PROMPT":
|
||||
input_data_all[x] = [dynprompt.get_original_prompt() if dynprompt is not None else {}]
|
||||
if h[x] == "DYNPROMPT":
|
||||
input_data_all[x] = [dynprompt]
|
||||
if h[x] == "EXTRA_PNGINFO":
|
||||
input_data_all[x] = [extra_data.get('extra_pnginfo', None)]
|
||||
if h[x] == "UNIQUE_ID":
|
||||
input_data_all[x] = [unique_id]
|
||||
if h[x] == "AUTH_TOKEN_COMFY_ORG":
|
||||
input_data_all[x] = [extra_data.get("auth_token_comfy_org", None)]
|
||||
if h[x] == "API_KEY_COMFY_ORG":
|
||||
input_data_all[x] = [extra_data.get("api_key_comfy_org", None)]
|
||||
return input_data_all, missing_keys, hidden_inputs_v3
|
||||
|
||||
map_node_over_list = None #Don't hook this please
|
||||
|
||||
@@ -184,7 +214,7 @@ async def resolve_map_node_over_list_results(results):
|
||||
raise exc
|
||||
return [x.result() if isinstance(x, asyncio.Task) else x for x in results]
|
||||
|
||||
async def _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None):
|
||||
async def _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None, hidden_inputs=None):
|
||||
# check if node wants the lists
|
||||
input_is_list = getattr(obj, "INPUT_IS_LIST", False)
|
||||
|
||||
@@ -214,7 +244,22 @@ async def _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, f
|
||||
if execution_block is None:
|
||||
if pre_execute_cb is not None and index is not None:
|
||||
pre_execute_cb(index)
|
||||
f = getattr(obj, func)
|
||||
# V3
|
||||
if isinstance(obj, _ComfyNodeInternal) or (is_class(obj) and issubclass(obj, _ComfyNodeInternal)):
|
||||
# if is just a class, then assign no resources or state, just create clone
|
||||
if is_class(obj):
|
||||
type_obj = obj
|
||||
obj.VALIDATE_CLASS()
|
||||
class_clone = obj.PREPARE_CLASS_CLONE(hidden_inputs)
|
||||
# otherwise, use class instance to populate/reuse some fields
|
||||
else:
|
||||
type_obj = type(obj)
|
||||
type_obj.VALIDATE_CLASS()
|
||||
class_clone = type_obj.PREPARE_CLASS_CLONE(hidden_inputs)
|
||||
f = make_locked_method_func(type_obj, func, class_clone)
|
||||
# V1
|
||||
else:
|
||||
f = getattr(obj, func)
|
||||
if inspect.iscoroutinefunction(f):
|
||||
async def async_wrapper(f, prompt_id, unique_id, list_index, args):
|
||||
with CurrentNodeContext(prompt_id, unique_id, list_index):
|
||||
@@ -266,8 +311,8 @@ def merge_result_data(results, obj):
|
||||
output.append([o[i] for o in results])
|
||||
return output
|
||||
|
||||
async def get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=None, pre_execute_cb=None):
|
||||
return_values = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, obj.FUNCTION, allow_interrupt=True, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb)
|
||||
async def get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=None, pre_execute_cb=None, hidden_inputs=None):
|
||||
return_values = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, obj.FUNCTION, allow_interrupt=True, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, hidden_inputs=hidden_inputs)
|
||||
has_pending_task = any(isinstance(r, asyncio.Task) and not r.done() for r in return_values)
|
||||
if has_pending_task:
|
||||
return return_values, {}, False, has_pending_task
|
||||
@@ -298,6 +343,26 @@ def get_output_from_returns(return_values, obj):
|
||||
result = tuple([result] * len(obj.RETURN_TYPES))
|
||||
results.append(result)
|
||||
subgraph_results.append((None, result))
|
||||
elif isinstance(r, _NodeOutputInternal):
|
||||
# V3
|
||||
if r.ui is not None:
|
||||
if isinstance(r.ui, dict):
|
||||
uis.append(r.ui)
|
||||
else:
|
||||
uis.append(r.ui.as_dict())
|
||||
if r.expand is not None:
|
||||
has_subgraph = True
|
||||
new_graph = r.expand
|
||||
result = r.result
|
||||
if r.block_execution is not None:
|
||||
result = tuple([ExecutionBlocker(r.block_execution)] * len(obj.RETURN_TYPES))
|
||||
subgraph_results.append((new_graph, result))
|
||||
elif r.result is not None:
|
||||
result = r.result
|
||||
if r.block_execution is not None:
|
||||
result = tuple([ExecutionBlocker(r.block_execution)] * len(obj.RETURN_TYPES))
|
||||
results.append(result)
|
||||
subgraph_results.append((None, result))
|
||||
else:
|
||||
if isinstance(r, ExecutionBlocker):
|
||||
r = tuple([r] * len(obj.RETURN_TYPES))
|
||||
@@ -381,7 +446,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
||||
has_subgraph = False
|
||||
else:
|
||||
get_progress_state().start_progress(unique_id)
|
||||
input_data_all, missing_keys = get_input_data(inputs, class_def, unique_id, caches.outputs, dynprompt, extra_data)
|
||||
input_data_all, missing_keys, hidden_inputs = get_input_data(inputs, class_def, unique_id, caches.outputs, dynprompt, extra_data)
|
||||
if server.client_id is not None:
|
||||
server.last_node_id = display_node_id
|
||||
server.send_sync("executing", { "node": unique_id, "display_node": display_node_id, "prompt_id": prompt_id }, server.client_id)
|
||||
@@ -391,8 +456,12 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
||||
obj = class_def()
|
||||
caches.objects.set(unique_id, obj)
|
||||
|
||||
if hasattr(obj, "check_lazy_status"):
|
||||
required_inputs = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, "check_lazy_status", allow_interrupt=True)
|
||||
if issubclass(class_def, _ComfyNodeInternal):
|
||||
lazy_status_present = first_real_override(class_def, "check_lazy_status") is not None
|
||||
else:
|
||||
lazy_status_present = getattr(obj, "check_lazy_status", None) is not None
|
||||
if lazy_status_present:
|
||||
required_inputs = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, "check_lazy_status", allow_interrupt=True, hidden_inputs=hidden_inputs)
|
||||
required_inputs = await resolve_map_node_over_list_results(required_inputs)
|
||||
required_inputs = set(sum([r for r in required_inputs if isinstance(r,list)], []))
|
||||
required_inputs = [x for x in required_inputs if isinstance(x,str) and (
|
||||
@@ -424,7 +493,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
||||
def pre_execute_cb(call_index):
|
||||
# TODO - How to handle this with async functions without contextvars (which requires Python 3.12)?
|
||||
GraphBuilder.set_default_prefix(unique_id, call_index, 0)
|
||||
output_data, output_ui, has_subgraph, has_pending_tasks = await get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb)
|
||||
output_data, output_ui, has_subgraph, has_pending_tasks = await get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, hidden_inputs=hidden_inputs)
|
||||
if has_pending_tasks:
|
||||
pending_async_nodes[unique_id] = output_data
|
||||
unblock = execution_list.add_external_block(unique_id)
|
||||
@@ -672,8 +741,14 @@ async def validate_inputs(prompt_id, prompt, item, validated):
|
||||
|
||||
validate_function_inputs = []
|
||||
validate_has_kwargs = False
|
||||
if hasattr(obj_class, "VALIDATE_INPUTS"):
|
||||
argspec = inspect.getfullargspec(obj_class.VALIDATE_INPUTS)
|
||||
if issubclass(obj_class, _ComfyNodeInternal):
|
||||
validate_function_name = "validate_inputs"
|
||||
validate_function = first_real_override(obj_class, validate_function_name)
|
||||
else:
|
||||
validate_function_name = "VALIDATE_INPUTS"
|
||||
validate_function = getattr(obj_class, validate_function_name, None)
|
||||
if validate_function is not None:
|
||||
argspec = inspect.getfullargspec(validate_function)
|
||||
validate_function_inputs = argspec.args
|
||||
validate_has_kwargs = argspec.varkw is not None
|
||||
received_types = {}
|
||||
@@ -848,7 +923,7 @@ async def validate_inputs(prompt_id, prompt, item, validated):
|
||||
continue
|
||||
|
||||
if len(validate_function_inputs) > 0 or validate_has_kwargs:
|
||||
input_data_all, _ = get_input_data(inputs, obj_class, unique_id)
|
||||
input_data_all, _, hidden_inputs = get_input_data(inputs, obj_class, unique_id)
|
||||
input_filtered = {}
|
||||
for x in input_data_all:
|
||||
if x in validate_function_inputs or validate_has_kwargs:
|
||||
@@ -856,8 +931,7 @@ async def validate_inputs(prompt_id, prompt, item, validated):
|
||||
if 'input_types' in validate_function_inputs:
|
||||
input_filtered['input_types'] = [received_types]
|
||||
|
||||
#ret = obj_class.VALIDATE_INPUTS(**input_filtered)
|
||||
ret = await _async_map_node_over_list(prompt_id, unique_id, obj_class, input_filtered, "VALIDATE_INPUTS")
|
||||
ret = await _async_map_node_over_list(prompt_id, unique_id, obj_class, input_filtered, validate_function_name, hidden_inputs=hidden_inputs)
|
||||
ret = await resolve_map_node_over_list_results(ret)
|
||||
for x in input_filtered:
|
||||
for i, r in enumerate(ret):
|
||||
|
37
nodes.py
37
nodes.py
@@ -6,6 +6,7 @@ import os
|
||||
import sys
|
||||
import json
|
||||
import hashlib
|
||||
import inspect
|
||||
import traceback
|
||||
import math
|
||||
import time
|
||||
@@ -29,6 +30,7 @@ import comfy.controlnet
|
||||
from comfy.comfy_types import IO, ComfyNodeABC, InputTypeDict, FileLocator
|
||||
from comfy_api.internal import register_versions, ComfyAPIWithVersion
|
||||
from comfy_api.version_list import supported_versions
|
||||
from comfy_api.latest import io, ComfyExtension
|
||||
|
||||
import comfy.clip_vision
|
||||
|
||||
@@ -2152,6 +2154,7 @@ async def load_custom_node(module_path: str, ignore=set(), module_parent="custom
|
||||
if os.path.isdir(web_dir):
|
||||
EXTENSION_WEB_DIRS[module_name] = web_dir
|
||||
|
||||
# V1 node definition
|
||||
if hasattr(module, "NODE_CLASS_MAPPINGS") and getattr(module, "NODE_CLASS_MAPPINGS") is not None:
|
||||
for name, node_cls in module.NODE_CLASS_MAPPINGS.items():
|
||||
if name not in ignore:
|
||||
@@ -2160,8 +2163,38 @@ async def load_custom_node(module_path: str, ignore=set(), module_parent="custom
|
||||
if hasattr(module, "NODE_DISPLAY_NAME_MAPPINGS") and getattr(module, "NODE_DISPLAY_NAME_MAPPINGS") is not None:
|
||||
NODE_DISPLAY_NAME_MAPPINGS.update(module.NODE_DISPLAY_NAME_MAPPINGS)
|
||||
return True
|
||||
# V3 Extension Definition
|
||||
elif hasattr(module, "comfy_entrypoint"):
|
||||
entrypoint = getattr(module, "comfy_entrypoint")
|
||||
if not callable(entrypoint):
|
||||
logging.warning(f"comfy_entrypoint in {module_path} is not callable, skipping.")
|
||||
return False
|
||||
try:
|
||||
if inspect.iscoroutinefunction(entrypoint):
|
||||
extension = await entrypoint()
|
||||
else:
|
||||
extension = entrypoint()
|
||||
if not isinstance(extension, ComfyExtension):
|
||||
logging.warning(f"comfy_entrypoint in {module_path} did not return a ComfyExtension, skipping.")
|
||||
return False
|
||||
node_list = await extension.get_node_list()
|
||||
if not isinstance(node_list, list):
|
||||
logging.warning(f"comfy_entrypoint in {module_path} did not return a list of nodes, skipping.")
|
||||
return False
|
||||
for node_cls in node_list:
|
||||
node_cls: io.ComfyNode
|
||||
schema = node_cls.GET_SCHEMA()
|
||||
if schema.node_id not in ignore:
|
||||
NODE_CLASS_MAPPINGS[schema.node_id] = node_cls
|
||||
node_cls.RELATIVE_PYTHON_MODULE = "{}.{}".format(module_parent, get_module_name(module_path))
|
||||
if schema.display_name is not None:
|
||||
NODE_DISPLAY_NAME_MAPPINGS[schema.node_id] = schema.display_name
|
||||
return True
|
||||
except Exception as e:
|
||||
logging.warning(f"Error while calling comfy_entrypoint in {module_path}: {e}")
|
||||
return False
|
||||
else:
|
||||
logging.warning(f"Skip {module_path} module for custom nodes due to the lack of NODE_CLASS_MAPPINGS.")
|
||||
logging.warning(f"Skip {module_path} module for custom nodes due to the lack of NODE_CLASS_MAPPINGS or NODES_LIST (need one).")
|
||||
return False
|
||||
except Exception as e:
|
||||
logging.warning(traceback.format_exc())
|
||||
@@ -2286,7 +2319,7 @@ async def init_builtin_extra_nodes():
|
||||
"nodes_string.py",
|
||||
"nodes_camera_trajectory.py",
|
||||
"nodes_edit_model.py",
|
||||
"nodes_tcfg.py"
|
||||
"nodes_tcfg.py",
|
||||
]
|
||||
|
||||
import_failed = []
|
||||
|
@@ -30,6 +30,7 @@ from comfy_api import feature_flags
|
||||
import node_helpers
|
||||
from comfyui_version import __version__
|
||||
from app.frontend_management import FrontendManager
|
||||
from comfy_api.internal import _ComfyNodeInternal
|
||||
|
||||
from app.user_manager import UserManager
|
||||
from app.model_manager import ModelFileManager
|
||||
@@ -591,6 +592,8 @@ class PromptServer():
|
||||
|
||||
def node_info(node_class):
|
||||
obj_class = nodes.NODE_CLASS_MAPPINGS[node_class]
|
||||
if issubclass(obj_class, _ComfyNodeInternal):
|
||||
return obj_class.GET_NODE_INFO_V1()
|
||||
info = {}
|
||||
info['input'] = obj_class.INPUT_TYPES()
|
||||
info['input_order'] = {key: list(value.keys()) for (key, value) in obj_class.INPUT_TYPES().items()}
|
||||
|
Reference in New Issue
Block a user