1
mirror of https://github.com/comfyanonymous/ComfyUI.git synced 2025-08-03 23:49:57 +08:00

Merge branch 'master' into v3-definition - async v3 nodes do not currently work, but I will fix that in the next v3 PR

This commit is contained in:
Jedrzej Kosinski
2025-07-18 14:14:02 -07:00
37 changed files with 2437 additions and 480 deletions

View File

@@ -26,6 +26,7 @@ import mimetypes
from comfy.cli_args import args
import comfy.utils
import comfy.model_management
from comfy_api import feature_flags
import node_helpers
from comfyui_version import __version__
from app.frontend_management import FrontendManager
@@ -36,11 +37,7 @@ from app.model_manager import ModelFileManager
from app.custom_node_manager import CustomNodeManager
from typing import Optional, Union
from api_server.routes.internal.internal_routes import InternalRoutes
class BinaryEventTypes:
PREVIEW_IMAGE = 1
UNENCODED_PREVIEW_IMAGE = 2
TEXT = 3
from protocol import BinaryEventTypes
async def send_socket_catch_exception(function, message):
try:
@@ -179,6 +176,7 @@ class PromptServer():
max_upload_size = round(args.max_upload_size * 1024 * 1024)
self.app = web.Application(client_max_size=max_upload_size, middlewares=middlewares)
self.sockets = dict()
self.sockets_metadata = dict()
self.web_root = (
FrontendManager.init_frontend(args.front_end_version)
if args.front_end_root is None
@@ -203,20 +201,53 @@ class PromptServer():
else:
sid = uuid.uuid4().hex
# Store WebSocket for backward compatibility
self.sockets[sid] = ws
# Store metadata separately
self.sockets_metadata[sid] = {"feature_flags": {}}
try:
# Send initial state to the new client
await self.send("status", { "status": self.get_queue_info(), 'sid': sid }, sid)
await self.send("status", {"status": self.get_queue_info(), "sid": sid}, sid)
# On reconnect if we are the currently executing client send the current node
if self.client_id == sid and self.last_node_id is not None:
await self.send("executing", { "node": self.last_node_id }, sid)
# Flag to track if we've received the first message
first_message = True
async for msg in ws:
if msg.type == aiohttp.WSMsgType.ERROR:
logging.warning('ws connection closed with exception %s' % ws.exception())
elif msg.type == aiohttp.WSMsgType.TEXT:
try:
data = json.loads(msg.data)
# Check if first message is feature flags
if first_message and data.get("type") == "feature_flags":
# Store client feature flags
client_flags = data.get("data", {})
self.sockets_metadata[sid]["feature_flags"] = client_flags
# Send server feature flags in response
await self.send(
"feature_flags",
feature_flags.get_server_features(),
sid,
)
logging.info(
f"Feature flags negotiated for client {sid}: {client_flags}"
)
first_message = False
except json.JSONDecodeError:
logging.warning(
f"Invalid JSON received from client {sid}: {msg.data}"
)
except Exception as e:
logging.error(f"Error processing WebSocket message: {e}")
finally:
self.sockets.pop(sid, None)
self.sockets_metadata.pop(sid, None)
return ws
@routes.get("/")
@@ -549,6 +580,10 @@ class PromptServer():
}
return web.json_response(system_stats)
@routes.get("/features")
async def get_features(request):
return web.json_response(feature_flags.get_server_features())
@routes.get("/prompt")
async def get_prompt(request):
return web.json_response(self.get_queue_info())
@@ -646,7 +681,8 @@ class PromptServer():
if "prompt" in json_data:
prompt = json_data["prompt"]
valid = execution.validate_prompt(prompt)
prompt_id = str(json_data.get("prompt_id", uuid.uuid4()))
valid = await execution.validate_prompt(prompt_id, prompt)
extra_data = {}
if "extra_data" in json_data:
extra_data = json_data["extra_data"]
@@ -654,7 +690,6 @@ class PromptServer():
if "client_id" in json_data:
extra_data["client_id"] = json_data["client_id"]
if valid[0]:
prompt_id = str(uuid.uuid4())
outputs_to_execute = valid[2]
self.prompt_queue.put((number, prompt_id, prompt, extra_data, outputs_to_execute))
response = {"prompt_id": prompt_id, "number": number, "node_errors": valid[3]}
@@ -769,6 +804,10 @@ class PromptServer():
async def send(self, event, data, sid=None):
if event == BinaryEventTypes.UNENCODED_PREVIEW_IMAGE:
await self.send_image(data, sid=sid)
elif event == BinaryEventTypes.PREVIEW_IMAGE_WITH_METADATA:
# data is (preview_image, metadata)
preview_image, metadata = data
await self.send_image_with_metadata(preview_image, metadata, sid=sid)
elif isinstance(data, (bytes, bytearray)):
await self.send_bytes(event, data, sid)
else:
@@ -807,6 +846,43 @@ class PromptServer():
preview_bytes = bytesIO.getvalue()
await self.send_bytes(BinaryEventTypes.PREVIEW_IMAGE, preview_bytes, sid=sid)
async def send_image_with_metadata(self, image_data, metadata=None, sid=None):
image_type = image_data[0]
image = image_data[1]
max_size = image_data[2]
if max_size is not None:
if hasattr(Image, 'Resampling'):
resampling = Image.Resampling.BILINEAR
else:
resampling = Image.Resampling.LANCZOS
image = ImageOps.contain(image, (max_size, max_size), resampling)
mimetype = "image/png" if image_type == "PNG" else "image/jpeg"
# Prepare metadata
if metadata is None:
metadata = {}
metadata["image_type"] = mimetype
# Serialize metadata as JSON
import json
metadata_json = json.dumps(metadata).encode('utf-8')
metadata_length = len(metadata_json)
# Prepare image data
bytesIO = BytesIO()
image.save(bytesIO, format=image_type, quality=95, compress_level=1)
image_bytes = bytesIO.getvalue()
# Combine metadata and image
combined_data = bytearray()
combined_data.extend(struct.pack(">I", metadata_length))
combined_data.extend(metadata_json)
combined_data.extend(image_bytes)
await self.send_bytes(BinaryEventTypes.PREVIEW_IMAGE_WITH_METADATA, combined_data, sid=sid)
async def send_bytes(self, event, data, sid=None):
message = self.encode_bytes(event, data)
@@ -848,10 +924,10 @@ class PromptServer():
ssl_ctx = None
scheme = "http"
if args.tls_keyfile and args.tls_certfile:
ssl_ctx = ssl.SSLContext(protocol=ssl.PROTOCOL_TLS_SERVER, verify_mode=ssl.CERT_NONE)
ssl_ctx.load_cert_chain(certfile=args.tls_certfile,
ssl_ctx = ssl.SSLContext(protocol=ssl.PROTOCOL_TLS_SERVER, verify_mode=ssl.CERT_NONE)
ssl_ctx.load_cert_chain(certfile=args.tls_certfile,
keyfile=args.tls_keyfile)
scheme = "https"
scheme = "https"
if verbose:
logging.info("Starting server\n")