mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-08-02 23:14:49 +08:00
Compare commits
46 Commits
Author | SHA1 | Date | |
---|---|---|---|
|
6e8cdcd3cb | ||
|
e5c3f4b87f | ||
|
bc6be6c11e | ||
|
94323a26a7 | ||
|
5818f6cf51 | ||
|
0b734de449 | ||
|
5e16f1d24b | ||
|
2fd9c1308a | ||
|
8f0009aad0 | ||
|
41444b5236 | ||
|
772e620e32 | ||
|
07f6eeaa13 | ||
|
22535d0589 | ||
|
898615122f | ||
|
156a28786b | ||
|
f498d855ba | ||
|
b699a15062 | ||
|
9cc90ee3eb | ||
|
9a0a5d32ee | ||
|
d9f90965c8 | ||
|
41886af138 | ||
|
22a1d7ce78 | ||
|
4ac401af2b | ||
|
5fb59c8475 | ||
|
122c9ca1ce | ||
|
3b9a6cf2b1 | ||
|
3748e7ef7a | ||
|
8ebf2d8831 | ||
|
a72d152b0c | ||
|
eb476e6ea9 | ||
|
2d28b0b479 | ||
|
8b275ce5be | ||
|
2a18e98ccf | ||
|
8a5281006f | ||
|
bdeb1c171c | ||
|
9c1ed58ef2 | ||
|
8b90e50979 | ||
|
6ee066a14f | ||
|
dd5b57e3d7 | ||
|
75a818c720 | ||
|
2865f913f7 | ||
|
b49616f951 | ||
|
5e29e7a488 | ||
|
8afb97cd3f | ||
|
69694f40b3 | ||
|
c49025f01b |
@@ -28,7 +28,7 @@
|
||||
[github-downloads-latest-shield]: https://img.shields.io/github/downloads/comfyanonymous/ComfyUI/latest/total?style=flat&label=downloads%40latest
|
||||
[github-downloads-link]: https://github.com/comfyanonymous/ComfyUI/releases
|
||||
|
||||

|
||||

