mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-08-04 07:52:46 +08:00
Compare commits
35 Commits
Author | SHA1 | Date | |
---|---|---|---|
|
1d47ec38d8 | ||
|
f87810cd3e | ||
|
10c919f4c7 | ||
|
ce80e69fb8 | ||
|
19944ad252 | ||
|
10b43ceea5 | ||
|
0a4c49c57c | ||
|
88ed893034 | ||
|
334ba48cea | ||
|
14764aa2e2 | ||
|
b2c995f623 | ||
|
4151fbfa8a | ||
|
6045ed31f8 | ||
|
f836e69346 | ||
|
5b69cfe7c3 | ||
|
95fa9545f1 | ||
|
11b74147ee | ||
|
6ab8cad22e | ||
|
011b11d8d7 | ||
|
ff6ca2a892 | ||
|
374e093e09 | ||
|
855789403b | ||
|
6f7869f365 | ||
|
281ad42df4 | ||
|
1cde6b2eff | ||
|
c5a48b15bd | ||
|
f2298799ba | ||
|
60383f3b64 | ||
|
8270c62530 | ||
|
821f93872e | ||
|
e1630391d6 | ||
|
99458e8aca | ||
|
33346fd9b8 | ||
|
136c93cb47 | ||
|
1305fb294c |
23
.github/workflows/pylint.yml
vendored
Normal file
23
.github/workflows/pylint.yml
vendored
Normal file
@@ -0,0 +1,23 @@
|
||||
name: Python Linting
|
||||
|
||||
on: [push, pull_request]
|
||||
|
||||
jobs:
|
||||
pylint:
|
||||
name: Run Pylint
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v2
|
||||
with:
|
||||
python-version: 3.x
|
||||
|
||||
- name: Install Pylint
|
||||
run: pip install pylint
|
||||
|
||||
- name: Run Pylint
|
||||
run: pylint --rcfile=.pylintrc $(find . -type f -name "*.py")
|
4
.github/workflows/test-ui.yaml
vendored
4
.github/workflows/test-ui.yaml
vendored
@@ -24,3 +24,7 @@ jobs:
|
||||
npm run test:generate
|
||||
npm test -- --verbose
|
||||
working-directory: ./tests-ui
|
||||
- name: Run Unit Tests
|
||||
run: |
|
||||
pip install -r tests-unit/requirements.txt
|
||||
python -m pytest tests-unit
|
||||
|
3
.gitignore
vendored
3
.gitignore
vendored
@@ -17,4 +17,5 @@ venv/
|
||||
!/web/extensions/core/
|
||||
/tests-ui/data/object_info.json
|
||||
/user/
|
||||
*.log
|
||||
*.log
|
||||
web_custom_versions/
|
0
app/__init__.py
Normal file
0
app/__init__.py
Normal file
188
app/frontend_management.py
Normal file
188
app/frontend_management.py
Normal file
@@ -0,0 +1,188 @@
|
||||
from __future__ import annotations
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import tempfile
|
||||
import zipfile
|
||||
from dataclasses import dataclass
|
||||
from functools import cached_property
|
||||
from pathlib import Path
|
||||
from typing import TypedDict
|
||||
|
||||
import requests
|
||||
from typing_extensions import NotRequired
|
||||
from comfy.cli_args import DEFAULT_VERSION_STRING
|
||||
|
||||
|
||||
REQUEST_TIMEOUT = 10 # seconds
|
||||
|
||||
|
||||
class Asset(TypedDict):
|
||||
url: str
|
||||
|
||||
|
||||
class Release(TypedDict):
|
||||
id: int
|
||||
tag_name: str
|
||||
name: str
|
||||
prerelease: bool
|
||||
created_at: str
|
||||
published_at: str
|
||||
body: str
|
||||
assets: NotRequired[list[Asset]]
|
||||
|
||||
|
||||
@dataclass
|
||||
class FrontEndProvider:
|
||||
owner: str
|
||||
repo: str
|
||||
|
||||
@property
|
||||
def folder_name(self) -> str:
|
||||
return f"{self.owner}_{self.repo}"
|
||||
|
||||
@property
|
||||
def release_url(self) -> str:
|
||||
return f"https://api.github.com/repos/{self.owner}/{self.repo}/releases"
|
||||
|
||||
@cached_property
|
||||
def all_releases(self) -> list[Release]:
|
||||
releases = []
|
||||
api_url = self.release_url
|
||||
while api_url:
|
||||
response = requests.get(api_url, timeout=REQUEST_TIMEOUT)
|
||||
response.raise_for_status() # Raises an HTTPError if the response was an error
|
||||
releases.extend(response.json())
|
||||
# GitHub uses the Link header to provide pagination links. Check if it exists and update api_url accordingly.
|
||||
if "next" in response.links:
|
||||
api_url = response.links["next"]["url"]
|
||||
else:
|
||||
api_url = None
|
||||
return releases
|
||||
|
||||
@cached_property
|
||||
def latest_release(self) -> Release:
|
||||
latest_release_url = f"{self.release_url}/latest"
|
||||
response = requests.get(latest_release_url, timeout=REQUEST_TIMEOUT)
|
||||
response.raise_for_status() # Raises an HTTPError if the response was an error
|
||||
return response.json()
|
||||
|
||||
def get_release(self, version: str) -> Release:
|
||||
if version == "latest":
|
||||
return self.latest_release
|
||||
else:
|
||||
for release in self.all_releases:
|
||||
if release["tag_name"] in [version, f"v{version}"]:
|
||||
return release
|
||||
raise ValueError(f"Version {version} not found in releases")
|
||||
|
||||
|
||||
def download_release_asset_zip(release: Release, destination_path: str) -> None:
|
||||
"""Download dist.zip from github release."""
|
||||
asset_url = None
|
||||
for asset in release.get("assets", []):
|
||||
if asset["name"] == "dist.zip":
|
||||
asset_url = asset["url"]
|
||||
break
|
||||
|
||||
if not asset_url:
|
||||
raise ValueError("dist.zip not found in the release assets")
|
||||
|
||||
# Use a temporary file to download the zip content
|
||||
with tempfile.TemporaryFile() as tmp_file:
|
||||
headers = {"Accept": "application/octet-stream"}
|
||||
response = requests.get(
|
||||
asset_url, headers=headers, allow_redirects=True, timeout=REQUEST_TIMEOUT
|
||||
)
|
||||
response.raise_for_status() # Ensure we got a successful response
|
||||
|
||||
# Write the content to the temporary file
|
||||
tmp_file.write(response.content)
|
||||
|
||||
# Go back to the beginning of the temporary file
|
||||
tmp_file.seek(0)
|
||||
|
||||
# Extract the zip file content to the destination path
|
||||
with zipfile.ZipFile(tmp_file, "r") as zip_ref:
|
||||
zip_ref.extractall(destination_path)
|
||||
|
||||
|
||||
class FrontendManager:
|
||||
DEFAULT_FRONTEND_PATH = str(Path(__file__).parents[1] / "web")
|
||||
CUSTOM_FRONTENDS_ROOT = str(Path(__file__).parents[1] / "web_custom_versions")
|
||||
|
||||
@classmethod
|
||||
def parse_version_string(cls, value: str) -> tuple[str, str, str]:
|
||||
"""
|
||||
Args:
|
||||
value (str): The version string to parse.
|
||||
|
||||
Returns:
|
||||
tuple[str, str]: A tuple containing provider name and version.
|
||||
|
||||
Raises:
|
||||
argparse.ArgumentTypeError: If the version string is invalid.
|
||||
"""
|
||||
VERSION_PATTERN = r"^([a-zA-Z0-9][a-zA-Z0-9-]{0,38})/([a-zA-Z0-9_.-]+)@(v?\d+\.\d+\.\d+|latest)$"
|
||||
match_result = re.match(VERSION_PATTERN, value)
|
||||
if match_result is None:
|
||||
raise argparse.ArgumentTypeError(f"Invalid version string: {value}")
|
||||
|
||||
return match_result.group(1), match_result.group(2), match_result.group(3)
|
||||
|
||||
@classmethod
|
||||
def init_frontend_unsafe(cls, version_string: str) -> str:
|
||||
"""
|
||||
Initializes the frontend for the specified version.
|
||||
|
||||
Args:
|
||||
version_string (str): The version string.
|
||||
|
||||
Returns:
|
||||
str: The path to the initialized frontend.
|
||||
|
||||
Raises:
|
||||
Exception: If there is an error during the initialization process.
|
||||
main error source might be request timeout or invalid URL.
|
||||
"""
|
||||
if version_string == DEFAULT_VERSION_STRING:
|
||||
return cls.DEFAULT_FRONTEND_PATH
|
||||
|
||||
repo_owner, repo_name, version = cls.parse_version_string(version_string)
|
||||
provider = FrontEndProvider(repo_owner, repo_name)
|
||||
release = provider.get_release(version)
|
||||
|
||||
semantic_version = release["tag_name"].lstrip("v")
|
||||
web_root = str(
|
||||
Path(cls.CUSTOM_FRONTENDS_ROOT) / provider.folder_name / semantic_version
|
||||
)
|
||||
if not os.path.exists(web_root):
|
||||
os.makedirs(web_root, exist_ok=True)
|
||||
logging.info(
|
||||
"Downloading frontend(%s) version(%s) to (%s)",
|
||||
provider.folder_name,
|
||||
semantic_version,
|
||||
web_root,
|
||||
)
|
||||
logging.debug(release)
|
||||
download_release_asset_zip(release, destination_path=web_root)
|
||||
return web_root
|
||||
|
||||
@classmethod
|
||||
def init_frontend(cls, version_string: str) -> str:
|
||||
"""
|
||||
Initializes the frontend with the specified version string.
|
||||
|
||||
Args:
|
||||
version_string (str): The version string to initialize the frontend with.
|
||||
|
||||
Returns:
|
||||
str: The path of the initialized frontend.
|
||||
"""
|
||||
try:
|
||||
return cls.init_frontend_unsafe(version_string)
|
||||
except Exception as e:
|
||||
logging.error("Failed to initialize frontend: %s", e)
|
||||
logging.info("Falling back to the default frontend.")
|
||||
return cls.DEFAULT_FRONTEND_PATH
|
@@ -13,6 +13,7 @@ from ..ldm.modules.diffusionmodules.util import (
|
||||
from ..ldm.modules.attention import SpatialTransformer
|
||||
from ..ldm.modules.diffusionmodules.openaimodel import UNetModel, TimestepEmbedSequential, ResBlock, Downsample
|
||||
from ..ldm.util import exists
|
||||
from .control_types import UNION_CONTROLNET_TYPES
|
||||
from collections import OrderedDict
|
||||
import comfy.ops
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
@@ -361,7 +362,7 @@ class ControlNet(nn.Module):
|
||||
controlnet_cond = self.input_hint_block(hint[idx], emb, context)
|
||||
feat_seq = torch.mean(controlnet_cond, dim=(2, 3))
|
||||
if idx < len(control_type):
|
||||
feat_seq += self.task_embedding[control_type[idx]]
|
||||
feat_seq += self.task_embedding[control_type[idx]].to(dtype=feat_seq.dtype, device=feat_seq.device)
|
||||
|
||||
inputs.append(feat_seq.unsqueeze(1))
|
||||
condition_list.append(controlnet_cond)
|
||||
@@ -390,6 +391,18 @@ class ControlNet(nn.Module):
|
||||
if self.control_add_embedding is not None: #Union Controlnet
|
||||
control_type = kwargs.get("control_type", [])
|
||||
|
||||
if any([c >= self.num_control_type for c in control_type]):
|
||||
max_type = max(control_type)
|
||||
max_type_name = {
|
||||
v: k for k, v in UNION_CONTROLNET_TYPES.items()
|
||||
}[max_type]
|
||||
raise ValueError(
|
||||
f"Control type {max_type_name}({max_type}) is out of range for the number of control types" +
|
||||
f"({self.num_control_type}) supported.\n" +
|
||||
"Please consider using the ProMax ControlNet Union model.\n" +
|
||||
"https://huggingface.co/xinsir/controlnet-union-sdxl-1.0/tree/main"
|
||||
)
|
||||
|
||||
emb += self.control_add_embedding(control_type, emb.dtype, emb.device)
|
||||
if len(control_type) > 0:
|
||||
if len(hint.shape) < 5:
|
||||
|
10
comfy/cldm/control_types.py
Normal file
10
comfy/cldm/control_types.py
Normal file
@@ -0,0 +1,10 @@
|
||||
UNION_CONTROLNET_TYPES = {
|
||||
"openpose": 0,
|
||||
"depth": 1,
|
||||
"hed/pidi/scribble/ted": 2,
|
||||
"canny/lineart/anime_lineart/mlsd": 3,
|
||||
"normal": 4,
|
||||
"segment": 5,
|
||||
"tile": 6,
|
||||
"repaint": 7,
|
||||
}
|
@@ -1,7 +1,10 @@
|
||||
import argparse
|
||||
import enum
|
||||
import os
|
||||
from typing import Optional
|
||||
import comfy.options
|
||||
|
||||
|
||||
class EnumAction(argparse.Action):
|
||||
"""
|
||||
Argparse action for handling Enums
|
||||
@@ -109,6 +112,7 @@ vram_group.add_argument("--lowvram", action="store_true", help="Split the unet i
|
||||
vram_group.add_argument("--novram", action="store_true", help="When lowvram isn't enough.")
|
||||
vram_group.add_argument("--cpu", action="store_true", help="To use the CPU for everything (slow).")
|
||||
|
||||
parser.add_argument("--default-hashing-function", type=str, choices=['md5', 'sha1', 'sha256', 'sha512'], default='sha256', help="Allows you to choose the hash function to use for duplicate filename / contents comparison. Default is sha256.")
|
||||
|
||||
parser.add_argument("--disable-smart-memory", action="store_true", help="Force ComfyUI to agressively offload to regular ram instead of keeping models in vram when it can.")
|
||||
parser.add_argument("--deterministic", action="store_true", help="Make pytorch use slower deterministic algorithms when it can. Note that this might not make images deterministic in all cases.")
|
||||
@@ -124,6 +128,38 @@ parser.add_argument("--multi-user", action="store_true", help="Enables per-user
|
||||
|
||||
parser.add_argument("--verbose", action="store_true", help="Enables more debug prints.")
|
||||
|
||||
# The default built-in provider hosted under web/
|
||||
DEFAULT_VERSION_STRING = "comfyanonymous/ComfyUI@latest"
|
||||
|
||||
parser.add_argument(
|
||||
"--front-end-version",
|
||||
type=str,
|
||||
default=DEFAULT_VERSION_STRING,
|
||||
help="""
|
||||
Specifies the version of the frontend to be used. This command needs internet connectivity to query and
|
||||
download available frontend implementations from GitHub releases.
|
||||
|
||||
The version string should be in the format of:
|
||||
[repoOwner]/[repoName]@[version]
|
||||
where version is one of: "latest" or a valid version number (e.g. "1.0.0")
|
||||
""",
|
||||
)
|
||||
|
||||
def is_valid_directory(path: Optional[str]) -> Optional[str]:
|
||||
"""Validate if the given path is a directory."""
|
||||
if path is None:
|
||||
return None
|
||||
|
||||
if not os.path.isdir(path):
|
||||
raise argparse.ArgumentTypeError(f"{path} is not a valid directory.")
|
||||
return path
|
||||
|
||||
parser.add_argument(
|
||||
"--front-end-root",
|
||||
type=is_valid_directory,
|
||||
default=None,
|
||||
help="The local filesystem path to the directory where the frontend is located. Overrides --front-end-version.",
|
||||
)
|
||||
|
||||
if comfy.options.args_parsing:
|
||||
args = parser.parse_args()
|
||||
|
@@ -34,6 +34,7 @@ class ClipVisionModel():
|
||||
with open(json_config) as f:
|
||||
config = json.load(f)
|
||||
|
||||
self.image_size = config.get("image_size", 224)
|
||||
self.load_device = comfy.model_management.text_encoder_device()
|
||||
offload_device = comfy.model_management.text_encoder_offload_device()
|
||||
self.dtype = comfy.model_management.text_encoder_dtype(self.load_device)
|
||||
@@ -50,7 +51,7 @@ class ClipVisionModel():
|
||||
|
||||
def encode_image(self, image):
|
||||
comfy.model_management.load_model_gpu(self.patcher)
|
||||
pixel_values = clip_preprocess(image.to(self.load_device)).float()
|
||||
pixel_values = clip_preprocess(image.to(self.load_device), size=self.image_size).float()
|
||||
out = self.model(pixel_values=pixel_values, intermediate_output=-2)
|
||||
|
||||
outputs = Output()
|
||||
@@ -93,7 +94,10 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False):
|
||||
elif "vision_model.encoder.layers.30.layer_norm1.weight" in sd:
|
||||
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_h.json")
|
||||
elif "vision_model.encoder.layers.22.layer_norm1.weight" in sd:
|
||||
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl.json")
|
||||
if sd["vision_model.embeddings.position_embedding.weight"].shape[0] == 577:
|
||||
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl_336.json")
|
||||
else:
|
||||
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl.json")
|
||||
else:
|
||||
return None
|
||||
|
||||
|
18
comfy/clip_vision_config_vitl_336.json
Normal file
18
comfy/clip_vision_config_vitl_336.json
Normal file
@@ -0,0 +1,18 @@
|
||||
{
|
||||
"attention_dropout": 0.0,
|
||||
"dropout": 0.0,
|
||||
"hidden_act": "quick_gelu",
|
||||
"hidden_size": 1024,
|
||||
"image_size": 336,
|
||||
"initializer_factor": 1.0,
|
||||
"initializer_range": 0.02,
|
||||
"intermediate_size": 4096,
|
||||
"layer_norm_eps": 1e-5,
|
||||
"model_type": "clip_vision_model",
|
||||
"num_attention_heads": 16,
|
||||
"num_channels": 3,
|
||||
"num_hidden_layers": 24,
|
||||
"patch_size": 14,
|
||||
"projection_dim": 768,
|
||||
"torch_dtype": "float32"
|
||||
}
|
@@ -45,6 +45,7 @@ class ControlBase:
|
||||
self.timestep_range = None
|
||||
self.compression_ratio = 8
|
||||
self.upscale_algorithm = 'nearest-exact'
|
||||
self.extra_args = {}
|
||||
|
||||
if device is None:
|
||||
device = comfy.model_management.get_torch_device()
|
||||
@@ -90,6 +91,7 @@ class ControlBase:
|
||||
c.compression_ratio = self.compression_ratio
|
||||
c.upscale_algorithm = self.upscale_algorithm
|
||||
c.latent_format = self.latent_format
|
||||
c.extra_args = self.extra_args.copy()
|
||||
c.vae = self.vae
|
||||
|
||||
def inference_memory_requirements(self, dtype):
|
||||
@@ -135,6 +137,10 @@ class ControlBase:
|
||||
o[i] = prev_val + o[i] #TODO: change back to inplace add if shared tensors stop being an issue
|
||||
return out
|
||||
|
||||
def set_extra_arg(self, argument, value=None):
|
||||
self.extra_args[argument] = value
|
||||
|
||||
|
||||
class ControlNet(ControlBase):
|
||||
def __init__(self, control_model=None, global_average_pooling=False, compression_ratio=8, latent_format=None, device=None, load_device=None, manual_cast_dtype=None):
|
||||
super().__init__(device)
|
||||
@@ -191,7 +197,7 @@ class ControlNet(ControlBase):
|
||||
timestep = self.model_sampling_current.timestep(t)
|
||||
x_noisy = self.model_sampling_current.calculate_input(t, x_noisy)
|
||||
|
||||
control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.float(), context=context.to(dtype), y=y)
|
||||
control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.float(), context=context.to(dtype), y=y, **self.extra_args)
|
||||
return self.control_merge(control, control_prev, output_dtype)
|
||||
|
||||
def copy(self):
|
||||
|
@@ -7,6 +7,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
from .. import attention
|
||||
from einops import rearrange, repeat
|
||||
from .util import timestep_embedding
|
||||
|
||||
def default(x, y):
|
||||
if x is not None:
|
||||
@@ -230,34 +231,8 @@ class TimestepEmbedder(nn.Module):
|
||||
)
|
||||
self.frequency_embedding_size = frequency_embedding_size
|
||||
|
||||
@staticmethod
|
||||
def timestep_embedding(t, dim, max_period=10000):
|
||||
"""
|
||||
Create sinusoidal timestep embeddings.
|
||||
:param t: a 1-D Tensor of N indices, one per batch element.
|
||||
These may be fractional.
|
||||
:param dim: the dimension of the output.
|
||||
:param max_period: controls the minimum frequency of the embeddings.
|
||||
:return: an (N, D) Tensor of positional embeddings.
|
||||
"""
|
||||
half = dim // 2
|
||||
freqs = torch.exp(
|
||||
-math.log(max_period)
|
||||
* torch.arange(start=0, end=half, dtype=torch.float32, device=t.device)
|
||||
/ half
|
||||
)
|
||||
args = t[:, None].float() * freqs[None]
|
||||
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||
if dim % 2:
|
||||
embedding = torch.cat(
|
||||
[embedding, torch.zeros_like(embedding[:, :1])], dim=-1
|
||||
)
|
||||
if torch.is_floating_point(t):
|
||||
embedding = embedding.to(dtype=t.dtype)
|
||||
return embedding
|
||||
|
||||
def forward(self, t, dtype, **kwargs):
|
||||
t_freq = self.timestep_embedding(t, self.frequency_embedding_size).to(dtype)
|
||||
t_freq = timestep_embedding(t, self.frequency_embedding_size).to(dtype)
|
||||
t_emb = self.mlp(t_freq)
|
||||
return t_emb
|
||||
|
||||
|
@@ -261,13 +261,22 @@ def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=Fal
|
||||
return model_config
|
||||
|
||||
def unet_prefix_from_state_dict(state_dict):
|
||||
if "model.model.postprocess_conv.weight" in state_dict: #audio models
|
||||
unet_key_prefix = "model.model."
|
||||
elif "model.double_layers.0.attn.w1q.weight" in state_dict: #aura flow
|
||||
unet_key_prefix = "model."
|
||||
candidates = ["model.diffusion_model.", #ldm/sgm models
|
||||
"model.model.", #audio models
|
||||
]
|
||||
counts = {k: 0 for k in candidates}
|
||||
for k in state_dict:
|
||||
for c in candidates:
|
||||
if k.startswith(c):
|
||||
counts[c] += 1
|
||||
break
|
||||
|
||||
top = max(counts, key=counts.get)
|
||||
if counts[top] > 5:
|
||||
return top
|
||||
else:
|
||||
unet_key_prefix = "model.diffusion_model."
|
||||
return unet_key_prefix
|
||||
return "model." #aura flow and others
|
||||
|
||||
|
||||
def convert_config(unet_config):
|
||||
new_config = unet_config.copy()
|
||||
|
@@ -59,8 +59,9 @@ class ModelSamplingDiscrete(torch.nn.Module):
|
||||
beta_schedule = sampling_settings.get("beta_schedule", "linear")
|
||||
linear_start = sampling_settings.get("linear_start", 0.00085)
|
||||
linear_end = sampling_settings.get("linear_end", 0.012)
|
||||
timesteps = sampling_settings.get("timesteps", 1000)
|
||||
|
||||
self._register_schedule(given_betas=None, beta_schedule=beta_schedule, timesteps=1000, linear_start=linear_start, linear_end=linear_end, cosine_s=8e-3)
|
||||
self._register_schedule(given_betas=None, beta_schedule=beta_schedule, timesteps=timesteps, linear_start=linear_start, linear_end=linear_end, cosine_s=8e-3)
|
||||
self.sigma_data = 1.0
|
||||
|
||||
def _register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000,
|
||||
|
@@ -6,6 +6,8 @@ from comfy import model_management
|
||||
import math
|
||||
import logging
|
||||
import comfy.sampler_helpers
|
||||
import scipy
|
||||
import numpy
|
||||
|
||||
def get_area_and_mult(conds, x_in, timestep_in):
|
||||
dims = tuple(x_in.shape[2:])
|
||||
@@ -311,13 +313,18 @@ def simple_scheduler(model_sampling, steps):
|
||||
def ddim_scheduler(model_sampling, steps):
|
||||
s = model_sampling
|
||||
sigs = []
|
||||
ss = max(len(s.sigmas) // steps, 1)
|
||||
x = 1
|
||||
if math.isclose(float(s.sigmas[x]), 0, abs_tol=0.00001):
|
||||
steps += 1
|
||||
sigs = []
|
||||
else:
|
||||
sigs = [0.0]
|
||||
|
||||
ss = max(len(s.sigmas) // steps, 1)
|
||||
while x < len(s.sigmas):
|
||||
sigs += [float(s.sigmas[x])]
|
||||
x += ss
|
||||
sigs = sigs[::-1]
|
||||
sigs += [0.0]
|
||||
return torch.FloatTensor(sigs)
|
||||
|
||||
def normal_scheduler(model_sampling, steps, sgm=False, floor=False):
|
||||
@@ -325,15 +332,34 @@ def normal_scheduler(model_sampling, steps, sgm=False, floor=False):
|
||||
start = s.timestep(s.sigma_max)
|
||||
end = s.timestep(s.sigma_min)
|
||||
|
||||
append_zero = True
|
||||
if sgm:
|
||||
timesteps = torch.linspace(start, end, steps + 1)[:-1]
|
||||
else:
|
||||
if math.isclose(float(s.sigma(end)), 0, abs_tol=0.00001):
|
||||
steps += 1
|
||||
append_zero = False
|
||||
timesteps = torch.linspace(start, end, steps)
|
||||
|
||||
sigs = []
|
||||
for x in range(len(timesteps)):
|
||||
ts = timesteps[x]
|
||||
sigs.append(s.sigma(ts))
|
||||
sigs.append(float(s.sigma(ts)))
|
||||
|
||||
if append_zero:
|
||||
sigs += [0.0]
|
||||
|
||||
return torch.FloatTensor(sigs)
|
||||
|
||||
# Implemented based on: https://arxiv.org/abs/2407.12173
|
||||
def beta_scheduler(model_sampling, steps, alpha=0.6, beta=0.6):
|
||||
total_timesteps = (len(model_sampling.sigmas) - 1)
|
||||
ts = 1 - numpy.linspace(0, 1, steps, endpoint=False)
|
||||
ts = numpy.rint(scipy.stats.beta.ppf(ts, alpha, beta) * total_timesteps)
|
||||
|
||||
sigs = []
|
||||
for t in ts:
|
||||
sigs += [float(model_sampling.sigmas[int(t)])]
|
||||
sigs += [0.0]
|
||||
return torch.FloatTensor(sigs)
|
||||
|
||||
@@ -703,7 +729,7 @@ def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model
|
||||
return cfg_guider.sample(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)
|
||||
|
||||
|
||||
SCHEDULER_NAMES = ["normal", "karras", "exponential", "sgm_uniform", "simple", "ddim_uniform"]
|
||||
SCHEDULER_NAMES = ["normal", "karras", "exponential", "sgm_uniform", "simple", "ddim_uniform", "beta"]
|
||||
SAMPLER_NAMES = KSAMPLER_NAMES + ["ddim", "uni_pc", "uni_pc_bh2"]
|
||||
|
||||
def calculate_sigmas(model_sampling, scheduler_name, steps):
|
||||
@@ -719,6 +745,8 @@ def calculate_sigmas(model_sampling, scheduler_name, steps):
|
||||
sigmas = ddim_scheduler(model_sampling, steps)
|
||||
elif scheduler_name == "sgm_uniform":
|
||||
sigmas = normal_scheduler(model_sampling, steps, sgm=True)
|
||||
elif scheduler_name == "beta":
|
||||
sigmas = beta_scheduler(model_sampling, steps)
|
||||
else:
|
||||
logging.error("error invalid scheduler {}".format(scheduler_name))
|
||||
return sigmas
|
||||
|
32
comfy/sd.py
32
comfy/sd.py
@@ -19,8 +19,8 @@ from . import model_detection
|
||||
from . import sd1_clip
|
||||
from . import sd2_clip
|
||||
from . import sdxl_clip
|
||||
from . import sd3_clip
|
||||
from . import sa_t5
|
||||
import comfy.text_encoders.sd3_clip
|
||||
import comfy.text_encoders.sa_t5
|
||||
import comfy.text_encoders.aura_t5
|
||||
|
||||
import comfy.model_patcher
|
||||
@@ -60,7 +60,7 @@ def load_lora_for_models(model, clip, lora, strength_model, strength_clip):
|
||||
|
||||
|
||||
class CLIP:
|
||||
def __init__(self, target=None, embedding_directory=None, no_init=False):
|
||||
def __init__(self, target=None, embedding_directory=None, no_init=False, tokenizer_data={}):
|
||||
if no_init:
|
||||
return
|
||||
params = target.params.copy()
|
||||
@@ -79,7 +79,7 @@ class CLIP:
|
||||
if not model_management.supports_cast(load_device, dt):
|
||||
load_device = offload_device
|
||||
|
||||
self.tokenizer = tokenizer(embedding_directory=embedding_directory)
|
||||
self.tokenizer = tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
|
||||
self.patcher = comfy.model_patcher.ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device)
|
||||
self.layer_idx = None
|
||||
logging.debug("CLIP model load device: {}, offload device: {}".format(load_device, offload_device))
|
||||
@@ -135,7 +135,11 @@ class CLIP:
|
||||
return self.cond_stage_model.load_sd(sd)
|
||||
|
||||
def get_sd(self):
|
||||
return self.cond_stage_model.state_dict()
|
||||
sd_clip = self.cond_stage_model.state_dict()
|
||||
sd_tokenizer = self.tokenizer.state_dict()
|
||||
for k in sd_tokenizer:
|
||||
sd_clip[k] = sd_tokenizer[k]
|
||||
return sd_clip
|
||||
|
||||
def load_model(self):
|
||||
model_management.load_model_gpu(self.patcher)
|
||||
@@ -414,27 +418,27 @@ def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DI
|
||||
weight = clip_data[0]["encoder.block.23.layer.1.DenseReluDense.wi_1.weight"]
|
||||
dtype_t5 = weight.dtype
|
||||
if weight.shape[-1] == 4096:
|
||||
clip_target.clip = sd3_clip.sd3_clip(clip_l=False, clip_g=False, t5=True, dtype_t5=dtype_t5)
|
||||
clip_target.tokenizer = sd3_clip.SD3Tokenizer
|
||||
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=False, clip_g=False, t5=True, dtype_t5=dtype_t5)
|
||||
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
|
||||
elif weight.shape[-1] == 2048:
|
||||
clip_target.clip = comfy.text_encoders.aura_t5.AuraT5Model
|
||||
clip_target.tokenizer = comfy.text_encoders.aura_t5.AuraT5Tokenizer
|
||||
elif "encoder.block.0.layer.0.SelfAttention.k.weight" in clip_data[0]:
|
||||
clip_target.clip = sa_t5.SAT5Model
|
||||
clip_target.tokenizer = sa_t5.SAT5Tokenizer
|
||||
clip_target.clip = comfy.text_encoders.sa_t5.SAT5Model
|
||||
clip_target.tokenizer = comfy.text_encoders.sa_t5.SAT5Tokenizer
|
||||
else:
|
||||
clip_target.clip = sd1_clip.SD1ClipModel
|
||||
clip_target.tokenizer = sd1_clip.SD1Tokenizer
|
||||
elif len(clip_data) == 2:
|
||||
if clip_type == CLIPType.SD3:
|
||||
clip_target.clip = sd3_clip.sd3_clip(clip_l=True, clip_g=True, t5=False)
|
||||
clip_target.tokenizer = sd3_clip.SD3Tokenizer
|
||||
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=True, clip_g=True, t5=False)
|
||||
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
|
||||
else:
|
||||
clip_target.clip = sdxl_clip.SDXLClipModel
|
||||
clip_target.tokenizer = sdxl_clip.SDXLTokenizer
|
||||
elif len(clip_data) == 3:
|
||||
clip_target.clip = sd3_clip.SD3ClipModel
|
||||
clip_target.tokenizer = sd3_clip.SD3Tokenizer
|
||||
clip_target.clip = comfy.text_encoders.sd3_clip.SD3ClipModel
|
||||
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
|
||||
|
||||
clip = CLIP(clip_target, embedding_directory=embedding_directory)
|
||||
for c in clip_data:
|
||||
@@ -520,7 +524,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
|
||||
if clip_target is not None:
|
||||
clip_sd = model_config.process_clip_state_dict(sd)
|
||||
if len(clip_sd) > 0:
|
||||
clip = CLIP(clip_target, embedding_directory=embedding_directory)
|
||||
clip = CLIP(clip_target, embedding_directory=embedding_directory, tokenizer_data=clip_sd)
|
||||
m, u = clip.load_sd(clip_sd, full_model=True)
|
||||
if len(m) > 0:
|
||||
m_filter = list(filter(lambda a: ".logit_scale" not in a and ".transformer.text_projection.weight" not in a, m))
|
||||
|
@@ -386,7 +386,7 @@ def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=No
|
||||
return embed_out
|
||||
|
||||
class SDTokenizer:
|
||||
def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, pad_to_max_length=True, min_length=None, pad_token=None):
|
||||
def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, pad_to_max_length=True, min_length=None, pad_token=None, tokenizer_data={}):
|
||||
if tokenizer_path is None:
|
||||
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_tokenizer")
|
||||
self.tokenizer = tokenizer_class.from_pretrained(tokenizer_path)
|
||||
@@ -519,12 +519,14 @@ class SDTokenizer:
|
||||
def untokenize(self, token_weight_pair):
|
||||
return list(map(lambda a: (a, self.inv_vocab[a[0]]), token_weight_pair))
|
||||
|
||||
def state_dict(self):
|
||||
return {}
|
||||
|
||||
class SD1Tokenizer:
|
||||
def __init__(self, embedding_directory=None, clip_name="l", tokenizer=SDTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}, clip_name="l", tokenizer=SDTokenizer):
|
||||
self.clip_name = clip_name
|
||||
self.clip = "clip_{}".format(self.clip_name)
|
||||
setattr(self, self.clip, tokenizer(embedding_directory=embedding_directory))
|
||||
setattr(self, self.clip, tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data))
|
||||
|
||||
def tokenize_with_weights(self, text:str, return_word_ids=False):
|
||||
out = {}
|
||||
@@ -534,6 +536,8 @@ class SD1Tokenizer:
|
||||
def untokenize(self, token_weight_pair):
|
||||
return getattr(self, self.clip).untokenize(token_weight_pair)
|
||||
|
||||
def state_dict(self):
|
||||
return {}
|
||||
|
||||
class SD1ClipModel(torch.nn.Module):
|
||||
def __init__(self, device="cpu", dtype=None, clip_name="l", clip_model=SDClipModel, name=None, **kwargs):
|
||||
|
@@ -11,12 +11,12 @@ class SD2ClipHModel(sd1_clip.SDClipModel):
|
||||
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"start": 49406, "end": 49407, "pad": 0})
|
||||
|
||||
class SD2ClipHTokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, tokenizer_path=None, embedding_directory=None):
|
||||
def __init__(self, tokenizer_path=None, embedding_directory=None, tokenizer_data={}):
|
||||
super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=1024)
|
||||
|
||||
class SD2Tokenizer(sd1_clip.SD1Tokenizer):
|
||||
def __init__(self, embedding_directory=None):
|
||||
super().__init__(embedding_directory=embedding_directory, clip_name="h", tokenizer=SD2ClipHTokenizer)
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="h", tokenizer=SD2ClipHTokenizer)
|
||||
|
||||
class SD2ClipModel(sd1_clip.SD1ClipModel):
|
||||
def __init__(self, device="cpu", dtype=None, **kwargs):
|
||||
|
@@ -16,12 +16,12 @@ class SDXLClipG(sd1_clip.SDClipModel):
|
||||
return super().load_sd(sd)
|
||||
|
||||
class SDXLClipGTokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, tokenizer_path=None, embedding_directory=None):
|
||||
def __init__(self, tokenizer_path=None, embedding_directory=None, tokenizer_data={}):
|
||||
super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=1280, embedding_key='clip_g')
|
||||
|
||||
|
||||
class SDXLTokenizer:
|
||||
def __init__(self, embedding_directory=None):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory)
|
||||
self.clip_g = SDXLClipGTokenizer(embedding_directory=embedding_directory)
|
||||
|
||||
@@ -34,6 +34,9 @@ class SDXLTokenizer:
|
||||
def untokenize(self, token_weight_pair):
|
||||
return self.clip_g.untokenize(token_weight_pair)
|
||||
|
||||
def state_dict(self):
|
||||
return {}
|
||||
|
||||
class SDXLClipModel(torch.nn.Module):
|
||||
def __init__(self, device="cpu", dtype=None):
|
||||
super().__init__()
|
||||
@@ -68,12 +71,12 @@ class SDXLRefinerClipModel(sd1_clip.SD1ClipModel):
|
||||
|
||||
|
||||
class StableCascadeClipGTokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, tokenizer_path=None, embedding_directory=None):
|
||||
def __init__(self, tokenizer_path=None, embedding_directory=None, tokenizer_data={}):
|
||||
super().__init__(tokenizer_path, pad_with_end=True, embedding_directory=embedding_directory, embedding_size=1280, embedding_key='clip_g')
|
||||
|
||||
class StableCascadeTokenizer(sd1_clip.SD1Tokenizer):
|
||||
def __init__(self, embedding_directory=None):
|
||||
super().__init__(embedding_directory=embedding_directory, clip_name="g", tokenizer=StableCascadeClipGTokenizer)
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="g", tokenizer=StableCascadeClipGTokenizer)
|
||||
|
||||
class StableCascadeClipG(sd1_clip.SDClipModel):
|
||||
def __init__(self, device="cpu", max_length=77, freeze=True, layer="hidden", layer_idx=-1, dtype=None):
|
||||
|
@@ -5,8 +5,8 @@ from . import utils
|
||||
from . import sd1_clip
|
||||
from . import sd2_clip
|
||||
from . import sdxl_clip
|
||||
from . import sd3_clip
|
||||
from . import sa_t5
|
||||
import comfy.text_encoders.sd3_clip
|
||||
import comfy.text_encoders.sa_t5
|
||||
import comfy.text_encoders.aura_t5
|
||||
|
||||
from . import supported_models_base
|
||||
@@ -524,7 +524,7 @@ class SD3(supported_models_base.BASE):
|
||||
t5 = True
|
||||
dtype_t5 = state_dict[t5_key].dtype
|
||||
|
||||
return supported_models_base.ClipTarget(sd3_clip.SD3Tokenizer, sd3_clip.sd3_clip(clip_l=clip_l, clip_g=clip_g, t5=t5, dtype_t5=dtype_t5))
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.sd3_clip.SD3Tokenizer, comfy.text_encoders.sd3_clip.sd3_clip(clip_l=clip_l, clip_g=clip_g, t5=t5, dtype_t5=dtype_t5))
|
||||
|
||||
class StableAudio(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
@@ -555,7 +555,7 @@ class StableAudio(supported_models_base.BASE):
|
||||
return utils.state_dict_prefix_replace(state_dict, replace_prefix)
|
||||
|
||||
def clip_target(self, state_dict={}):
|
||||
return supported_models_base.ClipTarget(sa_t5.SAT5Tokenizer, sa_t5.SAT5Model)
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.sa_t5.SAT5Tokenizer, comfy.text_encoders.sa_t5.SAT5Model)
|
||||
|
||||
class AuraFlow(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
|
@@ -1,21 +1,21 @@
|
||||
from comfy import sd1_clip
|
||||
from .llama_tokenizer import LLAMATokenizer
|
||||
import comfy.t5
|
||||
from .spiece_tokenizer import SPieceTokenizer
|
||||
import comfy.text_encoders.t5
|
||||
import os
|
||||
|
||||
class PT5XlModel(sd1_clip.SDClipModel):
|
||||
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None):
|
||||
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_pile_config_xl.json")
|
||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 2, "pad": 1}, model_class=comfy.t5.T5, enable_attention_masks=True, zero_out_masked=True)
|
||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 2, "pad": 1}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=True, zero_out_masked=True)
|
||||
|
||||
class PT5XlTokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, embedding_directory=None):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
tokenizer_path = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_pile_tokenizer"), "tokenizer.model")
|
||||
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=2048, embedding_key='pile_t5xl', tokenizer_class=LLAMATokenizer, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256, pad_token=1)
|
||||
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=2048, embedding_key='pile_t5xl', tokenizer_class=SPieceTokenizer, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256, pad_token=1)
|
||||
|
||||
class AuraT5Tokenizer(sd1_clip.SD1Tokenizer):
|
||||
def __init__(self, embedding_directory=None):
|
||||
super().__init__(embedding_directory=embedding_directory, clip_name="pile_t5xl", tokenizer=PT5XlTokenizer)
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="pile_t5xl", tokenizer=PT5XlTokenizer)
|
||||
|
||||
class AuraT5Model(sd1_clip.SD1ClipModel):
|
||||
def __init__(self, device="cpu", dtype=None, **kwargs):
|
||||
|
@@ -1,22 +0,0 @@
|
||||
import os
|
||||
|
||||
class LLAMATokenizer:
|
||||
@staticmethod
|
||||
def from_pretrained(path):
|
||||
return LLAMATokenizer(path)
|
||||
|
||||
def __init__(self, tokenizer_path):
|
||||
import sentencepiece
|
||||
self.tokenizer = sentencepiece.SentencePieceProcessor(model_file=tokenizer_path)
|
||||
self.end = self.tokenizer.eos_id()
|
||||
|
||||
def get_vocab(self):
|
||||
out = {}
|
||||
for i in range(self.tokenizer.get_piece_size()):
|
||||
out[self.tokenizer.id_to_piece(i)] = i
|
||||
return out
|
||||
|
||||
def __call__(self, string):
|
||||
out = self.tokenizer.encode(string)
|
||||
out += [self.end]
|
||||
return {"input_ids": out}
|
@@ -1,21 +1,21 @@
|
||||
from comfy import sd1_clip
|
||||
from transformers import T5TokenizerFast
|
||||
import comfy.t5
|
||||
import comfy.text_encoders.t5
|
||||
import os
|
||||
|
||||
class T5BaseModel(sd1_clip.SDClipModel):
|
||||
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None):
|
||||
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_config_base.json")
|
||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.t5.T5, enable_attention_masks=True, zero_out_masked=True)
|
||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=True, zero_out_masked=True)
|
||||
|
||||
class T5BaseTokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, embedding_directory=None):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
|
||||
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=768, embedding_key='t5base', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=128)
|
||||
|
||||
class SAT5Tokenizer(sd1_clip.SD1Tokenizer):
|
||||
def __init__(self, embedding_directory=None):
|
||||
super().__init__(embedding_directory=embedding_directory, clip_name="t5base", tokenizer=T5BaseTokenizer)
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="t5base", tokenizer=T5BaseTokenizer)
|
||||
|
||||
class SAT5Model(sd1_clip.SD1ClipModel):
|
||||
def __init__(self, device="cpu", dtype=None, **kwargs):
|
@@ -1,7 +1,7 @@
|
||||
from comfy import sd1_clip
|
||||
from comfy import sdxl_clip
|
||||
from transformers import T5TokenizerFast
|
||||
import comfy.t5
|
||||
import comfy.text_encoders.t5
|
||||
import torch
|
||||
import os
|
||||
import comfy.model_management
|
||||
@@ -10,25 +10,16 @@ import logging
|
||||
class T5XXLModel(sd1_clip.SDClipModel):
|
||||
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None):
|
||||
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_config_xxl.json")
|
||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.t5.T5)
|
||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5)
|
||||
|
||||
class T5XXLTokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, embedding_directory=None):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
|
||||
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=77)
|
||||
|
||||
class SDT5XXLTokenizer(sd1_clip.SD1Tokenizer):
|
||||
def __init__(self, embedding_directory=None):
|
||||
super().__init__(embedding_directory=embedding_directory, clip_name="t5xxl", tokenizer=T5XXLTokenizer)
|
||||
|
||||
class SDT5XXLModel(sd1_clip.SD1ClipModel):
|
||||
def __init__(self, device="cpu", dtype=None, **kwargs):
|
||||
super().__init__(device=device, dtype=dtype, clip_name="t5xxl", clip_model=T5XXLModel, **kwargs)
|
||||
|
||||
|
||||
|
||||
class SD3Tokenizer:
|
||||
def __init__(self, embedding_directory=None):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory)
|
||||
self.clip_g = sdxl_clip.SDXLClipGTokenizer(embedding_directory=embedding_directory)
|
||||
self.t5xxl = T5XXLTokenizer(embedding_directory=embedding_directory)
|
||||
@@ -43,6 +34,9 @@ class SD3Tokenizer:
|
||||
def untokenize(self, token_weight_pair):
|
||||
return self.clip_g.untokenize(token_weight_pair)
|
||||
|
||||
def state_dict(self):
|
||||
return {}
|
||||
|
||||
class SD3ClipModel(torch.nn.Module):
|
||||
def __init__(self, clip_l=True, clip_g=True, t5=True, dtype_t5=None, device="cpu", dtype=None):
|
||||
super().__init__()
|
29
comfy/text_encoders/spiece_tokenizer.py
Normal file
29
comfy/text_encoders/spiece_tokenizer.py
Normal file
@@ -0,0 +1,29 @@
|
||||
import os
|
||||
import torch
|
||||
|
||||
class SPieceTokenizer:
|
||||
add_eos = True
|
||||
|
||||
@staticmethod
|
||||
def from_pretrained(path):
|
||||
return SPieceTokenizer(path)
|
||||
|
||||
def __init__(self, tokenizer_path):
|
||||
import sentencepiece
|
||||
if torch.is_tensor(tokenizer_path):
|
||||
tokenizer_path = tokenizer_path.numpy().tobytes()
|
||||
|
||||
if isinstance(tokenizer_path, bytes):
|
||||
self.tokenizer = sentencepiece.SentencePieceProcessor(model_proto=tokenizer_path, add_eos=self.add_eos)
|
||||
else:
|
||||
self.tokenizer = sentencepiece.SentencePieceProcessor(model_file=tokenizer_path, add_eos=self.add_eos)
|
||||
|
||||
def get_vocab(self):
|
||||
out = {}
|
||||
for i in range(self.tokenizer.get_piece_size()):
|
||||
out[self.tokenizer.id_to_piece(i)] = i
|
||||
return out
|
||||
|
||||
def __call__(self, string):
|
||||
out = self.tokenizer.encode(string)
|
||||
return {"input_ids": out}
|
@@ -223,7 +223,7 @@ class T5(torch.nn.Module):
|
||||
self.num_layers = config_dict["num_layers"]
|
||||
model_dim = config_dict["d_model"]
|
||||
|
||||
self.encoder = T5Stack(self.num_layers, model_dim, model_dim, config_dict["d_ff"], config_dict["dense_act_fn"], config_dict["is_gated_act"], config_dict["num_heads"], config_dict["model_type"] == "t5", dtype, device, operations)
|
||||
self.encoder = T5Stack(self.num_layers, model_dim, model_dim, config_dict["d_ff"], config_dict["dense_act_fn"], config_dict["is_gated_act"], config_dict["num_heads"], config_dict["model_type"] != "umt5", dtype, device, operations)
|
||||
self.dtype = dtype
|
||||
self.shared = torch.nn.Embedding(config_dict["vocab_size"], model_dim, device=device)
|
||||
|
@@ -147,7 +147,7 @@ class SaveAudio:
|
||||
for x in extra_pnginfo:
|
||||
metadata[x] = json.dumps(extra_pnginfo[x])
|
||||
|
||||
for (batch_number, waveform) in enumerate(audio["waveform"]):
|
||||
for (batch_number, waveform) in enumerate(audio["waveform"].cpu()):
|
||||
filename_with_batch_num = filename.replace("%batch_num%", str(batch_number))
|
||||
file = f"{filename_with_batch_num}_{counter:05}_.flac"
|
||||
|
||||
|
27
comfy_extras/nodes_controlnet.py
Normal file
27
comfy_extras/nodes_controlnet.py
Normal file
@@ -0,0 +1,27 @@
|
||||
from comfy.cldm.control_types import UNION_CONTROLNET_TYPES
|
||||
|
||||
class SetUnionControlNetType:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"control_net": ("CONTROL_NET", ),
|
||||
"type": (["auto"] + list(UNION_CONTROLNET_TYPES.keys()),)
|
||||
}}
|
||||
|
||||
CATEGORY = "conditioning/controlnet"
|
||||
RETURN_TYPES = ("CONTROL_NET",)
|
||||
|
||||
FUNCTION = "set_controlnet_type"
|
||||
|
||||
def set_controlnet_type(self, control_net, type):
|
||||
control_net = control_net.copy()
|
||||
type_number = UNION_CONTROLNET_TYPES.get(type, -1)
|
||||
if type_number >= 0:
|
||||
control_net.set_extra_arg("control_type", [type_number])
|
||||
else:
|
||||
control_net.set_extra_arg("control_type", [])
|
||||
|
||||
return (control_net,)
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"SetUnionControlNetType": SetUnionControlNetType,
|
||||
}
|
@@ -111,6 +111,25 @@ class SDTurboScheduler:
|
||||
sigmas = torch.cat([sigmas, sigmas.new_zeros([1])])
|
||||
return (sigmas, )
|
||||
|
||||
class BetaSamplingScheduler:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required":
|
||||
{"model": ("MODEL",),
|
||||
"steps": ("INT", {"default": 20, "min": 1, "max": 10000}),
|
||||
"alpha": ("FLOAT", {"default": 0.6, "min": 0.0, "max": 50.0, "step":0.01, "round": False}),
|
||||
"beta": ("FLOAT", {"default": 0.6, "min": 0.0, "max": 50.0, "step":0.01, "round": False}),
|
||||
}
|
||||
}
|
||||
RETURN_TYPES = ("SIGMAS",)
|
||||
CATEGORY = "sampling/custom_sampling/schedulers"
|
||||
|
||||
FUNCTION = "get_sigmas"
|
||||
|
||||
def get_sigmas(self, model, steps, alpha, beta):
|
||||
sigmas = comfy.samplers.beta_scheduler(model.get_model_object("model_sampling"), steps, alpha=alpha, beta=beta)
|
||||
return (sigmas, )
|
||||
|
||||
class VPScheduler:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
@@ -638,6 +657,7 @@ NODE_CLASS_MAPPINGS = {
|
||||
"ExponentialScheduler": ExponentialScheduler,
|
||||
"PolyexponentialScheduler": PolyexponentialScheduler,
|
||||
"VPScheduler": VPScheduler,
|
||||
"BetaSamplingScheduler": BetaSamplingScheduler,
|
||||
"SDTurboScheduler": SDTurboScheduler,
|
||||
"KSamplerSelect": KSamplerSelect,
|
||||
"SamplerEulerAncestral": SamplerEulerAncestral,
|
||||
|
@@ -34,7 +34,7 @@ class FreeU:
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
FUNCTION = "patch"
|
||||
|
||||
CATEGORY = "model_patches"
|
||||
CATEGORY = "model_patches/unet"
|
||||
|
||||
def patch(self, model, b1, b2, s1, s2):
|
||||
model_channels = model.model.model_config.unet_config["model_channels"]
|
||||
@@ -73,7 +73,7 @@ class FreeU_V2:
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
FUNCTION = "patch"
|
||||
|
||||
CATEGORY = "model_patches"
|
||||
CATEGORY = "model_patches/unet"
|
||||
|
||||
def patch(self, model, b1, b2, s1, s2):
|
||||
model_channels = model.model.model_config.unet_config["model_channels"]
|
||||
|
@@ -32,7 +32,7 @@ class HyperTile:
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
FUNCTION = "patch"
|
||||
|
||||
CATEGORY = "model_patches"
|
||||
CATEGORY = "model_patches/unet"
|
||||
|
||||
def patch(self, model, tile_size, swap_size, max_depth, scale_depth):
|
||||
model_channels = model.model.model_config.unet_config["model_channels"]
|
||||
|
@@ -19,7 +19,7 @@ class PerturbedAttentionGuidance:
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
FUNCTION = "patch"
|
||||
|
||||
CATEGORY = "_for_testing"
|
||||
CATEGORY = "model_patches/unet"
|
||||
|
||||
def patch(self, model, scale):
|
||||
unet_block = "middle"
|
||||
|
@@ -92,7 +92,7 @@ class ControlNetApplySD3(nodes.ControlNetApplyAdvanced):
|
||||
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
|
||||
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001})
|
||||
}}
|
||||
CATEGORY = "_for_testing/sd3"
|
||||
CATEGORY = "conditioning/controlnet"
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"TripleCLIPLoader": TripleCLIPLoader,
|
||||
|
12
execution.py
12
execution.py
@@ -3,6 +3,7 @@ import copy
|
||||
import logging
|
||||
import threading
|
||||
import heapq
|
||||
import time
|
||||
import traceback
|
||||
import inspect
|
||||
from typing import List, Literal, NamedTuple, Optional
|
||||
@@ -247,6 +248,8 @@ def recursive_output_delete_if_changed(prompt, old_prompt, outputs, current_item
|
||||
to_delete = True
|
||||
elif unique_id not in old_prompt:
|
||||
to_delete = True
|
||||
elif class_type != old_prompt[unique_id]['class_type']:
|
||||
to_delete = True
|
||||
elif inputs == old_prompt[unique_id]['inputs']:
|
||||
for x in inputs:
|
||||
input_data = inputs[x]
|
||||
@@ -281,7 +284,11 @@ class PromptExecutor:
|
||||
self.success = True
|
||||
self.old_prompt = {}
|
||||
|
||||
def add_message(self, event, data, broadcast: bool):
|
||||
def add_message(self, event, data: dict, broadcast: bool):
|
||||
data = {
|
||||
**data,
|
||||
"timestamp": int(time.time() * 1000),
|
||||
}
|
||||
self.status_messages.append((event, data))
|
||||
if self.server.client_id is not None or broadcast:
|
||||
self.server.send_sync(event, data, self.server.client_id)
|
||||
@@ -392,6 +399,9 @@ class PromptExecutor:
|
||||
if self.success is not True:
|
||||
self.handle_execution_error(prompt_id, prompt, current_outputs, executed, error, ex)
|
||||
break
|
||||
else:
|
||||
# Only execute when the while-loop ends without break
|
||||
self.add_message("execution_success", { "prompt_id": prompt_id }, broadcast=False)
|
||||
|
||||
for x in executed:
|
||||
self.old_prompt[x] = copy.deepcopy(prompt[x])
|
||||
|
24
fix_torch.py
Normal file
24
fix_torch.py
Normal file
@@ -0,0 +1,24 @@
|
||||
import importlib.util
|
||||
import shutil
|
||||
import os
|
||||
import ctypes
|
||||
import logging
|
||||
|
||||
|
||||
torch_spec = importlib.util.find_spec("torch")
|
||||
for folder in torch_spec.submodule_search_locations:
|
||||
lib_folder = os.path.join(folder, "lib")
|
||||
test_file = os.path.join(lib_folder, "fbgemm.dll")
|
||||
dest = os.path.join(lib_folder, "libomp140.x86_64.dll")
|
||||
if os.path.exists(dest):
|
||||
break
|
||||
|
||||
with open(test_file, 'rb') as f:
|
||||
contents = f.read()
|
||||
if b"libomp140.x86_64.dll" not in contents:
|
||||
break
|
||||
try:
|
||||
mydll = ctypes.cdll.LoadLibrary(test_file)
|
||||
except FileNotFoundError as e:
|
||||
logging.warning("Detected pytorch version with libomp issue, patching.")
|
||||
shutil.copyfile(os.path.join(lib_folder, "libiomp5md.dll"), dest)
|
6
main.py
6
main.py
@@ -74,6 +74,12 @@ if __name__ == "__main__":
|
||||
|
||||
import cuda_malloc
|
||||
|
||||
if args.windows_standalone_build:
|
||||
try:
|
||||
import fix_torch
|
||||
except:
|
||||
pass
|
||||
|
||||
import comfy.utils
|
||||
import yaml
|
||||
|
||||
|
@@ -1,3 +1,7 @@
|
||||
import hashlib
|
||||
|
||||
from comfy.cli_args import args
|
||||
|
||||
from PIL import ImageFile, UnidentifiedImageError
|
||||
|
||||
def conditioning_set_values(conditioning, values={}):
|
||||
@@ -22,3 +26,12 @@ def pillow(fn, arg):
|
||||
if prev_value is not None:
|
||||
ImageFile.LOAD_TRUNCATED_IMAGES = prev_value
|
||||
return x
|
||||
|
||||
def hasher():
|
||||
hashfuncs = {
|
||||
"md5": hashlib.md5,
|
||||
"sha1": hashlib.sha1,
|
||||
"sha256": hashlib.sha256,
|
||||
"sha512": hashlib.sha512
|
||||
}
|
||||
return hashfuncs[args.default_hashing_function]
|
||||
|
35
nodes.py
35
nodes.py
@@ -748,7 +748,7 @@ class ControlNetApply:
|
||||
RETURN_TYPES = ("CONDITIONING",)
|
||||
FUNCTION = "apply_controlnet"
|
||||
|
||||
CATEGORY = "conditioning"
|
||||
CATEGORY = "conditioning/controlnet"
|
||||
|
||||
def apply_controlnet(self, conditioning, control_net, image, strength):
|
||||
if strength == 0:
|
||||
@@ -783,7 +783,7 @@ class ControlNetApplyAdvanced:
|
||||
RETURN_NAMES = ("positive", "negative")
|
||||
FUNCTION = "apply_controlnet"
|
||||
|
||||
CATEGORY = "conditioning"
|
||||
CATEGORY = "conditioning/controlnet"
|
||||
|
||||
def apply_controlnet(self, positive, negative, control_net, image, strength, start_percent, end_percent, vae=None):
|
||||
if strength == 0:
|
||||
@@ -1890,29 +1890,29 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
EXTENSION_WEB_DIRS = {}
|
||||
|
||||
|
||||
def get_relative_module_name(module_path: str) -> str:
|
||||
def get_module_name(module_path: str) -> str:
|
||||
"""
|
||||
Returns the module name based on the given module path.
|
||||
Examples:
|
||||
get_module_name("C:/Users/username/ComfyUI/custom_nodes/my_custom_node.py") -> "custom_nodes.my_custom_node"
|
||||
get_module_name("C:/Users/username/ComfyUI/custom_nodes/my_custom_node") -> "custom_nodes.my_custom_node"
|
||||
get_module_name("C:/Users/username/ComfyUI/custom_nodes/my_custom_node/") -> "custom_nodes.my_custom_node"
|
||||
get_module_name("C:/Users/username/ComfyUI/custom_nodes/my_custom_node/__init__.py") -> "custom_nodes.my_custom_node"
|
||||
get_module_name("C:/Users/username/ComfyUI/custom_nodes/my_custom_node/__init__") -> "custom_nodes.my_custom_node"
|
||||
get_module_name("C:/Users/username/ComfyUI/custom_nodes/my_custom_node/__init__/") -> "custom_nodes.my_custom_node"
|
||||
get_module_name("C:/Users/username/ComfyUI/custom_nodes/my_custom_node.disabled") -> "custom_nodes.my
|
||||
get_module_name("C:/Users/username/ComfyUI/custom_nodes/my_custom_node.py") -> "my_custom_node"
|
||||
get_module_name("C:/Users/username/ComfyUI/custom_nodes/my_custom_node") -> "my_custom_node"
|
||||
get_module_name("C:/Users/username/ComfyUI/custom_nodes/my_custom_node/") -> "my_custom_node"
|
||||
get_module_name("C:/Users/username/ComfyUI/custom_nodes/my_custom_node/__init__.py") -> "my_custom_node"
|
||||
get_module_name("C:/Users/username/ComfyUI/custom_nodes/my_custom_node/__init__") -> "my_custom_node"
|
||||
get_module_name("C:/Users/username/ComfyUI/custom_nodes/my_custom_node/__init__/") -> "my_custom_node"
|
||||
get_module_name("C:/Users/username/ComfyUI/custom_nodes/my_custom_node.disabled") -> "custom_nodes
|
||||
Args:
|
||||
module_path (str): The path of the module.
|
||||
Returns:
|
||||
str: The module name.
|
||||
"""
|
||||
relative_path = os.path.relpath(module_path, folder_paths.base_path)
|
||||
base_path = os.path.basename(module_path)
|
||||
if os.path.isfile(module_path):
|
||||
relative_path = os.path.splitext(relative_path)[0]
|
||||
return relative_path.replace(os.sep, '.')
|
||||
base_path = os.path.splitext(base_path)[0]
|
||||
return base_path
|
||||
|
||||
|
||||
def load_custom_node(module_path: str, ignore=set()) -> bool:
|
||||
def load_custom_node(module_path: str, ignore=set(), module_parent="custom_nodes") -> bool:
|
||||
module_name = os.path.basename(module_path)
|
||||
if os.path.isfile(module_path):
|
||||
sp = os.path.splitext(module_path)
|
||||
@@ -1939,7 +1939,7 @@ def load_custom_node(module_path: str, ignore=set()) -> bool:
|
||||
for name, node_cls in module.NODE_CLASS_MAPPINGS.items():
|
||||
if name not in ignore:
|
||||
NODE_CLASS_MAPPINGS[name] = node_cls
|
||||
node_cls.RELATIVE_PYTHON_MODULE = get_relative_module_name(module_path)
|
||||
node_cls.RELATIVE_PYTHON_MODULE = "{}.{}".format(module_parent, get_module_name(module_path))
|
||||
if hasattr(module, "NODE_DISPLAY_NAME_MAPPINGS") and getattr(module, "NODE_DISPLAY_NAME_MAPPINGS") is not None:
|
||||
NODE_DISPLAY_NAME_MAPPINGS.update(module.NODE_DISPLAY_NAME_MAPPINGS)
|
||||
return True
|
||||
@@ -1974,7 +1974,7 @@ def init_external_custom_nodes():
|
||||
if os.path.isfile(module_path) and os.path.splitext(module_path)[1] != ".py": continue
|
||||
if module_path.endswith(".disabled"): continue
|
||||
time_before = time.perf_counter()
|
||||
success = load_custom_node(module_path, base_node_names)
|
||||
success = load_custom_node(module_path, base_node_names, module_parent="custom_nodes")
|
||||
node_import_times.append((time.perf_counter() - time_before, module_path, success))
|
||||
|
||||
if len(node_import_times) > 0:
|
||||
@@ -2036,11 +2036,12 @@ def init_builtin_extra_nodes():
|
||||
"nodes_audio.py",
|
||||
"nodes_sd3.py",
|
||||
"nodes_gits.py",
|
||||
"nodes_controlnet.py",
|
||||
]
|
||||
|
||||
import_failed = []
|
||||
for node_file in extras_files:
|
||||
if not load_custom_node(os.path.join(extras_dir, node_file)):
|
||||
if not load_custom_node(os.path.join(extras_dir, node_file), module_parent="comfy_extras"):
|
||||
import_failed.append(node_file)
|
||||
|
||||
return import_failed
|
||||
|
@@ -1,5 +1,8 @@
|
||||
[pytest]
|
||||
markers =
|
||||
inference: mark as inference test (deselect with '-m "not inference"')
|
||||
testpaths = tests
|
||||
addopts = -s
|
||||
testpaths =
|
||||
tests
|
||||
tests-unit
|
||||
addopts = -s
|
||||
pythonpath = .
|
||||
|
@@ -1,7 +1,7 @@
|
||||
torch
|
||||
torch==2.3.1
|
||||
torchsde
|
||||
torchvision
|
||||
torchaudio
|
||||
torchvision==0.18.1
|
||||
torchaudio==2.3.1
|
||||
einops
|
||||
transformers>=4.28.1
|
||||
tokenizers>=0.13.3
|
||||
@@ -13,6 +13,7 @@ Pillow
|
||||
scipy
|
||||
tqdm
|
||||
psutil
|
||||
numpy<2.0.0
|
||||
|
||||
#non essential dependencies:
|
||||
kornia>=0.7.1
|
||||
|
18
server.py
18
server.py
@@ -25,9 +25,11 @@ import mimetypes
|
||||
from comfy.cli_args import args
|
||||
import comfy.utils
|
||||
import comfy.model_management
|
||||
|
||||
import node_helpers
|
||||
from app.frontend_management import FrontendManager
|
||||
from app.user_manager import UserManager
|
||||
|
||||
|
||||
class BinaryEventTypes:
|
||||
PREVIEW_IMAGE = 1
|
||||
UNENCODED_PREVIEW_IMAGE = 2
|
||||
@@ -83,8 +85,12 @@ class PromptServer():
|
||||
max_upload_size = round(args.max_upload_size * 1024 * 1024)
|
||||
self.app = web.Application(client_max_size=max_upload_size, middlewares=middlewares)
|
||||
self.sockets = dict()
|
||||
self.web_root = os.path.join(os.path.dirname(
|
||||
os.path.realpath(__file__)), "web")
|
||||
self.web_root = (
|
||||
FrontendManager.init_frontend(args.front_end_version)
|
||||
if args.front_end_root is None
|
||||
else args.front_end_root
|
||||
)
|
||||
logging.info(f"[Prompt Server] web root: {self.web_root}")
|
||||
routes = web.RouteTableDef()
|
||||
self.routes = routes
|
||||
self.last_node_id = None
|
||||
@@ -156,10 +162,12 @@ class PromptServer():
|
||||
return type_dir, dir_type
|
||||
|
||||
def compare_image_hash(filepath, image):
|
||||
hasher = node_helpers.hasher()
|
||||
|
||||
# function to compare hashes of two images to see if it already exists, fix to #3465
|
||||
if os.path.exists(filepath):
|
||||
a = hashlib.sha256()
|
||||
b = hashlib.sha256()
|
||||
a = hasher()
|
||||
b = hasher()
|
||||
with open(filepath, "rb") as f:
|
||||
a.update(f.read())
|
||||
b.update(image.file.read())
|
||||
|
8
tests-unit/README.md
Normal file
8
tests-unit/README.md
Normal file
@@ -0,0 +1,8 @@
|
||||
# Pytest Unit Tests
|
||||
|
||||
## Install test dependencies
|
||||
|
||||
`pip install -r tests-units/requirements.txt`
|
||||
|
||||
## Run tests
|
||||
`pytest tests-units/`
|
0
tests-unit/app_test/__init__.py
Normal file
0
tests-unit/app_test/__init__.py
Normal file
100
tests-unit/app_test/frontend_manager_test.py
Normal file
100
tests-unit/app_test/frontend_manager_test.py
Normal file
@@ -0,0 +1,100 @@
|
||||
import argparse
|
||||
import pytest
|
||||
from requests.exceptions import HTTPError
|
||||
|
||||
from app.frontend_management import (
|
||||
FrontendManager,
|
||||
FrontEndProvider,
|
||||
Release,
|
||||
)
|
||||
from comfy.cli_args import DEFAULT_VERSION_STRING
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_releases():
|
||||
return [
|
||||
Release(
|
||||
id=1,
|
||||
tag_name="1.0.0",
|
||||
name="Release 1.0.0",
|
||||
prerelease=False,
|
||||
created_at="2022-01-01T00:00:00Z",
|
||||
published_at="2022-01-01T00:00:00Z",
|
||||
body="Release notes for 1.0.0",
|
||||
assets=[{"name": "dist.zip", "url": "https://example.com/dist.zip"}],
|
||||
),
|
||||
Release(
|
||||
id=2,
|
||||
tag_name="2.0.0",
|
||||
name="Release 2.0.0",
|
||||
prerelease=False,
|
||||
created_at="2022-02-01T00:00:00Z",
|
||||
published_at="2022-02-01T00:00:00Z",
|
||||
body="Release notes for 2.0.0",
|
||||
assets=[{"name": "dist.zip", "url": "https://example.com/dist.zip"}],
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_provider(mock_releases):
|
||||
provider = FrontEndProvider(
|
||||
owner="test-owner",
|
||||
repo="test-repo",
|
||||
)
|
||||
provider.all_releases = mock_releases
|
||||
provider.latest_release = mock_releases[1]
|
||||
FrontendManager.PROVIDERS = [provider]
|
||||
return provider
|
||||
|
||||
|
||||
def test_get_release(mock_provider, mock_releases):
|
||||
version = "1.0.0"
|
||||
release = mock_provider.get_release(version)
|
||||
assert release == mock_releases[0]
|
||||
|
||||
|
||||
def test_get_release_latest(mock_provider, mock_releases):
|
||||
version = "latest"
|
||||
release = mock_provider.get_release(version)
|
||||
assert release == mock_releases[1]
|
||||
|
||||
|
||||
def test_get_release_invalid_version(mock_provider):
|
||||
version = "invalid"
|
||||
with pytest.raises(ValueError):
|
||||
mock_provider.get_release(version)
|
||||
|
||||
|
||||
def test_init_frontend_default():
|
||||
version_string = DEFAULT_VERSION_STRING
|
||||
frontend_path = FrontendManager.init_frontend(version_string)
|
||||
assert frontend_path == FrontendManager.DEFAULT_FRONTEND_PATH
|
||||
|
||||
|
||||
def test_init_frontend_invalid_version():
|
||||
version_string = "test-owner/test-repo@1.100.99"
|
||||
with pytest.raises(HTTPError):
|
||||
FrontendManager.init_frontend_unsafe(version_string)
|
||||
|
||||
|
||||
def test_init_frontend_invalid_provider():
|
||||
version_string = "invalid/invalid@latest"
|
||||
with pytest.raises(HTTPError):
|
||||
FrontendManager.init_frontend_unsafe(version_string)
|
||||
|
||||
|
||||
def test_parse_version_string():
|
||||
version_string = "owner/repo@1.0.0"
|
||||
repo_owner, repo_name, version = FrontendManager.parse_version_string(
|
||||
version_string
|
||||
)
|
||||
assert repo_owner == "owner"
|
||||
assert repo_name == "repo"
|
||||
assert version == "1.0.0"
|
||||
|
||||
|
||||
def test_parse_version_string_invalid():
|
||||
version_string = "invalid"
|
||||
with pytest.raises(argparse.ArgumentTypeError):
|
||||
FrontendManager.parse_version_string(version_string)
|
1
tests-unit/requirements.txt
Normal file
1
tests-unit/requirements.txt
Normal file
@@ -0,0 +1 @@
|
||||
pytest>=7.8.0
|
@@ -17,7 +17,6 @@ function getResourceURL(subfolder, filename, type = "input") {
|
||||
"filename=" + encodeURIComponent(filename),
|
||||
"type=" + type,
|
||||
"subfolder=" + subfolder,
|
||||
app.getPreviewFormatParam().substring(1),
|
||||
app.getRandParam().substring(1)
|
||||
].join("&")
|
||||
|
||||
@@ -150,6 +149,15 @@ app.registerExtension({
|
||||
}
|
||||
audioWidget.callback = onAudioWidgetUpdate
|
||||
|
||||
// Load saved audio file widget values if restoring from workflow
|
||||
const onGraphConfigured = node.onGraphConfigured;
|
||||
node.onGraphConfigured = function() {
|
||||
onGraphConfigured?.apply(this, arguments)
|
||||
if (audioWidget.value) {
|
||||
onAudioWidgetUpdate()
|
||||
}
|
||||
}
|
||||
|
||||
const fileInput = document.createElement("input")
|
||||
fileInput.type = "file"
|
||||
fileInput.accept = "audio/*"
|
||||
|
@@ -136,6 +136,9 @@ class ComfyApi extends EventTarget {
|
||||
case "execution_start":
|
||||
this.dispatchEvent(new CustomEvent("execution_start", { detail: msg.data }));
|
||||
break;
|
||||
case "execution_success":
|
||||
this.dispatchEvent(new CustomEvent("execution_success", { detail: msg.data }));
|
||||
break;
|
||||
case "execution_error":
|
||||
this.dispatchEvent(new CustomEvent("execution_error", { detail: msg.data }));
|
||||
break;
|
||||
|
@@ -49,7 +49,7 @@ export function getPngMetadata(file) {
|
||||
|
||||
function parseExifData(exifData) {
|
||||
// Check for the correct TIFF header (0x4949 for little-endian or 0x4D4D for big-endian)
|
||||
const isLittleEndian = new Uint16Array(exifData.slice(0, 2))[0] === 0x4949;
|
||||
const isLittleEndian = String.fromCharCode(...exifData.slice(0, 2)) === "II";
|
||||
|
||||
// Function to read 16-bit and 32-bit integers from binary data
|
||||
function readInt(offset, isLittleEndian, length) {
|
||||
@@ -134,6 +134,7 @@ export function getWebpMetadata(file) {
|
||||
let index = value.indexOf(':');
|
||||
txt_chunks[value.slice(0, index)] = value.slice(index + 1);
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
offset += 8 + chunk_length;
|
||||
|
Reference in New Issue
Block a user