|
||||
</div>
|
||||
|
||||
This ui will let you design and execute advanced stable diffusion pipelines using a graph/nodes/flowchart based interface. For some workflow examples and see what ComfyUI can do you can check out:
|
||||
@@ -39,6 +39,7 @@ This ui will let you design and execute advanced stable diffusion pipelines usin
|
||||
## Features
|
||||
- Nodes/graph/flowchart interface to experiment and create complex Stable Diffusion workflows without needing to code anything.
|
||||
- Fully supports SD1.x, SD2.x, [SDXL](https://comfyanonymous.github.io/ComfyUI_examples/sdxl/), [Stable Video Diffusion](https://comfyanonymous.github.io/ComfyUI_examples/video/), [Stable Cascade](https://comfyanonymous.github.io/ComfyUI_examples/stable_cascade/), [SD3](https://comfyanonymous.github.io/ComfyUI_examples/sd3/) and [Stable Audio](https://comfyanonymous.github.io/ComfyUI_examples/audio/)
|
||||
- [LTX-Video](https://comfyanonymous.github.io/ComfyUI_examples/ltxv/)
|
||||
- [Flux](https://comfyanonymous.github.io/ComfyUI_examples/flux/)
|
||||
- [Mochi](https://comfyanonymous.github.io/ComfyUI_examples/mochi/)
|
||||
- Asynchronous Queue system
|
||||
@@ -140,7 +141,7 @@ Put your VAE in: models/vae
|
||||
### AMD GPUs (Linux only)
|
||||
AMD users can install rocm and pytorch with pip if you don't have it already installed, this is the command to install the stable version:
|
||||
|
||||
```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.1```
|
||||
```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.2```
|
||||
|
||||
This is the command to install the nightly with ROCm 6.2 which might have some performance improvements:
|
||||
|
||||
|
@@ -2,6 +2,7 @@ from aiohttp import web
|
||||
from typing import Optional
|
||||
from folder_paths import models_dir, user_directory, output_directory, folder_names_and_paths
|
||||
from api_server.services.file_service import FileService
|
||||
from api_server.services.terminal_service import TerminalService
|
||||
import app.logger
|
||||
|
||||
class InternalRoutes:
|
||||
@@ -11,7 +12,8 @@ class InternalRoutes:
|
||||
Check README.md for more information.
|
||||
|
||||
'''
|
||||
def __init__(self):
|
||||
|
||||
def __init__(self, prompt_server):
|
||||
self.routes: web.RouteTableDef = web.RouteTableDef()
|
||||
self._app: Optional[web.Application] = None
|
||||
self.file_service = FileService({
|
||||
@@ -19,6 +21,8 @@ class InternalRoutes:
|
||||
"user": user_directory,
|
||||
"output": output_directory
|
||||
})
|
||||
self.prompt_server = prompt_server
|
||||
self.terminal_service = TerminalService(prompt_server)
|
||||
|
||||
def setup_routes(self):
|
||||
@self.routes.get('/files')
|
||||
@@ -34,7 +38,28 @@ class InternalRoutes:
|
||||
|
||||
@self.routes.get('/logs')
|
||||
async def get_logs(request):
|
||||
return web.json_response(app.logger.get_logs())
|
||||
return web.json_response("".join([(l["t"] + " - " + l["m"]) for l in app.logger.get_logs()]))
|
||||
|
||||
@self.routes.get('/logs/raw')
|
||||
async def get_logs(request):
|
||||
self.terminal_service.update_size()
|
||||
return web.json_response({
|
||||
"entries": list(app.logger.get_logs()),
|
||||
"size": {"cols": self.terminal_service.cols, "rows": self.terminal_service.rows}
|
||||
})
|
||||
|
||||
@self.routes.patch('/logs/subscribe')
|
||||
async def subscribe_logs(request):
|
||||
json_data = await request.json()
|
||||
client_id = json_data["clientId"]
|
||||
enabled = json_data["enabled"]
|
||||
if enabled:
|
||||
self.terminal_service.subscribe(client_id)
|
||||
else:
|
||||
self.terminal_service.unsubscribe(client_id)
|
||||
|
||||
return web.Response(status=200)
|
||||
|
||||
|
||||
@self.routes.get('/folder_paths')
|
||||
async def get_folder_paths(request):
|
||||
|
60
api_server/services/terminal_service.py
Normal file
60
api_server/services/terminal_service.py
Normal file
@@ -0,0 +1,60 @@
|
||||
from app.logger import on_flush
|
||||
import os
|
||||
import shutil
|
||||
|
||||
|
||||
class TerminalService:
|
||||
def __init__(self, server):
|
||||
self.server = server
|
||||
self.cols = None
|
||||
self.rows = None
|
||||
self.subscriptions = set()
|
||||
on_flush(self.send_messages)
|
||||
|
||||
def get_terminal_size(self):
|
||||
try:
|
||||
size = os.get_terminal_size()
|
||||
return (size.columns, size.lines)
|
||||
except OSError:
|
||||
try:
|
||||
size = shutil.get_terminal_size()
|
||||
return (size.columns, size.lines)
|
||||
except OSError:
|
||||
return (80, 24) # fallback to 80x24
|
||||
|
||||
def update_size(self):
|
||||
columns, lines = self.get_terminal_size()
|
||||
changed = False
|
||||
|
||||
if columns != self.cols:
|
||||
self.cols = columns
|
||||
changed = True
|
||||
|
||||
if lines != self.rows:
|
||||
self.rows = lines
|
||||
changed = True
|
||||
|
||||
if changed:
|
||||
return {"cols": self.cols, "rows": self.rows}
|
||||
|
||||
return None
|
||||
|
||||
def subscribe(self, client_id):
|
||||
self.subscriptions.add(client_id)
|
||||
|
||||
def unsubscribe(self, client_id):
|
||||
self.subscriptions.discard(client_id)
|
||||
|
||||
def send_messages(self, entries):
|
||||
if not len(entries) or not len(self.subscriptions):
|
||||
return
|
||||
|
||||
new_size = self.update_size()
|
||||
|
||||
for client_id in self.subscriptions.copy(): # prevent: Set changed size during iteration
|
||||
if client_id not in self.server.sockets:
|
||||
# Automatically unsub if the socket has disconnected
|
||||
self.unsubscribe(client_id)
|
||||
continue
|
||||
|
||||
self.server.send_sync("logs", {"entries": entries, "size": new_size}, client_id)
|
@@ -1,20 +1,69 @@
|
||||
import logging
|
||||
from logging.handlers import MemoryHandler
|
||||
from collections import deque
|
||||
from datetime import datetime
|
||||
import io
|
||||
import logging
|
||||
import sys
|
||||
import threading
|
||||
|
||||
logs = None
|
||||
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
||||
stdout_interceptor = None
|
||||
stderr_interceptor = None
|
||||
|
||||
|
||||
class LogInterceptor(io.TextIOWrapper):
|
||||
def __init__(self, stream, *args, **kwargs):
|
||||
buffer = stream.buffer
|
||||
encoding = stream.encoding
|
||||
super().__init__(buffer, *args, **kwargs, encoding=encoding, line_buffering=stream.line_buffering)
|
||||
self._lock = threading.Lock()
|
||||
self._flush_callbacks = []
|
||||
self._logs_since_flush = []
|
||||
|
||||
def write(self, data):
|
||||
entry = {"t": datetime.now().isoformat(), "m": data}
|
||||
with self._lock:
|
||||
self._logs_since_flush.append(entry)
|
||||
|
||||
# Simple handling for cr to overwrite the last output if it isnt a full line
|
||||
# else logs just get full of progress messages
|
||||
if isinstance(data, str) and data.startswith("\r") and not logs[-1]["m"].endswith("\n"):
|
||||
logs.pop()
|
||||
logs.append(entry)
|
||||
super().write(data)
|
||||
|
||||
def flush(self):
|
||||
super().flush()
|
||||
for cb in self._flush_callbacks:
|
||||
cb(self._logs_since_flush)
|
||||
self._logs_since_flush = []
|
||||
|
||||
def on_flush(self, callback):
|
||||
self._flush_callbacks.append(callback)
|
||||
|
||||
|
||||
def get_logs():
|
||||
return "\n".join([formatter.format(x) for x in logs])
|
||||
return logs
|
||||
|
||||
|
||||
def on_flush(callback):
|
||||
if stdout_interceptor is not None:
|
||||
stdout_interceptor.on_flush(callback)
|
||||
if stderr_interceptor is not None:
|
||||
stderr_interceptor.on_flush(callback)
|
||||
|
||||
def setup_logger(log_level: str = 'INFO', capacity: int = 300):
|
||||
global logs
|
||||
if logs:
|
||||
return
|
||||
|
||||
# Override output streams and log to buffer
|
||||
logs = deque(maxlen=capacity)
|
||||
|
||||
global stdout_interceptor
|
||||
global stderr_interceptor
|
||||
stdout_interceptor = sys.stdout = LogInterceptor(sys.stdout)
|
||||
stderr_interceptor = sys.stderr = LogInterceptor(sys.stderr)
|
||||
|
||||
# Setup default global logger
|
||||
logger = logging.getLogger()
|
||||
logger.setLevel(log_level)
|
||||
@@ -22,10 +71,3 @@ def setup_logger(log_level: str = 'INFO', capacity: int = 300):
|
||||
stream_handler = logging.StreamHandler()
|
||||
stream_handler.setFormatter(logging.Formatter("%(message)s"))
|
||||
logger.addHandler(stream_handler)
|
||||
|
||||
# Create a memory handler with a deque as its buffer
|
||||
logs = deque(maxlen=capacity)
|
||||
memory_handler = MemoryHandler(capacity, flushLevel=logging.INFO)
|
||||
memory_handler.buffer = logs
|
||||
memory_handler.setFormatter(formatter)
|
||||
logger.addHandler(memory_handler)
|
||||
|
@@ -1,18 +1,35 @@
|
||||
from __future__ import annotations
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import uuid
|
||||
import glob
|
||||
import shutil
|
||||
import logging
|
||||
from aiohttp import web
|
||||
from urllib import parse
|
||||
from comfy.cli_args import args
|
||||
import folder_paths
|
||||
from .app_settings import AppSettings
|
||||
from typing import TypedDict
|
||||
|
||||
default_user = "default"
|
||||
|
||||
|
||||
class FileInfo(TypedDict):
|
||||
path: str
|
||||
size: int
|
||||
modified: int
|
||||
|
||||
|
||||
def get_file_info(path: str, relative_to: str) -> FileInfo:
|
||||
return {
|
||||
"path": os.path.relpath(path, relative_to).replace(os.sep, '/'),
|
||||
"size": os.path.getsize(path),
|
||||
"modified": os.path.getmtime(path)
|
||||
}
|
||||
|
||||
|
||||
class UserManager():
|
||||
def __init__(self):
|
||||
user_directory = folder_paths.get_user_directory()
|
||||
@@ -154,6 +171,7 @@ class UserManager():
|
||||
|
||||
recurse = request.rel_url.query.get('recurse', '').lower() == "true"
|
||||
full_info = request.rel_url.query.get('full_info', '').lower() == "true"
|
||||
split_path = request.rel_url.query.get('split', '').lower() == "true"
|
||||
|
||||
# Use different patterns based on whether we're recursing or not
|
||||
if recurse:
|
||||
@@ -161,26 +179,21 @@ class UserManager():
|
||||
else:
|
||||
pattern = os.path.join(glob.escape(path), '*')
|
||||
|
||||
results = glob.glob(pattern, recursive=recurse)
|
||||
def process_full_path(full_path: str) -> FileInfo | str | list[str]:
|
||||
if full_info:
|
||||
return get_file_info(full_path, path)
|
||||
|
||||
if full_info:
|
||||
results = [
|
||||
{
|
||||
'path': os.path.relpath(x, path).replace(os.sep, '/'),
|
||||
'size': os.path.getsize(x),
|
||||
'modified': os.path.getmtime(x)
|
||||
} for x in results if os.path.isfile(x)
|
||||
]
|
||||
else:
|
||||
results = [
|
||||
os.path.relpath(x, path).replace(os.sep, '/')
|
||||
for x in results
|
||||
if os.path.isfile(x)
|
||||
]
|
||||
rel_path = os.path.relpath(full_path, path).replace(os.sep, '/')
|
||||
if split_path:
|
||||
return [rel_path] + rel_path.split('/')
|
||||
|
||||
split_path = request.rel_url.query.get('split', '').lower() == "true"
|
||||
if split_path and not full_info:
|
||||
results = [[x] + x.split('/') for x in results]
|
||||
return rel_path
|
||||
|
||||
results = [
|
||||
process_full_path(full_path)
|
||||
for full_path in glob.glob(pattern, recursive=recurse)
|
||||
if os.path.isfile(full_path)
|
||||
]
|
||||
|
||||
return web.json_response(results)
|
||||
|
||||
@@ -208,20 +221,51 @@ class UserManager():
|
||||
|
||||
@routes.post("/userdata/{file}")
|
||||
async def post_userdata(request):
|
||||
"""
|
||||
Upload or update a user data file.
|
||||
|
||||
This endpoint handles file uploads to a user's data directory, with options for
|
||||
controlling overwrite behavior and response format.
|
||||
|
||||
Query Parameters:
|
||||
- overwrite (optional): If "false", prevents overwriting existing files. Defaults to "true".
|
||||
- full_info (optional): If "true", returns detailed file information (path, size, modified time).
|
||||
If "false", returns only the relative file path.
|
||||
|
||||
Path Parameters:
|
||||
- file: The target file path (URL encoded if necessary).
|
||||
|
||||
Returns:
|
||||
- 400: If 'file' parameter is missing.
|
||||
- 403: If the requested path is not allowed.
|
||||
- 409: If overwrite=false and the file already exists.
|
||||
- 200: JSON response with either:
|
||||
- Full file information (if full_info=true)
|
||||
- Relative file path (if full_info=false)
|
||||
|
||||
The request body should contain the raw file content to be written.
|
||||
"""
|
||||
path = get_user_data_path(request)
|
||||
if not isinstance(path, str):
|
||||
return path
|
||||
|
||||
overwrite = request.query["overwrite"] != "false"
|
||||
overwrite = request.query.get("overwrite", 'true') != "false"
|
||||
full_info = request.query.get('full_info', 'false').lower() == "true"
|
||||
|
||||
if not overwrite and os.path.exists(path):
|
||||
return web.Response(status=409)
|
||||
return web.Response(status=409, text="File already exists")
|
||||
|
||||
body = await request.read()
|
||||
|
||||
with open(path, "wb") as f:
|
||||
f.write(body)
|
||||
|
||||
resp = os.path.relpath(path, self.get_request_user_filepath(request, None))
|
||||
user_path = self.get_request_user_filepath(request, None)
|
||||
if full_info:
|
||||
resp = get_file_info(path, user_path)
|
||||
else:
|
||||
resp = os.path.relpath(path, user_path)
|
||||
|
||||
return web.json_response(resp)
|
||||
|
||||
@routes.delete("/userdata/{file}")
|
||||
@@ -236,6 +280,30 @@ class UserManager():
|
||||
|
||||
@routes.post("/userdata/{file}/move/{dest}")
|
||||
async def move_userdata(request):
|
||||
"""
|
||||
Move or rename a user data file.
|
||||
|
||||
This endpoint handles moving or renaming files within a user's data directory, with options for
|
||||
controlling overwrite behavior and response format.
|
||||
|
||||
Path Parameters:
|
||||
- file: The source file path (URL encoded if necessary)
|
||||
- dest: The destination file path (URL encoded if necessary)
|
||||
|
||||
Query Parameters:
|
||||
- overwrite (optional): If "false", prevents overwriting existing files. Defaults to "true".
|
||||
- full_info (optional): If "true", returns detailed file information (path, size, modified time).
|
||||
If "false", returns only the relative file path.
|
||||
|
||||
Returns:
|
||||
- 400: If either 'file' or 'dest' parameter is missing
|
||||
- 403: If either requested path is not allowed
|
||||
- 404: If the source file does not exist
|
||||
- 409: If overwrite=false and the destination file already exists
|
||||
- 200: JSON response with either:
|
||||
- Full file information (if full_info=true)
|
||||
- Relative file path (if full_info=false)
|
||||
"""
|
||||
source = get_user_data_path(request, check_exists=True)
|
||||
if not isinstance(source, str):
|
||||
return source
|
||||
@@ -244,12 +312,19 @@ class UserManager():
|
||||
if not isinstance(source, str):
|
||||
return dest
|
||||
|
||||
overwrite = request.query["overwrite"] != "false"
|
||||
if not overwrite and os.path.exists(dest):
|
||||
return web.Response(status=409)
|
||||
overwrite = request.query.get("overwrite", 'true') != "false"
|
||||
full_info = request.query.get('full_info', 'false').lower() == "true"
|
||||
|
||||
print(f"moving '{source}' -> '{dest}'")
|
||||
if not overwrite and os.path.exists(dest):
|
||||
return web.Response(status=409, text="File already exists")
|
||||
|
||||
logging.info(f"moving '{source}' -> '{dest}'")
|
||||
shutil.move(source, dest)
|
||||
|
||||
resp = os.path.relpath(dest, self.get_request_user_filepath(request, None))
|
||||
user_path = self.get_request_user_filepath(request, None)
|
||||
if full_info:
|
||||
resp = get_file_info(dest, user_path)
|
||||
else:
|
||||
resp = os.path.relpath(dest, user_path)
|
||||
|
||||
return web.json_response(resp)
|
||||
|
@@ -23,6 +23,7 @@ class CLIPAttention(torch.nn.Module):
|
||||
|
||||
ACTIVATIONS = {"quick_gelu": lambda a: a * torch.sigmoid(1.702 * a),
|
||||
"gelu": torch.nn.functional.gelu,
|
||||
"gelu_pytorch_tanh": lambda a: torch.nn.functional.gelu(a, approximate="tanh"),
|
||||
}
|
||||
|
||||
class CLIPMLP(torch.nn.Module):
|
||||
@@ -139,27 +140,35 @@ class CLIPTextModel(torch.nn.Module):
|
||||
|
||||
|
||||
class CLIPVisionEmbeddings(torch.nn.Module):
|
||||
def __init__(self, embed_dim, num_channels=3, patch_size=14, image_size=224, dtype=None, device=None, operations=None):
|
||||
def __init__(self, embed_dim, num_channels=3, patch_size=14, image_size=224, model_type="", dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.class_embedding = torch.nn.Parameter(torch.empty(embed_dim, dtype=dtype, device=device))
|
||||
|
||||
num_patches = (image_size // patch_size) ** 2
|
||||
if model_type == "siglip_vision_model":
|
||||
self.class_embedding = None
|
||||
patch_bias = True
|
||||
else:
|
||||
num_patches = num_patches + 1
|
||||
self.class_embedding = torch.nn.Parameter(torch.empty(embed_dim, dtype=dtype, device=device))
|
||||
patch_bias = False
|
||||
|
||||
self.patch_embedding = operations.Conv2d(
|
||||
in_channels=num_channels,
|
||||
out_channels=embed_dim,
|
||||
kernel_size=patch_size,
|
||||
stride=patch_size,
|
||||
bias=False,
|
||||
bias=patch_bias,
|
||||
dtype=dtype,
|
||||
device=device
|
||||
)
|
||||
|
||||
num_patches = (image_size // patch_size) ** 2
|
||||
num_positions = num_patches + 1
|
||||
self.position_embedding = operations.Embedding(num_positions, embed_dim, dtype=dtype, device=device)
|
||||
self.position_embedding = operations.Embedding(num_patches, embed_dim, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, pixel_values):
|
||||
embeds = self.patch_embedding(pixel_values).flatten(2).transpose(1, 2)
|
||||
return torch.cat([comfy.ops.cast_to_input(self.class_embedding, embeds).expand(pixel_values.shape[0], 1, -1), embeds], dim=1) + comfy.ops.cast_to_input(self.position_embedding.weight, embeds)
|
||||
if self.class_embedding is not None:
|
||||
embeds = torch.cat([comfy.ops.cast_to_input(self.class_embedding, embeds).expand(pixel_values.shape[0], 1, -1), embeds], dim=1)
|
||||
return embeds + comfy.ops.cast_to_input(self.position_embedding.weight, embeds)
|
||||
|
||||
|
||||
class CLIPVision(torch.nn.Module):
|
||||
@@ -170,9 +179,15 @@ class CLIPVision(torch.nn.Module):
|
||||
heads = config_dict["num_attention_heads"]
|
||||
intermediate_size = config_dict["intermediate_size"]
|
||||
intermediate_activation = config_dict["hidden_act"]
|
||||
model_type = config_dict["model_type"]
|
||||
|
||||
self.embeddings = CLIPVisionEmbeddings(embed_dim, config_dict["num_channels"], config_dict["patch_size"], config_dict["image_size"], dtype=dtype, device=device, operations=operations)
|
||||
self.pre_layrnorm = operations.LayerNorm(embed_dim)
|
||||
self.embeddings = CLIPVisionEmbeddings(embed_dim, config_dict["num_channels"], config_dict["patch_size"], config_dict["image_size"], model_type=model_type, dtype=dtype, device=device, operations=operations)
|
||||
if model_type == "siglip_vision_model":
|
||||
self.pre_layrnorm = lambda a: a
|
||||
self.output_layernorm = True
|
||||
else:
|
||||
self.pre_layrnorm = operations.LayerNorm(embed_dim)
|
||||
self.output_layernorm = False
|
||||
self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations)
|
||||
self.post_layernorm = operations.LayerNorm(embed_dim)
|
||||
|
||||
@@ -181,14 +196,21 @@ class CLIPVision(torch.nn.Module):
|
||||
x = self.pre_layrnorm(x)
|
||||
#TODO: attention_mask?
|
||||
x, i = self.encoder(x, mask=None, intermediate_output=intermediate_output)
|
||||
pooled_output = self.post_layernorm(x[:, 0, :])
|
||||
if self.output_layernorm:
|
||||
x = self.post_layernorm(x)
|
||||
pooled_output = x
|
||||
else:
|
||||
pooled_output = self.post_layernorm(x[:, 0, :])
|
||||
return x, i, pooled_output
|
||||
|
||||
class CLIPVisionModelProjection(torch.nn.Module):
|
||||
def __init__(self, config_dict, dtype, device, operations):
|
||||
super().__init__()
|
||||
self.vision_model = CLIPVision(config_dict, dtype, device, operations)
|
||||
self.visual_projection = operations.Linear(config_dict["hidden_size"], config_dict["projection_dim"], bias=False)
|
||||
if "projection_dim" in config_dict:
|
||||
self.visual_projection = operations.Linear(config_dict["hidden_size"], config_dict["projection_dim"], bias=False)
|
||||
else:
|
||||
self.visual_projection = lambda a: a
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
x = self.vision_model(*args, **kwargs)
|
||||
|
@@ -16,9 +16,9 @@ class Output:
|
||||
def __setitem__(self, key, item):
|
||||
setattr(self, key, item)
|
||||
|
||||
def clip_preprocess(image, size=224):
|
||||
mean = torch.tensor([ 0.48145466,0.4578275,0.40821073], device=image.device, dtype=image.dtype)
|
||||
std = torch.tensor([0.26862954,0.26130258,0.27577711], device=image.device, dtype=image.dtype)
|
||||
def clip_preprocess(image, size=224, mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711]):
|
||||
mean = torch.tensor(mean, device=image.device, dtype=image.dtype)
|
||||
std = torch.tensor(std, device=image.device, dtype=image.dtype)
|
||||
image = image.movedim(-1, 1)
|
||||
if not (image.shape[2] == size and image.shape[3] == size):
|
||||
scale = (size / min(image.shape[2], image.shape[3]))
|
||||
@@ -35,6 +35,8 @@ class ClipVisionModel():
|
||||
config = json.load(f)
|
||||
|
||||
self.image_size = config.get("image_size", 224)
|
||||
self.image_mean = config.get("image_mean", [0.48145466, 0.4578275, 0.40821073])
|
||||
self.image_std = config.get("image_std", [0.26862954, 0.26130258, 0.27577711])
|
||||
self.load_device = comfy.model_management.text_encoder_device()
|
||||
offload_device = comfy.model_management.text_encoder_offload_device()
|
||||
self.dtype = comfy.model_management.text_encoder_dtype(self.load_device)
|
||||
@@ -51,7 +53,7 @@ class ClipVisionModel():
|
||||
|
||||
def encode_image(self, image):
|
||||
comfy.model_management.load_model_gpu(self.patcher)
|
||||
pixel_values = clip_preprocess(image.to(self.load_device), size=self.image_size).float()
|
||||
pixel_values = clip_preprocess(image.to(self.load_device), size=self.image_size, mean=self.image_mean, std=self.image_std).float()
|
||||
out = self.model(pixel_values=pixel_values, intermediate_output=-2)
|
||||
|
||||
outputs = Output()
|
||||
@@ -94,7 +96,9 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False):
|
||||
elif "vision_model.encoder.layers.30.layer_norm1.weight" in sd:
|
||||
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_h.json")
|
||||
elif "vision_model.encoder.layers.22.layer_norm1.weight" in sd:
|
||||
if sd["vision_model.embeddings.position_embedding.weight"].shape[0] == 577:
|
||||
if sd["vision_model.encoder.layers.0.layer_norm1.weight"].shape[0] == 1152:
|
||||
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_siglip_384.json")
|
||||
elif sd["vision_model.embeddings.position_embedding.weight"].shape[0] == 577:
|
||||
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl_336.json")
|
||||
else:
|
||||
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl.json")
|
||||
|
13
comfy/clip_vision_siglip_384.json
Normal file
13
comfy/clip_vision_siglip_384.json
Normal file
@@ -0,0 +1,13 @@
|
||||
{
|
||||
"num_channels": 3,
|
||||
"hidden_act": "gelu_pytorch_tanh",
|
||||
"hidden_size": 1152,
|
||||
"image_size": 384,
|
||||
"intermediate_size": 4304,
|
||||
"model_type": "siglip_vision_model",
|
||||
"num_attention_heads": 16,
|
||||
"num_hidden_layers": 27,
|
||||
"patch_size": 14,
|
||||
"image_mean": [0.5, 0.5, 0.5],
|
||||
"image_std": [0.5, 0.5, 0.5]
|
||||
}
|
@@ -190,7 +190,21 @@ class Mochi(LatentFormat):
|
||||
0.9294154431013696, 1.3720942357788521, 0.881393668867029,
|
||||
0.9168315692124348, 0.9185249279345552, 0.9274757570805041]).view(1, self.latent_channels, 1, 1, 1)
|
||||
|
||||
self.latent_rgb_factors = None #TODO
|
||||
self.latent_rgb_factors =[
|
||||
[-0.0069, -0.0045, 0.0018],
|
||||
[ 0.0154, -0.0692, -0.0274],
|
||||
[ 0.0333, 0.0019, 0.0206],
|
||||
[-0.1390, 0.0628, 0.1678],
|
||||
[-0.0725, 0.0134, -0.1898],
|
||||
[ 0.0074, -0.0270, -0.0209],
|
||||
[-0.0176, -0.0277, -0.0221],
|
||||
[ 0.5294, 0.5204, 0.3852],
|
||||
[-0.0326, -0.0446, -0.0143],
|
||||
[-0.0659, 0.0153, -0.0153],
|
||||
[ 0.0185, -0.0217, 0.0014],
|
||||
[-0.0396, -0.0495, -0.0281]
|
||||
]
|
||||
self.latent_rgb_factors_bias = [-0.0940, -0.1418, -0.1453]
|
||||
self.taesd_decoder_name = None #TODO
|
||||
|
||||
def process_in(self, latent):
|
||||
@@ -202,3 +216,7 @@ class Mochi(LatentFormat):
|
||||
latents_mean = self.latents_mean.to(latent.device, latent.dtype)
|
||||
latents_std = self.latents_std.to(latent.device, latent.dtype)
|
||||
return latent * latents_std / self.scale_factor + latents_mean
|
||||
|
||||
class LTXV(LatentFormat):
|
||||
latent_channels = 128
|
||||
|
||||
|
@@ -612,7 +612,9 @@ class ContinuousTransformer(nn.Module):
|
||||
return_info = False,
|
||||
**kwargs
|
||||
):
|
||||
patches_replace = kwargs.get("transformer_options", {}).get("patches_replace", {})
|
||||
batch, seq, device = *x.shape[:2], x.device
|
||||
context = kwargs["context"]
|
||||
|
||||
info = {
|
||||
"hidden_states": [],
|
||||
@@ -643,9 +645,19 @@ class ContinuousTransformer(nn.Module):
|
||||
if self.use_sinusoidal_emb or self.use_abs_pos_emb:
|
||||
x = x + self.pos_emb(x)
|
||||
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
# Iterate over the transformer layers
|
||||
for layer in self.layers:
|
||||
x = layer(x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, **kwargs)
|
||||
for i, layer in enumerate(self.layers):
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["img"] = layer(args["img"], rotary_pos_emb=args["pe"], global_cond=args["vec"], context=args["txt"])
|
||||
return out
|
||||
|
||||
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": global_cond, "pe": rotary_pos_emb}, {"original_block": block_wrap})
|
||||
x = out["img"]
|
||||
else:
|
||||
x = layer(x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, context=context)
|
||||
# x = checkpoint(layer, x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, **kwargs)
|
||||
|
||||
if return_info:
|
||||
@@ -874,7 +886,6 @@ class AudioDiffusionTransformer(nn.Module):
|
||||
mask=None,
|
||||
return_info=False,
|
||||
control=None,
|
||||
transformer_options={},
|
||||
**kwargs):
|
||||
return self._forward(
|
||||
x,
|
||||
|
@@ -437,7 +437,8 @@ class MMDiT(nn.Module):
|
||||
pos_encoding = pos_encoding[:,from_h:from_h+h,from_w:from_w+w]
|
||||
return x + pos_encoding.reshape(1, -1, self.positional_encoding.shape[-1])
|
||||
|
||||
def forward(self, x, timestep, context, **kwargs):
|
||||
def forward(self, x, timestep, context, transformer_options={}, **kwargs):
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
# patchify x, add PE
|
||||
b, c, h, w = x.shape
|
||||
|
||||
@@ -458,15 +459,36 @@ class MMDiT(nn.Module):
|
||||
|
||||
global_cond = self.t_embedder(t, x.dtype) # B, D
|
||||
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
if len(self.double_layers) > 0:
|
||||
for layer in self.double_layers:
|
||||
c, x = layer(c, x, global_cond, **kwargs)
|
||||
for i, layer in enumerate(self.double_layers):
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["txt"], out["img"] = layer(args["txt"],
|
||||
args["img"],
|
||||
args["vec"])
|
||||
return out
|
||||
out = blocks_replace[("double_block", i)]({"img": x, "txt": c, "vec": global_cond}, {"original_block": block_wrap})
|
||||
c = out["txt"]
|
||||
x = out["img"]
|
||||
else:
|
||||
c, x = layer(c, x, global_cond, **kwargs)
|
||||
|
||||
if len(self.single_layers) > 0:
|
||||
c_len = c.size(1)
|
||||
cx = torch.cat([c, x], dim=1)
|
||||
for layer in self.single_layers:
|
||||
cx = layer(cx, global_cond, **kwargs)
|
||||
for i, layer in enumerate(self.single_layers):
|
||||
if ("single_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["img"] = layer(args["img"], args["vec"])
|
||||
return out
|
||||
|
||||
out = blocks_replace[("single_block", i)]({"img": cx, "vec": global_cond}, {"original_block": block_wrap})
|
||||
cx = out["img"]
|
||||
else:
|
||||
cx = layer(cx, global_cond, **kwargs)
|
||||
|
||||
x = cx[:, c_len:]
|
||||
|
||||
|
@@ -20,6 +20,7 @@ import comfy.ldm.common_dit
|
||||
@dataclass
|
||||
class FluxParams:
|
||||
in_channels: int
|
||||
out_channels: int
|
||||
vec_in_dim: int
|
||||
context_in_dim: int
|
||||
hidden_size: int
|
||||
@@ -29,6 +30,7 @@ class FluxParams:
|
||||
depth_single_blocks: int
|
||||
axes_dim: list
|
||||
theta: int
|
||||
patch_size: int
|
||||
qkv_bias: bool
|
||||
guidance_embed: bool
|
||||
|
||||
@@ -43,8 +45,9 @@ class Flux(nn.Module):
|
||||
self.dtype = dtype
|
||||
params = FluxParams(**kwargs)
|
||||
self.params = params
|
||||
self.in_channels = params.in_channels * 2 * 2
|
||||
self.out_channels = self.in_channels
|
||||
self.patch_size = params.patch_size
|
||||
self.in_channels = params.in_channels * params.patch_size * params.patch_size
|
||||
self.out_channels = params.out_channels * params.patch_size * params.patch_size
|
||||
if params.hidden_size % params.num_heads != 0:
|
||||
raise ValueError(
|
||||
f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
|
||||
@@ -96,7 +99,9 @@ class Flux(nn.Module):
|
||||
y: Tensor,
|
||||
guidance: Tensor = None,
|
||||
control=None,
|
||||
transformer_options={},
|
||||
) -> Tensor:
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
if img.ndim != 3 or txt.ndim != 3:
|
||||
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
||||
|
||||
@@ -114,8 +119,19 @@ class Flux(nn.Module):
|
||||
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||
pe = self.pe_embedder(ids)
|
||||
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
for i, block in enumerate(self.double_blocks):
|
||||
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["img"], out["txt"] = block(img=args["img"], txt=args["txt"], vec=args["vec"], pe=args["pe"])
|
||||
return out
|
||||
|
||||
out = blocks_replace[("double_block", i)]({"img": img, "txt": txt, "vec": vec, "pe": pe}, {"original_block": block_wrap})
|
||||
txt = out["txt"]
|
||||
img = out["img"]
|
||||
else:
|
||||
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
|
||||
|
||||
if control is not None: # Controlnet
|
||||
control_i = control.get("input")
|
||||
@@ -127,7 +143,16 @@ class Flux(nn.Module):
|
||||
img = torch.cat((txt, img), 1)
|
||||
|
||||
for i, block in enumerate(self.single_blocks):
|
||||
img = block(img, vec=vec, pe=pe)
|
||||
if ("single_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["img"] = block(args["img"], vec=args["vec"], pe=args["pe"])
|
||||
return out
|
||||
|
||||
out = blocks_replace[("single_block", i)]({"img": img, "vec": vec, "pe": pe}, {"original_block": block_wrap})
|
||||
img = out["img"]
|
||||
else:
|
||||
img = block(img, vec=vec, pe=pe)
|
||||
|
||||
if control is not None: # Controlnet
|
||||
control_o = control.get("output")
|
||||
@@ -141,9 +166,9 @@ class Flux(nn.Module):
|
||||
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
|
||||
return img
|
||||
|
||||
def forward(self, x, timestep, context, y, guidance, control=None, **kwargs):
|
||||
def forward(self, x, timestep, context, y, guidance, control=None, transformer_options={}, **kwargs):
|
||||
bs, c, h, w = x.shape
|
||||
patch_size = 2
|
||||
patch_size = self.patch_size
|
||||
x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size))
|
||||
|
||||
img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
|
||||
@@ -151,10 +176,10 @@ class Flux(nn.Module):
|
||||
h_len = ((h + (patch_size // 2)) // patch_size)
|
||||
w_len = ((w + (patch_size // 2)) // patch_size)
|
||||
img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
|
||||
img_ids[:, :, 1] = torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1)
|
||||
img_ids[:, :, 2] = torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0)
|
||||
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1)
|
||||
img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0)
|
||||
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
||||
|
||||
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
|
||||
out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control)
|
||||
out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control, transformer_options)
|
||||
return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)[:,:,:h,:w]
|
||||
|
25
comfy/ldm/flux/redux.py
Normal file
25
comfy/ldm/flux/redux.py
Normal file
@@ -0,0 +1,25 @@
|
||||
import torch
|
||||
import comfy.ops
|
||||
|
||||
ops = comfy.ops.manual_cast
|
||||
|
||||
class ReduxImageEncoder(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
redux_dim: int = 1152,
|
||||
txt_in_features: int = 4096,
|
||||
device=None,
|
||||
dtype=None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.redux_dim = redux_dim
|
||||
self.device = device
|
||||
self.dtype = dtype
|
||||
|
||||
self.redux_up = ops.Linear(redux_dim, txt_in_features * 3, dtype=dtype)
|
||||
self.redux_down = ops.Linear(txt_in_features * 3, txt_in_features, dtype=dtype)
|
||||
|
||||
def forward(self, sigclip_embeds) -> torch.Tensor:
|
||||
projected_x = self.redux_down(torch.nn.functional.silu(self.redux_up(sigclip_embeds)))
|
||||
return projected_x
|
@@ -494,8 +494,9 @@ class AsymmDiTJoint(nn.Module):
|
||||
packed_indices: Dict[str, torch.Tensor] = None,
|
||||
rope_cos: torch.Tensor = None,
|
||||
rope_sin: torch.Tensor = None,
|
||||
control=None, **kwargs
|
||||
control=None, transformer_options={}, **kwargs
|
||||
):
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
y_feat = context
|
||||
y_mask = attention_mask
|
||||
sigma = timestep
|
||||
@@ -515,15 +516,32 @@ class AsymmDiTJoint(nn.Module):
|
||||
)
|
||||
del y_mask
|
||||
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
for i, block in enumerate(self.blocks):
|
||||
x, y_feat = block(
|
||||
x,
|
||||
c,
|
||||
y_feat,
|
||||
rope_cos=rope_cos,
|
||||
rope_sin=rope_sin,
|
||||
crop_y=num_tokens,
|
||||
) # (B, M, D), (B, L, D)
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["img"], out["txt"] = block(
|
||||
args["img"],
|
||||
args["vec"],
|
||||
args["txt"],
|
||||
rope_cos=args["rope_cos"],
|
||||
rope_sin=args["rope_sin"],
|
||||
crop_y=args["num_tokens"]
|
||||
)
|
||||
return out
|
||||
out = blocks_replace[("double_block", i)]({"img": x, "txt": y_feat, "vec": c, "rope_cos": rope_cos, "rope_sin": rope_sin, "num_tokens": num_tokens}, {"original_block": block_wrap})
|
||||
y_feat = out["txt"]
|
||||
x = out["img"]
|
||||
else:
|
||||
x, y_feat = block(
|
||||
x,
|
||||
c,
|
||||
y_feat,
|
||||
rope_cos=rope_cos,
|
||||
rope_sin=rope_sin,
|
||||
crop_y=num_tokens,
|
||||
) # (B, M, D), (B, L, D)
|
||||
del y_feat # Final layers don't use dense text features.
|
||||
|
||||
x = self.final_layer(x, c) # (B, M, patch_size ** 2 * out_channels)
|
||||
|
502
comfy/ldm/lightricks/model.py
Normal file
502
comfy/ldm/lightricks/model.py
Normal file
@@ -0,0 +1,502 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
import comfy.ldm.modules.attention
|
||||
from comfy.ldm.genmo.joint_model.layers import RMSNorm
|
||||
import comfy.ldm.common_dit
|
||||
from einops import rearrange
|
||||
import math
|
||||
from typing import Dict, Optional, Tuple
|
||||
|
||||
from .symmetric_patchifier import SymmetricPatchifier
|
||||
|
||||
|
||||
def get_timestep_embedding(
|
||||
timesteps: torch.Tensor,
|
||||
embedding_dim: int,
|
||||
flip_sin_to_cos: bool = False,
|
||||
downscale_freq_shift: float = 1,
|
||||
scale: float = 1,
|
||||
max_period: int = 10000,
|
||||
):
|
||||
"""
|
||||
This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
|
||||
|
||||
Args
|
||||
timesteps (torch.Tensor):
|
||||
a 1-D Tensor of N indices, one per batch element. These may be fractional.
|
||||
embedding_dim (int):
|
||||
the dimension of the output.
|
||||
flip_sin_to_cos (bool):
|
||||
Whether the embedding order should be `cos, sin` (if True) or `sin, cos` (if False)
|
||||
downscale_freq_shift (float):
|
||||
Controls the delta between frequencies between dimensions
|
||||
scale (float):
|
||||
Scaling factor applied to the embeddings.
|
||||
max_period (int):
|
||||
Controls the maximum frequency of the embeddings
|
||||
Returns
|
||||
torch.Tensor: an [N x dim] Tensor of positional embeddings.
|
||||
"""
|
||||
assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
|
||||
|
||||
half_dim = embedding_dim // 2
|
||||
exponent = -math.log(max_period) * torch.arange(
|
||||
start=0, end=half_dim, dtype=torch.float32, device=timesteps.device
|
||||
)
|
||||
exponent = exponent / (half_dim - downscale_freq_shift)
|
||||
|
||||
emb = torch.exp(exponent)
|
||||
emb = timesteps[:, None].float() * emb[None, :]
|
||||
|
||||
# scale embeddings
|
||||
emb = scale * emb
|
||||
|
||||
# concat sine and cosine embeddings
|
||||
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
|
||||
|
||||
# flip sine and cosine embeddings
|
||||
if flip_sin_to_cos:
|
||||
emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
|
||||
|
||||
# zero pad
|
||||
if embedding_dim % 2 == 1:
|
||||
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
|
||||
return emb
|
||||
|
||||
|
||||
class TimestepEmbedding(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
time_embed_dim: int,
|
||||
act_fn: str = "silu",
|
||||
out_dim: int = None,
|
||||
post_act_fn: Optional[str] = None,
|
||||
cond_proj_dim=None,
|
||||
sample_proj_bias=True,
|
||||
dtype=None, device=None, operations=None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.linear_1 = operations.Linear(in_channels, time_embed_dim, sample_proj_bias, dtype=dtype, device=device)
|
||||
|
||||
if cond_proj_dim is not None:
|
||||
self.cond_proj = operations.Linear(cond_proj_dim, in_channels, bias=False, dtype=dtype, device=device)
|
||||
else:
|
||||
self.cond_proj = None
|
||||
|
||||
self.act = nn.SiLU()
|
||||
|
||||
if out_dim is not None:
|
||||
time_embed_dim_out = out_dim
|
||||
else:
|
||||
time_embed_dim_out = time_embed_dim
|
||||
self.linear_2 = operations.Linear(time_embed_dim, time_embed_dim_out, sample_proj_bias, dtype=dtype, device=device)
|
||||
|
||||
if post_act_fn is None:
|
||||
self.post_act = None
|
||||
# else:
|
||||
# self.post_act = get_activation(post_act_fn)
|
||||
|
||||
def forward(self, sample, condition=None):
|
||||
if condition is not None:
|
||||
sample = sample + self.cond_proj(condition)
|
||||
sample = self.linear_1(sample)
|
||||
|
||||
if self.act is not None:
|
||||
sample = self.act(sample)
|
||||
|
||||
sample = self.linear_2(sample)
|
||||
|
||||
if self.post_act is not None:
|
||||
sample = self.post_act(sample)
|
||||
return sample
|
||||
|
||||
|
||||
class Timesteps(nn.Module):
|
||||
def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, scale: int = 1):
|
||||
super().__init__()
|
||||
self.num_channels = num_channels
|
||||
self.flip_sin_to_cos = flip_sin_to_cos
|
||||
self.downscale_freq_shift = downscale_freq_shift
|
||||
self.scale = scale
|
||||
|
||||
def forward(self, timesteps):
|
||||
t_emb = get_timestep_embedding(
|
||||
timesteps,
|
||||
self.num_channels,
|
||||
flip_sin_to_cos=self.flip_sin_to_cos,
|
||||
downscale_freq_shift=self.downscale_freq_shift,
|
||||
scale=self.scale,
|
||||
)
|
||||
return t_emb
|
||||
|
||||
|
||||
class PixArtAlphaCombinedTimestepSizeEmbeddings(nn.Module):
|
||||
"""
|
||||
For PixArt-Alpha.
|
||||
|
||||
Reference:
|
||||
https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L164C9-L168C29
|
||||
"""
|
||||
|
||||
def __init__(self, embedding_dim, size_emb_dim, use_additional_conditions: bool = False, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
|
||||
self.outdim = size_emb_dim
|
||||
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
|
||||
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
def forward(self, timestep, resolution, aspect_ratio, batch_size, hidden_dtype):
|
||||
timesteps_proj = self.time_proj(timestep)
|
||||
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D)
|
||||
return timesteps_emb
|
||||
|
||||
|
||||
class AdaLayerNormSingle(nn.Module):
|
||||
r"""
|
||||
Norm layer adaptive layer norm single (adaLN-single).
|
||||
|
||||
As proposed in PixArt-Alpha (see: https://arxiv.org/abs/2310.00426; Section 2.3).
|
||||
|
||||
Parameters:
|
||||
embedding_dim (`int`): The size of each embedding vector.
|
||||
use_additional_conditions (`bool`): To use additional conditions for normalization or not.
|
||||
"""
|
||||
|
||||
def __init__(self, embedding_dim: int, use_additional_conditions: bool = False, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
|
||||
self.emb = PixArtAlphaCombinedTimestepSizeEmbeddings(
|
||||
embedding_dim, size_emb_dim=embedding_dim // 3, use_additional_conditions=use_additional_conditions, dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
|
||||
self.silu = nn.SiLU()
|
||||
self.linear = operations.Linear(embedding_dim, 6 * embedding_dim, bias=True, dtype=dtype, device=device)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
timestep: torch.Tensor,
|
||||
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
|
||||
batch_size: Optional[int] = None,
|
||||
hidden_dtype: Optional[torch.dtype] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
# No modulation happening here.
|
||||
added_cond_kwargs = added_cond_kwargs or {"resolution": None, "aspect_ratio": None}
|
||||
embedded_timestep = self.emb(timestep, **added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_dtype)
|
||||
return self.linear(self.silu(embedded_timestep)), embedded_timestep
|
||||
|
||||
class PixArtAlphaTextProjection(nn.Module):
|
||||
"""
|
||||
Projects caption embeddings. Also handles dropout for classifier-free guidance.
|
||||
|
||||
Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py
|
||||
"""
|
||||
|
||||
def __init__(self, in_features, hidden_size, out_features=None, act_fn="gelu_tanh", dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
if out_features is None:
|
||||
out_features = hidden_size
|
||||
self.linear_1 = operations.Linear(in_features=in_features, out_features=hidden_size, bias=True, dtype=dtype, device=device)
|
||||
if act_fn == "gelu_tanh":
|
||||
self.act_1 = nn.GELU(approximate="tanh")
|
||||
elif act_fn == "silu":
|
||||
self.act_1 = nn.SiLU()
|
||||
else:
|
||||
raise ValueError(f"Unknown activation function: {act_fn}")
|
||||
self.linear_2 = operations.Linear(in_features=hidden_size, out_features=out_features, bias=True, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, caption):
|
||||
hidden_states = self.linear_1(caption)
|
||||
hidden_states = self.act_1(hidden_states)
|
||||
hidden_states = self.linear_2(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class GELU_approx(nn.Module):
|
||||
def __init__(self, dim_in, dim_out, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.proj = operations.Linear(dim_in, dim_out, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x):
|
||||
return torch.nn.functional.gelu(self.proj(x), approximate="tanh")
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, dim_out, mult=4, glu=False, dropout=0., dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
inner_dim = int(dim * mult)
|
||||
project_in = GELU_approx(dim, inner_dim, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
self.net = nn.Sequential(
|
||||
project_in,
|
||||
nn.Dropout(dropout),
|
||||
operations.Linear(inner_dim, dim_out, dtype=dtype, device=device)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
|
||||
def apply_rotary_emb(input_tensor, freqs_cis): #TODO: remove duplicate funcs and pick the best/fastest one
|
||||
cos_freqs = freqs_cis[0]
|
||||
sin_freqs = freqs_cis[1]
|
||||
|
||||
t_dup = rearrange(input_tensor, "... (d r) -> ... d r", r=2)
|
||||
t1, t2 = t_dup.unbind(dim=-1)
|
||||
t_dup = torch.stack((-t2, t1), dim=-1)
|
||||
input_tensor_rot = rearrange(t_dup, "... d r -> ... (d r)")
|
||||
|
||||
out = input_tensor * cos_freqs + input_tensor_rot * sin_freqs
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class CrossAttention(nn.Module):
|
||||
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., attn_precision=None, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
context_dim = query_dim if context_dim is None else context_dim
|
||||
self.attn_precision = attn_precision
|
||||
|
||||
self.heads = heads
|
||||
self.dim_head = dim_head
|
||||
|
||||
self.q_norm = RMSNorm(inner_dim, dtype=dtype, device=device)
|
||||
self.k_norm = RMSNorm(inner_dim, dtype=dtype, device=device)
|
||||
|
||||
self.to_q = operations.Linear(query_dim, inner_dim, bias=True, dtype=dtype, device=device)
|
||||
self.to_k = operations.Linear(context_dim, inner_dim, bias=True, dtype=dtype, device=device)
|
||||
self.to_v = operations.Linear(context_dim, inner_dim, bias=True, dtype=dtype, device=device)
|
||||
|
||||
self.to_out = nn.Sequential(operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout))
|
||||
|
||||
def forward(self, x, context=None, mask=None, pe=None):
|
||||
q = self.to_q(x)
|
||||
context = x if context is None else context
|
||||
k = self.to_k(context)
|
||||
v = self.to_v(context)
|
||||
|
||||
q = self.q_norm(q)
|
||||
k = self.k_norm(k)
|
||||
|
||||
if pe is not None:
|
||||
q = apply_rotary_emb(q, pe)
|
||||
k = apply_rotary_emb(k, pe)
|
||||
|
||||
if mask is None:
|
||||
out = comfy.ldm.modules.attention.optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision)
|
||||
else:
|
||||
out = comfy.ldm.modules.attention.optimized_attention_masked(q, k, v, self.heads, mask, attn_precision=self.attn_precision)
|
||||
return self.to_out(out)
|
||||
|
||||
|
||||
class BasicTransformerBlock(nn.Module):
|
||||
def __init__(self, dim, n_heads, d_head, context_dim=None, attn_precision=None, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
|
||||
self.attn_precision = attn_precision
|
||||
self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, context_dim=None, attn_precision=self.attn_precision, dtype=dtype, device=device, operations=operations)
|
||||
self.ff = FeedForward(dim, dim_out=dim, glu=True, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim, heads=n_heads, dim_head=d_head, attn_precision=self.attn_precision, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
self.scale_shift_table = nn.Parameter(torch.empty(6, dim, device=device, dtype=dtype))
|
||||
|
||||
def forward(self, x, context=None, attention_mask=None, timestep=None, pe=None):
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + timestep.reshape(x.shape[0], timestep.shape[1], self.scale_shift_table.shape[0], -1)).unbind(dim=2)
|
||||
|
||||
x += self.attn1(comfy.ldm.common_dit.rms_norm(x) * (1 + scale_msa) + shift_msa, pe=pe) * gate_msa
|
||||
|
||||
x += self.attn2(x, context=context, mask=attention_mask)
|
||||
|
||||
y = comfy.ldm.common_dit.rms_norm(x) * (1 + scale_mlp) + shift_mlp
|
||||
x += self.ff(y) * gate_mlp
|
||||
|
||||
return x
|
||||
|
||||
def get_fractional_positions(indices_grid, max_pos):
|
||||
fractional_positions = torch.stack(
|
||||
[
|
||||
indices_grid[:, i] / max_pos[i]
|
||||
for i in range(3)
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
return fractional_positions
|
||||
|
||||
|
||||
def precompute_freqs_cis(indices_grid, dim, out_dtype, theta=10000.0, max_pos=[20, 2048, 2048]):
|
||||
dtype = torch.float32 #self.dtype
|
||||
|
||||
fractional_positions = get_fractional_positions(indices_grid, max_pos)
|
||||
|
||||
start = 1
|
||||
end = theta
|
||||
device = fractional_positions.device
|
||||
|
||||
indices = theta ** (
|
||||
torch.linspace(
|
||||
math.log(start, theta),
|
||||
math.log(end, theta),
|
||||
dim // 6,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
)
|
||||
indices = indices.to(dtype=dtype)
|
||||
|
||||
indices = indices * math.pi / 2
|
||||
|
||||
freqs = (
|
||||
(indices * (fractional_positions.unsqueeze(-1) * 2 - 1))
|
||||
.transpose(-1, -2)
|
||||
.flatten(2)
|
||||
)
|
||||
|
||||
cos_freq = freqs.cos().repeat_interleave(2, dim=-1)
|
||||
sin_freq = freqs.sin().repeat_interleave(2, dim=-1)
|
||||
if dim % 6 != 0:
|
||||
cos_padding = torch.ones_like(cos_freq[:, :, : dim % 6])
|
||||
sin_padding = torch.zeros_like(cos_freq[:, :, : dim % 6])
|
||||
cos_freq = torch.cat([cos_padding, cos_freq], dim=-1)
|
||||
sin_freq = torch.cat([sin_padding, sin_freq], dim=-1)
|
||||
return cos_freq.to(out_dtype), sin_freq.to(out_dtype)
|
||||
|
||||
|
||||
class LTXVModel(torch.nn.Module):
|
||||
def __init__(self,
|
||||
in_channels=128,
|
||||
cross_attention_dim=2048,
|
||||
attention_head_dim=64,
|
||||
num_attention_heads=32,
|
||||
|
||||
caption_channels=4096,
|
||||
num_layers=28,
|
||||
|
||||
|
||||
positional_embedding_theta=10000.0,
|
||||
positional_embedding_max_pos=[20, 2048, 2048],
|
||||
dtype=None, device=None, operations=None, **kwargs):
|
||||
super().__init__()
|
||||
self.dtype = dtype
|
||||
self.out_channels = in_channels
|
||||
self.inner_dim = num_attention_heads * attention_head_dim
|
||||
|
||||
self.patchify_proj = operations.Linear(in_channels, self.inner_dim, bias=True, dtype=dtype, device=device)
|
||||
|
||||
self.adaln_single = AdaLayerNormSingle(
|
||||
self.inner_dim, use_additional_conditions=False, dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
|
||||
# self.adaln_single.linear = operations.Linear(self.inner_dim, 4 * self.inner_dim, bias=True, dtype=dtype, device=device)
|
||||
|
||||
self.caption_projection = PixArtAlphaTextProjection(
|
||||
in_features=caption_channels, hidden_size=self.inner_dim, dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
BasicTransformerBlock(
|
||||
self.inner_dim,
|
||||
num_attention_heads,
|
||||
attention_head_dim,
|
||||
context_dim=cross_attention_dim,
|
||||
# attn_precision=attn_precision,
|
||||
dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
for d in range(num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
self.scale_shift_table = nn.Parameter(torch.empty(2, self.inner_dim, dtype=dtype, device=device))
|
||||
self.norm_out = operations.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
self.proj_out = operations.Linear(self.inner_dim, self.out_channels, dtype=dtype, device=device)
|
||||
|
||||
self.patchifier = SymmetricPatchifier(1)
|
||||
|
||||
def forward(self, x, timestep, context, attention_mask, frame_rate=25, guiding_latent=None, **kwargs):
|
||||
indices_grid = self.patchifier.get_grid(
|
||||
orig_num_frames=x.shape[2],
|
||||
orig_height=x.shape[3],
|
||||
orig_width=x.shape[4],
|
||||
batch_size=x.shape[0],
|
||||
scale_grid=((1 / frame_rate) * 8, 32, 32), #TODO: controlable frame rate
|
||||
device=x.device,
|
||||
)
|
||||
|
||||
if guiding_latent is not None:
|
||||
ts = torch.ones([x.shape[0], 1, x.shape[2], x.shape[3], x.shape[4]], device=x.device, dtype=x.dtype)
|
||||
input_ts = timestep.view([timestep.shape[0]] + [1] * (x.ndim - 1))
|
||||
ts *= input_ts
|
||||
ts[:, :, 0] = 0.0
|
||||
timestep = self.patchifier.patchify(ts)
|
||||
input_x = x.clone()
|
||||
x[:, :, 0] = guiding_latent[:, :, 0]
|
||||
|
||||
orig_shape = list(x.shape)
|
||||
|
||||
x = self.patchifier.patchify(x)
|
||||
|
||||
x = self.patchify_proj(x)
|
||||
timestep = timestep * 1000.0
|
||||
|
||||
attention_mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1]))
|
||||
attention_mask = attention_mask.masked_fill(attention_mask.to(torch.bool), float("-inf")) # not sure about this
|
||||
# attention_mask = (context != 0).any(dim=2).to(dtype=x.dtype)
|
||||
|
||||
pe = precompute_freqs_cis(indices_grid, dim=self.inner_dim, out_dtype=x.dtype)
|
||||
|
||||
batch_size = x.shape[0]
|
||||
timestep, embedded_timestep = self.adaln_single(
|
||||
timestep.flatten(),
|
||||
{"resolution": None, "aspect_ratio": None},
|
||||
batch_size=batch_size,
|
||||
hidden_dtype=x.dtype,
|
||||
)
|
||||
# Second dimension is 1 or number of tokens (if timestep_per_token)
|
||||
timestep = timestep.view(batch_size, -1, timestep.shape[-1])
|
||||
embedded_timestep = embedded_timestep.view(
|
||||
batch_size, -1, embedded_timestep.shape[-1]
|
||||
)
|
||||
|
||||
# 2. Blocks
|
||||
if self.caption_projection is not None:
|
||||
batch_size = x.shape[0]
|
||||
context = self.caption_projection(context)
|
||||
context = context.view(
|
||||
batch_size, -1, x.shape[-1]
|
||||
)
|
||||
|
||||
for block in self.transformer_blocks:
|
||||
x = block(
|
||||
x,
|
||||
context=context,
|
||||
attention_mask=attention_mask,
|
||||
timestep=timestep,
|
||||
pe=pe
|
||||
)
|
||||
|
||||
# 3. Output
|
||||
scale_shift_values = (
|
||||
self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + embedded_timestep[:, :, None]
|
||||
)
|
||||
shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1]
|
||||
x = self.norm_out(x)
|
||||
# Modulation
|
||||
x = x * (1 + scale) + shift
|
||||
x = self.proj_out(x)
|
||||
|
||||
x = self.patchifier.unpatchify(
|
||||
latents=x,
|
||||
output_height=orig_shape[3],
|
||||
output_width=orig_shape[4],
|
||||
output_num_frames=orig_shape[2],
|
||||
out_channels=orig_shape[1] // math.prod(self.patchifier.patch_size),
|
||||
)
|
||||
|
||||
if guiding_latent is not None:
|
||||
x[:, :, 0] = (input_x[:, :, 0] - guiding_latent[:, :, 0]) / input_ts[:, :, 0]
|
||||
|
||||
# print("res", x)
|
||||
return x
|
105
comfy/ldm/lightricks/symmetric_patchifier.py
Normal file
105
comfy/ldm/lightricks/symmetric_patchifier.py
Normal file
@@ -0,0 +1,105 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
from einops import rearrange
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
def append_dims(x: torch.Tensor, target_dims: int) -> torch.Tensor:
|
||||
"""Appends dimensions to the end of a tensor until it has target_dims dimensions."""
|
||||
dims_to_append = target_dims - x.ndim
|
||||
if dims_to_append < 0:
|
||||
raise ValueError(
|
||||
f"input has {x.ndim} dims but target_dims is {target_dims}, which is less"
|
||||
)
|
||||
elif dims_to_append == 0:
|
||||
return x
|
||||
return x[(...,) + (None,) * dims_to_append]
|
||||
|
||||
|
||||
class Patchifier(ABC):
|
||||
def __init__(self, patch_size: int):
|
||||
super().__init__()
|
||||
self._patch_size = (1, patch_size, patch_size)
|
||||
|
||||
@abstractmethod
|
||||
def patchify(
|
||||
self, latents: Tensor, frame_rates: Tensor, scale_grid: bool
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def unpatchify(
|
||||
self,
|
||||
latents: Tensor,
|
||||
output_height: int,
|
||||
output_width: int,
|
||||
output_num_frames: int,
|
||||
out_channels: int,
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
pass
|
||||
|
||||
@property
|
||||
def patch_size(self):
|
||||
return self._patch_size
|
||||
|
||||
def get_grid(
|
||||
self, orig_num_frames, orig_height, orig_width, batch_size, scale_grid, device
|
||||
):
|
||||
f = orig_num_frames // self._patch_size[0]
|
||||
h = orig_height // self._patch_size[1]
|
||||
w = orig_width // self._patch_size[2]
|
||||
grid_h = torch.arange(h, dtype=torch.float32, device=device)
|
||||
grid_w = torch.arange(w, dtype=torch.float32, device=device)
|
||||
grid_f = torch.arange(f, dtype=torch.float32, device=device)
|
||||
grid = torch.meshgrid(grid_f, grid_h, grid_w)
|
||||
grid = torch.stack(grid, dim=0)
|
||||
grid = grid.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1)
|
||||
|
||||
if scale_grid is not None:
|
||||
for i in range(3):
|
||||
if isinstance(scale_grid[i], Tensor):
|
||||
scale = append_dims(scale_grid[i], grid.ndim - 1)
|
||||
else:
|
||||
scale = scale_grid[i]
|
||||
grid[:, i, ...] = grid[:, i, ...] * scale * self._patch_size[i]
|
||||
|
||||
grid = rearrange(grid, "b c f h w -> b c (f h w)", b=batch_size)
|
||||
return grid
|
||||
|
||||
|
||||
class SymmetricPatchifier(Patchifier):
|
||||
def patchify(
|
||||
self,
|
||||
latents: Tensor,
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
latents = rearrange(
|
||||
latents,
|
||||
"b c (f p1) (h p2) (w p3) -> b (f h w) (c p1 p2 p3)",
|
||||
p1=self._patch_size[0],
|
||||
p2=self._patch_size[1],
|
||||
p3=self._patch_size[2],
|
||||
)
|
||||
return latents
|
||||
|
||||
def unpatchify(
|
||||
self,
|
||||
latents: Tensor,
|
||||
output_height: int,
|
||||
output_width: int,
|
||||
output_num_frames: int,
|
||||
out_channels: int,
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
output_height = output_height // self._patch_size[1]
|
||||
output_width = output_width // self._patch_size[2]
|
||||
latents = rearrange(
|
||||
latents,
|
||||
"b (f h w) (c p q) -> b c f (h p) (w q) ",
|
||||
f=output_num_frames,
|
||||
h=output_height,
|
||||
w=output_width,
|
||||
p=self._patch_size[1],
|
||||
q=self._patch_size[2],
|
||||
)
|
||||
return latents
|
64
comfy/ldm/lightricks/vae/causal_conv3d.py
Normal file
64
comfy/ldm/lightricks/vae/causal_conv3d.py
Normal file
@@ -0,0 +1,64 @@
|
||||
from typing import Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import comfy.ops
|
||||
ops = comfy.ops.disable_weight_init
|
||||
|
||||
|
||||
class CausalConv3d(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size: int = 3,
|
||||
stride: Union[int, Tuple[int]] = 1,
|
||||
dilation: int = 1,
|
||||
groups: int = 1,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
|
||||
kernel_size = (kernel_size, kernel_size, kernel_size)
|
||||
self.time_kernel_size = kernel_size[0]
|
||||
|
||||
dilation = (dilation, 1, 1)
|
||||
|
||||
height_pad = kernel_size[1] // 2
|
||||
width_pad = kernel_size[2] // 2
|
||||
padding = (0, height_pad, width_pad)
|
||||
|
||||
self.conv = ops.Conv3d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride=stride,
|
||||
dilation=dilation,
|
||||
padding=padding,
|
||||
padding_mode="zeros",
|
||||
groups=groups,
|
||||
)
|
||||
|
||||
def forward(self, x, causal: bool = True):
|
||||
if causal:
|
||||
first_frame_pad = x[:, :, :1, :, :].repeat(
|
||||
(1, 1, self.time_kernel_size - 1, 1, 1)
|
||||
)
|
||||
x = torch.concatenate((first_frame_pad, x), dim=2)
|
||||
else:
|
||||
first_frame_pad = x[:, :, :1, :, :].repeat(
|
||||
(1, 1, (self.time_kernel_size - 1) // 2, 1, 1)
|
||||
)
|
||||
last_frame_pad = x[:, :, -1:, :, :].repeat(
|
||||
(1, 1, (self.time_kernel_size - 1) // 2, 1, 1)
|
||||
)
|
||||
x = torch.concatenate((first_frame_pad, x, last_frame_pad), dim=2)
|
||||
x = self.conv(x)
|
||||
return x
|
||||
|
||||
@property
|
||||
def weight(self):
|
||||
return self.conv.weight
|
698
comfy/ldm/lightricks/vae/causal_video_autoencoder.py
Normal file
698
comfy/ldm/lightricks/vae/causal_video_autoencoder.py
Normal file
@@ -0,0 +1,698 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
from functools import partial
|
||||
import math
|
||||
from einops import rearrange
|
||||
from typing import Any, Mapping, Optional, Tuple, Union, List
|
||||
from .conv_nd_factory import make_conv_nd, make_linear_nd
|
||||
from .pixel_norm import PixelNorm
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
r"""
|
||||
The `Encoder` layer of a variational autoencoder that encodes its input into a latent representation.
|
||||
|
||||
Args:
|
||||
dims (`int` or `Tuple[int, int]`, *optional*, defaults to 3):
|
||||
The number of dimensions to use in convolutions.
|
||||
in_channels (`int`, *optional*, defaults to 3):
|
||||
The number of input channels.
|
||||
out_channels (`int`, *optional*, defaults to 3):
|
||||
The number of output channels.
|
||||
blocks (`List[Tuple[str, int]]`, *optional*, defaults to `[("res_x", 1)]`):
|
||||
The blocks to use. Each block is a tuple of the block name and the number of layers.
|
||||
base_channels (`int`, *optional*, defaults to 128):
|
||||
The number of output channels for the first convolutional layer.
|
||||
norm_num_groups (`int`, *optional*, defaults to 32):
|
||||
The number of groups for normalization.
|
||||
patch_size (`int`, *optional*, defaults to 1):
|
||||
The patch size to use. Should be a power of 2.
|
||||
norm_layer (`str`, *optional*, defaults to `group_norm`):
|
||||
The normalization layer to use. Can be either `group_norm` or `pixel_norm`.
|
||||
latent_log_var (`str`, *optional*, defaults to `per_channel`):
|
||||
The number of channels for the log variance. Can be either `per_channel`, `uniform`, or `none`.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dims: Union[int, Tuple[int, int]] = 3,
|
||||
in_channels: int = 3,
|
||||
out_channels: int = 3,
|
||||
blocks=[("res_x", 1)],
|
||||
base_channels: int = 128,
|
||||
norm_num_groups: int = 32,
|
||||
patch_size: Union[int, Tuple[int]] = 1,
|
||||
norm_layer: str = "group_norm", # group_norm, pixel_norm
|
||||
latent_log_var: str = "per_channel",
|
||||
):
|
||||
super().__init__()
|
||||
self.patch_size = patch_size
|
||||
self.norm_layer = norm_layer
|
||||
self.latent_channels = out_channels
|
||||
self.latent_log_var = latent_log_var
|
||||
self.blocks_desc = blocks
|
||||
|
||||
in_channels = in_channels * patch_size**2
|
||||
output_channel = base_channels
|
||||
|
||||
self.conv_in = make_conv_nd(
|
||||
dims=dims,
|
||||
in_channels=in_channels,
|
||||
out_channels=output_channel,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
causal=True,
|
||||
)
|
||||
|
||||
self.down_blocks = nn.ModuleList([])
|
||||
|
||||
for block_name, block_params in blocks:
|
||||
input_channel = output_channel
|
||||
if isinstance(block_params, int):
|
||||
block_params = {"num_layers": block_params}
|
||||
|
||||
if block_name == "res_x":
|
||||
block = UNetMidBlock3D(
|
||||
dims=dims,
|
||||
in_channels=input_channel,
|
||||
num_layers=block_params["num_layers"],
|
||||
resnet_eps=1e-6,
|
||||
resnet_groups=norm_num_groups,
|
||||
norm_layer=norm_layer,
|
||||
)
|
||||
elif block_name == "res_x_y":
|
||||
output_channel = block_params.get("multiplier", 2) * output_channel
|
||||
block = ResnetBlock3D(
|
||||
dims=dims,
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
eps=1e-6,
|
||||
groups=norm_num_groups,
|
||||
norm_layer=norm_layer,
|
||||
)
|
||||
elif block_name == "compress_time":
|
||||
block = make_conv_nd(
|
||||
dims=dims,
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
kernel_size=3,
|
||||
stride=(2, 1, 1),
|
||||
causal=True,
|
||||
)
|
||||
elif block_name == "compress_space":
|
||||
block = make_conv_nd(
|
||||
dims=dims,
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
kernel_size=3,
|
||||
stride=(1, 2, 2),
|
||||
causal=True,
|
||||
)
|
||||
elif block_name == "compress_all":
|
||||
block = make_conv_nd(
|
||||
dims=dims,
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
kernel_size=3,
|
||||
stride=(2, 2, 2),
|
||||
causal=True,
|
||||
)
|
||||
elif block_name == "compress_all_x_y":
|
||||
output_channel = block_params.get("multiplier", 2) * output_channel
|
||||
block = make_conv_nd(
|
||||
dims=dims,
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
kernel_size=3,
|
||||
stride=(2, 2, 2),
|
||||
causal=True,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"unknown block: {block_name}")
|
||||
|
||||
self.down_blocks.append(block)
|
||||
|
||||
# out
|
||||
if norm_layer == "group_norm":
|
||||
self.conv_norm_out = nn.GroupNorm(
|
||||
num_channels=output_channel, num_groups=norm_num_groups, eps=1e-6
|
||||
)
|
||||
elif norm_layer == "pixel_norm":
|
||||
self.conv_norm_out = PixelNorm()
|
||||
elif norm_layer == "layer_norm":
|
||||
self.conv_norm_out = LayerNorm(output_channel, eps=1e-6)
|
||||
|
||||
self.conv_act = nn.SiLU()
|
||||
|
||||
conv_out_channels = out_channels
|
||||
if latent_log_var == "per_channel":
|
||||
conv_out_channels *= 2
|
||||
elif latent_log_var == "uniform":
|
||||
conv_out_channels += 1
|
||||
elif latent_log_var != "none":
|
||||
raise ValueError(f"Invalid latent_log_var: {latent_log_var}")
|
||||
self.conv_out = make_conv_nd(
|
||||
dims, output_channel, conv_out_channels, 3, padding=1, causal=True
|
||||
)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(self, sample: torch.FloatTensor) -> torch.FloatTensor:
|
||||
r"""The forward method of the `Encoder` class."""
|
||||
|
||||
sample = patchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)
|
||||
sample = self.conv_in(sample)
|
||||
|
||||
checkpoint_fn = (
|
||||
partial(torch.utils.checkpoint.checkpoint, use_reentrant=False)
|
||||
if self.gradient_checkpointing and self.training
|
||||
else lambda x: x
|
||||
)
|
||||
|
||||
for down_block in self.down_blocks:
|
||||
sample = checkpoint_fn(down_block)(sample)
|
||||
|
||||
sample = self.conv_norm_out(sample)
|
||||
sample = self.conv_act(sample)
|
||||
sample = self.conv_out(sample)
|
||||
|
||||
if self.latent_log_var == "uniform":
|
||||
last_channel = sample[:, -1:, ...]
|
||||
num_dims = sample.dim()
|
||||
|
||||
if num_dims == 4:
|
||||
# For shape (B, C, H, W)
|
||||
repeated_last_channel = last_channel.repeat(
|
||||
1, sample.shape[1] - 2, 1, 1
|
||||
)
|
||||
sample = torch.cat([sample, repeated_last_channel], dim=1)
|
||||
elif num_dims == 5:
|
||||
# For shape (B, C, F, H, W)
|
||||
repeated_last_channel = last_channel.repeat(
|
||||
1, sample.shape[1] - 2, 1, 1, 1
|
||||
)
|
||||
sample = torch.cat([sample, repeated_last_channel], dim=1)
|
||||
else:
|
||||
raise ValueError(f"Invalid input shape: {sample.shape}")
|
||||
|
||||
return sample
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
r"""
|
||||
The `Decoder` layer of a variational autoencoder that decodes its latent representation into an output sample.
|
||||
|
||||
Args:
|
||||
dims (`int` or `Tuple[int, int]`, *optional*, defaults to 3):
|
||||
The number of dimensions to use in convolutions.
|
||||
in_channels (`int`, *optional*, defaults to 3):
|
||||
The number of input channels.
|
||||
out_channels (`int`, *optional*, defaults to 3):
|
||||
The number of output channels.
|
||||
blocks (`List[Tuple[str, int]]`, *optional*, defaults to `[("res_x", 1)]`):
|
||||
The blocks to use. Each block is a tuple of the block name and the number of layers.
|
||||
base_channels (`int`, *optional*, defaults to 128):
|
||||
The number of output channels for the first convolutional layer.
|
||||
norm_num_groups (`int`, *optional*, defaults to 32):
|
||||
The number of groups for normalization.
|
||||
patch_size (`int`, *optional*, defaults to 1):
|
||||
The patch size to use. Should be a power of 2.
|
||||
norm_layer (`str`, *optional*, defaults to `group_norm`):
|
||||
The normalization layer to use. Can be either `group_norm` or `pixel_norm`.
|
||||
causal (`bool`, *optional*, defaults to `True`):
|
||||
Whether to use causal convolutions or not.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dims,
|
||||
in_channels: int = 3,
|
||||
out_channels: int = 3,
|
||||
blocks=[("res_x", 1)],
|
||||
base_channels: int = 128,
|
||||
layers_per_block: int = 2,
|
||||
norm_num_groups: int = 32,
|
||||
patch_size: int = 1,
|
||||
norm_layer: str = "group_norm",
|
||||
causal: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
self.patch_size = patch_size
|
||||
self.layers_per_block = layers_per_block
|
||||
out_channels = out_channels * patch_size**2
|
||||
self.causal = causal
|
||||
self.blocks_desc = blocks
|
||||
|
||||
# Compute output channel to be product of all channel-multiplier blocks
|
||||
output_channel = base_channels
|
||||
for block_name, block_params in list(reversed(blocks)):
|
||||
block_params = block_params if isinstance(block_params, dict) else {}
|
||||
if block_name == "res_x_y":
|
||||
output_channel = output_channel * block_params.get("multiplier", 2)
|
||||
|
||||
self.conv_in = make_conv_nd(
|
||||
dims,
|
||||
in_channels,
|
||||
output_channel,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
causal=True,
|
||||
)
|
||||
|
||||
self.up_blocks = nn.ModuleList([])
|
||||
|
||||
for block_name, block_params in list(reversed(blocks)):
|
||||
input_channel = output_channel
|
||||
if isinstance(block_params, int):
|
||||
block_params = {"num_layers": block_params}
|
||||
|
||||
if block_name == "res_x":
|
||||
block = UNetMidBlock3D(
|
||||
dims=dims,
|
||||
in_channels=input_channel,
|
||||
num_layers=block_params["num_layers"],
|
||||
resnet_eps=1e-6,
|
||||
resnet_groups=norm_num_groups,
|
||||
norm_layer=norm_layer,
|
||||
)
|
||||
elif block_name == "res_x_y":
|
||||
output_channel = output_channel // block_params.get("multiplier", 2)
|
||||
block = ResnetBlock3D(
|
||||
dims=dims,
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
eps=1e-6,
|
||||
groups=norm_num_groups,
|
||||
norm_layer=norm_layer,
|
||||
)
|
||||
elif block_name == "compress_time":
|
||||
block = DepthToSpaceUpsample(
|
||||
dims=dims, in_channels=input_channel, stride=(2, 1, 1)
|
||||
)
|
||||
elif block_name == "compress_space":
|
||||
block = DepthToSpaceUpsample(
|
||||
dims=dims, in_channels=input_channel, stride=(1, 2, 2)
|
||||
)
|
||||
elif block_name == "compress_all":
|
||||
block = DepthToSpaceUpsample(
|
||||
dims=dims,
|
||||
in_channels=input_channel,
|
||||
stride=(2, 2, 2),
|
||||
residual=block_params.get("residual", False),
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"unknown layer: {block_name}")
|
||||
|
||||
self.up_blocks.append(block)
|
||||
|
||||
if norm_layer == "group_norm":
|
||||
self.conv_norm_out = nn.GroupNorm(
|
||||
num_channels=output_channel, num_groups=norm_num_groups, eps=1e-6
|
||||
)
|
||||
elif norm_layer == "pixel_norm":
|
||||
self.conv_norm_out = PixelNorm()
|
||||
elif norm_layer == "layer_norm":
|
||||
self.conv_norm_out = LayerNorm(output_channel, eps=1e-6)
|
||||
|
||||
self.conv_act = nn.SiLU()
|
||||
self.conv_out = make_conv_nd(
|
||||
dims, output_channel, out_channels, 3, padding=1, causal=True
|
||||
)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
# def forward(self, sample: torch.FloatTensor, target_shape) -> torch.FloatTensor:
|
||||
def forward(self, sample: torch.FloatTensor) -> torch.FloatTensor:
|
||||
r"""The forward method of the `Decoder` class."""
|
||||
# assert target_shape is not None, "target_shape must be provided"
|
||||
|
||||
sample = self.conv_in(sample, causal=self.causal)
|
||||
|
||||
upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
|
||||
|
||||
checkpoint_fn = (
|
||||
partial(torch.utils.checkpoint.checkpoint, use_reentrant=False)
|
||||
if self.gradient_checkpointing and self.training
|
||||
else lambda x: x
|
||||
)
|
||||
|
||||
sample = sample.to(upscale_dtype)
|
||||
|
||||
for up_block in self.up_blocks:
|
||||
sample = checkpoint_fn(up_block)(sample, causal=self.causal)
|
||||
|
||||
sample = self.conv_norm_out(sample)
|
||||
sample = self.conv_act(sample)
|
||||
sample = self.conv_out(sample, causal=self.causal)
|
||||
|
||||
sample = unpatchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)
|
||||
|
||||
return sample
|
||||
|
||||
|
||||
class UNetMidBlock3D(nn.Module):
|
||||
"""
|
||||
A 3D UNet mid-block [`UNetMidBlock3D`] with multiple residual blocks.
|
||||
|
||||
Args:
|
||||
in_channels (`int`): The number of input channels.
|
||||
dropout (`float`, *optional*, defaults to 0.0): The dropout rate.
|
||||
num_layers (`int`, *optional*, defaults to 1): The number of residual blocks.
|
||||
resnet_eps (`float`, *optional*, 1e-6 ): The epsilon value for the resnet blocks.
|
||||
resnet_groups (`int`, *optional*, defaults to 32):
|
||||
The number of groups to use in the group normalization layers of the resnet blocks.
|
||||
|
||||
Returns:
|
||||
`torch.FloatTensor`: The output of the last residual block, which is a tensor of shape `(batch_size,
|
||||
in_channels, height, width)`.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dims: Union[int, Tuple[int, int]],
|
||||
in_channels: int,
|
||||
dropout: float = 0.0,
|
||||
num_layers: int = 1,
|
||||
resnet_eps: float = 1e-6,
|
||||
resnet_groups: int = 32,
|
||||
norm_layer: str = "group_norm",
|
||||
):
|
||||
super().__init__()
|
||||
resnet_groups = (
|
||||
resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
|
||||
)
|
||||
|
||||
self.res_blocks = nn.ModuleList(
|
||||
[
|
||||
ResnetBlock3D(
|
||||
dims=dims,
|
||||
in_channels=in_channels,
|
||||
out_channels=in_channels,
|
||||
eps=resnet_eps,
|
||||
groups=resnet_groups,
|
||||
dropout=dropout,
|
||||
norm_layer=norm_layer,
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, hidden_states: torch.FloatTensor, causal: bool = True
|
||||
) -> torch.FloatTensor:
|
||||
for resnet in self.res_blocks:
|
||||
hidden_states = resnet(hidden_states, causal=causal)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class DepthToSpaceUpsample(nn.Module):
|
||||
def __init__(self, dims, in_channels, stride, residual=False):
|
||||
super().__init__()
|
||||
self.stride = stride
|
||||
self.out_channels = math.prod(stride) * in_channels
|
||||
self.conv = make_conv_nd(
|
||||
dims=dims,
|
||||
in_channels=in_channels,
|
||||
out_channels=self.out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
causal=True,
|
||||
)
|
||||
self.residual = residual
|
||||
|
||||
def forward(self, x, causal: bool = True):
|
||||
if self.residual:
|
||||
# Reshape and duplicate the input to match the output shape
|
||||
x_in = rearrange(
|
||||
x,
|
||||
"b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)",
|
||||
p1=self.stride[0],
|
||||
p2=self.stride[1],
|
||||
p3=self.stride[2],
|
||||
)
|
||||
x_in = x_in.repeat(1, math.prod(self.stride), 1, 1, 1)
|
||||
if self.stride[0] == 2:
|
||||
x_in = x_in[:, :, 1:, :, :]
|
||||
x = self.conv(x, causal=causal)
|
||||
x = rearrange(
|
||||
x,
|
||||
"b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)",
|
||||
p1=self.stride[0],
|
||||
p2=self.stride[1],
|
||||
p3=self.stride[2],
|
||||
)
|
||||
if self.stride[0] == 2:
|
||||
x = x[:, :, 1:, :, :]
|
||||
if self.residual:
|
||||
x = x + x_in
|
||||
return x
|
||||
|
||||
|
||||
class LayerNorm(nn.Module):
|
||||
def __init__(self, dim, eps, elementwise_affine=True) -> None:
|
||||
super().__init__()
|
||||
self.norm = nn.LayerNorm(dim, eps=eps, elementwise_affine=elementwise_affine)
|
||||
|
||||
def forward(self, x):
|
||||
x = rearrange(x, "b c d h w -> b d h w c")
|
||||
x = self.norm(x)
|
||||
x = rearrange(x, "b d h w c -> b c d h w")
|
||||
return x
|
||||
|
||||
|
||||
class ResnetBlock3D(nn.Module):
|
||||
r"""
|
||||
A Resnet block.
|
||||
|
||||
Parameters:
|
||||
in_channels (`int`): The number of channels in the input.
|
||||
out_channels (`int`, *optional*, default to be `None`):
|
||||
The number of output channels for the first conv layer. If None, same as `in_channels`.
|
||||
dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use.
|
||||
groups (`int`, *optional*, default to `32`): The number of groups to use for the first normalization layer.
|
||||
eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dims: Union[int, Tuple[int, int]],
|
||||
in_channels: int,
|
||||
out_channels: Optional[int] = None,
|
||||
dropout: float = 0.0,
|
||||
groups: int = 32,
|
||||
eps: float = 1e-6,
|
||||
norm_layer: str = "group_norm",
|
||||
):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
out_channels = in_channels if out_channels is None else out_channels
|
||||
self.out_channels = out_channels
|
||||
|
||||
if norm_layer == "group_norm":
|
||||
self.norm1 = nn.GroupNorm(
|
||||
num_groups=groups, num_channels=in_channels, eps=eps, affine=True
|
||||
)
|
||||
elif norm_layer == "pixel_norm":
|
||||
self.norm1 = PixelNorm()
|
||||
elif norm_layer == "layer_norm":
|
||||
self.norm1 = LayerNorm(in_channels, eps=eps, elementwise_affine=True)
|
||||
|
||||
self.non_linearity = nn.SiLU()
|
||||
|
||||
self.conv1 = make_conv_nd(
|
||||
dims,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
causal=True,
|
||||
)
|
||||
|
||||
if norm_layer == "group_norm":
|
||||
self.norm2 = nn.GroupNorm(
|
||||
num_groups=groups, num_channels=out_channels, eps=eps, affine=True
|
||||
)
|
||||
elif norm_layer == "pixel_norm":
|
||||
self.norm2 = PixelNorm()
|
||||
elif norm_layer == "layer_norm":
|
||||
self.norm2 = LayerNorm(out_channels, eps=eps, elementwise_affine=True)
|
||||
|
||||
self.dropout = torch.nn.Dropout(dropout)
|
||||
|
||||
self.conv2 = make_conv_nd(
|
||||
dims,
|
||||
out_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
causal=True,
|
||||
)
|
||||
|
||||
self.conv_shortcut = (
|
||||
make_linear_nd(
|
||||
dims=dims, in_channels=in_channels, out_channels=out_channels
|
||||
)
|
||||
if in_channels != out_channels
|
||||
else nn.Identity()
|
||||
)
|
||||
|
||||
self.norm3 = (
|
||||
LayerNorm(in_channels, eps=eps, elementwise_affine=True)
|
||||
if in_channels != out_channels
|
||||
else nn.Identity()
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_tensor: torch.FloatTensor,
|
||||
causal: bool = True,
|
||||
) -> torch.FloatTensor:
|
||||
hidden_states = input_tensor
|
||||
|
||||
hidden_states = self.norm1(hidden_states)
|
||||
|
||||
hidden_states = self.non_linearity(hidden_states)
|
||||
|
||||
hidden_states = self.conv1(hidden_states, causal=causal)
|
||||
|
||||
hidden_states = self.norm2(hidden_states)
|
||||
|
||||
hidden_states = self.non_linearity(hidden_states)
|
||||
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
|
||||
hidden_states = self.conv2(hidden_states, causal=causal)
|
||||
|
||||
input_tensor = self.norm3(input_tensor)
|
||||
|
||||
input_tensor = self.conv_shortcut(input_tensor)
|
||||
|
||||
output_tensor = input_tensor + hidden_states
|
||||
|
||||
return output_tensor
|
||||
|
||||
|
||||
def patchify(x, patch_size_hw, patch_size_t=1):
|
||||
if patch_size_hw == 1 and patch_size_t == 1:
|
||||
return x
|
||||
if x.dim() == 4:
|
||||
x = rearrange(
|
||||
x, "b c (h q) (w r) -> b (c r q) h w", q=patch_size_hw, r=patch_size_hw
|
||||
)
|
||||
elif x.dim() == 5:
|
||||
x = rearrange(
|
||||
x,
|
||||
"b c (f p) (h q) (w r) -> b (c p r q) f h w",
|
||||
p=patch_size_t,
|
||||
q=patch_size_hw,
|
||||
r=patch_size_hw,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid input shape: {x.shape}")
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def unpatchify(x, patch_size_hw, patch_size_t=1):
|
||||
if patch_size_hw == 1 and patch_size_t == 1:
|
||||
return x
|
||||
|
||||
if x.dim() == 4:
|
||||
x = rearrange(
|
||||
x, "b (c r q) h w -> b c (h q) (w r)", q=patch_size_hw, r=patch_size_hw
|
||||
)
|
||||
elif x.dim() == 5:
|
||||
x = rearrange(
|
||||
x,
|
||||
"b (c p r q) f h w -> b c (f p) (h q) (w r)",
|
||||
p=patch_size_t,
|
||||
q=patch_size_hw,
|
||||
r=patch_size_hw,
|
||||
)
|
||||
|
||||
return x
|
||||
|
||||
class processor(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.register_buffer("std-of-means", torch.empty(128))
|
||||
self.register_buffer("mean-of-means", torch.empty(128))
|
||||
self.register_buffer("mean-of-stds", torch.empty(128))
|
||||
self.register_buffer("mean-of-stds_over_std-of-means", torch.empty(128))
|
||||
self.register_buffer("channel", torch.empty(128))
|
||||
|
||||
def un_normalize(self, x):
|
||||
return (x * self.get_buffer("std-of-means").view(1, -1, 1, 1, 1).to(x)) + self.get_buffer("mean-of-means").view(1, -1, 1, 1, 1).to(x)
|
||||
|
||||
def normalize(self, x):
|
||||
return (x - self.get_buffer("mean-of-means").view(1, -1, 1, 1, 1).to(x)) / self.get_buffer("std-of-means").view(1, -1, 1, 1, 1).to(x)
|
||||
|
||||
class VideoVAE(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
config = {
|
||||
"_class_name": "CausalVideoAutoencoder",
|
||||
"dims": 3,
|
||||
"in_channels": 3,
|
||||
"out_channels": 3,
|
||||
"latent_channels": 128,
|
||||
"blocks": [
|
||||
["res_x", 4],
|
||||
["compress_all", 1],
|
||||
["res_x_y", 1],
|
||||
["res_x", 3],
|
||||
["compress_all", 1],
|
||||
["res_x_y", 1],
|
||||
["res_x", 3],
|
||||
["compress_all", 1],
|
||||
["res_x", 3],
|
||||
["res_x", 4],
|
||||
],
|
||||
"scaling_factor": 1.0,
|
||||
"norm_layer": "pixel_norm",
|
||||
"patch_size": 4,
|
||||
"latent_log_var": "uniform",
|
||||
"use_quant_conv": False,
|
||||
"causal_decoder": False,
|
||||
}
|
||||
|
||||
double_z = config.get("double_z", True)
|
||||
latent_log_var = config.get(
|
||||
"latent_log_var", "per_channel" if double_z else "none"
|
||||
)
|
||||
|
||||
self.encoder = Encoder(
|
||||
dims=config["dims"],
|
||||
in_channels=config.get("in_channels", 3),
|
||||
out_channels=config["latent_channels"],
|
||||
blocks=config.get("encoder_blocks", config.get("blocks")),
|
||||
patch_size=config.get("patch_size", 1),
|
||||
latent_log_var=latent_log_var,
|
||||
norm_layer=config.get("norm_layer", "group_norm"),
|
||||
)
|
||||
|
||||
self.decoder = Decoder(
|
||||
dims=config["dims"],
|
||||
in_channels=config["latent_channels"],
|
||||
out_channels=config.get("out_channels", 3),
|
||||
blocks=config.get("decoder_blocks", config.get("blocks")),
|
||||
patch_size=config.get("patch_size", 1),
|
||||
norm_layer=config.get("norm_layer", "group_norm"),
|
||||
causal=config.get("causal_decoder", False),
|
||||
)
|
||||
|
||||
self.per_channel_statistics = processor()
|
||||
|
||||
def encode(self, x):
|
||||
means, logvar = torch.chunk(self.encoder(x), 2, dim=1)
|
||||
return self.per_channel_statistics.normalize(means)
|
||||
|
||||
def decode(self, x):
|
||||
return self.decoder(self.per_channel_statistics.un_normalize(x))
|
||||
|
83
comfy/ldm/lightricks/vae/conv_nd_factory.py
Normal file
83
comfy/ldm/lightricks/vae/conv_nd_factory.py
Normal file
@@ -0,0 +1,83 @@
|
||||
from typing import Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from .dual_conv3d import DualConv3d
|
||||
from .causal_conv3d import CausalConv3d
|
||||
import comfy.ops
|
||||
ops = comfy.ops.disable_weight_init
|
||||
|
||||
def make_conv_nd(
|
||||
dims: Union[int, Tuple[int, int]],
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: int,
|
||||
stride=1,
|
||||
padding=0,
|
||||
dilation=1,
|
||||
groups=1,
|
||||
bias=True,
|
||||
causal=False,
|
||||
):
|
||||
if dims == 2:
|
||||
return ops.Conv2d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
bias=bias,
|
||||
)
|
||||
elif dims == 3:
|
||||
if causal:
|
||||
return CausalConv3d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
bias=bias,
|
||||
)
|
||||
return ops.Conv3d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
bias=bias,
|
||||
)
|
||||
elif dims == (2, 1):
|
||||
return DualConv3d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
bias=bias,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"unsupported dimensions: {dims}")
|
||||
|
||||
|
||||
def make_linear_nd(
|
||||
dims: int,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
bias=True,
|
||||
):
|
||||
if dims == 2:
|
||||
return ops.Conv2d(
|
||||
in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias
|
||||
)
|
||||
elif dims == 3 or dims == (2, 1):
|
||||
return ops.Conv3d(
|
||||
in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"unsupported dimensions: {dims}")
|
195
comfy/ldm/lightricks/vae/dual_conv3d.py
Normal file
195
comfy/ldm/lightricks/vae/dual_conv3d.py
Normal file
@@ -0,0 +1,195 @@
|
||||
import math
|
||||
from typing import Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
|
||||
|
||||
class DualConv3d(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride: Union[int, Tuple[int, int, int]] = 1,
|
||||
padding: Union[int, Tuple[int, int, int]] = 0,
|
||||
dilation: Union[int, Tuple[int, int, int]] = 1,
|
||||
groups=1,
|
||||
bias=True,
|
||||
):
|
||||
super(DualConv3d, self).__init__()
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
# Ensure kernel_size, stride, padding, and dilation are tuples of length 3
|
||||
if isinstance(kernel_size, int):
|
||||
kernel_size = (kernel_size, kernel_size, kernel_size)
|
||||
if kernel_size == (1, 1, 1):
|
||||
raise ValueError(
|
||||
"kernel_size must be greater than 1. Use make_linear_nd instead."
|
||||
)
|
||||
if isinstance(stride, int):
|
||||
stride = (stride, stride, stride)
|
||||
if isinstance(padding, int):
|
||||
padding = (padding, padding, padding)
|
||||
if isinstance(dilation, int):
|
||||
dilation = (dilation, dilation, dilation)
|
||||
|
||||
# Set parameters for convolutions
|
||||
self.groups = groups
|
||||
self.bias = bias
|
||||
|
||||
# Define the size of the channels after the first convolution
|
||||
intermediate_channels = (
|
||||
out_channels if in_channels < out_channels else in_channels
|
||||
)
|
||||
|
||||
# Define parameters for the first convolution
|
||||
self.weight1 = nn.Parameter(
|
||||
torch.Tensor(
|
||||
intermediate_channels,
|
||||
in_channels // groups,
|
||||
1,
|
||||
kernel_size[1],
|
||||
kernel_size[2],
|
||||
)
|
||||
)
|
||||
self.stride1 = (1, stride[1], stride[2])
|
||||
self.padding1 = (0, padding[1], padding[2])
|
||||
self.dilation1 = (1, dilation[1], dilation[2])
|
||||
if bias:
|
||||
self.bias1 = nn.Parameter(torch.Tensor(intermediate_channels))
|
||||
else:
|
||||
self.register_parameter("bias1", None)
|
||||
|
||||
# Define parameters for the second convolution
|
||||
self.weight2 = nn.Parameter(
|
||||
torch.Tensor(
|
||||
out_channels, intermediate_channels // groups, kernel_size[0], 1, 1
|
||||
)
|
||||
)
|
||||
self.stride2 = (stride[0], 1, 1)
|
||||
self.padding2 = (padding[0], 0, 0)
|
||||
self.dilation2 = (dilation[0], 1, 1)
|
||||
if bias:
|
||||
self.bias2 = nn.Parameter(torch.Tensor(out_channels))
|
||||
else:
|
||||
self.register_parameter("bias2", None)
|
||||
|
||||
# Initialize weights and biases
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self):
|
||||
nn.init.kaiming_uniform_(self.weight1, a=math.sqrt(5))
|
||||
nn.init.kaiming_uniform_(self.weight2, a=math.sqrt(5))
|
||||
if self.bias:
|
||||
fan_in1, _ = nn.init._calculate_fan_in_and_fan_out(self.weight1)
|
||||
bound1 = 1 / math.sqrt(fan_in1)
|
||||
nn.init.uniform_(self.bias1, -bound1, bound1)
|
||||
fan_in2, _ = nn.init._calculate_fan_in_and_fan_out(self.weight2)
|
||||
bound2 = 1 / math.sqrt(fan_in2)
|
||||
nn.init.uniform_(self.bias2, -bound2, bound2)
|
||||
|
||||
def forward(self, x, use_conv3d=False, skip_time_conv=False):
|
||||
if use_conv3d:
|
||||
return self.forward_with_3d(x=x, skip_time_conv=skip_time_conv)
|
||||
else:
|
||||
return self.forward_with_2d(x=x, skip_time_conv=skip_time_conv)
|
||||
|
||||
def forward_with_3d(self, x, skip_time_conv):
|
||||
# First convolution
|
||||
x = F.conv3d(
|
||||
x,
|
||||
self.weight1,
|
||||
self.bias1,
|
||||
self.stride1,
|
||||
self.padding1,
|
||||
self.dilation1,
|
||||
self.groups,
|
||||
)
|
||||
|
||||
if skip_time_conv:
|
||||
return x
|
||||
|
||||
# Second convolution
|
||||
x = F.conv3d(
|
||||
x,
|
||||
self.weight2,
|
||||
self.bias2,
|
||||
self.stride2,
|
||||
self.padding2,
|
||||
self.dilation2,
|
||||
self.groups,
|
||||
)
|
||||
|
||||
return x
|
||||
|
||||
def forward_with_2d(self, x, skip_time_conv):
|
||||
b, c, d, h, w = x.shape
|
||||
|
||||
# First 2D convolution
|
||||
x = rearrange(x, "b c d h w -> (b d) c h w")
|
||||
# Squeeze the depth dimension out of weight1 since it's 1
|
||||
weight1 = self.weight1.squeeze(2)
|
||||
# Select stride, padding, and dilation for the 2D convolution
|
||||
stride1 = (self.stride1[1], self.stride1[2])
|
||||
padding1 = (self.padding1[1], self.padding1[2])
|
||||
dilation1 = (self.dilation1[1], self.dilation1[2])
|
||||
x = F.conv2d(x, weight1, self.bias1, stride1, padding1, dilation1, self.groups)
|
||||
|
||||
_, _, h, w = x.shape
|
||||
|
||||
if skip_time_conv:
|
||||
x = rearrange(x, "(b d) c h w -> b c d h w", b=b)
|
||||
return x
|
||||
|
||||
# Second convolution which is essentially treated as a 1D convolution across the 'd' dimension
|
||||
x = rearrange(x, "(b d) c h w -> (b h w) c d", b=b)
|
||||
|
||||
# Reshape weight2 to match the expected dimensions for conv1d
|
||||
weight2 = self.weight2.squeeze(-1).squeeze(-1)
|
||||
# Use only the relevant dimension for stride, padding, and dilation for the 1D convolution
|
||||
stride2 = self.stride2[0]
|
||||
padding2 = self.padding2[0]
|
||||
dilation2 = self.dilation2[0]
|
||||
x = F.conv1d(x, weight2, self.bias2, stride2, padding2, dilation2, self.groups)
|
||||
x = rearrange(x, "(b h w) c d -> b c d h w", b=b, h=h, w=w)
|
||||
|
||||
return x
|
||||
|
||||
@property
|
||||
def weight(self):
|
||||
return self.weight2
|
||||
|
||||
|
||||
def test_dual_conv3d_consistency():
|
||||
# Initialize parameters
|
||||
in_channels = 3
|
||||
out_channels = 5
|
||||
kernel_size = (3, 3, 3)
|
||||
stride = (2, 2, 2)
|
||||
padding = (1, 1, 1)
|
||||
|
||||
# Create an instance of the DualConv3d class
|
||||
dual_conv3d = DualConv3d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
bias=True,
|
||||
)
|
||||
|
||||
# Example input tensor
|
||||
test_input = torch.randn(1, 3, 10, 10, 10)
|
||||
|
||||
# Perform forward passes with both 3D and 2D settings
|
||||
output_conv3d = dual_conv3d(test_input, use_conv3d=True)
|
||||
output_2d = dual_conv3d(test_input, use_conv3d=False)
|
||||
|
||||
# Assert that the outputs from both methods are sufficiently close
|
||||
assert torch.allclose(
|
||||
output_conv3d, output_2d, atol=1e-6
|
||||
), "Outputs are not consistent between 3D and 2D convolutions."
|
12
comfy/ldm/lightricks/vae/pixel_norm.py
Normal file
12
comfy/ldm/lightricks/vae/pixel_norm.py
Normal file
@@ -0,0 +1,12 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
class PixelNorm(nn.Module):
|
||||
def __init__(self, dim=1, eps=1e-8):
|
||||
super(PixelNorm, self).__init__()
|
||||
self.dim = dim
|
||||
self.eps = eps
|
||||
|
||||
def forward(self, x):
|
||||
return x / torch.sqrt(torch.mean(x**2, dim=self.dim, keepdim=True) + self.eps)
|
@@ -299,7 +299,10 @@ def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
|
||||
if len(mask.shape) == 2:
|
||||
s1 += mask[i:end]
|
||||
else:
|
||||
s1 += mask[:, i:end]
|
||||
if mask.shape[1] == 1:
|
||||
s1 += mask
|
||||
else:
|
||||
s1 += mask[:, i:end]
|
||||
|
||||
s2 = s1.softmax(dim=-1).to(v.dtype)
|
||||
del s1
|
||||
@@ -372,10 +375,10 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_resh
|
||||
)
|
||||
|
||||
if mask is not None:
|
||||
pad = 8 - q.shape[1] % 8
|
||||
mask_out = torch.empty([q.shape[0], q.shape[1], q.shape[1] + pad], dtype=q.dtype, device=q.device)
|
||||
mask_out[:, :, :mask.shape[-1]] = mask
|
||||
mask = mask_out[:, :, :mask.shape[-1]]
|
||||
pad = 8 - mask.shape[-1] % 8
|
||||
mask_out = torch.empty([q.shape[0], q.shape[2], q.shape[1], mask.shape[-1] + pad], dtype=q.dtype, device=q.device)
|
||||
mask_out[..., :mask.shape[-1]] = mask
|
||||
mask = mask_out[..., :mask.shape[-1]]
|
||||
|
||||
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=mask)
|
||||
|
||||
|
@@ -234,6 +234,8 @@ def efficient_dot_product_attention(
|
||||
def get_mask_chunk(chunk_idx: int) -> Tensor:
|
||||
if mask is None:
|
||||
return None
|
||||
if mask.shape[1] == 1:
|
||||
return mask
|
||||
chunk = min(query_chunk_size, q_tokens)
|
||||
return mask[:,chunk_idx:chunk_idx + chunk]
|
||||
|
||||
|
@@ -49,6 +49,15 @@ def load_lora(lora, to_load):
|
||||
dora_scale = lora[dora_scale_name]
|
||||
loaded_keys.add(dora_scale_name)
|
||||
|
||||
reshape_name = "{}.reshape_weight".format(x)
|
||||
reshape = None
|
||||
if reshape_name in lora.keys():
|
||||
try:
|
||||
reshape = lora[reshape_name].tolist()
|
||||
loaded_keys.add(reshape_name)
|
||||
except:
|
||||
pass
|
||||
|
||||
regular_lora = "{}.lora_up.weight".format(x)
|
||||
diffusers_lora = "{}_lora.up.weight".format(x)
|
||||
diffusers2_lora = "{}.lora_B.weight".format(x)
|
||||
@@ -82,7 +91,7 @@ def load_lora(lora, to_load):
|
||||
if mid_name is not None and mid_name in lora.keys():
|
||||
mid = lora[mid_name]
|
||||
loaded_keys.add(mid_name)
|
||||
patch_dict[to_load[x]] = ("lora", (lora[A_name], lora[B_name], alpha, mid, dora_scale))
|
||||
patch_dict[to_load[x]] = ("lora", (lora[A_name], lora[B_name], alpha, mid, dora_scale, reshape))
|
||||
loaded_keys.add(A_name)
|
||||
loaded_keys.add(B_name)
|
||||
|
||||
@@ -193,6 +202,12 @@ def load_lora(lora, to_load):
|
||||
patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = ("diff", (diff_bias,))
|
||||
loaded_keys.add(diff_bias_name)
|
||||
|
||||
set_weight_name = "{}.set_weight".format(x)
|
||||
set_weight = lora.get(set_weight_name, None)
|
||||
if set_weight is not None:
|
||||
patch_dict[to_load[x]] = ("set", (set_weight,))
|
||||
loaded_keys.add(set_weight_name)
|
||||
|
||||
for x in lora.keys():
|
||||
if x not in loaded_keys:
|
||||
logging.warning("lora key not loaded: {}".format(x))
|
||||
@@ -282,11 +297,14 @@ def model_lora_keys_unet(model, key_map={}):
|
||||
sdk = sd.keys()
|
||||
|
||||
for k in sdk:
|
||||
if k.startswith("diffusion_model.") and k.endswith(".weight"):
|
||||
key_lora = k[len("diffusion_model."):-len(".weight")].replace(".", "_")
|
||||
key_map["lora_unet_{}".format(key_lora)] = k
|
||||
key_map["lora_prior_unet_{}".format(key_lora)] = k #cascade lora: TODO put lora key prefix in the model config
|
||||
key_map["{}".format(k[:-len(".weight")])] = k #generic lora format without any weird key names
|
||||
if k.startswith("diffusion_model."):
|
||||
if k.endswith(".weight"):
|
||||
key_lora = k[len("diffusion_model."):-len(".weight")].replace(".", "_")
|
||||
key_map["lora_unet_{}".format(key_lora)] = k
|
||||
key_map["lora_prior_unet_{}".format(key_lora)] = k #cascade lora: TODO put lora key prefix in the model config
|
||||
key_map["{}".format(k[:-len(".weight")])] = k #generic lora format without any weird key names
|
||||
else:
|
||||
key_map["{}".format(k)] = k #generic lora format for not .weight without any weird key names
|
||||
|
||||
diffusers_keys = comfy.utils.unet_to_diffusers(model.model_config.unet_config)
|
||||
for k in diffusers_keys:
|
||||
@@ -440,10 +458,17 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32):
|
||||
logging.warning("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, diff.shape, weight.shape))
|
||||
else:
|
||||
weight += function(strength * comfy.model_management.cast_to_device(diff, weight.device, weight.dtype))
|
||||
elif patch_type == "set":
|
||||
weight.copy_(v[0])
|
||||
elif patch_type == "lora": #lora/locon
|
||||
mat1 = comfy.model_management.cast_to_device(v[0], weight.device, intermediate_dtype)
|
||||
mat2 = comfy.model_management.cast_to_device(v[1], weight.device, intermediate_dtype)
|
||||
dora_scale = v[4]
|
||||
reshape = v[5]
|
||||
|
||||
if reshape is not None:
|
||||
weight = pad_tensor_to_shape(weight, reshape)
|
||||
|
||||
if v[2] is not None:
|
||||
alpha = v[2] / mat2.shape[0]
|
||||
else:
|
||||
|
17
comfy/lora_convert.py
Normal file
17
comfy/lora_convert.py
Normal file
@@ -0,0 +1,17 @@
|
||||
import torch
|
||||
|
||||
|
||||
def convert_lora_bfl_control(sd): #BFL loras for Flux
|
||||
sd_out = {}
|
||||
for k in sd:
|
||||
k_to = "diffusion_model.{}".format(k.replace(".lora_B.bias", ".diff_b").replace("_norm.scale", "_norm.scale.set_weight"))
|
||||
sd_out[k_to] = sd[k]
|
||||
|
||||
sd_out["diffusion_model.img_in.reshape_weight"] = torch.tensor([sd["img_in.lora_B.weight"].shape[0], sd["img_in.lora_A.weight"].shape[1]])
|
||||
return sd_out
|
||||
|
||||
|
||||
def convert_lora(sd):
|
||||
if "img_in.lora_A.weight" in sd and "single_blocks.0.norm.key_norm.scale" in sd:
|
||||
return convert_lora_bfl_control(sd)
|
||||
return sd
|
@@ -30,6 +30,7 @@ import comfy.ldm.hydit.models
|
||||
import comfy.ldm.audio.dit
|
||||
import comfy.ldm.audio.embedders
|
||||
import comfy.ldm.flux.model
|
||||
import comfy.ldm.lightricks.model
|
||||
|
||||
import comfy.model_management
|
||||
import comfy.conds
|
||||
@@ -153,8 +154,7 @@ class BaseModel(torch.nn.Module):
|
||||
def encode_adm(self, **kwargs):
|
||||
return None
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = {}
|
||||
def concat_cond(self, **kwargs):
|
||||
if len(self.concat_keys) > 0:
|
||||
cond_concat = []
|
||||
denoise_mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None))
|
||||
@@ -193,7 +193,14 @@ class BaseModel(torch.nn.Module):
|
||||
elif ck == "masked_image":
|
||||
cond_concat.append(self.blank_inpaint_image_like(noise))
|
||||
data = torch.cat(cond_concat, dim=1)
|
||||
out['c_concat'] = comfy.conds.CONDNoiseShape(data)
|
||||
return data
|
||||
return None
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = {}
|
||||
concat_cond = self.concat_cond(**kwargs)
|
||||
if concat_cond is not None:
|
||||
out['c_concat'] = comfy.conds.CONDNoiseShape(concat_cond)
|
||||
|
||||
adm = self.encode_adm(**kwargs)
|
||||
if adm is not None:
|
||||
@@ -523,9 +530,7 @@ class SD_X4Upscaler(BaseModel):
|
||||
return out
|
||||
|
||||
class IP2P:
|
||||
def extra_conds(self, **kwargs):
|
||||
out = {}
|
||||
|
||||
def concat_cond(self, **kwargs):
|
||||
image = kwargs.get("concat_latent_image", None)
|
||||
noise = kwargs.get("noise", None)
|
||||
device = kwargs["device"]
|
||||
@@ -537,18 +542,15 @@ class IP2P:
|
||||
image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
|
||||
|
||||
image = utils.resize_to_batch_size(image, noise.shape[0])
|
||||
return self.process_ip2p_image_in(image)
|
||||
|
||||
out['c_concat'] = comfy.conds.CONDNoiseShape(self.process_ip2p_image_in(image))
|
||||
adm = self.encode_adm(**kwargs)
|
||||
if adm is not None:
|
||||
out['y'] = comfy.conds.CONDRegular(adm)
|
||||
return out
|
||||
|
||||
class SD15_instructpix2pix(IP2P, BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.EPS, device=None):
|
||||
super().__init__(model_config, model_type, device=device)
|
||||
self.process_ip2p_image_in = lambda image: image
|
||||
|
||||
|
||||
class SDXL_instructpix2pix(IP2P, SDXL):
|
||||
def __init__(self, model_config, model_type=ModelType.EPS, device=None):
|
||||
super().__init__(model_config, model_type, device=device)
|
||||
@@ -709,6 +711,38 @@ class Flux(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLUX, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.flux.model.Flux)
|
||||
|
||||
def concat_cond(self, **kwargs):
|
||||
num_channels = self.diffusion_model.img_in.weight.shape[1] // (self.diffusion_model.patch_size * self.diffusion_model.patch_size)
|
||||
out_channels = self.model_config.unet_config["out_channels"]
|
||||
|
||||
if num_channels <= out_channels:
|
||||
return None
|
||||
|
||||
image = kwargs.get("concat_latent_image", None)
|
||||
noise = kwargs.get("noise", None)
|
||||
device = kwargs["device"]
|
||||
|
||||
if image is None:
|
||||
image = torch.zeros_like(noise)
|
||||
|
||||
image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
|
||||
image = utils.resize_to_batch_size(image, noise.shape[0])
|
||||
image = self.process_latent_in(image)
|
||||
if num_channels <= out_channels * 2:
|
||||
return image
|
||||
|
||||
#inpaint model
|
||||
mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None))
|
||||
if mask is None:
|
||||
mask = torch.ones_like(noise)[:, :1]
|
||||
|
||||
mask = torch.mean(mask, dim=1, keepdim=True)
|
||||
print(mask.shape)
|
||||
mask = utils.common_upscale(mask.to(device), noise.shape[-1] * 8, noise.shape[-2] * 8, "bilinear", "center")
|
||||
mask = mask.view(mask.shape[0], mask.shape[2] // 8, 8, mask.shape[3] // 8, 8).permute(0, 2, 4, 1, 3).reshape(mask.shape[0], -1, mask.shape[2] // 8, mask.shape[3] // 8)
|
||||
mask = utils.resize_to_batch_size(mask, noise.shape[0])
|
||||
return torch.cat((image, mask), dim=1)
|
||||
|
||||
def encode_adm(self, **kwargs):
|
||||
return kwargs["pooled_output"]
|
||||
|
||||
@@ -734,3 +768,23 @@ class GenmoMochi(BaseModel):
|
||||
if cross_attn is not None:
|
||||
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||
return out
|
||||
|
||||
class LTXV(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLUX, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.lightricks.model.LTXVModel) #TODO
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = super().extra_conds(**kwargs)
|
||||
attention_mask = kwargs.get("attention_mask", None)
|
||||
if attention_mask is not None:
|
||||
out['attention_mask'] = comfy.conds.CONDRegular(attention_mask)
|
||||
cross_attn = kwargs.get("cross_attn", None)
|
||||
if cross_attn is not None:
|
||||
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||
|
||||
guiding_latent = kwargs.get("guiding_latent", None)
|
||||
if guiding_latent is not None:
|
||||
out['guiding_latent'] = comfy.conds.CONDRegular(guiding_latent)
|
||||
|
||||
out['frame_rate'] = comfy.conds.CONDConstant(kwargs.get("frame_rate", 25))
|
||||
return out
|
||||
|
@@ -137,6 +137,12 @@ def detect_unet_config(state_dict, key_prefix):
|
||||
dit_config = {}
|
||||
dit_config["image_model"] = "flux"
|
||||
dit_config["in_channels"] = 16
|
||||
patch_size = 2
|
||||
dit_config["patch_size"] = patch_size
|
||||
in_key = "{}img_in.weight".format(key_prefix)
|
||||
if in_key in state_dict_keys:
|
||||
dit_config["in_channels"] = state_dict[in_key].shape[1] // (patch_size * patch_size)
|
||||
dit_config["out_channels"] = 16
|
||||
dit_config["vec_in_dim"] = 768
|
||||
dit_config["context_in_dim"] = 4096
|
||||
dit_config["hidden_size"] = 3072
|
||||
@@ -177,6 +183,10 @@ def detect_unet_config(state_dict, key_prefix):
|
||||
dit_config["rope_theta"] = 10000.0
|
||||
return dit_config
|
||||
|
||||
if '{}adaln_single.emb.timestep_embedder.linear_1.bias'.format(key_prefix) in state_dict_keys: #Lightricks ltxv
|
||||
dit_config = {}
|
||||
dit_config["image_model"] = "ltxv"
|
||||
return dit_config
|
||||
|
||||
if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys:
|
||||
return None
|
||||
@@ -321,8 +331,9 @@ def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=Fal
|
||||
if model_config is None and use_base_if_no_match:
|
||||
model_config = comfy.supported_models_base.BASE(unet_config)
|
||||
|
||||
scaled_fp8_weight = state_dict.get("{}scaled_fp8".format(unet_key_prefix), None)
|
||||
if scaled_fp8_weight is not None:
|
||||
scaled_fp8_key = "{}scaled_fp8".format(unet_key_prefix)
|
||||
if scaled_fp8_key in state_dict:
|
||||
scaled_fp8_weight = state_dict.pop(scaled_fp8_key)
|
||||
model_config.scaled_fp8 = scaled_fp8_weight.dtype
|
||||
if model_config.scaled_fp8 == torch.float32:
|
||||
model_config.scaled_fp8 = torch.float8_e4m3fn
|
||||
|
@@ -373,14 +373,23 @@ class ModelPatcher:
|
||||
lowvram_counter = 0
|
||||
loading = []
|
||||
for n, m in self.model.named_modules():
|
||||
if hasattr(m, "comfy_cast_weights") or hasattr(m, "weight"):
|
||||
loading.append((comfy.model_management.module_size(m), n, m))
|
||||
params = []
|
||||
skip = False
|
||||
for name, param in m.named_parameters(recurse=False):
|
||||
params.append(name)
|
||||
for name, param in m.named_parameters(recurse=True):
|
||||
if name not in params:
|
||||
skip = True # skip random weights in non leaf modules
|
||||
break
|
||||
if not skip and (hasattr(m, "comfy_cast_weights") or len(params) > 0):
|
||||
loading.append((comfy.model_management.module_size(m), n, m, params))
|
||||
|
||||
load_completely = []
|
||||
loading.sort(reverse=True)
|
||||
for x in loading:
|
||||
n = x[1]
|
||||
m = x[2]
|
||||
params = x[3]
|
||||
module_mem = x[0]
|
||||
|
||||
lowvram_weight = False
|
||||
@@ -416,22 +425,22 @@ class ModelPatcher:
|
||||
if m.comfy_cast_weights:
|
||||
wipe_lowvram_weight(m)
|
||||
|
||||
if hasattr(m, "weight"):
|
||||
if full_load or mem_counter + module_mem < lowvram_model_memory:
|
||||
mem_counter += module_mem
|
||||
load_completely.append((module_mem, n, m))
|
||||
load_completely.append((module_mem, n, m, params))
|
||||
|
||||
load_completely.sort(reverse=True)
|
||||
for x in load_completely:
|
||||
n = x[1]
|
||||
m = x[2]
|
||||
weight_key = "{}.weight".format(n)
|
||||
bias_key = "{}.bias".format(n)
|
||||
params = x[3]
|
||||
if hasattr(m, "comfy_patched_weights"):
|
||||
if m.comfy_patched_weights == True:
|
||||
continue
|
||||
|
||||
self.patch_weight_to_device(weight_key, device_to=device_to)
|
||||
self.patch_weight_to_device(bias_key, device_to=device_to)
|
||||
for param in params:
|
||||
self.patch_weight_to_device("{}.{}".format(n, param), device_to=device_to)
|
||||
|
||||
logging.debug("lowvram: loaded module regularly {} {}".format(n, m))
|
||||
m.comfy_patched_weights = True
|
||||
|
||||
|
@@ -2,6 +2,25 @@ import torch
|
||||
from comfy.ldm.modules.diffusionmodules.util import make_beta_schedule
|
||||
import math
|
||||
|
||||
def rescale_zero_terminal_snr_sigmas(sigmas):
|
||||
alphas_cumprod = 1 / ((sigmas * sigmas) + 1)
|
||||
alphas_bar_sqrt = alphas_cumprod.sqrt()
|
||||
|
||||
# Store old values.
|
||||
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
|
||||
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
|
||||
|
||||
# Shift so the last timestep is zero.
|
||||
alphas_bar_sqrt -= (alphas_bar_sqrt_T)
|
||||
|
||||
# Scale so the first timestep is back to the old value.
|
||||
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
|
||||
|
||||
# Convert alphas_bar_sqrt to betas
|
||||
alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
|
||||
alphas_bar[-1] = 4.8973451890853435e-08
|
||||
return ((1 - alphas_bar) / alphas_bar) ** 0.5
|
||||
|
||||
class EPS:
|
||||
def calculate_input(self, sigma, noise):
|
||||
sigma = sigma.view(sigma.shape[:1] + (1,) * (noise.ndim - 1))
|
||||
@@ -48,7 +67,7 @@ class CONST:
|
||||
return latent / (1.0 - sigma)
|
||||
|
||||
class ModelSamplingDiscrete(torch.nn.Module):
|
||||
def __init__(self, model_config=None):
|
||||
def __init__(self, model_config=None, zsnr=None):
|
||||
super().__init__()
|
||||
|
||||
if model_config is not None:
|
||||
@@ -61,11 +80,14 @@ class ModelSamplingDiscrete(torch.nn.Module):
|
||||
linear_end = sampling_settings.get("linear_end", 0.012)
|
||||
timesteps = sampling_settings.get("timesteps", 1000)
|
||||
|
||||
self._register_schedule(given_betas=None, beta_schedule=beta_schedule, timesteps=timesteps, linear_start=linear_start, linear_end=linear_end, cosine_s=8e-3)
|
||||
if zsnr is None:
|
||||
zsnr = sampling_settings.get("zsnr", False)
|
||||
|
||||
self._register_schedule(given_betas=None, beta_schedule=beta_schedule, timesteps=timesteps, linear_start=linear_start, linear_end=linear_end, cosine_s=8e-3, zsnr=zsnr)
|
||||
self.sigma_data = 1.0
|
||||
|
||||
def _register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000,
|
||||
linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
|
||||
linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3, zsnr=False):
|
||||
if given_betas is not None:
|
||||
betas = given_betas
|
||||
else:
|
||||
@@ -83,6 +105,9 @@ class ModelSamplingDiscrete(torch.nn.Module):
|
||||
# self.register_buffer('alphas_cumprod_prev', torch.tensor(alphas_cumprod_prev, dtype=torch.float32))
|
||||
|
||||
sigmas = ((1 - alphas_cumprod) / alphas_cumprod) ** 0.5
|
||||
if zsnr:
|
||||
sigmas = rescale_zero_terminal_snr_sigmas(sigmas)
|
||||
|
||||
self.set_sigmas(sigmas)
|
||||
|
||||
def set_sigmas(self, sigmas):
|
||||
|
@@ -1,14 +1,10 @@
|
||||
import torch
|
||||
import comfy.model_management
|
||||
import comfy.conds
|
||||
import comfy.utils
|
||||
|
||||
def prepare_mask(noise_mask, shape, device):
|
||||
"""ensures noise mask is of proper dimensions"""
|
||||
noise_mask = torch.nn.functional.interpolate(noise_mask.reshape((-1, 1, noise_mask.shape[-2], noise_mask.shape[-1])), size=(shape[2], shape[3]), mode="bilinear")
|
||||
noise_mask = torch.cat([noise_mask] * shape[1], dim=1)
|
||||
noise_mask = comfy.utils.repeat_to_batch_size(noise_mask, shape[0])
|
||||
noise_mask = noise_mask.to(device)
|
||||
return noise_mask
|
||||
return comfy.utils.reshape_mask(noise_mask, shape).to(device)
|
||||
|
||||
def get_models_from_cond(cond, model_type):
|
||||
models = []
|
||||
|
56
comfy/sd.py
56
comfy/sd.py
@@ -8,6 +8,7 @@ from .ldm.cascade.stage_a import StageA
|
||||
from .ldm.cascade.stage_c_coder import StageC_coder
|
||||
from .ldm.audio.autoencoder import AudioOobleckVAE
|
||||
import comfy.ldm.genmo.vae.model
|
||||
import comfy.ldm.lightricks.vae.causal_video_autoencoder
|
||||
import yaml
|
||||
|
||||
import comfy.utils
|
||||
@@ -27,12 +28,16 @@ import comfy.text_encoders.hydit
|
||||
import comfy.text_encoders.flux
|
||||
import comfy.text_encoders.long_clipl
|
||||
import comfy.text_encoders.genmo
|
||||
import comfy.text_encoders.lt
|
||||
|
||||
import comfy.model_patcher
|
||||
import comfy.lora
|
||||
import comfy.lora_convert
|
||||
import comfy.t2i_adapter.adapter
|
||||
import comfy.taesd.taesd
|
||||
|
||||
import comfy.ldm.flux.redux
|
||||
|
||||
def load_lora_for_models(model, clip, lora, strength_model, strength_clip):
|
||||
key_map = {}
|
||||
if model is not None:
|
||||
@@ -40,6 +45,7 @@ def load_lora_for_models(model, clip, lora, strength_model, strength_clip):
|
||||
if clip is not None:
|
||||
key_map = comfy.lora.model_lora_keys_clip(clip.cond_stage_model, key_map)
|
||||
|
||||
lora = comfy.lora_convert.convert_lora(lora)
|
||||
loaded = comfy.lora.load_lora(lora, key_map)
|
||||
if model is not None:
|
||||
new_modelpatcher = model.clone()
|
||||
@@ -245,7 +251,7 @@ class VAE:
|
||||
self.process_output = lambda audio: audio
|
||||
self.process_input = lambda audio: audio
|
||||
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
||||
elif "blocks.2.blocks.3.stack.5.weight" in sd or "decoder.blocks.2.blocks.3.stack.5.weight" in sd or "layers.4.layers.1.attn_block.attn.qkv.weight" in sd or "encoder.layers.4.layers.1.attn_block.attn.qkv.weight": #genmo mochi vae
|
||||
elif "blocks.2.blocks.3.stack.5.weight" in sd or "decoder.blocks.2.blocks.3.stack.5.weight" in sd or "layers.4.layers.1.attn_block.attn.qkv.weight" in sd or "encoder.layers.4.layers.1.attn_block.attn.qkv.weight" in sd: #genmo mochi vae
|
||||
if "blocks.2.blocks.3.stack.5.weight" in sd:
|
||||
sd = comfy.utils.state_dict_prefix_replace(sd, {"": "decoder."})
|
||||
if "layers.4.layers.1.attn_block.attn.qkv.weight" in sd:
|
||||
@@ -257,6 +263,14 @@ class VAE:
|
||||
self.memory_used_encode = lambda shape, dtype: (1.5 * max(shape[2], 7) * shape[3] * shape[4] * (6 * 8 * 8)) * model_management.dtype_size(dtype)
|
||||
self.upscale_ratio = (lambda a: max(0, a * 6 - 5), 8, 8)
|
||||
self.working_dtypes = [torch.float16, torch.float32]
|
||||
elif "decoder.up_blocks.0.res_blocks.0.conv1.conv.weight" in sd: #lightricks ltxv
|
||||
self.first_stage_model = comfy.ldm.lightricks.vae.causal_video_autoencoder.VideoVAE()
|
||||
self.latent_channels = 128
|
||||
self.latent_dim = 3
|
||||
self.memory_used_decode = lambda shape, dtype: (900 * shape[2] * shape[3] * shape[4] * (8 * 8 * 8)) * model_management.dtype_size(dtype)
|
||||
self.memory_used_encode = lambda shape, dtype: (70 * max(shape[2], 7) * shape[3] * shape[4]) * model_management.dtype_size(dtype)
|
||||
self.upscale_ratio = (lambda a: max(0, a * 8 - 7), 32, 32)
|
||||
self.working_dtypes = [torch.bfloat16, torch.float32]
|
||||
else:
|
||||
logging.warning("WARNING: No VAE weights detected, VAE not initalized.")
|
||||
self.first_stage_model = None
|
||||
@@ -356,15 +370,33 @@ class VAE:
|
||||
elif dims == 2:
|
||||
pixel_samples = self.decode_tiled_(samples_in)
|
||||
elif dims == 3:
|
||||
pixel_samples = self.decode_tiled_3d(samples_in)
|
||||
tile = 256 // self.spacial_compression_decode()
|
||||
overlap = tile // 4
|
||||
pixel_samples = self.decode_tiled_3d(samples_in, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap))
|
||||
|
||||
pixel_samples = pixel_samples.to(self.output_device).movedim(1,-1)
|
||||
return pixel_samples
|
||||
|
||||
def decode_tiled(self, samples, tile_x=64, tile_y=64, overlap = 16):
|
||||
model_management.load_model_gpu(self.patcher)
|
||||
output = self.decode_tiled_(samples, tile_x, tile_y, overlap)
|
||||
return output.movedim(1,-1)
|
||||
def decode_tiled(self, samples, tile_x=None, tile_y=None, overlap=None):
|
||||
memory_used = self.memory_used_decode(samples.shape, self.vae_dtype) #TODO: calculate mem required for tile
|
||||
model_management.load_models_gpu([self.patcher], memory_required=memory_used)
|
||||
dims = samples.ndim - 2
|
||||
args = {}
|
||||
if tile_x is not None:
|
||||
args["tile_x"] = tile_x
|
||||
if tile_y is not None:
|
||||
args["tile_y"] = tile_y
|
||||
if overlap is not None:
|
||||
args["overlap"] = overlap
|
||||
|
||||
if dims == 1:
|
||||
args.pop("tile_y")
|
||||
output = self.decode_tiled_1d(samples, **args)
|
||||
elif dims == 2:
|
||||
output = self.decode_tiled_(samples, **args)
|
||||
elif dims == 3:
|
||||
output = self.decode_tiled_3d(samples, **args)
|
||||
return output.movedim(1, -1)
|
||||
|
||||
def encode(self, pixel_samples):
|
||||
pixel_samples = self.vae_encode_crop_pixels(pixel_samples)
|
||||
@@ -404,6 +436,12 @@ class VAE:
|
||||
def get_sd(self):
|
||||
return self.first_stage_model.state_dict()
|
||||
|
||||
def spacial_compression_decode(self):
|
||||
try:
|
||||
return self.upscale_ratio[-1]
|
||||
except:
|
||||
return self.upscale_ratio
|
||||
|
||||
class StyleModel:
|
||||
def __init__(self, model, device="cpu"):
|
||||
self.model = model
|
||||
@@ -417,6 +455,8 @@ def load_style_model(ckpt_path):
|
||||
keys = model_data.keys()
|
||||
if "style_embedding" in keys:
|
||||
model = comfy.t2i_adapter.adapter.StyleAdapter(width=1024, context_dim=768, num_head=8, n_layes=3, num_token=8)
|
||||
elif "redux_down.weight" in keys:
|
||||
model = comfy.ldm.flux.redux.ReduxImageEncoder()
|
||||
else:
|
||||
raise Exception("invalid style model {}".format(ckpt_path))
|
||||
model.load_state_dict(model_data)
|
||||
@@ -430,6 +470,7 @@ class CLIPType(Enum):
|
||||
HUNYUAN_DIT = 5
|
||||
FLUX = 6
|
||||
MOCHI = 7
|
||||
LTXV = 8
|
||||
|
||||
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
|
||||
clip_data = []
|
||||
@@ -508,6 +549,9 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
||||
if clip_type == CLIPType.SD3:
|
||||
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=False, clip_g=False, t5=True, **t5xxl_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
|
||||
elif clip_type == CLIPType.LTXV:
|
||||
clip_target.clip = comfy.text_encoders.lt.ltxv_te(**t5xxl_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.lt.LTXVT5Tokenizer
|
||||
else: #CLIPType.MOCHI
|
||||
clip_target.clip = comfy.text_encoders.genmo.mochi_te(**t5xxl_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.genmo.MochiT5Tokenizer
|
||||
|
@@ -11,6 +11,7 @@ import comfy.text_encoders.aura_t5
|
||||
import comfy.text_encoders.hydit
|
||||
import comfy.text_encoders.flux
|
||||
import comfy.text_encoders.genmo
|
||||
import comfy.text_encoders.lt
|
||||
|
||||
from . import supported_models_base
|
||||
from . import latent_formats
|
||||
@@ -197,6 +198,8 @@ class SDXL(supported_models_base.BASE):
|
||||
self.sampling_settings["sigma_min"] = float(state_dict["edm_vpred.sigma_min"].item())
|
||||
return model_base.ModelType.V_PREDICTION_EDM
|
||||
elif "v_pred" in state_dict:
|
||||
if "ztsnr" in state_dict: #Some zsnr anime checkpoints
|
||||
self.sampling_settings["zsnr"] = True
|
||||
return model_base.ModelType.V_PREDICTION
|
||||
else:
|
||||
return model_base.ModelType.EPS
|
||||
@@ -700,7 +703,34 @@ class GenmoMochi(supported_models_base.BASE):
|
||||
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.genmo.MochiT5Tokenizer, comfy.text_encoders.genmo.mochi_te(**t5_detect))
|
||||
|
||||
class LTXV(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"image_model": "ltxv",
|
||||
}
|
||||
|
||||
models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, HunyuanDiT, HunyuanDiT1, Flux, FluxSchnell, GenmoMochi]
|
||||
sampling_settings = {
|
||||
"shift": 2.37,
|
||||
}
|
||||
|
||||
unet_extra_config = {}
|
||||
latent_format = latent_formats.LTXV
|
||||
|
||||
memory_usage_factor = 2.7
|
||||
|
||||
supported_inference_dtypes = [torch.bfloat16, torch.float32]
|
||||
|
||||
vae_key_prefix = ["vae."]
|
||||
text_encoder_key_prefix = ["text_encoders."]
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.LTXV(self, device=device)
|
||||
return out
|
||||
|
||||
def clip_target(self, state_dict={}):
|
||||
pref = self.text_encoder_key_prefix[0]
|
||||
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.lt.LTXVT5Tokenizer, comfy.text_encoders.lt.ltxv_te(**t5_detect))
|
||||
|
||||
models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, HunyuanDiT, HunyuanDiT1, Flux, FluxSchnell, GenmoMochi, LTXV]
|
||||
|
||||
models += [SVD_img2vid]
|
||||
|
18
comfy/text_encoders/lt.py
Normal file
18
comfy/text_encoders/lt.py
Normal file
@@ -0,0 +1,18 @@
|
||||
from comfy import sd1_clip
|
||||
import os
|
||||
from transformers import T5TokenizerFast
|
||||
import comfy.text_encoders.genmo
|
||||
|
||||
class T5XXLTokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
|
||||
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=128) #pad to 128?
|
||||
|
||||
|
||||
class LTXVT5Tokenizer(sd1_clip.SD1Tokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="t5xxl", tokenizer=T5XXLTokenizer)
|
||||
|
||||
|
||||
def ltxv_te(*args, **kwargs):
|
||||
return comfy.text_encoders.genmo.mochi_te(*args, **kwargs)
|
@@ -848,3 +848,24 @@ class ProgressBar:
|
||||
|
||||
def update(self, value):
|
||||
self.update_absolute(self.current + value)
|
||||
|
||||
def reshape_mask(input_mask, output_shape):
|
||||
dims = len(output_shape) - 2
|
||||
|
||||
if dims == 1:
|
||||
scale_mode = "linear"
|
||||
|
||||
if dims == 2:
|
||||
input_mask = input_mask.reshape((-1, 1, input_mask.shape[-2], input_mask.shape[-1]))
|
||||
scale_mode = "bilinear"
|
||||
|
||||
if dims == 3:
|
||||
if len(input_mask.shape) < 5:
|
||||
input_mask = input_mask.reshape((1, 1, -1, input_mask.shape[-2], input_mask.shape[-1]))
|
||||
scale_mode = "trilinear"
|
||||
|
||||
mask = torch.nn.functional.interpolate(input_mask, size=output_shape[2:], mode=scale_mode)
|
||||
if mask.shape[1] < output_shape[1]:
|
||||
mask = mask.repeat((1, output_shape[1]) + (1,) * dims)[:,:output_shape[1]]
|
||||
mask = comfy.utils.repeat_to_batch_size(mask, output_shape[0])
|
||||
return mask
|
||||
|
181
comfy_extras/nodes_lt.py
Normal file
181
comfy_extras/nodes_lt.py
Normal file
@@ -0,0 +1,181 @@
|
||||
import nodes
|
||||
import node_helpers
|
||||
import torch
|
||||
import comfy.model_management
|
||||
import comfy.model_sampling
|
||||
import math
|
||||
|
||||
class EmptyLTXVLatentVideo:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "width": ("INT", {"default": 768, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 32}),
|
||||
"height": ("INT", {"default": 512, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 32}),
|
||||
"length": ("INT", {"default": 97, "min": 9, "max": nodes.MAX_RESOLUTION, "step": 8}),
|
||||
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096})}}
|
||||
RETURN_TYPES = ("LATENT",)
|
||||
FUNCTION = "generate"
|
||||
|
||||
CATEGORY = "latent/video/ltxv"
|
||||
|
||||
def generate(self, width, height, length, batch_size=1):
|
||||
latent = torch.zeros([batch_size, 128, ((length - 1) // 8) + 1, height // 32, width // 32], device=comfy.model_management.intermediate_device())
|
||||
return ({"samples": latent}, )
|
||||
|
||||
|
||||
class LTXVImgToVideo:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"positive": ("CONDITIONING", ),
|
||||
"negative": ("CONDITIONING", ),
|
||||
"vae": ("VAE",),
|
||||
"image": ("IMAGE",),
|
||||
"width": ("INT", {"default": 768, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 32}),
|
||||
"height": ("INT", {"default": 512, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 32}),
|
||||
"length": ("INT", {"default": 97, "min": 9, "max": nodes.MAX_RESOLUTION, "step": 8}),
|
||||
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096})}}
|
||||
|
||||
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
|
||||
RETURN_NAMES = ("positive", "negative", "latent")
|
||||
|
||||
CATEGORY = "conditioning/video_models"
|
||||
FUNCTION = "generate"
|
||||
|
||||
def generate(self, positive, negative, image, vae, width, height, length, batch_size):
|
||||
pixels = comfy.utils.common_upscale(image.movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
||||
encode_pixels = pixels[:, :, :, :3]
|
||||
t = vae.encode(encode_pixels)
|
||||
positive = node_helpers.conditioning_set_values(positive, {"guiding_latent": t})
|
||||
negative = node_helpers.conditioning_set_values(negative, {"guiding_latent": t})
|
||||
|
||||
latent = torch.zeros([batch_size, 128, ((length - 1) // 8) + 1, height // 32, width // 32], device=comfy.model_management.intermediate_device())
|
||||
latent[:, :, :t.shape[2]] = t
|
||||
return (positive, negative, {"samples": latent}, )
|
||||
|
||||
|
||||
class LTXVConditioning:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"positive": ("CONDITIONING", ),
|
||||
"negative": ("CONDITIONING", ),
|
||||
"frame_rate": ("FLOAT", {"default": 25.0, "min": 0.0, "max": 1000.0, "step": 0.01}),
|
||||
}}
|
||||
RETURN_TYPES = ("CONDITIONING", "CONDITIONING")
|
||||
RETURN_NAMES = ("positive", "negative")
|
||||
FUNCTION = "append"
|
||||
|
||||
CATEGORY = "conditioning/video_models"
|
||||
|
||||
def append(self, positive, negative, frame_rate):
|
||||
positive = node_helpers.conditioning_set_values(positive, {"frame_rate": frame_rate})
|
||||
negative = node_helpers.conditioning_set_values(negative, {"frame_rate": frame_rate})
|
||||
return (positive, negative)
|
||||
|
||||
|
||||
class ModelSamplingLTXV:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "model": ("MODEL",),
|
||||
"max_shift": ("FLOAT", {"default": 2.05, "min": 0.0, "max": 100.0, "step":0.01}),
|
||||
"base_shift": ("FLOAT", {"default": 0.95, "min": 0.0, "max": 100.0, "step":0.01}),
|
||||
},
|
||||
"optional": {"latent": ("LATENT",), }
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
FUNCTION = "patch"
|
||||
|
||||
CATEGORY = "advanced/model"
|
||||
|
||||
def patch(self, model, max_shift, base_shift, latent=None):
|
||||
m = model.clone()
|
||||
|
||||
if latent is None:
|
||||
tokens = 4096
|
||||
else:
|
||||
tokens = math.prod(latent["samples"].shape[2:])
|
||||
|
||||
x1 = 1024
|
||||
x2 = 4096
|
||||
mm = (max_shift - base_shift) / (x2 - x1)
|
||||
b = base_shift - mm * x1
|
||||
shift = (tokens) * mm + b
|
||||
|
||||
sampling_base = comfy.model_sampling.ModelSamplingFlux
|
||||
sampling_type = comfy.model_sampling.CONST
|
||||
|
||||
class ModelSamplingAdvanced(sampling_base, sampling_type):
|
||||
pass
|
||||
|
||||
model_sampling = ModelSamplingAdvanced(model.model.model_config)
|
||||
model_sampling.set_parameters(shift=shift)
|
||||
m.add_object_patch("model_sampling", model_sampling)
|
||||
return (m, )
|
||||
|
||||
|
||||
class LTXVScheduler:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required":
|
||||
{"steps": ("INT", {"default": 20, "min": 1, "max": 10000}),
|
||||
"max_shift": ("FLOAT", {"default": 2.05, "min": 0.0, "max": 100.0, "step":0.01}),
|
||||
"base_shift": ("FLOAT", {"default": 0.95, "min": 0.0, "max": 100.0, "step":0.01}),
|
||||
"stretch": ("BOOLEAN", {
|
||||
"default": True,
|
||||
"tooltip": "Stretch the sigmas to be in the range [terminal, 1]."
|
||||
}),
|
||||
"terminal": (
|
||||
"FLOAT",
|
||||
{
|
||||
"default": 0.1, "min": 0.0, "max": 0.99, "step": 0.01,
|
||||
"tooltip": "The terminal value of the sigmas after stretching."
|
||||
},
|
||||
),
|
||||
},
|
||||
"optional": {"latent": ("LATENT",), }
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("SIGMAS",)
|
||||
CATEGORY = "sampling/custom_sampling/schedulers"
|
||||
|
||||
FUNCTION = "get_sigmas"
|
||||
|
||||
def get_sigmas(self, steps, max_shift, base_shift, stretch, terminal, latent=None):
|
||||
if latent is None:
|
||||
tokens = 4096
|
||||
else:
|
||||
tokens = math.prod(latent["samples"].shape[2:])
|
||||
|
||||
sigmas = torch.linspace(1.0, 0.0, steps + 1)
|
||||
|
||||
x1 = 1024
|
||||
x2 = 4096
|
||||
mm = (max_shift - base_shift) / (x2 - x1)
|
||||
b = base_shift - mm * x1
|
||||
sigma_shift = (tokens) * mm + b
|
||||
|
||||
power = 1
|
||||
sigmas = torch.where(
|
||||
sigmas != 0,
|
||||
math.exp(sigma_shift) / (math.exp(sigma_shift) + (1 / sigmas - 1) ** power),
|
||||
0,
|
||||
)
|
||||
|
||||
# Stretch sigmas so that its final value matches the given terminal value.
|
||||
if stretch:
|
||||
non_zero_mask = sigmas != 0
|
||||
non_zero_sigmas = sigmas[non_zero_mask]
|
||||
one_minus_z = 1.0 - non_zero_sigmas
|
||||
scale_factor = one_minus_z[-1] / (1.0 - terminal)
|
||||
stretched = 1.0 - (one_minus_z / scale_factor)
|
||||
sigmas[non_zero_mask] = stretched
|
||||
|
||||
return (sigmas,)
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"EmptyLTXVLatentVideo": EmptyLTXVLatentVideo,
|
||||
"LTXVImgToVideo": LTXVImgToVideo,
|
||||
"ModelSamplingLTXV": ModelSamplingLTXV,
|
||||
"LTXVConditioning": LTXVConditioning,
|
||||
"LTXVScheduler": LTXVScheduler,
|
||||
}
|
@@ -3,9 +3,6 @@ import torch
|
||||
import comfy.model_management
|
||||
|
||||
class EmptyMochiLatentVideo:
|
||||
def __init__(self):
|
||||
self.device = comfy.model_management.intermediate_device()
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "width": ("INT", {"default": 848, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
||||
@@ -15,10 +12,10 @@ class EmptyMochiLatentVideo:
|
||||
RETURN_TYPES = ("LATENT",)
|
||||
FUNCTION = "generate"
|
||||
|
||||
CATEGORY = "latent/mochi"
|
||||
CATEGORY = "latent/video"
|
||||
|
||||
def generate(self, width, height, length, batch_size=1):
|
||||
latent = torch.zeros([batch_size, 12, ((length - 1) // 6) + 1, height // 8, width // 8], device=self.device)
|
||||
latent = torch.zeros([batch_size, 12, ((length - 1) // 6) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
|
||||
return ({"samples":latent}, )
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
|
@@ -26,8 +26,8 @@ class X0(comfy.model_sampling.EPS):
|
||||
class ModelSamplingDiscreteDistilled(comfy.model_sampling.ModelSamplingDiscrete):
|
||||
original_timesteps = 50
|
||||
|
||||
def __init__(self, model_config=None):
|
||||
super().__init__(model_config)
|
||||
def __init__(self, model_config=None, zsnr=None):
|
||||
super().__init__(model_config, zsnr=zsnr)
|
||||
|
||||
self.skip_steps = self.num_timesteps // self.original_timesteps
|
||||
|
||||
@@ -51,25 +51,6 @@ class ModelSamplingDiscreteDistilled(comfy.model_sampling.ModelSamplingDiscrete)
|
||||
return log_sigma.exp().to(timestep.device)
|
||||
|
||||
|
||||
def rescale_zero_terminal_snr_sigmas(sigmas):
|
||||
alphas_cumprod = 1 / ((sigmas * sigmas) + 1)
|
||||
alphas_bar_sqrt = alphas_cumprod.sqrt()
|
||||
|
||||
# Store old values.
|
||||
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
|
||||
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
|
||||
|
||||
# Shift so the last timestep is zero.
|
||||
alphas_bar_sqrt -= (alphas_bar_sqrt_T)
|
||||
|
||||
# Scale so the first timestep is back to the old value.
|
||||
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
|
||||
|
||||
# Convert alphas_bar_sqrt to betas
|
||||
alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
|
||||
alphas_bar[-1] = 4.8973451890853435e-08
|
||||
return ((1 - alphas_bar) / alphas_bar) ** 0.5
|
||||
|
||||
class ModelSamplingDiscrete:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
@@ -100,9 +81,7 @@ class ModelSamplingDiscrete:
|
||||
class ModelSamplingAdvanced(sampling_base, sampling_type):
|
||||
pass
|
||||
|
||||
model_sampling = ModelSamplingAdvanced(model.model.model_config)
|
||||
if zsnr:
|
||||
model_sampling.set_sigmas(rescale_zero_terminal_snr_sigmas(model_sampling.sigmas))
|
||||
model_sampling = ModelSamplingAdvanced(model.model.model_config, zsnr=zsnr)
|
||||
|
||||
m.add_object_patch("model_sampling", model_sampling)
|
||||
return (m, )
|
||||
|
@@ -75,6 +75,34 @@ class ModelMergeSD3_2B(comfy_extras.nodes_model_merging.ModelMergeBlocks):
|
||||
|
||||
return {"required": arg_dict}
|
||||
|
||||
|
||||
class ModelMergeAuraflow(comfy_extras.nodes_model_merging.ModelMergeBlocks):
|
||||
CATEGORY = "advanced/model_merging/model_specific"
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
arg_dict = { "model1": ("MODEL",),
|
||||
"model2": ("MODEL",)}
|
||||
|
||||
argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
|
||||
|
||||
arg_dict["init_x_linear."] = argument
|
||||
arg_dict["positional_encoding"] = argument
|
||||
arg_dict["cond_seq_linear."] = argument
|
||||
arg_dict["register_tokens"] = argument
|
||||
arg_dict["t_embedder."] = argument
|
||||
|
||||
for i in range(4):
|
||||
arg_dict["double_layers.{}.".format(i)] = argument
|
||||
|
||||
for i in range(32):
|
||||
arg_dict["single_layers.{}.".format(i)] = argument
|
||||
|
||||
arg_dict["modF."] = argument
|
||||
arg_dict["final_linear."] = argument
|
||||
|
||||
return {"required": arg_dict}
|
||||
|
||||
class ModelMergeFlux1(comfy_extras.nodes_model_merging.ModelMergeBlocks):
|
||||
CATEGORY = "advanced/model_merging/model_specific"
|
||||
|
||||
@@ -124,11 +152,35 @@ class ModelMergeSD35_Large(comfy_extras.nodes_model_merging.ModelMergeBlocks):
|
||||
|
||||
return {"required": arg_dict}
|
||||
|
||||
class ModelMergeMochiPreview(comfy_extras.nodes_model_merging.ModelMergeBlocks):
|
||||
CATEGORY = "advanced/model_merging/model_specific"
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
arg_dict = { "model1": ("MODEL",),
|
||||
"model2": ("MODEL",)}
|
||||
|
||||
argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
|
||||
|
||||
arg_dict["pos_frequencies."] = argument
|
||||
arg_dict["t_embedder."] = argument
|
||||
arg_dict["t5_y_embedder."] = argument
|
||||
arg_dict["t5_yproj."] = argument
|
||||
|
||||
for i in range(48):
|
||||
arg_dict["blocks.{}.".format(i)] = argument
|
||||
|
||||
arg_dict["final_layer."] = argument
|
||||
|
||||
return {"required": arg_dict}
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"ModelMergeSD1": ModelMergeSD1,
|
||||
"ModelMergeSD2": ModelMergeSD1, #SD1 and SD2 have the same blocks
|
||||
"ModelMergeSDXL": ModelMergeSDXL,
|
||||
"ModelMergeSD3_2B": ModelMergeSD3_2B,
|
||||
"ModelMergeAuraflow": ModelMergeAuraflow,
|
||||
"ModelMergeFlux1": ModelMergeFlux1,
|
||||
"ModelMergeSD35_Large": ModelMergeSD35_Large,
|
||||
"ModelMergeMochiPreview": ModelMergeMochiPreview,
|
||||
}
|
||||
|
@@ -57,12 +57,24 @@ def create_blur_map(x0, attn, sigma=3.0, threshold=1.0):
|
||||
attn = attn.reshape(b, -1, hw1, hw2)
|
||||
# Global Average Pool
|
||||
mask = attn.mean(1, keepdim=False).sum(1, keepdim=False) > threshold
|
||||
ratio = 2**(math.ceil(math.sqrt(lh * lw / hw1)) - 1).bit_length()
|
||||
mid_shape = [math.ceil(lh / ratio), math.ceil(lw / ratio)]
|
||||
|
||||
total = mask.shape[-1]
|
||||
x = round(math.sqrt((lh / lw) * total))
|
||||
xx = None
|
||||
for i in range(0, math.floor(math.sqrt(total) / 2)):
|
||||
for j in [(x + i), max(1, x - i)]:
|
||||
if total % j == 0:
|
||||
xx = j
|
||||
break
|
||||
if xx is not None:
|
||||
break
|
||||
|
||||
x = xx
|
||||
y = total // x
|
||||
|
||||
# Reshape
|
||||
mask = (
|
||||
mask.reshape(b, *mid_shape)
|
||||
mask.reshape(b, x, y)
|
||||
.unsqueeze(1)
|
||||
.type(attn.dtype)
|
||||
)
|
||||
|
@@ -3,7 +3,9 @@ import comfy.sd
|
||||
import comfy.model_management
|
||||
import nodes
|
||||
import torch
|
||||
import re
|
||||
import comfy_extras.nodes_slg
|
||||
|
||||
|
||||
class TripleCLIPLoader:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
@@ -14,6 +16,8 @@ class TripleCLIPLoader:
|
||||
|
||||
CATEGORY = "advanced/loaders"
|
||||
|
||||
DESCRIPTION = "[Recipes]\n\nsd3: clip-l, clip-g, t5"
|
||||
|
||||
def load_clip(self, clip_name1, clip_name2, clip_name3):
|
||||
clip_path1 = folder_paths.get_full_path_or_raise("text_encoders", clip_name1)
|
||||
clip_path2 = folder_paths.get_full_path_or_raise("text_encoders", clip_name2)
|
||||
@@ -21,6 +25,7 @@ class TripleCLIPLoader:
|
||||
clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2, clip_path3], embedding_directory=folder_paths.get_folder_paths("embeddings"))
|
||||
return (clip,)
|
||||
|
||||
|
||||
class EmptySD3LatentImage:
|
||||
def __init__(self):
|
||||
self.device = comfy.model_management.intermediate_device()
|
||||
@@ -39,6 +44,7 @@ class EmptySD3LatentImage:
|
||||
latent = torch.zeros([batch_size, 16, height // 8, width // 8], device=self.device)
|
||||
return ({"samples":latent}, )
|
||||
|
||||
|
||||
class CLIPTextEncodeSD3:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
@@ -95,7 +101,8 @@ class ControlNetApplySD3(nodes.ControlNetApplyAdvanced):
|
||||
CATEGORY = "conditioning/controlnet"
|
||||
DEPRECATED = True
|
||||
|
||||
class SkipLayerGuidanceSD3:
|
||||
|
||||
class SkipLayerGuidanceSD3(comfy_extras.nodes_slg.SkipLayerGuidanceDiT):
|
||||
'''
|
||||
Enhance guidance towards detailed dtructure by having another set of CFG negative with skipped layers.
|
||||
Inspired by Perturbed Attention Guidance (https://arxiv.org/abs/2403.17377)
|
||||
@@ -110,47 +117,12 @@ class SkipLayerGuidanceSD3:
|
||||
"end_percent": ("FLOAT", {"default": 0.15, "min": 0.0, "max": 1.0, "step": 0.001})
|
||||
}}
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
FUNCTION = "skip_guidance"
|
||||
FUNCTION = "skip_guidance_sd3"
|
||||
|
||||
CATEGORY = "advanced/guidance"
|
||||
|
||||
|
||||
def skip_guidance(self, model, layers, scale, start_percent, end_percent):
|
||||
if layers == "" or layers == None:
|
||||
return (model, )
|
||||
# check if layer is comma separated integers
|
||||
def skip(args, extra_args):
|
||||
return args
|
||||
|
||||
model_sampling = model.get_model_object("model_sampling")
|
||||
sigma_start = model_sampling.percent_to_sigma(start_percent)
|
||||
sigma_end = model_sampling.percent_to_sigma(end_percent)
|
||||
|
||||
def post_cfg_function(args):
|
||||
model = args["model"]
|
||||
cond_pred = args["cond_denoised"]
|
||||
cond = args["cond"]
|
||||
cfg_result = args["denoised"]
|
||||
sigma = args["sigma"]
|
||||
x = args["input"]
|
||||
model_options = args["model_options"].copy()
|
||||
|
||||
for layer in layers:
|
||||
model_options = comfy.model_patcher.set_model_options_patch_replace(model_options, skip, "dit", "double_block", layer)
|
||||
model_sampling.percent_to_sigma(start_percent)
|
||||
|
||||
sigma_ = sigma[0].item()
|
||||
if scale > 0 and sigma_ >= sigma_end and sigma_ <= sigma_start:
|
||||
(slg,) = comfy.samplers.calc_cond_batch(model, [cond], x, sigma, model_options)
|
||||
cfg_result = cfg_result + (cond_pred - slg) * scale
|
||||
return cfg_result
|
||||
|
||||
layers = re.findall(r'\d+', layers)
|
||||
layers = [int(i) for i in layers]
|
||||
m = model.clone()
|
||||
m.set_model_sampler_post_cfg_function(post_cfg_function)
|
||||
|
||||
return (m, )
|
||||
def skip_guidance_sd3(self, model, layers, scale, start_percent, end_percent):
|
||||
return self.skip_guidance(model=model, scale=scale, start_percent=start_percent, end_percent=end_percent, double_layers=layers)
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
|
78
comfy_extras/nodes_slg.py
Normal file
78
comfy_extras/nodes_slg.py
Normal file
@@ -0,0 +1,78 @@
|
||||
import comfy.model_patcher
|
||||
import comfy.samplers
|
||||
import re
|
||||
|
||||
|
||||
class SkipLayerGuidanceDiT:
|
||||
'''
|
||||
Enhance guidance towards detailed dtructure by having another set of CFG negative with skipped layers.
|
||||
Inspired by Perturbed Attention Guidance (https://arxiv.org/abs/2403.17377)
|
||||
Original experimental implementation for SD3 by Dango233@StabilityAI.
|
||||
'''
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"model": ("MODEL", ),
|
||||
"double_layers": ("STRING", {"default": "7, 8, 9", "multiline": False}),
|
||||
"single_layers": ("STRING", {"default": "7, 8, 9", "multiline": False}),
|
||||
"scale": ("FLOAT", {"default": 3.0, "min": 0.0, "max": 10.0, "step": 0.1}),
|
||||
"start_percent": ("FLOAT", {"default": 0.01, "min": 0.0, "max": 1.0, "step": 0.001}),
|
||||
"end_percent": ("FLOAT", {"default": 0.15, "min": 0.0, "max": 1.0, "step": 0.001})
|
||||
}}
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
FUNCTION = "skip_guidance"
|
||||
EXPERIMENTAL = True
|
||||
|
||||
DESCRIPTION = "Generic version of SkipLayerGuidance node that can be used on every DiT model."
|
||||
|
||||
CATEGORY = "advanced/guidance"
|
||||
|
||||
def skip_guidance(self, model, scale, start_percent, end_percent, double_layers="", single_layers=""):
|
||||
# check if layer is comma separated integers
|
||||
def skip(args, extra_args):
|
||||
return args
|
||||
|
||||
model_sampling = model.get_model_object("model_sampling")
|
||||
sigma_start = model_sampling.percent_to_sigma(start_percent)
|
||||
sigma_end = model_sampling.percent_to_sigma(end_percent)
|
||||
|
||||
double_layers = re.findall(r'\d+', double_layers)
|
||||
double_layers = [int(i) for i in double_layers]
|
||||
|
||||
single_layers = re.findall(r'\d+', single_layers)
|
||||
single_layers = [int(i) for i in single_layers]
|
||||
|
||||
if len(double_layers) == 0 and len(single_layers) == 0:
|
||||
return (model, )
|
||||
|
||||
def post_cfg_function(args):
|
||||
model = args["model"]
|
||||
cond_pred = args["cond_denoised"]
|
||||
cond = args["cond"]
|
||||
cfg_result = args["denoised"]
|
||||
sigma = args["sigma"]
|
||||
x = args["input"]
|
||||
model_options = args["model_options"].copy()
|
||||
|
||||
for layer in double_layers:
|
||||
model_options = comfy.model_patcher.set_model_options_patch_replace(model_options, skip, "dit", "double_block", layer)
|
||||
|
||||
for layer in single_layers:
|
||||
model_options = comfy.model_patcher.set_model_options_patch_replace(model_options, skip, "dit", "single_block", layer)
|
||||
|
||||
model_sampling.percent_to_sigma(start_percent)
|
||||
|
||||
sigma_ = sigma[0].item()
|
||||
if scale > 0 and sigma_ >= sigma_end and sigma_ <= sigma_start:
|
||||
(slg,) = comfy.samplers.calc_cond_batch(model, [cond], x, sigma, model_options)
|
||||
cfg_result = cfg_result + (cond_pred - slg) * scale
|
||||
return cfg_result
|
||||
|
||||
m = model.clone()
|
||||
m.set_model_sampler_post_cfg_function(post_cfg_function)
|
||||
|
||||
return (m, )
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"SkipLayerGuidanceDiT": SkipLayerGuidanceDiT,
|
||||
}
|
Binary file not shown.
Before Width: | Height: | Size: 116 KiB |
@@ -47,7 +47,12 @@ class Latent2RGBPreviewer(LatentPreviewer):
|
||||
if self.latent_rgb_factors_bias is not None:
|
||||
self.latent_rgb_factors_bias = self.latent_rgb_factors_bias.to(dtype=x0.dtype, device=x0.device)
|
||||
|
||||
latent_image = torch.nn.functional.linear(x0[0].permute(1, 2, 0), self.latent_rgb_factors, bias=self.latent_rgb_factors_bias)
|
||||
if x0.ndim == 5:
|
||||
x0 = x0[0, :, 0]
|
||||
else:
|
||||
x0 = x0[0]
|
||||
|
||||
latent_image = torch.nn.functional.linear(x0.movedim(0, -1), self.latent_rgb_factors, bias=self.latent_rgb_factors_bias)
|
||||
# latent_image = x0[0].permute(1, 2, 0) @ self.latent_rgb_factors
|
||||
|
||||
return preview_to_image(latent_image)
|
||||
|
1
main.py
1
main.py
@@ -71,6 +71,7 @@ if os.name == "nt":
|
||||
if __name__ == "__main__":
|
||||
if args.cuda_device is not None:
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.cuda_device)
|
||||
os.environ['HIP_VISIBLE_DEVICES'] = str(args.cuda_device)
|
||||
logging.info("Set cuda device to: {}".format(args.cuda_device))
|
||||
|
||||
if args.deterministic:
|
||||
|
29
nodes.py
29
nodes.py
@@ -290,15 +290,22 @@ class VAEDecodeTiled:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"samples": ("LATENT", ), "vae": ("VAE", ),
|
||||
"tile_size": ("INT", {"default": 512, "min": 320, "max": 4096, "step": 64})
|
||||
"tile_size": ("INT", {"default": 512, "min": 128, "max": 4096, "step": 32}),
|
||||
"overlap": ("INT", {"default": 64, "min": 0, "max": 4096, "step": 32}),
|
||||
}}
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "decode"
|
||||
|
||||
CATEGORY = "_for_testing"
|
||||
|
||||
def decode(self, vae, samples, tile_size):
|
||||
return (vae.decode_tiled(samples["samples"], tile_x=tile_size // 8, tile_y=tile_size // 8, ), )
|
||||
def decode(self, vae, samples, tile_size, overlap=64):
|
||||
if tile_size < overlap * 4:
|
||||
overlap = tile_size // 4
|
||||
compression = vae.spacial_compression_decode()
|
||||
images = vae.decode_tiled(samples["samples"], tile_x=tile_size // compression, tile_y=tile_size // compression, overlap=overlap // compression)
|
||||
if len(images.shape) == 5: #Combine batches
|
||||
images = images.reshape(-1, images.shape[-3], images.shape[-2], images.shape[-1])
|
||||
return (images, )
|
||||
|
||||
class VAEEncode:
|
||||
@classmethod
|
||||
@@ -376,6 +383,7 @@ class InpaintModelConditioning:
|
||||
"vae": ("VAE", ),
|
||||
"pixels": ("IMAGE", ),
|
||||
"mask": ("MASK", ),
|
||||
"noise_mask": ("BOOLEAN", {"default": True, "tooltip": "Add a noise mask to the latent so sampling will only happen within the mask. Might improve results or completely break things depending on the model."}),
|
||||
}}
|
||||
|
||||
RETURN_TYPES = ("CONDITIONING","CONDITIONING","LATENT")
|
||||
@@ -384,7 +392,7 @@ class InpaintModelConditioning:
|
||||
|
||||
CATEGORY = "conditioning/inpaint"
|
||||
|
||||
def encode(self, positive, negative, pixels, vae, mask):
|
||||
def encode(self, positive, negative, pixels, vae, mask, noise_mask):
|
||||
x = (pixels.shape[1] // 8) * 8
|
||||
y = (pixels.shape[2] // 8) * 8
|
||||
mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(pixels.shape[1], pixels.shape[2]), mode="bilinear")
|
||||
@@ -408,7 +416,8 @@ class InpaintModelConditioning:
|
||||
out_latent = {}
|
||||
|
||||
out_latent["samples"] = orig_latent
|
||||
out_latent["noise_mask"] = mask
|
||||
if noise_mask:
|
||||
out_latent["noise_mask"] = mask
|
||||
|
||||
out = []
|
||||
for conditioning in [positive, negative]:
|
||||
@@ -889,13 +898,15 @@ class CLIPLoader:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "clip_name": (folder_paths.get_filename_list("text_encoders"), ),
|
||||
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi"], ),
|
||||
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv"], ),
|
||||
}}
|
||||
RETURN_TYPES = ("CLIP",)
|
||||
FUNCTION = "load_clip"
|
||||
|
||||
CATEGORY = "advanced/loaders"
|
||||
|
||||
DESCRIPTION = "[Recipes]\n\nstable_diffusion: clip-l\nstable_cascade: clip-g\nsd3: t5 / clip-g / clip-l\nstable_audio: t5\nmochi: t5"
|
||||
|
||||
def load_clip(self, clip_name, type="stable_diffusion"):
|
||||
if type == "stable_cascade":
|
||||
clip_type = comfy.sd.CLIPType.STABLE_CASCADE
|
||||
@@ -905,6 +916,8 @@ class CLIPLoader:
|
||||
clip_type = comfy.sd.CLIPType.STABLE_AUDIO
|
||||
elif type == "mochi":
|
||||
clip_type = comfy.sd.CLIPType.MOCHI
|
||||
elif type == "ltxv":
|
||||
clip_type = comfy.sd.CLIPType.LTXV
|
||||
else:
|
||||
clip_type = comfy.sd.CLIPType.STABLE_DIFFUSION
|
||||
|
||||
@@ -924,6 +937,8 @@ class DualCLIPLoader:
|
||||
|
||||
CATEGORY = "advanced/loaders"
|
||||
|
||||
DESCRIPTION = "[Recipes]\n\nsdxl: clip-l, clip-g\nsd3: clip-l, clip-g / clip-l, t5 / clip-g, t5\nflux: clip-l, t5"
|
||||
|
||||
def load_clip(self, clip_name1, clip_name2, type):
|
||||
clip_path1 = folder_paths.get_full_path_or_raise("text_encoders", clip_name1)
|
||||
clip_path2 = folder_paths.get_full_path_or_raise("text_encoders", clip_name2)
|
||||
@@ -2123,6 +2138,8 @@ def init_builtin_extra_nodes():
|
||||
"nodes_lora_extract.py",
|
||||
"nodes_torch_compile.py",
|
||||
"nodes_mochi.py",
|
||||
"nodes_slg.py",
|
||||
"nodes_lt.py",
|
||||
]
|
||||
|
||||
import_failed = []
|
||||
|
@@ -152,7 +152,7 @@ class PromptServer():
|
||||
mimetypes.types_map['.js'] = 'application/javascript; charset=utf-8'
|
||||
|
||||
self.user_manager = UserManager()
|
||||
self.internal_routes = InternalRoutes()
|
||||
self.internal_routes = InternalRoutes(self)
|
||||
self.supports = ["custom_nodes_from_web"]
|
||||
self.prompt_queue = None
|
||||
self.loop = loop
|
||||
|
@@ -14,7 +14,7 @@ def user_manager(tmp_path):
|
||||
um = UserManager()
|
||||
um.get_request_user_filepath = lambda req, file, **kwargs: os.path.join(
|
||||
tmp_path, file
|
||||
)
|
||||
) if file else tmp_path
|
||||
return um
|
||||
|
||||
|
||||
@@ -80,9 +80,7 @@ async def test_listuserdata_split_path(aiohttp_client, app, tmp_path):
|
||||
client = await aiohttp_client(app)
|
||||
resp = await client.get("/userdata?dir=test_dir&recurse=true&split=true")
|
||||
assert resp.status == 200
|
||||
assert await resp.json() == [
|
||||
["subdir/file1.txt", "subdir", "file1.txt"]
|
||||
]
|
||||
assert await resp.json() == [["subdir/file1.txt", "subdir", "file1.txt"]]
|
||||
|
||||
|
||||
async def test_listuserdata_invalid_directory(aiohttp_client, app):
|
||||
@@ -118,3 +116,116 @@ async def test_listuserdata_normalized_separator(aiohttp_client, app, tmp_path):
|
||||
assert "/" in result[0]["path"] # Ensure forward slash is used
|
||||
assert "\\" not in result[0]["path"] # Ensure backslash is not present
|
||||
assert result[0]["path"] == "subdir/file1.txt"
|
||||
|
||||
|
||||
async def test_post_userdata_new_file(aiohttp_client, app, tmp_path):
|
||||
client = await aiohttp_client(app)
|
||||
content = b"test content"
|
||||
resp = await client.post("/userdata/test.txt", data=content)
|
||||
|
||||
assert resp.status == 200
|
||||
assert await resp.text() == '"test.txt"'
|
||||
|
||||
# Verify file was created with correct content
|
||||
with open(tmp_path / "test.txt", "rb") as f:
|
||||
assert f.read() == content
|
||||
|
||||
|
||||
async def test_post_userdata_overwrite_existing(aiohttp_client, app, tmp_path):
|
||||
# Create initial file
|
||||
with open(tmp_path / "test.txt", "w") as f:
|
||||
f.write("initial content")
|
||||
|
||||
client = await aiohttp_client(app)
|
||||
new_content = b"updated content"
|
||||
resp = await client.post("/userdata/test.txt", data=new_content)
|
||||
|
||||
assert resp.status == 200
|
||||
assert await resp.text() == '"test.txt"'
|
||||
|
||||
# Verify file was overwritten
|
||||
with open(tmp_path / "test.txt", "rb") as f:
|
||||
assert f.read() == new_content
|
||||
|
||||
|
||||
async def test_post_userdata_no_overwrite(aiohttp_client, app, tmp_path):
|
||||
# Create initial file
|
||||
with open(tmp_path / "test.txt", "w") as f:
|
||||
f.write("initial content")
|
||||
|
||||
client = await aiohttp_client(app)
|
||||
resp = await client.post("/userdata/test.txt?overwrite=false", data=b"new content")
|
||||
|
||||
assert resp.status == 409
|
||||
|
||||
# Verify original content unchanged
|
||||
with open(tmp_path / "test.txt", "r") as f:
|
||||
assert f.read() == "initial content"
|
||||
|
||||
|
||||
async def test_post_userdata_full_info(aiohttp_client, app, tmp_path):
|
||||
client = await aiohttp_client(app)
|
||||
content = b"test content"
|
||||
resp = await client.post("/userdata/test.txt?full_info=true", data=content)
|
||||
|
||||
assert resp.status == 200
|
||||
result = await resp.json()
|
||||
assert result["path"] == "test.txt"
|
||||
assert result["size"] == len(content)
|
||||
assert "modified" in result
|
||||
|
||||
|
||||
async def test_move_userdata(aiohttp_client, app, tmp_path):
|
||||
# Create initial file
|
||||
with open(tmp_path / "source.txt", "w") as f:
|
||||
f.write("test content")
|
||||
|
||||
client = await aiohttp_client(app)
|
||||
resp = await client.post("/userdata/source.txt/move/dest.txt")
|
||||
|
||||
assert resp.status == 200
|
||||
assert await resp.text() == '"dest.txt"'
|
||||
|
||||
# Verify file was moved
|
||||
assert not os.path.exists(tmp_path / "source.txt")
|
||||
with open(tmp_path / "dest.txt", "r") as f:
|
||||
assert f.read() == "test content"
|
||||
|
||||
|
||||
async def test_move_userdata_no_overwrite(aiohttp_client, app, tmp_path):
|
||||
# Create source and destination files
|
||||
with open(tmp_path / "source.txt", "w") as f:
|
||||
f.write("source content")
|
||||
with open(tmp_path / "dest.txt", "w") as f:
|
||||
f.write("destination content")
|
||||
|
||||
client = await aiohttp_client(app)
|
||||
resp = await client.post("/userdata/source.txt/move/dest.txt?overwrite=false")
|
||||
|
||||
assert resp.status == 409
|
||||
|
||||
# Verify files remain unchanged
|
||||
with open(tmp_path / "source.txt", "r") as f:
|
||||
assert f.read() == "source content"
|
||||
with open(tmp_path / "dest.txt", "r") as f:
|
||||
assert f.read() == "destination content"
|
||||
|
||||
|
||||
async def test_move_userdata_full_info(aiohttp_client, app, tmp_path):
|
||||
# Create initial file
|
||||
with open(tmp_path / "source.txt", "w") as f:
|
||||
f.write("test content")
|
||||
|
||||
client = await aiohttp_client(app)
|
||||
resp = await client.post("/userdata/source.txt/move/dest.txt?full_info=true")
|
||||
|
||||
assert resp.status == 200
|
||||
result = await resp.json()
|
||||
assert result["path"] == "dest.txt"
|
||||
assert result["size"] == len("test content")
|
||||
assert "modified" in result
|
||||
|
||||
# Verify file was moved
|
||||
assert not os.path.exists(tmp_path / "source.txt")
|
||||
with open(tmp_path / "dest.txt", "r") as f:
|
||||
assert f.read() == "test content"
|
||||
|
@@ -8,7 +8,7 @@ from folder_paths import models_dir, user_directory, output_directory
|
||||
|
||||
@pytest.fixture
|
||||
def internal_routes():
|
||||
return InternalRoutes()
|
||||
return InternalRoutes(None)
|
||||
|
||||
@pytest.fixture
|
||||
def aiohttp_client_factory(aiohttp_client, internal_routes):
|
||||
@@ -102,7 +102,7 @@ async def test_file_service_initialization():
|
||||
# Create a mock instance
|
||||
mock_file_service_instance = MagicMock(spec=FileService)
|
||||
MockFileService.return_value = mock_file_service_instance
|
||||
internal_routes = InternalRoutes()
|
||||
internal_routes = InternalRoutes(None)
|
||||
|
||||
# Check if FileService was initialized with the correct parameters
|
||||
MockFileService.assert_called_once_with({
|
||||
@@ -112,4 +112,4 @@ async def test_file_service_initialization():
|
||||
})
|
||||
|
||||
# Verify that the file_service attribute of InternalRoutes is set
|
||||
assert internal_routes.file_service == mock_file_service_instance
|
||||
assert internal_routes.file_service == mock_file_service_instance
|
||||
|
8
web/assets/ExtensionPanel-BmKi_NKS.js → web/assets/ExtensionPanel-CfMfcLgI.js
generated
vendored
8
web/assets/ExtensionPanel-BmKi_NKS.js → web/assets/ExtensionPanel-CfMfcLgI.js
generated
vendored
@@ -1,8 +1,8 @@
|
||||
var __defProp = Object.defineProperty;
|
||||
var __name = (target, value) => __defProp(target, "name", { value, configurable: true });
|
||||
import { d as defineComponent, bQ as useExtensionStore, u as useSettingStore, r as ref, o as onMounted, q as computed, g as openBlock, h as createElementBlock, i as createVNode, y as withCtx, z as unref, bR as script$1, A as createBaseVNode, x as createBlock, N as Fragment, O as renderList, a6 as toDisplayString, aw as createTextVNode, j as createCommentVNode, D as script$4 } from "./index-BHayQCxv.js";
|
||||
import { s as script, a as script$2, b as script$3 } from "./index-CwRXxFdA.js";
|
||||
import "./index-C_wOqB0f.js";
|
||||
import { d as defineComponent, c6 as useExtensionStore, u as useSettingStore, r as ref, o as onMounted, q as computed, g as openBlock, h as createElementBlock, i as createVNode, y as withCtx, z as unref, bT as script$1, A as createBaseVNode, x as createBlock, N as Fragment, O as renderList, a6 as toDisplayString, aw as createTextVNode, bR as script$3, j as createCommentVNode, D as script$4 } from "./index-B6dYHNhg.js";
|
||||
import { s as script, a as script$2 } from "./index-CjwCGacA.js";
|
||||
import "./index-MX9DEi8Q.js";
|
||||
const _hoisted_1 = { class: "extension-panel" };
|
||||
const _hoisted_2 = { class: "mt-4" };
|
||||
const _sfc_main = /* @__PURE__ */ defineComponent({
|
||||
@@ -100,4 +100,4 @@ const _sfc_main = /* @__PURE__ */ defineComponent({
|
||||
export {
|
||||
_sfc_main as default
|
||||
};
|
||||
//# sourceMappingURL=ExtensionPanel-BmKi_NKS.js.map
|
||||
//# sourceMappingURL=ExtensionPanel-CfMfcLgI.js.map
|
@@ -1 +1 @@
|
||||
{"version":3,"file":"ExtensionPanel-BmKi_NKS.js","sources":["../../src/components/dialog/content/setting/ExtensionPanel.vue"],"sourcesContent":["<template>\n <div class=\"extension-panel\">\n <DataTable :value=\"extensionStore.extensions\" stripedRows size=\"small\">\n <Column field=\"name\" :header=\"$t('extensionName')\" sortable></Column>\n <Column\n :pt=\"{\n bodyCell: 'flex items-center justify-end'\n }\"\n >\n <template #body=\"slotProps\">\n <ToggleSwitch\n v-model=\"editingEnabledExtensions[slotProps.data.name]\"\n @change=\"updateExtensionStatus\"\n />\n </template>\n </Column>\n </DataTable>\n <div class=\"mt-4\">\n <Message v-if=\"hasChanges\" severity=\"info\">\n <ul>\n <li v-for=\"ext in changedExtensions\" :key=\"ext.name\">\n <span>\n {{ extensionStore.isExtensionEnabled(ext.name) ? '[-]' : '[+]' }}\n </span>\n {{ ext.name }}\n </li>\n </ul>\n </Message>\n <Button\n :label=\"$t('reloadToApplyChanges')\"\n icon=\"pi pi-refresh\"\n @click=\"applyChanges\"\n :disabled=\"!hasChanges\"\n text\n fluid\n severity=\"danger\"\n />\n </div>\n </div>\n</template>\n\n<script setup lang=\"ts\">\nimport { ref, computed, onMounted } from 'vue'\nimport { useExtensionStore } from '@/stores/extensionStore'\nimport { useSettingStore } from '@/stores/settingStore'\nimport DataTable from 'primevue/datatable'\nimport Column from 'primevue/column'\nimport ToggleSwitch from 'primevue/toggleswitch'\nimport Button from 'primevue/button'\nimport Message from 'primevue/message'\n\nconst extensionStore = useExtensionStore()\nconst settingStore = useSettingStore()\n\nconst editingEnabledExtensions = ref<Record<string, boolean>>({})\n\nonMounted(() => {\n extensionStore.extensions.forEach((ext) => {\n editingEnabledExtensions.value[ext.name] =\n extensionStore.isExtensionEnabled(ext.name)\n })\n})\n\nconst changedExtensions = computed(() => {\n return extensionStore.extensions.filter(\n (ext) =>\n editingEnabledExtensions.value[ext.name] !==\n extensionStore.isExtensionEnabled(ext.name)\n )\n})\n\nconst hasChanges = computed(() => {\n return changedExtensions.value.length > 0\n})\n\nconst updateExtensionStatus = () => {\n const editingDisabledExtensionNames = Object.entries(\n editingEnabledExtensions.value\n )\n .filter(([_, enabled]) => !enabled)\n .map(([name]) => name)\n\n settingStore.set('Comfy.Extension.Disabled', [\n ...extensionStore.inactiveDisabledExtensionNames,\n ...editingDisabledExtensionNames\n ])\n}\n\nconst applyChanges = () => {\n // Refresh the page to apply changes\n window.location.reload()\n}\n</script>\n"],"names":[],"mappings":";;;;;;;;;;AAmDA,UAAM,iBAAiB;AACvB,UAAM,eAAe;AAEf,UAAA,2BAA2B,IAA6B,CAAA,CAAE;AAEhE,cAAU,MAAM;AACC,qBAAA,WAAW,QAAQ,CAAC,QAAQ;AACzC,iCAAyB,MAAM,IAAI,IAAI,IACrC,eAAe,mBAAmB,IAAI,IAAI;AAAA,MAAA,CAC7C;AAAA,IAAA,CACF;AAEK,UAAA,oBAAoB,SAAS,MAAM;AACvC,aAAO,eAAe,WAAW;AAAA,QAC/B,CAAC,QACC,yBAAyB,MAAM,IAAI,IAAI,MACvC,eAAe,mBAAmB,IAAI,IAAI;AAAA,MAAA;AAAA,IAC9C,CACD;AAEK,UAAA,aAAa,SAAS,MAAM;AACzB,aAAA,kBAAkB,MAAM,SAAS;AAAA,IAAA,CACzC;AAED,UAAM,wBAAwB,6BAAM;AAClC,YAAM,gCAAgC,OAAO;AAAA,QAC3C,yBAAyB;AAAA,MAExB,EAAA,OAAO,CAAC,CAAC,GAAG,OAAO,MAAM,CAAC,OAAO,EACjC,IAAI,CAAC,CAAC,IAAI,MAAM,IAAI;AAEvB,mBAAa,IAAI,4BAA4B;AAAA,QAC3C,GAAG,eAAe;AAAA,QAClB,GAAG;AAAA,MAAA,CACJ;AAAA,IAAA,GAV2B;AAa9B,UAAM,eAAe,6BAAM;AAEzB,aAAO,SAAS;IAAO,GAFJ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;"}
|
||||
{"version":3,"file":"ExtensionPanel-CfMfcLgI.js","sources":["../../src/components/dialog/content/setting/ExtensionPanel.vue"],"sourcesContent":["<template>\n <div class=\"extension-panel\">\n <DataTable :value=\"extensionStore.extensions\" stripedRows size=\"small\">\n <Column field=\"name\" :header=\"$t('extensionName')\" sortable></Column>\n <Column\n :pt=\"{\n bodyCell: 'flex items-center justify-end'\n }\"\n >\n <template #body=\"slotProps\">\n <ToggleSwitch\n v-model=\"editingEnabledExtensions[slotProps.data.name]\"\n @change=\"updateExtensionStatus\"\n />\n </template>\n </Column>\n </DataTable>\n <div class=\"mt-4\">\n <Message v-if=\"hasChanges\" severity=\"info\">\n <ul>\n <li v-for=\"ext in changedExtensions\" :key=\"ext.name\">\n <span>\n {{ extensionStore.isExtensionEnabled(ext.name) ? '[-]' : '[+]' }}\n </span>\n {{ ext.name }}\n </li>\n </ul>\n </Message>\n <Button\n :label=\"$t('reloadToApplyChanges')\"\n icon=\"pi pi-refresh\"\n @click=\"applyChanges\"\n :disabled=\"!hasChanges\"\n text\n fluid\n severity=\"danger\"\n />\n </div>\n </div>\n</template>\n\n<script setup lang=\"ts\">\nimport { ref, computed, onMounted } from 'vue'\nimport { useExtensionStore } from '@/stores/extensionStore'\nimport { useSettingStore } from '@/stores/settingStore'\nimport DataTable from 'primevue/datatable'\nimport Column from 'primevue/column'\nimport ToggleSwitch from 'primevue/toggleswitch'\nimport Button from 'primevue/button'\nimport Message from 'primevue/message'\n\nconst extensionStore = useExtensionStore()\nconst settingStore = useSettingStore()\n\nconst editingEnabledExtensions = ref<Record<string, boolean>>({})\n\nonMounted(() => {\n extensionStore.extensions.forEach((ext) => {\n editingEnabledExtensions.value[ext.name] =\n extensionStore.isExtensionEnabled(ext.name)\n })\n})\n\nconst changedExtensions = computed(() => {\n return extensionStore.extensions.filter(\n (ext) =>\n editingEnabledExtensions.value[ext.name] !==\n extensionStore.isExtensionEnabled(ext.name)\n )\n})\n\nconst hasChanges = computed(() => {\n return changedExtensions.value.length > 0\n})\n\nconst updateExtensionStatus = () => {\n const editingDisabledExtensionNames = Object.entries(\n editingEnabledExtensions.value\n )\n .filter(([_, enabled]) => !enabled)\n .map(([name]) => name)\n\n settingStore.set('Comfy.Extension.Disabled', [\n ...extensionStore.inactiveDisabledExtensionNames,\n ...editingDisabledExtensionNames\n ])\n}\n\nconst applyChanges = () => {\n // Refresh the page to apply changes\n window.location.reload()\n}\n</script>\n"],"names":[],"mappings":";;;;;;;;;;AAmDA,UAAM,iBAAiB;AACvB,UAAM,eAAe;AAEf,UAAA,2BAA2B,IAA6B,CAAA,CAAE;AAEhE,cAAU,MAAM;AACC,qBAAA,WAAW,QAAQ,CAAC,QAAQ;AACzC,iCAAyB,MAAM,IAAI,IAAI,IACrC,eAAe,mBAAmB,IAAI,IAAI;AAAA,MAAA,CAC7C;AAAA,IAAA,CACF;AAEK,UAAA,oBAAoB,SAAS,MAAM;AACvC,aAAO,eAAe,WAAW;AAAA,QAC/B,CAAC,QACC,yBAAyB,MAAM,IAAI,IAAI,MACvC,eAAe,mBAAmB,IAAI,IAAI;AAAA,MAAA;AAAA,IAC9C,CACD;AAEK,UAAA,aAAa,SAAS,MAAM;AACzB,aAAA,kBAAkB,MAAM,SAAS;AAAA,IAAA,CACzC;AAED,UAAM,wBAAwB,6BAAM;AAClC,YAAM,gCAAgC,OAAO;AAAA,QAC3C,yBAAyB;AAAA,MAExB,EAAA,OAAO,CAAC,CAAC,GAAG,OAAO,MAAM,CAAC,OAAO,EACjC,IAAI,CAAC,CAAC,IAAI,MAAM,IAAI;AAEvB,mBAAa,IAAI,4BAA4B;AAAA,QAC3C,GAAG,eAAe;AAAA,QAClB,GAAG;AAAA,MAAA,CACJ;AAAA,IAAA,GAV2B;AAa9B,UAAM,eAAe,6BAAM;AAEzB,aAAO,SAAS;IAAO,GAFJ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;"}
|
627
web/assets/GraphView-C4blCugc.js → web/assets/GraphView-BCOd0Zle.js
generated
vendored
627
web/assets/GraphView-C4blCugc.js → web/assets/GraphView-BCOd0Zle.js
generated
vendored
File diff suppressed because one or more lines are too long
1
web/assets/GraphView-BCOd0Zle.js.map
generated
vendored
Normal file
1
web/assets/GraphView-BCOd0Zle.js.map
generated
vendored
Normal file
File diff suppressed because one or more lines are too long
1
web/assets/GraphView-C4blCugc.js.map
generated
vendored
1
web/assets/GraphView-C4blCugc.js.map
generated
vendored
File diff suppressed because one or more lines are too long
64
web/assets/GraphView-Cf7ubG48.css → web/assets/GraphView-CghYAxkP.css
generated
vendored
64
web/assets/GraphView-Cf7ubG48.css → web/assets/GraphView-CghYAxkP.css
generated
vendored
@@ -45,7 +45,7 @@
|
||||
--sidebar-icon-size: 1rem;
|
||||
}
|
||||
|
||||
.side-tool-bar-container[data-v-37fd2fa4] {
|
||||
.side-tool-bar-container[data-v-e0812a25] {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
align-items: center;
|
||||
@@ -58,28 +58,32 @@
|
||||
background-color: var(--comfy-menu-bg);
|
||||
color: var(--fg-color);
|
||||
}
|
||||
.side-tool-bar-end[data-v-37fd2fa4] {
|
||||
.side-tool-bar-end[data-v-e0812a25] {
|
||||
align-self: flex-end;
|
||||
margin-top: auto;
|
||||
}
|
||||
|
||||
[data-v-b49f20b1] .p-splitter-gutter {
|
||||
[data-v-7c3279c1] .p-splitter-gutter {
|
||||
pointer-events: auto;
|
||||
}
|
||||
.side-bar-panel[data-v-b49f20b1] {
|
||||
[data-v-7c3279c1] .p-splitter-gutter:hover,[data-v-7c3279c1] .p-splitter-gutter[data-p-gutter-resizing='true'] {
|
||||
transition: background-color 0.2s ease 300ms;
|
||||
background-color: var(--p-primary-color);
|
||||
}
|
||||
.side-bar-panel[data-v-7c3279c1] {
|
||||
background-color: var(--bg-color);
|
||||
pointer-events: auto;
|
||||
}
|
||||
.bottom-panel[data-v-b49f20b1] {
|
||||
.bottom-panel[data-v-7c3279c1] {
|
||||
background-color: var(--bg-color);
|
||||
pointer-events: auto;
|
||||
}
|
||||
.splitter-overlay[data-v-b49f20b1] {
|
||||
.splitter-overlay[data-v-7c3279c1] {
|
||||
pointer-events: none;
|
||||
border-style: none;
|
||||
background-color: transparent;
|
||||
}
|
||||
.splitter-overlay-root[data-v-b49f20b1] {
|
||||
.splitter-overlay-root[data-v-7c3279c1] {
|
||||
position: absolute;
|
||||
top: 0px;
|
||||
left: 0px;
|
||||
@@ -146,7 +150,7 @@
|
||||
align-items: flex-start !important;
|
||||
}
|
||||
|
||||
.node-tooltip[data-v-79ec8c53] {
|
||||
.node-tooltip[data-v-c2e0098f] {
|
||||
background: var(--comfy-input-bg);
|
||||
border-radius: 5px;
|
||||
box-shadow: 0 0 5px rgba(0, 0, 0, 0.4);
|
||||
@@ -162,22 +166,28 @@
|
||||
z-index: 99999;
|
||||
}
|
||||
|
||||
.p-buttongroup-vertical[data-v-444d3768] {
|
||||
.p-buttongroup-vertical[data-v-94481f39] {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
border-radius: var(--p-button-border-radius);
|
||||
overflow: hidden;
|
||||
border: 1px solid var(--p-panel-border-color);
|
||||
}
|
||||
.p-buttongroup-vertical .p-button[data-v-444d3768] {
|
||||
.p-buttongroup-vertical .p-button[data-v-94481f39] {
|
||||
margin: 0;
|
||||
border-radius: 0;
|
||||
}
|
||||
|
||||
[data-v-84e785b8] .p-togglebutton::before {
|
||||
.comfy-menu-hamburger[data-v-2ddd26e8] {
|
||||
pointer-events: auto;
|
||||
position: fixed;
|
||||
z-index: 9999;
|
||||
}
|
||||
|
||||
[data-v-9eb975c3] .p-togglebutton::before {
|
||||
display: none
|
||||
}
|
||||
[data-v-84e785b8] .p-togglebutton {
|
||||
[data-v-9eb975c3] .p-togglebutton {
|
||||
position: relative;
|
||||
flex-shrink: 0;
|
||||
border-radius: 0px;
|
||||
@@ -185,14 +195,14 @@
|
||||
padding-left: 0.5rem;
|
||||
padding-right: 0.5rem
|
||||
}
|
||||
[data-v-84e785b8] .p-togglebutton.p-togglebutton-checked {
|
||||
[data-v-9eb975c3] .p-togglebutton.p-togglebutton-checked {
|
||||
border-bottom-width: 2px;
|
||||
border-bottom-color: var(--p-button-text-primary-color)
|
||||
}
|
||||
[data-v-84e785b8] .p-togglebutton-checked .close-button,[data-v-84e785b8] .p-togglebutton:hover .close-button {
|
||||
[data-v-9eb975c3] .p-togglebutton-checked .close-button,[data-v-9eb975c3] .p-togglebutton:hover .close-button {
|
||||
visibility: visible
|
||||
}
|
||||
.status-indicator[data-v-84e785b8] {
|
||||
.status-indicator[data-v-9eb975c3] {
|
||||
position: absolute;
|
||||
font-weight: 700;
|
||||
font-size: 1.5rem;
|
||||
@@ -200,10 +210,10 @@
|
||||
left: 50%;
|
||||
transform: translate(-50%, -50%)
|
||||
}
|
||||
[data-v-84e785b8] .p-togglebutton:hover .status-indicator {
|
||||
[data-v-9eb975c3] .p-togglebutton:hover .status-indicator {
|
||||
display: none
|
||||
}
|
||||
[data-v-84e785b8] .p-togglebutton .close-button {
|
||||
[data-v-9eb975c3] .p-togglebutton .close-button {
|
||||
visibility: hidden
|
||||
}
|
||||
|
||||
@@ -226,35 +236,35 @@
|
||||
border-bottom-left-radius: 0;
|
||||
}
|
||||
|
||||
.comfyui-queue-button[data-v-2b80bf74] .p-splitbutton-dropdown {
|
||||
.comfyui-queue-button[data-v-95bc9be0] .p-splitbutton-dropdown {
|
||||
border-top-right-radius: 0;
|
||||
border-bottom-right-radius: 0;
|
||||
}
|
||||
|
||||
.actionbar[data-v-2e54db00] {
|
||||
.actionbar[data-v-eb6e9acf] {
|
||||
pointer-events: all;
|
||||
position: fixed;
|
||||
z-index: 1000;
|
||||
}
|
||||
.actionbar.is-docked[data-v-2e54db00] {
|
||||
.actionbar.is-docked[data-v-eb6e9acf] {
|
||||
position: static;
|
||||
border-style: none;
|
||||
background-color: transparent;
|
||||
padding: 0px;
|
||||
}
|
||||
.actionbar.is-dragging[data-v-2e54db00] {
|
||||
.actionbar.is-dragging[data-v-eb6e9acf] {
|
||||
-webkit-user-select: none;
|
||||
-moz-user-select: none;
|
||||
user-select: none;
|
||||
}
|
||||
[data-v-2e54db00] .p-panel-content {
|
||||
[data-v-eb6e9acf] .p-panel-content {
|
||||
padding: 0.25rem;
|
||||
}
|
||||
[data-v-2e54db00] .p-panel-header {
|
||||
[data-v-eb6e9acf] .p-panel-header {
|
||||
display: none;
|
||||
}
|
||||
|
||||
.comfyui-menu[data-v-ad2c662b] {
|
||||
.comfyui-menu[data-v-d84a704d] {
|
||||
width: 100vw;
|
||||
background: var(--comfy-menu-bg);
|
||||
color: var(--fg-color);
|
||||
@@ -266,13 +276,13 @@
|
||||
grid-column: 1/-1;
|
||||
max-height: 90vh;
|
||||
}
|
||||
.comfyui-menu.dropzone[data-v-ad2c662b] {
|
||||
.comfyui-menu.dropzone[data-v-d84a704d] {
|
||||
background: var(--p-highlight-background);
|
||||
}
|
||||
.comfyui-menu.dropzone-active[data-v-ad2c662b] {
|
||||
.comfyui-menu.dropzone-active[data-v-d84a704d] {
|
||||
background: var(--p-highlight-background-focus);
|
||||
}
|
||||
.comfyui-logo[data-v-ad2c662b] {
|
||||
.comfyui-logo[data-v-d84a704d] {
|
||||
font-size: 1.2em;
|
||||
-webkit-user-select: none;
|
||||
-moz-user-select: none;
|
4
web/assets/InstallView-CN3CA9Fk.css
generated
vendored
Normal file
4
web/assets/InstallView-CN3CA9Fk.css
generated
vendored
Normal file
@@ -0,0 +1,4 @@
|
||||
|
||||
[data-v-53e62b05] .p-steppanel {
|
||||
background-color: transparent
|
||||
}
|
1048
web/assets/InstallView-D9ueAxrz.js
generated
vendored
Normal file
1048
web/assets/InstallView-D9ueAxrz.js
generated
vendored
Normal file
File diff suppressed because one or more lines are too long
1
web/assets/InstallView-D9ueAxrz.js.map
generated
vendored
Normal file
1
web/assets/InstallView-D9ueAxrz.js.map
generated
vendored
Normal file
File diff suppressed because one or more lines are too long
8
web/assets/KeybindingPanel-BNYKhW1k.css
generated
vendored
8
web/assets/KeybindingPanel-BNYKhW1k.css
generated
vendored
@@ -1,8 +0,0 @@
|
||||
|
||||
[data-v-e5724e4d] .p-datatable-tbody > tr > td {
|
||||
padding: 1px;
|
||||
min-height: 2rem;
|
||||
}
|
||||
[data-v-e5724e4d] .p-datatable-row-selected .actions,[data-v-e5724e4d] .p-datatable-selectable-row:hover .actions {
|
||||
visibility: visible;
|
||||
}
|
8
web/assets/KeybindingPanel-CB_wEOHl.css
generated
vendored
Normal file
8
web/assets/KeybindingPanel-CB_wEOHl.css
generated
vendored
Normal file
@@ -0,0 +1,8 @@
|
||||
|
||||
[data-v-2d8b3a76] .p-datatable-tbody > tr > td {
|
||||
padding: 0.25rem;
|
||||
min-height: 2rem
|
||||
}
|
||||
[data-v-2d8b3a76] .p-datatable-row-selected .actions,[data-v-2d8b3a76] .p-datatable-selectable-row:hover .actions {
|
||||
visibility: visible
|
||||
}
|
30
web/assets/KeybindingPanel-Dm_3sBT5.js → web/assets/KeybindingPanel-DcEfyPZZ.js
generated
vendored
30
web/assets/KeybindingPanel-Dm_3sBT5.js → web/assets/KeybindingPanel-DcEfyPZZ.js
generated
vendored
@@ -1,8 +1,8 @@
|
||||
var __defProp = Object.defineProperty;
|
||||
var __name = (target, value) => __defProp(target, "name", { value, configurable: true });
|
||||
import { d as defineComponent, q as computed, g as openBlock, h as createElementBlock, N as Fragment, O as renderList, i as createVNode, y as withCtx, aw as createTextVNode, a6 as toDisplayString, z as unref, aA as script, j as createCommentVNode, r as ref, bN as FilterMatchMode, M as useKeybindingStore, F as useCommandStore, aJ as watchEffect, b9 as useToast, t as resolveDirective, bO as SearchBox, A as createBaseVNode, D as script$2, x as createBlock, ao as script$4, be as withModifiers, aH as script$6, v as withDirectives, P as pushScopeId, Q as popScopeId, bJ as KeyComboImpl, bP as KeybindingImpl, _ as _export_sfc } from "./index-BHayQCxv.js";
|
||||
import { s as script$1, a as script$3, b as script$5 } from "./index-CwRXxFdA.js";
|
||||
import "./index-C_wOqB0f.js";
|
||||
import { d as defineComponent, q as computed, g as openBlock, h as createElementBlock, N as Fragment, O as renderList, i as createVNode, y as withCtx, aw as createTextVNode, a6 as toDisplayString, z as unref, aA as script, j as createCommentVNode, r as ref, c3 as FilterMatchMode, M as useKeybindingStore, F as useCommandStore, aJ as watchEffect, be as useToast, t as resolveDirective, c4 as SearchBox, A as createBaseVNode, D as script$2, x as createBlock, ao as script$4, bi as withModifiers, bR as script$5, aH as script$6, v as withDirectives, P as pushScopeId, Q as popScopeId, b$ as KeyComboImpl, c5 as KeybindingImpl, _ as _export_sfc } from "./index-B6dYHNhg.js";
|
||||
import { s as script$1, a as script$3 } from "./index-CjwCGacA.js";
|
||||
import "./index-MX9DEi8Q.js";
|
||||
const _hoisted_1$1 = {
|
||||
key: 0,
|
||||
class: "px-2"
|
||||
@@ -35,10 +35,11 @@ const _sfc_main$1 = /* @__PURE__ */ defineComponent({
|
||||
};
|
||||
}
|
||||
});
|
||||
const _withScopeId = /* @__PURE__ */ __name((n) => (pushScopeId("data-v-e5724e4d"), n = n(), popScopeId(), n), "_withScopeId");
|
||||
const _withScopeId = /* @__PURE__ */ __name((n) => (pushScopeId("data-v-2d8b3a76"), n = n(), popScopeId(), n), "_withScopeId");
|
||||
const _hoisted_1 = { class: "keybinding-panel" };
|
||||
const _hoisted_2 = { class: "actions invisible" };
|
||||
const _hoisted_3 = { key: 1 };
|
||||
const _hoisted_2 = { class: "actions invisible flex flex-row" };
|
||||
const _hoisted_3 = ["title"];
|
||||
const _hoisted_4 = { key: 1 };
|
||||
const _sfc_main = /* @__PURE__ */ defineComponent({
|
||||
__name: "KeybindingPanel",
|
||||
setup(__props) {
|
||||
@@ -177,7 +178,16 @@ const _sfc_main = /* @__PURE__ */ defineComponent({
|
||||
createVNode(unref(script$1), {
|
||||
field: "id",
|
||||
header: "Command ID",
|
||||
sortable: ""
|
||||
sortable: "",
|
||||
class: "max-w-64 2xl:max-w-full"
|
||||
}, {
|
||||
body: withCtx((slotProps) => [
|
||||
createBaseVNode("div", {
|
||||
class: "overflow-hidden text-ellipsis whitespace-nowrap",
|
||||
title: slotProps.data.id
|
||||
}, toDisplayString(slotProps.data.id), 9, _hoisted_3)
|
||||
]),
|
||||
_: 1
|
||||
}),
|
||||
createVNode(unref(script$1), {
|
||||
field: "keybinding",
|
||||
@@ -188,7 +198,7 @@ const _sfc_main = /* @__PURE__ */ defineComponent({
|
||||
key: 0,
|
||||
keyCombo: slotProps.data.keybinding.combo,
|
||||
isModified: unref(keybindingStore).isCommandKeybindingModified(slotProps.data.id)
|
||||
}, null, 8, ["keyCombo", "isModified"])) : (openBlock(), createElementBlock("span", _hoisted_3, "-"))
|
||||
}, null, 8, ["keyCombo", "isModified"])) : (openBlock(), createElementBlock("span", _hoisted_4, "-"))
|
||||
]),
|
||||
_: 1
|
||||
})
|
||||
@@ -257,8 +267,8 @@ const _sfc_main = /* @__PURE__ */ defineComponent({
|
||||
};
|
||||
}
|
||||
});
|
||||
const KeybindingPanel = /* @__PURE__ */ _export_sfc(_sfc_main, [["__scopeId", "data-v-e5724e4d"]]);
|
||||
const KeybindingPanel = /* @__PURE__ */ _export_sfc(_sfc_main, [["__scopeId", "data-v-2d8b3a76"]]);
|
||||
export {
|
||||
KeybindingPanel as default
|
||||
};
|
||||
//# sourceMappingURL=KeybindingPanel-Dm_3sBT5.js.map
|
||||
//# sourceMappingURL=KeybindingPanel-DcEfyPZZ.js.map
|
1
web/assets/KeybindingPanel-DcEfyPZZ.js.map
generated
vendored
Normal file
1
web/assets/KeybindingPanel-DcEfyPZZ.js.map
generated
vendored
Normal file
File diff suppressed because one or more lines are too long
1
web/assets/KeybindingPanel-Dm_3sBT5.js.map
generated
vendored
1
web/assets/KeybindingPanel-Dm_3sBT5.js.map
generated
vendored
File diff suppressed because one or more lines are too long
102
web/assets/ServerStartView-e57oVZ6V.js
generated
vendored
Normal file
102
web/assets/ServerStartView-e57oVZ6V.js
generated
vendored
Normal file
@@ -0,0 +1,102 @@
|
||||
var __defProp = Object.defineProperty;
|
||||
var __name = (target, value) => __defProp(target, "name", { value, configurable: true });
|
||||
import { d as defineComponent, r as ref, o as onMounted, w as watch, I as onBeforeUnmount, g as openBlock, h as createElementBlock, i as createVNode, y as withCtx, A as createBaseVNode, a6 as toDisplayString, z as unref, bK as script, bL as electronAPI } from "./index-B6dYHNhg.js";
|
||||
import { t, s } from "./index-B4gmhi99.js";
|
||||
const _hoisted_1$1 = { class: "p-terminal rounded-none h-full w-full" };
|
||||
const _hoisted_2$1 = { class: "px-4 whitespace-pre-wrap" };
|
||||
const _sfc_main$1 = /* @__PURE__ */ defineComponent({
|
||||
__name: "LogTerminal",
|
||||
props: {
|
||||
fetchLogs: { type: Function },
|
||||
fetchInterval: {}
|
||||
},
|
||||
setup(__props) {
|
||||
const props = __props;
|
||||
const log = ref("");
|
||||
const scrollPanelRef = ref(null);
|
||||
const scrolledToBottom = ref(false);
|
||||
let intervalId = 0;
|
||||
onMounted(async () => {
|
||||
const element = scrollPanelRef.value?.$el;
|
||||
const scrollContainer = element?.querySelector(".p-scrollpanel-content");
|
||||
if (scrollContainer) {
|
||||
scrollContainer.addEventListener("scroll", () => {
|
||||
scrolledToBottom.value = scrollContainer.scrollTop + scrollContainer.clientHeight === scrollContainer.scrollHeight;
|
||||
});
|
||||
}
|
||||
const scrollToBottom = /* @__PURE__ */ __name(() => {
|
||||
if (scrollContainer) {
|
||||
scrollContainer.scrollTop = scrollContainer.scrollHeight;
|
||||
}
|
||||
}, "scrollToBottom");
|
||||
watch(log, () => {
|
||||
if (scrolledToBottom.value) {
|
||||
scrollToBottom();
|
||||
}
|
||||
});
|
||||
const fetchLogs = /* @__PURE__ */ __name(async () => {
|
||||
log.value = await props.fetchLogs();
|
||||
}, "fetchLogs");
|
||||
await fetchLogs();
|
||||
scrollToBottom();
|
||||
intervalId = window.setInterval(fetchLogs, props.fetchInterval);
|
||||
});
|
||||
onBeforeUnmount(() => {
|
||||
window.clearInterval(intervalId);
|
||||
});
|
||||
return (_ctx, _cache) => {
|
||||
return openBlock(), createElementBlock("div", _hoisted_1$1, [
|
||||
createVNode(unref(script), {
|
||||
class: "h-full w-full",
|
||||
ref_key: "scrollPanelRef",
|
||||
ref: scrollPanelRef
|
||||
}, {
|
||||
default: withCtx(() => [
|
||||
createBaseVNode("pre", _hoisted_2$1, toDisplayString(log.value), 1)
|
||||
]),
|
||||
_: 1
|
||||
}, 512)
|
||||
]);
|
||||
};
|
||||
}
|
||||
});
|
||||
const _hoisted_1 = { class: "font-sans flex flex-col justify-center items-center h-screen m-0 text-neutral-300 bg-neutral-900 dark-theme pointer-events-auto" };
|
||||
const _hoisted_2 = { class: "text-2xl font-bold" };
|
||||
const _sfc_main = /* @__PURE__ */ defineComponent({
|
||||
__name: "ServerStartView",
|
||||
setup(__props) {
|
||||
const electron = electronAPI();
|
||||
const status = ref(t.INITIAL_STATE);
|
||||
const logs = ref([]);
|
||||
const updateProgress = /* @__PURE__ */ __name(({ status: newStatus }) => {
|
||||
status.value = newStatus;
|
||||
logs.value = [];
|
||||
}, "updateProgress");
|
||||
const addLogMessage = /* @__PURE__ */ __name((message) => {
|
||||
logs.value = [...logs.value, message];
|
||||
}, "addLogMessage");
|
||||
const fetchLogs = /* @__PURE__ */ __name(async () => {
|
||||
return logs.value.join("\n");
|
||||
}, "fetchLogs");
|
||||
onMounted(() => {
|
||||
electron.sendReady();
|
||||
electron.onProgressUpdate(updateProgress);
|
||||
electron.onLogMessage((message) => {
|
||||
addLogMessage(message);
|
||||
});
|
||||
});
|
||||
return (_ctx, _cache) => {
|
||||
return openBlock(), createElementBlock("div", _hoisted_1, [
|
||||
createBaseVNode("h2", _hoisted_2, toDisplayString(unref(s)[status.value]), 1),
|
||||
createVNode(_sfc_main$1, {
|
||||
"fetch-logs": fetchLogs,
|
||||
"fetch-interval": 500
|
||||
})
|
||||
]);
|
||||
};
|
||||
}
|
||||
});
|
||||
export {
|
||||
_sfc_main as default
|
||||
};
|
||||
//# sourceMappingURL=ServerStartView-e57oVZ6V.js.map
|
1
web/assets/ServerStartView-e57oVZ6V.js.map
generated
vendored
Normal file
1
web/assets/ServerStartView-e57oVZ6V.js.map
generated
vendored
Normal file
@@ -0,0 +1 @@
|
||||
{"version":3,"file":"ServerStartView-e57oVZ6V.js","sources":["../../src/components/common/LogTerminal.vue","../../src/views/ServerStartView.vue"],"sourcesContent":["<!-- A simple read-only terminal component that displays logs. -->\n<template>\n <div class=\"p-terminal rounded-none h-full w-full\">\n <ScrollPanel class=\"h-full w-full\" ref=\"scrollPanelRef\">\n <pre class=\"px-4 whitespace-pre-wrap\">{{ log }}</pre>\n </ScrollPanel>\n </div>\n</template>\n\n<script setup lang=\"ts\">\nimport ScrollPanel from 'primevue/scrollpanel'\nimport { onBeforeUnmount, onMounted, ref, watch } from 'vue'\n\nconst props = defineProps<{\n fetchLogs: () => Promise<string>\n fetchInterval: number\n}>()\n\nconst log = ref<string>('')\nconst scrollPanelRef = ref<InstanceType<typeof ScrollPanel> | null>(null)\n/**\n * Whether the user has scrolled to the bottom of the terminal.\n * This is used to prevent the terminal from scrolling to the bottom\n * when new logs are fetched.\n */\nconst scrolledToBottom = ref(false)\n\nlet intervalId: number = 0\n\nonMounted(async () => {\n const element = scrollPanelRef.value?.$el\n const scrollContainer = element?.querySelector('.p-scrollpanel-content')\n\n if (scrollContainer) {\n scrollContainer.addEventListener('scroll', () => {\n scrolledToBottom.value =\n scrollContainer.scrollTop + scrollContainer.clientHeight ===\n scrollContainer.scrollHeight\n })\n }\n\n const scrollToBottom = () => {\n if (scrollContainer) {\n scrollContainer.scrollTop = scrollContainer.scrollHeight\n }\n }\n\n watch(log, () => {\n if (scrolledToBottom.value) {\n scrollToBottom()\n }\n })\n\n const fetchLogs = async () => {\n log.value = await props.fetchLogs()\n }\n\n await fetchLogs()\n scrollToBottom()\n intervalId = window.setInterval(fetchLogs, props.fetchInterval)\n})\n\nonBeforeUnmount(() => {\n window.clearInterval(intervalId)\n})\n</script>\n","<template>\n <div\n class=\"font-sans flex flex-col justify-center items-center h-screen m-0 text-neutral-300 bg-neutral-900 dark-theme pointer-events-auto\"\n >\n <h2 class=\"text-2xl font-bold\">{{ ProgressMessages[status] }}</h2>\n <LogTerminal :fetch-logs=\"fetchLogs\" :fetch-interval=\"500\" />\n </div>\n</template>\n\n<script setup lang=\"ts\">\nimport { ref, onMounted } from 'vue'\nimport LogTerminal from '@/components/common/LogTerminal.vue'\nimport {\n ProgressStatus,\n ProgressMessages\n} from '@comfyorg/comfyui-electron-types'\nimport { electronAPI } from '@/utils/envUtil'\n\nconst electron = electronAPI()\n\nconst status = ref<ProgressStatus>(ProgressStatus.INITIAL_STATE)\nconst logs = ref<string[]>([])\n\nconst updateProgress = ({ status: newStatus }: { status: ProgressStatus }) => {\n status.value = newStatus\n logs.value = [] // Clear logs when status changes\n}\n\nconst addLogMessage = (message: string) => {\n logs.value = [...logs.value, message]\n}\n\nconst fetchLogs = async () => {\n return logs.value.join('\\n')\n}\n\nonMounted(() => {\n electron.sendReady()\n electron.onProgressUpdate(updateProgress)\n electron.onLogMessage((message: string) => {\n addLogMessage(message)\n })\n})\n</script>\n"],"names":["ProgressStatus"],"mappings":";;;;;;;;;;;;;AAaA,UAAM,QAAQ;AAKR,UAAA,MAAM,IAAY,EAAE;AACpB,UAAA,iBAAiB,IAA6C,IAAI;AAMlE,UAAA,mBAAmB,IAAI,KAAK;AAElC,QAAI,aAAqB;AAEzB,cAAU,YAAY;AACd,YAAA,UAAU,eAAe,OAAO;AAChC,YAAA,kBAAkB,SAAS,cAAc,wBAAwB;AAEvE,UAAI,iBAAiB;AACH,wBAAA,iBAAiB,UAAU,MAAM;AAC/C,2BAAiB,QACf,gBAAgB,YAAY,gBAAgB,iBAC5C,gBAAgB;AAAA,QAAA,CACnB;AAAA,MACH;AAEA,YAAM,iBAAiB,6BAAM;AAC3B,YAAI,iBAAiB;AACnB,0BAAgB,YAAY,gBAAgB;AAAA,QAC9C;AAAA,MAAA,GAHqB;AAMvB,YAAM,KAAK,MAAM;AACf,YAAI,iBAAiB,OAAO;AACX;QACjB;AAAA,MAAA,CACD;AAED,YAAM,YAAY,mCAAY;AACxB,YAAA,QAAQ,MAAM,MAAM,UAAU;AAAA,MAAA,GADlB;AAIlB,YAAM,UAAU;AACD;AACf,mBAAa,OAAO,YAAY,WAAW,MAAM,aAAa;AAAA,IAAA,CAC/D;AAED,oBAAgB,MAAM;AACpB,aAAO,cAAc,UAAU;AAAA,IAAA,CAChC;;;;;;;;;;;;;;;;;;;;;;AC9CD,UAAM,WAAW;AAEX,UAAA,SAAS,IAAoBA,EAAe,aAAa;AACzD,UAAA,OAAO,IAAc,CAAA,CAAE;AAE7B,UAAM,iBAAiB,wBAAC,EAAE,QAAQ,gBAA4C;AAC5E,aAAO,QAAQ;AACf,WAAK,QAAQ;IAAC,GAFO;AAKjB,UAAA,gBAAgB,wBAAC,YAAoB;AACzC,WAAK,QAAQ,CAAC,GAAG,KAAK,OAAO,OAAO;AAAA,IAAA,GADhB;AAItB,UAAM,YAAY,mCAAY;AACrB,aAAA,KAAK,MAAM,KAAK,IAAI;AAAA,IAAA,GADX;AAIlB,cAAU,MAAM;AACd,eAAS,UAAU;AACnB,eAAS,iBAAiB,cAAc;AAC/B,eAAA,aAAa,CAAC,YAAoB;AACzC,sBAAc,OAAO;AAAA,MAAA,CACtB;AAAA,IAAA,CACF;;;;;;;;;;;;"}
|
36
web/assets/WelcomeView-DQQgHnsr.css
generated
vendored
Normal file
36
web/assets/WelcomeView-DQQgHnsr.css
generated
vendored
Normal file
@@ -0,0 +1,36 @@
|
||||
|
||||
.animated-gradient-text[data-v-12b8b11b] {
|
||||
font-weight: 700;
|
||||
font-size: clamp(2rem, 8vw, 4rem);
|
||||
background: linear-gradient(to right, #12c2e9, #c471ed, #f64f59, #12c2e9);
|
||||
background-size: 300% auto;
|
||||
background-clip: text;
|
||||
-webkit-background-clip: text;
|
||||
-webkit-text-fill-color: transparent;
|
||||
animation: gradient-12b8b11b 8s linear infinite;
|
||||
}
|
||||
.text-glow[data-v-12b8b11b] {
|
||||
filter: drop-shadow(0 0 8px rgba(255, 255, 255, 0.3));
|
||||
}
|
||||
@keyframes gradient-12b8b11b {
|
||||
0% {
|
||||
background-position: 0% center;
|
||||
}
|
||||
100% {
|
||||
background-position: 300% center;
|
||||
}
|
||||
}
|
||||
.fade-in-up[data-v-12b8b11b] {
|
||||
animation: fadeInUp-12b8b11b 1.5s ease-out;
|
||||
animation-fill-mode: both;
|
||||
}
|
||||
@keyframes fadeInUp-12b8b11b {
|
||||
0% {
|
||||
opacity: 0;
|
||||
transform: translateY(20px);
|
||||
}
|
||||
100% {
|
||||
opacity: 1;
|
||||
transform: translateY(0);
|
||||
}
|
||||
}
|
33
web/assets/WelcomeView-DT4bj-QV.js
generated
vendored
Normal file
33
web/assets/WelcomeView-DT4bj-QV.js
generated
vendored
Normal file
@@ -0,0 +1,33 @@
|
||||
var __defProp = Object.defineProperty;
|
||||
var __name = (target, value) => __defProp(target, "name", { value, configurable: true });
|
||||
import { d as defineComponent, g as openBlock, h as createElementBlock, A as createBaseVNode, a6 as toDisplayString, i as createVNode, z as unref, D as script, P as pushScopeId, Q as popScopeId, _ as _export_sfc } from "./index-B6dYHNhg.js";
|
||||
const _withScopeId = /* @__PURE__ */ __name((n) => (pushScopeId("data-v-12b8b11b"), n = n(), popScopeId(), n), "_withScopeId");
|
||||
const _hoisted_1 = { class: "font-sans flex flex-col justify-center items-center h-screen m-0 text-neutral-300 bg-neutral-900 dark-theme pointer-events-auto" };
|
||||
const _hoisted_2 = { class: "flex flex-col items-center justify-center gap-8 p-8" };
|
||||
const _hoisted_3 = { class: "animated-gradient-text text-glow select-none" };
|
||||
const _sfc_main = /* @__PURE__ */ defineComponent({
|
||||
__name: "WelcomeView",
|
||||
setup(__props) {
|
||||
return (_ctx, _cache) => {
|
||||
return openBlock(), createElementBlock("div", _hoisted_1, [
|
||||
createBaseVNode("div", _hoisted_2, [
|
||||
createBaseVNode("h1", _hoisted_3, toDisplayString(_ctx.$t("welcome.title")), 1),
|
||||
createVNode(unref(script), {
|
||||
label: _ctx.$t("welcome.getStarted"),
|
||||
icon: "pi pi-arrow-right",
|
||||
iconPos: "right",
|
||||
size: "large",
|
||||
rounded: "",
|
||||
onClick: _cache[0] || (_cache[0] = ($event) => _ctx.$router.push("/install")),
|
||||
class: "p-4 text-lg fade-in-up"
|
||||
}, null, 8, ["label"])
|
||||
])
|
||||
]);
|
||||
};
|
||||
}
|
||||
});
|
||||
const WelcomeView = /* @__PURE__ */ _export_sfc(_sfc_main, [["__scopeId", "data-v-12b8b11b"]]);
|
||||
export {
|
||||
WelcomeView as default
|
||||
};
|
||||
//# sourceMappingURL=WelcomeView-DT4bj-QV.js.map
|
1
web/assets/WelcomeView-DT4bj-QV.js.map
generated
vendored
Normal file
1
web/assets/WelcomeView-DT4bj-QV.js.map
generated
vendored
Normal file
@@ -0,0 +1 @@
|
||||
{"version":3,"file":"WelcomeView-DT4bj-QV.js","sources":[],"sourcesContent":[],"names":[],"mappings":";;;;;;;;;;;;;;;;;;;;;;;;;;;;;"}
|
230
web/assets/index-BReiUkk9.js → web/assets/index-B1vRdV2i.js
generated
vendored
230
web/assets/index-BReiUkk9.js → web/assets/index-B1vRdV2i.js
generated
vendored
@@ -1,7 +1,7 @@
|
||||
var __defProp = Object.defineProperty;
|
||||
var __name = (target, value) => __defProp(target, "name", { value, configurable: true });
|
||||
import { bF as ComfyDialog, bG as $el, bH as ComfyApp, c as app, k as LiteGraph, b0 as LGraphCanvas, bI as DraggableList, ba as useToastStore, aE as useNodeDefStore, bC as api, L as LGraphGroup, bJ as KeyComboImpl, M as useKeybindingStore, F as useCommandStore, e as LGraphNode, bK as ComfyWidgets, bL as applyTextReplacements } from "./index-BHayQCxv.js";
|
||||
import { mergeIfValid, getWidgetConfig, setWidgetConfig } from "./widgetInputs-DdecKYqd.js";
|
||||
import { bV as ComfyDialog, bW as $el, bX as ComfyApp, c as app, k as LiteGraph, b2 as LGraphCanvas, bY as DraggableList, bf as useToastStore, bZ as serialise, aE as useNodeDefStore, b_ as deserialiseAndCreate, bH as api, L as LGraphGroup, b$ as KeyComboImpl, M as useKeybindingStore, F as useCommandStore, e as LGraphNode, c0 as ComfyWidgets, c1 as applyTextReplacements } from "./index-B6dYHNhg.js";
|
||||
import { mergeIfValid, getWidgetConfig, setWidgetConfig } from "./widgetInputs-BJ21PG7d.js";
|
||||
class ClipspaceDialog extends ComfyDialog {
|
||||
static {
|
||||
__name(this, "ClipspaceDialog");
|
||||
@@ -160,7 +160,7 @@ app.registerExtension({
|
||||
window.comfyAPI = window.comfyAPI || {};
|
||||
window.comfyAPI.clipspace = window.comfyAPI.clipspace || {};
|
||||
window.comfyAPI.clipspace.ClipspaceDialog = ClipspaceDialog;
|
||||
const ext$2 = {
|
||||
const ext$1 = {
|
||||
name: "Comfy.ContextMenuFilter",
|
||||
init() {
|
||||
const ctxMenu = LiteGraph.ContextMenu;
|
||||
@@ -178,7 +178,7 @@ const ext$2 = {
|
||||
let itemCount = displayedItems.length;
|
||||
requestAnimationFrame(() => {
|
||||
const currentNode = LGraphCanvas.active_canvas.current_node;
|
||||
const clickedComboValue = currentNode.widgets?.filter(
|
||||
const clickedComboValue = currentNode?.widgets?.filter(
|
||||
(w) => w.type === "combo" && w.options.values?.length === values.length
|
||||
).find(
|
||||
(w) => w.options.values?.every((v, i) => v === values[i])
|
||||
@@ -284,7 +284,7 @@ const ext$2 = {
|
||||
LiteGraph.ContextMenu.prototype = ctxMenu.prototype;
|
||||
}
|
||||
};
|
||||
app.registerExtension(ext$2);
|
||||
app.registerExtension(ext$1);
|
||||
function stripComments(str) {
|
||||
return str.replace(/\/\*[\s\S]*?\*\/|\/\/.*/g, "");
|
||||
}
|
||||
@@ -966,17 +966,13 @@ class GroupNodeBuilder {
|
||||
}
|
||||
}
|
||||
}, "storeExternalLinks");
|
||||
const backup = localStorage.getItem("litegrapheditor_clipboard");
|
||||
try {
|
||||
app.canvas.copyToClipboard(this.nodes);
|
||||
const config = JSON.parse(
|
||||
localStorage.getItem("litegrapheditor_clipboard")
|
||||
);
|
||||
const serialised = serialise(this.nodes, app.canvas.graph);
|
||||
const config = JSON.parse(serialised);
|
||||
storeLinkTypes(config);
|
||||
storeExternalLinks(config);
|
||||
return config;
|
||||
} finally {
|
||||
localStorage.setItem("litegrapheditor_clipboard", backup);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1517,7 +1513,6 @@ class GroupNodeHandler {
|
||||
};
|
||||
this.node.convertToNodes = () => {
|
||||
const addInnerNodes = /* @__PURE__ */ __name(() => {
|
||||
const backup = localStorage.getItem("litegrapheditor_clipboard");
|
||||
const c = { ...this.groupData.nodeData };
|
||||
c.nodes = [...c.nodes];
|
||||
const innerNodes = this.node.getInnerNodes();
|
||||
@@ -1531,9 +1526,7 @@ class GroupNodeHandler {
|
||||
}
|
||||
c.nodes[i] = { ...c.nodes[i], id: id2 };
|
||||
}
|
||||
localStorage.setItem("litegrapheditor_clipboard", JSON.stringify(c));
|
||||
app.canvas.pasteFromClipboard();
|
||||
localStorage.setItem("litegrapheditor_clipboard", backup);
|
||||
deserialiseAndCreate(JSON.stringify(c), app.canvas);
|
||||
const [x, y] = this.node.pos;
|
||||
let top;
|
||||
let left;
|
||||
@@ -1580,10 +1573,8 @@ class GroupNodeHandler {
|
||||
}
|
||||
}
|
||||
for (const newNode of newNodes2) {
|
||||
newNode.pos = [
|
||||
newNode.pos[0] - (left - x),
|
||||
newNode.pos[1] - (top - y)
|
||||
];
|
||||
newNode.pos[0] -= left - x;
|
||||
newNode.pos[1] -= top - y;
|
||||
}
|
||||
return { newNodes: newNodes2, selectedIds: selectedIds2 };
|
||||
}, "addInnerNodes");
|
||||
@@ -1618,10 +1609,12 @@ class GroupNodeHandler {
|
||||
}
|
||||
}
|
||||
}, "reconnectOutputs");
|
||||
app.canvas.emitBeforeChange();
|
||||
const { newNodes, selectedIds } = addInnerNodes();
|
||||
reconnectInputs(selectedIds);
|
||||
reconnectOutputs(selectedIds);
|
||||
app.graph.remove(this.node);
|
||||
app.canvas.emitAfterChange();
|
||||
return newNodes;
|
||||
};
|
||||
const getExtraMenuOptions = this.node.getExtraMenuOptions;
|
||||
@@ -2030,10 +2023,10 @@ function manageGroupNodes() {
|
||||
new ManageGroupDialog(app).show();
|
||||
}
|
||||
__name(manageGroupNodes, "manageGroupNodes");
|
||||
const id$3 = "Comfy.GroupNode";
|
||||
const id$2 = "Comfy.GroupNode";
|
||||
let globalDefs;
|
||||
const ext$1 = {
|
||||
name: id$3,
|
||||
const ext = {
|
||||
name: id$2,
|
||||
commands: [
|
||||
{
|
||||
id: "Comfy.GroupNode.ConvertSelectedNodesToGroupNode",
|
||||
@@ -2103,56 +2096,18 @@ const ext$1 = {
|
||||
}
|
||||
}
|
||||
};
|
||||
app.registerExtension(ext$1);
|
||||
app.registerExtension(ext);
|
||||
window.comfyAPI = window.comfyAPI || {};
|
||||
window.comfyAPI.groupNode = window.comfyAPI.groupNode || {};
|
||||
window.comfyAPI.groupNode.GroupNodeConfig = GroupNodeConfig;
|
||||
window.comfyAPI.groupNode.GroupNodeHandler = GroupNodeHandler;
|
||||
function setNodeMode(node, mode) {
|
||||
node.mode = mode;
|
||||
node.graph.change();
|
||||
node.graph?.change();
|
||||
}
|
||||
__name(setNodeMode, "setNodeMode");
|
||||
function addNodesToGroup(group, nodes = []) {
|
||||
var x1, y1, x2, y2;
|
||||
var nx1, ny1, nx2, ny2;
|
||||
var node;
|
||||
x1 = y1 = x2 = y2 = -1;
|
||||
nx1 = ny1 = nx2 = ny2 = -1;
|
||||
for (var n of [group.nodes, nodes]) {
|
||||
for (var i in n) {
|
||||
node = n[i];
|
||||
nx1 = node.pos[0];
|
||||
ny1 = node.pos[1];
|
||||
nx2 = node.pos[0] + node.size[0];
|
||||
ny2 = node.pos[1] + node.size[1];
|
||||
if (node.type != "Reroute") {
|
||||
ny1 -= LiteGraph.NODE_TITLE_HEIGHT;
|
||||
}
|
||||
if (node.flags?.collapsed) {
|
||||
ny2 = ny1 + LiteGraph.NODE_TITLE_HEIGHT;
|
||||
if (node?._collapsed_width) {
|
||||
nx2 = nx1 + Math.round(node._collapsed_width);
|
||||
}
|
||||
}
|
||||
if (x1 == -1 || nx1 < x1) {
|
||||
x1 = nx1;
|
||||
}
|
||||
if (y1 == -1 || ny1 < y1) {
|
||||
y1 = ny1;
|
||||
}
|
||||
if (x2 == -1 || nx2 > x2) {
|
||||
x2 = nx2;
|
||||
}
|
||||
if (y2 == -1 || ny2 > y2) {
|
||||
y2 = ny2;
|
||||
}
|
||||
}
|
||||
}
|
||||
var padding = 10;
|
||||
y1 = y1 - Math.round(group.font_size * 1.4);
|
||||
group.pos = [x1 - padding, y1 - padding];
|
||||
group.size = [x2 - x1 + padding * 2, y2 - y1 + padding * 2];
|
||||
function addNodesToGroup(group, items) {
|
||||
group.resizeTo([...group.children, ...items]);
|
||||
}
|
||||
__name(addNodesToGroup, "addNodesToGroup");
|
||||
app.registerExtension({
|
||||
@@ -2168,11 +2123,11 @@ app.registerExtension({
|
||||
if (!group) {
|
||||
options.push({
|
||||
content: "Add Group For Selected Nodes",
|
||||
disabled: !Object.keys(app.canvas.selected_nodes || {}).length,
|
||||
disabled: !this.selectedItems?.size,
|
||||
callback: /* @__PURE__ */ __name(() => {
|
||||
const group2 = new LGraphGroup();
|
||||
addNodesToGroup(group2, this.selected_nodes);
|
||||
app.canvas.graph.add(group2);
|
||||
addNodesToGroup(group2, this.selectedItems);
|
||||
this.graph.add(group2);
|
||||
this.graph.change();
|
||||
}, "callback")
|
||||
});
|
||||
@@ -2182,9 +2137,9 @@ app.registerExtension({
|
||||
const nodesInGroup = group.nodes;
|
||||
options.push({
|
||||
content: "Add Selected Nodes To Group",
|
||||
disabled: !Object.keys(app.canvas.selected_nodes || {}).length,
|
||||
disabled: !this.selectedItems?.size,
|
||||
callback: /* @__PURE__ */ __name(() => {
|
||||
addNodesToGroup(group, this.selected_nodes);
|
||||
addNodesToGroup(group, this.selectedItems);
|
||||
this.graph.change();
|
||||
}, "callback")
|
||||
});
|
||||
@@ -2203,7 +2158,8 @@ app.registerExtension({
|
||||
options.push({
|
||||
content: "Fit Group To Nodes",
|
||||
callback: /* @__PURE__ */ __name(() => {
|
||||
addNodesToGroup(group);
|
||||
group.recomputeInsideNodes();
|
||||
group.resizeTo(group.children);
|
||||
this.graph.change();
|
||||
}, "callback")
|
||||
});
|
||||
@@ -2329,9 +2285,9 @@ app.registerExtension({
|
||||
};
|
||||
}
|
||||
});
|
||||
const id$2 = "Comfy.InvertMenuScrolling";
|
||||
const id$1 = "Comfy.InvertMenuScrolling";
|
||||
app.registerExtension({
|
||||
name: id$2,
|
||||
name: id$1,
|
||||
init() {
|
||||
const ctxMenu = LiteGraph.ContextMenu;
|
||||
const replace = /* @__PURE__ */ __name(() => {
|
||||
@@ -2347,7 +2303,7 @@ app.registerExtension({
|
||||
LiteGraph.ContextMenu.prototype = ctxMenu.prototype;
|
||||
}, "replace");
|
||||
app.ui.settings.addSetting({
|
||||
id: id$2,
|
||||
id: id$1,
|
||||
category: ["Comfy", "Graph", "InvertMenuScrolling"],
|
||||
name: "Invert Context Menu Scrolling",
|
||||
type: "boolean",
|
||||
@@ -2379,8 +2335,8 @@ app.registerExtension({
|
||||
const commandStore = useCommandStore();
|
||||
const keybinding = keybindingStore.getKeybinding(keyCombo);
|
||||
if (keybinding && keybinding.targetSelector !== "#graph-canvas") {
|
||||
await commandStore.execute(keybinding.commandId);
|
||||
event.preventDefault();
|
||||
await commandStore.execute(keybinding.commandId);
|
||||
return;
|
||||
}
|
||||
if (event.ctrlKey || event.altKey || event.metaKey) {
|
||||
@@ -2403,35 +2359,6 @@ app.registerExtension({
|
||||
window.addEventListener("keydown", keybindListener);
|
||||
}
|
||||
});
|
||||
const id$1 = "Comfy.LinkRenderMode";
|
||||
const ext = {
|
||||
name: id$1,
|
||||
async setup(app2) {
|
||||
app2.ui.settings.addSetting({
|
||||
id: id$1,
|
||||
category: ["Comfy", "Graph", "LinkRenderMode"],
|
||||
name: "Link Render Mode",
|
||||
defaultValue: 2,
|
||||
type: "combo",
|
||||
options: [
|
||||
{ value: LiteGraph.STRAIGHT_LINK.toString(), text: "Straight" },
|
||||
{ value: LiteGraph.LINEAR_LINK.toString(), text: "Linear" },
|
||||
{ value: LiteGraph.SPLINE_LINK.toString(), text: "Spline" },
|
||||
{ value: LiteGraph.HIDDEN_LINK.toString(), text: "Hidden" }
|
||||
],
|
||||
onChange(value) {
|
||||
app2.canvas.links_render_mode = +value;
|
||||
app2.canvas.setDirty(
|
||||
/* fg */
|
||||
false,
|
||||
/* bg */
|
||||
true
|
||||
);
|
||||
}
|
||||
});
|
||||
}
|
||||
};
|
||||
app.registerExtension(ext);
|
||||
function dataURLToBlob(dataURL) {
|
||||
const parts = dataURL.split(";base64,");
|
||||
const contentType = parts[0].split(":")[1];
|
||||
@@ -3714,8 +3641,12 @@ app.registerExtension({
|
||||
clipboardAction(async () => {
|
||||
const data = JSON.parse(t.data);
|
||||
await GroupNodeConfig.registerFromWorkflow(data.groupNodes, {});
|
||||
localStorage.setItem("litegrapheditor_clipboard", t.data);
|
||||
app.canvas.pasteFromClipboard();
|
||||
if (!data.reroutes) {
|
||||
deserialiseAndCreate(t.data, app.canvas);
|
||||
} else {
|
||||
localStorage.setItem("litegrapheditor_clipboard", t.data);
|
||||
app.canvas.pasteFromClipboard();
|
||||
}
|
||||
});
|
||||
}, "callback")
|
||||
};
|
||||
@@ -4049,9 +3980,10 @@ let touchCount = 0;
|
||||
app.registerExtension({
|
||||
name: "Comfy.SimpleTouchSupport",
|
||||
setup() {
|
||||
let zoomPos;
|
||||
let touchDist;
|
||||
let touchTime;
|
||||
let lastTouch;
|
||||
let lastScale;
|
||||
function getMultiTouchPos(e) {
|
||||
return Math.hypot(
|
||||
e.touches[0].clientX - e.touches[1].clientX,
|
||||
@@ -4059,63 +3991,90 @@ app.registerExtension({
|
||||
);
|
||||
}
|
||||
__name(getMultiTouchPos, "getMultiTouchPos");
|
||||
app.canvasEl.addEventListener(
|
||||
function getMultiTouchCenter(e) {
|
||||
return {
|
||||
clientX: (e.touches[0].clientX + e.touches[1].clientX) / 2,
|
||||
clientY: (e.touches[0].clientY + e.touches[1].clientY) / 2
|
||||
};
|
||||
}
|
||||
__name(getMultiTouchCenter, "getMultiTouchCenter");
|
||||
app.canvasEl.parentElement.addEventListener(
|
||||
"touchstart",
|
||||
(e) => {
|
||||
touchCount++;
|
||||
lastTouch = null;
|
||||
lastScale = null;
|
||||
if (e.touches?.length === 1) {
|
||||
touchTime = /* @__PURE__ */ new Date();
|
||||
lastTouch = e.touches[0];
|
||||
} else {
|
||||
touchTime = null;
|
||||
if (e.touches?.length === 2) {
|
||||
zoomPos = getMultiTouchPos(e);
|
||||
lastScale = app.canvas.ds.scale;
|
||||
lastTouch = getMultiTouchCenter(e);
|
||||
touchDist = getMultiTouchPos(e);
|
||||
app.canvas.pointer_is_down = false;
|
||||
}
|
||||
}
|
||||
},
|
||||
true
|
||||
);
|
||||
app.canvasEl.addEventListener("touchend", (e) => {
|
||||
touchZooming = false;
|
||||
touchCount = e.touches?.length ?? touchCount - 1;
|
||||
app.canvasEl.parentElement.addEventListener("touchend", (e) => {
|
||||
touchCount--;
|
||||
if (e.touches?.length !== 1) touchZooming = false;
|
||||
if (touchTime && !e.touches?.length) {
|
||||
if ((/* @__PURE__ */ new Date()).getTime() - touchTime > 600) {
|
||||
try {
|
||||
e.constructor = CustomEvent;
|
||||
} catch (error) {
|
||||
if (e.target === app.canvasEl) {
|
||||
app.canvasEl.dispatchEvent(
|
||||
new PointerEvent("pointerdown", {
|
||||
button: 2,
|
||||
clientX: e.changedTouches[0].clientX,
|
||||
clientY: e.changedTouches[0].clientY
|
||||
})
|
||||
);
|
||||
e.preventDefault();
|
||||
}
|
||||
e.clientX = lastTouch.clientX;
|
||||
e.clientY = lastTouch.clientY;
|
||||
app.canvas.pointer_is_down = true;
|
||||
app.canvas._mousedown_callback(e);
|
||||
}
|
||||
touchTime = null;
|
||||
}
|
||||
});
|
||||
app.canvasEl.addEventListener(
|
||||
app.canvasEl.parentElement.addEventListener(
|
||||
"touchmove",
|
||||
(e) => {
|
||||
touchTime = null;
|
||||
if (e.touches?.length === 2) {
|
||||
if (e.touches?.length === 2 && lastTouch && !e.ctrlKey && !e.shiftKey) {
|
||||
e.preventDefault();
|
||||
app.canvas.pointer_is_down = false;
|
||||
touchZooming = true;
|
||||
LiteGraph.closeAllContextMenus();
|
||||
LiteGraph.closeAllContextMenus(window);
|
||||
app.canvas.search_box?.close();
|
||||
const newZoomPos = getMultiTouchPos(e);
|
||||
const midX = (e.touches[0].clientX + e.touches[1].clientX) / 2;
|
||||
const midY = (e.touches[0].clientY + e.touches[1].clientY) / 2;
|
||||
let scale = app.canvas.ds.scale;
|
||||
const diff = zoomPos - newZoomPos;
|
||||
if (diff > 0.5) {
|
||||
scale *= 1 / 1.07;
|
||||
} else if (diff < -0.5) {
|
||||
scale *= 1.07;
|
||||
const newTouchDist = getMultiTouchPos(e);
|
||||
const center = getMultiTouchCenter(e);
|
||||
let scale = lastScale * newTouchDist / touchDist;
|
||||
const newX = (center.clientX - lastTouch.clientX) / scale;
|
||||
const newY = (center.clientY - lastTouch.clientY) / scale;
|
||||
if (scale < app.canvas.ds.min_scale) {
|
||||
scale = app.canvas.ds.min_scale;
|
||||
} else if (scale > app.canvas.ds.max_scale) {
|
||||
scale = app.canvas.ds.max_scale;
|
||||
}
|
||||
app.canvas.ds.changeScale(scale, [midX, midY]);
|
||||
const oldScale = app.canvas.ds.scale;
|
||||
app.canvas.ds.scale = scale;
|
||||
if (Math.abs(app.canvas.ds.scale - 1) < 0.01) {
|
||||
app.canvas.ds.scale = 1;
|
||||
}
|
||||
const newScale = app.canvas.ds.scale;
|
||||
const convertScaleToOffset = /* @__PURE__ */ __name((scale2) => [
|
||||
center.clientX / scale2 - app.canvas.ds.offset[0],
|
||||
center.clientY / scale2 - app.canvas.ds.offset[1]
|
||||
], "convertScaleToOffset");
|
||||
var oldCenter = convertScaleToOffset(oldScale);
|
||||
var newCenter = convertScaleToOffset(newScale);
|
||||
app.canvas.ds.offset[0] += newX + newCenter[0] - oldCenter[0];
|
||||
app.canvas.ds.offset[1] += newY + newCenter[1] - oldCenter[1];
|
||||
lastTouch.clientX = center.clientX;
|
||||
lastTouch.clientY = center.clientY;
|
||||
app.canvas.setDirty(true, true);
|
||||
zoomPos = newZoomPos;
|
||||
}
|
||||
},
|
||||
true
|
||||
@@ -4127,6 +4086,7 @@ LGraphCanvas.prototype.processMouseDown = function(e) {
|
||||
if (touchZooming || touchCount) {
|
||||
return;
|
||||
}
|
||||
app.canvas.pointer_is_down = false;
|
||||
return processMouseDown.apply(this, arguments);
|
||||
};
|
||||
const processMouseMove = LGraphCanvas.prototype.processMouseMove;
|
||||
@@ -4539,7 +4499,9 @@ app.registerExtension({
|
||||
/* name=*/
|
||||
"audioUI",
|
||||
audio,
|
||||
{ serialize: false }
|
||||
{
|
||||
serialize: false
|
||||
}
|
||||
);
|
||||
const isOutputNode = node.constructor.nodeData.output_node;
|
||||
if (isOutputNode) {
|
||||
@@ -4633,4 +4595,4 @@ app.registerExtension({
|
||||
};
|
||||
}
|
||||
});
|
||||
//# sourceMappingURL=index-BReiUkk9.js.map
|
||||
//# sourceMappingURL=index-B1vRdV2i.js.map
|
1
web/assets/index-B1vRdV2i.js.map
generated
vendored
Normal file
1
web/assets/index-B1vRdV2i.js.map
generated
vendored
Normal file
File diff suppressed because one or more lines are too long
62
web/assets/index-B4gmhi99.js
generated
vendored
Normal file
62
web/assets/index-B4gmhi99.js
generated
vendored
Normal file
@@ -0,0 +1,62 @@
|
||||
const o = {
|
||||
LOADING_PROGRESS: "loading-progress",
|
||||
IS_PACKAGED: "is-packaged",
|
||||
RENDERER_READY: "renderer-ready",
|
||||
RESTART_APP: "restart-app",
|
||||
REINSTALL: "reinstall",
|
||||
LOG_MESSAGE: "log-message",
|
||||
OPEN_DIALOG: "open-dialog",
|
||||
DOWNLOAD_PROGRESS: "download-progress",
|
||||
START_DOWNLOAD: "start-download",
|
||||
PAUSE_DOWNLOAD: "pause-download",
|
||||
RESUME_DOWNLOAD: "resume-download",
|
||||
CANCEL_DOWNLOAD: "cancel-download",
|
||||
DELETE_MODEL: "delete-model",
|
||||
GET_ALL_DOWNLOADS: "get-all-downloads",
|
||||
GET_ELECTRON_VERSION: "get-electron-version",
|
||||
SEND_ERROR_TO_SENTRY: "send-error-to-sentry",
|
||||
GET_BASE_PATH: "get-base-path",
|
||||
GET_MODEL_CONFIG_PATH: "get-model-config-path",
|
||||
OPEN_PATH: "open-path",
|
||||
OPEN_LOGS_PATH: "open-logs-path",
|
||||
OPEN_DEV_TOOLS: "open-dev-tools",
|
||||
IS_FIRST_TIME_SETUP: "is-first-time-setup",
|
||||
GET_SYSTEM_PATHS: "get-system-paths",
|
||||
VALIDATE_INSTALL_PATH: "validate-install-path",
|
||||
VALIDATE_COMFYUI_SOURCE: "validate-comfyui-source",
|
||||
SHOW_DIRECTORY_PICKER: "show-directory-picker",
|
||||
INSTALL_COMFYUI: "install-comfyui"
|
||||
};
|
||||
var t = /* @__PURE__ */ ((e) => (e.INITIAL_STATE = "initial-state", e.PYTHON_SETUP = "python-setup", e.STARTING_SERVER = "starting-server", e.READY = "ready", e.ERROR = "error", e.ERROR_INSTALL_PATH = "error-install-path", e))(t || {});
|
||||
const s = {
|
||||
"initial-state": "Loading...",
|
||||
"python-setup": "Setting up Python Environment...",
|
||||
"starting-server": "Starting ComfyUI server...",
|
||||
ready: "Finishing...",
|
||||
error: "Was not able to start ComfyUI. Please check the logs for more details. You can open it from the Help menu. Please report issues to: https://forum.comfy.org",
|
||||
"error-install-path": "Installation path does not exist. Please reset the installation location."
|
||||
}, a = "electronAPI", n = "https://942cadba58d247c9cab96f45221aa813@o4507954455314432.ingest.us.sentry.io/4508007940685824", r = [
|
||||
{
|
||||
id: "user_files",
|
||||
label: "User Files",
|
||||
description: "Settings and user-created workflows"
|
||||
},
|
||||
{
|
||||
id: "models",
|
||||
label: "Models",
|
||||
description: "Reference model files from existing ComfyUI installations. (No copy)"
|
||||
}
|
||||
// TODO: Decide whether we want to auto-migrate custom nodes, and install their dependencies.
|
||||
// huchenlei: This is a very essential thing for migration experience.
|
||||
// {
|
||||
// id: 'custom_nodes',
|
||||
// label: 'Custom Nodes',
|
||||
// description: 'Reference custom node files from existing ComfyUI installations. (No copy)',
|
||||
// },
|
||||
];
|
||||
export {
|
||||
r,
|
||||
s,
|
||||
t
|
||||
};
|
||||
//# sourceMappingURL=index-B4gmhi99.js.map
|
1
web/assets/index-B4gmhi99.js.map
generated
vendored
Normal file
1
web/assets/index-B4gmhi99.js.map
generated
vendored
Normal file
@@ -0,0 +1 @@
|
||||
{"version":3,"file":"index-B4gmhi99.js","sources":["../../node_modules/@comfyorg/comfyui-electron-types/index.mjs"],"sourcesContent":["const o = {\n LOADING_PROGRESS: \"loading-progress\",\n IS_PACKAGED: \"is-packaged\",\n RENDERER_READY: \"renderer-ready\",\n RESTART_APP: \"restart-app\",\n REINSTALL: \"reinstall\",\n LOG_MESSAGE: \"log-message\",\n OPEN_DIALOG: \"open-dialog\",\n DOWNLOAD_PROGRESS: \"download-progress\",\n START_DOWNLOAD: \"start-download\",\n PAUSE_DOWNLOAD: \"pause-download\",\n RESUME_DOWNLOAD: \"resume-download\",\n CANCEL_DOWNLOAD: \"cancel-download\",\n DELETE_MODEL: \"delete-model\",\n GET_ALL_DOWNLOADS: \"get-all-downloads\",\n GET_ELECTRON_VERSION: \"get-electron-version\",\n SEND_ERROR_TO_SENTRY: \"send-error-to-sentry\",\n GET_BASE_PATH: \"get-base-path\",\n GET_MODEL_CONFIG_PATH: \"get-model-config-path\",\n OPEN_PATH: \"open-path\",\n OPEN_LOGS_PATH: \"open-logs-path\",\n OPEN_DEV_TOOLS: \"open-dev-tools\",\n IS_FIRST_TIME_SETUP: \"is-first-time-setup\",\n GET_SYSTEM_PATHS: \"get-system-paths\",\n VALIDATE_INSTALL_PATH: \"validate-install-path\",\n VALIDATE_COMFYUI_SOURCE: \"validate-comfyui-source\",\n SHOW_DIRECTORY_PICKER: \"show-directory-picker\",\n INSTALL_COMFYUI: \"install-comfyui\"\n};\nvar t = /* @__PURE__ */ ((e) => (e.INITIAL_STATE = \"initial-state\", e.PYTHON_SETUP = \"python-setup\", e.STARTING_SERVER = \"starting-server\", e.READY = \"ready\", e.ERROR = \"error\", e.ERROR_INSTALL_PATH = \"error-install-path\", e))(t || {});\nconst s = {\n \"initial-state\": \"Loading...\",\n \"python-setup\": \"Setting up Python Environment...\",\n \"starting-server\": \"Starting ComfyUI server...\",\n ready: \"Finishing...\",\n error: \"Was not able to start ComfyUI. Please check the logs for more details. You can open it from the Help menu. Please report issues to: https://forum.comfy.org\",\n \"error-install-path\": \"Installation path does not exist. Please reset the installation location.\"\n}, a = \"electronAPI\", n = \"https://942cadba58d247c9cab96f45221aa813@o4507954455314432.ingest.us.sentry.io/4508007940685824\", r = [\n {\n id: \"user_files\",\n label: \"User Files\",\n description: \"Settings and user-created workflows\"\n },\n {\n id: \"models\",\n label: \"Models\",\n description: \"Reference model files from existing ComfyUI installations. (No copy)\"\n }\n // TODO: Decide whether we want to auto-migrate custom nodes, and install their dependencies.\n // huchenlei: This is a very essential thing for migration experience.\n // {\n // id: 'custom_nodes',\n // label: 'Custom Nodes',\n // description: 'Reference custom node files from existing ComfyUI installations. (No copy)',\n // },\n];\nexport {\n a as ELECTRON_BRIDGE_API,\n o as IPC_CHANNELS,\n r as MigrationItems,\n s as ProgressMessages,\n t as ProgressStatus,\n n as SENTRY_URL_ENDPOINT\n};\n"],"names":[],"mappings":"AAAA,MAAM,IAAI;AAAA,EACR,kBAAkB;AAAA,EAClB,aAAa;AAAA,EACb,gBAAgB;AAAA,EAChB,aAAa;AAAA,EACb,WAAW;AAAA,EACX,aAAa;AAAA,EACb,aAAa;AAAA,EACb,mBAAmB;AAAA,EACnB,gBAAgB;AAAA,EAChB,gBAAgB;AAAA,EAChB,iBAAiB;AAAA,EACjB,iBAAiB;AAAA,EACjB,cAAc;AAAA,EACd,mBAAmB;AAAA,EACnB,sBAAsB;AAAA,EACtB,sBAAsB;AAAA,EACtB,eAAe;AAAA,EACf,uBAAuB;AAAA,EACvB,WAAW;AAAA,EACX,gBAAgB;AAAA,EAChB,gBAAgB;AAAA,EAChB,qBAAqB;AAAA,EACrB,kBAAkB;AAAA,EAClB,uBAAuB;AAAA,EACvB,yBAAyB;AAAA,EACzB,uBAAuB;AAAA,EACvB,iBAAiB;AACnB;AACG,IAAC,IAAqB,kBAAC,OAAO,EAAE,gBAAgB,iBAAiB,EAAE,eAAe,gBAAgB,EAAE,kBAAkB,mBAAmB,EAAE,QAAQ,SAAS,EAAE,QAAQ,SAAS,EAAE,qBAAqB,sBAAsB,IAAI,KAAK,CAAA,CAAE;AACrO,MAAC,IAAI;AAAA,EACR,iBAAiB;AAAA,EACjB,gBAAgB;AAAA,EAChB,mBAAmB;AAAA,EACnB,OAAO;AAAA,EACP,OAAO;AAAA,EACP,sBAAsB;AACxB,GAAG,IAAI,eAAe,IAAI,mGAAmG,IAAI;AAAA,EAC/H;AAAA,IACE,IAAI;AAAA,IACJ,OAAO;AAAA,IACP,aAAa;AAAA,EACd;AAAA,EACD;AAAA,IACE,IAAI;AAAA,IACJ,OAAO;AAAA,IACP,aAAa;AAAA,EACd;AAAA;AAAA;AAAA;AAAA;AAAA;AAAA;AAAA;AAQH;","x_google_ignoreList":[0]}
|
72808
web/assets/index-BHayQCxv.js → web/assets/index-B6dYHNhg.js
generated
vendored
72808
web/assets/index-BHayQCxv.js → web/assets/index-B6dYHNhg.js
generated
vendored
File diff suppressed because one or more lines are too long
1
web/assets/index-B6dYHNhg.js.map
generated
vendored
Normal file
1
web/assets/index-B6dYHNhg.js.map
generated
vendored
Normal file
File diff suppressed because one or more lines are too long
685
web/assets/index-BitceZ14.css → web/assets/index-BCoLUtIt.css
generated
vendored
685
web/assets/index-BitceZ14.css → web/assets/index-BCoLUtIt.css
generated
vendored
File diff suppressed because it is too large
Load Diff
1
web/assets/index-BHayQCxv.js.map
generated
vendored
1
web/assets/index-BHayQCxv.js.map
generated
vendored
File diff suppressed because one or more lines are too long
1
web/assets/index-BReiUkk9.js.map
generated
vendored
1
web/assets/index-BReiUkk9.js.map
generated
vendored
File diff suppressed because one or more lines are too long
102
web/assets/index-C_wOqB0f.js
generated
vendored
102
web/assets/index-C_wOqB0f.js
generated
vendored
@@ -1,102 +0,0 @@
|
||||
var __defProp = Object.defineProperty;
|
||||
var __name = (target, value) => __defProp(target, "name", { value, configurable: true });
|
||||
import { bS as script$4, A as createBaseVNode, g as openBlock, h as createElementBlock, m as mergeProps } from "./index-BHayQCxv.js";
|
||||
var script$3 = {
|
||||
name: "BarsIcon",
|
||||
"extends": script$4
|
||||
};
|
||||
var _hoisted_1$3 = /* @__PURE__ */ createBaseVNode("path", {
|
||||
"fill-rule": "evenodd",
|
||||
"clip-rule": "evenodd",
|
||||
d: "M13.3226 3.6129H0.677419C0.497757 3.6129 0.325452 3.54152 0.198411 3.41448C0.0713707 3.28744 0 3.11514 0 2.93548C0 2.75581 0.0713707 2.58351 0.198411 2.45647C0.325452 2.32943 0.497757 2.25806 0.677419 2.25806H13.3226C13.5022 2.25806 13.6745 2.32943 13.8016 2.45647C13.9286 2.58351 14 2.75581 14 2.93548C14 3.11514 13.9286 3.28744 13.8016 3.41448C13.6745 3.54152 13.5022 3.6129 13.3226 3.6129ZM13.3226 7.67741H0.677419C0.497757 7.67741 0.325452 7.60604 0.198411 7.479C0.0713707 7.35196 0 7.17965 0 6.99999C0 6.82033 0.0713707 6.64802 0.198411 6.52098C0.325452 6.39394 0.497757 6.32257 0.677419 6.32257H13.3226C13.5022 6.32257 13.6745 6.39394 13.8016 6.52098C13.9286 6.64802 14 6.82033 14 6.99999C14 7.17965 13.9286 7.35196 13.8016 7.479C13.6745 7.60604 13.5022 7.67741 13.3226 7.67741ZM0.677419 11.7419H13.3226C13.5022 11.7419 13.6745 11.6706 13.8016 11.5435C13.9286 11.4165 14 11.2442 14 11.0645C14 10.8848 13.9286 10.7125 13.8016 10.5855C13.6745 10.4585 13.5022 10.3871 13.3226 10.3871H0.677419C0.497757 10.3871 0.325452 10.4585 0.198411 10.5855C0.0713707 10.7125 0 10.8848 0 11.0645C0 11.2442 0.0713707 11.4165 0.198411 11.5435C0.325452 11.6706 0.497757 11.7419 0.677419 11.7419Z",
|
||||
fill: "currentColor"
|
||||
}, null, -1);
|
||||
var _hoisted_2$3 = [_hoisted_1$3];
|
||||
function render$3(_ctx, _cache, $props, $setup, $data, $options) {
|
||||
return openBlock(), createElementBlock("svg", mergeProps({
|
||||
width: "14",
|
||||
height: "14",
|
||||
viewBox: "0 0 14 14",
|
||||
fill: "none",
|
||||
xmlns: "http://www.w3.org/2000/svg"
|
||||
}, _ctx.pti()), _hoisted_2$3, 16);
|
||||
}
|
||||
__name(render$3, "render$3");
|
||||
script$3.render = render$3;
|
||||
var script$2 = {
|
||||
name: "PlusIcon",
|
||||
"extends": script$4
|
||||
};
|
||||
var _hoisted_1$2 = /* @__PURE__ */ createBaseVNode("path", {
|
||||
d: "M7.67742 6.32258V0.677419C7.67742 0.497757 7.60605 0.325452 7.47901 0.198411C7.35197 0.0713707 7.17966 0 7 0C6.82034 0 6.64803 0.0713707 6.52099 0.198411C6.39395 0.325452 6.32258 0.497757 6.32258 0.677419V6.32258H0.677419C0.497757 6.32258 0.325452 6.39395 0.198411 6.52099C0.0713707 6.64803 0 6.82034 0 7C0 7.17966 0.0713707 7.35197 0.198411 7.47901C0.325452 7.60605 0.497757 7.67742 0.677419 7.67742H6.32258V13.3226C6.32492 13.5015 6.39704 13.6725 6.52358 13.799C6.65012 13.9255 6.82106 13.9977 7 14C7.17966 14 7.35197 13.9286 7.47901 13.8016C7.60605 13.6745 7.67742 13.5022 7.67742 13.3226V7.67742H13.3226C13.5022 7.67742 13.6745 7.60605 13.8016 7.47901C13.9286 7.35197 14 7.17966 14 7C13.9977 6.82106 13.9255 6.65012 13.799 6.52358C13.6725 6.39704 13.5015 6.32492 13.3226 6.32258H7.67742Z",
|
||||
fill: "currentColor"
|
||||
}, null, -1);
|
||||
var _hoisted_2$2 = [_hoisted_1$2];
|
||||
function render$2(_ctx, _cache, $props, $setup, $data, $options) {
|
||||
return openBlock(), createElementBlock("svg", mergeProps({
|
||||
width: "14",
|
||||
height: "14",
|
||||
viewBox: "0 0 14 14",
|
||||
fill: "none",
|
||||
xmlns: "http://www.w3.org/2000/svg"
|
||||
}, _ctx.pti()), _hoisted_2$2, 16);
|
||||
}
|
||||
__name(render$2, "render$2");
|
||||
script$2.render = render$2;
|
||||
var script$1 = {
|
||||
name: "ExclamationTriangleIcon",
|
||||
"extends": script$4
|
||||
};
|
||||
var _hoisted_1$1 = /* @__PURE__ */ createBaseVNode("path", {
|
||||
d: "M13.4018 13.1893H0.598161C0.49329 13.189 0.390283 13.1615 0.299143 13.1097C0.208003 13.0578 0.131826 12.9832 0.0780112 12.8932C0.0268539 12.8015 0 12.6982 0 12.5931C0 12.4881 0.0268539 12.3848 0.0780112 12.293L6.47985 1.08982C6.53679 1.00399 6.61408 0.933574 6.70484 0.884867C6.7956 0.836159 6.897 0.810669 7 0.810669C7.103 0.810669 7.2044 0.836159 7.29516 0.884867C7.38592 0.933574 7.46321 1.00399 7.52015 1.08982L13.922 12.293C13.9731 12.3848 14 12.4881 14 12.5931C14 12.6982 13.9731 12.8015 13.922 12.8932C13.8682 12.9832 13.792 13.0578 13.7009 13.1097C13.6097 13.1615 13.5067 13.189 13.4018 13.1893ZM1.63046 11.989H12.3695L7 2.59425L1.63046 11.989Z",
|
||||
fill: "currentColor"
|
||||
}, null, -1);
|
||||
var _hoisted_2$1 = /* @__PURE__ */ createBaseVNode("path", {
|
||||
d: "M6.99996 8.78801C6.84143 8.78594 6.68997 8.72204 6.57787 8.60993C6.46576 8.49782 6.40186 8.34637 6.39979 8.18784V5.38703C6.39979 5.22786 6.46302 5.0752 6.57557 4.96265C6.68813 4.85009 6.84078 4.78686 6.99996 4.78686C7.15914 4.78686 7.31179 4.85009 7.42435 4.96265C7.5369 5.0752 7.60013 5.22786 7.60013 5.38703V8.18784C7.59806 8.34637 7.53416 8.49782 7.42205 8.60993C7.30995 8.72204 7.15849 8.78594 6.99996 8.78801Z",
|
||||
fill: "currentColor"
|
||||
}, null, -1);
|
||||
var _hoisted_3 = /* @__PURE__ */ createBaseVNode("path", {
|
||||
d: "M6.99996 11.1887C6.84143 11.1866 6.68997 11.1227 6.57787 11.0106C6.46576 10.8985 6.40186 10.7471 6.39979 10.5885V10.1884C6.39979 10.0292 6.46302 9.87658 6.57557 9.76403C6.68813 9.65147 6.84078 9.58824 6.99996 9.58824C7.15914 9.58824 7.31179 9.65147 7.42435 9.76403C7.5369 9.87658 7.60013 10.0292 7.60013 10.1884V10.5885C7.59806 10.7471 7.53416 10.8985 7.42205 11.0106C7.30995 11.1227 7.15849 11.1866 6.99996 11.1887Z",
|
||||
fill: "currentColor"
|
||||
}, null, -1);
|
||||
var _hoisted_4 = [_hoisted_1$1, _hoisted_2$1, _hoisted_3];
|
||||
function render$1(_ctx, _cache, $props, $setup, $data, $options) {
|
||||
return openBlock(), createElementBlock("svg", mergeProps({
|
||||
width: "14",
|
||||
height: "14",
|
||||
viewBox: "0 0 14 14",
|
||||
fill: "none",
|
||||
xmlns: "http://www.w3.org/2000/svg"
|
||||
}, _ctx.pti()), _hoisted_4, 16);
|
||||
}
|
||||
__name(render$1, "render$1");
|
||||
script$1.render = render$1;
|
||||
var script = {
|
||||
name: "InfoCircleIcon",
|
||||
"extends": script$4
|
||||
};
|
||||
var _hoisted_1 = /* @__PURE__ */ createBaseVNode("path", {
|
||||
"fill-rule": "evenodd",
|
||||
"clip-rule": "evenodd",
|
||||
d: "M3.11101 12.8203C4.26215 13.5895 5.61553 14 7 14C8.85652 14 10.637 13.2625 11.9497 11.9497C13.2625 10.637 14 8.85652 14 7C14 5.61553 13.5895 4.26215 12.8203 3.11101C12.0511 1.95987 10.9579 1.06266 9.67879 0.532846C8.3997 0.00303296 6.99224 -0.13559 5.63437 0.134506C4.2765 0.404603 3.02922 1.07129 2.05026 2.05026C1.07129 3.02922 0.404603 4.2765 0.134506 5.63437C-0.13559 6.99224 0.00303296 8.3997 0.532846 9.67879C1.06266 10.9579 1.95987 12.0511 3.11101 12.8203ZM3.75918 2.14976C4.71846 1.50879 5.84628 1.16667 7 1.16667C8.5471 1.16667 10.0308 1.78125 11.1248 2.87521C12.2188 3.96918 12.8333 5.45291 12.8333 7C12.8333 8.15373 12.4912 9.28154 11.8502 10.2408C11.2093 11.2001 10.2982 11.9478 9.23232 12.3893C8.16642 12.8308 6.99353 12.9463 5.86198 12.7212C4.73042 12.4962 3.69102 11.9406 2.87521 11.1248C2.05941 10.309 1.50384 9.26958 1.27876 8.13803C1.05367 7.00647 1.16919 5.83358 1.61071 4.76768C2.05222 3.70178 2.79989 2.79074 3.75918 2.14976ZM7.00002 4.8611C6.84594 4.85908 6.69873 4.79698 6.58977 4.68801C6.48081 4.57905 6.4187 4.43185 6.41669 4.27776V3.88888C6.41669 3.73417 6.47815 3.58579 6.58754 3.4764C6.69694 3.367 6.84531 3.30554 7.00002 3.30554C7.15473 3.30554 7.3031 3.367 7.4125 3.4764C7.52189 3.58579 7.58335 3.73417 7.58335 3.88888V4.27776C7.58134 4.43185 7.51923 4.57905 7.41027 4.68801C7.30131 4.79698 7.1541 4.85908 7.00002 4.8611ZM7.00002 10.6945C6.84594 10.6925 6.69873 10.6304 6.58977 10.5214C6.48081 10.4124 6.4187 10.2652 6.41669 10.1111V6.22225C6.41669 6.06754 6.47815 5.91917 6.58754 5.80977C6.69694 5.70037 6.84531 5.63892 7.00002 5.63892C7.15473 5.63892 7.3031 5.70037 7.4125 5.80977C7.52189 5.91917 7.58335 6.06754 7.58335 6.22225V10.1111C7.58134 10.2652 7.51923 10.4124 7.41027 10.5214C7.30131 10.6304 7.1541 10.6925 7.00002 10.6945Z",
|
||||
fill: "currentColor"
|
||||
}, null, -1);
|
||||
var _hoisted_2 = [_hoisted_1];
|
||||
function render(_ctx, _cache, $props, $setup, $data, $options) {
|
||||
return openBlock(), createElementBlock("svg", mergeProps({
|
||||
width: "14",
|
||||
height: "14",
|
||||
viewBox: "0 0 14 14",
|
||||
fill: "none",
|
||||
xmlns: "http://www.w3.org/2000/svg"
|
||||
}, _ctx.pti()), _hoisted_2, 16);
|
||||
}
|
||||
__name(render, "render");
|
||||
script.render = render;
|
||||
export {
|
||||
script$1 as a,
|
||||
script$3 as b,
|
||||
script$2 as c,
|
||||
script as s
|
||||
};
|
||||
//# sourceMappingURL=index-C_wOqB0f.js.map
|
1
web/assets/index-C_wOqB0f.js.map
generated
vendored
1
web/assets/index-C_wOqB0f.js.map
generated
vendored
File diff suppressed because one or more lines are too long
644
web/assets/index-CwRXxFdA.js → web/assets/index-CjwCGacA.js
generated
vendored
644
web/assets/index-CwRXxFdA.js → web/assets/index-CjwCGacA.js
generated
vendored
File diff suppressed because one or more lines are too long
1
web/assets/index-CjwCGacA.js.map
generated
vendored
Normal file
1
web/assets/index-CjwCGacA.js.map
generated
vendored
Normal file
File diff suppressed because one or more lines are too long
1
web/assets/index-CwRXxFdA.js.map
generated
vendored
1
web/assets/index-CwRXxFdA.js.map
generated
vendored
File diff suppressed because one or more lines are too long
50
web/assets/index-MX9DEi8Q.js
generated
vendored
Normal file
50
web/assets/index-MX9DEi8Q.js
generated
vendored
Normal file
@@ -0,0 +1,50 @@
|
||||
var __defProp = Object.defineProperty;
|
||||
var __name = (target, value) => __defProp(target, "name", { value, configurable: true });
|
||||
import { c7 as script$2, A as createBaseVNode, g as openBlock, h as createElementBlock, m as mergeProps } from "./index-B6dYHNhg.js";
|
||||
var script$1 = {
|
||||
name: "BarsIcon",
|
||||
"extends": script$2
|
||||
};
|
||||
var _hoisted_1$1 = /* @__PURE__ */ createBaseVNode("path", {
|
||||
"fill-rule": "evenodd",
|
||||
"clip-rule": "evenodd",
|
||||
d: "M13.3226 3.6129H0.677419C0.497757 3.6129 0.325452 3.54152 0.198411 3.41448C0.0713707 3.28744 0 3.11514 0 2.93548C0 2.75581 0.0713707 2.58351 0.198411 2.45647C0.325452 2.32943 0.497757 2.25806 0.677419 2.25806H13.3226C13.5022 2.25806 13.6745 2.32943 13.8016 2.45647C13.9286 2.58351 14 2.75581 14 2.93548C14 3.11514 13.9286 3.28744 13.8016 3.41448C13.6745 3.54152 13.5022 3.6129 13.3226 3.6129ZM13.3226 7.67741H0.677419C0.497757 7.67741 0.325452 7.60604 0.198411 7.479C0.0713707 7.35196 0 7.17965 0 6.99999C0 6.82033 0.0713707 6.64802 0.198411 6.52098C0.325452 6.39394 0.497757 6.32257 0.677419 6.32257H13.3226C13.5022 6.32257 13.6745 6.39394 13.8016 6.52098C13.9286 6.64802 14 6.82033 14 6.99999C14 7.17965 13.9286 7.35196 13.8016 7.479C13.6745 7.60604 13.5022 7.67741 13.3226 7.67741ZM0.677419 11.7419H13.3226C13.5022 11.7419 13.6745 11.6706 13.8016 11.5435C13.9286 11.4165 14 11.2442 14 11.0645C14 10.8848 13.9286 10.7125 13.8016 10.5855C13.6745 10.4585 13.5022 10.3871 13.3226 10.3871H0.677419C0.497757 10.3871 0.325452 10.4585 0.198411 10.5855C0.0713707 10.7125 0 10.8848 0 11.0645C0 11.2442 0.0713707 11.4165 0.198411 11.5435C0.325452 11.6706 0.497757 11.7419 0.677419 11.7419Z",
|
||||
fill: "currentColor"
|
||||
}, null, -1);
|
||||
var _hoisted_2$1 = [_hoisted_1$1];
|
||||
function render$1(_ctx, _cache, $props, $setup, $data, $options) {
|
||||
return openBlock(), createElementBlock("svg", mergeProps({
|
||||
width: "14",
|
||||
height: "14",
|
||||
viewBox: "0 0 14 14",
|
||||
fill: "none",
|
||||
xmlns: "http://www.w3.org/2000/svg"
|
||||
}, _ctx.pti()), _hoisted_2$1, 16);
|
||||
}
|
||||
__name(render$1, "render$1");
|
||||
script$1.render = render$1;
|
||||
var script = {
|
||||
name: "PlusIcon",
|
||||
"extends": script$2
|
||||
};
|
||||
var _hoisted_1 = /* @__PURE__ */ createBaseVNode("path", {
|
||||
d: "M7.67742 6.32258V0.677419C7.67742 0.497757 7.60605 0.325452 7.47901 0.198411C7.35197 0.0713707 7.17966 0 7 0C6.82034 0 6.64803 0.0713707 6.52099 0.198411C6.39395 0.325452 6.32258 0.497757 6.32258 0.677419V6.32258H0.677419C0.497757 6.32258 0.325452 6.39395 0.198411 6.52099C0.0713707 6.64803 0 6.82034 0 7C0 7.17966 0.0713707 7.35197 0.198411 7.47901C0.325452 7.60605 0.497757 7.67742 0.677419 7.67742H6.32258V13.3226C6.32492 13.5015 6.39704 13.6725 6.52358 13.799C6.65012 13.9255 6.82106 13.9977 7 14C7.17966 14 7.35197 13.9286 7.47901 13.8016C7.60605 13.6745 7.67742 13.5022 7.67742 13.3226V7.67742H13.3226C13.5022 7.67742 13.6745 7.60605 13.8016 7.47901C13.9286 7.35197 14 7.17966 14 7C13.9977 6.82106 13.9255 6.65012 13.799 6.52358C13.6725 6.39704 13.5015 6.32492 13.3226 6.32258H7.67742Z",
|
||||
fill: "currentColor"
|
||||
}, null, -1);
|
||||
var _hoisted_2 = [_hoisted_1];
|
||||
function render(_ctx, _cache, $props, $setup, $data, $options) {
|
||||
return openBlock(), createElementBlock("svg", mergeProps({
|
||||
width: "14",
|
||||
height: "14",
|
||||
viewBox: "0 0 14 14",
|
||||
fill: "none",
|
||||
xmlns: "http://www.w3.org/2000/svg"
|
||||
}, _ctx.pti()), _hoisted_2, 16);
|
||||
}
|
||||
__name(render, "render");
|
||||
script.render = render;
|
||||
export {
|
||||
script as a,
|
||||
script$1 as s
|
||||
};
|
||||
//# sourceMappingURL=index-MX9DEi8Q.js.map
|
1
web/assets/index-MX9DEi8Q.js.map
generated
vendored
Normal file
1
web/assets/index-MX9DEi8Q.js.map
generated
vendored
Normal file
@@ -0,0 +1 @@
|
||||
{"version":3,"file":"index-MX9DEi8Q.js","sources":["../../node_modules/@primevue/icons/bars/index.mjs","../../node_modules/@primevue/icons/plus/index.mjs"],"sourcesContent":["import BaseIcon from '@primevue/icons/baseicon';\nimport { openBlock, createElementBlock, mergeProps, createElementVNode } from 'vue';\n\nvar script = {\n name: 'BarsIcon',\n \"extends\": BaseIcon\n};\n\nvar _hoisted_1 = /*#__PURE__*/createElementVNode(\"path\", {\n \"fill-rule\": \"evenodd\",\n \"clip-rule\": \"evenodd\",\n d: \"M13.3226 3.6129H0.677419C0.497757 3.6129 0.325452 3.54152 0.198411 3.41448C0.0713707 3.28744 0 3.11514 0 2.93548C0 2.75581 0.0713707 2.58351 0.198411 2.45647C0.325452 2.32943 0.497757 2.25806 0.677419 2.25806H13.3226C13.5022 2.25806 13.6745 2.32943 13.8016 2.45647C13.9286 2.58351 14 2.75581 14 2.93548C14 3.11514 13.9286 3.28744 13.8016 3.41448C13.6745 3.54152 13.5022 3.6129 13.3226 3.6129ZM13.3226 7.67741H0.677419C0.497757 7.67741 0.325452 7.60604 0.198411 7.479C0.0713707 7.35196 0 7.17965 0 6.99999C0 6.82033 0.0713707 6.64802 0.198411 6.52098C0.325452 6.39394 0.497757 6.32257 0.677419 6.32257H13.3226C13.5022 6.32257 13.6745 6.39394 13.8016 6.52098C13.9286 6.64802 14 6.82033 14 6.99999C14 7.17965 13.9286 7.35196 13.8016 7.479C13.6745 7.60604 13.5022 7.67741 13.3226 7.67741ZM0.677419 11.7419H13.3226C13.5022 11.7419 13.6745 11.6706 13.8016 11.5435C13.9286 11.4165 14 11.2442 14 11.0645C14 10.8848 13.9286 10.7125 13.8016 10.5855C13.6745 10.4585 13.5022 10.3871 13.3226 10.3871H0.677419C0.497757 10.3871 0.325452 10.4585 0.198411 10.5855C0.0713707 10.7125 0 10.8848 0 11.0645C0 11.2442 0.0713707 11.4165 0.198411 11.5435C0.325452 11.6706 0.497757 11.7419 0.677419 11.7419Z\",\n fill: \"currentColor\"\n}, null, -1);\nvar _hoisted_2 = [_hoisted_1];\nfunction render(_ctx, _cache, $props, $setup, $data, $options) {\n return openBlock(), createElementBlock(\"svg\", mergeProps({\n width: \"14\",\n height: \"14\",\n viewBox: \"0 0 14 14\",\n fill: \"none\",\n xmlns: \"http://www.w3.org/2000/svg\"\n }, _ctx.pti()), _hoisted_2, 16);\n}\n\nscript.render = render;\n\nexport { script as default };\n//# sourceMappingURL=index.mjs.map\n","import BaseIcon from '@primevue/icons/baseicon';\nimport { openBlock, createElementBlock, mergeProps, createElementVNode } from 'vue';\n\nvar script = {\n name: 'PlusIcon',\n \"extends\": BaseIcon\n};\n\nvar _hoisted_1 = /*#__PURE__*/createElementVNode(\"path\", {\n d: \"M7.67742 6.32258V0.677419C7.67742 0.497757 7.60605 0.325452 7.47901 0.198411C7.35197 0.0713707 7.17966 0 7 0C6.82034 0 6.64803 0.0713707 6.52099 0.198411C6.39395 0.325452 6.32258 0.497757 6.32258 0.677419V6.32258H0.677419C0.497757 6.32258 0.325452 6.39395 0.198411 6.52099C0.0713707 6.64803 0 6.82034 0 7C0 7.17966 0.0713707 7.35197 0.198411 7.47901C0.325452 7.60605 0.497757 7.67742 0.677419 7.67742H6.32258V13.3226C6.32492 13.5015 6.39704 13.6725 6.52358 13.799C6.65012 13.9255 6.82106 13.9977 7 14C7.17966 14 7.35197 13.9286 7.47901 13.8016C7.60605 13.6745 7.67742 13.5022 7.67742 13.3226V7.67742H13.3226C13.5022 7.67742 13.6745 7.60605 13.8016 7.47901C13.9286 7.35197 14 7.17966 14 7C13.9977 6.82106 13.9255 6.65012 13.799 6.52358C13.6725 6.39704 13.5015 6.32492 13.3226 6.32258H7.67742Z\",\n fill: \"currentColor\"\n}, null, -1);\nvar _hoisted_2 = [_hoisted_1];\nfunction render(_ctx, _cache, $props, $setup, $data, $options) {\n return openBlock(), createElementBlock(\"svg\", mergeProps({\n width: \"14\",\n height: \"14\",\n viewBox: \"0 0 14 14\",\n fill: \"none\",\n xmlns: \"http://www.w3.org/2000/svg\"\n }, _ctx.pti()), _hoisted_2, 16);\n}\n\nscript.render = render;\n\nexport { script as default };\n//# sourceMappingURL=index.mjs.map\n"],"names":["script","BaseIcon","_hoisted_1","createElementVNode","_hoisted_2","render"],"mappings":";;;AAGG,IAACA,WAAS;AAAA,EACX,MAAM;AAAA,EACN,WAAWC;AACb;AAEA,IAAIC,eAA0BC,gCAAmB,QAAQ;AAAA,EACvD,aAAa;AAAA,EACb,aAAa;AAAA,EACb,GAAG;AAAA,EACH,MAAM;AACR,GAAG,MAAM,EAAE;AACX,IAAIC,eAAa,CAACF,YAAU;AAC5B,SAASG,SAAO,MAAM,QAAQ,QAAQ,QAAQ,OAAO,UAAU;AAC7D,SAAO,UAAW,GAAE,mBAAmB,OAAO,WAAW;AAAA,IACvD,OAAO;AAAA,IACP,QAAQ;AAAA,IACR,SAAS;AAAA,IACT,MAAM;AAAA,IACN,OAAO;AAAA,EACR,GAAE,KAAK,IAAG,CAAE,GAAGD,cAAY,EAAE;AAChC;AARSC;AAUTL,SAAO,SAASK;ACtBb,IAAC,SAAS;AAAA,EACX,MAAM;AAAA,EACN,WAAWJ;AACb;AAEA,IAAI,aAA0BE,gCAAmB,QAAQ;AAAA,EACvD,GAAG;AAAA,EACH,MAAM;AACR,GAAG,MAAM,EAAE;AACX,IAAI,aAAa,CAAC,UAAU;AAC5B,SAAS,OAAO,MAAM,QAAQ,QAAQ,QAAQ,OAAO,UAAU;AAC7D,SAAO,UAAW,GAAE,mBAAmB,OAAO,WAAW;AAAA,IACvD,OAAO;AAAA,IACP,QAAQ;AAAA,IACR,SAAS;AAAA,IACT,MAAM;AAAA,IACN,OAAO;AAAA,EACR,GAAE,KAAK,IAAG,CAAE,GAAG,YAAY,EAAE;AAChC;AARS;AAUT,OAAO,SAAS;","x_google_ignoreList":[0,1]}
|
4
web/assets/userSelection-DITGVoWz.js → web/assets/userSelection-BSkuSZyR.js
generated
vendored
4
web/assets/userSelection-DITGVoWz.js → web/assets/userSelection-BSkuSZyR.js
generated
vendored
@@ -1,6 +1,6 @@
|
||||
var __defProp = Object.defineProperty;
|
||||
var __name = (target, value) => __defProp(target, "name", { value, configurable: true });
|
||||
import { bC as api, bG as $el } from "./index-BHayQCxv.js";
|
||||
import { bH as api, bW as $el } from "./index-B6dYHNhg.js";
|
||||
function createSpinner() {
|
||||
const div = document.createElement("div");
|
||||
div.innerHTML = `<div class="lds-ring"><div></div><div></div><div></div><div></div></div>`;
|
||||
@@ -126,4 +126,4 @@ window.comfyAPI.userSelection.UserSelectionScreen = UserSelectionScreen;
|
||||
export {
|
||||
UserSelectionScreen
|
||||
};
|
||||
//# sourceMappingURL=userSelection-DITGVoWz.js.map
|
||||
//# sourceMappingURL=userSelection-BSkuSZyR.js.map
|
2
web/assets/userSelection-DITGVoWz.js.map → web/assets/userSelection-BSkuSZyR.js.map
generated
vendored
2
web/assets/userSelection-DITGVoWz.js.map → web/assets/userSelection-BSkuSZyR.js.map
generated
vendored
File diff suppressed because one or more lines are too long
28
web/assets/widgetInputs-DdecKYqd.js → web/assets/widgetInputs-BJ21PG7d.js
generated
vendored
28
web/assets/widgetInputs-DdecKYqd.js → web/assets/widgetInputs-BJ21PG7d.js
generated
vendored
@@ -1,6 +1,6 @@
|
||||
var __defProp = Object.defineProperty;
|
||||
var __name = (target, value) => __defProp(target, "name", { value, configurable: true });
|
||||
import { e as LGraphNode, c as app, bL as applyTextReplacements, bK as ComfyWidgets, bM as addValueControlWidgets, k as LiteGraph } from "./index-BHayQCxv.js";
|
||||
import { e as LGraphNode, c as app, c1 as applyTextReplacements, c0 as ComfyWidgets, c2 as addValueControlWidgets, k as LiteGraph } from "./index-B6dYHNhg.js";
|
||||
const CONVERTED_TYPE = "converted-widget";
|
||||
const VALID_TYPES = [
|
||||
"STRING",
|
||||
@@ -171,7 +171,7 @@ class PrimitiveNode extends LGraphNode {
|
||||
if (type instanceof Array) {
|
||||
type = "COMBO";
|
||||
}
|
||||
const size = this.size;
|
||||
const [oldWidth, oldHeight] = this.size;
|
||||
let widget;
|
||||
if (type in ComfyWidgets) {
|
||||
widget = (ComfyWidgets[type](this, "value", inputData, app) || {}).widget;
|
||||
@@ -218,8 +218,8 @@ class PrimitiveNode extends LGraphNode {
|
||||
return r;
|
||||
};
|
||||
this.size = [
|
||||
Math.max(this.size[0], size[0]),
|
||||
Math.max(this.size[1], size[1])
|
||||
Math.max(this.size[0], oldWidth),
|
||||
Math.max(this.size[1], oldHeight)
|
||||
];
|
||||
if (!recreating) {
|
||||
const sz = this.computeSize();
|
||||
@@ -320,7 +320,7 @@ class PrimitiveNode extends LGraphNode {
|
||||
}
|
||||
}
|
||||
function getWidgetConfig(slot) {
|
||||
return slot.widget[CONFIG] ?? slot.widget[GET_CONFIG]();
|
||||
return slot.widget[CONFIG] ?? slot.widget[GET_CONFIG]?.() ?? ["*", {}];
|
||||
}
|
||||
__name(getWidgetConfig, "getWidgetConfig");
|
||||
function getConfig(widgetName) {
|
||||
@@ -373,7 +373,7 @@ __name(showWidget, "showWidget");
|
||||
function convertToInput(node, widget, config) {
|
||||
hideWidget(node, widget);
|
||||
const { type } = getWidgetType(config);
|
||||
const sz = node.size;
|
||||
const [oldWidth, oldHeight] = node.size;
|
||||
const inputIsOptional = !!widget.options?.inputIsOptional;
|
||||
const input = node.addInput(widget.name, type, {
|
||||
widget: { name: widget.name, [GET_CONFIG]: () => config },
|
||||
@@ -382,18 +382,24 @@ function convertToInput(node, widget, config) {
|
||||
for (const widget2 of node.widgets) {
|
||||
widget2.last_y += LiteGraph.NODE_SLOT_HEIGHT;
|
||||
}
|
||||
node.setSize([Math.max(sz[0], node.size[0]), Math.max(sz[1], node.size[1])]);
|
||||
node.setSize([
|
||||
Math.max(oldWidth, node.size[0]),
|
||||
Math.max(oldHeight, node.size[1])
|
||||
]);
|
||||
return input;
|
||||
}
|
||||
__name(convertToInput, "convertToInput");
|
||||
function convertToWidget(node, widget) {
|
||||
showWidget(widget);
|
||||
const sz = node.size;
|
||||
const [oldWidth, oldHeight] = node.size;
|
||||
node.removeInput(node.inputs.findIndex((i) => i.widget?.name === widget.name));
|
||||
for (const widget2 of node.widgets) {
|
||||
widget2.last_y -= LiteGraph.NODE_SLOT_HEIGHT;
|
||||
}
|
||||
node.setSize([Math.max(sz[0], node.size[0]), Math.max(sz[1], node.size[1])]);
|
||||
node.setSize([
|
||||
Math.max(oldWidth, node.size[0]),
|
||||
Math.max(oldHeight, node.size[1])
|
||||
]);
|
||||
}
|
||||
__name(convertToWidget, "convertToWidget");
|
||||
function getWidgetType(config) {
|
||||
@@ -450,7 +456,7 @@ function setWidgetConfig(slot, config, target) {
|
||||
__name(setWidgetConfig, "setWidgetConfig");
|
||||
function mergeIfValid(output, config2, forceUpdate, recreateWidget, config1) {
|
||||
if (!config1) {
|
||||
config1 = output.widget[CONFIG] ?? output.widget[GET_CONFIG]();
|
||||
config1 = getWidgetConfig(output);
|
||||
}
|
||||
if (config1[0] instanceof Array) {
|
||||
if (!isValidCombo(config1[0], config2[0])) return;
|
||||
@@ -753,4 +759,4 @@ export {
|
||||
mergeIfValid,
|
||||
setWidgetConfig
|
||||
};
|
||||
//# sourceMappingURL=widgetInputs-DdecKYqd.js.map
|
||||
//# sourceMappingURL=widgetInputs-BJ21PG7d.js.map
|
1
web/assets/widgetInputs-BJ21PG7d.js.map
generated
vendored
Normal file
1
web/assets/widgetInputs-BJ21PG7d.js.map
generated
vendored
Normal file
File diff suppressed because one or more lines are too long
1
web/assets/widgetInputs-DdecKYqd.js.map
generated
vendored
1
web/assets/widgetInputs-DdecKYqd.js.map
generated
vendored
File diff suppressed because one or more lines are too long
3
web/extensions/core/vintageClipboard.js
vendored
Normal file
3
web/extensions/core/vintageClipboard.js
vendored
Normal file
@@ -0,0 +1,3 @@
|
||||
// Shim for extensions/core/vintageClipboard.ts
|
||||
export const serialise = window.comfyAPI.vintageClipboard.serialise;
|
||||
export const deserialiseAndCreate = window.comfyAPI.vintageClipboard.deserialiseAndCreate;
|
4
web/index.html
vendored
4
web/index.html
vendored
@@ -6,8 +6,8 @@
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0, user-scalable=no">
|
||||
<link rel="stylesheet" type="text/css" href="user.css" />
|
||||
<link rel="stylesheet" type="text/css" href="materialdesignicons.min.css" />
|
||||
<script type="module" crossorigin src="./assets/index-BHayQCxv.js"></script>
|
||||
<link rel="stylesheet" crossorigin href="./assets/index-BitceZ14.css">
|
||||
<script type="module" crossorigin src="./assets/index-B6dYHNhg.js"></script>
|
||||
<link rel="stylesheet" crossorigin href="./assets/index-BCoLUtIt.css">
|
||||
</head>
|
||||
<body class="litegraph grid">
|
||||
<div id="vue-app"></div>
|
||||
|
1
web/scripts/changeTracker.js
vendored
1
web/scripts/changeTracker.js
vendored
@@ -1,3 +1,2 @@
|
||||
// Shim for scripts/changeTracker.ts
|
||||
export const ChangeTracker = window.comfyAPI.changeTracker.ChangeTracker;
|
||||
export const globalTracker = window.comfyAPI.changeTracker.globalTracker;
|
||||
|
2
web/scripts/defaultGraph.js
vendored
2
web/scripts/defaultGraph.js
vendored
@@ -1,2 +1,4 @@
|
||||
// Shim for scripts/defaultGraph.ts
|
||||
export const defaultGraph = window.comfyAPI.defaultGraph.defaultGraph;
|
||||
export const defaultGraphJSON = window.comfyAPI.defaultGraph.defaultGraphJSON;
|
||||
export const blankGraph = window.comfyAPI.defaultGraph.blankGraph;
|
||||
|
2
web/scripts/domWidget.js
vendored
2
web/scripts/domWidget.js
vendored
@@ -1,2 +0,0 @@
|
||||
// Shim for scripts/domWidget.ts
|
||||
export const addDomClippingSetting = window.comfyAPI.domWidget.addDomClippingSetting;
|
3
web/scripts/workflows.js
vendored
3
web/scripts/workflows.js
vendored
@@ -1,3 +0,0 @@
|
||||
// Shim for scripts/workflows.ts
|
||||
export const ComfyWorkflowManager = window.comfyAPI.workflows.ComfyWorkflowManager;
|
||||
export const ComfyWorkflow = window.comfyAPI.workflows.ComfyWorkflow;
|
Reference in New Issue
Block a user