mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-08-02 15:04:50 +08:00
Compare commits
108 Commits
v0.3.39
...
js/drafts/
Author | SHA1 | Date | |
---|---|---|---|
|
0254d9cc11 | ||
|
92f9a10782 | ||
|
a6a6b615f4 | ||
|
50bf72f852 | ||
|
46c8311d14 | ||
|
772de7c006 | ||
|
b22e97dcfa | ||
|
f02de13316 | ||
|
c46268bf60 | ||
|
cf49a2c5b5 | ||
|
170c7bb90c | ||
|
2a0b138feb | ||
|
e195c1b13f | ||
|
5b4eb021cb | ||
|
396454fa41 | ||
|
a3cf272522 | ||
|
ba9548f756 | ||
|
e18f53cca9 | ||
|
c36be0ea09 | ||
|
9093301a49 | ||
|
bd951a714f | ||
|
6493709d6a | ||
|
b976f934ae | ||
|
7d8cf4cacc | ||
|
68f4496b8e | ||
|
ef5266b1c1 | ||
|
a96e65df18 | ||
|
93a49a45de | ||
|
ec70ed6aea | ||
|
7a13f74220 | ||
|
8042eb20c6 | ||
|
bd9f166c12 | ||
|
dd94416db2 | ||
|
ae0e7c4dff | ||
|
78f79266a9 | ||
|
1883e70b43 | ||
|
31ca603ccb | ||
|
f7fb193712 | ||
|
7e9267fa77 | ||
|
91d40086db | ||
|
5b12b55e32 | ||
|
e9e9a031a8 | ||
|
d7430c529a | ||
|
cd88f709ab | ||
|
4459a17e82 | ||
|
483b3e62e0 | ||
|
8e81c507d2 | ||
|
e1c6dc720e | ||
|
7ea79ebb9d | ||
|
ae75a084df | ||
|
d6a2137fc3 | ||
|
53e8d8193c | ||
|
29596bd53f | ||
|
803af1e0c3 | ||
|
6673939e76 | ||
|
f74778e75d | ||
|
520eb77b72 | ||
|
5bf69bde35 | ||
|
c69af655aa | ||
|
251f54a2ad | ||
|
c6529c0d77 | ||
|
baa8c8cdd3 | ||
|
40fd39c7cb | ||
|
4d1c4b9797 | ||
|
d2566eb4b2 | ||
|
ef7e885fe4 | ||
|
ecb8d15e7a | ||
|
365f9ed157 | ||
|
50c605e957 | ||
|
9685d4f3c3 | ||
|
8a4ff747bd | ||
|
af1eb58be8 | ||
|
373a9386a4 | ||
|
6e28a46454 | ||
|
c7b25784b1 | ||
|
7f800d04fa | ||
|
97755eed46 | ||
|
daf9d25ee2 | ||
|
3b4b171e18 | ||
|
d8759c772b | ||
|
4248b1618f | ||
|
866f6cdab4 | ||
|
3aa83feeec | ||
|
871749c208 | ||
|
fcc1643c52 | ||
|
20687293fe | ||
|
47d55b8b45 | ||
|
310f4b6ef8 | ||
|
856448060c | ||
|
312d511630 | ||
|
4f4f1c642a | ||
|
010954d277 | ||
|
6d46bb4b4c | ||
|
67f57c5bcc | ||
|
fd943c928f | ||
|
d3bd983b91 | ||
|
fb4754624d | ||
|
180db6753f | ||
|
d062fcc5c0 | ||
|
456abad834 | ||
|
19e45e9b0e | ||
|
97f23b81f3 | ||
|
08b7cc7506 | ||
|
6c319cbb4e | ||
|
df1aebe52e | ||
|
704fc78854 | ||
|
1d9fee79fd | ||
|
aeba0b3a26 |
8
.github/ISSUE_TEMPLATE/bug-report.yml
vendored
8
.github/ISSUE_TEMPLATE/bug-report.yml
vendored
@@ -15,6 +15,14 @@ body:
|
||||
steps to replicate what went wrong and others will be able to repeat your steps and see the same issue happen.
|
||||
|
||||
If unsure, ask on the [ComfyUI Matrix Space](https://app.element.io/#/room/%23comfyui_space%3Amatrix.org) or the [Comfy Org Discord](https://discord.gg/comfyorg) first.
|
||||
- type: checkboxes
|
||||
id: custom-nodes-test
|
||||
attributes:
|
||||
label: Custom Node Testing
|
||||
description: Please confirm you have tried to reproduce the issue with all custom nodes disabled.
|
||||
options:
|
||||
- label: I have tried disabling custom nodes and the issue persists (see [how to disable custom nodes](https://docs.comfy.org/troubleshooting/custom-node-issues#step-1%3A-test-with-all-custom-nodes-disabled) if you need help)
|
||||
required: true
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: Expected Behavior
|
||||
|
8
.github/ISSUE_TEMPLATE/user-support.yml
vendored
8
.github/ISSUE_TEMPLATE/user-support.yml
vendored
@@ -11,6 +11,14 @@ body:
|
||||
**2:** You have made an effort to find public answers to your question before asking here. In other words, you googled it first, and scrolled through recent help topics.
|
||||
|
||||
If unsure, ask on the [ComfyUI Matrix Space](https://app.element.io/#/room/%23comfyui_space%3Amatrix.org) or the [Comfy Org Discord](https://discord.gg/comfyorg) first.
|
||||
- type: checkboxes
|
||||
id: custom-nodes-test
|
||||
attributes:
|
||||
label: Custom Node Testing
|
||||
description: Please confirm you have tried to reproduce the issue with all custom nodes disabled.
|
||||
options:
|
||||
- label: I have tried disabling custom nodes and the issue persists (see [how to disable custom nodes](https://docs.comfy.org/troubleshooting/custom-node-issues#step-1%3A-test-with-all-custom-nodes-disabled) if you need help)
|
||||
required: true
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: Your question
|
||||
|
108
.github/workflows/release-webhook.yml
vendored
Normal file
108
.github/workflows/release-webhook.yml
vendored
Normal file
@@ -0,0 +1,108 @@
|
||||
name: Release Webhook
|
||||
|
||||
on:
|
||||
release:
|
||||
types: [published]
|
||||
|
||||
jobs:
|
||||
send-webhook:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Send release webhook
|
||||
env:
|
||||
WEBHOOK_URL: ${{ secrets.RELEASE_GITHUB_WEBHOOK_URL }}
|
||||
WEBHOOK_SECRET: ${{ secrets.RELEASE_GITHUB_WEBHOOK_SECRET }}
|
||||
run: |
|
||||
# Generate UUID for delivery ID
|
||||
DELIVERY_ID=$(uuidgen)
|
||||
HOOK_ID="release-webhook-$(date +%s)"
|
||||
|
||||
# Create webhook payload matching GitHub release webhook format
|
||||
PAYLOAD=$(cat <<EOF
|
||||
{
|
||||
"action": "published",
|
||||
"release": {
|
||||
"id": ${{ github.event.release.id }},
|
||||
"node_id": "${{ github.event.release.node_id }}",
|
||||
"url": "${{ github.event.release.url }}",
|
||||
"html_url": "${{ github.event.release.html_url }}",
|
||||
"assets_url": "${{ github.event.release.assets_url }}",
|
||||
"upload_url": "${{ github.event.release.upload_url }}",
|
||||
"tag_name": "${{ github.event.release.tag_name }}",
|
||||
"target_commitish": "${{ github.event.release.target_commitish }}",
|
||||
"name": ${{ toJSON(github.event.release.name) }},
|
||||
"body": ${{ toJSON(github.event.release.body) }},
|
||||
"draft": ${{ github.event.release.draft }},
|
||||
"prerelease": ${{ github.event.release.prerelease }},
|
||||
"created_at": "${{ github.event.release.created_at }}",
|
||||
"published_at": "${{ github.event.release.published_at }}",
|
||||
"author": {
|
||||
"login": "${{ github.event.release.author.login }}",
|
||||
"id": ${{ github.event.release.author.id }},
|
||||
"node_id": "${{ github.event.release.author.node_id }}",
|
||||
"avatar_url": "${{ github.event.release.author.avatar_url }}",
|
||||
"url": "${{ github.event.release.author.url }}",
|
||||
"html_url": "${{ github.event.release.author.html_url }}",
|
||||
"type": "${{ github.event.release.author.type }}",
|
||||
"site_admin": ${{ github.event.release.author.site_admin }}
|
||||
},
|
||||
"tarball_url": "${{ github.event.release.tarball_url }}",
|
||||
"zipball_url": "${{ github.event.release.zipball_url }}",
|
||||
"assets": ${{ toJSON(github.event.release.assets) }}
|
||||
},
|
||||
"repository": {
|
||||
"id": ${{ github.event.repository.id }},
|
||||
"node_id": "${{ github.event.repository.node_id }}",
|
||||
"name": "${{ github.event.repository.name }}",
|
||||
"full_name": "${{ github.event.repository.full_name }}",
|
||||
"private": ${{ github.event.repository.private }},
|
||||
"owner": {
|
||||
"login": "${{ github.event.repository.owner.login }}",
|
||||
"id": ${{ github.event.repository.owner.id }},
|
||||
"node_id": "${{ github.event.repository.owner.node_id }}",
|
||||
"avatar_url": "${{ github.event.repository.owner.avatar_url }}",
|
||||
"url": "${{ github.event.repository.owner.url }}",
|
||||
"html_url": "${{ github.event.repository.owner.html_url }}",
|
||||
"type": "${{ github.event.repository.owner.type }}",
|
||||
"site_admin": ${{ github.event.repository.owner.site_admin }}
|
||||
},
|
||||
"html_url": "${{ github.event.repository.html_url }}",
|
||||
"clone_url": "${{ github.event.repository.clone_url }}",
|
||||
"git_url": "${{ github.event.repository.git_url }}",
|
||||
"ssh_url": "${{ github.event.repository.ssh_url }}",
|
||||
"url": "${{ github.event.repository.url }}",
|
||||
"created_at": "${{ github.event.repository.created_at }}",
|
||||
"updated_at": "${{ github.event.repository.updated_at }}",
|
||||
"pushed_at": "${{ github.event.repository.pushed_at }}",
|
||||
"default_branch": "${{ github.event.repository.default_branch }}",
|
||||
"fork": ${{ github.event.repository.fork }}
|
||||
},
|
||||
"sender": {
|
||||
"login": "${{ github.event.sender.login }}",
|
||||
"id": ${{ github.event.sender.id }},
|
||||
"node_id": "${{ github.event.sender.node_id }}",
|
||||
"avatar_url": "${{ github.event.sender.avatar_url }}",
|
||||
"url": "${{ github.event.sender.url }}",
|
||||
"html_url": "${{ github.event.sender.html_url }}",
|
||||
"type": "${{ github.event.sender.type }}",
|
||||
"site_admin": ${{ github.event.sender.site_admin }}
|
||||
}
|
||||
}
|
||||
EOF
|
||||
)
|
||||
|
||||
# Generate HMAC-SHA256 signature
|
||||
SIGNATURE=$(echo -n "$PAYLOAD" | openssl dgst -sha256 -hmac "$WEBHOOK_SECRET" -hex | cut -d' ' -f2)
|
||||
|
||||
# Send webhook with required headers
|
||||
curl -X POST "$WEBHOOK_URL" \
|
||||
-H "Content-Type: application/json" \
|
||||
-H "X-GitHub-Event: release" \
|
||||
-H "X-GitHub-Delivery: $DELIVERY_ID" \
|
||||
-H "X-GitHub-Hook-ID: $HOOK_ID" \
|
||||
-H "X-Hub-Signature-256: sha256=$SIGNATURE" \
|
||||
-H "User-Agent: GitHub-Actions-Webhook/1.0" \
|
||||
-d "$PAYLOAD" \
|
||||
--fail --silent --show-error
|
||||
|
||||
echo "✅ Release webhook sent successfully"
|
3
.github/workflows/stable-release.yml
vendored
3
.github/workflows/stable-release.yml
vendored
@@ -102,5 +102,4 @@ jobs:
|
||||
file: ComfyUI_windows_portable_nvidia.7z
|
||||
tag: ${{ inputs.git_tag }}
|
||||
overwrite: true
|
||||
prerelease: true
|
||||
make_latest: false
|
||||
draft: true
|
||||
|
4
.github/workflows/test-unit.yml
vendored
4
.github/workflows/test-unit.yml
vendored
@@ -28,3 +28,7 @@ jobs:
|
||||
run: |
|
||||
pip install -r tests-unit/requirements.txt
|
||||
python -m pytest tests-unit
|
||||
- name: Run Execution Model Tests
|
||||
run: |
|
||||
python -m pytest tests/inference/test_execution.py
|
||||
|
||||
|
26
CODEOWNERS
26
CODEOWNERS
@@ -5,20 +5,20 @@
|
||||
# Inlined the team members for now.
|
||||
|
||||
# Maintainers
|
||||
*.md @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
||||
/tests/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
||||
/tests-unit/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
||||
/notebooks/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
||||
/script_examples/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
||||
/.github/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
||||
/requirements.txt @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
||||
/pyproject.toml @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
||||
*.md @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
||||
/tests/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
||||
/tests-unit/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
||||
/notebooks/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
||||
/script_examples/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
||||
/.github/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
||||
/requirements.txt @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
||||
/pyproject.toml @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
||||
|
||||
# Python web server
|
||||
/api_server/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @christian-byrne
|
||||
/app/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @christian-byrne
|
||||
/utils/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @christian-byrne
|
||||
/api_server/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @christian-byrne
|
||||
/app/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @christian-byrne
|
||||
/utils/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @christian-byrne
|
||||
|
||||
# Node developers
|
||||
/comfy_extras/ @yoland68 @robinjhuang @huchenlei @pythongosssss @ltdrdata @Kosinkadink @webfiltered @christian-byrne
|
||||
/comfy/comfy_types/ @yoland68 @robinjhuang @huchenlei @pythongosssss @ltdrdata @Kosinkadink @webfiltered @christian-byrne
|
||||
/comfy_extras/ @yoland68 @robinjhuang @pythongosssss @ltdrdata @Kosinkadink @webfiltered @christian-byrne
|
||||
/comfy/comfy_types/ @yoland68 @robinjhuang @pythongosssss @ltdrdata @Kosinkadink @webfiltered @christian-byrne
|
||||
|
14
README.md
14
README.md
@@ -6,6 +6,7 @@
|
||||
|
||||
[![Website][website-shield]][website-url]
|
||||
[![Dynamic JSON Badge][discord-shield]][discord-url]
|
||||
[![Twitter][twitter-shield]][twitter-url]
|
||||
[![Matrix][matrix-shield]][matrix-url]
|
||||
<br>
|
||||
[![][github-release-shield]][github-release-link]
|
||||
@@ -20,6 +21,8 @@
|
||||
<!-- Workaround to display total user from https://github.com/badges/shields/issues/4500#issuecomment-2060079995 -->
|
||||
[discord-shield]: https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fdiscord.com%2Fapi%2Finvites%2Fcomfyorg%3Fwith_counts%3Dtrue&query=%24.approximate_member_count&logo=discord&logoColor=white&label=Discord&color=green&suffix=%20total
|
||||
[discord-url]: https://www.comfy.org/discord
|
||||
[twitter-shield]: https://img.shields.io/twitter/follow/ComfyUI
|
||||
[twitter-url]: https://x.com/ComfyUI
|
||||
|
||||
[github-release-shield]: https://img.shields.io/github/v/release/comfyanonymous/ComfyUI?style=flat&sort=semver
|
||||
[github-release-link]: https://github.com/comfyanonymous/ComfyUI/releases
|
||||
@@ -62,12 +65,16 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith
|
||||
- [Flux](https://comfyanonymous.github.io/ComfyUI_examples/flux/)
|
||||
- [Lumina Image 2.0](https://comfyanonymous.github.io/ComfyUI_examples/lumina2/)
|
||||
- [HiDream](https://comfyanonymous.github.io/ComfyUI_examples/hidream/)
|
||||
- [Cosmos Predict2](https://comfyanonymous.github.io/ComfyUI_examples/cosmos_predict2/)
|
||||
- Image Editing Models
|
||||
- [Omnigen 2](https://comfyanonymous.github.io/ComfyUI_examples/omnigen/)
|
||||
- [Flux Kontext](https://comfyanonymous.github.io/ComfyUI_examples/flux/#flux-kontext-image-editing-model)
|
||||
- Video Models
|
||||
- [Stable Video Diffusion](https://comfyanonymous.github.io/ComfyUI_examples/video/)
|
||||
- [Mochi](https://comfyanonymous.github.io/ComfyUI_examples/mochi/)
|
||||
- [LTX-Video](https://comfyanonymous.github.io/ComfyUI_examples/ltxv/)
|
||||
- [Hunyuan Video](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_video/)
|
||||
- [Nvidia Cosmos](https://comfyanonymous.github.io/ComfyUI_examples/cosmos/)
|
||||
- [Nvidia Cosmos](https://comfyanonymous.github.io/ComfyUI_examples/cosmos/) and [Cosmos Predict2](https://comfyanonymous.github.io/ComfyUI_examples/cosmos_predict2/)
|
||||
- [Wan 2.1](https://comfyanonymous.github.io/ComfyUI_examples/wan/)
|
||||
- Audio Models
|
||||
- [Stable Audio](https://comfyanonymous.github.io/ComfyUI_examples/audio/)
|
||||
@@ -95,7 +102,8 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith
|
||||
- [LCM models and Loras](https://comfyanonymous.github.io/ComfyUI_examples/lcm/)
|
||||
- Latent previews with [TAESD](#how-to-show-high-quality-previews)
|
||||
- Starts up very fast.
|
||||
- Works fully offline: will never download anything.
|
||||
- Works fully offline: core will never download anything unless you want to.
|
||||
- Optional API nodes to use paid models from external providers through the online [Comfy API](https://docs.comfy.org/tutorials/api-nodes/overview).
|
||||
- [Config file](extra_model_paths.yaml.example) to set the search paths for models.
|
||||
|
||||
Workflow examples can be found on the [Examples page](https://comfyanonymous.github.io/ComfyUI_examples/)
|
||||
@@ -268,6 +276,8 @@ You can install ComfyUI in Apple Mac silicon (M1 or M2) with any recent macOS ve
|
||||
|
||||
#### DirectML (AMD Cards on Windows)
|
||||
|
||||
This is very badly supported and is not recommended. There are some unofficial builds of pytorch ROCm on windows that exist that will give you a much better experience than this. This readme will be updated once official pytorch ROCm builds for windows come out.
|
||||
|
||||
```pip install torch-directml``` Then you can launch ComfyUI with: ```python main.py --directml```
|
||||
|
||||
#### Ascend NPUs
|
||||
|
84
alembic.ini
Normal file
84
alembic.ini
Normal file
@@ -0,0 +1,84 @@
|
||||
# A generic, single database configuration.
|
||||
|
||||
[alembic]
|
||||
# path to migration scripts
|
||||
# Use forward slashes (/) also on windows to provide an os agnostic path
|
||||
script_location = alembic_db
|
||||
|
||||
# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s
|
||||
# Uncomment the line below if you want the files to be prepended with date and time
|
||||
# see https://alembic.sqlalchemy.org/en/latest/tutorial.html#editing-the-ini-file
|
||||
# for all available tokens
|
||||
# file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s
|
||||
|
||||
# sys.path path, will be prepended to sys.path if present.
|
||||
# defaults to the current working directory.
|
||||
prepend_sys_path = .
|
||||
|
||||
# timezone to use when rendering the date within the migration file
|
||||
# as well as the filename.
|
||||
# If specified, requires the python>=3.9 or backports.zoneinfo library and tzdata library.
|
||||
# Any required deps can installed by adding `alembic[tz]` to the pip requirements
|
||||
# string value is passed to ZoneInfo()
|
||||
# leave blank for localtime
|
||||
# timezone =
|
||||
|
||||
# max length of characters to apply to the "slug" field
|
||||
# truncate_slug_length = 40
|
||||
|
||||
# set to 'true' to run the environment during
|
||||
# the 'revision' command, regardless of autogenerate
|
||||
# revision_environment = false
|
||||
|
||||
# set to 'true' to allow .pyc and .pyo files without
|
||||
# a source .py file to be detected as revisions in the
|
||||
# versions/ directory
|
||||
# sourceless = false
|
||||
|
||||
# version location specification; This defaults
|
||||
# to alembic_db/versions. When using multiple version
|
||||
# directories, initial revisions must be specified with --version-path.
|
||||
# The path separator used here should be the separator specified by "version_path_separator" below.
|
||||
# version_locations = %(here)s/bar:%(here)s/bat:alembic_db/versions
|
||||
|
||||
# version path separator; As mentioned above, this is the character used to split
|
||||
# version_locations. The default within new alembic.ini files is "os", which uses os.pathsep.
|
||||
# If this key is omitted entirely, it falls back to the legacy behavior of splitting on spaces and/or commas.
|
||||
# Valid values for version_path_separator are:
|
||||
#
|
||||
# version_path_separator = :
|
||||
# version_path_separator = ;
|
||||
# version_path_separator = space
|
||||
# version_path_separator = newline
|
||||
#
|
||||
# Use os.pathsep. Default configuration used for new projects.
|
||||
version_path_separator = os
|
||||
|
||||
# set to 'true' to search source files recursively
|
||||
# in each "version_locations" directory
|
||||
# new in Alembic version 1.10
|
||||
# recursive_version_locations = false
|
||||
|
||||
# the output encoding used when revision files
|
||||
# are written from script.py.mako
|
||||
# output_encoding = utf-8
|
||||
|
||||
sqlalchemy.url = sqlite:///user/comfyui.db
|
||||
|
||||
|
||||
[post_write_hooks]
|
||||
# post_write_hooks defines scripts or Python functions that are run
|
||||
# on newly generated revision scripts. See the documentation for further
|
||||
# detail and examples
|
||||
|
||||
# format using "black" - use the console_scripts runner, against the "black" entrypoint
|
||||
# hooks = black
|
||||
# black.type = console_scripts
|
||||
# black.entrypoint = black
|
||||
# black.options = -l 79 REVISION_SCRIPT_FILENAME
|
||||
|
||||
# lint with attempts to fix using "ruff" - use the exec runner, execute a binary
|
||||
# hooks = ruff
|
||||
# ruff.type = exec
|
||||
# ruff.executable = %(here)s/.venv/bin/ruff
|
||||
# ruff.options = check --fix REVISION_SCRIPT_FILENAME
|
4
alembic_db/README.md
Normal file
4
alembic_db/README.md
Normal file
@@ -0,0 +1,4 @@
|
||||
## Generate new revision
|
||||
|
||||
1. Update models in `/app/database/models.py`
|
||||
2. Run `alembic revision --autogenerate -m "{your message}"`
|
64
alembic_db/env.py
Normal file
64
alembic_db/env.py
Normal file
@@ -0,0 +1,64 @@
|
||||
from sqlalchemy import engine_from_config
|
||||
from sqlalchemy import pool
|
||||
|
||||
from alembic import context
|
||||
|
||||
# this is the Alembic Config object, which provides
|
||||
# access to the values within the .ini file in use.
|
||||
config = context.config
|
||||
|
||||
|
||||
from app.database.models import Base
|
||||
target_metadata = Base.metadata
|
||||
|
||||
# other values from the config, defined by the needs of env.py,
|
||||
# can be acquired:
|
||||
# my_important_option = config.get_main_option("my_important_option")
|
||||
# ... etc.
|
||||
|
||||
|
||||
def run_migrations_offline() -> None:
|
||||
"""Run migrations in 'offline' mode.
|
||||
This configures the context with just a URL
|
||||
and not an Engine, though an Engine is acceptable
|
||||
here as well. By skipping the Engine creation
|
||||
we don't even need a DBAPI to be available.
|
||||
Calls to context.execute() here emit the given string to the
|
||||
script output.
|
||||
"""
|
||||
url = config.get_main_option("sqlalchemy.url")
|
||||
context.configure(
|
||||
url=url,
|
||||
target_metadata=target_metadata,
|
||||
literal_binds=True,
|
||||
dialect_opts={"paramstyle": "named"},
|
||||
)
|
||||
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
def run_migrations_online() -> None:
|
||||
"""Run migrations in 'online' mode.
|
||||
In this scenario we need to create an Engine
|
||||
and associate a connection with the context.
|
||||
"""
|
||||
connectable = engine_from_config(
|
||||
config.get_section(config.config_ini_section, {}),
|
||||
prefix="sqlalchemy.",
|
||||
poolclass=pool.NullPool,
|
||||
)
|
||||
|
||||
with connectable.connect() as connection:
|
||||
context.configure(
|
||||
connection=connection, target_metadata=target_metadata
|
||||
)
|
||||
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
if context.is_offline_mode():
|
||||
run_migrations_offline()
|
||||
else:
|
||||
run_migrations_online()
|
28
alembic_db/script.py.mako
Normal file
28
alembic_db/script.py.mako
Normal file
@@ -0,0 +1,28 @@
|
||||
"""${message}
|
||||
|
||||
Revision ID: ${up_revision}
|
||||
Revises: ${down_revision | comma,n}
|
||||
Create Date: ${create_date}
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
${imports if imports else ""}
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = ${repr(up_revision)}
|
||||
down_revision: Union[str, None] = ${repr(down_revision)}
|
||||
branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)}
|
||||
depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)}
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Upgrade schema."""
|
||||
${upgrades if upgrades else "pass"}
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Downgrade schema."""
|
||||
${downgrades if downgrades else "pass"}
|
112
app/database/db.py
Normal file
112
app/database/db.py
Normal file
@@ -0,0 +1,112 @@
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
from app.logger import log_startup_warning
|
||||
from utils.install_util import get_missing_requirements_message
|
||||
from comfy.cli_args import args
|
||||
|
||||
_DB_AVAILABLE = False
|
||||
Session = None
|
||||
|
||||
|
||||
try:
|
||||
from alembic import command
|
||||
from alembic.config import Config
|
||||
from alembic.runtime.migration import MigrationContext
|
||||
from alembic.script import ScriptDirectory
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
_DB_AVAILABLE = True
|
||||
except ImportError as e:
|
||||
log_startup_warning(
|
||||
f"""
|
||||
------------------------------------------------------------------------
|
||||
Error importing dependencies: {e}
|
||||
{get_missing_requirements_message()}
|
||||
This error is happening because ComfyUI now uses a local sqlite database.
|
||||
------------------------------------------------------------------------
|
||||
""".strip()
|
||||
)
|
||||
|
||||
|
||||
def dependencies_available():
|
||||
"""
|
||||
Temporary function to check if the dependencies are available
|
||||
"""
|
||||
return _DB_AVAILABLE
|
||||
|
||||
|
||||
def can_create_session():
|
||||
"""
|
||||
Temporary function to check if the database is available to create a session
|
||||
During initial release there may be environmental issues (or missing dependencies) that prevent the database from being created
|
||||
"""
|
||||
return dependencies_available() and Session is not None
|
||||
|
||||
|
||||
def get_alembic_config():
|
||||
root_path = os.path.join(os.path.dirname(__file__), "../..")
|
||||
config_path = os.path.abspath(os.path.join(root_path, "alembic.ini"))
|
||||
scripts_path = os.path.abspath(os.path.join(root_path, "alembic_db"))
|
||||
|
||||
config = Config(config_path)
|
||||
config.set_main_option("script_location", scripts_path)
|
||||
config.set_main_option("sqlalchemy.url", args.database_url)
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def get_db_path():
|
||||
url = args.database_url
|
||||
if url.startswith("sqlite:///"):
|
||||
return url.split("///")[1]
|
||||
else:
|
||||
raise ValueError(f"Unsupported database URL '{url}'.")
|
||||
|
||||
|
||||
def init_db():
|
||||
db_url = args.database_url
|
||||
logging.debug(f"Database URL: {db_url}")
|
||||
db_path = get_db_path()
|
||||
db_exists = os.path.exists(db_path)
|
||||
|
||||
config = get_alembic_config()
|
||||
|
||||
# Check if we need to upgrade
|
||||
engine = create_engine(db_url)
|
||||
conn = engine.connect()
|
||||
|
||||
context = MigrationContext.configure(conn)
|
||||
current_rev = context.get_current_revision()
|
||||
|
||||
script = ScriptDirectory.from_config(config)
|
||||
target_rev = script.get_current_head()
|
||||
|
||||
if target_rev is None:
|
||||
logging.warning("No target revision found.")
|
||||
elif current_rev != target_rev:
|
||||
# Backup the database pre upgrade
|
||||
backup_path = db_path + ".bkp"
|
||||
if db_exists:
|
||||
shutil.copy(db_path, backup_path)
|
||||
else:
|
||||
backup_path = None
|
||||
|
||||
try:
|
||||
command.upgrade(config, target_rev)
|
||||
logging.info(f"Database upgraded from {current_rev} to {target_rev}")
|
||||
except Exception as e:
|
||||
if backup_path:
|
||||
# Restore the database from backup if upgrade fails
|
||||
shutil.copy(backup_path, db_path)
|
||||
os.remove(backup_path)
|
||||
logging.exception("Error upgrading database: ")
|
||||
raise e
|
||||
|
||||
global Session
|
||||
Session = sessionmaker(bind=engine)
|
||||
|
||||
|
||||
def create_session():
|
||||
return Session()
|
14
app/database/models.py
Normal file
14
app/database/models.py
Normal file
@@ -0,0 +1,14 @@
|
||||
from sqlalchemy.orm import declarative_base
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
|
||||
def to_dict(obj):
|
||||
fields = obj.__table__.columns.keys()
|
||||
return {
|
||||
field: (val.to_dict() if hasattr(val, "to_dict") else val)
|
||||
for field in fields
|
||||
if (val := getattr(obj, field))
|
||||
}
|
||||
|
||||
# TODO: Define models here
|
@@ -16,26 +16,17 @@ from importlib.metadata import version
|
||||
import requests
|
||||
from typing_extensions import NotRequired
|
||||
|
||||
from utils.install_util import get_missing_requirements_message, requirements_path
|
||||
|
||||
from comfy.cli_args import DEFAULT_VERSION_STRING
|
||||
import app.logger
|
||||
|
||||
# The path to the requirements.txt file
|
||||
req_path = Path(__file__).parents[1] / "requirements.txt"
|
||||
|
||||
|
||||
def frontend_install_warning_message():
|
||||
"""The warning message to display when the frontend version is not up to date."""
|
||||
|
||||
extra = ""
|
||||
if sys.flags.no_user_site:
|
||||
extra = "-s "
|
||||
return f"""
|
||||
Please install the updated requirements.txt file by running:
|
||||
{sys.executable} {extra}-m pip install -r {req_path}
|
||||
{get_missing_requirements_message()}
|
||||
|
||||
This error is happening because the ComfyUI frontend is no longer shipped as part of the main repo but as a pip package instead.
|
||||
|
||||
If you are on the portable package you can run: update\\update_comfyui.bat to solve this problem
|
||||
""".strip()
|
||||
|
||||
|
||||
@@ -48,7 +39,7 @@ def check_frontend_version():
|
||||
try:
|
||||
frontend_version_str = version("comfyui-frontend-package")
|
||||
frontend_version = parse_version(frontend_version_str)
|
||||
with open(req_path, "r", encoding="utf-8") as f:
|
||||
with open(requirements_path, "r", encoding="utf-8") as f:
|
||||
required_frontend = parse_version(f.readline().split("=")[-1])
|
||||
if frontend_version < required_frontend:
|
||||
app.logger.log_startup_warning(
|
||||
@@ -121,9 +112,22 @@ class FrontEndProvider:
|
||||
response.raise_for_status() # Raises an HTTPError if the response was an error
|
||||
return response.json()
|
||||
|
||||
@cached_property
|
||||
def latest_prerelease(self) -> Release:
|
||||
"""Get the latest pre-release version - even if it's older than the latest release"""
|
||||
release = [release for release in self.all_releases if release["prerelease"]]
|
||||
|
||||
if not release:
|
||||
raise ValueError("No pre-releases found")
|
||||
|
||||
# GitHub returns releases in reverse chronological order, so first is latest
|
||||
return release[0]
|
||||
|
||||
def get_release(self, version: str) -> Release:
|
||||
if version == "latest":
|
||||
return self.latest_release
|
||||
elif version == "prerelease":
|
||||
return self.latest_prerelease
|
||||
else:
|
||||
for release in self.all_releases:
|
||||
if release["tag_name"] in [version, f"v{version}"]:
|
||||
@@ -205,6 +209,19 @@ comfyui-workflow-templates is not installed.
|
||||
""".strip()
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def embedded_docs_path(cls) -> str:
|
||||
"""Get the path to embedded documentation"""
|
||||
try:
|
||||
import comfyui_embedded_docs
|
||||
|
||||
return str(
|
||||
importlib.resources.files(comfyui_embedded_docs) / "docs"
|
||||
)
|
||||
except ImportError:
|
||||
logging.info("comfyui-embedded-docs package not found")
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def parse_version_string(cls, value: str) -> tuple[str, str, str]:
|
||||
"""
|
||||
@@ -217,7 +234,7 @@ comfyui-workflow-templates is not installed.
|
||||
Raises:
|
||||
argparse.ArgumentTypeError: If the version string is invalid.
|
||||
"""
|
||||
VERSION_PATTERN = r"^([a-zA-Z0-9][a-zA-Z0-9-]{0,38})/([a-zA-Z0-9_.-]+)@(v?\d+\.\d+\.\d+|latest)$"
|
||||
VERSION_PATTERN = r"^([a-zA-Z0-9][a-zA-Z0-9-]{0,38})/([a-zA-Z0-9_.-]+)@(v?\d+\.\d+\.\d+[-._a-zA-Z0-9]*|latest|prerelease)$"
|
||||
match_result = re.match(VERSION_PATTERN, value)
|
||||
if match_result is None:
|
||||
raise argparse.ArgumentTypeError(f"Invalid version string: {value}")
|
||||
|
@@ -151,6 +151,7 @@ parser.add_argument("--windows-standalone-build", action="store_true", help="Win
|
||||
|
||||
parser.add_argument("--disable-metadata", action="store_true", help="Disable saving prompt metadata in files.")
|
||||
parser.add_argument("--disable-all-custom-nodes", action="store_true", help="Disable loading all custom nodes.")
|
||||
parser.add_argument("--whitelist-custom-nodes", type=str, nargs='+', default=[], help="Specify custom node folders to load even when --disable-all-custom-nodes is enabled.")
|
||||
parser.add_argument("--disable-api-nodes", action="store_true", help="Disable loading all api nodes.")
|
||||
|
||||
parser.add_argument("--multi-user", action="store_true", help="Enables per-user storage.")
|
||||
@@ -203,6 +204,11 @@ parser.add_argument(
|
||||
help="Set the base URL for the ComfyUI API. (default: https://api.comfy.org)",
|
||||
)
|
||||
|
||||
database_default_path = os.path.abspath(
|
||||
os.path.join(os.path.dirname(__file__), "..", "user", "comfyui.db")
|
||||
)
|
||||
parser.add_argument("--database-url", type=str, default=f"sqlite:///{database_default_path}", help="Specify the database URL, e.g. for an in-memory database you can use 'sqlite:///:memory:'.")
|
||||
|
||||
if comfy.options.args_parsing:
|
||||
args = parser.parse_args()
|
||||
else:
|
||||
|
@@ -37,6 +37,8 @@ class IO(StrEnum):
|
||||
CONTROL_NET = "CONTROL_NET"
|
||||
VAE = "VAE"
|
||||
MODEL = "MODEL"
|
||||
LORA_MODEL = "LORA_MODEL"
|
||||
LOSS_MAP = "LOSS_MAP"
|
||||
CLIP_VISION = "CLIP_VISION"
|
||||
CLIP_VISION_OUTPUT = "CLIP_VISION_OUTPUT"
|
||||
STYLE_MODEL = "STYLE_MODEL"
|
||||
|
@@ -86,3 +86,45 @@ class CONDConstant(CONDRegular):
|
||||
|
||||
def size(self):
|
||||
return [1]
|
||||
|
||||
|
||||
class CONDList(CONDRegular):
|
||||
def __init__(self, cond):
|
||||
self.cond = cond
|
||||
|
||||
def process_cond(self, batch_size, device, **kwargs):
|
||||
out = []
|
||||
for c in self.cond:
|
||||
out.append(comfy.utils.repeat_to_batch_size(c, batch_size).to(device))
|
||||
|
||||
return self._copy_with(out)
|
||||
|
||||
def can_concat(self, other):
|
||||
if len(self.cond) != len(other.cond):
|
||||
return False
|
||||
for i in range(len(self.cond)):
|
||||
if self.cond[i].shape != other.cond[i].shape:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def concat(self, others):
|
||||
out = []
|
||||
for i in range(len(self.cond)):
|
||||
o = [self.cond[i]]
|
||||
for x in others:
|
||||
o.append(x.cond[i])
|
||||
out.append(torch.cat(o))
|
||||
|
||||
return out
|
||||
|
||||
def size(self): # hackish implementation to make the mem estimation work
|
||||
o = 0
|
||||
c = 1
|
||||
for c in self.cond:
|
||||
size = c.size()
|
||||
o += math.prod(size)
|
||||
if len(size) > 1:
|
||||
c = size[1]
|
||||
|
||||
return [1, c, o // c]
|
||||
|
@@ -390,8 +390,9 @@ class ControlLora(ControlNet):
|
||||
pass
|
||||
|
||||
for k in self.control_weights:
|
||||
if k not in {"lora_controlnet"}:
|
||||
comfy.utils.set_attr_param(self.control_model, k, self.control_weights[k].to(dtype).to(comfy.model_management.get_torch_device()))
|
||||
if (k not in {"lora_controlnet"}):
|
||||
if (k.endswith(".up") or k.endswith(".down") or k.endswith(".weight") or k.endswith(".bias")) and ("__" not in k):
|
||||
comfy.utils.set_attr_param(self.control_model, k, self.control_weights[k].to(dtype).to(comfy.model_management.get_torch_device()))
|
||||
|
||||
def copy(self):
|
||||
c = ControlLora(self.control_weights, global_average_pooling=self.global_average_pooling)
|
||||
|
@@ -1,4 +1,5 @@
|
||||
import math
|
||||
from functools import partial
|
||||
|
||||
from scipy import integrate
|
||||
import torch
|
||||
@@ -142,6 +143,33 @@ class BrownianTreeNoiseSampler:
|
||||
return self.tree(t0, t1) / (t1 - t0).abs().sqrt()
|
||||
|
||||
|
||||
def sigma_to_half_log_snr(sigma, model_sampling):
|
||||
"""Convert sigma to half-logSNR log(alpha_t / sigma_t)."""
|
||||
if isinstance(model_sampling, comfy.model_sampling.CONST):
|
||||
# log((1 - t) / t) = log((1 - sigma) / sigma)
|
||||
return sigma.logit().neg()
|
||||
return sigma.log().neg()
|
||||
|
||||
|
||||
def half_log_snr_to_sigma(half_log_snr, model_sampling):
|
||||
"""Convert half-logSNR log(alpha_t / sigma_t) to sigma."""
|
||||
if isinstance(model_sampling, comfy.model_sampling.CONST):
|
||||
# 1 / (1 + exp(half_log_snr))
|
||||
return half_log_snr.neg().sigmoid()
|
||||
return half_log_snr.neg().exp()
|
||||
|
||||
|
||||
def offset_first_sigma_for_snr(sigmas, model_sampling, percent_offset=1e-4):
|
||||
"""Adjust the first sigma to avoid invalid logSNR."""
|
||||
if len(sigmas) <= 1:
|
||||
return sigmas
|
||||
if isinstance(model_sampling, comfy.model_sampling.CONST):
|
||||
if sigmas[0] >= 1:
|
||||
sigmas = sigmas.clone()
|
||||
sigmas[0] = model_sampling.percent_to_sigma(percent_offset)
|
||||
return sigmas
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_euler(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
|
||||
"""Implements Algorithm 2 (Euler steps) from Karras et al. (2022)."""
|
||||
@@ -682,6 +710,7 @@ def sample_dpmpp_2s_ancestral_RF(model, x, sigmas, extra_args=None, callback=Non
|
||||
# logged_x = torch.cat((logged_x, x.unsqueeze(0)), dim=0)
|
||||
return x
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_dpmpp_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=1 / 2):
|
||||
"""DPM-Solver++ (stochastic)."""
|
||||
@@ -693,38 +722,49 @@ def sample_dpmpp_sde(model, x, sigmas, extra_args=None, callback=None, disable=N
|
||||
seed = extra_args.get("seed", None)
|
||||
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
sigma_fn = lambda t: t.neg().exp()
|
||||
t_fn = lambda sigma: sigma.log().neg()
|
||||
|
||||
model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling')
|
||||
sigma_fn = partial(half_log_snr_to_sigma, model_sampling=model_sampling)
|
||||
lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling)
|
||||
sigmas = offset_first_sigma_for_snr(sigmas, model_sampling)
|
||||
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||
if sigmas[i + 1] == 0:
|
||||
# Euler method
|
||||
d = to_d(x, sigmas[i], denoised)
|
||||
dt = sigmas[i + 1] - sigmas[i]
|
||||
x = x + d * dt
|
||||
# Denoising step
|
||||
x = denoised
|
||||
else:
|
||||
# DPM-Solver++
|
||||
t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1])
|
||||
h = t_next - t
|
||||
s = t + h * r
|
||||
lambda_s, lambda_t = lambda_fn(sigmas[i]), lambda_fn(sigmas[i + 1])
|
||||
h = lambda_t - lambda_s
|
||||
lambda_s_1 = lambda_s + r * h
|
||||
fac = 1 / (2 * r)
|
||||
|
||||
sigma_s_1 = sigma_fn(lambda_s_1)
|
||||
|
||||
alpha_s = sigmas[i] * lambda_s.exp()
|
||||
alpha_s_1 = sigma_s_1 * lambda_s_1.exp()
|
||||
alpha_t = sigmas[i + 1] * lambda_t.exp()
|
||||
|
||||
# Step 1
|
||||
sd, su = get_ancestral_step(sigma_fn(t), sigma_fn(s), eta)
|
||||
s_ = t_fn(sd)
|
||||
x_2 = (sigma_fn(s_) / sigma_fn(t)) * x - (t - s_).expm1() * denoised
|
||||
x_2 = x_2 + noise_sampler(sigma_fn(t), sigma_fn(s)) * s_noise * su
|
||||
denoised_2 = model(x_2, sigma_fn(s) * s_in, **extra_args)
|
||||
sd, su = get_ancestral_step(lambda_s.neg().exp(), lambda_s_1.neg().exp(), eta)
|
||||
lambda_s_1_ = sd.log().neg()
|
||||
h_ = lambda_s_1_ - lambda_s
|
||||
x_2 = (alpha_s_1 / alpha_s) * (-h_).exp() * x - alpha_s_1 * (-h_).expm1() * denoised
|
||||
if eta > 0 and s_noise > 0:
|
||||
x_2 = x_2 + alpha_s_1 * noise_sampler(sigmas[i], sigma_s_1) * s_noise * su
|
||||
denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args)
|
||||
|
||||
# Step 2
|
||||
sd, su = get_ancestral_step(sigma_fn(t), sigma_fn(t_next), eta)
|
||||
t_next_ = t_fn(sd)
|
||||
sd, su = get_ancestral_step(lambda_s.neg().exp(), lambda_t.neg().exp(), eta)
|
||||
lambda_t_ = sd.log().neg()
|
||||
h_ = lambda_t_ - lambda_s
|
||||
denoised_d = (1 - fac) * denoised + fac * denoised_2
|
||||
x = (sigma_fn(t_next_) / sigma_fn(t)) * x - (t - t_next_).expm1() * denoised_d
|
||||
x = x + noise_sampler(sigma_fn(t), sigma_fn(t_next)) * s_noise * su
|
||||
x = (alpha_t / alpha_s) * (-h_).exp() * x - alpha_t * (-h_).expm1() * denoised_d
|
||||
if eta > 0 and s_noise > 0:
|
||||
x = x + alpha_t * noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * su
|
||||
return x
|
||||
|
||||
|
||||
@@ -753,6 +793,7 @@ def sample_dpmpp_2m(model, x, sigmas, extra_args=None, callback=None, disable=No
|
||||
old_denoised = denoised
|
||||
return x
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'):
|
||||
"""DPM-Solver++(2M) SDE."""
|
||||
@@ -768,9 +809,12 @@ def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
|
||||
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
|
||||
model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling')
|
||||
lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling)
|
||||
sigmas = offset_first_sigma_for_snr(sigmas, model_sampling)
|
||||
|
||||
old_denoised = None
|
||||
h_last = None
|
||||
h = None
|
||||
h, h_last = None, None
|
||||
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||
@@ -781,26 +825,29 @@ def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
|
||||
x = denoised
|
||||
else:
|
||||
# DPM-Solver++(2M) SDE
|
||||
t, s = -sigmas[i].log(), -sigmas[i + 1].log()
|
||||
h = s - t
|
||||
eta_h = eta * h
|
||||
lambda_s, lambda_t = lambda_fn(sigmas[i]), lambda_fn(sigmas[i + 1])
|
||||
h = lambda_t - lambda_s
|
||||
h_eta = h * (eta + 1)
|
||||
|
||||
x = sigmas[i + 1] / sigmas[i] * (-eta_h).exp() * x + (-h - eta_h).expm1().neg() * denoised
|
||||
alpha_t = sigmas[i + 1] * lambda_t.exp()
|
||||
|
||||
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x + alpha_t * (-h_eta).expm1().neg() * denoised
|
||||
|
||||
if old_denoised is not None:
|
||||
r = h_last / h
|
||||
if solver_type == 'heun':
|
||||
x = x + ((-h - eta_h).expm1().neg() / (-h - eta_h) + 1) * (1 / r) * (denoised - old_denoised)
|
||||
x = x + alpha_t * ((-h_eta).expm1().neg() / (-h_eta) + 1) * (1 / r) * (denoised - old_denoised)
|
||||
elif solver_type == 'midpoint':
|
||||
x = x + 0.5 * (-h - eta_h).expm1().neg() * (1 / r) * (denoised - old_denoised)
|
||||
x = x + 0.5 * alpha_t * (-h_eta).expm1().neg() * (1 / r) * (denoised - old_denoised)
|
||||
|
||||
if eta:
|
||||
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * eta_h).expm1().neg().sqrt() * s_noise
|
||||
if eta > 0 and s_noise > 0:
|
||||
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * h * eta).expm1().neg().sqrt() * s_noise
|
||||
|
||||
old_denoised = denoised
|
||||
h_last = h
|
||||
return x
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
||||
"""DPM-Solver++(3M) SDE."""
|
||||
@@ -814,6 +861,10 @@ def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
|
||||
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
|
||||
model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling')
|
||||
lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling)
|
||||
sigmas = offset_first_sigma_for_snr(sigmas, model_sampling)
|
||||
|
||||
denoised_1, denoised_2 = None, None
|
||||
h, h_1, h_2 = None, None, None
|
||||
|
||||
@@ -825,13 +876,16 @@ def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
|
||||
# Denoising step
|
||||
x = denoised
|
||||
else:
|
||||
t, s = -sigmas[i].log(), -sigmas[i + 1].log()
|
||||
h = s - t
|
||||
lambda_s, lambda_t = lambda_fn(sigmas[i]), lambda_fn(sigmas[i + 1])
|
||||
h = lambda_t - lambda_s
|
||||
h_eta = h * (eta + 1)
|
||||
|
||||
x = torch.exp(-h_eta) * x + (-h_eta).expm1().neg() * denoised
|
||||
alpha_t = sigmas[i + 1] * lambda_t.exp()
|
||||
|
||||
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x + alpha_t * (-h_eta).expm1().neg() * denoised
|
||||
|
||||
if h_2 is not None:
|
||||
# DPM-Solver++(3M) SDE
|
||||
r0 = h_1 / h
|
||||
r1 = h_2 / h
|
||||
d1_0 = (denoised - denoised_1) / r0
|
||||
@@ -840,20 +894,22 @@ def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
|
||||
d2 = (d1_0 - d1_1) / (r0 + r1)
|
||||
phi_2 = h_eta.neg().expm1() / h_eta + 1
|
||||
phi_3 = phi_2 / h_eta - 0.5
|
||||
x = x + phi_2 * d1 - phi_3 * d2
|
||||
x = x + (alpha_t * phi_2) * d1 - (alpha_t * phi_3) * d2
|
||||
elif h_1 is not None:
|
||||
# DPM-Solver++(2M) SDE
|
||||
r = h_1 / h
|
||||
d = (denoised - denoised_1) / r
|
||||
phi_2 = h_eta.neg().expm1() / h_eta + 1
|
||||
x = x + phi_2 * d
|
||||
x = x + (alpha_t * phi_2) * d
|
||||
|
||||
if eta:
|
||||
if eta > 0 and s_noise > 0:
|
||||
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * h * eta).expm1().neg().sqrt() * s_noise
|
||||
|
||||
denoised_1, denoised_2 = denoised, denoised_1
|
||||
h_1, h_2 = h, h_1
|
||||
return x
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_dpmpp_3m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
||||
if len(sigmas) <= 1:
|
||||
@@ -863,6 +919,7 @@ def sample_dpmpp_3m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, di
|
||||
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler
|
||||
return sample_dpmpp_3m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_dpmpp_2m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'):
|
||||
if len(sigmas) <= 1:
|
||||
@@ -872,6 +929,7 @@ def sample_dpmpp_2m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, di
|
||||
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler
|
||||
return sample_dpmpp_2m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, solver_type=solver_type)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_dpmpp_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=1 / 2):
|
||||
if len(sigmas) <= 1:
|
||||
@@ -1389,14 +1447,15 @@ def sample_gradient_estimation(model, x, sigmas, extra_args=None, callback=None,
|
||||
old_d = d
|
||||
return x
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_gradient_estimation_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, ge_gamma=2.):
|
||||
return sample_gradient_estimation(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, ge_gamma=ge_gamma, cfg_pp=True)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_er_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, s_noise=1., noise_sampler=None, noise_scaler=None, max_stage=3):
|
||||
"""
|
||||
Extended Reverse-Time SDE solver (VE ER-SDE-Solver-3). Arxiv: https://arxiv.org/abs/2309.06169.
|
||||
def sample_er_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, s_noise=1.0, noise_sampler=None, noise_scaler=None, max_stage=3):
|
||||
"""Extended Reverse-Time SDE solver (VP ER-SDE-Solver-3). arXiv: https://arxiv.org/abs/2309.06169.
|
||||
Code reference: https://github.com/QinpengCui/ER-SDE-Solver/blob/main/er_sde_solver.py.
|
||||
"""
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
@@ -1404,12 +1463,18 @@ def sample_er_sde(model, x, sigmas, extra_args=None, callback=None, disable=None
|
||||
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
|
||||
def default_noise_scaler(sigma):
|
||||
return sigma * ((sigma ** 0.3).exp() + 10.0)
|
||||
noise_scaler = default_noise_scaler if noise_scaler is None else noise_scaler
|
||||
def default_er_sde_noise_scaler(x):
|
||||
return x * ((x ** 0.3).exp() + 10.0)
|
||||
|
||||
noise_scaler = default_er_sde_noise_scaler if noise_scaler is None else noise_scaler
|
||||
num_integration_points = 200.0
|
||||
point_indice = torch.arange(0, num_integration_points, dtype=torch.float32, device=x.device)
|
||||
|
||||
model_sampling = model.inner_model.model_patcher.get_model_object("model_sampling")
|
||||
sigmas = offset_first_sigma_for_snr(sigmas, model_sampling)
|
||||
half_log_snrs = sigma_to_half_log_snr(sigmas, model_sampling)
|
||||
er_lambdas = half_log_snrs.neg().exp() # er_lambda_t = sigma_t / alpha_t
|
||||
|
||||
old_denoised = None
|
||||
old_denoised_d = None
|
||||
|
||||
@@ -1420,41 +1485,45 @@ def sample_er_sde(model, x, sigmas, extra_args=None, callback=None, disable=None
|
||||
stage_used = min(max_stage, i + 1)
|
||||
if sigmas[i + 1] == 0:
|
||||
x = denoised
|
||||
elif stage_used == 1:
|
||||
r = noise_scaler(sigmas[i + 1]) / noise_scaler(sigmas[i])
|
||||
x = r * x + (1 - r) * denoised
|
||||
else:
|
||||
r = noise_scaler(sigmas[i + 1]) / noise_scaler(sigmas[i])
|
||||
x = r * x + (1 - r) * denoised
|
||||
er_lambda_s, er_lambda_t = er_lambdas[i], er_lambdas[i + 1]
|
||||
alpha_s = sigmas[i] / er_lambda_s
|
||||
alpha_t = sigmas[i + 1] / er_lambda_t
|
||||
r_alpha = alpha_t / alpha_s
|
||||
r = noise_scaler(er_lambda_t) / noise_scaler(er_lambda_s)
|
||||
|
||||
dt = sigmas[i + 1] - sigmas[i]
|
||||
sigma_step_size = -dt / num_integration_points
|
||||
sigma_pos = sigmas[i + 1] + point_indice * sigma_step_size
|
||||
scaled_pos = noise_scaler(sigma_pos)
|
||||
# Stage 1 Euler
|
||||
x = r_alpha * r * x + alpha_t * (1 - r) * denoised
|
||||
|
||||
# Stage 2
|
||||
s = torch.sum(1 / scaled_pos) * sigma_step_size
|
||||
denoised_d = (denoised - old_denoised) / (sigmas[i] - sigmas[i - 1])
|
||||
x = x + (dt + s * noise_scaler(sigmas[i + 1])) * denoised_d
|
||||
if stage_used >= 2:
|
||||
dt = er_lambda_t - er_lambda_s
|
||||
lambda_step_size = -dt / num_integration_points
|
||||
lambda_pos = er_lambda_t + point_indice * lambda_step_size
|
||||
scaled_pos = noise_scaler(lambda_pos)
|
||||
|
||||
if stage_used >= 3:
|
||||
# Stage 3
|
||||
s_u = torch.sum((sigma_pos - sigmas[i]) / scaled_pos) * sigma_step_size
|
||||
denoised_u = (denoised_d - old_denoised_d) / ((sigmas[i] - sigmas[i - 2]) / 2)
|
||||
x = x + ((dt ** 2) / 2 + s_u * noise_scaler(sigmas[i + 1])) * denoised_u
|
||||
old_denoised_d = denoised_d
|
||||
# Stage 2
|
||||
s = torch.sum(1 / scaled_pos) * lambda_step_size
|
||||
denoised_d = (denoised - old_denoised) / (er_lambda_s - er_lambdas[i - 1])
|
||||
x = x + alpha_t * (dt + s * noise_scaler(er_lambda_t)) * denoised_d
|
||||
|
||||
if s_noise != 0 and sigmas[i + 1] > 0:
|
||||
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * (sigmas[i + 1] ** 2 - sigmas[i] ** 2 * r ** 2).sqrt().nan_to_num(nan=0.0)
|
||||
if stage_used >= 3:
|
||||
# Stage 3
|
||||
s_u = torch.sum((lambda_pos - er_lambda_s) / scaled_pos) * lambda_step_size
|
||||
denoised_u = (denoised_d - old_denoised_d) / ((er_lambda_s - er_lambdas[i - 2]) / 2)
|
||||
x = x + alpha_t * ((dt ** 2) / 2 + s_u * noise_scaler(er_lambda_t)) * denoised_u
|
||||
old_denoised_d = denoised_d
|
||||
|
||||
if s_noise > 0:
|
||||
x = x + alpha_t * noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * (er_lambda_t ** 2 - er_lambda_s ** 2 * r ** 2).sqrt().nan_to_num(nan=0.0)
|
||||
old_denoised = denoised
|
||||
return x
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=0.5):
|
||||
'''
|
||||
SEEDS-2 - Stochastic Explicit Exponential Derivative-free Solvers (VE Data Prediction) stage 2
|
||||
Arxiv: https://arxiv.org/abs/2305.14267
|
||||
'''
|
||||
"""SEEDS-2 - Stochastic Explicit Exponential Derivative-free Solvers (VP Data Prediction) stage 2.
|
||||
arXiv: https://arxiv.org/abs/2305.14267
|
||||
"""
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
seed = extra_args.get("seed", None)
|
||||
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
|
||||
@@ -1462,6 +1531,11 @@ def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=Non
|
||||
|
||||
inject_noise = eta > 0 and s_noise > 0
|
||||
|
||||
model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling')
|
||||
sigma_fn = partial(half_log_snr_to_sigma, model_sampling=model_sampling)
|
||||
lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling)
|
||||
sigmas = offset_first_sigma_for_snr(sigmas, model_sampling)
|
||||
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||
if callback is not None:
|
||||
@@ -1469,80 +1543,96 @@ def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=Non
|
||||
if sigmas[i + 1] == 0:
|
||||
x = denoised
|
||||
else:
|
||||
t, t_next = -sigmas[i].log(), -sigmas[i + 1].log()
|
||||
h = t_next - t
|
||||
lambda_s, lambda_t = lambda_fn(sigmas[i]), lambda_fn(sigmas[i + 1])
|
||||
h = lambda_t - lambda_s
|
||||
h_eta = h * (eta + 1)
|
||||
s = t + r * h
|
||||
lambda_s_1 = lambda_s + r * h
|
||||
fac = 1 / (2 * r)
|
||||
sigma_s = s.neg().exp()
|
||||
sigma_s_1 = sigma_fn(lambda_s_1)
|
||||
|
||||
# alpha_t = sigma_t * exp(log(alpha_t / sigma_t)) = sigma_t * exp(lambda_t)
|
||||
alpha_s_1 = sigma_s_1 * lambda_s_1.exp()
|
||||
alpha_t = sigmas[i + 1] * lambda_t.exp()
|
||||
|
||||
coeff_1, coeff_2 = (-r * h_eta).expm1(), (-h_eta).expm1()
|
||||
if inject_noise:
|
||||
# 0 < r < 1
|
||||
noise_coeff_1 = (-2 * r * h * eta).expm1().neg().sqrt()
|
||||
noise_coeff_2 = ((-2 * r * h * eta).expm1() - (-2 * h * eta).expm1()).sqrt()
|
||||
noise_1, noise_2 = noise_sampler(sigmas[i], sigma_s), noise_sampler(sigma_s, sigmas[i + 1])
|
||||
noise_coeff_2 = (-r * h * eta).exp() * (-2 * (1 - r) * h * eta).expm1().neg().sqrt()
|
||||
noise_1, noise_2 = noise_sampler(sigmas[i], sigma_s_1), noise_sampler(sigma_s_1, sigmas[i + 1])
|
||||
|
||||
# Step 1
|
||||
x_2 = (coeff_1 + 1) * x - coeff_1 * denoised
|
||||
if inject_noise:
|
||||
x_2 = x_2 + sigma_s * (noise_coeff_1 * noise_1) * s_noise
|
||||
denoised_2 = model(x_2, sigma_s * s_in, **extra_args)
|
||||
|
||||
# Step 2
|
||||
denoised_d = (1 - fac) * denoised + fac * denoised_2
|
||||
x = (coeff_2 + 1) * x - coeff_2 * denoised_d
|
||||
if inject_noise:
|
||||
x = x + sigmas[i + 1] * (noise_coeff_2 * noise_1 + noise_coeff_1 * noise_2) * s_noise
|
||||
return x
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_seeds_3(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r_1=1./3, r_2=2./3):
|
||||
'''
|
||||
SEEDS-3 - Stochastic Explicit Exponential Derivative-free Solvers (VE Data Prediction) stage 3
|
||||
Arxiv: https://arxiv.org/abs/2305.14267
|
||||
'''
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
seed = extra_args.get("seed", None)
|
||||
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
|
||||
inject_noise = eta > 0 and s_noise > 0
|
||||
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||
if sigmas[i + 1] == 0:
|
||||
x = denoised
|
||||
else:
|
||||
t, t_next = -sigmas[i].log(), -sigmas[i + 1].log()
|
||||
h = t_next - t
|
||||
h_eta = h * (eta + 1)
|
||||
s_1 = t + r_1 * h
|
||||
s_2 = t + r_2 * h
|
||||
sigma_s_1, sigma_s_2 = s_1.neg().exp(), s_2.neg().exp()
|
||||
|
||||
coeff_1, coeff_2, coeff_3 = (-r_1 * h_eta).expm1(), (-r_2 * h_eta).expm1(), (-h_eta).expm1()
|
||||
if inject_noise:
|
||||
noise_coeff_1 = (-2 * r_1 * h * eta).expm1().neg().sqrt()
|
||||
noise_coeff_2 = ((-2 * r_1 * h * eta).expm1() - (-2 * r_2 * h * eta).expm1()).sqrt()
|
||||
noise_coeff_3 = ((-2 * r_2 * h * eta).expm1() - (-2 * h * eta).expm1()).sqrt()
|
||||
noise_1, noise_2, noise_3 = noise_sampler(sigmas[i], sigma_s_1), noise_sampler(sigma_s_1, sigma_s_2), noise_sampler(sigma_s_2, sigmas[i + 1])
|
||||
|
||||
# Step 1
|
||||
x_2 = (coeff_1 + 1) * x - coeff_1 * denoised
|
||||
x_2 = sigma_s_1 / sigmas[i] * (-r * h * eta).exp() * x - alpha_s_1 * coeff_1 * denoised
|
||||
if inject_noise:
|
||||
x_2 = x_2 + sigma_s_1 * (noise_coeff_1 * noise_1) * s_noise
|
||||
denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args)
|
||||
|
||||
# Step 2
|
||||
x_3 = (coeff_2 + 1) * x - coeff_2 * denoised + (r_2 / r_1) * (coeff_2 / (r_2 * h_eta) + 1) * (denoised_2 - denoised)
|
||||
denoised_d = (1 - fac) * denoised + fac * denoised_2
|
||||
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * coeff_2 * denoised_d
|
||||
if inject_noise:
|
||||
x = x + sigmas[i + 1] * (noise_coeff_2 * noise_1 + noise_coeff_1 * noise_2) * s_noise
|
||||
return x
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_seeds_3(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r_1=1./3, r_2=2./3):
|
||||
"""SEEDS-3 - Stochastic Explicit Exponential Derivative-free Solvers (VP Data Prediction) stage 3.
|
||||
arXiv: https://arxiv.org/abs/2305.14267
|
||||
"""
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
seed = extra_args.get("seed", None)
|
||||
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
|
||||
inject_noise = eta > 0 and s_noise > 0
|
||||
|
||||
model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling')
|
||||
sigma_fn = partial(half_log_snr_to_sigma, model_sampling=model_sampling)
|
||||
lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling)
|
||||
sigmas = offset_first_sigma_for_snr(sigmas, model_sampling)
|
||||
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||
if sigmas[i + 1] == 0:
|
||||
x = denoised
|
||||
else:
|
||||
lambda_s, lambda_t = lambda_fn(sigmas[i]), lambda_fn(sigmas[i + 1])
|
||||
h = lambda_t - lambda_s
|
||||
h_eta = h * (eta + 1)
|
||||
lambda_s_1 = lambda_s + r_1 * h
|
||||
lambda_s_2 = lambda_s + r_2 * h
|
||||
sigma_s_1, sigma_s_2 = sigma_fn(lambda_s_1), sigma_fn(lambda_s_2)
|
||||
|
||||
# alpha_t = sigma_t * exp(log(alpha_t / sigma_t)) = sigma_t * exp(lambda_t)
|
||||
alpha_s_1 = sigma_s_1 * lambda_s_1.exp()
|
||||
alpha_s_2 = sigma_s_2 * lambda_s_2.exp()
|
||||
alpha_t = sigmas[i + 1] * lambda_t.exp()
|
||||
|
||||
coeff_1, coeff_2, coeff_3 = (-r_1 * h_eta).expm1(), (-r_2 * h_eta).expm1(), (-h_eta).expm1()
|
||||
if inject_noise:
|
||||
# 0 < r_1 < r_2 < 1
|
||||
noise_coeff_1 = (-2 * r_1 * h * eta).expm1().neg().sqrt()
|
||||
noise_coeff_2 = (-r_1 * h * eta).exp() * (-2 * (r_2 - r_1) * h * eta).expm1().neg().sqrt()
|
||||
noise_coeff_3 = (-r_2 * h * eta).exp() * (-2 * (1 - r_2) * h * eta).expm1().neg().sqrt()
|
||||
noise_1, noise_2, noise_3 = noise_sampler(sigmas[i], sigma_s_1), noise_sampler(sigma_s_1, sigma_s_2), noise_sampler(sigma_s_2, sigmas[i + 1])
|
||||
|
||||
# Step 1
|
||||
x_2 = sigma_s_1 / sigmas[i] * (-r_1 * h * eta).exp() * x - alpha_s_1 * coeff_1 * denoised
|
||||
if inject_noise:
|
||||
x_2 = x_2 + sigma_s_1 * (noise_coeff_1 * noise_1) * s_noise
|
||||
denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args)
|
||||
|
||||
# Step 2
|
||||
x_3 = sigma_s_2 / sigmas[i] * (-r_2 * h * eta).exp() * x - alpha_s_2 * coeff_2 * denoised + (r_2 / r_1) * alpha_s_2 * (coeff_2 / (r_2 * h_eta) + 1) * (denoised_2 - denoised)
|
||||
if inject_noise:
|
||||
x_3 = x_3 + sigma_s_2 * (noise_coeff_2 * noise_1 + noise_coeff_1 * noise_2) * s_noise
|
||||
denoised_3 = model(x_3, sigma_s_2 * s_in, **extra_args)
|
||||
|
||||
# Step 3
|
||||
x = (coeff_3 + 1) * x - coeff_3 * denoised + (1. / r_2) * (coeff_3 / h_eta + 1) * (denoised_3 - denoised)
|
||||
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * coeff_3 * denoised + (1. / r_2) * alpha_t * (coeff_3 / h_eta + 1) * (denoised_3 - denoised)
|
||||
if inject_noise:
|
||||
x = x + sigmas[i + 1] * (noise_coeff_3 * noise_1 + noise_coeff_2 * noise_2 + noise_coeff_1 * noise_3) * s_noise
|
||||
return x
|
||||
|
@@ -80,15 +80,13 @@ class DoubleStreamBlock(nn.Module):
|
||||
(img_mod1, img_mod2), (txt_mod1, txt_mod2) = vec
|
||||
|
||||
# prepare image for attention
|
||||
img_modulated = self.img_norm1(img)
|
||||
img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
|
||||
img_modulated = torch.addcmul(img_mod1.shift, 1 + img_mod1.scale, self.img_norm1(img))
|
||||
img_qkv = self.img_attn.qkv(img_modulated)
|
||||
img_q, img_k, img_v = img_qkv.view(img_qkv.shape[0], img_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
|
||||
|
||||
# prepare txt for attention
|
||||
txt_modulated = self.txt_norm1(txt)
|
||||
txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
|
||||
txt_modulated = torch.addcmul(txt_mod1.shift, 1 + txt_mod1.scale, self.txt_norm1(txt))
|
||||
txt_qkv = self.txt_attn.qkv(txt_modulated)
|
||||
txt_q, txt_k, txt_v = txt_qkv.view(txt_qkv.shape[0], txt_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
|
||||
@@ -102,12 +100,12 @@ class DoubleStreamBlock(nn.Module):
|
||||
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
|
||||
|
||||
# calculate the img bloks
|
||||
img = img + img_mod1.gate * self.img_attn.proj(img_attn)
|
||||
img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift)
|
||||
img.addcmul_(img_mod1.gate, self.img_attn.proj(img_attn))
|
||||
img.addcmul_(img_mod2.gate, self.img_mlp(torch.addcmul(img_mod2.shift, 1 + img_mod2.scale, self.img_norm2(img))))
|
||||
|
||||
# calculate the txt bloks
|
||||
txt += txt_mod1.gate * self.txt_attn.proj(txt_attn)
|
||||
txt += txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
|
||||
txt.addcmul_(txt_mod1.gate, self.txt_attn.proj(txt_attn))
|
||||
txt.addcmul_(txt_mod2.gate, self.txt_mlp(torch.addcmul(txt_mod2.shift, 1 + txt_mod2.scale, self.txt_norm2(txt))))
|
||||
|
||||
if txt.dtype == torch.float16:
|
||||
txt = torch.nan_to_num(txt, nan=0.0, posinf=65504, neginf=-65504)
|
||||
@@ -152,7 +150,7 @@ class SingleStreamBlock(nn.Module):
|
||||
|
||||
def forward(self, x: Tensor, pe: Tensor, vec: Tensor, attn_mask=None) -> Tensor:
|
||||
mod = vec
|
||||
x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
|
||||
x_mod = torch.addcmul(mod.shift, 1 + mod.scale, self.pre_norm(x))
|
||||
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
|
||||
|
||||
q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||
@@ -162,7 +160,7 @@ class SingleStreamBlock(nn.Module):
|
||||
attn = attention(q, k, v, pe=pe, mask=attn_mask)
|
||||
# compute activation in mlp stream, cat again and run second linear layer
|
||||
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
|
||||
x += mod.gate * output
|
||||
x.addcmul_(mod.gate, output)
|
||||
if x.dtype == torch.float16:
|
||||
x = torch.nan_to_num(x, nan=0.0, posinf=65504, neginf=-65504)
|
||||
return x
|
||||
@@ -178,6 +176,6 @@ class LastLayer(nn.Module):
|
||||
shift, scale = vec
|
||||
shift = shift.squeeze(1)
|
||||
scale = scale.squeeze(1)
|
||||
x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
|
||||
x = torch.addcmul(shift[:, None, :], 1 + scale[:, None, :], self.norm_final(x))
|
||||
x = self.linear(x)
|
||||
return x
|
||||
|
@@ -26,16 +26,6 @@ from torch import nn
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
|
||||
|
||||
def apply_rotary_pos_emb(
|
||||
t: torch.Tensor,
|
||||
freqs: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
t_ = t.reshape(*t.shape[:-1], 2, -1).movedim(-2, -1).unsqueeze(-2).float()
|
||||
t_out = freqs[..., 0] * t_[..., 0] + freqs[..., 1] * t_[..., 1]
|
||||
t_out = t_out.movedim(-1, -2).reshape(*t.shape).type_as(t)
|
||||
return t_out
|
||||
|
||||
|
||||
def get_normalization(name: str, channels: int, weight_args={}, operations=None):
|
||||
if name == "I":
|
||||
return nn.Identity()
|
||||
|
@@ -66,15 +66,16 @@ class VideoRopePosition3DEmb(VideoPositionEmb):
|
||||
h_extrapolation_ratio: float = 1.0,
|
||||
w_extrapolation_ratio: float = 1.0,
|
||||
t_extrapolation_ratio: float = 1.0,
|
||||
enable_fps_modulation: bool = True,
|
||||
device=None,
|
||||
**kwargs, # used for compatibility with other positional embeddings; unused in this class
|
||||
):
|
||||
del kwargs
|
||||
super().__init__()
|
||||
self.register_buffer("seq", torch.arange(max(len_h, len_w, len_t), dtype=torch.float, device=device))
|
||||
self.base_fps = base_fps
|
||||
self.max_h = len_h
|
||||
self.max_w = len_w
|
||||
self.enable_fps_modulation = enable_fps_modulation
|
||||
|
||||
dim = head_dim
|
||||
dim_h = dim // 6 * 2
|
||||
@@ -132,21 +133,19 @@ class VideoRopePosition3DEmb(VideoPositionEmb):
|
||||
temporal_freqs = 1.0 / (t_theta**self.dim_temporal_range.to(device=device))
|
||||
|
||||
B, T, H, W, _ = B_T_H_W_C
|
||||
seq = torch.arange(max(H, W, T), dtype=torch.float, device=device)
|
||||
uniform_fps = (fps is None) or isinstance(fps, (int, float)) or (fps.min() == fps.max())
|
||||
assert (
|
||||
uniform_fps or B == 1 or T == 1
|
||||
), "For video batch, batch size should be 1 for non-uniform fps. For image batch, T should be 1"
|
||||
assert (
|
||||
H <= self.max_h and W <= self.max_w
|
||||
), f"Input dimensions (H={H}, W={W}) exceed the maximum dimensions (max_h={self.max_h}, max_w={self.max_w})"
|
||||
half_emb_h = torch.outer(self.seq[:H].to(device=device), h_spatial_freqs)
|
||||
half_emb_w = torch.outer(self.seq[:W].to(device=device), w_spatial_freqs)
|
||||
half_emb_h = torch.outer(seq[:H].to(device=device), h_spatial_freqs)
|
||||
half_emb_w = torch.outer(seq[:W].to(device=device), w_spatial_freqs)
|
||||
|
||||
# apply sequence scaling in temporal dimension
|
||||
if fps is None: # image case
|
||||
half_emb_t = torch.outer(self.seq[:T].to(device=device), temporal_freqs)
|
||||
if fps is None or self.enable_fps_modulation is False: # image case
|
||||
half_emb_t = torch.outer(seq[:T].to(device=device), temporal_freqs)
|
||||
else:
|
||||
half_emb_t = torch.outer(self.seq[:T].to(device=device) / fps * self.base_fps, temporal_freqs)
|
||||
half_emb_t = torch.outer(seq[:T].to(device=device) / fps * self.base_fps, temporal_freqs)
|
||||
|
||||
half_emb_h = torch.stack([torch.cos(half_emb_h), -torch.sin(half_emb_h), torch.sin(half_emb_h), torch.cos(half_emb_h)], dim=-1)
|
||||
half_emb_w = torch.stack([torch.cos(half_emb_w), -torch.sin(half_emb_w), torch.sin(half_emb_w), torch.cos(half_emb_w)], dim=-1)
|
||||
|
864
comfy/ldm/cosmos/predict2.py
Normal file
864
comfy/ldm/cosmos/predict2.py
Normal file
@@ -0,0 +1,864 @@
|
||||
# original code from: https://github.com/nvidia-cosmos/cosmos-predict2
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from einops import rearrange
|
||||
from einops.layers.torch import Rearrange
|
||||
import logging
|
||||
from typing import Callable, Optional, Tuple
|
||||
import math
|
||||
|
||||
from .position_embedding import VideoRopePosition3DEmb, LearnablePosEmbAxis
|
||||
from torchvision import transforms
|
||||
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
|
||||
def apply_rotary_pos_emb(
|
||||
t: torch.Tensor,
|
||||
freqs: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
t_ = t.reshape(*t.shape[:-1], 2, -1).movedim(-2, -1).unsqueeze(-2).float()
|
||||
t_out = freqs[..., 0] * t_[..., 0] + freqs[..., 1] * t_[..., 1]
|
||||
t_out = t_out.movedim(-1, -2).reshape(*t.shape).type_as(t)
|
||||
return t_out
|
||||
|
||||
|
||||
# ---------------------- Feed Forward Network -----------------------
|
||||
class GPT2FeedForward(nn.Module):
|
||||
def __init__(self, d_model: int, d_ff: int, device=None, dtype=None, operations=None) -> None:
|
||||
super().__init__()
|
||||
self.activation = nn.GELU()
|
||||
self.layer1 = operations.Linear(d_model, d_ff, bias=False, device=device, dtype=dtype)
|
||||
self.layer2 = operations.Linear(d_ff, d_model, bias=False, device=device, dtype=dtype)
|
||||
|
||||
self._layer_id = None
|
||||
self._dim = d_model
|
||||
self._hidden_dim = d_ff
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.layer1(x)
|
||||
|
||||
x = self.activation(x)
|
||||
x = self.layer2(x)
|
||||
return x
|
||||
|
||||
|
||||
def torch_attention_op(q_B_S_H_D: torch.Tensor, k_B_S_H_D: torch.Tensor, v_B_S_H_D: torch.Tensor) -> torch.Tensor:
|
||||
"""Computes multi-head attention using PyTorch's native implementation.
|
||||
|
||||
This function provides a PyTorch backend alternative to Transformer Engine's attention operation.
|
||||
It rearranges the input tensors to match PyTorch's expected format, computes scaled dot-product
|
||||
attention, and rearranges the output back to the original format.
|
||||
|
||||
The input tensor names use the following dimension conventions:
|
||||
|
||||
- B: batch size
|
||||
- S: sequence length
|
||||
- H: number of attention heads
|
||||
- D: head dimension
|
||||
|
||||
Args:
|
||||
q_B_S_H_D: Query tensor with shape (batch, seq_len, n_heads, head_dim)
|
||||
k_B_S_H_D: Key tensor with shape (batch, seq_len, n_heads, head_dim)
|
||||
v_B_S_H_D: Value tensor with shape (batch, seq_len, n_heads, head_dim)
|
||||
|
||||
Returns:
|
||||
Attention output tensor with shape (batch, seq_len, n_heads * head_dim)
|
||||
"""
|
||||
in_q_shape = q_B_S_H_D.shape
|
||||
in_k_shape = k_B_S_H_D.shape
|
||||
q_B_H_S_D = rearrange(q_B_S_H_D, "b ... h k -> b h ... k").view(in_q_shape[0], in_q_shape[-2], -1, in_q_shape[-1])
|
||||
k_B_H_S_D = rearrange(k_B_S_H_D, "b ... h v -> b h ... v").view(in_k_shape[0], in_k_shape[-2], -1, in_k_shape[-1])
|
||||
v_B_H_S_D = rearrange(v_B_S_H_D, "b ... h v -> b h ... v").view(in_k_shape[0], in_k_shape[-2], -1, in_k_shape[-1])
|
||||
return optimized_attention(q_B_H_S_D, k_B_H_S_D, v_B_H_S_D, in_q_shape[-2], skip_reshape=True)
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
"""
|
||||
A flexible attention module supporting both self-attention and cross-attention mechanisms.
|
||||
|
||||
This module implements a multi-head attention layer that can operate in either self-attention
|
||||
or cross-attention mode. The mode is determined by whether a context dimension is provided.
|
||||
The implementation uses scaled dot-product attention and supports optional bias terms and
|
||||
dropout regularization.
|
||||
|
||||
Args:
|
||||
query_dim (int): The dimensionality of the query vectors.
|
||||
context_dim (int, optional): The dimensionality of the context (key/value) vectors.
|
||||
If None, the module operates in self-attention mode using query_dim. Default: None
|
||||
n_heads (int, optional): Number of attention heads for multi-head attention. Default: 8
|
||||
head_dim (int, optional): The dimension of each attention head. Default: 64
|
||||
dropout (float, optional): Dropout probability applied to the output. Default: 0.0
|
||||
qkv_format (str, optional): Format specification for QKV tensors. Default: "bshd"
|
||||
backend (str, optional): Backend to use for the attention operation. Default: "transformer_engine"
|
||||
|
||||
Examples:
|
||||
>>> # Self-attention with 512 dimensions and 8 heads
|
||||
>>> self_attn = Attention(query_dim=512)
|
||||
>>> x = torch.randn(32, 16, 512) # (batch_size, seq_len, dim)
|
||||
>>> out = self_attn(x) # (32, 16, 512)
|
||||
|
||||
>>> # Cross-attention
|
||||
>>> cross_attn = Attention(query_dim=512, context_dim=256)
|
||||
>>> query = torch.randn(32, 16, 512)
|
||||
>>> context = torch.randn(32, 8, 256)
|
||||
>>> out = cross_attn(query, context) # (32, 16, 512)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
query_dim: int,
|
||||
context_dim: Optional[int] = None,
|
||||
n_heads: int = 8,
|
||||
head_dim: int = 64,
|
||||
dropout: float = 0.0,
|
||||
device=None,
|
||||
dtype=None,
|
||||
operations=None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
logging.debug(
|
||||
f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using "
|
||||
f"{n_heads} heads with a dimension of {head_dim}."
|
||||
)
|
||||
self.is_selfattn = context_dim is None # self attention
|
||||
|
||||
context_dim = query_dim if context_dim is None else context_dim
|
||||
inner_dim = head_dim * n_heads
|
||||
|
||||
self.n_heads = n_heads
|
||||
self.head_dim = head_dim
|
||||
self.query_dim = query_dim
|
||||
self.context_dim = context_dim
|
||||
|
||||
self.q_proj = operations.Linear(query_dim, inner_dim, bias=False, device=device, dtype=dtype)
|
||||
self.q_norm = operations.RMSNorm(self.head_dim, eps=1e-6, device=device, dtype=dtype)
|
||||
|
||||
self.k_proj = operations.Linear(context_dim, inner_dim, bias=False, device=device, dtype=dtype)
|
||||
self.k_norm = operations.RMSNorm(self.head_dim, eps=1e-6, device=device, dtype=dtype)
|
||||
|
||||
self.v_proj = operations.Linear(context_dim, inner_dim, bias=False, device=device, dtype=dtype)
|
||||
self.v_norm = nn.Identity()
|
||||
|
||||
self.output_proj = operations.Linear(inner_dim, query_dim, bias=False, device=device, dtype=dtype)
|
||||
self.output_dropout = nn.Dropout(dropout) if dropout > 1e-4 else nn.Identity()
|
||||
|
||||
self.attn_op = torch_attention_op
|
||||
|
||||
self._query_dim = query_dim
|
||||
self._context_dim = context_dim
|
||||
self._inner_dim = inner_dim
|
||||
|
||||
def compute_qkv(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
context: Optional[torch.Tensor] = None,
|
||||
rope_emb: Optional[torch.Tensor] = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
q = self.q_proj(x)
|
||||
context = x if context is None else context
|
||||
k = self.k_proj(context)
|
||||
v = self.v_proj(context)
|
||||
q, k, v = map(
|
||||
lambda t: rearrange(t, "b ... (h d) -> b ... h d", h=self.n_heads, d=self.head_dim),
|
||||
(q, k, v),
|
||||
)
|
||||
|
||||
def apply_norm_and_rotary_pos_emb(
|
||||
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, rope_emb: Optional[torch.Tensor]
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
q = self.q_norm(q)
|
||||
k = self.k_norm(k)
|
||||
v = self.v_norm(v)
|
||||
if self.is_selfattn and rope_emb is not None: # only apply to self-attention!
|
||||
q = apply_rotary_pos_emb(q, rope_emb)
|
||||
k = apply_rotary_pos_emb(k, rope_emb)
|
||||
return q, k, v
|
||||
|
||||
q, k, v = apply_norm_and_rotary_pos_emb(q, k, v, rope_emb)
|
||||
|
||||
return q, k, v
|
||||
|
||||
def compute_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
|
||||
result = self.attn_op(q, k, v) # [B, S, H, D]
|
||||
return self.output_dropout(self.output_proj(result))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
context: Optional[torch.Tensor] = None,
|
||||
rope_emb: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
x (Tensor): The query tensor of shape [B, Mq, K]
|
||||
context (Optional[Tensor]): The key tensor of shape [B, Mk, K] or use x as context [self attention] if None
|
||||
"""
|
||||
q, k, v = self.compute_qkv(x, context, rope_emb=rope_emb)
|
||||
return self.compute_attention(q, k, v)
|
||||
|
||||
|
||||
class Timesteps(nn.Module):
|
||||
def __init__(self, num_channels: int):
|
||||
super().__init__()
|
||||
self.num_channels = num_channels
|
||||
|
||||
def forward(self, timesteps_B_T: torch.Tensor) -> torch.Tensor:
|
||||
assert timesteps_B_T.ndim == 2, f"Expected 2D input, got {timesteps_B_T.ndim}"
|
||||
timesteps = timesteps_B_T.flatten().float()
|
||||
half_dim = self.num_channels // 2
|
||||
exponent = -math.log(10000) * torch.arange(half_dim, dtype=torch.float32, device=timesteps.device)
|
||||
exponent = exponent / (half_dim - 0.0)
|
||||
|
||||
emb = torch.exp(exponent)
|
||||
emb = timesteps[:, None].float() * emb[None, :]
|
||||
|
||||
sin_emb = torch.sin(emb)
|
||||
cos_emb = torch.cos(emb)
|
||||
emb = torch.cat([cos_emb, sin_emb], dim=-1)
|
||||
|
||||
return rearrange(emb, "(b t) d -> b t d", b=timesteps_B_T.shape[0], t=timesteps_B_T.shape[1])
|
||||
|
||||
|
||||
class TimestepEmbedding(nn.Module):
|
||||
def __init__(self, in_features: int, out_features: int, use_adaln_lora: bool = False, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
logging.debug(
|
||||
f"Using AdaLN LoRA Flag: {use_adaln_lora}. We enable bias if no AdaLN LoRA for backward compatibility."
|
||||
)
|
||||
self.in_dim = in_features
|
||||
self.out_dim = out_features
|
||||
self.linear_1 = operations.Linear(in_features, out_features, bias=not use_adaln_lora, device=device, dtype=dtype)
|
||||
self.activation = nn.SiLU()
|
||||
self.use_adaln_lora = use_adaln_lora
|
||||
if use_adaln_lora:
|
||||
self.linear_2 = operations.Linear(out_features, 3 * out_features, bias=False, device=device, dtype=dtype)
|
||||
else:
|
||||
self.linear_2 = operations.Linear(out_features, out_features, bias=False, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, sample: torch.Tensor) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
emb = self.linear_1(sample)
|
||||
emb = self.activation(emb)
|
||||
emb = self.linear_2(emb)
|
||||
|
||||
if self.use_adaln_lora:
|
||||
adaln_lora_B_T_3D = emb
|
||||
emb_B_T_D = sample
|
||||
else:
|
||||
adaln_lora_B_T_3D = None
|
||||
emb_B_T_D = emb
|
||||
|
||||
return emb_B_T_D, adaln_lora_B_T_3D
|
||||
|
||||
|
||||
class PatchEmbed(nn.Module):
|
||||
"""
|
||||
PatchEmbed is a module for embedding patches from an input tensor by applying either 3D or 2D convolutional layers,
|
||||
depending on the . This module can process inputs with temporal (video) and spatial (image) dimensions,
|
||||
making it suitable for video and image processing tasks. It supports dividing the input into patches
|
||||
and embedding each patch into a vector of size `out_channels`.
|
||||
|
||||
Parameters:
|
||||
- spatial_patch_size (int): The size of each spatial patch.
|
||||
- temporal_patch_size (int): The size of each temporal patch.
|
||||
- in_channels (int): Number of input channels. Default: 3.
|
||||
- out_channels (int): The dimension of the embedding vector for each patch. Default: 768.
|
||||
- bias (bool): If True, adds a learnable bias to the output of the convolutional layers. Default: True.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
spatial_patch_size: int,
|
||||
temporal_patch_size: int,
|
||||
in_channels: int = 3,
|
||||
out_channels: int = 768,
|
||||
device=None, dtype=None, operations=None
|
||||
):
|
||||
super().__init__()
|
||||
self.spatial_patch_size = spatial_patch_size
|
||||
self.temporal_patch_size = temporal_patch_size
|
||||
|
||||
self.proj = nn.Sequential(
|
||||
Rearrange(
|
||||
"b c (t r) (h m) (w n) -> b t h w (c r m n)",
|
||||
r=temporal_patch_size,
|
||||
m=spatial_patch_size,
|
||||
n=spatial_patch_size,
|
||||
),
|
||||
operations.Linear(
|
||||
in_channels * spatial_patch_size * spatial_patch_size * temporal_patch_size, out_channels, bias=False, device=device, dtype=dtype
|
||||
),
|
||||
)
|
||||
self.dim = in_channels * spatial_patch_size * spatial_patch_size * temporal_patch_size
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Forward pass of the PatchEmbed module.
|
||||
|
||||
Parameters:
|
||||
- x (torch.Tensor): The input tensor of shape (B, C, T, H, W) where
|
||||
B is the batch size,
|
||||
C is the number of channels,
|
||||
T is the temporal dimension,
|
||||
H is the height, and
|
||||
W is the width of the input.
|
||||
|
||||
Returns:
|
||||
- torch.Tensor: The embedded patches as a tensor, with shape b t h w c.
|
||||
"""
|
||||
assert x.dim() == 5
|
||||
_, _, T, H, W = x.shape
|
||||
assert (
|
||||
H % self.spatial_patch_size == 0 and W % self.spatial_patch_size == 0
|
||||
), f"H,W {(H, W)} should be divisible by spatial_patch_size {self.spatial_patch_size}"
|
||||
assert T % self.temporal_patch_size == 0
|
||||
x = self.proj(x)
|
||||
return x
|
||||
|
||||
|
||||
class FinalLayer(nn.Module):
|
||||
"""
|
||||
The final layer of video DiT.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
spatial_patch_size: int,
|
||||
temporal_patch_size: int,
|
||||
out_channels: int,
|
||||
use_adaln_lora: bool = False,
|
||||
adaln_lora_dim: int = 256,
|
||||
device=None, dtype=None, operations=None
|
||||
):
|
||||
super().__init__()
|
||||
self.layer_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.linear = operations.Linear(
|
||||
hidden_size, spatial_patch_size * spatial_patch_size * temporal_patch_size * out_channels, bias=False, device=device, dtype=dtype
|
||||
)
|
||||
self.hidden_size = hidden_size
|
||||
self.n_adaln_chunks = 2
|
||||
self.use_adaln_lora = use_adaln_lora
|
||||
self.adaln_lora_dim = adaln_lora_dim
|
||||
if use_adaln_lora:
|
||||
self.adaln_modulation = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
operations.Linear(hidden_size, adaln_lora_dim, bias=False, device=device, dtype=dtype),
|
||||
operations.Linear(adaln_lora_dim, self.n_adaln_chunks * hidden_size, bias=False, device=device, dtype=dtype),
|
||||
)
|
||||
else:
|
||||
self.adaln_modulation = nn.Sequential(
|
||||
nn.SiLU(), operations.Linear(hidden_size, self.n_adaln_chunks * hidden_size, bias=False, device=device, dtype=dtype)
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x_B_T_H_W_D: torch.Tensor,
|
||||
emb_B_T_D: torch.Tensor,
|
||||
adaln_lora_B_T_3D: Optional[torch.Tensor] = None,
|
||||
):
|
||||
if self.use_adaln_lora:
|
||||
assert adaln_lora_B_T_3D is not None
|
||||
shift_B_T_D, scale_B_T_D = (
|
||||
self.adaln_modulation(emb_B_T_D) + adaln_lora_B_T_3D[:, :, : 2 * self.hidden_size]
|
||||
).chunk(2, dim=-1)
|
||||
else:
|
||||
shift_B_T_D, scale_B_T_D = self.adaln_modulation(emb_B_T_D).chunk(2, dim=-1)
|
||||
|
||||
shift_B_T_1_1_D, scale_B_T_1_1_D = rearrange(shift_B_T_D, "b t d -> b t 1 1 d"), rearrange(
|
||||
scale_B_T_D, "b t d -> b t 1 1 d"
|
||||
)
|
||||
|
||||
def _fn(
|
||||
_x_B_T_H_W_D: torch.Tensor,
|
||||
_norm_layer: nn.Module,
|
||||
_scale_B_T_1_1_D: torch.Tensor,
|
||||
_shift_B_T_1_1_D: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
return _norm_layer(_x_B_T_H_W_D) * (1 + _scale_B_T_1_1_D) + _shift_B_T_1_1_D
|
||||
|
||||
x_B_T_H_W_D = _fn(x_B_T_H_W_D, self.layer_norm, scale_B_T_1_1_D, shift_B_T_1_1_D)
|
||||
x_B_T_H_W_O = self.linear(x_B_T_H_W_D)
|
||||
return x_B_T_H_W_O
|
||||
|
||||
|
||||
class Block(nn.Module):
|
||||
"""
|
||||
A transformer block that combines self-attention, cross-attention and MLP layers with AdaLN modulation.
|
||||
Each component (self-attention, cross-attention, MLP) has its own layer normalization and AdaLN modulation.
|
||||
|
||||
Parameters:
|
||||
x_dim (int): Dimension of input features
|
||||
context_dim (int): Dimension of context features for cross-attention
|
||||
num_heads (int): Number of attention heads
|
||||
mlp_ratio (float): Multiplier for MLP hidden dimension. Default: 4.0
|
||||
use_adaln_lora (bool): Whether to use AdaLN-LoRA modulation. Default: False
|
||||
adaln_lora_dim (int): Hidden dimension for AdaLN-LoRA layers. Default: 256
|
||||
|
||||
The block applies the following sequence:
|
||||
1. Self-attention with AdaLN modulation
|
||||
2. Cross-attention with AdaLN modulation
|
||||
3. MLP with AdaLN modulation
|
||||
|
||||
Each component uses skip connections and layer normalization.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
x_dim: int,
|
||||
context_dim: int,
|
||||
num_heads: int,
|
||||
mlp_ratio: float = 4.0,
|
||||
use_adaln_lora: bool = False,
|
||||
adaln_lora_dim: int = 256,
|
||||
device=None,
|
||||
dtype=None,
|
||||
operations=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.x_dim = x_dim
|
||||
self.layer_norm_self_attn = operations.LayerNorm(x_dim, elementwise_affine=False, eps=1e-6, device=device, dtype=dtype)
|
||||
self.self_attn = Attention(x_dim, None, num_heads, x_dim // num_heads, device=device, dtype=dtype, operations=operations)
|
||||
|
||||
self.layer_norm_cross_attn = operations.LayerNorm(x_dim, elementwise_affine=False, eps=1e-6, device=device, dtype=dtype)
|
||||
self.cross_attn = Attention(
|
||||
x_dim, context_dim, num_heads, x_dim // num_heads, device=device, dtype=dtype, operations=operations
|
||||
)
|
||||
|
||||
self.layer_norm_mlp = operations.LayerNorm(x_dim, elementwise_affine=False, eps=1e-6, device=device, dtype=dtype)
|
||||
self.mlp = GPT2FeedForward(x_dim, int(x_dim * mlp_ratio), device=device, dtype=dtype, operations=operations)
|
||||
|
||||
self.use_adaln_lora = use_adaln_lora
|
||||
if self.use_adaln_lora:
|
||||
self.adaln_modulation_self_attn = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
operations.Linear(x_dim, adaln_lora_dim, bias=False, device=device, dtype=dtype),
|
||||
operations.Linear(adaln_lora_dim, 3 * x_dim, bias=False, device=device, dtype=dtype),
|
||||
)
|
||||
self.adaln_modulation_cross_attn = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
operations.Linear(x_dim, adaln_lora_dim, bias=False, device=device, dtype=dtype),
|
||||
operations.Linear(adaln_lora_dim, 3 * x_dim, bias=False, device=device, dtype=dtype),
|
||||
)
|
||||
self.adaln_modulation_mlp = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
operations.Linear(x_dim, adaln_lora_dim, bias=False, device=device, dtype=dtype),
|
||||
operations.Linear(adaln_lora_dim, 3 * x_dim, bias=False, device=device, dtype=dtype),
|
||||
)
|
||||
else:
|
||||
self.adaln_modulation_self_attn = nn.Sequential(nn.SiLU(), operations.Linear(x_dim, 3 * x_dim, bias=False, device=device, dtype=dtype))
|
||||
self.adaln_modulation_cross_attn = nn.Sequential(nn.SiLU(), operations.Linear(x_dim, 3 * x_dim, bias=False, device=device, dtype=dtype))
|
||||
self.adaln_modulation_mlp = nn.Sequential(nn.SiLU(), operations.Linear(x_dim, 3 * x_dim, bias=False, device=device, dtype=dtype))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x_B_T_H_W_D: torch.Tensor,
|
||||
emb_B_T_D: torch.Tensor,
|
||||
crossattn_emb: torch.Tensor,
|
||||
rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
|
||||
adaln_lora_B_T_3D: Optional[torch.Tensor] = None,
|
||||
extra_per_block_pos_emb: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if extra_per_block_pos_emb is not None:
|
||||
x_B_T_H_W_D = x_B_T_H_W_D + extra_per_block_pos_emb
|
||||
|
||||
if self.use_adaln_lora:
|
||||
shift_self_attn_B_T_D, scale_self_attn_B_T_D, gate_self_attn_B_T_D = (
|
||||
self.adaln_modulation_self_attn(emb_B_T_D) + adaln_lora_B_T_3D
|
||||
).chunk(3, dim=-1)
|
||||
shift_cross_attn_B_T_D, scale_cross_attn_B_T_D, gate_cross_attn_B_T_D = (
|
||||
self.adaln_modulation_cross_attn(emb_B_T_D) + adaln_lora_B_T_3D
|
||||
).chunk(3, dim=-1)
|
||||
shift_mlp_B_T_D, scale_mlp_B_T_D, gate_mlp_B_T_D = (
|
||||
self.adaln_modulation_mlp(emb_B_T_D) + adaln_lora_B_T_3D
|
||||
).chunk(3, dim=-1)
|
||||
else:
|
||||
shift_self_attn_B_T_D, scale_self_attn_B_T_D, gate_self_attn_B_T_D = self.adaln_modulation_self_attn(
|
||||
emb_B_T_D
|
||||
).chunk(3, dim=-1)
|
||||
shift_cross_attn_B_T_D, scale_cross_attn_B_T_D, gate_cross_attn_B_T_D = self.adaln_modulation_cross_attn(
|
||||
emb_B_T_D
|
||||
).chunk(3, dim=-1)
|
||||
shift_mlp_B_T_D, scale_mlp_B_T_D, gate_mlp_B_T_D = self.adaln_modulation_mlp(emb_B_T_D).chunk(3, dim=-1)
|
||||
|
||||
# Reshape tensors from (B, T, D) to (B, T, 1, 1, D) for broadcasting
|
||||
shift_self_attn_B_T_1_1_D = rearrange(shift_self_attn_B_T_D, "b t d -> b t 1 1 d")
|
||||
scale_self_attn_B_T_1_1_D = rearrange(scale_self_attn_B_T_D, "b t d -> b t 1 1 d")
|
||||
gate_self_attn_B_T_1_1_D = rearrange(gate_self_attn_B_T_D, "b t d -> b t 1 1 d")
|
||||
|
||||
shift_cross_attn_B_T_1_1_D = rearrange(shift_cross_attn_B_T_D, "b t d -> b t 1 1 d")
|
||||
scale_cross_attn_B_T_1_1_D = rearrange(scale_cross_attn_B_T_D, "b t d -> b t 1 1 d")
|
||||
gate_cross_attn_B_T_1_1_D = rearrange(gate_cross_attn_B_T_D, "b t d -> b t 1 1 d")
|
||||
|
||||
shift_mlp_B_T_1_1_D = rearrange(shift_mlp_B_T_D, "b t d -> b t 1 1 d")
|
||||
scale_mlp_B_T_1_1_D = rearrange(scale_mlp_B_T_D, "b t d -> b t 1 1 d")
|
||||
gate_mlp_B_T_1_1_D = rearrange(gate_mlp_B_T_D, "b t d -> b t 1 1 d")
|
||||
|
||||
B, T, H, W, D = x_B_T_H_W_D.shape
|
||||
|
||||
def _fn(_x_B_T_H_W_D, _norm_layer, _scale_B_T_1_1_D, _shift_B_T_1_1_D):
|
||||
return _norm_layer(_x_B_T_H_W_D) * (1 + _scale_B_T_1_1_D) + _shift_B_T_1_1_D
|
||||
|
||||
normalized_x_B_T_H_W_D = _fn(
|
||||
x_B_T_H_W_D,
|
||||
self.layer_norm_self_attn,
|
||||
scale_self_attn_B_T_1_1_D,
|
||||
shift_self_attn_B_T_1_1_D,
|
||||
)
|
||||
result_B_T_H_W_D = rearrange(
|
||||
self.self_attn(
|
||||
# normalized_x_B_T_HW_D,
|
||||
rearrange(normalized_x_B_T_H_W_D, "b t h w d -> b (t h w) d"),
|
||||
None,
|
||||
rope_emb=rope_emb_L_1_1_D,
|
||||
),
|
||||
"b (t h w) d -> b t h w d",
|
||||
t=T,
|
||||
h=H,
|
||||
w=W,
|
||||
)
|
||||
x_B_T_H_W_D = x_B_T_H_W_D + gate_self_attn_B_T_1_1_D * result_B_T_H_W_D
|
||||
|
||||
def _x_fn(
|
||||
_x_B_T_H_W_D: torch.Tensor,
|
||||
layer_norm_cross_attn: Callable,
|
||||
_scale_cross_attn_B_T_1_1_D: torch.Tensor,
|
||||
_shift_cross_attn_B_T_1_1_D: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
_normalized_x_B_T_H_W_D = _fn(
|
||||
_x_B_T_H_W_D, layer_norm_cross_attn, _scale_cross_attn_B_T_1_1_D, _shift_cross_attn_B_T_1_1_D
|
||||
)
|
||||
_result_B_T_H_W_D = rearrange(
|
||||
self.cross_attn(
|
||||
rearrange(_normalized_x_B_T_H_W_D, "b t h w d -> b (t h w) d"),
|
||||
crossattn_emb,
|
||||
rope_emb=rope_emb_L_1_1_D,
|
||||
),
|
||||
"b (t h w) d -> b t h w d",
|
||||
t=T,
|
||||
h=H,
|
||||
w=W,
|
||||
)
|
||||
return _result_B_T_H_W_D
|
||||
|
||||
result_B_T_H_W_D = _x_fn(
|
||||
x_B_T_H_W_D,
|
||||
self.layer_norm_cross_attn,
|
||||
scale_cross_attn_B_T_1_1_D,
|
||||
shift_cross_attn_B_T_1_1_D,
|
||||
)
|
||||
x_B_T_H_W_D = result_B_T_H_W_D * gate_cross_attn_B_T_1_1_D + x_B_T_H_W_D
|
||||
|
||||
normalized_x_B_T_H_W_D = _fn(
|
||||
x_B_T_H_W_D,
|
||||
self.layer_norm_mlp,
|
||||
scale_mlp_B_T_1_1_D,
|
||||
shift_mlp_B_T_1_1_D,
|
||||
)
|
||||
result_B_T_H_W_D = self.mlp(normalized_x_B_T_H_W_D)
|
||||
x_B_T_H_W_D = x_B_T_H_W_D + gate_mlp_B_T_1_1_D * result_B_T_H_W_D
|
||||
return x_B_T_H_W_D
|
||||
|
||||
|
||||
class MiniTrainDIT(nn.Module):
|
||||
"""
|
||||
A clean impl of DIT that can load and reproduce the training results of the original DIT model in~(cosmos 1)
|
||||
A general implementation of adaln-modulated VIT-like~(DiT) transformer for video processing.
|
||||
|
||||
Args:
|
||||
max_img_h (int): Maximum height of the input images.
|
||||
max_img_w (int): Maximum width of the input images.
|
||||
max_frames (int): Maximum number of frames in the video sequence.
|
||||
in_channels (int): Number of input channels (e.g., RGB channels for color images).
|
||||
out_channels (int): Number of output channels.
|
||||
patch_spatial (tuple): Spatial resolution of patches for input processing.
|
||||
patch_temporal (int): Temporal resolution of patches for input processing.
|
||||
concat_padding_mask (bool): If True, includes a mask channel in the input to handle padding.
|
||||
model_channels (int): Base number of channels used throughout the model.
|
||||
num_blocks (int): Number of transformer blocks.
|
||||
num_heads (int): Number of heads in the multi-head attention layers.
|
||||
mlp_ratio (float): Expansion ratio for MLP blocks.
|
||||
crossattn_emb_channels (int): Number of embedding channels for cross-attention.
|
||||
pos_emb_cls (str): Type of positional embeddings.
|
||||
pos_emb_learnable (bool): Whether positional embeddings are learnable.
|
||||
pos_emb_interpolation (str): Method for interpolating positional embeddings.
|
||||
min_fps (int): Minimum frames per second.
|
||||
max_fps (int): Maximum frames per second.
|
||||
use_adaln_lora (bool): Whether to use AdaLN-LoRA.
|
||||
adaln_lora_dim (int): Dimension for AdaLN-LoRA.
|
||||
rope_h_extrapolation_ratio (float): Height extrapolation ratio for RoPE.
|
||||
rope_w_extrapolation_ratio (float): Width extrapolation ratio for RoPE.
|
||||
rope_t_extrapolation_ratio (float): Temporal extrapolation ratio for RoPE.
|
||||
extra_per_block_abs_pos_emb (bool): Whether to use extra per-block absolute positional embeddings.
|
||||
extra_h_extrapolation_ratio (float): Height extrapolation ratio for extra embeddings.
|
||||
extra_w_extrapolation_ratio (float): Width extrapolation ratio for extra embeddings.
|
||||
extra_t_extrapolation_ratio (float): Temporal extrapolation ratio for extra embeddings.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_img_h: int,
|
||||
max_img_w: int,
|
||||
max_frames: int,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
patch_spatial: int, # tuple,
|
||||
patch_temporal: int,
|
||||
concat_padding_mask: bool = True,
|
||||
# attention settings
|
||||
model_channels: int = 768,
|
||||
num_blocks: int = 10,
|
||||
num_heads: int = 16,
|
||||
mlp_ratio: float = 4.0,
|
||||
# cross attention settings
|
||||
crossattn_emb_channels: int = 1024,
|
||||
# positional embedding settings
|
||||
pos_emb_cls: str = "sincos",
|
||||
pos_emb_learnable: bool = False,
|
||||
pos_emb_interpolation: str = "crop",
|
||||
min_fps: int = 1,
|
||||
max_fps: int = 30,
|
||||
use_adaln_lora: bool = False,
|
||||
adaln_lora_dim: int = 256,
|
||||
rope_h_extrapolation_ratio: float = 1.0,
|
||||
rope_w_extrapolation_ratio: float = 1.0,
|
||||
rope_t_extrapolation_ratio: float = 1.0,
|
||||
extra_per_block_abs_pos_emb: bool = False,
|
||||
extra_h_extrapolation_ratio: float = 1.0,
|
||||
extra_w_extrapolation_ratio: float = 1.0,
|
||||
extra_t_extrapolation_ratio: float = 1.0,
|
||||
rope_enable_fps_modulation: bool = True,
|
||||
image_model=None,
|
||||
device=None,
|
||||
dtype=None,
|
||||
operations=None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.dtype = dtype
|
||||
self.max_img_h = max_img_h
|
||||
self.max_img_w = max_img_w
|
||||
self.max_frames = max_frames
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.patch_spatial = patch_spatial
|
||||
self.patch_temporal = patch_temporal
|
||||
self.num_heads = num_heads
|
||||
self.num_blocks = num_blocks
|
||||
self.model_channels = model_channels
|
||||
self.concat_padding_mask = concat_padding_mask
|
||||
# positional embedding settings
|
||||
self.pos_emb_cls = pos_emb_cls
|
||||
self.pos_emb_learnable = pos_emb_learnable
|
||||
self.pos_emb_interpolation = pos_emb_interpolation
|
||||
self.min_fps = min_fps
|
||||
self.max_fps = max_fps
|
||||
self.rope_h_extrapolation_ratio = rope_h_extrapolation_ratio
|
||||
self.rope_w_extrapolation_ratio = rope_w_extrapolation_ratio
|
||||
self.rope_t_extrapolation_ratio = rope_t_extrapolation_ratio
|
||||
self.extra_per_block_abs_pos_emb = extra_per_block_abs_pos_emb
|
||||
self.extra_h_extrapolation_ratio = extra_h_extrapolation_ratio
|
||||
self.extra_w_extrapolation_ratio = extra_w_extrapolation_ratio
|
||||
self.extra_t_extrapolation_ratio = extra_t_extrapolation_ratio
|
||||
self.rope_enable_fps_modulation = rope_enable_fps_modulation
|
||||
|
||||
self.build_pos_embed(device=device, dtype=dtype)
|
||||
self.use_adaln_lora = use_adaln_lora
|
||||
self.adaln_lora_dim = adaln_lora_dim
|
||||
self.t_embedder = nn.Sequential(
|
||||
Timesteps(model_channels),
|
||||
TimestepEmbedding(model_channels, model_channels, use_adaln_lora=use_adaln_lora, device=device, dtype=dtype, operations=operations,),
|
||||
)
|
||||
|
||||
in_channels = in_channels + 1 if concat_padding_mask else in_channels
|
||||
self.x_embedder = PatchEmbed(
|
||||
spatial_patch_size=patch_spatial,
|
||||
temporal_patch_size=patch_temporal,
|
||||
in_channels=in_channels,
|
||||
out_channels=model_channels,
|
||||
device=device, dtype=dtype, operations=operations,
|
||||
)
|
||||
|
||||
self.blocks = nn.ModuleList(
|
||||
[
|
||||
Block(
|
||||
x_dim=model_channels,
|
||||
context_dim=crossattn_emb_channels,
|
||||
num_heads=num_heads,
|
||||
mlp_ratio=mlp_ratio,
|
||||
use_adaln_lora=use_adaln_lora,
|
||||
adaln_lora_dim=adaln_lora_dim,
|
||||
device=device, dtype=dtype, operations=operations,
|
||||
)
|
||||
for _ in range(num_blocks)
|
||||
]
|
||||
)
|
||||
|
||||
self.final_layer = FinalLayer(
|
||||
hidden_size=self.model_channels,
|
||||
spatial_patch_size=self.patch_spatial,
|
||||
temporal_patch_size=self.patch_temporal,
|
||||
out_channels=self.out_channels,
|
||||
use_adaln_lora=self.use_adaln_lora,
|
||||
adaln_lora_dim=self.adaln_lora_dim,
|
||||
device=device, dtype=dtype, operations=operations,
|
||||
)
|
||||
|
||||
self.t_embedding_norm = operations.RMSNorm(model_channels, eps=1e-6, device=device, dtype=dtype)
|
||||
|
||||
def build_pos_embed(self, device=None, dtype=None) -> None:
|
||||
if self.pos_emb_cls == "rope3d":
|
||||
cls_type = VideoRopePosition3DEmb
|
||||
else:
|
||||
raise ValueError(f"Unknown pos_emb_cls {self.pos_emb_cls}")
|
||||
|
||||
logging.debug(f"Building positional embedding with {self.pos_emb_cls} class, impl {cls_type}")
|
||||
kwargs = dict(
|
||||
model_channels=self.model_channels,
|
||||
len_h=self.max_img_h // self.patch_spatial,
|
||||
len_w=self.max_img_w // self.patch_spatial,
|
||||
len_t=self.max_frames // self.patch_temporal,
|
||||
max_fps=self.max_fps,
|
||||
min_fps=self.min_fps,
|
||||
is_learnable=self.pos_emb_learnable,
|
||||
interpolation=self.pos_emb_interpolation,
|
||||
head_dim=self.model_channels // self.num_heads,
|
||||
h_extrapolation_ratio=self.rope_h_extrapolation_ratio,
|
||||
w_extrapolation_ratio=self.rope_w_extrapolation_ratio,
|
||||
t_extrapolation_ratio=self.rope_t_extrapolation_ratio,
|
||||
enable_fps_modulation=self.rope_enable_fps_modulation,
|
||||
device=device,
|
||||
)
|
||||
self.pos_embedder = cls_type(
|
||||
**kwargs, # type: ignore
|
||||
)
|
||||
|
||||
if self.extra_per_block_abs_pos_emb:
|
||||
kwargs["h_extrapolation_ratio"] = self.extra_h_extrapolation_ratio
|
||||
kwargs["w_extrapolation_ratio"] = self.extra_w_extrapolation_ratio
|
||||
kwargs["t_extrapolation_ratio"] = self.extra_t_extrapolation_ratio
|
||||
kwargs["device"] = device
|
||||
kwargs["dtype"] = dtype
|
||||
self.extra_pos_embedder = LearnablePosEmbAxis(
|
||||
**kwargs, # type: ignore
|
||||
)
|
||||
|
||||
def prepare_embedded_sequence(
|
||||
self,
|
||||
x_B_C_T_H_W: torch.Tensor,
|
||||
fps: Optional[torch.Tensor] = None,
|
||||
padding_mask: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||
"""
|
||||
Prepares an embedded sequence tensor by applying positional embeddings and handling padding masks.
|
||||
|
||||
Args:
|
||||
x_B_C_T_H_W (torch.Tensor): video
|
||||
fps (Optional[torch.Tensor]): Frames per second tensor to be used for positional embedding when required.
|
||||
If None, a default value (`self.base_fps`) will be used.
|
||||
padding_mask (Optional[torch.Tensor]): current it is not used
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
- A tensor of shape (B, T, H, W, D) with the embedded sequence.
|
||||
- An optional positional embedding tensor, returned only if the positional embedding class
|
||||
(`self.pos_emb_cls`) includes 'rope'. Otherwise, None.
|
||||
|
||||
Notes:
|
||||
- If `self.concat_padding_mask` is True, a padding mask channel is concatenated to the input tensor.
|
||||
- The method of applying positional embeddings depends on the value of `self.pos_emb_cls`.
|
||||
- If 'rope' is in `self.pos_emb_cls` (case insensitive), the positional embeddings are generated using
|
||||
the `self.pos_embedder` with the shape [T, H, W].
|
||||
- If "fps_aware" is in `self.pos_emb_cls`, the positional embeddings are generated using the
|
||||
`self.pos_embedder` with the fps tensor.
|
||||
- Otherwise, the positional embeddings are generated without considering fps.
|
||||
"""
|
||||
if self.concat_padding_mask:
|
||||
if padding_mask is None:
|
||||
padding_mask = torch.zeros(x_B_C_T_H_W.shape[0], 1, x_B_C_T_H_W.shape[3], x_B_C_T_H_W.shape[4], dtype=x_B_C_T_H_W.dtype, device=x_B_C_T_H_W.device)
|
||||
else:
|
||||
padding_mask = transforms.functional.resize(
|
||||
padding_mask, list(x_B_C_T_H_W.shape[-2:]), interpolation=transforms.InterpolationMode.NEAREST
|
||||
)
|
||||
x_B_C_T_H_W = torch.cat(
|
||||
[x_B_C_T_H_W, padding_mask.unsqueeze(1).repeat(1, 1, x_B_C_T_H_W.shape[2], 1, 1)], dim=1
|
||||
)
|
||||
x_B_T_H_W_D = self.x_embedder(x_B_C_T_H_W)
|
||||
|
||||
if self.extra_per_block_abs_pos_emb:
|
||||
extra_pos_emb = self.extra_pos_embedder(x_B_T_H_W_D, fps=fps, device=x_B_C_T_H_W.device, dtype=x_B_C_T_H_W.dtype)
|
||||
else:
|
||||
extra_pos_emb = None
|
||||
|
||||
if "rope" in self.pos_emb_cls.lower():
|
||||
return x_B_T_H_W_D, self.pos_embedder(x_B_T_H_W_D, fps=fps, device=x_B_C_T_H_W.device), extra_pos_emb
|
||||
x_B_T_H_W_D = x_B_T_H_W_D + self.pos_embedder(x_B_T_H_W_D, device=x_B_C_T_H_W.device) # [B, T, H, W, D]
|
||||
|
||||
return x_B_T_H_W_D, None, extra_pos_emb
|
||||
|
||||
def unpatchify(self, x_B_T_H_W_M: torch.Tensor) -> torch.Tensor:
|
||||
x_B_C_Tt_Hp_Wp = rearrange(
|
||||
x_B_T_H_W_M,
|
||||
"B T H W (p1 p2 t C) -> B C (T t) (H p1) (W p2)",
|
||||
p1=self.patch_spatial,
|
||||
p2=self.patch_spatial,
|
||||
t=self.patch_temporal,
|
||||
)
|
||||
return x_B_C_Tt_Hp_Wp
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
timesteps: torch.Tensor,
|
||||
context: torch.Tensor,
|
||||
fps: Optional[torch.Tensor] = None,
|
||||
padding_mask: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
):
|
||||
x_B_C_T_H_W = x
|
||||
timesteps_B_T = timesteps
|
||||
crossattn_emb = context
|
||||
"""
|
||||
Args:
|
||||
x: (B, C, T, H, W) tensor of spatial-temp inputs
|
||||
timesteps: (B, ) tensor of timesteps
|
||||
crossattn_emb: (B, N, D) tensor of cross-attention embeddings
|
||||
"""
|
||||
x_B_T_H_W_D, rope_emb_L_1_1_D, extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = self.prepare_embedded_sequence(
|
||||
x_B_C_T_H_W,
|
||||
fps=fps,
|
||||
padding_mask=padding_mask,
|
||||
)
|
||||
|
||||
if timesteps_B_T.ndim == 1:
|
||||
timesteps_B_T = timesteps_B_T.unsqueeze(1)
|
||||
t_embedding_B_T_D, adaln_lora_B_T_3D = self.t_embedder[1](self.t_embedder[0](timesteps_B_T).to(x_B_T_H_W_D.dtype))
|
||||
t_embedding_B_T_D = self.t_embedding_norm(t_embedding_B_T_D)
|
||||
|
||||
# for logging purpose
|
||||
affline_scale_log_info = {}
|
||||
affline_scale_log_info["t_embedding_B_T_D"] = t_embedding_B_T_D.detach()
|
||||
self.affline_scale_log_info = affline_scale_log_info
|
||||
self.affline_emb = t_embedding_B_T_D
|
||||
self.crossattn_emb = crossattn_emb
|
||||
|
||||
if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None:
|
||||
assert (
|
||||
x_B_T_H_W_D.shape == extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.shape
|
||||
), f"{x_B_T_H_W_D.shape} != {extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.shape}"
|
||||
|
||||
block_kwargs = {
|
||||
"rope_emb_L_1_1_D": rope_emb_L_1_1_D.unsqueeze(1).unsqueeze(0),
|
||||
"adaln_lora_B_T_3D": adaln_lora_B_T_3D,
|
||||
"extra_per_block_pos_emb": extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D,
|
||||
}
|
||||
for block in self.blocks:
|
||||
x_B_T_H_W_D = block(
|
||||
x_B_T_H_W_D,
|
||||
t_embedding_B_T_D,
|
||||
crossattn_emb,
|
||||
**block_kwargs,
|
||||
)
|
||||
|
||||
x_B_T_H_W_O = self.final_layer(x_B_T_H_W_D, t_embedding_B_T_D, adaln_lora_B_T_3D=adaln_lora_B_T_3D)
|
||||
x_B_C_Tt_Hp_Wp = self.unpatchify(x_B_T_H_W_O)
|
||||
return x_B_C_Tt_Hp_Wp
|
@@ -121,6 +121,11 @@ class ControlNetFlux(Flux):
|
||||
if img.ndim != 3 or txt.ndim != 3:
|
||||
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
||||
|
||||
if y is None:
|
||||
y = torch.zeros((img.shape[0], self.params.vec_in_dim), device=img.device, dtype=img.dtype)
|
||||
else:
|
||||
y = y[:, :self.params.vec_in_dim]
|
||||
|
||||
# running on sequences img
|
||||
img = self.img_in(img)
|
||||
|
||||
@@ -174,7 +179,7 @@ class ControlNetFlux(Flux):
|
||||
out["output"] = out_output[:self.main_model_single]
|
||||
return out
|
||||
|
||||
def forward(self, x, timesteps, context, y, guidance=None, hint=None, **kwargs):
|
||||
def forward(self, x, timesteps, context, y=None, guidance=None, hint=None, **kwargs):
|
||||
patch_size = 2
|
||||
if self.latent_input:
|
||||
hint = comfy.ldm.common_dit.pad_to_patch_size(hint, (patch_size, patch_size))
|
||||
|
@@ -118,7 +118,7 @@ class Modulation(nn.Module):
|
||||
def apply_mod(tensor, m_mult, m_add=None, modulation_dims=None):
|
||||
if modulation_dims is None:
|
||||
if m_add is not None:
|
||||
return tensor * m_mult + m_add
|
||||
return torch.addcmul(m_add, tensor, m_mult)
|
||||
else:
|
||||
return tensor * m_mult
|
||||
else:
|
||||
|
@@ -101,6 +101,10 @@ class Flux(nn.Module):
|
||||
transformer_options={},
|
||||
attn_mask: Tensor = None,
|
||||
) -> Tensor:
|
||||
|
||||
if y is None:
|
||||
y = torch.zeros((img.shape[0], self.params.vec_in_dim), device=img.device, dtype=img.dtype)
|
||||
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
if img.ndim != 3 or txt.ndim != 3:
|
||||
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
||||
@@ -155,6 +159,9 @@ class Flux(nn.Module):
|
||||
if add is not None:
|
||||
img += add
|
||||
|
||||
if img.dtype == torch.float16:
|
||||
img = torch.nan_to_num(img, nan=0.0, posinf=65504, neginf=-65504)
|
||||
|
||||
img = torch.cat((txt, img), 1)
|
||||
|
||||
for i, block in enumerate(self.single_blocks):
|
||||
@@ -188,20 +195,50 @@ class Flux(nn.Module):
|
||||
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
|
||||
return img
|
||||
|
||||
def forward(self, x, timestep, context, y, guidance=None, control=None, transformer_options={}, **kwargs):
|
||||
def process_img(self, x, index=0, h_offset=0, w_offset=0):
|
||||
bs, c, h, w = x.shape
|
||||
patch_size = self.patch_size
|
||||
x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size))
|
||||
|
||||
img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
|
||||
|
||||
h_len = ((h + (patch_size // 2)) // patch_size)
|
||||
w_len = ((w + (patch_size // 2)) // patch_size)
|
||||
|
||||
h_offset = ((h_offset + (patch_size // 2)) // patch_size)
|
||||
w_offset = ((w_offset + (patch_size // 2)) // patch_size)
|
||||
|
||||
img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
|
||||
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1)
|
||||
img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0)
|
||||
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
||||
img_ids[:, :, 0] = img_ids[:, :, 1] + index
|
||||
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(h_offset, h_len - 1 + h_offset, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1)
|
||||
img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0)
|
||||
return img, repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
||||
|
||||
def forward(self, x, timestep, context, y=None, guidance=None, ref_latents=None, control=None, transformer_options={}, **kwargs):
|
||||
bs, c, h_orig, w_orig = x.shape
|
||||
patch_size = self.patch_size
|
||||
|
||||
h_len = ((h_orig + (patch_size // 2)) // patch_size)
|
||||
w_len = ((w_orig + (patch_size // 2)) // patch_size)
|
||||
img, img_ids = self.process_img(x)
|
||||
img_tokens = img.shape[1]
|
||||
if ref_latents is not None:
|
||||
h = 0
|
||||
w = 0
|
||||
for ref in ref_latents:
|
||||
h_offset = 0
|
||||
w_offset = 0
|
||||
if ref.shape[-2] + h > ref.shape[-1] + w:
|
||||
w_offset = w
|
||||
else:
|
||||
h_offset = h
|
||||
|
||||
kontext, kontext_ids = self.process_img(ref, index=1, h_offset=h_offset, w_offset=w_offset)
|
||||
img = torch.cat([img, kontext], dim=1)
|
||||
img_ids = torch.cat([img_ids, kontext_ids], dim=1)
|
||||
h = max(h, ref.shape[-2] + h_offset)
|
||||
w = max(w, ref.shape[-1] + w_offset)
|
||||
|
||||
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
|
||||
out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control, transformer_options, attn_mask=kwargs.get("attention_mask", None))
|
||||
return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)[:,:,:h,:w]
|
||||
out = out[:, :img_tokens]
|
||||
return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)[:,:,:h_orig,:w_orig]
|
||||
|
@@ -261,8 +261,8 @@ class CrossAttention(nn.Module):
|
||||
self.heads = heads
|
||||
self.dim_head = dim_head
|
||||
|
||||
self.q_norm = operations.RMSNorm(inner_dim, dtype=dtype, device=device)
|
||||
self.k_norm = operations.RMSNorm(inner_dim, dtype=dtype, device=device)
|
||||
self.q_norm = operations.RMSNorm(inner_dim, eps=1e-5, dtype=dtype, device=device)
|
||||
self.k_norm = operations.RMSNorm(inner_dim, eps=1e-5, dtype=dtype, device=device)
|
||||
|
||||
self.to_q = operations.Linear(query_dim, inner_dim, bias=True, dtype=dtype, device=device)
|
||||
self.to_k = operations.Linear(context_dim, inner_dim, bias=True, dtype=dtype, device=device)
|
||||
|
@@ -11,7 +11,7 @@ from comfy.ldm.modules.ema import LitEma
|
||||
import comfy.ops
|
||||
|
||||
class DiagonalGaussianRegularizer(torch.nn.Module):
|
||||
def __init__(self, sample: bool = True):
|
||||
def __init__(self, sample: bool = False):
|
||||
super().__init__()
|
||||
self.sample = sample
|
||||
|
||||
@@ -19,16 +19,12 @@ class DiagonalGaussianRegularizer(torch.nn.Module):
|
||||
yield from ()
|
||||
|
||||
def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]:
|
||||
log = dict()
|
||||
posterior = DiagonalGaussianDistribution(z)
|
||||
if self.sample:
|
||||
z = posterior.sample()
|
||||
else:
|
||||
z = posterior.mode()
|
||||
kl_loss = posterior.kl()
|
||||
kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
|
||||
log["kl_loss"] = kl_loss
|
||||
return z, log
|
||||
return z, None
|
||||
|
||||
|
||||
class AbstractAutoencoder(torch.nn.Module):
|
||||
|
@@ -753,7 +753,7 @@ class BasicTransformerBlock(nn.Module):
|
||||
for p in patch:
|
||||
n = p(n, extra_options)
|
||||
|
||||
x += n
|
||||
x = n + x
|
||||
if "middle_patch" in transformer_patches:
|
||||
patch = transformer_patches["middle_patch"]
|
||||
for p in patch:
|
||||
@@ -793,12 +793,12 @@ class BasicTransformerBlock(nn.Module):
|
||||
for p in patch:
|
||||
n = p(n, extra_options)
|
||||
|
||||
x += n
|
||||
x = n + x
|
||||
if self.is_res:
|
||||
x_skip = x
|
||||
x = self.ff(self.norm3(x))
|
||||
if self.is_res:
|
||||
x += x_skip
|
||||
x = x_skip + x
|
||||
|
||||
return x
|
||||
|
||||
|
@@ -31,7 +31,7 @@ def dynamic_slice(
|
||||
starts: List[int],
|
||||
sizes: List[int],
|
||||
) -> Tensor:
|
||||
slicing = [slice(start, start + size) for start, size in zip(starts, sizes)]
|
||||
slicing = tuple(slice(start, start + size) for start, size in zip(starts, sizes))
|
||||
return x[slicing]
|
||||
|
||||
class AttnChunk(NamedTuple):
|
||||
|
469
comfy/ldm/omnigen/omnigen2.py
Normal file
469
comfy/ldm/omnigen/omnigen2.py
Normal file
@@ -0,0 +1,469 @@
|
||||
# Original code: https://github.com/VectorSpaceLab/OmniGen2
|
||||
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange, repeat
|
||||
from comfy.ldm.lightricks.model import Timesteps
|
||||
from comfy.ldm.flux.layers import EmbedND
|
||||
from comfy.ldm.modules.attention import optimized_attention_masked
|
||||
import comfy.model_management
|
||||
import comfy.ldm.common_dit
|
||||
|
||||
|
||||
def apply_rotary_emb(x, freqs_cis):
|
||||
if x.shape[1] == 0:
|
||||
return x
|
||||
|
||||
t_ = x.reshape(*x.shape[:-1], -1, 1, 2)
|
||||
t_out = freqs_cis[..., 0] * t_[..., 0] + freqs_cis[..., 1] * t_[..., 1]
|
||||
return t_out.reshape(*x.shape).to(dtype=x.dtype)
|
||||
|
||||
|
||||
def swiglu(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||
return F.silu(x) * y
|
||||
|
||||
|
||||
class TimestepEmbedding(nn.Module):
|
||||
def __init__(self, in_channels: int, time_embed_dim: int, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.linear_1 = operations.Linear(in_channels, time_embed_dim, dtype=dtype, device=device)
|
||||
self.act = nn.SiLU()
|
||||
self.linear_2 = operations.Linear(time_embed_dim, time_embed_dim, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, sample: torch.Tensor) -> torch.Tensor:
|
||||
sample = self.linear_1(sample)
|
||||
sample = self.act(sample)
|
||||
sample = self.linear_2(sample)
|
||||
return sample
|
||||
|
||||
|
||||
class LuminaRMSNormZero(nn.Module):
|
||||
def __init__(self, embedding_dim: int, norm_eps: float = 1e-5, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.silu = nn.SiLU()
|
||||
self.linear = operations.Linear(min(embedding_dim, 1024), 4 * embedding_dim, dtype=dtype, device=device)
|
||||
self.norm = operations.RMSNorm(embedding_dim, eps=norm_eps, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x: torch.Tensor, emb: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
emb = self.linear(self.silu(emb))
|
||||
scale_msa, gate_msa, scale_mlp, gate_mlp = emb.chunk(4, dim=1)
|
||||
x = self.norm(x) * (1 + scale_msa[:, None])
|
||||
return x, gate_msa, scale_mlp, gate_mlp
|
||||
|
||||
|
||||
class LuminaLayerNormContinuous(nn.Module):
|
||||
def __init__(self, embedding_dim: int, conditioning_embedding_dim: int, elementwise_affine: bool = False, eps: float = 1e-6, out_dim: Optional[int] = None, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.silu = nn.SiLU()
|
||||
self.linear_1 = operations.Linear(conditioning_embedding_dim, embedding_dim, dtype=dtype, device=device)
|
||||
self.norm = operations.LayerNorm(embedding_dim, eps, elementwise_affine, dtype=dtype, device=device)
|
||||
self.linear_2 = operations.Linear(embedding_dim, out_dim, bias=True, dtype=dtype, device=device) if out_dim is not None else None
|
||||
|
||||
def forward(self, x: torch.Tensor, conditioning_embedding: torch.Tensor) -> torch.Tensor:
|
||||
emb = self.linear_1(self.silu(conditioning_embedding).to(x.dtype))
|
||||
x = self.norm(x) * (1 + emb)[:, None, :]
|
||||
if self.linear_2 is not None:
|
||||
x = self.linear_2(x)
|
||||
return x
|
||||
|
||||
|
||||
class LuminaFeedForward(nn.Module):
|
||||
def __init__(self, dim: int, inner_dim: int, multiple_of: int = 256, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
inner_dim = multiple_of * ((inner_dim + multiple_of - 1) // multiple_of)
|
||||
self.linear_1 = operations.Linear(dim, inner_dim, bias=False, dtype=dtype, device=device)
|
||||
self.linear_2 = operations.Linear(inner_dim, dim, bias=False, dtype=dtype, device=device)
|
||||
self.linear_3 = operations.Linear(dim, inner_dim, bias=False, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
h1, h2 = self.linear_1(x), self.linear_3(x)
|
||||
return self.linear_2(swiglu(h1, h2))
|
||||
|
||||
|
||||
class Lumina2CombinedTimestepCaptionEmbedding(nn.Module):
|
||||
def __init__(self, hidden_size: int = 4096, text_feat_dim: int = 2048, frequency_embedding_size: int = 256, norm_eps: float = 1e-5, timestep_scale: float = 1.0, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.time_proj = Timesteps(num_channels=frequency_embedding_size, flip_sin_to_cos=True, downscale_freq_shift=0.0, scale=timestep_scale)
|
||||
self.timestep_embedder = TimestepEmbedding(in_channels=frequency_embedding_size, time_embed_dim=min(hidden_size, 1024), dtype=dtype, device=device, operations=operations)
|
||||
self.caption_embedder = nn.Sequential(
|
||||
operations.RMSNorm(text_feat_dim, eps=norm_eps, dtype=dtype, device=device),
|
||||
operations.Linear(text_feat_dim, hidden_size, bias=True, dtype=dtype, device=device),
|
||||
)
|
||||
|
||||
def forward(self, timestep: torch.Tensor, text_hidden_states: torch.Tensor, dtype: torch.dtype) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
timestep_proj = self.time_proj(timestep).to(dtype=dtype)
|
||||
time_embed = self.timestep_embedder(timestep_proj)
|
||||
caption_embed = self.caption_embedder(text_hidden_states)
|
||||
return time_embed, caption_embed
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, query_dim: int, dim_head: int, heads: int, kv_heads: int, eps: float = 1e-5, bias: bool = False, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.heads = heads
|
||||
self.kv_heads = kv_heads
|
||||
self.dim_head = dim_head
|
||||
self.scale = dim_head ** -0.5
|
||||
|
||||
self.to_q = operations.Linear(query_dim, heads * dim_head, bias=bias, dtype=dtype, device=device)
|
||||
self.to_k = operations.Linear(query_dim, kv_heads * dim_head, bias=bias, dtype=dtype, device=device)
|
||||
self.to_v = operations.Linear(query_dim, kv_heads * dim_head, bias=bias, dtype=dtype, device=device)
|
||||
|
||||
self.norm_q = operations.RMSNorm(dim_head, eps=eps, dtype=dtype, device=device)
|
||||
self.norm_k = operations.RMSNorm(dim_head, eps=eps, dtype=dtype, device=device)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
operations.Linear(heads * dim_head, query_dim, bias=bias, dtype=dtype, device=device),
|
||||
nn.Dropout(0.0)
|
||||
)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, image_rotary_emb: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
batch_size, sequence_length, _ = hidden_states.shape
|
||||
|
||||
query = self.to_q(hidden_states)
|
||||
key = self.to_k(encoder_hidden_states)
|
||||
value = self.to_v(encoder_hidden_states)
|
||||
|
||||
query = query.view(batch_size, -1, self.heads, self.dim_head)
|
||||
key = key.view(batch_size, -1, self.kv_heads, self.dim_head)
|
||||
value = value.view(batch_size, -1, self.kv_heads, self.dim_head)
|
||||
|
||||
query = self.norm_q(query)
|
||||
key = self.norm_k(key)
|
||||
|
||||
if image_rotary_emb is not None:
|
||||
query = apply_rotary_emb(query, image_rotary_emb)
|
||||
key = apply_rotary_emb(key, image_rotary_emb)
|
||||
|
||||
query = query.transpose(1, 2)
|
||||
key = key.transpose(1, 2)
|
||||
value = value.transpose(1, 2)
|
||||
|
||||
if self.kv_heads < self.heads:
|
||||
key = key.repeat_interleave(self.heads // self.kv_heads, dim=1)
|
||||
value = value.repeat_interleave(self.heads // self.kv_heads, dim=1)
|
||||
|
||||
hidden_states = optimized_attention_masked(query, key, value, self.heads, attention_mask, skip_reshape=True)
|
||||
hidden_states = self.to_out[0](hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class OmniGen2TransformerBlock(nn.Module):
|
||||
def __init__(self, dim: int, num_attention_heads: int, num_kv_heads: int, multiple_of: int, ffn_dim_multiplier: float, norm_eps: float, modulation: bool = True, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.modulation = modulation
|
||||
|
||||
self.attn = Attention(
|
||||
query_dim=dim,
|
||||
dim_head=dim // num_attention_heads,
|
||||
heads=num_attention_heads,
|
||||
kv_heads=num_kv_heads,
|
||||
eps=1e-5,
|
||||
bias=False,
|
||||
dtype=dtype, device=device, operations=operations,
|
||||
)
|
||||
|
||||
self.feed_forward = LuminaFeedForward(
|
||||
dim=dim,
|
||||
inner_dim=4 * dim,
|
||||
multiple_of=multiple_of,
|
||||
dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
|
||||
if modulation:
|
||||
self.norm1 = LuminaRMSNormZero(embedding_dim=dim, norm_eps=norm_eps, dtype=dtype, device=device, operations=operations)
|
||||
else:
|
||||
self.norm1 = operations.RMSNorm(dim, eps=norm_eps, dtype=dtype, device=device)
|
||||
|
||||
self.ffn_norm1 = operations.RMSNorm(dim, eps=norm_eps, dtype=dtype, device=device)
|
||||
self.norm2 = operations.RMSNorm(dim, eps=norm_eps, dtype=dtype, device=device)
|
||||
self.ffn_norm2 = operations.RMSNorm(dim, eps=norm_eps, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, image_rotary_emb: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
if self.modulation:
|
||||
norm_hidden_states, gate_msa, scale_mlp, gate_mlp = self.norm1(hidden_states, temb)
|
||||
attn_output = self.attn(norm_hidden_states, norm_hidden_states, attention_mask, image_rotary_emb)
|
||||
hidden_states = hidden_states + gate_msa.unsqueeze(1).tanh() * self.norm2(attn_output)
|
||||
mlp_output = self.feed_forward(self.ffn_norm1(hidden_states) * (1 + scale_mlp.unsqueeze(1)))
|
||||
hidden_states = hidden_states + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2(mlp_output)
|
||||
else:
|
||||
norm_hidden_states = self.norm1(hidden_states)
|
||||
attn_output = self.attn(norm_hidden_states, norm_hidden_states, attention_mask, image_rotary_emb)
|
||||
hidden_states = hidden_states + self.norm2(attn_output)
|
||||
mlp_output = self.feed_forward(self.ffn_norm1(hidden_states))
|
||||
hidden_states = hidden_states + self.ffn_norm2(mlp_output)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class OmniGen2RotaryPosEmbed(nn.Module):
|
||||
def __init__(self, theta: int, axes_dim: Tuple[int, int, int], axes_lens: Tuple[int, int, int] = (300, 512, 512), patch_size: int = 2):
|
||||
super().__init__()
|
||||
self.theta = theta
|
||||
self.axes_dim = axes_dim
|
||||
self.axes_lens = axes_lens
|
||||
self.patch_size = patch_size
|
||||
self.rope_embedder = EmbedND(dim=sum(axes_dim), theta=self.theta, axes_dim=axes_dim)
|
||||
|
||||
def forward(self, batch_size, encoder_seq_len, l_effective_cap_len, l_effective_ref_img_len, l_effective_img_len, ref_img_sizes, img_sizes, device):
|
||||
p = self.patch_size
|
||||
|
||||
seq_lengths = [cap_len + sum(ref_img_len) + img_len for cap_len, ref_img_len, img_len in zip(l_effective_cap_len, l_effective_ref_img_len, l_effective_img_len)]
|
||||
|
||||
max_seq_len = max(seq_lengths)
|
||||
max_ref_img_len = max([sum(ref_img_len) for ref_img_len in l_effective_ref_img_len])
|
||||
max_img_len = max(l_effective_img_len)
|
||||
|
||||
position_ids = torch.zeros(batch_size, max_seq_len, 3, dtype=torch.int32, device=device)
|
||||
|
||||
for i, (cap_seq_len, seq_len) in enumerate(zip(l_effective_cap_len, seq_lengths)):
|
||||
position_ids[i, :cap_seq_len] = repeat(torch.arange(cap_seq_len, dtype=torch.int32, device=device), "l -> l 3")
|
||||
|
||||
pe_shift = cap_seq_len
|
||||
pe_shift_len = cap_seq_len
|
||||
|
||||
if ref_img_sizes[i] is not None:
|
||||
for ref_img_size, ref_img_len in zip(ref_img_sizes[i], l_effective_ref_img_len[i]):
|
||||
H, W = ref_img_size
|
||||
ref_H_tokens, ref_W_tokens = H // p, W // p
|
||||
|
||||
row_ids = repeat(torch.arange(ref_H_tokens, dtype=torch.int32, device=device), "h -> h w", w=ref_W_tokens).flatten()
|
||||
col_ids = repeat(torch.arange(ref_W_tokens, dtype=torch.int32, device=device), "w -> h w", h=ref_H_tokens).flatten()
|
||||
position_ids[i, pe_shift_len:pe_shift_len + ref_img_len, 0] = pe_shift
|
||||
position_ids[i, pe_shift_len:pe_shift_len + ref_img_len, 1] = row_ids
|
||||
position_ids[i, pe_shift_len:pe_shift_len + ref_img_len, 2] = col_ids
|
||||
|
||||
pe_shift += max(ref_H_tokens, ref_W_tokens)
|
||||
pe_shift_len += ref_img_len
|
||||
|
||||
H, W = img_sizes[i]
|
||||
H_tokens, W_tokens = H // p, W // p
|
||||
|
||||
row_ids = repeat(torch.arange(H_tokens, dtype=torch.int32, device=device), "h -> h w", w=W_tokens).flatten()
|
||||
col_ids = repeat(torch.arange(W_tokens, dtype=torch.int32, device=device), "w -> h w", h=H_tokens).flatten()
|
||||
|
||||
position_ids[i, pe_shift_len: seq_len, 0] = pe_shift
|
||||
position_ids[i, pe_shift_len: seq_len, 1] = row_ids
|
||||
position_ids[i, pe_shift_len: seq_len, 2] = col_ids
|
||||
|
||||
freqs_cis = self.rope_embedder(position_ids).movedim(1, 2)
|
||||
|
||||
cap_freqs_cis_shape = list(freqs_cis.shape)
|
||||
cap_freqs_cis_shape[1] = encoder_seq_len
|
||||
cap_freqs_cis = torch.zeros(*cap_freqs_cis_shape, device=device, dtype=freqs_cis.dtype)
|
||||
|
||||
ref_img_freqs_cis_shape = list(freqs_cis.shape)
|
||||
ref_img_freqs_cis_shape[1] = max_ref_img_len
|
||||
ref_img_freqs_cis = torch.zeros(*ref_img_freqs_cis_shape, device=device, dtype=freqs_cis.dtype)
|
||||
|
||||
img_freqs_cis_shape = list(freqs_cis.shape)
|
||||
img_freqs_cis_shape[1] = max_img_len
|
||||
img_freqs_cis = torch.zeros(*img_freqs_cis_shape, device=device, dtype=freqs_cis.dtype)
|
||||
|
||||
for i, (cap_seq_len, ref_img_len, img_len, seq_len) in enumerate(zip(l_effective_cap_len, l_effective_ref_img_len, l_effective_img_len, seq_lengths)):
|
||||
cap_freqs_cis[i, :cap_seq_len] = freqs_cis[i, :cap_seq_len]
|
||||
ref_img_freqs_cis[i, :sum(ref_img_len)] = freqs_cis[i, cap_seq_len:cap_seq_len + sum(ref_img_len)]
|
||||
img_freqs_cis[i, :img_len] = freqs_cis[i, cap_seq_len + sum(ref_img_len):cap_seq_len + sum(ref_img_len) + img_len]
|
||||
|
||||
return cap_freqs_cis, ref_img_freqs_cis, img_freqs_cis, freqs_cis, l_effective_cap_len, seq_lengths
|
||||
|
||||
|
||||
class OmniGen2Transformer2DModel(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
patch_size: int = 2,
|
||||
in_channels: int = 16,
|
||||
out_channels: Optional[int] = None,
|
||||
hidden_size: int = 2304,
|
||||
num_layers: int = 26,
|
||||
num_refiner_layers: int = 2,
|
||||
num_attention_heads: int = 24,
|
||||
num_kv_heads: int = 8,
|
||||
multiple_of: int = 256,
|
||||
ffn_dim_multiplier: Optional[float] = None,
|
||||
norm_eps: float = 1e-5,
|
||||
axes_dim_rope: Tuple[int, int, int] = (32, 32, 32),
|
||||
axes_lens: Tuple[int, int, int] = (300, 512, 512),
|
||||
text_feat_dim: int = 1024,
|
||||
timestep_scale: float = 1.0,
|
||||
image_model=None,
|
||||
device=None,
|
||||
dtype=None,
|
||||
operations=None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.patch_size = patch_size
|
||||
self.out_channels = out_channels or in_channels
|
||||
self.hidden_size = hidden_size
|
||||
self.dtype = dtype
|
||||
|
||||
self.rope_embedder = OmniGen2RotaryPosEmbed(
|
||||
theta=10000,
|
||||
axes_dim=axes_dim_rope,
|
||||
axes_lens=axes_lens,
|
||||
patch_size=patch_size,
|
||||
)
|
||||
|
||||
self.x_embedder = operations.Linear(patch_size * patch_size * in_channels, hidden_size, dtype=dtype, device=device)
|
||||
self.ref_image_patch_embedder = operations.Linear(patch_size * patch_size * in_channels, hidden_size, dtype=dtype, device=device)
|
||||
|
||||
self.time_caption_embed = Lumina2CombinedTimestepCaptionEmbedding(
|
||||
hidden_size=hidden_size,
|
||||
text_feat_dim=text_feat_dim,
|
||||
norm_eps=norm_eps,
|
||||
timestep_scale=timestep_scale, dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
|
||||
self.noise_refiner = nn.ModuleList([
|
||||
OmniGen2TransformerBlock(
|
||||
hidden_size, num_attention_heads, num_kv_heads,
|
||||
multiple_of, ffn_dim_multiplier, norm_eps, modulation=True, dtype=dtype, device=device, operations=operations
|
||||
) for _ in range(num_refiner_layers)
|
||||
])
|
||||
|
||||
self.ref_image_refiner = nn.ModuleList([
|
||||
OmniGen2TransformerBlock(
|
||||
hidden_size, num_attention_heads, num_kv_heads,
|
||||
multiple_of, ffn_dim_multiplier, norm_eps, modulation=True, dtype=dtype, device=device, operations=operations
|
||||
) for _ in range(num_refiner_layers)
|
||||
])
|
||||
|
||||
self.context_refiner = nn.ModuleList([
|
||||
OmniGen2TransformerBlock(
|
||||
hidden_size, num_attention_heads, num_kv_heads,
|
||||
multiple_of, ffn_dim_multiplier, norm_eps, modulation=False, dtype=dtype, device=device, operations=operations
|
||||
) for _ in range(num_refiner_layers)
|
||||
])
|
||||
|
||||
self.layers = nn.ModuleList([
|
||||
OmniGen2TransformerBlock(
|
||||
hidden_size, num_attention_heads, num_kv_heads,
|
||||
multiple_of, ffn_dim_multiplier, norm_eps, modulation=True, dtype=dtype, device=device, operations=operations
|
||||
) for _ in range(num_layers)
|
||||
])
|
||||
|
||||
self.norm_out = LuminaLayerNormContinuous(
|
||||
embedding_dim=hidden_size,
|
||||
conditioning_embedding_dim=min(hidden_size, 1024),
|
||||
elementwise_affine=False,
|
||||
eps=1e-6,
|
||||
out_dim=patch_size * patch_size * self.out_channels, dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
|
||||
self.image_index_embedding = nn.Parameter(torch.empty(5, hidden_size, device=device, dtype=dtype))
|
||||
|
||||
def flat_and_pad_to_seq(self, hidden_states, ref_image_hidden_states):
|
||||
batch_size = len(hidden_states)
|
||||
p = self.patch_size
|
||||
|
||||
img_sizes = [(img.size(1), img.size(2)) for img in hidden_states]
|
||||
l_effective_img_len = [(H // p) * (W // p) for (H, W) in img_sizes]
|
||||
|
||||
if ref_image_hidden_states is not None:
|
||||
ref_image_hidden_states = list(map(lambda ref: comfy.ldm.common_dit.pad_to_patch_size(ref, (p, p)), ref_image_hidden_states))
|
||||
ref_img_sizes = [[(imgs.size(2), imgs.size(3)) if imgs is not None else None for imgs in ref_image_hidden_states]] * batch_size
|
||||
l_effective_ref_img_len = [[(ref_img_size[0] // p) * (ref_img_size[1] // p) for ref_img_size in _ref_img_sizes] if _ref_img_sizes is not None else [0] for _ref_img_sizes in ref_img_sizes]
|
||||
else:
|
||||
ref_img_sizes = [None for _ in range(batch_size)]
|
||||
l_effective_ref_img_len = [[0] for _ in range(batch_size)]
|
||||
|
||||
flat_ref_img_hidden_states = None
|
||||
if ref_image_hidden_states is not None:
|
||||
imgs = []
|
||||
for ref_img in ref_image_hidden_states:
|
||||
B, C, H, W = ref_img.size()
|
||||
ref_img = rearrange(ref_img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=p, p2=p)
|
||||
imgs.append(ref_img)
|
||||
flat_ref_img_hidden_states = torch.cat(imgs, dim=1)
|
||||
|
||||
img = hidden_states
|
||||
B, C, H, W = img.size()
|
||||
flat_hidden_states = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=p, p2=p)
|
||||
|
||||
return (
|
||||
flat_hidden_states, flat_ref_img_hidden_states,
|
||||
None, None,
|
||||
l_effective_ref_img_len, l_effective_img_len,
|
||||
ref_img_sizes, img_sizes,
|
||||
)
|
||||
|
||||
def img_patch_embed_and_refine(self, hidden_states, ref_image_hidden_states, padded_img_mask, padded_ref_img_mask, noise_rotary_emb, ref_img_rotary_emb, l_effective_ref_img_len, l_effective_img_len, temb):
|
||||
batch_size = len(hidden_states)
|
||||
|
||||
hidden_states = self.x_embedder(hidden_states)
|
||||
if ref_image_hidden_states is not None:
|
||||
ref_image_hidden_states = self.ref_image_patch_embedder(ref_image_hidden_states)
|
||||
image_index_embedding = comfy.model_management.cast_to(self.image_index_embedding, dtype=hidden_states.dtype, device=hidden_states.device)
|
||||
|
||||
for i in range(batch_size):
|
||||
shift = 0
|
||||
for j, ref_img_len in enumerate(l_effective_ref_img_len[i]):
|
||||
ref_image_hidden_states[i, shift:shift + ref_img_len, :] = ref_image_hidden_states[i, shift:shift + ref_img_len, :] + image_index_embedding[j]
|
||||
shift += ref_img_len
|
||||
|
||||
for layer in self.noise_refiner:
|
||||
hidden_states = layer(hidden_states, padded_img_mask, noise_rotary_emb, temb)
|
||||
|
||||
if ref_image_hidden_states is not None:
|
||||
for layer in self.ref_image_refiner:
|
||||
ref_image_hidden_states = layer(ref_image_hidden_states, padded_ref_img_mask, ref_img_rotary_emb, temb)
|
||||
|
||||
hidden_states = torch.cat([ref_image_hidden_states, hidden_states], dim=1)
|
||||
|
||||
return hidden_states
|
||||
|
||||
def forward(self, x, timesteps, context, num_tokens, ref_latents=None, attention_mask=None, **kwargs):
|
||||
B, C, H, W = x.shape
|
||||
hidden_states = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size))
|
||||
_, _, H_padded, W_padded = hidden_states.shape
|
||||
timestep = 1.0 - timesteps
|
||||
text_hidden_states = context
|
||||
text_attention_mask = attention_mask
|
||||
ref_image_hidden_states = ref_latents
|
||||
device = hidden_states.device
|
||||
|
||||
temb, text_hidden_states = self.time_caption_embed(timestep, text_hidden_states, hidden_states[0].dtype)
|
||||
|
||||
(
|
||||
hidden_states, ref_image_hidden_states,
|
||||
img_mask, ref_img_mask,
|
||||
l_effective_ref_img_len, l_effective_img_len,
|
||||
ref_img_sizes, img_sizes,
|
||||
) = self.flat_and_pad_to_seq(hidden_states, ref_image_hidden_states)
|
||||
|
||||
(
|
||||
context_rotary_emb, ref_img_rotary_emb, noise_rotary_emb,
|
||||
rotary_emb, encoder_seq_lengths, seq_lengths,
|
||||
) = self.rope_embedder(
|
||||
hidden_states.shape[0], text_hidden_states.shape[1], [num_tokens] * text_hidden_states.shape[0],
|
||||
l_effective_ref_img_len, l_effective_img_len,
|
||||
ref_img_sizes, img_sizes, device,
|
||||
)
|
||||
|
||||
for layer in self.context_refiner:
|
||||
text_hidden_states = layer(text_hidden_states, text_attention_mask, context_rotary_emb)
|
||||
|
||||
img_len = hidden_states.shape[1]
|
||||
combined_img_hidden_states = self.img_patch_embed_and_refine(
|
||||
hidden_states, ref_image_hidden_states,
|
||||
img_mask, ref_img_mask,
|
||||
noise_rotary_emb, ref_img_rotary_emb,
|
||||
l_effective_ref_img_len, l_effective_img_len,
|
||||
temb,
|
||||
)
|
||||
|
||||
hidden_states = torch.cat([text_hidden_states, combined_img_hidden_states], dim=1)
|
||||
attention_mask = None
|
||||
|
||||
for layer in self.layers:
|
||||
hidden_states = layer(hidden_states, attention_mask, rotary_emb, temb)
|
||||
|
||||
hidden_states = self.norm_out(hidden_states, temb)
|
||||
|
||||
p = self.patch_size
|
||||
output = rearrange(hidden_states[:, -img_len:], 'b (h w) (p1 p2 c) -> b c (h p1) (w p2)', h=H_padded // p, w=W_padded// p, p1=p, p2=p)[:, :, :H, :W]
|
||||
|
||||
return -output
|
@@ -34,12 +34,14 @@ import comfy.ldm.flux.model
|
||||
import comfy.ldm.lightricks.model
|
||||
import comfy.ldm.hunyuan_video.model
|
||||
import comfy.ldm.cosmos.model
|
||||
import comfy.ldm.cosmos.predict2
|
||||
import comfy.ldm.lumina.model
|
||||
import comfy.ldm.wan.model
|
||||
import comfy.ldm.hunyuan3d.model
|
||||
import comfy.ldm.hidream.model
|
||||
import comfy.ldm.chroma.model
|
||||
import comfy.ldm.ace.model
|
||||
import comfy.ldm.omnigen.omnigen2
|
||||
|
||||
import comfy.model_management
|
||||
import comfy.patcher_extension
|
||||
@@ -48,6 +50,7 @@ import comfy.ops
|
||||
from enum import Enum
|
||||
from . import utils
|
||||
import comfy.latent_formats
|
||||
import comfy.model_sampling
|
||||
import math
|
||||
from typing import TYPE_CHECKING
|
||||
if TYPE_CHECKING:
|
||||
@@ -63,38 +66,39 @@ class ModelType(Enum):
|
||||
V_PREDICTION_CONTINUOUS = 7
|
||||
FLUX = 8
|
||||
IMG_TO_IMG = 9
|
||||
|
||||
|
||||
from comfy.model_sampling import EPS, V_PREDICTION, EDM, ModelSamplingDiscrete, ModelSamplingContinuousEDM, StableCascadeSampling, ModelSamplingContinuousV
|
||||
FLOW_COSMOS = 10
|
||||
|
||||
|
||||
def model_sampling(model_config, model_type):
|
||||
s = ModelSamplingDiscrete
|
||||
s = comfy.model_sampling.ModelSamplingDiscrete
|
||||
|
||||
if model_type == ModelType.EPS:
|
||||
c = EPS
|
||||
c = comfy.model_sampling.EPS
|
||||
elif model_type == ModelType.V_PREDICTION:
|
||||
c = V_PREDICTION
|
||||
c = comfy.model_sampling.V_PREDICTION
|
||||
elif model_type == ModelType.V_PREDICTION_EDM:
|
||||
c = V_PREDICTION
|
||||
s = ModelSamplingContinuousEDM
|
||||
c = comfy.model_sampling.V_PREDICTION
|
||||
s = comfy.model_sampling.ModelSamplingContinuousEDM
|
||||
elif model_type == ModelType.FLOW:
|
||||
c = comfy.model_sampling.CONST
|
||||
s = comfy.model_sampling.ModelSamplingDiscreteFlow
|
||||
elif model_type == ModelType.STABLE_CASCADE:
|
||||
c = EPS
|
||||
s = StableCascadeSampling
|
||||
c = comfy.model_sampling.EPS
|
||||
s = comfy.model_sampling.StableCascadeSampling
|
||||
elif model_type == ModelType.EDM:
|
||||
c = EDM
|
||||
s = ModelSamplingContinuousEDM
|
||||
c = comfy.model_sampling.EDM
|
||||
s = comfy.model_sampling.ModelSamplingContinuousEDM
|
||||
elif model_type == ModelType.V_PREDICTION_CONTINUOUS:
|
||||
c = V_PREDICTION
|
||||
s = ModelSamplingContinuousV
|
||||
c = comfy.model_sampling.V_PREDICTION
|
||||
s = comfy.model_sampling.ModelSamplingContinuousV
|
||||
elif model_type == ModelType.FLUX:
|
||||
c = comfy.model_sampling.CONST
|
||||
s = comfy.model_sampling.ModelSamplingFlux
|
||||
elif model_type == ModelType.IMG_TO_IMG:
|
||||
c = comfy.model_sampling.IMG_TO_IMG
|
||||
elif model_type == ModelType.FLOW_COSMOS:
|
||||
c = comfy.model_sampling.COSMOS_RFLOW
|
||||
s = comfy.model_sampling.ModelSamplingCosmosRFlow
|
||||
|
||||
class ModelSampling(s, c):
|
||||
pass
|
||||
@@ -102,6 +106,13 @@ def model_sampling(model_config, model_type):
|
||||
return ModelSampling(model_config)
|
||||
|
||||
|
||||
def convert_tensor(extra, dtype):
|
||||
if hasattr(extra, "dtype"):
|
||||
if extra.dtype != torch.int and extra.dtype != torch.long:
|
||||
extra = extra.to(dtype)
|
||||
return extra
|
||||
|
||||
|
||||
class BaseModel(torch.nn.Module):
|
||||
def __init__(self, model_config, model_type=ModelType.EPS, device=None, unet_model=UNetModel):
|
||||
super().__init__()
|
||||
@@ -165,9 +176,14 @@ class BaseModel(torch.nn.Module):
|
||||
extra_conds = {}
|
||||
for o in kwargs:
|
||||
extra = kwargs[o]
|
||||
|
||||
if hasattr(extra, "dtype"):
|
||||
if extra.dtype != torch.int and extra.dtype != torch.long:
|
||||
extra = extra.to(dtype)
|
||||
extra = convert_tensor(extra, dtype)
|
||||
elif isinstance(extra, list):
|
||||
ex = []
|
||||
for ext in extra:
|
||||
ex.append(convert_tensor(ext, dtype))
|
||||
extra = ex
|
||||
extra_conds[o] = extra
|
||||
|
||||
t = self.process_timestep(t, x=x, **extra_conds)
|
||||
@@ -800,6 +816,7 @@ class PixArt(BaseModel):
|
||||
class Flux(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLUX, device=None, unet_model=comfy.ldm.flux.model.Flux):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=unet_model)
|
||||
self.memory_usage_factor_conds = ("ref_latents",)
|
||||
|
||||
def concat_cond(self, **kwargs):
|
||||
try:
|
||||
@@ -860,8 +877,23 @@ class Flux(BaseModel):
|
||||
guidance = kwargs.get("guidance", 3.5)
|
||||
if guidance is not None:
|
||||
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance]))
|
||||
|
||||
ref_latents = kwargs.get("reference_latents", None)
|
||||
if ref_latents is not None:
|
||||
latents = []
|
||||
for lat in ref_latents:
|
||||
latents.append(self.process_latent_in(lat))
|
||||
out['ref_latents'] = comfy.conds.CONDList(latents)
|
||||
return out
|
||||
|
||||
def extra_conds_shapes(self, **kwargs):
|
||||
out = {}
|
||||
ref_latents = kwargs.get("reference_latents", None)
|
||||
if ref_latents is not None:
|
||||
out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()), ref_latents)) // 16])
|
||||
return out
|
||||
|
||||
|
||||
class GenmoMochi(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.genmo.joint_model.asymm_models_joint.AsymmDiTJoint)
|
||||
@@ -986,6 +1018,45 @@ class CosmosVideo(BaseModel):
|
||||
latent_image = self.model_sampling.calculate_input(torch.tensor([sigma_noise_augmentation], device=latent_image.device, dtype=latent_image.dtype), latent_image)
|
||||
return latent_image * ((sigma ** 2 + self.model_sampling.sigma_data ** 2) ** 0.5)
|
||||
|
||||
class CosmosPredict2(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW_COSMOS, image_to_video=False, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.cosmos.predict2.MiniTrainDIT)
|
||||
self.image_to_video = image_to_video
|
||||
if self.image_to_video:
|
||||
self.concat_keys = ("mask_inverted",)
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = super().extra_conds(**kwargs)
|
||||
cross_attn = kwargs.get("cross_attn", None)
|
||||
if cross_attn is not None:
|
||||
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||
|
||||
denoise_mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None))
|
||||
if denoise_mask is not None:
|
||||
out["denoise_mask"] = comfy.conds.CONDRegular(denoise_mask)
|
||||
|
||||
out['fps'] = comfy.conds.CONDConstant(kwargs.get("frame_rate", None))
|
||||
return out
|
||||
|
||||
def process_timestep(self, timestep, x, denoise_mask=None, **kwargs):
|
||||
if denoise_mask is None:
|
||||
return timestep
|
||||
if denoise_mask.ndim <= 4:
|
||||
return timestep
|
||||
condition_video_mask_B_1_T_1_1 = denoise_mask.mean(dim=[1, 3, 4], keepdim=True)
|
||||
c_noise_B_1_T_1_1 = 0.0 * (1.0 - condition_video_mask_B_1_T_1_1) + timestep.reshape(timestep.shape[0], 1, 1, 1, 1) * condition_video_mask_B_1_T_1_1
|
||||
out = c_noise_B_1_T_1_1.squeeze(dim=[1, 3, 4])
|
||||
return out
|
||||
|
||||
def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs):
|
||||
sigma = sigma.reshape([sigma.shape[0]] + [1] * (len(noise.shape) - 1))
|
||||
sigma_noise_augmentation = 0 #TODO
|
||||
if sigma_noise_augmentation != 0:
|
||||
latent_image = latent_image + noise
|
||||
latent_image = self.model_sampling.calculate_input(torch.tensor([sigma_noise_augmentation], device=latent_image.device, dtype=latent_image.dtype), latent_image)
|
||||
sigma = (sigma / (sigma + 1))
|
||||
return latent_image / (1.0 - sigma)
|
||||
|
||||
class Lumina2(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.lumina.model.NextDiT)
|
||||
@@ -1176,3 +1247,33 @@ class ACEStep(BaseModel):
|
||||
out['speaker_embeds'] = comfy.conds.CONDRegular(torch.zeros(noise.shape[0], 512, device=noise.device, dtype=noise.dtype))
|
||||
out['lyrics_strength'] = comfy.conds.CONDConstant(kwargs.get("lyrics_strength", 1.0))
|
||||
return out
|
||||
|
||||
class Omnigen2(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.omnigen.omnigen2.OmniGen2Transformer2DModel)
|
||||
self.memory_usage_factor_conds = ("ref_latents",)
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = super().extra_conds(**kwargs)
|
||||
attention_mask = kwargs.get("attention_mask", None)
|
||||
if attention_mask is not None:
|
||||
if torch.numel(attention_mask) != attention_mask.sum():
|
||||
out['attention_mask'] = comfy.conds.CONDRegular(attention_mask)
|
||||
out['num_tokens'] = comfy.conds.CONDConstant(max(1, torch.sum(attention_mask).item()))
|
||||
cross_attn = kwargs.get("cross_attn", None)
|
||||
if cross_attn is not None:
|
||||
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||
ref_latents = kwargs.get("reference_latents", None)
|
||||
if ref_latents is not None:
|
||||
latents = []
|
||||
for lat in ref_latents:
|
||||
latents.append(self.process_latent_in(lat))
|
||||
out['ref_latents'] = comfy.conds.CONDList(latents)
|
||||
return out
|
||||
|
||||
def extra_conds_shapes(self, **kwargs):
|
||||
out = {}
|
||||
ref_latents = kwargs.get("reference_latents", None)
|
||||
if ref_latents is not None:
|
||||
out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()), ref_latents)) // 16])
|
||||
return out
|
||||
|
@@ -407,6 +407,78 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
dit_config["text_emb_dim"] = 2048
|
||||
return dit_config
|
||||
|
||||
if '{}blocks.0.mlp.layer1.weight'.format(key_prefix) in state_dict_keys: # Cosmos predict2
|
||||
dit_config = {}
|
||||
dit_config["image_model"] = "cosmos_predict2"
|
||||
dit_config["max_img_h"] = 240
|
||||
dit_config["max_img_w"] = 240
|
||||
dit_config["max_frames"] = 128
|
||||
concat_padding_mask = True
|
||||
dit_config["in_channels"] = (state_dict['{}x_embedder.proj.1.weight'.format(key_prefix)].shape[1] // 4) - int(concat_padding_mask)
|
||||
dit_config["out_channels"] = 16
|
||||
dit_config["patch_spatial"] = 2
|
||||
dit_config["patch_temporal"] = 1
|
||||
dit_config["model_channels"] = state_dict['{}x_embedder.proj.1.weight'.format(key_prefix)].shape[0]
|
||||
dit_config["concat_padding_mask"] = concat_padding_mask
|
||||
dit_config["crossattn_emb_channels"] = 1024
|
||||
dit_config["pos_emb_cls"] = "rope3d"
|
||||
dit_config["pos_emb_learnable"] = True
|
||||
dit_config["pos_emb_interpolation"] = "crop"
|
||||
dit_config["min_fps"] = 1
|
||||
dit_config["max_fps"] = 30
|
||||
|
||||
dit_config["use_adaln_lora"] = True
|
||||
dit_config["adaln_lora_dim"] = 256
|
||||
if dit_config["model_channels"] == 2048:
|
||||
dit_config["num_blocks"] = 28
|
||||
dit_config["num_heads"] = 16
|
||||
elif dit_config["model_channels"] == 5120:
|
||||
dit_config["num_blocks"] = 36
|
||||
dit_config["num_heads"] = 40
|
||||
|
||||
if dit_config["in_channels"] == 16:
|
||||
dit_config["extra_per_block_abs_pos_emb"] = False
|
||||
dit_config["rope_h_extrapolation_ratio"] = 4.0
|
||||
dit_config["rope_w_extrapolation_ratio"] = 4.0
|
||||
dit_config["rope_t_extrapolation_ratio"] = 1.0
|
||||
elif dit_config["in_channels"] == 17: # img to video
|
||||
if dit_config["model_channels"] == 2048:
|
||||
dit_config["extra_per_block_abs_pos_emb"] = False
|
||||
dit_config["rope_h_extrapolation_ratio"] = 3.0
|
||||
dit_config["rope_w_extrapolation_ratio"] = 3.0
|
||||
dit_config["rope_t_extrapolation_ratio"] = 1.0
|
||||
elif dit_config["model_channels"] == 5120:
|
||||
dit_config["rope_h_extrapolation_ratio"] = 2.0
|
||||
dit_config["rope_w_extrapolation_ratio"] = 2.0
|
||||
dit_config["rope_t_extrapolation_ratio"] = 0.8333333333333334
|
||||
|
||||
dit_config["extra_h_extrapolation_ratio"] = 1.0
|
||||
dit_config["extra_w_extrapolation_ratio"] = 1.0
|
||||
dit_config["extra_t_extrapolation_ratio"] = 1.0
|
||||
dit_config["rope_enable_fps_modulation"] = False
|
||||
|
||||
return dit_config
|
||||
|
||||
if '{}time_caption_embed.timestep_embedder.linear_1.bias'.format(key_prefix) in state_dict_keys: # Omnigen2
|
||||
dit_config = {}
|
||||
dit_config["image_model"] = "omnigen2"
|
||||
dit_config["axes_dim_rope"] = [40, 40, 40]
|
||||
dit_config["axes_lens"] = [1024, 1664, 1664]
|
||||
dit_config["ffn_dim_multiplier"] = None
|
||||
dit_config["hidden_size"] = 2520
|
||||
dit_config["in_channels"] = 16
|
||||
dit_config["multiple_of"] = 256
|
||||
dit_config["norm_eps"] = 1e-05
|
||||
dit_config["num_attention_heads"] = 21
|
||||
dit_config["num_kv_heads"] = 7
|
||||
dit_config["num_layers"] = 32
|
||||
dit_config["num_refiner_layers"] = 2
|
||||
dit_config["out_channels"] = None
|
||||
dit_config["patch_size"] = 2
|
||||
dit_config["text_feat_dim"] = 2048
|
||||
dit_config["timestep_scale"] = 1000.0
|
||||
return dit_config
|
||||
|
||||
if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys:
|
||||
return None
|
||||
|
||||
|
@@ -295,14 +295,24 @@ except:
|
||||
pass
|
||||
|
||||
|
||||
SUPPORT_FP8_OPS = args.supports_fp8_compute
|
||||
try:
|
||||
if is_amd():
|
||||
try:
|
||||
rocm_version = tuple(map(int, str(torch.version.hip).split(".")[:2]))
|
||||
except:
|
||||
rocm_version = (6, -1)
|
||||
arch = torch.cuda.get_device_properties(get_torch_device()).gcnArchName
|
||||
logging.info("AMD arch: {}".format(arch))
|
||||
logging.info("ROCm version: {}".format(rocm_version))
|
||||
if args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
|
||||
if torch_version_numeric[0] >= 2 and torch_version_numeric[1] >= 7: # works on 2.6 but doesn't actually seem to improve much
|
||||
if any((a in arch) for a in ["gfx1100", "gfx1101", "gfx1151"]): # TODO: more arches
|
||||
if torch_version_numeric >= (2, 7): # works on 2.6 but doesn't actually seem to improve much
|
||||
if any((a in arch) for a in ["gfx90a", "gfx942", "gfx1100", "gfx1101", "gfx1151"]): # TODO: more arches, TODO: gfx1201 and gfx950
|
||||
ENABLE_PYTORCH_ATTENTION = True
|
||||
if torch_version_numeric >= (2, 7) and rocm_version >= (6, 4):
|
||||
if any((a in arch) for a in ["gfx1201", "gfx942", "gfx950"]): # TODO: more arches
|
||||
SUPPORT_FP8_OPS = True
|
||||
|
||||
except:
|
||||
pass
|
||||
|
||||
@@ -323,7 +333,7 @@ except:
|
||||
pass
|
||||
|
||||
try:
|
||||
if torch_version_numeric[0] == 2 and torch_version_numeric[1] >= 5:
|
||||
if torch_version_numeric >= (2, 5):
|
||||
torch.backends.cuda.allow_fp16_bf16_reduction_math_sdp(True)
|
||||
except:
|
||||
logging.warning("Warning, could not set allow_fp16_bf16_reduction_math_sdp")
|
||||
@@ -1042,7 +1052,7 @@ def pytorch_attention_flash_attention():
|
||||
global ENABLE_PYTORCH_ATTENTION
|
||||
if ENABLE_PYTORCH_ATTENTION:
|
||||
#TODO: more reliable way of checking for flash attention?
|
||||
if is_nvidia(): #pytorch flash attention only works on Nvidia
|
||||
if is_nvidia():
|
||||
return True
|
||||
if is_intel_xpu():
|
||||
return True
|
||||
@@ -1058,7 +1068,7 @@ def force_upcast_attention_dtype():
|
||||
upcast = args.force_upcast_attention
|
||||
|
||||
macos_version = mac_version()
|
||||
if macos_version is not None and ((14, 5) <= macos_version < (16,)): # black image bug on recent versions of macOS
|
||||
if macos_version is not None and ((14, 5) <= macos_version): # black image bug on recent versions of macOS, I don't think it's ever getting fixed
|
||||
upcast = True
|
||||
|
||||
if upcast:
|
||||
@@ -1257,7 +1267,7 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma
|
||||
return False
|
||||
|
||||
def supports_fp8_compute(device=None):
|
||||
if args.supports_fp8_compute:
|
||||
if SUPPORT_FP8_OPS:
|
||||
return True
|
||||
|
||||
if not is_nvidia():
|
||||
@@ -1271,15 +1281,22 @@ def supports_fp8_compute(device=None):
|
||||
if props.minor < 9:
|
||||
return False
|
||||
|
||||
if torch_version_numeric[0] < 2 or (torch_version_numeric[0] == 2 and torch_version_numeric[1] < 3):
|
||||
if torch_version_numeric < (2, 3):
|
||||
return False
|
||||
|
||||
if WINDOWS:
|
||||
if (torch_version_numeric[0] == 2 and torch_version_numeric[1] < 4):
|
||||
if torch_version_numeric < (2, 4):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def extended_fp16_support():
|
||||
# TODO: check why some models work with fp16 on newer torch versions but not on older
|
||||
if torch_version_numeric < (2, 7):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def soft_empty_cache(force=False):
|
||||
global cpu_state
|
||||
if cpu_state == CPUState.MPS:
|
||||
|
@@ -17,23 +17,26 @@
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
from typing import Optional, Callable
|
||||
import torch
|
||||
|
||||
import collections
|
||||
import copy
|
||||
import inspect
|
||||
import logging
|
||||
import uuid
|
||||
import collections
|
||||
import math
|
||||
import uuid
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
|
||||
import comfy.utils
|
||||
import comfy.float
|
||||
import comfy.model_management
|
||||
import comfy.lora
|
||||
import comfy.hooks
|
||||
import comfy.lora
|
||||
import comfy.model_management
|
||||
import comfy.patcher_extension
|
||||
from comfy.patcher_extension import CallbacksMP, WrappersMP, PatcherInjection
|
||||
import comfy.utils
|
||||
from comfy.comfy_types import UnetWrapperFunction
|
||||
from comfy.patcher_extension import CallbacksMP, PatcherInjection, WrappersMP
|
||||
|
||||
|
||||
def string_to_seed(data):
|
||||
crc = 0xFFFFFFFF
|
||||
|
@@ -77,6 +77,25 @@ class IMG_TO_IMG(X0):
|
||||
def calculate_input(self, sigma, noise):
|
||||
return noise
|
||||
|
||||
class COSMOS_RFLOW:
|
||||
def calculate_input(self, sigma, noise):
|
||||
sigma = (sigma / (sigma + 1))
|
||||
sigma = sigma.view(sigma.shape[:1] + (1,) * (noise.ndim - 1))
|
||||
return noise * (1.0 - sigma)
|
||||
|
||||
def calculate_denoised(self, sigma, model_output, model_input):
|
||||
sigma = (sigma / (sigma + 1))
|
||||
sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
|
||||
return model_input * (1.0 - sigma) - model_output * sigma
|
||||
|
||||
def noise_scaling(self, sigma, noise, latent_image, max_denoise=False):
|
||||
sigma = sigma.view(sigma.shape[:1] + (1,) * (noise.ndim - 1))
|
||||
noise = noise * sigma
|
||||
noise += latent_image
|
||||
return noise
|
||||
|
||||
def inverse_noise_scaling(self, sigma, latent):
|
||||
return latent
|
||||
|
||||
class ModelSamplingDiscrete(torch.nn.Module):
|
||||
def __init__(self, model_config=None, zsnr=None):
|
||||
@@ -350,3 +369,15 @@ class ModelSamplingFlux(torch.nn.Module):
|
||||
if percent >= 1.0:
|
||||
return 0.0
|
||||
return flux_time_shift(self.shift, 1.0, 1.0 - percent)
|
||||
|
||||
|
||||
class ModelSamplingCosmosRFlow(ModelSamplingContinuousEDM):
|
||||
def timestep(self, sigma):
|
||||
return sigma / (sigma + 1)
|
||||
|
||||
def sigma(self, timestep):
|
||||
sigma_max = self.sigma_max
|
||||
if timestep >= (sigma_max / (sigma_max + 1)):
|
||||
return sigma_max
|
||||
|
||||
return timestep / (1 - timestep)
|
||||
|
@@ -1039,13 +1039,13 @@ class SchedulerHandler(NamedTuple):
|
||||
use_ms: bool = True
|
||||
|
||||
SCHEDULER_HANDLERS = {
|
||||
"normal": SchedulerHandler(normal_scheduler),
|
||||
"simple": SchedulerHandler(simple_scheduler),
|
||||
"sgm_uniform": SchedulerHandler(partial(normal_scheduler, sgm=True)),
|
||||
"karras": SchedulerHandler(k_diffusion_sampling.get_sigmas_karras, use_ms=False),
|
||||
"exponential": SchedulerHandler(k_diffusion_sampling.get_sigmas_exponential, use_ms=False),
|
||||
"sgm_uniform": SchedulerHandler(partial(normal_scheduler, sgm=True)),
|
||||
"simple": SchedulerHandler(simple_scheduler),
|
||||
"ddim_uniform": SchedulerHandler(ddim_scheduler),
|
||||
"beta": SchedulerHandler(beta_scheduler),
|
||||
"normal": SchedulerHandler(normal_scheduler),
|
||||
"linear_quadratic": SchedulerHandler(linear_quadratic_schedule),
|
||||
"kl_optimal": SchedulerHandler(kl_optimal_scheduler, use_ms=False),
|
||||
}
|
||||
|
35
comfy/sd.py
35
comfy/sd.py
@@ -44,6 +44,7 @@ import comfy.text_encoders.lumina2
|
||||
import comfy.text_encoders.wan
|
||||
import comfy.text_encoders.hidream
|
||||
import comfy.text_encoders.ace
|
||||
import comfy.text_encoders.omnigen2
|
||||
|
||||
import comfy.model_patcher
|
||||
import comfy.lora
|
||||
@@ -754,6 +755,7 @@ class CLIPType(Enum):
|
||||
HIDREAM = 14
|
||||
CHROMA = 15
|
||||
ACE = 16
|
||||
OMNIGEN2 = 17
|
||||
|
||||
|
||||
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
|
||||
@@ -773,6 +775,7 @@ class TEModel(Enum):
|
||||
LLAMA3_8 = 7
|
||||
T5_XXL_OLD = 8
|
||||
GEMMA_2_2B = 9
|
||||
QWEN25_3B = 10
|
||||
|
||||
def detect_te_model(sd):
|
||||
if "text_model.encoder.layers.30.mlp.fc1.weight" in sd:
|
||||
@@ -793,6 +796,8 @@ def detect_te_model(sd):
|
||||
return TEModel.T5_BASE
|
||||
if 'model.layers.0.post_feedforward_layernorm.weight' in sd:
|
||||
return TEModel.GEMMA_2_2B
|
||||
if 'model.layers.0.self_attn.k_proj.bias' in sd:
|
||||
return TEModel.QWEN25_3B
|
||||
if "model.layers.0.post_attention_layernorm.weight" in sd:
|
||||
return TEModel.LLAMA3_8
|
||||
return None
|
||||
@@ -894,6 +899,9 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
||||
clip_target.clip = comfy.text_encoders.hidream.hidream_clip(**llama_detect(clip_data),
|
||||
clip_l=False, clip_g=False, t5=False, llama=True, dtype_t5=None, t5xxl_scaled_fp8=None)
|
||||
clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer
|
||||
elif te_model == TEModel.QWEN25_3B:
|
||||
clip_target.clip = comfy.text_encoders.omnigen2.te(**llama_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.omnigen2.Omnigen2Tokenizer
|
||||
else:
|
||||
# clip_l
|
||||
if clip_type == CLIPType.SD3:
|
||||
@@ -1081,7 +1089,28 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
|
||||
return (model_patcher, clip, vae, clipvision)
|
||||
|
||||
|
||||
def load_diffusion_model_state_dict(sd, model_options={}): #load unet in diffusers or regular format
|
||||
def load_diffusion_model_state_dict(sd, model_options={}):
|
||||
"""
|
||||
Loads a UNet diffusion model from a state dictionary, supporting both diffusers and regular formats.
|
||||
|
||||
Args:
|
||||
sd (dict): State dictionary containing model weights and configuration
|
||||
model_options (dict, optional): Additional options for model loading. Supports:
|
||||
- dtype: Override model data type
|
||||
- custom_operations: Custom model operations
|
||||
- fp8_optimizations: Enable FP8 optimizations
|
||||
|
||||
Returns:
|
||||
ModelPatcher: A wrapped model instance that handles device management and weight loading.
|
||||
Returns None if the model configuration cannot be detected.
|
||||
|
||||
The function:
|
||||
1. Detects and handles different model formats (regular, diffusers, mmdit)
|
||||
2. Configures model dtype based on parameters and device capabilities
|
||||
3. Handles weight conversion and device placement
|
||||
4. Manages model optimization settings
|
||||
5. Loads weights and returns a device-managed model instance
|
||||
"""
|
||||
dtype = model_options.get("dtype", None)
|
||||
|
||||
#Allow loading unets from checkpoint files
|
||||
@@ -1139,7 +1168,7 @@ def load_diffusion_model_state_dict(sd, model_options={}): #load unet in diffuse
|
||||
model.load_model_weights(new_sd, "")
|
||||
left_over = sd.keys()
|
||||
if len(left_over) > 0:
|
||||
logging.info("left over keys in unet: {}".format(left_over))
|
||||
logging.info("left over keys in diffusion model: {}".format(left_over))
|
||||
return comfy.model_patcher.ModelPatcher(model, load_device=load_device, offload_device=offload_device)
|
||||
|
||||
|
||||
@@ -1147,7 +1176,7 @@ def load_diffusion_model(unet_path, model_options={}):
|
||||
sd = comfy.utils.load_torch_file(unet_path)
|
||||
model = load_diffusion_model_state_dict(sd, model_options=model_options)
|
||||
if model is None:
|
||||
logging.error("ERROR UNSUPPORTED UNET {}".format(unet_path))
|
||||
logging.error("ERROR UNSUPPORTED DIFFUSION MODEL {}".format(unet_path))
|
||||
raise RuntimeError("ERROR: Could not detect model type of: {}".format(unet_path))
|
||||
return model
|
||||
|
||||
|
@@ -462,7 +462,7 @@ class SDTokenizer:
|
||||
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_tokenizer")
|
||||
self.tokenizer = tokenizer_class.from_pretrained(tokenizer_path, **tokenizer_args)
|
||||
self.max_length = tokenizer_data.get("{}_max_length".format(embedding_key), max_length)
|
||||
self.min_length = min_length
|
||||
self.min_length = tokenizer_data.get("{}_min_length".format(embedding_key), min_length)
|
||||
self.end_token = None
|
||||
self.min_padding = min_padding
|
||||
|
||||
@@ -482,7 +482,8 @@ class SDTokenizer:
|
||||
if end_token is not None:
|
||||
self.end_token = end_token
|
||||
else:
|
||||
self.end_token = empty[0]
|
||||
if has_end_token:
|
||||
self.end_token = empty[0]
|
||||
|
||||
if pad_token is not None:
|
||||
self.pad_token = pad_token
|
||||
|
@@ -18,6 +18,7 @@ import comfy.text_encoders.cosmos
|
||||
import comfy.text_encoders.lumina2
|
||||
import comfy.text_encoders.wan
|
||||
import comfy.text_encoders.ace
|
||||
import comfy.text_encoders.omnigen2
|
||||
|
||||
from . import supported_models_base
|
||||
from . import latent_formats
|
||||
@@ -908,6 +909,48 @@ class CosmosI2V(CosmosT2V):
|
||||
out = model_base.CosmosVideo(self, image_to_video=True, device=device)
|
||||
return out
|
||||
|
||||
class CosmosT2IPredict2(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"image_model": "cosmos_predict2",
|
||||
"in_channels": 16,
|
||||
}
|
||||
|
||||
sampling_settings = {
|
||||
"sigma_data": 1.0,
|
||||
"sigma_max": 80.0,
|
||||
"sigma_min": 0.002,
|
||||
}
|
||||
|
||||
unet_extra_config = {}
|
||||
latent_format = latent_formats.Wan21
|
||||
|
||||
memory_usage_factor = 1.0
|
||||
|
||||
supported_inference_dtypes = [torch.bfloat16, torch.float32]
|
||||
|
||||
def __init__(self, unet_config):
|
||||
super().__init__(unet_config)
|
||||
self.memory_usage_factor = (unet_config.get("model_channels", 2048) / 2048) * 0.9
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.CosmosPredict2(self, device=device)
|
||||
return out
|
||||
|
||||
def clip_target(self, state_dict={}):
|
||||
pref = self.text_encoder_key_prefix[0]
|
||||
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.cosmos.CosmosT5Tokenizer, comfy.text_encoders.cosmos.te(**t5_detect))
|
||||
|
||||
class CosmosI2VPredict2(CosmosT2IPredict2):
|
||||
unet_config = {
|
||||
"image_model": "cosmos_predict2",
|
||||
"in_channels": 17,
|
||||
}
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.CosmosPredict2(self, image_to_video=True, device=device)
|
||||
return out
|
||||
|
||||
class Lumina2(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"image_model": "lumina2",
|
||||
@@ -1139,6 +1182,41 @@ class ACEStep(supported_models_base.BASE):
|
||||
def clip_target(self, state_dict={}):
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.ace.AceT5Tokenizer, comfy.text_encoders.ace.AceT5Model)
|
||||
|
||||
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, Lumina2, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream, Chroma, ACEStep]
|
||||
class Omnigen2(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"image_model": "omnigen2",
|
||||
}
|
||||
|
||||
sampling_settings = {
|
||||
"multiplier": 1.0,
|
||||
"shift": 2.6,
|
||||
}
|
||||
|
||||
memory_usage_factor = 1.65 #TODO
|
||||
|
||||
unet_extra_config = {}
|
||||
latent_format = latent_formats.Flux
|
||||
|
||||
supported_inference_dtypes = [torch.bfloat16, torch.float32]
|
||||
|
||||
vae_key_prefix = ["vae."]
|
||||
text_encoder_key_prefix = ["text_encoders."]
|
||||
|
||||
def __init__(self, unet_config):
|
||||
super().__init__(unet_config)
|
||||
if comfy.model_management.extended_fp16_support():
|
||||
self.supported_inference_dtypes = [torch.float16] + self.supported_inference_dtypes
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.Omnigen2(self, device=device)
|
||||
return out
|
||||
|
||||
def clip_target(self, state_dict={}):
|
||||
pref = self.text_encoder_key_prefix[0]
|
||||
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_3b.transformer.".format(pref))
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.omnigen2.LuminaTokenizer, comfy.text_encoders.omnigen2.te(**hunyuan_detect))
|
||||
|
||||
|
||||
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream, Chroma, ACEStep, Omnigen2]
|
||||
|
||||
models += [SVD_img2vid]
|
||||
|
@@ -24,6 +24,24 @@ class Llama2Config:
|
||||
head_dim = 128
|
||||
rms_norm_add = False
|
||||
mlp_activation = "silu"
|
||||
qkv_bias = False
|
||||
|
||||
@dataclass
|
||||
class Qwen25_3BConfig:
|
||||
vocab_size: int = 151936
|
||||
hidden_size: int = 2048
|
||||
intermediate_size: int = 11008
|
||||
num_hidden_layers: int = 36
|
||||
num_attention_heads: int = 16
|
||||
num_key_value_heads: int = 2
|
||||
max_position_embeddings: int = 128000
|
||||
rms_norm_eps: float = 1e-6
|
||||
rope_theta: float = 1000000.0
|
||||
transformer_type: str = "llama"
|
||||
head_dim = 128
|
||||
rms_norm_add = False
|
||||
mlp_activation = "silu"
|
||||
qkv_bias = True
|
||||
|
||||
@dataclass
|
||||
class Gemma2_2B_Config:
|
||||
@@ -40,6 +58,7 @@ class Gemma2_2B_Config:
|
||||
head_dim = 256
|
||||
rms_norm_add = True
|
||||
mlp_activation = "gelu_pytorch_tanh"
|
||||
qkv_bias = False
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(self, dim: int, eps: float = 1e-5, add=False, device=None, dtype=None):
|
||||
@@ -98,9 +117,9 @@ class Attention(nn.Module):
|
||||
self.inner_size = self.num_heads * self.head_dim
|
||||
|
||||
ops = ops or nn
|
||||
self.q_proj = ops.Linear(config.hidden_size, self.inner_size, bias=False, device=device, dtype=dtype)
|
||||
self.k_proj = ops.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False, device=device, dtype=dtype)
|
||||
self.v_proj = ops.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False, device=device, dtype=dtype)
|
||||
self.q_proj = ops.Linear(config.hidden_size, self.inner_size, bias=config.qkv_bias, device=device, dtype=dtype)
|
||||
self.k_proj = ops.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=config.qkv_bias, device=device, dtype=dtype)
|
||||
self.v_proj = ops.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=config.qkv_bias, device=device, dtype=dtype)
|
||||
self.o_proj = ops.Linear(self.inner_size, config.hidden_size, bias=False, device=device, dtype=dtype)
|
||||
|
||||
def forward(
|
||||
@@ -320,6 +339,14 @@ class Llama2(BaseLlama, torch.nn.Module):
|
||||
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
|
||||
self.dtype = dtype
|
||||
|
||||
class Qwen25_3B(BaseLlama, torch.nn.Module):
|
||||
def __init__(self, config_dict, dtype, device, operations):
|
||||
super().__init__()
|
||||
config = Qwen25_3BConfig(**config_dict)
|
||||
self.num_layers = config.num_hidden_layers
|
||||
|
||||
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
|
||||
self.dtype = dtype
|
||||
|
||||
class Gemma2_2B(BaseLlama, torch.nn.Module):
|
||||
def __init__(self, config_dict, dtype, device, operations):
|
||||
|
44
comfy/text_encoders/omnigen2.py
Normal file
44
comfy/text_encoders/omnigen2.py
Normal file
@@ -0,0 +1,44 @@
|
||||
from transformers import Qwen2Tokenizer
|
||||
from comfy import sd1_clip
|
||||
import comfy.text_encoders.llama
|
||||
import os
|
||||
|
||||
|
||||
class Qwen25_3BTokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "qwen25_tokenizer")
|
||||
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=2048, embedding_key='qwen25_3b', tokenizer_class=Qwen2Tokenizer, has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, pad_token=151643, tokenizer_data=tokenizer_data)
|
||||
|
||||
|
||||
class Omnigen2Tokenizer(sd1_clip.SD1Tokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="qwen25_3b", tokenizer=Qwen25_3BTokenizer)
|
||||
self.llama_template = '<|im_start|>system\nYou are a helpful assistant that generates high-quality images based on user instructions.<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n'
|
||||
|
||||
def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None,**kwargs):
|
||||
if llama_template is None:
|
||||
llama_text = self.llama_template.format(text)
|
||||
else:
|
||||
llama_text = llama_template.format(text)
|
||||
return super().tokenize_with_weights(llama_text, return_word_ids=return_word_ids, **kwargs)
|
||||
|
||||
class Qwen25_3BModel(sd1_clip.SDClipModel):
|
||||
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=True, model_options={}):
|
||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Qwen25_3B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
|
||||
|
||||
|
||||
class Omnigen2Model(sd1_clip.SD1ClipModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
super().__init__(device=device, dtype=dtype, name="qwen25_3b", clip_model=Qwen25_3BModel, model_options=model_options)
|
||||
|
||||
|
||||
def te(dtype_llama=None, llama_scaled_fp8=None):
|
||||
class Omnigen2TEModel_(Omnigen2Model):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options:
|
||||
model_options = model_options.copy()
|
||||
model_options["scaled_fp8"] = llama_scaled_fp8
|
||||
if dtype_llama is not None:
|
||||
dtype = dtype_llama
|
||||
super().__init__(device=device, dtype=dtype, model_options=model_options)
|
||||
return Omnigen2TEModel_
|
151388
comfy/text_encoders/qwen25_tokenizer/merges.txt
Normal file
151388
comfy/text_encoders/qwen25_tokenizer/merges.txt
Normal file
File diff suppressed because it is too large
Load Diff
241
comfy/text_encoders/qwen25_tokenizer/tokenizer_config.json
Normal file
241
comfy/text_encoders/qwen25_tokenizer/tokenizer_config.json
Normal file
@@ -0,0 +1,241 @@
|
||||
{
|
||||
"add_bos_token": false,
|
||||
"add_prefix_space": false,
|
||||
"added_tokens_decoder": {
|
||||
"151643": {
|
||||
"content": "<|endoftext|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"151644": {
|
||||
"content": "<|im_start|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"151645": {
|
||||
"content": "<|im_end|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"151646": {
|
||||
"content": "<|object_ref_start|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"151647": {
|
||||
"content": "<|object_ref_end|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"151648": {
|
||||
"content": "<|box_start|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"151649": {
|
||||
"content": "<|box_end|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"151650": {
|
||||
"content": "<|quad_start|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"151651": {
|
||||
"content": "<|quad_end|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"151652": {
|
||||
"content": "<|vision_start|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"151653": {
|
||||
"content": "<|vision_end|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"151654": {
|
||||
"content": "<|vision_pad|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"151655": {
|
||||
"content": "<|image_pad|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"151656": {
|
||||
"content": "<|video_pad|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"151657": {
|
||||
"content": "<tool_call>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": false
|
||||
},
|
||||
"151658": {
|
||||
"content": "</tool_call>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": false
|
||||
},
|
||||
"151659": {
|
||||
"content": "<|fim_prefix|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": false
|
||||
},
|
||||
"151660": {
|
||||
"content": "<|fim_middle|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": false
|
||||
},
|
||||
"151661": {
|
||||
"content": "<|fim_suffix|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": false
|
||||
},
|
||||
"151662": {
|
||||
"content": "<|fim_pad|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": false
|
||||
},
|
||||
"151663": {
|
||||
"content": "<|repo_name|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": false
|
||||
},
|
||||
"151664": {
|
||||
"content": "<|file_sep|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": false
|
||||
},
|
||||
"151665": {
|
||||
"content": "<|img|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"151666": {
|
||||
"content": "<|endofimg|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"151667": {
|
||||
"content": "<|meta|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"151668": {
|
||||
"content": "<|endofmeta|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
}
|
||||
},
|
||||
"additional_special_tokens": [
|
||||
"<|im_start|>",
|
||||
"<|im_end|>",
|
||||
"<|object_ref_start|>",
|
||||
"<|object_ref_end|>",
|
||||
"<|box_start|>",
|
||||
"<|box_end|>",
|
||||
"<|quad_start|>",
|
||||
"<|quad_end|>",
|
||||
"<|vision_start|>",
|
||||
"<|vision_end|>",
|
||||
"<|vision_pad|>",
|
||||
"<|image_pad|>",
|
||||
"<|video_pad|>"
|
||||
],
|
||||
"bos_token": null,
|
||||
"chat_template": "{%- if tools %}\n {{- '<|im_start|>system\\n' }}\n {%- if messages[0]['role'] == 'system' %}\n {{- messages[0]['content'] }}\n {%- else %}\n {{- 'You are a helpful assistant.' }}\n {%- endif %}\n {{- \"\\n\\n# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within <tools></tools> XML tags:\\n<tools>\" }}\n {%- for tool in tools %}\n {{- \"\\n\" }}\n {{- tool | tojson }}\n {%- endfor %}\n {{- \"\\n</tools>\\n\\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\\n<tool_call>\\n{\\\"name\\\": <function-name>, \\\"arguments\\\": <args-json-object>}\\n</tool_call><|im_end|>\\n\" }}\n{%- else %}\n {%- if messages[0]['role'] == 'system' %}\n {{- '<|im_start|>system\\n' + messages[0]['content'] + '<|im_end|>\\n' }}\n {%- else %}\n {{- '<|im_start|>system\\nYou are a helpful assistant.<|im_end|>\\n' }}\n {%- endif %}\n{%- endif %}\n{%- for message in messages %}\n {%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) or (message.role == \"assistant\" and not message.tool_calls) %}\n {{- '<|im_start|>' + message.role + '\\n' + message.content + '<|im_end|>' + '\\n' }}\n {%- elif message.role == \"assistant\" %}\n {{- '<|im_start|>' + message.role }}\n {%- if message.content %}\n {{- '\\n' + message.content }}\n {%- endif %}\n {%- for tool_call in message.tool_calls %}\n {%- if tool_call.function is defined %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- '\\n<tool_call>\\n{\"name\": \"' }}\n {{- tool_call.name }}\n {{- '\", \"arguments\": ' }}\n {{- tool_call.arguments | tojson }}\n {{- '}\\n</tool_call>' }}\n {%- endfor %}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == \"tool\" %}\n {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != \"tool\") %}\n {{- '<|im_start|>user' }}\n {%- endif %}\n {{- '\\n<tool_response>\\n' }}\n {{- message.content }}\n {{- '\\n</tool_response>' }}\n {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n {{- '<|im_end|>\\n' }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n{%- endif %}\n",
|
||||
"clean_up_tokenization_spaces": false,
|
||||
"eos_token": "<|im_end|>",
|
||||
"errors": "replace",
|
||||
"extra_special_tokens": {},
|
||||
"model_max_length": 131072,
|
||||
"pad_token": "<|endoftext|>",
|
||||
"processor_class": "Qwen2_5_VLProcessor",
|
||||
"split_special_tokens": false,
|
||||
"tokenizer_class": "Qwen2Tokenizer",
|
||||
"unk_token": null
|
||||
}
|
1
comfy/text_encoders/qwen25_tokenizer/vocab.json
Normal file
1
comfy/text_encoders/qwen25_tokenizer/vocab.json
Normal file
File diff suppressed because one or more lines are too long
@@ -146,7 +146,7 @@ class T5Attention(torch.nn.Module):
|
||||
)
|
||||
values = self.relative_attention_bias(relative_position_bucket, out_dtype=dtype) # shape (query_length, key_length, num_heads)
|
||||
values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length)
|
||||
return values
|
||||
return values.contiguous()
|
||||
|
||||
def forward(self, x, mask=None, past_bias=None, optimized_attention=None):
|
||||
q = self.q(x)
|
||||
|
@@ -997,11 +997,12 @@ def set_progress_bar_global_hook(function):
|
||||
PROGRESS_BAR_HOOK = function
|
||||
|
||||
class ProgressBar:
|
||||
def __init__(self, total):
|
||||
def __init__(self, total, node_id=None):
|
||||
global PROGRESS_BAR_HOOK
|
||||
self.total = total
|
||||
self.current = 0
|
||||
self.hook = PROGRESS_BAR_HOOK
|
||||
self.node_id = node_id
|
||||
|
||||
def update_absolute(self, value, total=None, preview=None):
|
||||
if total is not None:
|
||||
@@ -1010,7 +1011,7 @@ class ProgressBar:
|
||||
value = self.total
|
||||
self.current = value
|
||||
if self.hook is not None:
|
||||
self.hook(self.current, self.total, preview)
|
||||
self.hook(self.current, self.total, preview, node_id=self.node_id)
|
||||
|
||||
def update(self, value):
|
||||
self.update_absolute(self.current + value)
|
||||
|
@@ -1,4 +1,4 @@
|
||||
from .base import WeightAdapterBase
|
||||
from .base import WeightAdapterBase, WeightAdapterTrainBase
|
||||
from .lora import LoRAAdapter
|
||||
from .loha import LoHaAdapter
|
||||
from .lokr import LoKrAdapter
|
||||
@@ -15,3 +15,9 @@ adapters: list[type[WeightAdapterBase]] = [
|
||||
OFTAdapter,
|
||||
BOFTAdapter,
|
||||
]
|
||||
|
||||
__all__ = [
|
||||
"WeightAdapterBase",
|
||||
"WeightAdapterTrainBase",
|
||||
"adapters"
|
||||
] + [a.__name__ for a in adapters]
|
||||
|
@@ -12,12 +12,20 @@ class WeightAdapterBase:
|
||||
weights: list[torch.Tensor]
|
||||
|
||||
@classmethod
|
||||
def load(cls, x: str, lora: dict[str, torch.Tensor]) -> Optional["WeightAdapterBase"]:
|
||||
def load(cls, x: str, lora: dict[str, torch.Tensor], alpha: float, dora_scale: torch.Tensor) -> Optional["WeightAdapterBase"]:
|
||||
raise NotImplementedError
|
||||
|
||||
def to_train(self) -> "WeightAdapterTrainBase":
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def create_train(cls, weight, *args) -> "WeightAdapterTrainBase":
|
||||
"""
|
||||
weight: The original weight tensor to be modified.
|
||||
*args: Additional arguments for configuration, such as rank, alpha etc.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def calculate_weight(
|
||||
self,
|
||||
weight,
|
||||
@@ -33,10 +41,22 @@ class WeightAdapterBase:
|
||||
|
||||
|
||||
class WeightAdapterTrainBase(nn.Module):
|
||||
# We follow the scheme of PR #7032
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
# [TODO] Collaborate with LoRA training PR #7032
|
||||
def __call__(self, w):
|
||||
"""
|
||||
w: The original weight tensor to be modified.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def passive_memory_usage(self):
|
||||
raise NotImplementedError("passive_memory_usage is not implemented")
|
||||
|
||||
def move_to(self, device):
|
||||
self.to(device)
|
||||
return self.passive_memory_usage()
|
||||
|
||||
|
||||
def weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function):
|
||||
@@ -102,3 +122,14 @@ def pad_tensor_to_shape(tensor: torch.Tensor, new_shape: list[int]) -> torch.Ten
|
||||
padded_tensor[new_slices] = tensor[orig_slices]
|
||||
|
||||
return padded_tensor
|
||||
|
||||
|
||||
def tucker_weight_from_conv(up, down, mid):
|
||||
up = up.reshape(up.size(0), up.size(1))
|
||||
down = down.reshape(down.size(0), down.size(1))
|
||||
return torch.einsum("m n ..., i m, n j -> i j ...", mid, up, down)
|
||||
|
||||
|
||||
def tucker_weight(wa, wb, t):
|
||||
temp = torch.einsum("i j ..., j r -> i r ...", t, wb)
|
||||
return torch.einsum("i j ..., i r -> r j ...", temp, wa)
|
||||
|
@@ -3,7 +3,56 @@ from typing import Optional
|
||||
|
||||
import torch
|
||||
import comfy.model_management
|
||||
from .base import WeightAdapterBase, weight_decompose, pad_tensor_to_shape
|
||||
from .base import (
|
||||
WeightAdapterBase,
|
||||
WeightAdapterTrainBase,
|
||||
weight_decompose,
|
||||
pad_tensor_to_shape,
|
||||
tucker_weight_from_conv,
|
||||
)
|
||||
|
||||
|
||||
class LoraDiff(WeightAdapterTrainBase):
|
||||
def __init__(self, weights):
|
||||
super().__init__()
|
||||
mat1, mat2, alpha, mid, dora_scale, reshape = weights
|
||||
out_dim, rank = mat1.shape[0], mat1.shape[1]
|
||||
rank, in_dim = mat2.shape[0], mat2.shape[1]
|
||||
if mid is not None:
|
||||
convdim = mid.ndim - 2
|
||||
layer = (
|
||||
torch.nn.Conv1d,
|
||||
torch.nn.Conv2d,
|
||||
torch.nn.Conv3d
|
||||
)[convdim]
|
||||
else:
|
||||
layer = torch.nn.Linear
|
||||
self.lora_up = layer(rank, out_dim, bias=False)
|
||||
self.lora_down = layer(in_dim, rank, bias=False)
|
||||
self.lora_up.weight.data.copy_(mat1)
|
||||
self.lora_down.weight.data.copy_(mat2)
|
||||
if mid is not None:
|
||||
self.lora_mid = layer(mid, rank, bias=False)
|
||||
self.lora_mid.weight.data.copy_(mid)
|
||||
else:
|
||||
self.lora_mid = None
|
||||
self.rank = rank
|
||||
self.alpha = torch.nn.Parameter(torch.tensor(alpha), requires_grad=False)
|
||||
|
||||
def __call__(self, w):
|
||||
org_dtype = w.dtype
|
||||
if self.lora_mid is None:
|
||||
diff = self.lora_up.weight @ self.lora_down.weight
|
||||
else:
|
||||
diff = tucker_weight_from_conv(
|
||||
self.lora_up.weight, self.lora_down.weight, self.lora_mid.weight
|
||||
)
|
||||
scale = self.alpha / self.rank
|
||||
weight = w + scale * diff.reshape(w.shape)
|
||||
return weight.to(org_dtype)
|
||||
|
||||
def passive_memory_usage(self):
|
||||
return sum(param.numel() * param.element_size() for param in self.parameters())
|
||||
|
||||
|
||||
class LoRAAdapter(WeightAdapterBase):
|
||||
@@ -13,6 +62,21 @@ class LoRAAdapter(WeightAdapterBase):
|
||||
self.loaded_keys = loaded_keys
|
||||
self.weights = weights
|
||||
|
||||
@classmethod
|
||||
def create_train(cls, weight, rank=1, alpha=1.0):
|
||||
out_dim = weight.shape[0]
|
||||
in_dim = weight.shape[1:].numel()
|
||||
mat1 = torch.empty(out_dim, rank, device=weight.device, dtype=weight.dtype)
|
||||
mat2 = torch.empty(rank, in_dim, device=weight.device, dtype=weight.dtype)
|
||||
torch.nn.init.kaiming_uniform_(mat1, a=5**0.5)
|
||||
torch.nn.init.constant_(mat2, 0.0)
|
||||
return LoraDiff(
|
||||
(mat1, mat2, alpha, None, None, None)
|
||||
)
|
||||
|
||||
def to_train(self):
|
||||
return LoraDiff(self.weights)
|
||||
|
||||
@classmethod
|
||||
def load(
|
||||
cls,
|
||||
|
@@ -125,22 +125,6 @@ class BFLFluxKontextProGenerateRequest(BaseModel):
|
||||
None, description='Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation.'
|
||||
)
|
||||
|
||||
class BFLFluxKontextMaxGenerateRequest(BaseModel):
|
||||
prompt: str = Field(..., description='The text prompt for what you wannt to edit.')
|
||||
input_image: Optional[str] = Field(None, description='Image to edit in base64 format')
|
||||
seed: Optional[int] = Field(None, description='The seed value for reproducibility.')
|
||||
guidance: confloat(ge=0.1, le=99.0) = Field(..., description='Guidance strength for the image generation process')
|
||||
steps: conint(ge=1, le=150) = Field(..., description='Number of steps for the image generation process')
|
||||
safety_tolerance: Optional[conint(ge=0, le=2)] = Field(
|
||||
2, description='Tolerance level for input and output moderation. Between 0 and 2, 0 being most strict, 6 being least strict. Defaults to 2.'
|
||||
)
|
||||
output_format: Optional[BFLOutputFormat] = Field(
|
||||
BFLOutputFormat.png, description="Output format for the generated image. Can be 'jpeg' or 'png'.", examples=['png']
|
||||
)
|
||||
aspect_ratio: Optional[str] = Field(None, description='Aspect ratio of the image between 21:9 and 9:21.')
|
||||
prompt_upsampling: Optional[bool] = Field(
|
||||
None, description='Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation.'
|
||||
)
|
||||
|
||||
class BFLFluxProUltraGenerateRequest(BaseModel):
|
||||
prompt: str = Field(..., description='The text prompt for image generation.')
|
||||
|
@@ -327,7 +327,9 @@ class ApiClient:
|
||||
ApiServerError: If the API server is unreachable but internet is working
|
||||
Exception: For other request failures
|
||||
"""
|
||||
url = urljoin(self.base_url, path)
|
||||
# Use urljoin but ensure path is relative to avoid absolute path behavior
|
||||
relative_path = path.lstrip('/')
|
||||
url = urljoin(self.base_url, relative_path)
|
||||
self.check_auth(self.auth_token, self.comfy_api_key)
|
||||
# Combine default headers with any provided headers
|
||||
request_headers = self.get_headers()
|
||||
|
@@ -272,7 +272,7 @@ class FluxProUltraImageNode(ComfyNodeABC):
|
||||
|
||||
class FluxKontextProImageNode(ComfyNodeABC):
|
||||
"""
|
||||
Edits images using Flux.1 Kontext Pro via api based on prompt and resolution.
|
||||
Edits images using Flux.1 Kontext [pro] via api based on prompt and aspect ratio.
|
||||
"""
|
||||
|
||||
MINIMUM_RATIO = 1 / 4
|
||||
@@ -321,7 +321,7 @@ class FluxKontextProImageNode(ComfyNodeABC):
|
||||
"seed": (
|
||||
IO.INT,
|
||||
{
|
||||
"default": 0,
|
||||
"default": 1234,
|
||||
"min": 0,
|
||||
"max": 0xFFFFFFFFFFFFFFFF,
|
||||
"control_after_generate": True,
|
||||
@@ -346,26 +346,14 @@ class FluxKontextProImageNode(ComfyNodeABC):
|
||||
},
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def VALIDATE_INPUTS(cls, aspect_ratio: str):
|
||||
try:
|
||||
validate_aspect_ratio(
|
||||
aspect_ratio,
|
||||
minimum_ratio=cls.MINIMUM_RATIO,
|
||||
maximum_ratio=cls.MAXIMUM_RATIO,
|
||||
minimum_ratio_str=cls.MINIMUM_RATIO_STR,
|
||||
maximum_ratio_str=cls.MAXIMUM_RATIO_STR,
|
||||
)
|
||||
except Exception as e:
|
||||
return str(e)
|
||||
return True
|
||||
|
||||
RETURN_TYPES = (IO.IMAGE,)
|
||||
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
|
||||
FUNCTION = "api_call"
|
||||
API_NODE = True
|
||||
CATEGORY = "api node/image/BFL"
|
||||
|
||||
BFL_PATH = "/proxy/bfl/flux-kontext-pro/generate"
|
||||
|
||||
def api_call(
|
||||
self,
|
||||
prompt: str,
|
||||
@@ -378,11 +366,18 @@ class FluxKontextProImageNode(ComfyNodeABC):
|
||||
unique_id: Union[str, None] = None,
|
||||
**kwargs,
|
||||
):
|
||||
aspect_ratio = validate_aspect_ratio(
|
||||
aspect_ratio,
|
||||
minimum_ratio=self.MINIMUM_RATIO,
|
||||
maximum_ratio=self.MAXIMUM_RATIO,
|
||||
minimum_ratio_str=self.MINIMUM_RATIO_STR,
|
||||
maximum_ratio_str=self.MAXIMUM_RATIO_STR,
|
||||
)
|
||||
if input_image is None:
|
||||
validate_string(prompt, strip_whitespace=False)
|
||||
operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/bfl/flux-kontext-pro/generate",
|
||||
path=self.BFL_PATH,
|
||||
method=HttpMethod.POST,
|
||||
request_model=BFLFluxKontextProGenerateRequest,
|
||||
response_model=BFLFluxProGenerateResponse,
|
||||
@@ -393,13 +388,7 @@ class FluxKontextProImageNode(ComfyNodeABC):
|
||||
guidance=round(guidance, 1),
|
||||
steps=steps,
|
||||
seed=seed,
|
||||
aspect_ratio=validate_aspect_ratio(
|
||||
aspect_ratio,
|
||||
minimum_ratio=self.MINIMUM_RATIO,
|
||||
maximum_ratio=self.MAXIMUM_RATIO,
|
||||
minimum_ratio_str=self.MINIMUM_RATIO_STR,
|
||||
maximum_ratio_str=self.MAXIMUM_RATIO_STR,
|
||||
),
|
||||
aspect_ratio=aspect_ratio,
|
||||
input_image=(
|
||||
input_image
|
||||
if input_image is None
|
||||
@@ -411,146 +400,15 @@ class FluxKontextProImageNode(ComfyNodeABC):
|
||||
output_image = handle_bfl_synchronous_operation(operation, node_id=unique_id)
|
||||
return (output_image,)
|
||||
|
||||
class FluxKontextMaxImageNode(ComfyNodeABC):
|
||||
|
||||
class FluxKontextMaxImageNode(FluxKontextProImageNode):
|
||||
"""
|
||||
Edits images using Flux.1 Kontext Max via api based on prompt and resolution.
|
||||
Edits images using Flux.1 Kontext [max] via api based on prompt and aspect ratio.
|
||||
"""
|
||||
|
||||
MINIMUM_RATIO = 1 / 4
|
||||
MAXIMUM_RATIO = 4 / 1
|
||||
MINIMUM_RATIO_STR = "1:4"
|
||||
MAXIMUM_RATIO_STR = "4:1"
|
||||
DESCRIPTION = cleandoc(__doc__ or "")
|
||||
BFL_PATH = "/proxy/bfl/flux-kontext-max/generate"
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"prompt": (
|
||||
IO.STRING,
|
||||
{
|
||||
"multiline": True,
|
||||
"default": "",
|
||||
"tooltip": "Prompt for the image generation - specify what and how to edit.",
|
||||
},
|
||||
),
|
||||
"aspect_ratio": (
|
||||
IO.STRING,
|
||||
{
|
||||
"default": "16:9",
|
||||
"tooltip": "Aspect ratio of image; must be between 1:4 and 4:1.",
|
||||
},
|
||||
),
|
||||
"guidance": (
|
||||
IO.FLOAT,
|
||||
{
|
||||
"default": 3.0,
|
||||
"min": 0.1,
|
||||
"max": 99.0,
|
||||
"step": 0.1,
|
||||
"tooltip": "Guidance strength for the image generation process"
|
||||
},
|
||||
),
|
||||
"steps": (
|
||||
IO.INT,
|
||||
{
|
||||
"default": 50,
|
||||
"min": 1,
|
||||
"max": 150,
|
||||
"tooltip": "Number of steps for the image generation process"
|
||||
},
|
||||
),
|
||||
"seed": (
|
||||
IO.INT,
|
||||
{
|
||||
"default": 0,
|
||||
"min": 0,
|
||||
"max": 0xFFFFFFFFFFFFFFFF,
|
||||
"control_after_generate": True,
|
||||
"tooltip": "The random seed used for creating the noise.",
|
||||
},
|
||||
),
|
||||
"prompt_upsampling": (
|
||||
IO.BOOLEAN,
|
||||
{
|
||||
"default": False,
|
||||
"tooltip": "Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation, but results are nondeterministic (same seed will not produce exactly the same result).",
|
||||
},
|
||||
),
|
||||
},
|
||||
"optional": {
|
||||
"input_image": (IO.IMAGE,),
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def VALIDATE_INPUTS(cls, aspect_ratio: str):
|
||||
try:
|
||||
validate_aspect_ratio(
|
||||
aspect_ratio,
|
||||
minimum_ratio=cls.MINIMUM_RATIO,
|
||||
maximum_ratio=cls.MAXIMUM_RATIO,
|
||||
minimum_ratio_str=cls.MINIMUM_RATIO_STR,
|
||||
maximum_ratio_str=cls.MAXIMUM_RATIO_STR,
|
||||
)
|
||||
except Exception as e:
|
||||
return str(e)
|
||||
return True
|
||||
|
||||
RETURN_TYPES = (IO.IMAGE,)
|
||||
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
|
||||
FUNCTION = "api_call"
|
||||
API_NODE = True
|
||||
CATEGORY = "api node/image/BFL"
|
||||
|
||||
def api_call(
|
||||
self,
|
||||
prompt: str,
|
||||
aspect_ratio: str,
|
||||
guidance: float,
|
||||
steps: int,
|
||||
input_image: Optional[torch.Tensor]=None,
|
||||
seed=0,
|
||||
prompt_upsampling=False,
|
||||
unique_id: Union[str, None] = None,
|
||||
**kwargs,
|
||||
):
|
||||
if input_image is None:
|
||||
validate_string(prompt, strip_whitespace=False)
|
||||
operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/bfl/flux-kontext-max/generate",
|
||||
method=HttpMethod.POST,
|
||||
request_model=BFLFluxKontextProGenerateRequest,
|
||||
response_model=BFLFluxProGenerateResponse,
|
||||
),
|
||||
request=BFLFluxKontextProGenerateRequest(
|
||||
prompt=prompt,
|
||||
prompt_upsampling=prompt_upsampling,
|
||||
guidance=round(guidance, 1),
|
||||
steps=steps,
|
||||
seed=seed,
|
||||
aspect_ratio=validate_aspect_ratio(
|
||||
aspect_ratio,
|
||||
minimum_ratio=self.MINIMUM_RATIO,
|
||||
maximum_ratio=self.MAXIMUM_RATIO,
|
||||
minimum_ratio_str=self.MINIMUM_RATIO_STR,
|
||||
maximum_ratio_str=self.MAXIMUM_RATIO_STR,
|
||||
),
|
||||
input_image=(
|
||||
input_image
|
||||
if input_image is None
|
||||
else convert_image_to_base64(input_image)
|
||||
)
|
||||
),
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
output_image = handle_bfl_synchronous_operation(operation, node_id=unique_id)
|
||||
return (output_image,)
|
||||
|
||||
class FluxProImageNode(ComfyNodeABC):
|
||||
"""
|
||||
@@ -1208,8 +1066,8 @@ NODE_CLASS_MAPPINGS = {
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"FluxProUltraImageNode": "Flux 1.1 [pro] Ultra Image",
|
||||
# "FluxProImageNode": "Flux 1.1 [pro] Image",
|
||||
"FluxKontextProImageNode": "Flux.1 Kontext Pro Image",
|
||||
"FluxKontextMaxImageNode": "Flux.1 Kontext Max Image",
|
||||
"FluxKontextProImageNode": "Flux.1 Kontext [pro] Image",
|
||||
"FluxKontextMaxImageNode": "Flux.1 Kontext [max] Image",
|
||||
"FluxProExpandNode": "Flux.1 Expand Image",
|
||||
"FluxProFillNode": "Flux.1 Fill Image",
|
||||
"FluxProCannyNode": "Flux.1 Canny Control Image",
|
||||
|
@@ -324,7 +324,7 @@ class IdeogramV1(ComfyNodeABC):
|
||||
|
||||
RETURN_TYPES = (IO.IMAGE,)
|
||||
FUNCTION = "api_call"
|
||||
CATEGORY = "api node/image/Ideogram/v1"
|
||||
CATEGORY = "api node/image/Ideogram"
|
||||
DESCRIPTION = cleandoc(__doc__ or "")
|
||||
API_NODE = True
|
||||
|
||||
@@ -483,7 +483,7 @@ class IdeogramV2(ComfyNodeABC):
|
||||
|
||||
RETURN_TYPES = (IO.IMAGE,)
|
||||
FUNCTION = "api_call"
|
||||
CATEGORY = "api node/image/Ideogram/v2"
|
||||
CATEGORY = "api node/image/Ideogram"
|
||||
DESCRIPTION = cleandoc(__doc__ or "")
|
||||
API_NODE = True
|
||||
|
||||
@@ -649,7 +649,7 @@ class IdeogramV3(ComfyNodeABC):
|
||||
|
||||
RETURN_TYPES = (IO.IMAGE,)
|
||||
FUNCTION = "api_call"
|
||||
CATEGORY = "api node/image/Ideogram/v3"
|
||||
CATEGORY = "api node/image/Ideogram"
|
||||
DESCRIPTION = cleandoc(__doc__ or "")
|
||||
API_NODE = True
|
||||
|
||||
|
152
comfy_config/config_parser.py
Normal file
152
comfy_config/config_parser.py
Normal file
@@ -0,0 +1,152 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from pydantic_settings import PydanticBaseSettingsSource, TomlConfigSettingsSource
|
||||
|
||||
from comfy_config.types import (
|
||||
ComfyConfig,
|
||||
ProjectConfig,
|
||||
PyProjectConfig,
|
||||
PyProjectSettings
|
||||
)
|
||||
|
||||
def validate_and_extract_os_classifiers(classifiers: list) -> list:
|
||||
os_classifiers = [c for c in classifiers if c.startswith("Operating System :: ")]
|
||||
if not os_classifiers:
|
||||
return []
|
||||
|
||||
os_values = [c[len("Operating System :: ") :] for c in os_classifiers]
|
||||
valid_os_prefixes = {"Microsoft", "POSIX", "MacOS", "OS Independent"}
|
||||
|
||||
for os_value in os_values:
|
||||
if not any(os_value.startswith(prefix) for prefix in valid_os_prefixes):
|
||||
return []
|
||||
|
||||
return os_values
|
||||
|
||||
|
||||
def validate_and_extract_accelerator_classifiers(classifiers: list) -> list:
|
||||
accelerator_classifiers = [c for c in classifiers if c.startswith("Environment ::")]
|
||||
if not accelerator_classifiers:
|
||||
return []
|
||||
|
||||
accelerator_values = [c[len("Environment :: ") :] for c in accelerator_classifiers]
|
||||
|
||||
valid_accelerators = {
|
||||
"GPU :: NVIDIA CUDA",
|
||||
"GPU :: AMD ROCm",
|
||||
"GPU :: Intel Arc",
|
||||
"NPU :: Huawei Ascend",
|
||||
"GPU :: Apple Metal",
|
||||
}
|
||||
|
||||
for accelerator_value in accelerator_values:
|
||||
if accelerator_value not in valid_accelerators:
|
||||
return []
|
||||
|
||||
return accelerator_values
|
||||
|
||||
|
||||
"""
|
||||
Extract configuration from a custom node directory's pyproject.toml file or a Python file.
|
||||
|
||||
This function reads and parses the pyproject.toml file in the specified directory
|
||||
to extract project and ComfyUI-specific configuration information. If no
|
||||
pyproject.toml file is found, it creates a minimal configuration using the
|
||||
folder name as the project name. If a Python file is provided, it uses the
|
||||
file name (without extension) as the project name.
|
||||
|
||||
Args:
|
||||
path (str): Path to the directory containing the pyproject.toml file, or
|
||||
path to a .py file. If pyproject.toml doesn't exist in a directory,
|
||||
the folder name will be used as the default project name. If a .py
|
||||
file is provided, the filename (without .py extension) will be used
|
||||
as the project name.
|
||||
|
||||
Returns:
|
||||
Optional[PyProjectConfig]: A PyProjectConfig object containing:
|
||||
- project: Basic project information (name, version, dependencies, etc.)
|
||||
- tool_comfy: ComfyUI-specific configuration (publisher_id, models, etc.)
|
||||
Returns None if configuration extraction fails or if the provided file
|
||||
is not a Python file.
|
||||
|
||||
Notes:
|
||||
- If pyproject.toml is missing in a directory, creates a default config with folder name
|
||||
- If a .py file is provided, creates a default config with filename (without extension)
|
||||
- Returns None for non-Python files
|
||||
|
||||
Example:
|
||||
>>> from comfy_config import config_parser
|
||||
>>> # For directory
|
||||
>>> custom_node_dir = os.path.dirname(os.path.realpath(__file__))
|
||||
>>> project_config = config_parser.extract_node_configuration(custom_node_dir)
|
||||
>>> print(project_config.project.name) # "my_custom_node" or name from pyproject.toml
|
||||
>>>
|
||||
>>> # For single-file Python node file
|
||||
>>> py_file_path = os.path.realpath(__file__) # "/path/to/my_node.py"
|
||||
>>> project_config = config_parser.extract_node_configuration(py_file_path)
|
||||
>>> print(project_config.project.name) # "my_node"
|
||||
"""
|
||||
def extract_node_configuration(path) -> Optional[PyProjectConfig]:
|
||||
if os.path.isfile(path):
|
||||
file_path = Path(path)
|
||||
|
||||
if file_path.suffix.lower() != '.py':
|
||||
return None
|
||||
|
||||
project_name = file_path.stem
|
||||
project = ProjectConfig(name=project_name)
|
||||
comfy = ComfyConfig()
|
||||
return PyProjectConfig(project=project, tool_comfy=comfy)
|
||||
|
||||
folder_name = os.path.basename(path)
|
||||
toml_path = Path(path) / "pyproject.toml"
|
||||
|
||||
if not toml_path.exists():
|
||||
project = ProjectConfig(name=folder_name)
|
||||
comfy = ComfyConfig()
|
||||
return PyProjectConfig(project=project, tool_comfy=comfy)
|
||||
|
||||
raw_settings = load_pyproject_settings(toml_path)
|
||||
|
||||
project_data = raw_settings.project
|
||||
|
||||
tool_data = raw_settings.tool
|
||||
comfy_data = tool_data.get("comfy", {}) if tool_data else {}
|
||||
|
||||
dependencies = project_data.get("dependencies", [])
|
||||
supported_comfyui_frontend_version = ""
|
||||
for dep in dependencies:
|
||||
if isinstance(dep, str) and dep.startswith("comfyui-frontend-package"):
|
||||
supported_comfyui_frontend_version = dep.removeprefix("comfyui-frontend-package")
|
||||
break
|
||||
|
||||
supported_comfyui_version = comfy_data.get("requires-comfyui", "")
|
||||
|
||||
classifiers = project_data.get('classifiers', [])
|
||||
supported_os = validate_and_extract_os_classifiers(classifiers)
|
||||
supported_accelerators = validate_and_extract_accelerator_classifiers(classifiers)
|
||||
|
||||
project_data['supported_os'] = supported_os
|
||||
project_data['supported_accelerators'] = supported_accelerators
|
||||
project_data['supported_comfyui_frontend_version'] = supported_comfyui_frontend_version
|
||||
project_data['supported_comfyui_version'] = supported_comfyui_version
|
||||
|
||||
return PyProjectConfig(project=project_data, tool_comfy=comfy_data)
|
||||
|
||||
|
||||
def load_pyproject_settings(toml_path: Path) -> PyProjectSettings:
|
||||
class PyProjectLoader(PyProjectSettings):
|
||||
@classmethod
|
||||
def settings_customise_sources(
|
||||
cls,
|
||||
settings_cls,
|
||||
init_settings: PydanticBaseSettingsSource,
|
||||
env_settings: PydanticBaseSettingsSource,
|
||||
dotenv_settings: PydanticBaseSettingsSource,
|
||||
file_secret_settings: PydanticBaseSettingsSource,
|
||||
):
|
||||
return (TomlConfigSettingsSource(settings_cls, toml_path),)
|
||||
|
||||
return PyProjectLoader()
|
97
comfy_config/types.py
Normal file
97
comfy_config/types.py
Normal file
@@ -0,0 +1,97 @@
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
from typing import List, Optional
|
||||
|
||||
# IMPORTANT: The type definitions specified in pyproject.toml for custom nodes
|
||||
# must remain synchronized with the corresponding files in the https://github.com/Comfy-Org/comfy-cli/blob/main/comfy_cli/registry/types.py.
|
||||
# Any changes to one must be reflected in the other to maintain consistency.
|
||||
|
||||
class NodeVersion(BaseModel):
|
||||
changelog: str
|
||||
dependencies: List[str]
|
||||
deprecated: bool
|
||||
id: str
|
||||
version: str
|
||||
download_url: str
|
||||
|
||||
|
||||
class Node(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
description: str
|
||||
author: Optional[str] = None
|
||||
license: Optional[str] = None
|
||||
icon: Optional[str] = None
|
||||
repository: Optional[str] = None
|
||||
tags: List[str] = Field(default_factory=list)
|
||||
latest_version: Optional[NodeVersion] = None
|
||||
|
||||
|
||||
class PublishNodeVersionResponse(BaseModel):
|
||||
node_version: NodeVersion
|
||||
signedUrl: str
|
||||
|
||||
|
||||
class URLs(BaseModel):
|
||||
homepage: str = Field(default="", alias="Homepage")
|
||||
documentation: str = Field(default="", alias="Documentation")
|
||||
repository: str = Field(default="", alias="Repository")
|
||||
issues: str = Field(default="", alias="Issues")
|
||||
|
||||
|
||||
class Model(BaseModel):
|
||||
location: str
|
||||
model_url: str
|
||||
|
||||
|
||||
class ComfyConfig(BaseModel):
|
||||
publisher_id: str = Field(default="", alias="PublisherId")
|
||||
display_name: str = Field(default="", alias="DisplayName")
|
||||
icon: str = Field(default="", alias="Icon")
|
||||
models: List[Model] = Field(default_factory=list, alias="Models")
|
||||
includes: List[str] = Field(default_factory=list)
|
||||
web: Optional[str] = None
|
||||
banner_url: str = ""
|
||||
|
||||
class License(BaseModel):
|
||||
file: str = ""
|
||||
text: str = ""
|
||||
|
||||
|
||||
class ProjectConfig(BaseModel):
|
||||
name: str = ""
|
||||
description: str = ""
|
||||
version: str = "1.0.0"
|
||||
requires_python: str = Field(default=">= 3.9", alias="requires-python")
|
||||
dependencies: List[str] = Field(default_factory=list)
|
||||
license: License = Field(default_factory=License)
|
||||
urls: URLs = Field(default_factory=URLs)
|
||||
supported_os: List[str] = Field(default_factory=list)
|
||||
supported_accelerators: List[str] = Field(default_factory=list)
|
||||
supported_comfyui_version: str = ""
|
||||
supported_comfyui_frontend_version: str = ""
|
||||
|
||||
@field_validator('license', mode='before')
|
||||
@classmethod
|
||||
def validate_license(cls, v):
|
||||
if isinstance(v, str):
|
||||
return License(text=v)
|
||||
elif isinstance(v, dict):
|
||||
return License(**v)
|
||||
elif isinstance(v, License):
|
||||
return v
|
||||
else:
|
||||
return License()
|
||||
|
||||
|
||||
class PyProjectConfig(BaseModel):
|
||||
project: ProjectConfig = Field(default_factory=ProjectConfig)
|
||||
tool_comfy: ComfyConfig = Field(default_factory=ComfyConfig)
|
||||
|
||||
|
||||
class PyProjectSettings(BaseSettings):
|
||||
project: dict = Field(default_factory=dict)
|
||||
|
||||
tool: dict = Field(default_factory=dict)
|
||||
|
||||
model_config = SettingsConfigDict(extra='allow')
|
@@ -1,6 +1,7 @@
|
||||
import itertools
|
||||
from typing import Sequence, Mapping, Dict
|
||||
from comfy_execution.graph import DynamicPrompt
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
import nodes
|
||||
|
||||
@@ -16,12 +17,13 @@ def include_unique_id_in_input(class_type: str) -> bool:
|
||||
NODE_CLASS_CONTAINS_UNIQUE_ID[class_type] = "UNIQUE_ID" in class_def.INPUT_TYPES().get("hidden", {}).values()
|
||||
return NODE_CLASS_CONTAINS_UNIQUE_ID[class_type]
|
||||
|
||||
class CacheKeySet:
|
||||
class CacheKeySet(ABC):
|
||||
def __init__(self, dynprompt, node_ids, is_changed_cache):
|
||||
self.keys = {}
|
||||
self.subcache_keys = {}
|
||||
|
||||
def add_keys(self, node_ids):
|
||||
@abstractmethod
|
||||
async def add_keys(self, node_ids):
|
||||
raise NotImplementedError()
|
||||
|
||||
def all_node_ids(self):
|
||||
@@ -60,9 +62,8 @@ class CacheKeySetID(CacheKeySet):
|
||||
def __init__(self, dynprompt, node_ids, is_changed_cache):
|
||||
super().__init__(dynprompt, node_ids, is_changed_cache)
|
||||
self.dynprompt = dynprompt
|
||||
self.add_keys(node_ids)
|
||||
|
||||
def add_keys(self, node_ids):
|
||||
async def add_keys(self, node_ids):
|
||||
for node_id in node_ids:
|
||||
if node_id in self.keys:
|
||||
continue
|
||||
@@ -77,37 +78,36 @@ class CacheKeySetInputSignature(CacheKeySet):
|
||||
super().__init__(dynprompt, node_ids, is_changed_cache)
|
||||
self.dynprompt = dynprompt
|
||||
self.is_changed_cache = is_changed_cache
|
||||
self.add_keys(node_ids)
|
||||
|
||||
def include_node_id_in_input(self) -> bool:
|
||||
return False
|
||||
|
||||
def add_keys(self, node_ids):
|
||||
async def add_keys(self, node_ids):
|
||||
for node_id in node_ids:
|
||||
if node_id in self.keys:
|
||||
continue
|
||||
if not self.dynprompt.has_node(node_id):
|
||||
continue
|
||||
node = self.dynprompt.get_node(node_id)
|
||||
self.keys[node_id] = self.get_node_signature(self.dynprompt, node_id)
|
||||
self.keys[node_id] = await self.get_node_signature(self.dynprompt, node_id)
|
||||
self.subcache_keys[node_id] = (node_id, node["class_type"])
|
||||
|
||||
def get_node_signature(self, dynprompt, node_id):
|
||||
async def get_node_signature(self, dynprompt, node_id):
|
||||
signature = []
|
||||
ancestors, order_mapping = self.get_ordered_ancestry(dynprompt, node_id)
|
||||
signature.append(self.get_immediate_node_signature(dynprompt, node_id, order_mapping))
|
||||
signature.append(await self.get_immediate_node_signature(dynprompt, node_id, order_mapping))
|
||||
for ancestor_id in ancestors:
|
||||
signature.append(self.get_immediate_node_signature(dynprompt, ancestor_id, order_mapping))
|
||||
signature.append(await self.get_immediate_node_signature(dynprompt, ancestor_id, order_mapping))
|
||||
return to_hashable(signature)
|
||||
|
||||
def get_immediate_node_signature(self, dynprompt, node_id, ancestor_order_mapping):
|
||||
async def get_immediate_node_signature(self, dynprompt, node_id, ancestor_order_mapping):
|
||||
if not dynprompt.has_node(node_id):
|
||||
# This node doesn't exist -- we can't cache it.
|
||||
return [float("NaN")]
|
||||
node = dynprompt.get_node(node_id)
|
||||
class_type = node["class_type"]
|
||||
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
||||
signature = [class_type, self.is_changed_cache.get(node_id)]
|
||||
signature = [class_type, await self.is_changed_cache.get(node_id)]
|
||||
if self.include_node_id_in_input() or (hasattr(class_def, "NOT_IDEMPOTENT") and class_def.NOT_IDEMPOTENT) or include_unique_id_in_input(class_type):
|
||||
signature.append(node_id)
|
||||
inputs = node["inputs"]
|
||||
@@ -150,9 +150,10 @@ class BasicCache:
|
||||
self.cache = {}
|
||||
self.subcaches = {}
|
||||
|
||||
def set_prompt(self, dynprompt, node_ids, is_changed_cache):
|
||||
async def set_prompt(self, dynprompt, node_ids, is_changed_cache):
|
||||
self.dynprompt = dynprompt
|
||||
self.cache_key_set = self.key_class(dynprompt, node_ids, is_changed_cache)
|
||||
await self.cache_key_set.add_keys(node_ids)
|
||||
self.is_changed_cache = is_changed_cache
|
||||
self.initialized = True
|
||||
|
||||
@@ -201,13 +202,13 @@ class BasicCache:
|
||||
else:
|
||||
return None
|
||||
|
||||
def _ensure_subcache(self, node_id, children_ids):
|
||||
async def _ensure_subcache(self, node_id, children_ids):
|
||||
subcache_key = self.cache_key_set.get_subcache_key(node_id)
|
||||
subcache = self.subcaches.get(subcache_key, None)
|
||||
if subcache is None:
|
||||
subcache = BasicCache(self.key_class)
|
||||
self.subcaches[subcache_key] = subcache
|
||||
subcache.set_prompt(self.dynprompt, children_ids, self.is_changed_cache)
|
||||
await subcache.set_prompt(self.dynprompt, children_ids, self.is_changed_cache)
|
||||
return subcache
|
||||
|
||||
def _get_subcache(self, node_id):
|
||||
@@ -259,10 +260,10 @@ class HierarchicalCache(BasicCache):
|
||||
assert cache is not None
|
||||
cache._set_immediate(node_id, value)
|
||||
|
||||
def ensure_subcache_for(self, node_id, children_ids):
|
||||
async def ensure_subcache_for(self, node_id, children_ids):
|
||||
cache = self._get_cache_for(node_id)
|
||||
assert cache is not None
|
||||
return cache._ensure_subcache(node_id, children_ids)
|
||||
return await cache._ensure_subcache(node_id, children_ids)
|
||||
|
||||
class LRUCache(BasicCache):
|
||||
def __init__(self, key_class, max_size=100):
|
||||
@@ -273,8 +274,8 @@ class LRUCache(BasicCache):
|
||||
self.used_generation = {}
|
||||
self.children = {}
|
||||
|
||||
def set_prompt(self, dynprompt, node_ids, is_changed_cache):
|
||||
super().set_prompt(dynprompt, node_ids, is_changed_cache)
|
||||
async def set_prompt(self, dynprompt, node_ids, is_changed_cache):
|
||||
await super().set_prompt(dynprompt, node_ids, is_changed_cache)
|
||||
self.generation += 1
|
||||
for node_id in node_ids:
|
||||
self._mark_used(node_id)
|
||||
@@ -303,11 +304,11 @@ class LRUCache(BasicCache):
|
||||
self._mark_used(node_id)
|
||||
return self._set_immediate(node_id, value)
|
||||
|
||||
def ensure_subcache_for(self, node_id, children_ids):
|
||||
async def ensure_subcache_for(self, node_id, children_ids):
|
||||
# Just uses subcaches for tracking 'live' nodes
|
||||
super()._ensure_subcache(node_id, children_ids)
|
||||
await super()._ensure_subcache(node_id, children_ids)
|
||||
|
||||
self.cache_key_set.add_keys(children_ids)
|
||||
await self.cache_key_set.add_keys(children_ids)
|
||||
self._mark_used(node_id)
|
||||
cache_key = self.cache_key_set.get_data_key(node_id)
|
||||
self.children[cache_key] = []
|
||||
@@ -337,7 +338,7 @@ class DependencyAwareCache(BasicCache):
|
||||
self.ancestors = {} # Maps node_id -> set of ancestor node_ids
|
||||
self.executed_nodes = set() # Tracks nodes that have been executed
|
||||
|
||||
def set_prompt(self, dynprompt, node_ids, is_changed_cache):
|
||||
async def set_prompt(self, dynprompt, node_ids, is_changed_cache):
|
||||
"""
|
||||
Clear the entire cache and rebuild the dependency graph.
|
||||
|
||||
@@ -354,7 +355,7 @@ class DependencyAwareCache(BasicCache):
|
||||
self.executed_nodes.clear()
|
||||
|
||||
# Call the parent method to initialize the cache with the new prompt
|
||||
super().set_prompt(dynprompt, node_ids, is_changed_cache)
|
||||
await super().set_prompt(dynprompt, node_ids, is_changed_cache)
|
||||
|
||||
# Rebuild the dependency graph
|
||||
self._build_dependency_graph(dynprompt, node_ids)
|
||||
@@ -405,7 +406,7 @@ class DependencyAwareCache(BasicCache):
|
||||
"""
|
||||
return self._get_immediate(node_id)
|
||||
|
||||
def ensure_subcache_for(self, node_id, children_ids):
|
||||
async def ensure_subcache_for(self, node_id, children_ids):
|
||||
"""
|
||||
Ensure a subcache exists for a node and update dependencies.
|
||||
|
||||
@@ -416,7 +417,7 @@ class DependencyAwareCache(BasicCache):
|
||||
Returns:
|
||||
The subcache object for the node.
|
||||
"""
|
||||
subcache = super()._ensure_subcache(node_id, children_ids)
|
||||
subcache = await super()._ensure_subcache(node_id, children_ids)
|
||||
for child_id in children_ids:
|
||||
self.descendants[node_id].add(child_id)
|
||||
self.ancestors[child_id].add(node_id)
|
||||
|
@@ -2,6 +2,7 @@ from __future__ import annotations
|
||||
from typing import Type, Literal
|
||||
|
||||
import nodes
|
||||
import asyncio
|
||||
from comfy_execution.graph_utils import is_link
|
||||
from comfy.comfy_types.node_typing import ComfyNodeABC, InputTypeDict, InputTypeOptions
|
||||
|
||||
@@ -100,6 +101,8 @@ class TopologicalSort:
|
||||
self.pendingNodes = {}
|
||||
self.blockCount = {} # Number of nodes this node is directly blocked by
|
||||
self.blocking = {} # Which nodes are blocked by this node
|
||||
self.externalBlocks = 0
|
||||
self.unblockedEvent = asyncio.Event()
|
||||
|
||||
def get_input_info(self, unique_id, input_name):
|
||||
class_type = self.dynprompt.get_node(unique_id)["class_type"]
|
||||
@@ -153,6 +156,16 @@ class TopologicalSort:
|
||||
for link in links:
|
||||
self.add_strong_link(*link)
|
||||
|
||||
def add_external_block(self, node_id):
|
||||
assert node_id in self.blockCount, "Can't add external block to a node that isn't pending"
|
||||
self.externalBlocks += 1
|
||||
self.blockCount[node_id] += 1
|
||||
def unblock():
|
||||
self.externalBlocks -= 1
|
||||
self.blockCount[node_id] -= 1
|
||||
self.unblockedEvent.set()
|
||||
return unblock
|
||||
|
||||
def is_cached(self, node_id):
|
||||
return False
|
||||
|
||||
@@ -181,11 +194,16 @@ class ExecutionList(TopologicalSort):
|
||||
def is_cached(self, node_id):
|
||||
return self.output_cache.get(node_id) is not None
|
||||
|
||||
def stage_node_execution(self):
|
||||
async def stage_node_execution(self):
|
||||
assert self.staged_node_id is None
|
||||
if self.is_empty():
|
||||
return None, None, None
|
||||
available = self.get_ready_nodes()
|
||||
while len(available) == 0 and self.externalBlocks > 0:
|
||||
# Wait for an external block to be released
|
||||
await self.unblockedEvent.wait()
|
||||
self.unblockedEvent.clear()
|
||||
available = self.get_ready_nodes()
|
||||
if len(available) == 0:
|
||||
cycled_nodes = self.get_nodes_in_cycle()
|
||||
# Because cycles composed entirely of static nodes are caught during initial validation,
|
||||
|
288
comfy_execution/progress.py
Normal file
288
comfy_execution/progress.py
Normal file
@@ -0,0 +1,288 @@
|
||||
from typing import TypedDict, Dict, Optional
|
||||
from typing_extensions import override
|
||||
from PIL import Image
|
||||
from enum import Enum
|
||||
from abc import ABC
|
||||
from tqdm import tqdm
|
||||
from comfy_execution.graph import DynamicPrompt
|
||||
from protocol import BinaryEventTypes
|
||||
|
||||
class NodeState(Enum):
|
||||
Pending = "pending"
|
||||
Running = "running"
|
||||
Finished = "finished"
|
||||
Error = "error"
|
||||
|
||||
class NodeProgressState(TypedDict):
|
||||
"""
|
||||
A class to represent the state of a node's progress.
|
||||
"""
|
||||
state: NodeState
|
||||
value: float
|
||||
max: float
|
||||
|
||||
class ProgressHandler(ABC):
|
||||
"""
|
||||
Abstract base class for progress handlers.
|
||||
Progress handlers receive progress updates and display them in various ways.
|
||||
"""
|
||||
def __init__(self, name: str):
|
||||
self.name = name
|
||||
self.enabled = True
|
||||
|
||||
def set_registry(self, registry: "ProgressRegistry"):
|
||||
pass
|
||||
|
||||
def start_handler(self, node_id: str, state: NodeProgressState, prompt_id: str):
|
||||
"""Called when a node starts processing"""
|
||||
pass
|
||||
|
||||
def update_handler(self, node_id: str, value: float, max_value: float,
|
||||
state: NodeProgressState, prompt_id: str, image: Optional[Image.Image] = None):
|
||||
"""Called when a node's progress is updated"""
|
||||
pass
|
||||
|
||||
def finish_handler(self, node_id: str, state: NodeProgressState, prompt_id: str):
|
||||
"""Called when a node finishes processing"""
|
||||
pass
|
||||
|
||||
def reset(self):
|
||||
"""Called when the progress registry is reset"""
|
||||
pass
|
||||
|
||||
def enable(self):
|
||||
"""Enable this handler"""
|
||||
self.enabled = True
|
||||
|
||||
def disable(self):
|
||||
"""Disable this handler"""
|
||||
self.enabled = False
|
||||
|
||||
class CLIProgressHandler(ProgressHandler):
|
||||
"""
|
||||
Handler that displays progress using tqdm progress bars in the CLI.
|
||||
"""
|
||||
def __init__(self):
|
||||
super().__init__("cli")
|
||||
self.progress_bars: Dict[str, tqdm] = {}
|
||||
|
||||
@override
|
||||
def start_handler(self, node_id: str, state: NodeProgressState, prompt_id: str):
|
||||
# Create a new tqdm progress bar
|
||||
if node_id not in self.progress_bars:
|
||||
self.progress_bars[node_id] = tqdm(
|
||||
total=state["max"],
|
||||
desc=f"Node {node_id}",
|
||||
unit="steps",
|
||||
leave=True,
|
||||
position=len(self.progress_bars)
|
||||
)
|
||||
|
||||
@override
|
||||
def update_handler(self, node_id: str, value: float, max_value: float,
|
||||
state: NodeProgressState, prompt_id: str, image: Optional[Image.Image] = None):
|
||||
# Handle case where start_handler wasn't called
|
||||
if node_id not in self.progress_bars:
|
||||
self.progress_bars[node_id] = tqdm(
|
||||
total=max_value,
|
||||
desc=f"Node {node_id}",
|
||||
unit="steps",
|
||||
leave=True,
|
||||
position=len(self.progress_bars)
|
||||
)
|
||||
self.progress_bars[node_id].update(value)
|
||||
else:
|
||||
# Update existing progress bar
|
||||
if max_value != self.progress_bars[node_id].total:
|
||||
self.progress_bars[node_id].total = max_value
|
||||
# Calculate the update amount (difference from current position)
|
||||
current_position = self.progress_bars[node_id].n
|
||||
update_amount = value - current_position
|
||||
if update_amount > 0:
|
||||
self.progress_bars[node_id].update(update_amount)
|
||||
|
||||
@override
|
||||
def finish_handler(self, node_id: str, state: NodeProgressState, prompt_id: str):
|
||||
# Complete and close the progress bar if it exists
|
||||
if node_id in self.progress_bars:
|
||||
# Ensure the bar shows 100% completion
|
||||
remaining = state["max"] - self.progress_bars[node_id].n
|
||||
if remaining > 0:
|
||||
self.progress_bars[node_id].update(remaining)
|
||||
self.progress_bars[node_id].close()
|
||||
del self.progress_bars[node_id]
|
||||
|
||||
@override
|
||||
def reset(self):
|
||||
# Close all progress bars
|
||||
for bar in self.progress_bars.values():
|
||||
bar.close()
|
||||
self.progress_bars.clear()
|
||||
|
||||
class WebUIProgressHandler(ProgressHandler):
|
||||
"""
|
||||
Handler that sends progress updates to the WebUI via WebSockets.
|
||||
"""
|
||||
def __init__(self, server_instance):
|
||||
super().__init__("webui")
|
||||
self.server_instance = server_instance
|
||||
|
||||
def set_registry(self, registry: "ProgressRegistry"):
|
||||
self.registry = registry
|
||||
|
||||
def _send_progress_state(self, prompt_id: str, nodes: Dict[str, NodeProgressState]):
|
||||
"""Send the current progress state to the client"""
|
||||
if self.server_instance is None:
|
||||
return
|
||||
|
||||
# Only send info for non-pending nodes
|
||||
active_nodes = {
|
||||
node_id: {
|
||||
"value": state["value"],
|
||||
"max": state["max"],
|
||||
"state": state["state"].value,
|
||||
"node_id": node_id,
|
||||
"prompt_id": prompt_id,
|
||||
"display_node_id": self.registry.dynprompt.get_display_node_id(node_id),
|
||||
"parent_node_id": self.registry.dynprompt.get_parent_node_id(node_id),
|
||||
"real_node_id": self.registry.dynprompt.get_real_node_id(node_id)
|
||||
}
|
||||
for node_id, state in nodes.items()
|
||||
if state["state"] != NodeState.Pending
|
||||
}
|
||||
|
||||
# Send a combined progress_state message with all node states
|
||||
self.server_instance.send_sync("progress_state", {
|
||||
"prompt_id": prompt_id,
|
||||
"nodes": active_nodes
|
||||
})
|
||||
|
||||
@override
|
||||
def start_handler(self, node_id: str, state: NodeProgressState, prompt_id: str):
|
||||
# Send progress state of all nodes
|
||||
if self.registry:
|
||||
self._send_progress_state(prompt_id, self.registry.nodes)
|
||||
|
||||
@override
|
||||
def update_handler(self, node_id: str, value: float, max_value: float,
|
||||
state: NodeProgressState, prompt_id: str, image: Optional[Image.Image] = None):
|
||||
# Send progress state of all nodes
|
||||
if self.registry:
|
||||
self._send_progress_state(prompt_id, self.registry.nodes)
|
||||
if image:
|
||||
metadata = {
|
||||
"node_id": node_id,
|
||||
"prompt_id": prompt_id,
|
||||
"display_node_id": self.registry.dynprompt.get_display_node_id(node_id),
|
||||
"parent_node_id": self.registry.dynprompt.get_parent_node_id(node_id),
|
||||
"real_node_id": self.registry.dynprompt.get_real_node_id(node_id)
|
||||
}
|
||||
self.server_instance.send_sync(BinaryEventTypes.PREVIEW_IMAGE_WITH_METADATA, (image, metadata), self.server_instance.client_id)
|
||||
|
||||
|
||||
@override
|
||||
def finish_handler(self, node_id: str, state: NodeProgressState, prompt_id: str):
|
||||
# Send progress state of all nodes
|
||||
if self.registry:
|
||||
self._send_progress_state(prompt_id, self.registry.nodes)
|
||||
|
||||
class ProgressRegistry:
|
||||
"""
|
||||
Registry that maintains node progress state and notifies registered handlers.
|
||||
"""
|
||||
def __init__(self, prompt_id: str, dynprompt: DynamicPrompt):
|
||||
self.prompt_id = prompt_id
|
||||
self.dynprompt = dynprompt
|
||||
self.nodes: Dict[str, NodeProgressState] = {}
|
||||
self.handlers: Dict[str, ProgressHandler] = {}
|
||||
|
||||
def register_handler(self, handler: ProgressHandler) -> None:
|
||||
"""Register a progress handler"""
|
||||
self.handlers[handler.name] = handler
|
||||
|
||||
def unregister_handler(self, handler_name: str) -> None:
|
||||
"""Unregister a progress handler"""
|
||||
if handler_name in self.handlers:
|
||||
# Allow handler to clean up resources
|
||||
self.handlers[handler_name].reset()
|
||||
del self.handlers[handler_name]
|
||||
|
||||
def enable_handler(self, handler_name: str) -> None:
|
||||
"""Enable a progress handler"""
|
||||
if handler_name in self.handlers:
|
||||
self.handlers[handler_name].enable()
|
||||
|
||||
def disable_handler(self, handler_name: str) -> None:
|
||||
"""Disable a progress handler"""
|
||||
if handler_name in self.handlers:
|
||||
self.handlers[handler_name].disable()
|
||||
|
||||
def ensure_entry(self, node_id: str) -> NodeProgressState:
|
||||
"""Ensure a node entry exists"""
|
||||
if node_id not in self.nodes:
|
||||
self.nodes[node_id] = NodeProgressState(
|
||||
state = NodeState.Pending,
|
||||
value = 0,
|
||||
max = 1
|
||||
)
|
||||
return self.nodes[node_id]
|
||||
|
||||
def start_progress(self, node_id: str) -> None:
|
||||
"""Start progress tracking for a node"""
|
||||
entry = self.ensure_entry(node_id)
|
||||
entry["state"] = NodeState.Running
|
||||
entry["value"] = 0.0
|
||||
entry["max"] = 1.0
|
||||
|
||||
# Notify all enabled handlers
|
||||
for handler in self.handlers.values():
|
||||
if handler.enabled:
|
||||
handler.start_handler(node_id, entry, self.prompt_id)
|
||||
|
||||
def update_progress(self, node_id: str, value: float, max_value: float, image: Optional[Image.Image]) -> None:
|
||||
"""Update progress for a node"""
|
||||
entry = self.ensure_entry(node_id)
|
||||
entry["state"] = NodeState.Running
|
||||
entry["value"] = value
|
||||
entry["max"] = max_value
|
||||
|
||||
# Notify all enabled handlers
|
||||
for handler in self.handlers.values():
|
||||
if handler.enabled:
|
||||
handler.update_handler(node_id, value, max_value, entry, self.prompt_id, image)
|
||||
|
||||
def finish_progress(self, node_id: str) -> None:
|
||||
"""Finish progress tracking for a node"""
|
||||
entry = self.ensure_entry(node_id)
|
||||
entry["state"] = NodeState.Finished
|
||||
entry["value"] = entry["max"]
|
||||
|
||||
# Notify all enabled handlers
|
||||
for handler in self.handlers.values():
|
||||
if handler.enabled:
|
||||
handler.finish_handler(node_id, entry, self.prompt_id)
|
||||
|
||||
def reset_handlers(self) -> None:
|
||||
"""Reset all handlers"""
|
||||
for handler in self.handlers.values():
|
||||
handler.reset()
|
||||
|
||||
# Global registry instance
|
||||
global_progress_registry: ProgressRegistry = ProgressRegistry(prompt_id="", dynprompt=DynamicPrompt({}))
|
||||
|
||||
def reset_progress_state(prompt_id: str, dynprompt: DynamicPrompt) -> None:
|
||||
global global_progress_registry
|
||||
|
||||
# Reset existing handlers if registry exists
|
||||
if global_progress_registry is not None:
|
||||
global_progress_registry.reset_handlers()
|
||||
|
||||
# Create new registry
|
||||
global_progress_registry = ProgressRegistry(prompt_id, dynprompt)
|
||||
|
||||
def add_progress_handler(handler: ProgressHandler) -> None:
|
||||
handler.set_registry(global_progress_registry)
|
||||
global_progress_registry.register_handler(handler)
|
||||
|
||||
def get_progress_state() -> ProgressRegistry:
|
||||
return global_progress_registry
|
46
comfy_execution/utils.py
Normal file
46
comfy_execution/utils.py
Normal file
@@ -0,0 +1,46 @@
|
||||
import contextvars
|
||||
from typing import Optional, NamedTuple
|
||||
|
||||
class ExecutionContext(NamedTuple):
|
||||
"""
|
||||
Context information about the currently executing node.
|
||||
|
||||
Attributes:
|
||||
node_id: The ID of the currently executing node
|
||||
list_index: The index in a list being processed (for operations on batches/lists)
|
||||
"""
|
||||
prompt_id: str
|
||||
node_id: str
|
||||
list_index: Optional[int]
|
||||
|
||||
current_executing_context: contextvars.ContextVar[Optional[ExecutionContext]] = contextvars.ContextVar("current_executing_context", default=None)
|
||||
|
||||
def get_executing_context() -> Optional[ExecutionContext]:
|
||||
return current_executing_context.get(None)
|
||||
|
||||
class CurrentNodeContext:
|
||||
"""
|
||||
Context manager for setting the current executing node context.
|
||||
|
||||
Sets the current_executing_context on enter and resets it on exit.
|
||||
|
||||
Example:
|
||||
with CurrentNodeContext(node_id="123", list_index=0):
|
||||
# Code that should run with the current node context set
|
||||
process_image()
|
||||
"""
|
||||
def __init__(self, prompt_id: str, node_id: str, list_index: Optional[int] = None):
|
||||
self.context = ExecutionContext(
|
||||
prompt_id= prompt_id,
|
||||
node_id= node_id,
|
||||
list_index= list_index
|
||||
)
|
||||
self.token = None
|
||||
|
||||
def __enter__(self):
|
||||
self.token = current_executing_context.set(self.context)
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
if self.token is not None:
|
||||
current_executing_context.reset(self.token)
|
@@ -2,6 +2,7 @@ import nodes
|
||||
import torch
|
||||
import comfy.model_management
|
||||
import comfy.utils
|
||||
import comfy.latent_formats
|
||||
|
||||
|
||||
class EmptyCosmosLatentVideo:
|
||||
@@ -75,8 +76,53 @@ class CosmosImageToVideoLatent:
|
||||
out_latent["noise_mask"] = mask.repeat((batch_size, ) + (1,) * (mask.ndim - 1))
|
||||
return (out_latent,)
|
||||
|
||||
class CosmosPredict2ImageToVideoLatent:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"vae": ("VAE", ),
|
||||
"width": ("INT", {"default": 848, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
||||
"height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
||||
"length": ("INT", {"default": 93, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}),
|
||||
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
|
||||
},
|
||||
"optional": {"start_image": ("IMAGE", ),
|
||||
"end_image": ("IMAGE", ),
|
||||
}}
|
||||
|
||||
|
||||
RETURN_TYPES = ("LATENT",)
|
||||
FUNCTION = "encode"
|
||||
|
||||
CATEGORY = "conditioning/inpaint"
|
||||
|
||||
def encode(self, vae, width, height, length, batch_size, start_image=None, end_image=None):
|
||||
latent = torch.zeros([1, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
|
||||
if start_image is None and end_image is None:
|
||||
out_latent = {}
|
||||
out_latent["samples"] = latent
|
||||
return (out_latent,)
|
||||
|
||||
mask = torch.ones([latent.shape[0], 1, ((length - 1) // 4) + 1, latent.shape[-2], latent.shape[-1]], device=comfy.model_management.intermediate_device())
|
||||
|
||||
if start_image is not None:
|
||||
latent_temp = vae_encode_with_padding(vae, start_image, width, height, length, padding=1)
|
||||
latent[:, :, :latent_temp.shape[-3]] = latent_temp
|
||||
mask[:, :, :latent_temp.shape[-3]] *= 0.0
|
||||
|
||||
if end_image is not None:
|
||||
latent_temp = vae_encode_with_padding(vae, end_image, width, height, length, padding=0)
|
||||
latent[:, :, -latent_temp.shape[-3]:] = latent_temp
|
||||
mask[:, :, -latent_temp.shape[-3]:] *= 0.0
|
||||
|
||||
out_latent = {}
|
||||
latent_format = comfy.latent_formats.Wan21()
|
||||
latent = latent_format.process_out(latent) * mask + latent * (1.0 - mask)
|
||||
out_latent["samples"] = latent.repeat((batch_size, ) + (1,) * (latent.ndim - 1))
|
||||
out_latent["noise_mask"] = mask.repeat((batch_size, ) + (1,) * (mask.ndim - 1))
|
||||
return (out_latent,)
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"EmptyCosmosLatentVideo": EmptyCosmosLatentVideo,
|
||||
"CosmosImageToVideoLatent": CosmosImageToVideoLatent,
|
||||
"CosmosPredict2ImageToVideoLatent": CosmosPredict2ImageToVideoLatent,
|
||||
}
|
||||
|
@@ -2,6 +2,7 @@ import math
|
||||
import comfy.samplers
|
||||
import comfy.sample
|
||||
from comfy.k_diffusion import sampling as k_diffusion_sampling
|
||||
from comfy.comfy_types import IO, ComfyNodeABC, InputTypeDict
|
||||
import latent_preview
|
||||
import torch
|
||||
import comfy.utils
|
||||
@@ -480,6 +481,46 @@ class SamplerDPMAdaptative:
|
||||
"s_noise":s_noise })
|
||||
return (sampler, )
|
||||
|
||||
|
||||
class SamplerER_SDE(ComfyNodeABC):
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> InputTypeDict:
|
||||
return {
|
||||
"required": {
|
||||
"solver_type": (IO.COMBO, {"options": ["ER-SDE", "Reverse-time SDE", "ODE"]}),
|
||||
"max_stage": (IO.INT, {"default": 3, "min": 1, "max": 3}),
|
||||
"eta": (
|
||||
IO.FLOAT,
|
||||
{"default": 1.0, "min": 0.0, "max": 100.0, "step": 0.01, "round": False, "tooltip": "Stochastic strength of reverse-time SDE.\nWhen eta=0, it reduces to deterministic ODE. This setting doesn't apply to ER-SDE solver type."},
|
||||
),
|
||||
"s_noise": (IO.FLOAT, {"default": 1.0, "min": 0.0, "max": 100.0, "step": 0.01, "round": False}),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = (IO.SAMPLER,)
|
||||
CATEGORY = "sampling/custom_sampling/samplers"
|
||||
|
||||
FUNCTION = "get_sampler"
|
||||
|
||||
def get_sampler(self, solver_type, max_stage, eta, s_noise):
|
||||
if solver_type == "ODE" or (solver_type == "Reverse-time SDE" and eta == 0):
|
||||
eta = 0
|
||||
s_noise = 0
|
||||
|
||||
def reverse_time_sde_noise_scaler(x):
|
||||
return x ** (eta + 1)
|
||||
|
||||
if solver_type == "ER-SDE":
|
||||
# Use the default one in sample_er_sde()
|
||||
noise_scaler = None
|
||||
else:
|
||||
noise_scaler = reverse_time_sde_noise_scaler
|
||||
|
||||
sampler_name = "er_sde"
|
||||
sampler = comfy.samplers.ksampler(sampler_name, {"s_noise": s_noise, "noise_scaler": noise_scaler, "max_stage": max_stage})
|
||||
return (sampler,)
|
||||
|
||||
|
||||
class Noise_EmptyNoise:
|
||||
def __init__(self):
|
||||
self.seed = 0
|
||||
@@ -609,8 +650,14 @@ class Guider_DualCFG(comfy.samplers.CFGGuider):
|
||||
def predict_noise(self, x, timestep, model_options={}, seed=None):
|
||||
negative_cond = self.conds.get("negative", None)
|
||||
middle_cond = self.conds.get("middle", None)
|
||||
positive_cond = self.conds.get("positive", None)
|
||||
if model_options.get("disable_cfg1_optimization", False) == False:
|
||||
if math.isclose(self.cfg2, 1.0):
|
||||
negative_cond = None
|
||||
if math.isclose(self.cfg1, 1.0):
|
||||
middle_cond = None
|
||||
|
||||
out = comfy.samplers.calc_cond_batch(self.inner_model, [negative_cond, middle_cond, self.conds.get("positive", None)], x, timestep, model_options)
|
||||
out = comfy.samplers.calc_cond_batch(self.inner_model, [negative_cond, middle_cond, positive_cond], x, timestep, model_options)
|
||||
return comfy.samplers.cfg_function(self.inner_model, out[1], out[0], self.cfg2, x, timestep, model_options=model_options, cond=middle_cond, uncond=negative_cond) + (out[2] - out[1]) * self.cfg1
|
||||
|
||||
class DualCFGGuider:
|
||||
@@ -781,6 +828,7 @@ NODE_CLASS_MAPPINGS = {
|
||||
"SamplerDPMPP_SDE": SamplerDPMPP_SDE,
|
||||
"SamplerDPMPP_2S_Ancestral": SamplerDPMPP_2S_Ancestral,
|
||||
"SamplerDPMAdaptative": SamplerDPMAdaptative,
|
||||
"SamplerER_SDE": SamplerER_SDE,
|
||||
"SplitSigmas": SplitSigmas,
|
||||
"SplitSigmasDenoise": SplitSigmasDenoise,
|
||||
"FlipSigmas": FlipSigmas,
|
||||
|
26
comfy_extras/nodes_edit_model.py
Normal file
26
comfy_extras/nodes_edit_model.py
Normal file
@@ -0,0 +1,26 @@
|
||||
import node_helpers
|
||||
|
||||
|
||||
class ReferenceLatent:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"conditioning": ("CONDITIONING", ),
|
||||
},
|
||||
"optional": {"latent": ("LATENT", ),}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("CONDITIONING",)
|
||||
FUNCTION = "append"
|
||||
|
||||
CATEGORY = "advanced/conditioning/edit_models"
|
||||
DESCRIPTION = "This node sets the guiding latent for an edit model. If the model supports it you can chain multiple to set multiple reference images."
|
||||
|
||||
def append(self, conditioning, latent=None):
|
||||
if latent is not None:
|
||||
conditioning = node_helpers.conditioning_set_values(conditioning, {"reference_latents": [latent["samples"]]}, append=True)
|
||||
return (conditioning, )
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"ReferenceLatent": ReferenceLatent,
|
||||
}
|
@@ -1,4 +1,5 @@
|
||||
import node_helpers
|
||||
import comfy.utils
|
||||
|
||||
class CLIPTextEncodeFlux:
|
||||
@classmethod
|
||||
@@ -56,8 +57,52 @@ class FluxDisableGuidance:
|
||||
return (c, )
|
||||
|
||||
|
||||
PREFERED_KONTEXT_RESOLUTIONS = [
|
||||
(672, 1568),
|
||||
(688, 1504),
|
||||
(720, 1456),
|
||||
(752, 1392),
|
||||
(800, 1328),
|
||||
(832, 1248),
|
||||
(880, 1184),
|
||||
(944, 1104),
|
||||
(1024, 1024),
|
||||
(1104, 944),
|
||||
(1184, 880),
|
||||
(1248, 832),
|
||||
(1328, 800),
|
||||
(1392, 752),
|
||||
(1456, 720),
|
||||
(1504, 688),
|
||||
(1568, 672),
|
||||
]
|
||||
|
||||
|
||||
class FluxKontextImageScale:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"image": ("IMAGE", ),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "scale"
|
||||
|
||||
CATEGORY = "advanced/conditioning/flux"
|
||||
DESCRIPTION = "This node resizes the image to one that is more optimal for flux kontext."
|
||||
|
||||
def scale(self, image):
|
||||
width = image.shape[2]
|
||||
height = image.shape[1]
|
||||
aspect_ratio = width / height
|
||||
_, width, height = min((abs(aspect_ratio - w / h), w, h) for w, h in PREFERED_KONTEXT_RESOLUTIONS)
|
||||
image = comfy.utils.common_upscale(image.movedim(-1, 1), width, height, "lanczos", "center").movedim(1, -1)
|
||||
return (image, )
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"CLIPTextEncodeFlux": CLIPTextEncodeFlux,
|
||||
"FluxGuidance": FluxGuidance,
|
||||
"FluxDisableGuidance": FluxDisableGuidance,
|
||||
"FluxKontextImageScale": FluxKontextImageScale,
|
||||
}
|
||||
|
@@ -14,8 +14,10 @@ import re
|
||||
from io import BytesIO
|
||||
from inspect import cleandoc
|
||||
import torch
|
||||
import comfy.utils
|
||||
|
||||
from comfy.comfy_types import FileLocator
|
||||
from comfy.comfy_types import FileLocator, IO
|
||||
from server import PromptServer
|
||||
|
||||
MAX_RESOLUTION = nodes.MAX_RESOLUTION
|
||||
|
||||
@@ -229,6 +231,246 @@ class SVG:
|
||||
all_svgs_list.extend(svg_item.data)
|
||||
return SVG(all_svgs_list)
|
||||
|
||||
|
||||
class ImageStitch:
|
||||
"""Upstreamed from https://github.com/kijai/ComfyUI-KJNodes"""
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"image1": ("IMAGE",),
|
||||
"direction": (["right", "down", "left", "up"], {"default": "right"}),
|
||||
"match_image_size": ("BOOLEAN", {"default": True}),
|
||||
"spacing_width": (
|
||||
"INT",
|
||||
{"default": 0, "min": 0, "max": 1024, "step": 2},
|
||||
),
|
||||
"spacing_color": (
|
||||
["white", "black", "red", "green", "blue"],
|
||||
{"default": "white"},
|
||||
),
|
||||
},
|
||||
"optional": {
|
||||
"image2": ("IMAGE",),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "stitch"
|
||||
CATEGORY = "image/transform"
|
||||
DESCRIPTION = """
|
||||
Stitches image2 to image1 in the specified direction.
|
||||
If image2 is not provided, returns image1 unchanged.
|
||||
Optional spacing can be added between images.
|
||||
"""
|
||||
|
||||
def stitch(
|
||||
self,
|
||||
image1,
|
||||
direction,
|
||||
match_image_size,
|
||||
spacing_width,
|
||||
spacing_color,
|
||||
image2=None,
|
||||
):
|
||||
if image2 is None:
|
||||
return (image1,)
|
||||
|
||||
# Handle batch size differences
|
||||
if image1.shape[0] != image2.shape[0]:
|
||||
max_batch = max(image1.shape[0], image2.shape[0])
|
||||
if image1.shape[0] < max_batch:
|
||||
image1 = torch.cat(
|
||||
[image1, image1[-1:].repeat(max_batch - image1.shape[0], 1, 1, 1)]
|
||||
)
|
||||
if image2.shape[0] < max_batch:
|
||||
image2 = torch.cat(
|
||||
[image2, image2[-1:].repeat(max_batch - image2.shape[0], 1, 1, 1)]
|
||||
)
|
||||
|
||||
# Match image sizes if requested
|
||||
if match_image_size:
|
||||
h1, w1 = image1.shape[1:3]
|
||||
h2, w2 = image2.shape[1:3]
|
||||
aspect_ratio = w2 / h2
|
||||
|
||||
if direction in ["left", "right"]:
|
||||
target_h, target_w = h1, int(h1 * aspect_ratio)
|
||||
else: # up, down
|
||||
target_w, target_h = w1, int(w1 / aspect_ratio)
|
||||
|
||||
image2 = comfy.utils.common_upscale(
|
||||
image2.movedim(-1, 1), target_w, target_h, "lanczos", "disabled"
|
||||
).movedim(1, -1)
|
||||
|
||||
color_map = {
|
||||
"white": 1.0,
|
||||
"black": 0.0,
|
||||
"red": (1.0, 0.0, 0.0),
|
||||
"green": (0.0, 1.0, 0.0),
|
||||
"blue": (0.0, 0.0, 1.0),
|
||||
}
|
||||
|
||||
color_val = color_map[spacing_color]
|
||||
|
||||
# When not matching sizes, pad to align non-concat dimensions
|
||||
if not match_image_size:
|
||||
h1, w1 = image1.shape[1:3]
|
||||
h2, w2 = image2.shape[1:3]
|
||||
pad_value = 0.0
|
||||
if not isinstance(color_val, tuple):
|
||||
pad_value = color_val
|
||||
|
||||
if direction in ["left", "right"]:
|
||||
# For horizontal concat, pad heights to match
|
||||
if h1 != h2:
|
||||
target_h = max(h1, h2)
|
||||
if h1 < target_h:
|
||||
pad_h = target_h - h1
|
||||
pad_top, pad_bottom = pad_h // 2, pad_h - pad_h // 2
|
||||
image1 = torch.nn.functional.pad(image1, (0, 0, 0, 0, pad_top, pad_bottom), mode='constant', value=pad_value)
|
||||
if h2 < target_h:
|
||||
pad_h = target_h - h2
|
||||
pad_top, pad_bottom = pad_h // 2, pad_h - pad_h // 2
|
||||
image2 = torch.nn.functional.pad(image2, (0, 0, 0, 0, pad_top, pad_bottom), mode='constant', value=pad_value)
|
||||
else: # up, down
|
||||
# For vertical concat, pad widths to match
|
||||
if w1 != w2:
|
||||
target_w = max(w1, w2)
|
||||
if w1 < target_w:
|
||||
pad_w = target_w - w1
|
||||
pad_left, pad_right = pad_w // 2, pad_w - pad_w // 2
|
||||
image1 = torch.nn.functional.pad(image1, (0, 0, pad_left, pad_right), mode='constant', value=pad_value)
|
||||
if w2 < target_w:
|
||||
pad_w = target_w - w2
|
||||
pad_left, pad_right = pad_w // 2, pad_w - pad_w // 2
|
||||
image2 = torch.nn.functional.pad(image2, (0, 0, pad_left, pad_right), mode='constant', value=pad_value)
|
||||
|
||||
# Ensure same number of channels
|
||||
if image1.shape[-1] != image2.shape[-1]:
|
||||
max_channels = max(image1.shape[-1], image2.shape[-1])
|
||||
if image1.shape[-1] < max_channels:
|
||||
image1 = torch.cat(
|
||||
[
|
||||
image1,
|
||||
torch.ones(
|
||||
*image1.shape[:-1],
|
||||
max_channels - image1.shape[-1],
|
||||
device=image1.device,
|
||||
),
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
if image2.shape[-1] < max_channels:
|
||||
image2 = torch.cat(
|
||||
[
|
||||
image2,
|
||||
torch.ones(
|
||||
*image2.shape[:-1],
|
||||
max_channels - image2.shape[-1],
|
||||
device=image2.device,
|
||||
),
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
# Add spacing if specified
|
||||
if spacing_width > 0:
|
||||
spacing_width = spacing_width + (spacing_width % 2) # Ensure even
|
||||
|
||||
if direction in ["left", "right"]:
|
||||
spacing_shape = (
|
||||
image1.shape[0],
|
||||
max(image1.shape[1], image2.shape[1]),
|
||||
spacing_width,
|
||||
image1.shape[-1],
|
||||
)
|
||||
else:
|
||||
spacing_shape = (
|
||||
image1.shape[0],
|
||||
spacing_width,
|
||||
max(image1.shape[2], image2.shape[2]),
|
||||
image1.shape[-1],
|
||||
)
|
||||
|
||||
spacing = torch.full(spacing_shape, 0.0, device=image1.device)
|
||||
if isinstance(color_val, tuple):
|
||||
for i, c in enumerate(color_val):
|
||||
if i < spacing.shape[-1]:
|
||||
spacing[..., i] = c
|
||||
if spacing.shape[-1] == 4: # Add alpha
|
||||
spacing[..., 3] = 1.0
|
||||
else:
|
||||
spacing[..., : min(3, spacing.shape[-1])] = color_val
|
||||
if spacing.shape[-1] == 4:
|
||||
spacing[..., 3] = 1.0
|
||||
|
||||
# Concatenate images
|
||||
images = [image2, image1] if direction in ["left", "up"] else [image1, image2]
|
||||
if spacing_width > 0:
|
||||
images.insert(1, spacing)
|
||||
|
||||
concat_dim = 2 if direction in ["left", "right"] else 1
|
||||
return (torch.cat(images, dim=concat_dim),)
|
||||
|
||||
class ResizeAndPadImage:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"image": ("IMAGE",),
|
||||
"target_width": ("INT", {
|
||||
"default": 512,
|
||||
"min": 1,
|
||||
"max": MAX_RESOLUTION,
|
||||
"step": 1
|
||||
}),
|
||||
"target_height": ("INT", {
|
||||
"default": 512,
|
||||
"min": 1,
|
||||
"max": MAX_RESOLUTION,
|
||||
"step": 1
|
||||
}),
|
||||
"padding_color": (["white", "black"],),
|
||||
"interpolation": (["area", "bicubic", "nearest-exact", "bilinear", "lanczos"],),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "resize_and_pad"
|
||||
CATEGORY = "image/transform"
|
||||
|
||||
def resize_and_pad(self, image, target_width, target_height, padding_color, interpolation):
|
||||
batch_size, orig_height, orig_width, channels = image.shape
|
||||
|
||||
scale_w = target_width / orig_width
|
||||
scale_h = target_height / orig_height
|
||||
scale = min(scale_w, scale_h)
|
||||
|
||||
new_width = int(orig_width * scale)
|
||||
new_height = int(orig_height * scale)
|
||||
|
||||
image_permuted = image.permute(0, 3, 1, 2)
|
||||
|
||||
resized = comfy.utils.common_upscale(image_permuted, new_width, new_height, interpolation, "disabled")
|
||||
|
||||
pad_value = 0.0 if padding_color == "black" else 1.0
|
||||
padded = torch.full(
|
||||
(batch_size, channels, target_height, target_width),
|
||||
pad_value,
|
||||
dtype=image.dtype,
|
||||
device=image.device
|
||||
)
|
||||
|
||||
y_offset = (target_height - new_height) // 2
|
||||
x_offset = (target_width - new_width) // 2
|
||||
|
||||
padded[:, :, y_offset:y_offset + new_height, x_offset:x_offset + new_width] = resized
|
||||
|
||||
output = padded.permute(0, 2, 3, 1)
|
||||
return (output,)
|
||||
|
||||
class SaveSVGNode:
|
||||
"""
|
||||
Save SVG files on disk.
|
||||
@@ -310,6 +552,37 @@ class SaveSVGNode:
|
||||
counter += 1
|
||||
return { "ui": { "images": results } }
|
||||
|
||||
class GetImageSize:
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"image": (IO.IMAGE,),
|
||||
},
|
||||
"hidden": {
|
||||
"unique_id": "UNIQUE_ID",
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = (IO.INT, IO.INT, IO.INT)
|
||||
RETURN_NAMES = ("width", "height", "batch_size")
|
||||
FUNCTION = "get_size"
|
||||
|
||||
CATEGORY = "image"
|
||||
DESCRIPTION = """Returns width and height of the image, and passes it through unchanged."""
|
||||
|
||||
def get_size(self, image, unique_id=None) -> tuple[int, int]:
|
||||
height = image.shape[1]
|
||||
width = image.shape[2]
|
||||
batch_size = image.shape[0]
|
||||
|
||||
# Send progress text to display size on the node
|
||||
if unique_id:
|
||||
PromptServer.instance.send_progress_text(f"width: {width}, height: {height}\n batch size: {batch_size}", unique_id)
|
||||
|
||||
return width, height, batch_size
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"ImageCrop": ImageCrop,
|
||||
"RepeatImageBatch": RepeatImageBatch,
|
||||
@@ -318,4 +591,7 @@ NODE_CLASS_MAPPINGS = {
|
||||
"SaveAnimatedWEBP": SaveAnimatedWEBP,
|
||||
"SaveAnimatedPNG": SaveAnimatedPNG,
|
||||
"SaveSVGNode": SaveSVGNode,
|
||||
"ImageStitch": ImageStitch,
|
||||
"ResizeAndPadImage": ResizeAndPadImage,
|
||||
"GetImageSize": GetImageSize,
|
||||
}
|
||||
|
@@ -189,7 +189,7 @@ class ModelSamplingContinuousEDM:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "model": ("MODEL",),
|
||||
"sampling": (["v_prediction", "edm", "edm_playground_v2.5", "eps"],),
|
||||
"sampling": (["v_prediction", "edm", "edm_playground_v2.5", "eps", "cosmos_rflow"],),
|
||||
"sigma_max": ("FLOAT", {"default": 120.0, "min": 0.0, "max": 1000.0, "step":0.001, "round": False}),
|
||||
"sigma_min": ("FLOAT", {"default": 0.002, "min": 0.0, "max": 1000.0, "step":0.001, "round": False}),
|
||||
}}
|
||||
@@ -202,6 +202,7 @@ class ModelSamplingContinuousEDM:
|
||||
def patch(self, model, sampling, sigma_max, sigma_min):
|
||||
m = model.clone()
|
||||
|
||||
sampling_base = comfy.model_sampling.ModelSamplingContinuousEDM
|
||||
latent_format = None
|
||||
sigma_data = 1.0
|
||||
if sampling == "eps":
|
||||
@@ -215,8 +216,11 @@ class ModelSamplingContinuousEDM:
|
||||
sampling_type = comfy.model_sampling.EDM
|
||||
sigma_data = 0.5
|
||||
latent_format = comfy.latent_formats.SDXL_Playground_2_5()
|
||||
elif sampling == "cosmos_rflow":
|
||||
sampling_type = comfy.model_sampling.COSMOS_RFLOW
|
||||
sampling_base = comfy.model_sampling.ModelSamplingCosmosRFlow
|
||||
|
||||
class ModelSamplingAdvanced(comfy.model_sampling.ModelSamplingContinuousEDM, sampling_type):
|
||||
class ModelSamplingAdvanced(sampling_base, sampling_type):
|
||||
pass
|
||||
|
||||
model_sampling = ModelSamplingAdvanced(model.model.model_config)
|
||||
|
@@ -268,6 +268,52 @@ class ModelMergeWAN2_1(comfy_extras.nodes_model_merging.ModelMergeBlocks):
|
||||
|
||||
return {"required": arg_dict}
|
||||
|
||||
class ModelMergeCosmosPredict2_2B(comfy_extras.nodes_model_merging.ModelMergeBlocks):
|
||||
CATEGORY = "advanced/model_merging/model_specific"
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
arg_dict = { "model1": ("MODEL",),
|
||||
"model2": ("MODEL",)}
|
||||
|
||||
argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
|
||||
|
||||
arg_dict["pos_embedder."] = argument
|
||||
arg_dict["x_embedder."] = argument
|
||||
arg_dict["t_embedder."] = argument
|
||||
arg_dict["t_embedding_norm."] = argument
|
||||
|
||||
|
||||
for i in range(28):
|
||||
arg_dict["blocks.{}.".format(i)] = argument
|
||||
|
||||
arg_dict["final_layer."] = argument
|
||||
|
||||
return {"required": arg_dict}
|
||||
|
||||
class ModelMergeCosmosPredict2_14B(comfy_extras.nodes_model_merging.ModelMergeBlocks):
|
||||
CATEGORY = "advanced/model_merging/model_specific"
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
arg_dict = { "model1": ("MODEL",),
|
||||
"model2": ("MODEL",)}
|
||||
|
||||
argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
|
||||
|
||||
arg_dict["pos_embedder."] = argument
|
||||
arg_dict["x_embedder."] = argument
|
||||
arg_dict["t_embedder."] = argument
|
||||
arg_dict["t_embedding_norm."] = argument
|
||||
|
||||
|
||||
for i in range(36):
|
||||
arg_dict["blocks.{}.".format(i)] = argument
|
||||
|
||||
arg_dict["final_layer."] = argument
|
||||
|
||||
return {"required": arg_dict}
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"ModelMergeSD1": ModelMergeSD1,
|
||||
"ModelMergeSD2": ModelMergeSD1, #SD1 and SD2 have the same blocks
|
||||
@@ -281,4 +327,6 @@ NODE_CLASS_MAPPINGS = {
|
||||
"ModelMergeCosmos7B": ModelMergeCosmos7B,
|
||||
"ModelMergeCosmos14B": ModelMergeCosmos14B,
|
||||
"ModelMergeWAN2_1": ModelMergeWAN2_1,
|
||||
"ModelMergeCosmosPredict2_2B": ModelMergeCosmosPredict2_2B,
|
||||
"ModelMergeCosmosPredict2_14B": ModelMergeCosmosPredict2_14B,
|
||||
}
|
||||
|
@@ -4,6 +4,7 @@ import comfy.sampler_helpers
|
||||
import comfy.samplers
|
||||
import comfy.utils
|
||||
import node_helpers
|
||||
import math
|
||||
|
||||
def perp_neg(x, noise_pred_pos, noise_pred_neg, noise_pred_nocond, neg_scale, cond_scale):
|
||||
pos = noise_pred_pos - noise_pred_nocond
|
||||
@@ -69,8 +70,23 @@ class Guider_PerpNeg(comfy.samplers.CFGGuider):
|
||||
negative_cond = self.conds.get("negative", None)
|
||||
empty_cond = self.conds.get("empty_negative_prompt", None)
|
||||
|
||||
(noise_pred_pos, noise_pred_neg, noise_pred_empty) = \
|
||||
comfy.samplers.calc_cond_batch(self.inner_model, [positive_cond, negative_cond, empty_cond], x, timestep, model_options)
|
||||
if model_options.get("disable_cfg1_optimization", False) == False:
|
||||
if math.isclose(self.neg_scale, 0.0):
|
||||
negative_cond = None
|
||||
if math.isclose(self.cfg, 1.0):
|
||||
empty_cond = None
|
||||
|
||||
conds = [positive_cond, negative_cond, empty_cond]
|
||||
|
||||
out = comfy.samplers.calc_cond_batch(self.inner_model, conds, x, timestep, model_options)
|
||||
|
||||
# Apply pre_cfg_functions since sampling_function() is skipped
|
||||
for fn in model_options.get("sampler_pre_cfg_function", []):
|
||||
args = {"conds":conds, "conds_out": out, "cond_scale": self.cfg, "timestep": timestep,
|
||||
"input": x, "sigma": timestep, "model": self.inner_model, "model_options": model_options}
|
||||
out = fn(args)
|
||||
|
||||
noise_pred_pos, noise_pred_neg, noise_pred_empty = out
|
||||
cfg_result = perp_neg(x, noise_pred_pos, noise_pred_neg, noise_pred_empty, self.neg_scale, self.cfg)
|
||||
|
||||
# normally this would be done in cfg_function, but we skipped
|
||||
@@ -82,6 +98,7 @@ class Guider_PerpNeg(comfy.samplers.CFGGuider):
|
||||
"denoised": cfg_result,
|
||||
"cond": positive_cond,
|
||||
"uncond": negative_cond,
|
||||
"cond_scale": self.cfg,
|
||||
"model": self.inner_model,
|
||||
"uncond_denoised": noise_pred_neg,
|
||||
"cond_denoised": noise_pred_pos,
|
||||
|
@@ -296,6 +296,41 @@ class RegexExtract():
|
||||
|
||||
return result,
|
||||
|
||||
|
||||
class RegexReplace():
|
||||
DESCRIPTION = "Find and replace text using regex patterns."
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"string": (IO.STRING, {"multiline": True}),
|
||||
"regex_pattern": (IO.STRING, {"multiline": True}),
|
||||
"replace": (IO.STRING, {"multiline": True}),
|
||||
},
|
||||
"optional": {
|
||||
"case_insensitive": (IO.BOOLEAN, {"default": True}),
|
||||
"multiline": (IO.BOOLEAN, {"default": False}),
|
||||
"dotall": (IO.BOOLEAN, {"default": False, "tooltip": "When enabled, the dot (.) character will match any character including newline characters. When disabled, dots won't match newlines."}),
|
||||
"count": (IO.INT, {"default": 0, "min": 0, "max": 100, "tooltip": "Maximum number of replacements to make. Set to 0 to replace all occurrences (default). Set to 1 to replace only the first match, 2 for the first two matches, etc."}),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = (IO.STRING,)
|
||||
FUNCTION = "execute"
|
||||
CATEGORY = "utils/string"
|
||||
|
||||
def execute(self, string, regex_pattern, replace, case_insensitive=True, multiline=False, dotall=False, count=0, **kwargs):
|
||||
flags = 0
|
||||
|
||||
if case_insensitive:
|
||||
flags |= re.IGNORECASE
|
||||
if multiline:
|
||||
flags |= re.MULTILINE
|
||||
if dotall:
|
||||
flags |= re.DOTALL
|
||||
result = re.sub(regex_pattern, replace, string, count=count, flags=flags)
|
||||
return result,
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"StringConcatenate": StringConcatenate,
|
||||
"StringSubstring": StringSubstring,
|
||||
@@ -306,7 +341,8 @@ NODE_CLASS_MAPPINGS = {
|
||||
"StringContains": StringContains,
|
||||
"StringCompare": StringCompare,
|
||||
"RegexMatch": RegexMatch,
|
||||
"RegexExtract": RegexExtract
|
||||
"RegexExtract": RegexExtract,
|
||||
"RegexReplace": RegexReplace,
|
||||
}
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
@@ -319,5 +355,6 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"StringContains": "Contains",
|
||||
"StringCompare": "Compare",
|
||||
"RegexMatch": "Regex Match",
|
||||
"RegexExtract": "Regex Extract"
|
||||
"RegexExtract": "Regex Extract",
|
||||
"RegexReplace": "Regex Replace",
|
||||
}
|
||||
|
71
comfy_extras/nodes_tcfg.py
Normal file
71
comfy_extras/nodes_tcfg.py
Normal file
@@ -0,0 +1,71 @@
|
||||
# TCFG: Tangential Damping Classifier-free Guidance - (arXiv: https://arxiv.org/abs/2503.18137)
|
||||
|
||||
import torch
|
||||
|
||||
from comfy.comfy_types import IO, ComfyNodeABC, InputTypeDict
|
||||
|
||||
|
||||
def score_tangential_damping(cond_score: torch.Tensor, uncond_score: torch.Tensor) -> torch.Tensor:
|
||||
"""Drop tangential components from uncond score to align with cond score."""
|
||||
# (B, 1, ...)
|
||||
batch_num = cond_score.shape[0]
|
||||
cond_score_flat = cond_score.reshape(batch_num, 1, -1).float()
|
||||
uncond_score_flat = uncond_score.reshape(batch_num, 1, -1).float()
|
||||
|
||||
# Score matrix A (B, 2, ...)
|
||||
score_matrix = torch.cat((uncond_score_flat, cond_score_flat), dim=1)
|
||||
try:
|
||||
_, _, Vh = torch.linalg.svd(score_matrix, full_matrices=False)
|
||||
except RuntimeError:
|
||||
# Fallback to CPU
|
||||
_, _, Vh = torch.linalg.svd(score_matrix.cpu(), full_matrices=False)
|
||||
|
||||
# Drop the tangential components
|
||||
v1 = Vh[:, 0:1, :].to(uncond_score_flat.device) # (B, 1, ...)
|
||||
uncond_score_td = (uncond_score_flat @ v1.transpose(-2, -1)) * v1
|
||||
return uncond_score_td.reshape_as(uncond_score).to(uncond_score.dtype)
|
||||
|
||||
|
||||
class TCFG(ComfyNodeABC):
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> InputTypeDict:
|
||||
return {
|
||||
"required": {
|
||||
"model": (IO.MODEL, {}),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = (IO.MODEL,)
|
||||
RETURN_NAMES = ("patched_model",)
|
||||
FUNCTION = "patch"
|
||||
|
||||
CATEGORY = "advanced/guidance"
|
||||
DESCRIPTION = "TCFG – Tangential Damping CFG (2503.18137)\n\nRefine the uncond (negative) to align with the cond (positive) for improving quality."
|
||||
|
||||
def patch(self, model):
|
||||
m = model.clone()
|
||||
|
||||
def tangential_damping_cfg(args):
|
||||
# Assume [cond, uncond, ...]
|
||||
x = args["input"]
|
||||
conds_out = args["conds_out"]
|
||||
if len(conds_out) <= 1 or None in args["conds"][:2]:
|
||||
# Skip when either cond or uncond is None
|
||||
return conds_out
|
||||
cond_pred = conds_out[0]
|
||||
uncond_pred = conds_out[1]
|
||||
uncond_td = score_tangential_damping(x - cond_pred, x - uncond_pred)
|
||||
uncond_pred_td = x - uncond_td
|
||||
return [cond_pred, uncond_pred_td] + conds_out[2:]
|
||||
|
||||
m.set_model_sampler_pre_cfg_function(tangential_damping_cfg)
|
||||
return (m,)
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"TCFG": TCFG,
|
||||
}
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"TCFG": "Tangential Damping CFG",
|
||||
}
|
709
comfy_extras/nodes_train.py
Normal file
709
comfy_extras/nodes_train.py
Normal file
@@ -0,0 +1,709 @@
|
||||
import datetime
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import safetensors
|
||||
import torch
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
from PIL.PngImagePlugin import PngInfo
|
||||
import torch.utils.checkpoint
|
||||
import tqdm
|
||||
|
||||
import comfy.samplers
|
||||
import comfy.sd
|
||||
import comfy.utils
|
||||
import comfy.model_management
|
||||
import comfy_extras.nodes_custom_sampler
|
||||
import folder_paths
|
||||
import node_helpers
|
||||
from comfy.cli_args import args
|
||||
from comfy.comfy_types.node_typing import IO
|
||||
from comfy.weight_adapter import adapters
|
||||
|
||||
|
||||
class TrainSampler(comfy.samplers.Sampler):
|
||||
|
||||
def __init__(self, loss_fn, optimizer, loss_callback=None):
|
||||
self.loss_fn = loss_fn
|
||||
self.optimizer = optimizer
|
||||
self.loss_callback = loss_callback
|
||||
|
||||
def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False):
|
||||
self.optimizer.zero_grad()
|
||||
noise = model_wrap.inner_model.model_sampling.noise_scaling(sigmas, noise, latent_image, False)
|
||||
latent = model_wrap.inner_model.model_sampling.noise_scaling(
|
||||
torch.zeros_like(sigmas),
|
||||
torch.zeros_like(noise, requires_grad=True),
|
||||
latent_image,
|
||||
False
|
||||
)
|
||||
|
||||
# Ensure model is in training mode and computing gradients
|
||||
# x0 pred
|
||||
denoised = model_wrap(noise, sigmas, **extra_args)
|
||||
try:
|
||||
loss = self.loss_fn(denoised, latent.clone())
|
||||
except RuntimeError as e:
|
||||
if "does not require grad and does not have a grad_fn" in str(e):
|
||||
logging.info("WARNING: This is likely due to the model is loaded in inference mode.")
|
||||
loss.backward()
|
||||
if self.loss_callback:
|
||||
self.loss_callback(loss.item())
|
||||
|
||||
self.optimizer.step()
|
||||
# torch.cuda.memory._dump_snapshot("trainn.pickle")
|
||||
# torch.cuda.memory._record_memory_history(enabled=None)
|
||||
return torch.zeros_like(latent_image)
|
||||
|
||||
|
||||
class BiasDiff(torch.nn.Module):
|
||||
def __init__(self, bias):
|
||||
super().__init__()
|
||||
self.bias = bias
|
||||
|
||||
def __call__(self, b):
|
||||
org_dtype = b.dtype
|
||||
return (b.to(self.bias) + self.bias).to(org_dtype)
|
||||
|
||||
def passive_memory_usage(self):
|
||||
return self.bias.nelement() * self.bias.element_size()
|
||||
|
||||
def move_to(self, device):
|
||||
self.to(device=device)
|
||||
return self.passive_memory_usage()
|
||||
|
||||
|
||||
def load_and_process_images(image_files, input_dir, resize_method="None"):
|
||||
"""Utility function to load and process a list of images.
|
||||
|
||||
Args:
|
||||
image_files: List of image filenames
|
||||
input_dir: Base directory containing the images
|
||||
resize_method: How to handle images of different sizes ("None", "Stretch", "Crop", "Pad")
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Batch of processed images
|
||||
"""
|
||||
if not image_files:
|
||||
raise ValueError("No valid images found in input")
|
||||
|
||||
output_images = []
|
||||
w, h = None, None
|
||||
|
||||
for file in image_files:
|
||||
image_path = os.path.join(input_dir, file)
|
||||
img = node_helpers.pillow(Image.open, image_path)
|
||||
|
||||
if img.mode == "I":
|
||||
img = img.point(lambda i: i * (1 / 255))
|
||||
img = img.convert("RGB")
|
||||
|
||||
if w is None and h is None:
|
||||
w, h = img.size[0], img.size[1]
|
||||
|
||||
# Resize image to first image
|
||||
if img.size[0] != w or img.size[1] != h:
|
||||
if resize_method == "Stretch":
|
||||
img = img.resize((w, h), Image.Resampling.LANCZOS)
|
||||
elif resize_method == "Crop":
|
||||
img = img.crop((0, 0, w, h))
|
||||
elif resize_method == "Pad":
|
||||
img = img.resize((w, h), Image.Resampling.LANCZOS)
|
||||
elif resize_method == "None":
|
||||
raise ValueError(
|
||||
"Your input image size does not match the first image in the dataset. Either select a valid resize method or use the same size for all images."
|
||||
)
|
||||
|
||||
img_array = np.array(img).astype(np.float32) / 255.0
|
||||
img_tensor = torch.from_numpy(img_array)[None,]
|
||||
output_images.append(img_tensor)
|
||||
|
||||
return torch.cat(output_images, dim=0)
|
||||
|
||||
|
||||
class LoadImageSetNode:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"images": (
|
||||
[
|
||||
f
|
||||
for f in os.listdir(folder_paths.get_input_directory())
|
||||
if f.endswith((".png", ".jpg", ".jpeg", ".webp", ".bmp", ".gif", ".jpe", ".apng", ".tif", ".tiff"))
|
||||
],
|
||||
{"image_upload": True, "allow_batch": True},
|
||||
)
|
||||
},
|
||||
"optional": {
|
||||
"resize_method": (
|
||||
["None", "Stretch", "Crop", "Pad"],
|
||||
{"default": "None"},
|
||||
),
|
||||
},
|
||||
}
|
||||
|
||||
INPUT_IS_LIST = True
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "load_images"
|
||||
CATEGORY = "loaders"
|
||||
EXPERIMENTAL = True
|
||||
DESCRIPTION = "Loads a batch of images from a directory for training."
|
||||
|
||||
@classmethod
|
||||
def VALIDATE_INPUTS(s, images, resize_method):
|
||||
filenames = images[0] if isinstance(images[0], list) else images
|
||||
|
||||
for image in filenames:
|
||||
if not folder_paths.exists_annotated_filepath(image):
|
||||
return "Invalid image file: {}".format(image)
|
||||
return True
|
||||
|
||||
def load_images(self, input_files, resize_method):
|
||||
input_dir = folder_paths.get_input_directory()
|
||||
valid_extensions = [".png", ".jpg", ".jpeg", ".webp", ".bmp", ".gif", ".jpe", ".apng", ".tif", ".tiff"]
|
||||
image_files = [
|
||||
f
|
||||
for f in input_files
|
||||
if any(f.lower().endswith(ext) for ext in valid_extensions)
|
||||
]
|
||||
output_tensor = load_and_process_images(image_files, input_dir, resize_method)
|
||||
return (output_tensor,)
|
||||
|
||||
|
||||
class LoadImageSetFromFolderNode:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"folder": (folder_paths.get_input_subfolders(), {"tooltip": "The folder to load images from."})
|
||||
},
|
||||
"optional": {
|
||||
"resize_method": (
|
||||
["None", "Stretch", "Crop", "Pad"],
|
||||
{"default": "None"},
|
||||
),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "load_images"
|
||||
CATEGORY = "loaders"
|
||||
EXPERIMENTAL = True
|
||||
DESCRIPTION = "Loads a batch of images from a directory for training."
|
||||
|
||||
def load_images(self, folder, resize_method):
|
||||
sub_input_dir = os.path.join(folder_paths.get_input_directory(), folder)
|
||||
valid_extensions = [".png", ".jpg", ".jpeg", ".webp"]
|
||||
image_files = [
|
||||
f
|
||||
for f in os.listdir(sub_input_dir)
|
||||
if any(f.lower().endswith(ext) for ext in valid_extensions)
|
||||
]
|
||||
output_tensor = load_and_process_images(image_files, sub_input_dir, resize_method)
|
||||
return (output_tensor,)
|
||||
|
||||
|
||||
def draw_loss_graph(loss_map, steps):
|
||||
width, height = 500, 300
|
||||
img = Image.new("RGB", (width, height), "white")
|
||||
draw = ImageDraw.Draw(img)
|
||||
|
||||
min_loss, max_loss = min(loss_map.values()), max(loss_map.values())
|
||||
scaled_loss = [(l - min_loss) / (max_loss - min_loss) for l in loss_map.values()]
|
||||
|
||||
prev_point = (0, height - int(scaled_loss[0] * height))
|
||||
for i, l in enumerate(scaled_loss[1:], start=1):
|
||||
x = int(i / (steps - 1) * width)
|
||||
y = height - int(l * height)
|
||||
draw.line([prev_point, (x, y)], fill="blue", width=2)
|
||||
prev_point = (x, y)
|
||||
|
||||
return img
|
||||
|
||||
|
||||
def find_all_highest_child_module_with_forward(model: torch.nn.Module, result = None, name = None):
|
||||
if result is None:
|
||||
result = []
|
||||
elif hasattr(model, "forward") and not isinstance(model, (torch.nn.ModuleList, torch.nn.Sequential, torch.nn.ModuleDict)):
|
||||
result.append(model)
|
||||
logging.debug(f"Found module with forward: {name} ({model.__class__.__name__})")
|
||||
return result
|
||||
name = name or "root"
|
||||
for next_name, child in model.named_children():
|
||||
find_all_highest_child_module_with_forward(child, result, f"{name}.{next_name}")
|
||||
return result
|
||||
|
||||
|
||||
def patch(m):
|
||||
if not hasattr(m, "forward"):
|
||||
return
|
||||
org_forward = m.forward
|
||||
def fwd(args, kwargs):
|
||||
return org_forward(*args, **kwargs)
|
||||
def checkpointing_fwd(*args, **kwargs):
|
||||
return torch.utils.checkpoint.checkpoint(
|
||||
fwd, args, kwargs, use_reentrant=False
|
||||
)
|
||||
m.org_forward = org_forward
|
||||
m.forward = checkpointing_fwd
|
||||
|
||||
|
||||
def unpatch(m):
|
||||
if hasattr(m, "org_forward"):
|
||||
m.forward = m.org_forward
|
||||
del m.org_forward
|
||||
|
||||
|
||||
class TrainLoraNode:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"model": (IO.MODEL, {"tooltip": "The model to train the LoRA on."}),
|
||||
"latents": (
|
||||
"LATENT",
|
||||
{
|
||||
"tooltip": "The Latents to use for training, serve as dataset/input of the model."
|
||||
},
|
||||
),
|
||||
"positive": (
|
||||
IO.CONDITIONING,
|
||||
{"tooltip": "The positive conditioning to use for training."},
|
||||
),
|
||||
"batch_size": (
|
||||
IO.INT,
|
||||
{
|
||||
"default": 1,
|
||||
"min": 1,
|
||||
"max": 10000,
|
||||
"step": 1,
|
||||
"tooltip": "The batch size to use for training.",
|
||||
},
|
||||
),
|
||||
"steps": (
|
||||
IO.INT,
|
||||
{
|
||||
"default": 16,
|
||||
"min": 1,
|
||||
"max": 100000,
|
||||
"tooltip": "The number of steps to train the LoRA for.",
|
||||
},
|
||||
),
|
||||
"learning_rate": (
|
||||
IO.FLOAT,
|
||||
{
|
||||
"default": 0.0005,
|
||||
"min": 0.0000001,
|
||||
"max": 1.0,
|
||||
"step": 0.000001,
|
||||
"tooltip": "The learning rate to use for training.",
|
||||
},
|
||||
),
|
||||
"rank": (
|
||||
IO.INT,
|
||||
{
|
||||
"default": 8,
|
||||
"min": 1,
|
||||
"max": 128,
|
||||
"tooltip": "The rank of the LoRA layers.",
|
||||
},
|
||||
),
|
||||
"optimizer": (
|
||||
["AdamW", "Adam", "SGD", "RMSprop"],
|
||||
{
|
||||
"default": "AdamW",
|
||||
"tooltip": "The optimizer to use for training.",
|
||||
},
|
||||
),
|
||||
"loss_function": (
|
||||
["MSE", "L1", "Huber", "SmoothL1"],
|
||||
{
|
||||
"default": "MSE",
|
||||
"tooltip": "The loss function to use for training.",
|
||||
},
|
||||
),
|
||||
"seed": (
|
||||
IO.INT,
|
||||
{
|
||||
"default": 0,
|
||||
"min": 0,
|
||||
"max": 0xFFFFFFFFFFFFFFFF,
|
||||
"tooltip": "The seed to use for training (used in generator for LoRA weight initialization and noise sampling)",
|
||||
},
|
||||
),
|
||||
"training_dtype": (
|
||||
["bf16", "fp32"],
|
||||
{"default": "bf16", "tooltip": "The dtype to use for training."},
|
||||
),
|
||||
"lora_dtype": (
|
||||
["bf16", "fp32"],
|
||||
{"default": "bf16", "tooltip": "The dtype to use for lora."},
|
||||
),
|
||||
"existing_lora": (
|
||||
folder_paths.get_filename_list("loras") + ["[None]"],
|
||||
{
|
||||
"default": "[None]",
|
||||
"tooltip": "The existing LoRA to append to. Set to None for new LoRA.",
|
||||
},
|
||||
),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = (IO.MODEL, IO.LORA_MODEL, IO.LOSS_MAP, IO.INT)
|
||||
RETURN_NAMES = ("model_with_lora", "lora", "loss", "steps")
|
||||
FUNCTION = "train"
|
||||
CATEGORY = "training"
|
||||
EXPERIMENTAL = True
|
||||
|
||||
def train(
|
||||
self,
|
||||
model,
|
||||
latents,
|
||||
positive,
|
||||
batch_size,
|
||||
steps,
|
||||
learning_rate,
|
||||
rank,
|
||||
optimizer,
|
||||
loss_function,
|
||||
seed,
|
||||
training_dtype,
|
||||
lora_dtype,
|
||||
existing_lora,
|
||||
):
|
||||
mp = model.clone()
|
||||
dtype = node_helpers.string_to_torch_dtype(training_dtype)
|
||||
lora_dtype = node_helpers.string_to_torch_dtype(lora_dtype)
|
||||
mp.set_model_compute_dtype(dtype)
|
||||
|
||||
latents = latents["samples"].to(dtype)
|
||||
num_images = latents.shape[0]
|
||||
|
||||
with torch.inference_mode(False):
|
||||
lora_sd = {}
|
||||
generator = torch.Generator()
|
||||
generator.manual_seed(seed)
|
||||
|
||||
# Load existing LoRA weights if provided
|
||||
existing_weights = {}
|
||||
existing_steps = 0
|
||||
if existing_lora != "[None]":
|
||||
lora_path = folder_paths.get_full_path_or_raise("loras", existing_lora)
|
||||
# Extract steps from filename like "trained_lora_10_steps_20250225_203716"
|
||||
existing_steps = int(existing_lora.split("_steps_")[0].split("_")[-1])
|
||||
if lora_path:
|
||||
existing_weights = comfy.utils.load_torch_file(lora_path)
|
||||
|
||||
all_weight_adapters = []
|
||||
for n, m in mp.model.named_modules():
|
||||
if hasattr(m, "weight_function"):
|
||||
if m.weight is not None:
|
||||
key = "{}.weight".format(n)
|
||||
shape = m.weight.shape
|
||||
if len(shape) >= 2:
|
||||
alpha = float(existing_weights.get(f"{key}.alpha", 1.0))
|
||||
dora_scale = existing_weights.get(
|
||||
f"{key}.dora_scale", None
|
||||
)
|
||||
for adapter_cls in adapters:
|
||||
existing_adapter = adapter_cls.load(
|
||||
n, existing_weights, alpha, dora_scale
|
||||
)
|
||||
if existing_adapter is not None:
|
||||
break
|
||||
else:
|
||||
# If no existing adapter found, use LoRA
|
||||
# We will add algo option in the future
|
||||
existing_adapter = None
|
||||
adapter_cls = adapters[0]
|
||||
|
||||
if existing_adapter is not None:
|
||||
train_adapter = existing_adapter.to_train().to(lora_dtype)
|
||||
else:
|
||||
# Use LoRA with alpha=1.0 by default
|
||||
train_adapter = adapter_cls.create_train(
|
||||
m.weight, rank=rank, alpha=1.0
|
||||
).to(lora_dtype)
|
||||
for name, parameter in train_adapter.named_parameters():
|
||||
lora_sd[f"{n}.{name}"] = parameter
|
||||
|
||||
mp.add_weight_wrapper(key, train_adapter)
|
||||
all_weight_adapters.append(train_adapter)
|
||||
else:
|
||||
diff = torch.nn.Parameter(
|
||||
torch.zeros(
|
||||
m.weight.shape, dtype=lora_dtype, requires_grad=True
|
||||
)
|
||||
)
|
||||
diff_module = BiasDiff(diff)
|
||||
mp.add_weight_wrapper(key, BiasDiff(diff))
|
||||
all_weight_adapters.append(diff_module)
|
||||
lora_sd["{}.diff".format(n)] = diff
|
||||
if hasattr(m, "bias") and m.bias is not None:
|
||||
key = "{}.bias".format(n)
|
||||
bias = torch.nn.Parameter(
|
||||
torch.zeros(m.bias.shape, dtype=lora_dtype, requires_grad=True)
|
||||
)
|
||||
bias_module = BiasDiff(bias)
|
||||
lora_sd["{}.diff_b".format(n)] = bias
|
||||
mp.add_weight_wrapper(key, BiasDiff(bias))
|
||||
all_weight_adapters.append(bias_module)
|
||||
|
||||
if optimizer == "Adam":
|
||||
optimizer = torch.optim.Adam(lora_sd.values(), lr=learning_rate)
|
||||
elif optimizer == "AdamW":
|
||||
optimizer = torch.optim.AdamW(lora_sd.values(), lr=learning_rate)
|
||||
elif optimizer == "SGD":
|
||||
optimizer = torch.optim.SGD(lora_sd.values(), lr=learning_rate)
|
||||
elif optimizer == "RMSprop":
|
||||
optimizer = torch.optim.RMSprop(lora_sd.values(), lr=learning_rate)
|
||||
|
||||
# Setup loss function based on selection
|
||||
if loss_function == "MSE":
|
||||
criterion = torch.nn.MSELoss()
|
||||
elif loss_function == "L1":
|
||||
criterion = torch.nn.L1Loss()
|
||||
elif loss_function == "Huber":
|
||||
criterion = torch.nn.HuberLoss()
|
||||
elif loss_function == "SmoothL1":
|
||||
criterion = torch.nn.SmoothL1Loss()
|
||||
|
||||
# setup models
|
||||
for m in find_all_highest_child_module_with_forward(mp.model.diffusion_model):
|
||||
patch(m)
|
||||
comfy.model_management.load_models_gpu([mp], memory_required=1e20, force_full_load=True)
|
||||
|
||||
# Setup sampler and guider like in test script
|
||||
loss_map = {"loss": []}
|
||||
def loss_callback(loss):
|
||||
loss_map["loss"].append(loss)
|
||||
pbar.set_postfix({"loss": f"{loss:.4f}"})
|
||||
train_sampler = TrainSampler(
|
||||
criterion, optimizer, loss_callback=loss_callback
|
||||
)
|
||||
guider = comfy_extras.nodes_custom_sampler.Guider_Basic(mp)
|
||||
guider.set_conds(positive) # Set conditioning from input
|
||||
ss = comfy_extras.nodes_custom_sampler.SamplerCustomAdvanced()
|
||||
|
||||
# yoland: this currently resize to the first image in the dataset
|
||||
|
||||
# Training loop
|
||||
torch.cuda.empty_cache()
|
||||
try:
|
||||
for step in (pbar:=tqdm.trange(steps, desc="Training LoRA", smoothing=0.01, disable=not comfy.utils.PROGRESS_BAR_ENABLED)):
|
||||
# Generate random sigma
|
||||
sigma = mp.model.model_sampling.percent_to_sigma(
|
||||
torch.rand((1,)).item()
|
||||
)
|
||||
sigma = torch.tensor([sigma])
|
||||
|
||||
noise = comfy_extras.nodes_custom_sampler.Noise_RandomNoise(step * 1000 + seed)
|
||||
|
||||
indices = torch.randperm(num_images)[:batch_size]
|
||||
ss.sample(
|
||||
noise, guider, train_sampler, sigma, {"samples": latents[indices].clone()}
|
||||
)
|
||||
finally:
|
||||
for m in mp.model.modules():
|
||||
unpatch(m)
|
||||
del ss, train_sampler, optimizer
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
for adapter in all_weight_adapters:
|
||||
adapter.requires_grad_(False)
|
||||
|
||||
for param in lora_sd:
|
||||
lora_sd[param] = lora_sd[param].to(lora_dtype)
|
||||
|
||||
return (mp, lora_sd, loss_map, steps + existing_steps)
|
||||
|
||||
|
||||
class LoraModelLoader:
|
||||
def __init__(self):
|
||||
self.loaded_lora = None
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"model": ("MODEL", {"tooltip": "The diffusion model the LoRA will be applied to."}),
|
||||
"lora": (IO.LORA_MODEL, {"tooltip": "The LoRA model to apply to the diffusion model."}),
|
||||
"strength_model": ("FLOAT", {"default": 1.0, "min": -100.0, "max": 100.0, "step": 0.01, "tooltip": "How strongly to modify the diffusion model. This value can be negative."}),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
OUTPUT_TOOLTIPS = ("The modified diffusion model.",)
|
||||
FUNCTION = "load_lora_model"
|
||||
|
||||
CATEGORY = "loaders"
|
||||
DESCRIPTION = "Load Trained LoRA weights from Train LoRA node."
|
||||
EXPERIMENTAL = True
|
||||
|
||||
def load_lora_model(self, model, lora, strength_model):
|
||||
if strength_model == 0:
|
||||
return (model, )
|
||||
|
||||
model_lora, _ = comfy.sd.load_lora_for_models(model, None, lora, strength_model, 0)
|
||||
return (model_lora, )
|
||||
|
||||
|
||||
class SaveLoRA:
|
||||
def __init__(self):
|
||||
self.output_dir = folder_paths.get_output_directory()
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"lora": (
|
||||
IO.LORA_MODEL,
|
||||
{
|
||||
"tooltip": "The LoRA model to save. Do not use the model with LoRA layers."
|
||||
},
|
||||
),
|
||||
"prefix": (
|
||||
"STRING",
|
||||
{
|
||||
"default": "loras/ComfyUI_trained_lora",
|
||||
"tooltip": "The prefix to use for the saved LoRA file.",
|
||||
},
|
||||
),
|
||||
},
|
||||
"optional": {
|
||||
"steps": (
|
||||
IO.INT,
|
||||
{
|
||||
"forceInput": True,
|
||||
"tooltip": "Optional: The number of steps to LoRA has been trained for, used to name the saved file.",
|
||||
},
|
||||
),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ()
|
||||
FUNCTION = "save"
|
||||
CATEGORY = "loaders"
|
||||
EXPERIMENTAL = True
|
||||
OUTPUT_NODE = True
|
||||
|
||||
def save(self, lora, prefix, steps=None):
|
||||
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(prefix, self.output_dir)
|
||||
if steps is None:
|
||||
output_checkpoint = f"{filename}_{counter:05}_.safetensors"
|
||||
else:
|
||||
output_checkpoint = f"{filename}_{steps}_steps_{counter:05}_.safetensors"
|
||||
output_checkpoint = os.path.join(full_output_folder, output_checkpoint)
|
||||
safetensors.torch.save_file(lora, output_checkpoint)
|
||||
return {}
|
||||
|
||||
|
||||
class LossGraphNode:
|
||||
def __init__(self):
|
||||
self.output_dir = folder_paths.get_temp_directory()
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"loss": (IO.LOSS_MAP, {"default": {}}),
|
||||
"filename_prefix": (IO.STRING, {"default": "loss_graph"}),
|
||||
},
|
||||
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ()
|
||||
FUNCTION = "plot_loss"
|
||||
OUTPUT_NODE = True
|
||||
CATEGORY = "training"
|
||||
EXPERIMENTAL = True
|
||||
DESCRIPTION = "Plots the loss graph and saves it to the output directory."
|
||||
|
||||
def plot_loss(self, loss, filename_prefix, prompt=None, extra_pnginfo=None):
|
||||
loss_values = loss["loss"]
|
||||
width, height = 800, 480
|
||||
margin = 40
|
||||
|
||||
img = Image.new(
|
||||
"RGB", (width + margin, height + margin), "white"
|
||||
) # Extend canvas
|
||||
draw = ImageDraw.Draw(img)
|
||||
|
||||
min_loss, max_loss = min(loss_values), max(loss_values)
|
||||
scaled_loss = [(l - min_loss) / (max_loss - min_loss) for l in loss_values]
|
||||
|
||||
steps = len(loss_values)
|
||||
|
||||
prev_point = (margin, height - int(scaled_loss[0] * height))
|
||||
for i, l in enumerate(scaled_loss[1:], start=1):
|
||||
x = margin + int(i / steps * width) # Scale X properly
|
||||
y = height - int(l * height)
|
||||
draw.line([prev_point, (x, y)], fill="blue", width=2)
|
||||
prev_point = (x, y)
|
||||
|
||||
draw.line([(margin, 0), (margin, height)], fill="black", width=2) # Y-axis
|
||||
draw.line(
|
||||
[(margin, height), (width + margin, height)], fill="black", width=2
|
||||
) # X-axis
|
||||
|
||||
font = None
|
||||
try:
|
||||
font = ImageFont.truetype("arial.ttf", 12)
|
||||
except IOError:
|
||||
font = ImageFont.load_default()
|
||||
|
||||
# Add axis labels
|
||||
draw.text((5, height // 2), "Loss", font=font, fill="black")
|
||||
draw.text((width // 2, height + 10), "Steps", font=font, fill="black")
|
||||
|
||||
# Add min/max loss values
|
||||
draw.text((margin - 30, 0), f"{max_loss:.2f}", font=font, fill="black")
|
||||
draw.text(
|
||||
(margin - 30, height - 10), f"{min_loss:.2f}", font=font, fill="black"
|
||||
)
|
||||
|
||||
metadata = None
|
||||
if not args.disable_metadata:
|
||||
metadata = PngInfo()
|
||||
if prompt is not None:
|
||||
metadata.add_text("prompt", json.dumps(prompt))
|
||||
if extra_pnginfo is not None:
|
||||
for x in extra_pnginfo:
|
||||
metadata.add_text(x, json.dumps(extra_pnginfo[x]))
|
||||
|
||||
date = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
img.save(
|
||||
os.path.join(self.output_dir, f"{filename_prefix}_{date}.png"),
|
||||
pnginfo=metadata,
|
||||
)
|
||||
return {
|
||||
"ui": {
|
||||
"images": [
|
||||
{
|
||||
"filename": f"{filename_prefix}_{date}.png",
|
||||
"subfolder": "",
|
||||
"type": "temp",
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"TrainLoraNode": TrainLoraNode,
|
||||
"SaveLoRANode": SaveLoRA,
|
||||
"LoraModelLoader": LoraModelLoader,
|
||||
"LoadImageSetFromFolderNode": LoadImageSetFromFolderNode,
|
||||
"LossGraphNode": LossGraphNode,
|
||||
}
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"TrainLoraNode": "Train LoRA",
|
||||
"SaveLoRANode": "Save LoRA Weights",
|
||||
"LoraModelLoader": "Load LoRA Model",
|
||||
"LoadImageSetFromFolderNode": "Load Image Dataset from Folder",
|
||||
"LossGraphNode": "Plot Loss Graph",
|
||||
}
|
@@ -23,6 +23,10 @@ class WebcamCapture(nodes.LoadImage):
|
||||
def load_capture(self, image, **kwargs):
|
||||
return super().load_image(folder_paths.get_annotated_filepath(image))
|
||||
|
||||
@classmethod
|
||||
def IS_CHANGED(cls, image, width, height, capture_on_queue):
|
||||
return super().IS_CHANGED(image)
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"WebcamCapture": WebcamCapture,
|
||||
|
@@ -1,3 +1,3 @@
|
||||
# This file is automatically generated by the build process when version is
|
||||
# updated in pyproject.toml.
|
||||
__version__ = "0.3.39"
|
||||
__version__ = "0.3.43"
|
||||
|
174
execution.py
174
execution.py
@@ -1,22 +1,38 @@
|
||||
import sys
|
||||
import copy
|
||||
import logging
|
||||
import threading
|
||||
import heapq
|
||||
import inspect
|
||||
import logging
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
import traceback
|
||||
from enum import Enum
|
||||
import inspect
|
||||
from typing import List, Literal, NamedTuple, Optional
|
||||
import asyncio
|
||||
|
||||
import torch
|
||||
import nodes
|
||||
|
||||
import comfy.model_management
|
||||
from comfy_execution.graph import get_input_info, ExecutionList, DynamicPrompt, ExecutionBlocker
|
||||
from comfy_execution.graph_utils import is_link, GraphBuilder
|
||||
from comfy_execution.caching import HierarchicalCache, LRUCache, DependencyAwareCache, CacheKeySetInputSignature, CacheKeySetID
|
||||
import nodes
|
||||
from comfy_execution.caching import (
|
||||
BasicCache,
|
||||
CacheKeySetID,
|
||||
CacheKeySetInputSignature,
|
||||
DependencyAwareCache,
|
||||
HierarchicalCache,
|
||||
LRUCache,
|
||||
)
|
||||
from comfy_execution.graph import (
|
||||
DynamicPrompt,
|
||||
ExecutionBlocker,
|
||||
ExecutionList,
|
||||
get_input_info,
|
||||
)
|
||||
from comfy_execution.graph_utils import GraphBuilder, is_link
|
||||
from comfy_execution.validation import validate_node_input
|
||||
from comfy_execution.progress import get_progress_state, reset_progress_state, add_progress_handler, WebUIProgressHandler
|
||||
from comfy_execution.utils import CurrentNodeContext
|
||||
|
||||
|
||||
class ExecutionResult(Enum):
|
||||
SUCCESS = 0
|
||||
@@ -27,12 +43,13 @@ class DuplicateNodeError(Exception):
|
||||
pass
|
||||
|
||||
class IsChangedCache:
|
||||
def __init__(self, dynprompt, outputs_cache):
|
||||
def __init__(self, prompt_id: str, dynprompt: DynamicPrompt, outputs_cache: BasicCache):
|
||||
self.prompt_id = prompt_id
|
||||
self.dynprompt = dynprompt
|
||||
self.outputs_cache = outputs_cache
|
||||
self.is_changed = {}
|
||||
|
||||
def get(self, node_id):
|
||||
async def get(self, node_id):
|
||||
if node_id in self.is_changed:
|
||||
return self.is_changed[node_id]
|
||||
|
||||
@@ -50,7 +67,8 @@ class IsChangedCache:
|
||||
# Intentionally do not use cached outputs here. We only want constants in IS_CHANGED
|
||||
input_data_all, _ = get_input_data(node["inputs"], class_def, node_id, None)
|
||||
try:
|
||||
is_changed = _map_node_over_list(class_def, input_data_all, "IS_CHANGED")
|
||||
is_changed = await _async_map_node_over_list(self.prompt_id, node_id, class_def, input_data_all, "IS_CHANGED")
|
||||
is_changed = await resolve_map_node_over_list_results(is_changed)
|
||||
node["is_changed"] = [None if isinstance(x, ExecutionBlocker) else x for x in is_changed]
|
||||
except Exception as e:
|
||||
logging.warning("WARNING: {}".format(e))
|
||||
@@ -152,7 +170,19 @@ def get_input_data(inputs, class_def, unique_id, outputs=None, dynprompt=None, e
|
||||
|
||||
map_node_over_list = None #Don't hook this please
|
||||
|
||||
def _map_node_over_list(obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None):
|
||||
async def resolve_map_node_over_list_results(results):
|
||||
remaining = [x for x in results if isinstance(x, asyncio.Task) and not x.done()]
|
||||
if len(remaining) == 0:
|
||||
return [x.result() if isinstance(x, asyncio.Task) else x for x in results]
|
||||
else:
|
||||
done, pending = await asyncio.wait(remaining)
|
||||
for task in done:
|
||||
exc = task.exception()
|
||||
if exc is not None:
|
||||
raise exc
|
||||
return [x.result() if isinstance(x, asyncio.Task) else x for x in results]
|
||||
|
||||
async def _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None):
|
||||
# check if node wants the lists
|
||||
input_is_list = getattr(obj, "INPUT_IS_LIST", False)
|
||||
|
||||
@@ -166,7 +196,7 @@ def _map_node_over_list(obj, input_data_all, func, allow_interrupt=False, execut
|
||||
return {k: v[i if len(v) > i else -1] for k, v in d.items()}
|
||||
|
||||
results = []
|
||||
def process_inputs(inputs, index=None, input_is_list=False):
|
||||
async def process_inputs(inputs, index=None, input_is_list=False):
|
||||
if allow_interrupt:
|
||||
nodes.before_node_execution()
|
||||
execution_block = None
|
||||
@@ -182,20 +212,37 @@ def _map_node_over_list(obj, input_data_all, func, allow_interrupt=False, execut
|
||||
if execution_block is None:
|
||||
if pre_execute_cb is not None and index is not None:
|
||||
pre_execute_cb(index)
|
||||
results.append(getattr(obj, func)(**inputs))
|
||||
f = getattr(obj, func)
|
||||
if inspect.iscoroutinefunction(f):
|
||||
async def async_wrapper(f, prompt_id, unique_id, list_index, args):
|
||||
with CurrentNodeContext(prompt_id, unique_id, list_index):
|
||||
return await f(**args)
|
||||
task = asyncio.create_task(async_wrapper(f, prompt_id, unique_id, index, args=inputs))
|
||||
# Give the task a chance to execute without yielding
|
||||
await asyncio.sleep(0)
|
||||
if task.done():
|
||||
result = task.result()
|
||||
results.append(result)
|
||||
else:
|
||||
results.append(task)
|
||||
else:
|
||||
with CurrentNodeContext(prompt_id, unique_id, index):
|
||||
result = f(**inputs)
|
||||
results.append(result)
|
||||
else:
|
||||
results.append(execution_block)
|
||||
|
||||
if input_is_list:
|
||||
process_inputs(input_data_all, 0, input_is_list=input_is_list)
|
||||
await process_inputs(input_data_all, 0, input_is_list=input_is_list)
|
||||
elif max_len_input == 0:
|
||||
process_inputs({})
|
||||
await process_inputs({})
|
||||
else:
|
||||
for i in range(max_len_input):
|
||||
input_dict = slice_dict(input_data_all, i)
|
||||
process_inputs(input_dict, i)
|
||||
await process_inputs(input_dict, i)
|
||||
return results
|
||||
|
||||
|
||||
def merge_result_data(results, obj):
|
||||
# check which outputs need concatenating
|
||||
output = []
|
||||
@@ -217,11 +264,18 @@ def merge_result_data(results, obj):
|
||||
output.append([o[i] for o in results])
|
||||
return output
|
||||
|
||||
def get_output_data(obj, input_data_all, execution_block_cb=None, pre_execute_cb=None):
|
||||
async def get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=None, pre_execute_cb=None):
|
||||
return_values = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, obj.FUNCTION, allow_interrupt=True, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb)
|
||||
has_pending_task = any(isinstance(r, asyncio.Task) and not r.done() for r in return_values)
|
||||
if has_pending_task:
|
||||
return return_values, {}, False, has_pending_task
|
||||
output, ui, has_subgraph = get_output_from_returns(return_values, obj)
|
||||
return output, ui, has_subgraph, False
|
||||
|
||||
def get_output_from_returns(return_values, obj):
|
||||
results = []
|
||||
uis = []
|
||||
subgraph_results = []
|
||||
return_values = _map_node_over_list(obj, input_data_all, obj.FUNCTION, allow_interrupt=True, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb)
|
||||
has_subgraph = False
|
||||
for i in range(len(return_values)):
|
||||
r = return_values[i]
|
||||
@@ -255,6 +309,10 @@ def get_output_data(obj, input_data_all, execution_block_cb=None, pre_execute_cb
|
||||
else:
|
||||
output = []
|
||||
ui = dict()
|
||||
# TODO: Think there's an existing bug here
|
||||
# If we're performing a subgraph expansion, we probably shouldn't be returning UI values yet.
|
||||
# They'll get cached without the completed subgraphs. It's an edge case and I'm not aware of
|
||||
# any nodes that use both subgraph expansion and custom UI outputs, but might be a problem in the future.
|
||||
if len(uis) > 0:
|
||||
ui = {k: [y for x in uis for y in x[k]] for k in uis[0].keys()}
|
||||
return output, ui, has_subgraph
|
||||
@@ -267,7 +325,7 @@ def format_value(x):
|
||||
else:
|
||||
return str(x)
|
||||
|
||||
def execute(server, dynprompt, caches, current_item, extra_data, executed, prompt_id, execution_list, pending_subgraph_results):
|
||||
async def execute(server, dynprompt, caches, current_item, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes):
|
||||
unique_id = current_item
|
||||
real_node_id = dynprompt.get_real_node_id(unique_id)
|
||||
display_node_id = dynprompt.get_display_node_id(unique_id)
|
||||
@@ -279,11 +337,26 @@ def execute(server, dynprompt, caches, current_item, extra_data, executed, promp
|
||||
if server.client_id is not None:
|
||||
cached_output = caches.ui.get(unique_id) or {}
|
||||
server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": cached_output.get("output",None), "prompt_id": prompt_id }, server.client_id)
|
||||
get_progress_state().finish_progress(unique_id)
|
||||
return (ExecutionResult.SUCCESS, None, None)
|
||||
|
||||
input_data_all = None
|
||||
try:
|
||||
if unique_id in pending_subgraph_results:
|
||||
if unique_id in pending_async_nodes:
|
||||
results = []
|
||||
for r in pending_async_nodes[unique_id]:
|
||||
if isinstance(r, asyncio.Task):
|
||||
try:
|
||||
results.append(r.result())
|
||||
except Exception as ex:
|
||||
# An async task failed - propagate the exception up
|
||||
del pending_async_nodes[unique_id]
|
||||
raise ex
|
||||
else:
|
||||
results.append(r)
|
||||
del pending_async_nodes[unique_id]
|
||||
output_data, output_ui, has_subgraph = get_output_from_returns(results, class_def)
|
||||
elif unique_id in pending_subgraph_results:
|
||||
cached_results = pending_subgraph_results[unique_id]
|
||||
resolved_outputs = []
|
||||
for is_subgraph, result in cached_results:
|
||||
@@ -305,6 +378,7 @@ def execute(server, dynprompt, caches, current_item, extra_data, executed, promp
|
||||
output_ui = []
|
||||
has_subgraph = False
|
||||
else:
|
||||
get_progress_state().start_progress(unique_id)
|
||||
input_data_all, missing_keys = get_input_data(inputs, class_def, unique_id, caches.outputs, dynprompt, extra_data)
|
||||
if server.client_id is not None:
|
||||
server.last_node_id = display_node_id
|
||||
@@ -316,7 +390,8 @@ def execute(server, dynprompt, caches, current_item, extra_data, executed, promp
|
||||
caches.objects.set(unique_id, obj)
|
||||
|
||||
if hasattr(obj, "check_lazy_status"):
|
||||
required_inputs = _map_node_over_list(obj, input_data_all, "check_lazy_status", allow_interrupt=True)
|
||||
required_inputs = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, "check_lazy_status", allow_interrupt=True)
|
||||
required_inputs = await resolve_map_node_over_list_results(required_inputs)
|
||||
required_inputs = set(sum([r for r in required_inputs if isinstance(r,list)], []))
|
||||
required_inputs = [x for x in required_inputs if isinstance(x,str) and (
|
||||
x not in input_data_all or x in missing_keys
|
||||
@@ -345,8 +420,18 @@ def execute(server, dynprompt, caches, current_item, extra_data, executed, promp
|
||||
else:
|
||||
return block
|
||||
def pre_execute_cb(call_index):
|
||||
# TODO - How to handle this with async functions without contextvars (which requires Python 3.12)?
|
||||
GraphBuilder.set_default_prefix(unique_id, call_index, 0)
|
||||
output_data, output_ui, has_subgraph = get_output_data(obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb)
|
||||
output_data, output_ui, has_subgraph, has_pending_tasks = await get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb)
|
||||
if has_pending_tasks:
|
||||
pending_async_nodes[unique_id] = output_data
|
||||
unblock = execution_list.add_external_block(unique_id)
|
||||
async def await_completion():
|
||||
tasks = [x for x in output_data if isinstance(x, asyncio.Task)]
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
unblock()
|
||||
asyncio.create_task(await_completion())
|
||||
return (ExecutionResult.PENDING, None, None)
|
||||
if len(output_ui) > 0:
|
||||
caches.ui.set(unique_id, {
|
||||
"meta": {
|
||||
@@ -389,7 +474,8 @@ def execute(server, dynprompt, caches, current_item, extra_data, executed, promp
|
||||
cached_outputs.append((True, node_outputs))
|
||||
new_node_ids = set(new_node_ids)
|
||||
for cache in caches.all:
|
||||
cache.ensure_subcache_for(unique_id, new_node_ids).clean_unused()
|
||||
subcache = await cache.ensure_subcache_for(unique_id, new_node_ids)
|
||||
subcache.clean_unused()
|
||||
for node_id in new_output_ids:
|
||||
execution_list.add_node(node_id)
|
||||
for link in new_output_links:
|
||||
@@ -417,20 +503,24 @@ def execute(server, dynprompt, caches, current_item, extra_data, executed, promp
|
||||
|
||||
logging.error(f"!!! Exception during processing !!! {ex}")
|
||||
logging.error(traceback.format_exc())
|
||||
tips = ""
|
||||
|
||||
if isinstance(ex, comfy.model_management.OOM_EXCEPTION):
|
||||
tips = "This error means you ran out of memory on your GPU.\n\nTIPS: If the workflow worked before you might have accidentally set the batch_size to a large number."
|
||||
logging.error("Got an OOM, unloading all loaded models.")
|
||||
comfy.model_management.unload_all_models()
|
||||
|
||||
error_details = {
|
||||
"node_id": real_node_id,
|
||||
"exception_message": str(ex),
|
||||
"exception_message": "{}\n{}".format(ex, tips),
|
||||
"exception_type": exception_type,
|
||||
"traceback": traceback.format_tb(tb),
|
||||
"current_inputs": input_data_formatted
|
||||
}
|
||||
if isinstance(ex, comfy.model_management.OOM_EXCEPTION):
|
||||
logging.error("Got an OOM, unloading all loaded models.")
|
||||
comfy.model_management.unload_all_models()
|
||||
|
||||
return (ExecutionResult.FAILURE, error_details, ex)
|
||||
|
||||
get_progress_state().finish_progress(unique_id)
|
||||
executed.add(unique_id)
|
||||
|
||||
return (ExecutionResult.SUCCESS, None, None)
|
||||
@@ -485,6 +575,11 @@ class PromptExecutor:
|
||||
self.add_message("execution_error", mes, broadcast=False)
|
||||
|
||||
def execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]):
|
||||
asyncio_loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(asyncio_loop)
|
||||
asyncio.run(self.execute_async(prompt, prompt_id, extra_data, execute_outputs))
|
||||
|
||||
async def execute_async(self, prompt, prompt_id, extra_data={}, execute_outputs=[]):
|
||||
nodes.interrupt_processing(False)
|
||||
|
||||
if "client_id" in extra_data:
|
||||
@@ -497,9 +592,11 @@ class PromptExecutor:
|
||||
|
||||
with torch.inference_mode():
|
||||
dynamic_prompt = DynamicPrompt(prompt)
|
||||
is_changed_cache = IsChangedCache(dynamic_prompt, self.caches.outputs)
|
||||
reset_progress_state(prompt_id, dynamic_prompt)
|
||||
add_progress_handler(WebUIProgressHandler(self.server))
|
||||
is_changed_cache = IsChangedCache(prompt_id, dynamic_prompt, self.caches.outputs)
|
||||
for cache in self.caches.all:
|
||||
cache.set_prompt(dynamic_prompt, prompt.keys(), is_changed_cache)
|
||||
await cache.set_prompt(dynamic_prompt, prompt.keys(), is_changed_cache)
|
||||
cache.clean_unused()
|
||||
|
||||
cached_nodes = []
|
||||
@@ -512,6 +609,7 @@ class PromptExecutor:
|
||||
{ "nodes": cached_nodes, "prompt_id": prompt_id},
|
||||
broadcast=False)
|
||||
pending_subgraph_results = {}
|
||||
pending_async_nodes = {} # TODO - Unify this with pending_subgraph_results
|
||||
executed = set()
|
||||
execution_list = ExecutionList(dynamic_prompt, self.caches.outputs)
|
||||
current_outputs = self.caches.outputs.all_node_ids()
|
||||
@@ -519,12 +617,13 @@ class PromptExecutor:
|
||||
execution_list.add_node(node_id)
|
||||
|
||||
while not execution_list.is_empty():
|
||||
node_id, error, ex = execution_list.stage_node_execution()
|
||||
node_id, error, ex = await execution_list.stage_node_execution()
|
||||
if error is not None:
|
||||
self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex)
|
||||
break
|
||||
|
||||
result, error, ex = execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results)
|
||||
assert node_id is not None, "Node ID should not be None at this point"
|
||||
result, error, ex = await execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes)
|
||||
self.success = result != ExecutionResult.FAILURE
|
||||
if result == ExecutionResult.FAILURE:
|
||||
self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex)
|
||||
@@ -554,7 +653,7 @@ class PromptExecutor:
|
||||
comfy.model_management.unload_all_models()
|
||||
|
||||
|
||||
def validate_inputs(prompt, item, validated):
|
||||
async def validate_inputs(prompt_id, prompt, item, validated):
|
||||
unique_id = item
|
||||
if unique_id in validated:
|
||||
return validated[unique_id]
|
||||
@@ -631,7 +730,7 @@ def validate_inputs(prompt, item, validated):
|
||||
errors.append(error)
|
||||
continue
|
||||
try:
|
||||
r = validate_inputs(prompt, o_id, validated)
|
||||
r = await validate_inputs(prompt_id, prompt, o_id, validated)
|
||||
if r[0] is False:
|
||||
# `r` will be set in `validated[o_id]` already
|
||||
valid = False
|
||||
@@ -756,7 +855,8 @@ def validate_inputs(prompt, item, validated):
|
||||
input_filtered['input_types'] = [received_types]
|
||||
|
||||
#ret = obj_class.VALIDATE_INPUTS(**input_filtered)
|
||||
ret = _map_node_over_list(obj_class, input_filtered, "VALIDATE_INPUTS")
|
||||
ret = await _async_map_node_over_list(prompt_id, unique_id, obj_class, input_filtered, "VALIDATE_INPUTS")
|
||||
ret = await resolve_map_node_over_list_results(ret)
|
||||
for x in input_filtered:
|
||||
for i, r in enumerate(ret):
|
||||
if r is not True and not isinstance(r, ExecutionBlocker):
|
||||
@@ -789,7 +889,7 @@ def full_type_name(klass):
|
||||
return klass.__qualname__
|
||||
return module + '.' + klass.__qualname__
|
||||
|
||||
def validate_prompt(prompt):
|
||||
async def validate_prompt(prompt_id, prompt):
|
||||
outputs = set()
|
||||
for x in prompt:
|
||||
if 'class_type' not in prompt[x]:
|
||||
@@ -832,7 +932,7 @@ def validate_prompt(prompt):
|
||||
valid = False
|
||||
reasons = []
|
||||
try:
|
||||
m = validate_inputs(prompt, o, validated)
|
||||
m = await validate_inputs(prompt_id, prompt, o, validated)
|
||||
valid = m[0]
|
||||
reasons = m[1]
|
||||
except Exception as ex:
|
||||
|
@@ -276,6 +276,9 @@ def filter_files_extensions(files: Collection[str], extensions: Collection[str])
|
||||
|
||||
|
||||
def get_full_path(folder_name: str, filename: str) -> str | None:
|
||||
"""
|
||||
Get the full path of a file in a folder, has to be a file
|
||||
"""
|
||||
global folder_names_and_paths
|
||||
folder_name = map_legacy(folder_name)
|
||||
if folder_name not in folder_names_and_paths:
|
||||
@@ -293,6 +296,9 @@ def get_full_path(folder_name: str, filename: str) -> str | None:
|
||||
|
||||
|
||||
def get_full_path_or_raise(folder_name: str, filename: str) -> str:
|
||||
"""
|
||||
Get the full path of a file in a folder, has to be a file
|
||||
"""
|
||||
full_path = get_full_path(folder_name, filename)
|
||||
if full_path is None:
|
||||
raise FileNotFoundError(f"Model in folder '{folder_name}' with filename '{filename}' not found.")
|
||||
@@ -394,3 +400,26 @@ def get_save_image_path(filename_prefix: str, output_dir: str, image_width=0, im
|
||||
os.makedirs(full_output_folder, exist_ok=True)
|
||||
counter = 1
|
||||
return full_output_folder, filename, counter, subfolder, filename_prefix
|
||||
|
||||
def get_input_subfolders() -> list[str]:
|
||||
"""Returns a list of all subfolder paths in the input directory, recursively.
|
||||
|
||||
Returns:
|
||||
List of folder paths relative to the input directory, excluding the root directory
|
||||
"""
|
||||
input_dir = get_input_directory()
|
||||
folders = []
|
||||
|
||||
try:
|
||||
if not os.path.exists(input_dir):
|
||||
return []
|
||||
|
||||
for root, dirs, _ in os.walk(input_dir):
|
||||
rel_path = os.path.relpath(root, input_dir)
|
||||
if rel_path != ".": # Only include non-root directories
|
||||
# Normalize path separators to forward slashes
|
||||
folders.append(rel_path.replace(os.sep, '/'))
|
||||
|
||||
return sorted(folders)
|
||||
except FileNotFoundError:
|
||||
return []
|
||||
|
57
main.py
57
main.py
@@ -11,13 +11,14 @@ import itertools
|
||||
import utils.extra_config
|
||||
import logging
|
||||
import sys
|
||||
from comfy_execution.progress import get_progress_state
|
||||
from comfy_execution.utils import get_executing_context
|
||||
|
||||
if __name__ == "__main__":
|
||||
#NOTE: These do not do anything on core ComfyUI, they are for custom nodes.
|
||||
os.environ['HF_HUB_DISABLE_TELEMETRY'] = '1'
|
||||
os.environ['DO_NOT_TRACK'] = '1'
|
||||
|
||||
|
||||
setup_logger(log_level=args.verbose, use_stdout=args.log_stdout)
|
||||
|
||||
def apply_custom_paths():
|
||||
@@ -56,6 +57,9 @@ def apply_custom_paths():
|
||||
|
||||
|
||||
def execute_prestartup_script():
|
||||
if args.disable_all_custom_nodes and len(args.whitelist_custom_nodes) == 0:
|
||||
return
|
||||
|
||||
def execute_script(script_path):
|
||||
module_name = os.path.splitext(script_path)[0]
|
||||
try:
|
||||
@@ -67,9 +71,6 @@ def execute_prestartup_script():
|
||||
logging.error(f"Failed to execute startup-script: {script_path} / {e}")
|
||||
return False
|
||||
|
||||
if args.disable_all_custom_nodes:
|
||||
return
|
||||
|
||||
node_paths = folder_paths.get_folder_paths("custom_nodes")
|
||||
for custom_node_path in node_paths:
|
||||
possible_modules = os.listdir(custom_node_path)
|
||||
@@ -82,6 +83,9 @@ def execute_prestartup_script():
|
||||
|
||||
script_path = os.path.join(module_path, "prestartup_script.py")
|
||||
if os.path.exists(script_path):
|
||||
if args.disable_all_custom_nodes and possible_module not in args.whitelist_custom_nodes:
|
||||
logging.info(f"Prestartup Skipping {possible_module} due to disable_all_custom_nodes and whitelist_custom_nodes")
|
||||
continue
|
||||
time_before = time.perf_counter()
|
||||
success = execute_script(script_path)
|
||||
node_prestartup_times.append((time.perf_counter() - time_before, module_path, success))
|
||||
@@ -129,7 +133,7 @@ import comfy.utils
|
||||
|
||||
import execution
|
||||
import server
|
||||
from server import BinaryEventTypes
|
||||
from protocol import BinaryEventTypes
|
||||
import nodes
|
||||
import comfy.model_management
|
||||
import comfyui_version
|
||||
@@ -186,7 +190,13 @@ def prompt_worker(q, server_instance):
|
||||
|
||||
current_time = time.perf_counter()
|
||||
execution_time = current_time - execution_start_time
|
||||
logging.info("Prompt executed in {:.2f} seconds".format(execution_time))
|
||||
|
||||
# Log Time in a more readable way after 10 minutes
|
||||
if execution_time > 600:
|
||||
execution_time = time.strftime("%H:%M:%S", time.gmtime(execution_time))
|
||||
logging.info(f"Prompt executed in {execution_time}")
|
||||
else:
|
||||
logging.info("Prompt executed in {:.2f} seconds".format(execution_time))
|
||||
|
||||
flags = q.get_flags()
|
||||
free_memory = flags.get("free_memory", False)
|
||||
@@ -219,14 +229,25 @@ async def run(server_instance, address='', port=8188, verbose=True, call_on_star
|
||||
server_instance.start_multi_address(addresses, call_on_start, verbose), server_instance.publish_loop()
|
||||
)
|
||||
|
||||
|
||||
def hijack_progress(server_instance):
|
||||
def hook(value, total, preview_image):
|
||||
def hook(value, total, preview_image, prompt_id=None, node_id=None):
|
||||
executing_context = get_executing_context()
|
||||
if prompt_id is None and executing_context is not None:
|
||||
prompt_id = executing_context.prompt_id
|
||||
if node_id is None and executing_context is not None:
|
||||
node_id = executing_context.node_id
|
||||
comfy.model_management.throw_exception_if_processing_interrupted()
|
||||
progress = {"value": value, "max": total, "prompt_id": server_instance.last_prompt_id, "node": server_instance.last_node_id}
|
||||
if prompt_id is None:
|
||||
prompt_id = server_instance.last_prompt_id
|
||||
if node_id is None:
|
||||
node_id = server_instance.last_node_id
|
||||
progress = {"value": value, "max": total, "prompt_id": prompt_id, "node": node_id}
|
||||
get_progress_state().update_progress(node_id, value, total, preview_image)
|
||||
|
||||
server_instance.send_sync("progress", progress, server_instance.client_id)
|
||||
if preview_image is not None:
|
||||
# Also send old method for backward compatibility
|
||||
# TODO - Remove after this repo is updated to frontend with metadata support
|
||||
server_instance.send_sync(BinaryEventTypes.UNENCODED_PREVIEW_IMAGE, preview_image, server_instance.client_id)
|
||||
|
||||
comfy.utils.set_progress_bar_global_hook(hook)
|
||||
@@ -238,6 +259,15 @@ def cleanup_temp():
|
||||
shutil.rmtree(temp_dir, ignore_errors=True)
|
||||
|
||||
|
||||
def setup_database():
|
||||
try:
|
||||
from app.database.db import init_db, dependencies_available
|
||||
if dependencies_available():
|
||||
init_db()
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to initialize database. Please ensure you have installed the latest requirements. If the error persists, please report this as in future the database will be required: {e}")
|
||||
|
||||
|
||||
def start_comfyui(asyncio_loop=None):
|
||||
"""
|
||||
Starts the ComfyUI server using the provided asyncio event loop or creates a new one.
|
||||
@@ -262,10 +292,14 @@ def start_comfyui(asyncio_loop=None):
|
||||
prompt_server = server.PromptServer(asyncio_loop)
|
||||
|
||||
hook_breaker_ac10a0.save_functions()
|
||||
nodes.init_extra_nodes(init_custom_nodes=not args.disable_all_custom_nodes, init_api_nodes=not args.disable_api_nodes)
|
||||
nodes.init_extra_nodes(
|
||||
init_custom_nodes=(not args.disable_all_custom_nodes) or len(args.whitelist_custom_nodes) > 0,
|
||||
init_api_nodes=not args.disable_api_nodes
|
||||
)
|
||||
hook_breaker_ac10a0.restore_functions()
|
||||
|
||||
cuda_malloc_warning()
|
||||
setup_database()
|
||||
|
||||
prompt_server.add_routes()
|
||||
hijack_progress(prompt_server)
|
||||
@@ -300,6 +334,9 @@ if __name__ == "__main__":
|
||||
logging.info("Python version: {}".format(sys.version))
|
||||
logging.info("ComfyUI version: {}".format(comfyui_version.__version__))
|
||||
|
||||
if sys.version_info.major == 3 and sys.version_info.minor < 10:
|
||||
logging.warning("WARNING: You are using a python version older than 3.10, please upgrade to a newer one. 3.12 and above is recommended.")
|
||||
|
||||
event_loop, _, start_all_func = start_comfyui()
|
||||
try:
|
||||
x = start_all_func()
|
||||
|
31
nodes.py
31
nodes.py
@@ -920,7 +920,7 @@ class CLIPLoader:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "clip_name": (folder_paths.get_filename_list("text_encoders"), ),
|
||||
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace"], ),
|
||||
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2"], ),
|
||||
},
|
||||
"optional": {
|
||||
"device": (["default", "cpu"], {"advanced": True}),
|
||||
@@ -930,7 +930,7 @@ class CLIPLoader:
|
||||
|
||||
CATEGORY = "advanced/loaders"
|
||||
|
||||
DESCRIPTION = "[Recipes]\n\nstable_diffusion: clip-l\nstable_cascade: clip-g\nsd3: t5 xxl/ clip-g / clip-l\nstable_audio: t5 base\nmochi: t5 xxl\ncosmos: old t5 xxl\nlumina2: gemma 2 2B\nwan: umt5 xxl\n hidream: llama-3.1 (Recommend) or t5"
|
||||
DESCRIPTION = "[Recipes]\n\nstable_diffusion: clip-l\nstable_cascade: clip-g\nsd3: t5 xxl/ clip-g / clip-l\nstable_audio: t5 base\nmochi: t5 xxl\ncosmos: old t5 xxl\nlumina2: gemma 2 2B\nwan: umt5 xxl\n hidream: llama-3.1 (Recommend) or t5\nomnigen2: qwen vl 2.5 3B"
|
||||
|
||||
def load_clip(self, clip_name, type="stable_diffusion", device="default"):
|
||||
clip_type = getattr(comfy.sd.CLIPType, type.upper(), comfy.sd.CLIPType.STABLE_DIFFUSION)
|
||||
@@ -2061,11 +2061,13 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"ImagePadForOutpaint": "Pad Image for Outpainting",
|
||||
"ImageBatch": "Batch Images",
|
||||
"ImageCrop": "Image Crop",
|
||||
"ImageStitch": "Image Stitch",
|
||||
"ImageBlend": "Image Blend",
|
||||
"ImageBlur": "Image Blur",
|
||||
"ImageQuantize": "Image Quantize",
|
||||
"ImageSharpen": "Image Sharpen",
|
||||
"ImageScaleToTotalPixels": "Scale Image to Total Pixels",
|
||||
"GetImageSize": "Get Image Size",
|
||||
# _for_testing
|
||||
"VAEDecodeTiled": "VAE Decode (Tiled)",
|
||||
"VAEEncodeTiled": "VAE Encode (Tiled)",
|
||||
@@ -2123,6 +2125,25 @@ def load_custom_node(module_path: str, ignore=set(), module_parent="custom_nodes
|
||||
|
||||
LOADED_MODULE_DIRS[module_name] = os.path.abspath(module_dir)
|
||||
|
||||
try:
|
||||
from comfy_config import config_parser
|
||||
|
||||
project_config = config_parser.extract_node_configuration(module_path)
|
||||
|
||||
web_dir_name = project_config.tool_comfy.web
|
||||
|
||||
if web_dir_name:
|
||||
web_dir_path = os.path.join(module_path, web_dir_name)
|
||||
|
||||
if os.path.isdir(web_dir_path):
|
||||
project_name = project_config.project.name
|
||||
|
||||
EXTENSION_WEB_DIRS[project_name] = web_dir_path
|
||||
|
||||
logging.info("Automatically register web folder {} for {}".format(web_dir_name, project_name))
|
||||
except Exception as e:
|
||||
logging.warning(f"Unable to parse pyproject.toml due to lack dependency pydantic-settings, please run 'pip install -r requirements.txt': {e}")
|
||||
|
||||
if hasattr(module, "WEB_DIRECTORY") and getattr(module, "WEB_DIRECTORY") is not None:
|
||||
web_dir = os.path.abspath(os.path.join(module_dir, getattr(module, "WEB_DIRECTORY")))
|
||||
if os.path.isdir(web_dir):
|
||||
@@ -2166,6 +2187,9 @@ def init_external_custom_nodes():
|
||||
module_path = os.path.join(custom_node_path, possible_module)
|
||||
if os.path.isfile(module_path) and os.path.splitext(module_path)[1] != ".py": continue
|
||||
if module_path.endswith(".disabled"): continue
|
||||
if args.disable_all_custom_nodes and possible_module not in args.whitelist_custom_nodes:
|
||||
logging.info(f"Skipping {possible_module} due to disable_all_custom_nodes and whitelist_custom_nodes")
|
||||
continue
|
||||
time_before = time.perf_counter()
|
||||
success = load_custom_node(module_path, base_node_names, module_parent="custom_nodes")
|
||||
node_import_times.append((time.perf_counter() - time_before, module_path, success))
|
||||
@@ -2210,6 +2234,7 @@ def init_builtin_extra_nodes():
|
||||
"nodes_model_downscale.py",
|
||||
"nodes_images.py",
|
||||
"nodes_video_model.py",
|
||||
"nodes_train.py",
|
||||
"nodes_sag.py",
|
||||
"nodes_perpneg.py",
|
||||
"nodes_stable3d.py",
|
||||
@@ -2257,6 +2282,8 @@ def init_builtin_extra_nodes():
|
||||
"nodes_ace.py",
|
||||
"nodes_string.py",
|
||||
"nodes_camera_trajectory.py",
|
||||
"nodes_edit_model.py",
|
||||
"nodes_tcfg.py"
|
||||
]
|
||||
|
||||
import_failed = []
|
||||
|
7
protocol.py
Normal file
7
protocol.py
Normal file
@@ -0,0 +1,7 @@
|
||||
|
||||
class BinaryEventTypes:
|
||||
PREVIEW_IMAGE = 1
|
||||
UNENCODED_PREVIEW_IMAGE = 2
|
||||
TEXT = 3
|
||||
PREVIEW_IMAGE_WITH_METADATA = 4
|
||||
|
@@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "ComfyUI"
|
||||
version = "0.3.39"
|
||||
version = "0.3.43"
|
||||
readme = "README.md"
|
||||
license = { file = "LICENSE" }
|
||||
requires-python = ">=3.9"
|
||||
|
@@ -1,12 +1,13 @@
|
||||
comfyui-frontend-package==1.20.7
|
||||
comfyui-workflow-templates==0.1.22
|
||||
comfyui-frontend-package==1.23.4
|
||||
comfyui-workflow-templates==0.1.31
|
||||
comfyui-embedded-docs==0.2.3
|
||||
torch
|
||||
torchsde
|
||||
torchvision
|
||||
torchaudio
|
||||
numpy>=1.25.0
|
||||
einops
|
||||
transformers>=4.28.1
|
||||
transformers>=4.37.2
|
||||
tokenizers>=0.13.3
|
||||
sentencepiece
|
||||
safetensors>=0.4.2
|
||||
@@ -17,6 +18,8 @@ Pillow
|
||||
scipy
|
||||
tqdm
|
||||
psutil
|
||||
alembic
|
||||
SQLAlchemy
|
||||
|
||||
#non essential dependencies:
|
||||
kornia>=0.7.1
|
||||
@@ -24,3 +27,4 @@ spandrel
|
||||
soundfile
|
||||
av>=14.2.0
|
||||
pydantic~=2.0
|
||||
pydantic-settings~=2.0
|
||||
|
67
server.py
67
server.py
@@ -35,11 +35,7 @@ from app.model_manager import ModelFileManager
|
||||
from app.custom_node_manager import CustomNodeManager
|
||||
from typing import Optional, Union
|
||||
from api_server.routes.internal.internal_routes import InternalRoutes
|
||||
|
||||
class BinaryEventTypes:
|
||||
PREVIEW_IMAGE = 1
|
||||
UNENCODED_PREVIEW_IMAGE = 2
|
||||
TEXT = 3
|
||||
from protocol import BinaryEventTypes
|
||||
|
||||
async def send_socket_catch_exception(function, message):
|
||||
try:
|
||||
@@ -390,7 +386,7 @@ class PromptServer():
|
||||
async def view_image(request):
|
||||
if "filename" in request.rel_url.query:
|
||||
filename = request.rel_url.query["filename"]
|
||||
filename,output_dir = folder_paths.annotated_filepath(filename)
|
||||
filename, output_dir = folder_paths.annotated_filepath(filename)
|
||||
|
||||
if not filename:
|
||||
return web.Response(status=400)
|
||||
@@ -476,9 +472,8 @@ class PromptServer():
|
||||
# Get content type from mimetype, defaulting to 'application/octet-stream'
|
||||
content_type = mimetypes.guess_type(filename)[0] or 'application/octet-stream'
|
||||
|
||||
# For security, force certain extensions to download instead of display
|
||||
file_extension = os.path.splitext(filename)[1].lower()
|
||||
if file_extension in {'.html', '.htm', '.js', '.css'}:
|
||||
# For security, force certain mimetypes to download instead of display
|
||||
if content_type in {'text/html', 'text/html-sandboxed', 'application/xhtml+xml', 'text/javascript', 'text/css'}:
|
||||
content_type = 'application/octet-stream' # Forces download
|
||||
|
||||
return web.FileResponse(
|
||||
@@ -644,7 +639,8 @@ class PromptServer():
|
||||
|
||||
if "prompt" in json_data:
|
||||
prompt = json_data["prompt"]
|
||||
valid = execution.validate_prompt(prompt)
|
||||
prompt_id = str(uuid.uuid4())
|
||||
valid = await execution.validate_prompt(prompt_id, prompt)
|
||||
extra_data = {}
|
||||
if "extra_data" in json_data:
|
||||
extra_data = json_data["extra_data"]
|
||||
@@ -652,7 +648,6 @@ class PromptServer():
|
||||
if "client_id" in json_data:
|
||||
extra_data["client_id"] = json_data["client_id"]
|
||||
if valid[0]:
|
||||
prompt_id = str(uuid.uuid4())
|
||||
outputs_to_execute = valid[2]
|
||||
self.prompt_queue.put((number, prompt_id, prompt, extra_data, outputs_to_execute))
|
||||
response = {"prompt_id": prompt_id, "number": number, "node_errors": valid[3]}
|
||||
@@ -746,6 +741,13 @@ class PromptServer():
|
||||
web.static('/templates', workflow_templates_path)
|
||||
])
|
||||
|
||||
# Serve embedded documentation from the package
|
||||
embedded_docs_path = FrontendManager.embedded_docs_path()
|
||||
if embedded_docs_path:
|
||||
self.app.add_routes([
|
||||
web.static('/docs', embedded_docs_path)
|
||||
])
|
||||
|
||||
self.app.add_routes([
|
||||
web.static('/', self.web_root),
|
||||
])
|
||||
@@ -760,6 +762,10 @@ class PromptServer():
|
||||
async def send(self, event, data, sid=None):
|
||||
if event == BinaryEventTypes.UNENCODED_PREVIEW_IMAGE:
|
||||
await self.send_image(data, sid=sid)
|
||||
elif event == BinaryEventTypes.PREVIEW_IMAGE_WITH_METADATA:
|
||||
# data is (preview_image, metadata)
|
||||
preview_image, metadata = data
|
||||
await self.send_image_with_metadata(preview_image, metadata, sid=sid)
|
||||
elif isinstance(data, (bytes, bytearray)):
|
||||
await self.send_bytes(event, data, sid)
|
||||
else:
|
||||
@@ -782,7 +788,7 @@ class PromptServer():
|
||||
if hasattr(Image, 'Resampling'):
|
||||
resampling = Image.Resampling.BILINEAR
|
||||
else:
|
||||
resampling = Image.ANTIALIAS
|
||||
resampling = Image.Resampling.LANCZOS
|
||||
|
||||
image = ImageOps.contain(image, (max_size, max_size), resampling)
|
||||
type_num = 1
|
||||
@@ -798,6 +804,43 @@ class PromptServer():
|
||||
preview_bytes = bytesIO.getvalue()
|
||||
await self.send_bytes(BinaryEventTypes.PREVIEW_IMAGE, preview_bytes, sid=sid)
|
||||
|
||||
async def send_image_with_metadata(self, image_data, metadata=None, sid=None):
|
||||
image_type = image_data[0]
|
||||
image = image_data[1]
|
||||
max_size = image_data[2]
|
||||
if max_size is not None:
|
||||
if hasattr(Image, 'Resampling'):
|
||||
resampling = Image.Resampling.BILINEAR
|
||||
else:
|
||||
resampling = Image.Resampling.LANCZOS
|
||||
|
||||
image = ImageOps.contain(image, (max_size, max_size), resampling)
|
||||
|
||||
mimetype = "image/png" if image_type == "PNG" else "image/jpeg"
|
||||
|
||||
# Prepare metadata
|
||||
if metadata is None:
|
||||
metadata = {}
|
||||
metadata["image_type"] = mimetype
|
||||
|
||||
# Serialize metadata as JSON
|
||||
import json
|
||||
metadata_json = json.dumps(metadata).encode('utf-8')
|
||||
metadata_length = len(metadata_json)
|
||||
|
||||
# Prepare image data
|
||||
bytesIO = BytesIO()
|
||||
image.save(bytesIO, format=image_type, quality=95, compress_level=1)
|
||||
image_bytes = bytesIO.getvalue()
|
||||
|
||||
# Combine metadata and image
|
||||
combined_data = bytearray()
|
||||
combined_data.extend(struct.pack(">I", metadata_length))
|
||||
combined_data.extend(metadata_json)
|
||||
combined_data.extend(image_bytes)
|
||||
|
||||
await self.send_bytes(BinaryEventTypes.PREVIEW_IMAGE_WITH_METADATA, combined_data, sid=sid)
|
||||
|
||||
async def send_bytes(self, event, data, sid=None):
|
||||
message = self.encode_bytes(event, data)
|
||||
|
||||
|
0
tests-unit/comfy_extras_test/__init__.py
Normal file
0
tests-unit/comfy_extras_test/__init__.py
Normal file
243
tests-unit/comfy_extras_test/image_stitch_test.py
Normal file
243
tests-unit/comfy_extras_test/image_stitch_test.py
Normal file
@@ -0,0 +1,243 @@
|
||||
import torch
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
# Mock nodes module to prevent CUDA initialization during import
|
||||
mock_nodes = MagicMock()
|
||||
mock_nodes.MAX_RESOLUTION = 16384
|
||||
|
||||
# Mock server module for PromptServer
|
||||
mock_server = MagicMock()
|
||||
|
||||
with patch.dict('sys.modules', {'nodes': mock_nodes, 'server': mock_server}):
|
||||
from comfy_extras.nodes_images import ImageStitch
|
||||
|
||||
|
||||
class TestImageStitch:
|
||||
|
||||
def create_test_image(self, batch_size=1, height=64, width=64, channels=3):
|
||||
"""Helper to create test images with specific dimensions"""
|
||||
return torch.rand(batch_size, height, width, channels)
|
||||
|
||||
def test_no_image2_passthrough(self):
|
||||
"""Test that when image2 is None, image1 is returned unchanged"""
|
||||
node = ImageStitch()
|
||||
image1 = self.create_test_image()
|
||||
|
||||
result = node.stitch(image1, "right", True, 0, "white", image2=None)
|
||||
|
||||
assert len(result) == 1
|
||||
assert torch.equal(result[0], image1)
|
||||
|
||||
def test_basic_horizontal_stitch_right(self):
|
||||
"""Test basic horizontal stitching to the right"""
|
||||
node = ImageStitch()
|
||||
image1 = self.create_test_image(height=32, width=32)
|
||||
image2 = self.create_test_image(height=32, width=24)
|
||||
|
||||
result = node.stitch(image1, "right", False, 0, "white", image2)
|
||||
|
||||
assert result[0].shape == (1, 32, 56, 3) # 32 + 24 width
|
||||
|
||||
def test_basic_horizontal_stitch_left(self):
|
||||
"""Test basic horizontal stitching to the left"""
|
||||
node = ImageStitch()
|
||||
image1 = self.create_test_image(height=32, width=32)
|
||||
image2 = self.create_test_image(height=32, width=24)
|
||||
|
||||
result = node.stitch(image1, "left", False, 0, "white", image2)
|
||||
|
||||
assert result[0].shape == (1, 32, 56, 3) # 24 + 32 width
|
||||
|
||||
def test_basic_vertical_stitch_down(self):
|
||||
"""Test basic vertical stitching downward"""
|
||||
node = ImageStitch()
|
||||
image1 = self.create_test_image(height=32, width=32)
|
||||
image2 = self.create_test_image(height=24, width=32)
|
||||
|
||||
result = node.stitch(image1, "down", False, 0, "white", image2)
|
||||
|
||||
assert result[0].shape == (1, 56, 32, 3) # 32 + 24 height
|
||||
|
||||
def test_basic_vertical_stitch_up(self):
|
||||
"""Test basic vertical stitching upward"""
|
||||
node = ImageStitch()
|
||||
image1 = self.create_test_image(height=32, width=32)
|
||||
image2 = self.create_test_image(height=24, width=32)
|
||||
|
||||
result = node.stitch(image1, "up", False, 0, "white", image2)
|
||||
|
||||
assert result[0].shape == (1, 56, 32, 3) # 24 + 32 height
|
||||
|
||||
def test_size_matching_horizontal(self):
|
||||
"""Test size matching for horizontal concatenation"""
|
||||
node = ImageStitch()
|
||||
image1 = self.create_test_image(height=64, width=64)
|
||||
image2 = self.create_test_image(height=32, width=32) # Different aspect ratio
|
||||
|
||||
result = node.stitch(image1, "right", True, 0, "white", image2)
|
||||
|
||||
# image2 should be resized to match image1's height (64) with preserved aspect ratio
|
||||
expected_width = 64 + 64 # original + resized (32*64/32 = 64)
|
||||
assert result[0].shape == (1, 64, expected_width, 3)
|
||||
|
||||
def test_size_matching_vertical(self):
|
||||
"""Test size matching for vertical concatenation"""
|
||||
node = ImageStitch()
|
||||
image1 = self.create_test_image(height=64, width=64)
|
||||
image2 = self.create_test_image(height=32, width=32)
|
||||
|
||||
result = node.stitch(image1, "down", True, 0, "white", image2)
|
||||
|
||||
# image2 should be resized to match image1's width (64) with preserved aspect ratio
|
||||
expected_height = 64 + 64 # original + resized (32*64/32 = 64)
|
||||
assert result[0].shape == (1, expected_height, 64, 3)
|
||||
|
||||
def test_padding_for_mismatched_heights_horizontal(self):
|
||||
"""Test padding when heights don't match in horizontal concatenation"""
|
||||
node = ImageStitch()
|
||||
image1 = self.create_test_image(height=64, width=32)
|
||||
image2 = self.create_test_image(height=48, width=24) # Shorter height
|
||||
|
||||
result = node.stitch(image1, "right", False, 0, "white", image2)
|
||||
|
||||
# Both images should be padded to height 64
|
||||
assert result[0].shape == (1, 64, 56, 3) # 32 + 24 width, max(64,48) height
|
||||
|
||||
def test_padding_for_mismatched_widths_vertical(self):
|
||||
"""Test padding when widths don't match in vertical concatenation"""
|
||||
node = ImageStitch()
|
||||
image1 = self.create_test_image(height=32, width=64)
|
||||
image2 = self.create_test_image(height=24, width=48) # Narrower width
|
||||
|
||||
result = node.stitch(image1, "down", False, 0, "white", image2)
|
||||
|
||||
# Both images should be padded to width 64
|
||||
assert result[0].shape == (1, 56, 64, 3) # 32 + 24 height, max(64,48) width
|
||||
|
||||
def test_spacing_horizontal(self):
|
||||
"""Test spacing addition in horizontal concatenation"""
|
||||
node = ImageStitch()
|
||||
image1 = self.create_test_image(height=32, width=32)
|
||||
image2 = self.create_test_image(height=32, width=24)
|
||||
spacing_width = 16
|
||||
|
||||
result = node.stitch(image1, "right", False, spacing_width, "white", image2)
|
||||
|
||||
# Expected width: 32 + 16 (spacing) + 24 = 72
|
||||
assert result[0].shape == (1, 32, 72, 3)
|
||||
|
||||
def test_spacing_vertical(self):
|
||||
"""Test spacing addition in vertical concatenation"""
|
||||
node = ImageStitch()
|
||||
image1 = self.create_test_image(height=32, width=32)
|
||||
image2 = self.create_test_image(height=24, width=32)
|
||||
spacing_width = 16
|
||||
|
||||
result = node.stitch(image1, "down", False, spacing_width, "white", image2)
|
||||
|
||||
# Expected height: 32 + 16 (spacing) + 24 = 72
|
||||
assert result[0].shape == (1, 72, 32, 3)
|
||||
|
||||
def test_spacing_color_values(self):
|
||||
"""Test that spacing colors are applied correctly"""
|
||||
node = ImageStitch()
|
||||
image1 = self.create_test_image(height=32, width=32)
|
||||
image2 = self.create_test_image(height=32, width=32)
|
||||
|
||||
# Test white spacing
|
||||
result_white = node.stitch(image1, "right", False, 16, "white", image2)
|
||||
# Check that spacing region contains white values (close to 1.0)
|
||||
spacing_region = result_white[0][:, :, 32:48, :] # Middle 16 pixels
|
||||
assert torch.all(spacing_region >= 0.9) # Should be close to white
|
||||
|
||||
# Test black spacing
|
||||
result_black = node.stitch(image1, "right", False, 16, "black", image2)
|
||||
spacing_region = result_black[0][:, :, 32:48, :]
|
||||
assert torch.all(spacing_region <= 0.1) # Should be close to black
|
||||
|
||||
def test_odd_spacing_width_made_even(self):
|
||||
"""Test that odd spacing widths are made even"""
|
||||
node = ImageStitch()
|
||||
image1 = self.create_test_image(height=32, width=32)
|
||||
image2 = self.create_test_image(height=32, width=32)
|
||||
|
||||
# Use odd spacing width
|
||||
result = node.stitch(image1, "right", False, 15, "white", image2)
|
||||
|
||||
# Should be made even (16), so total width = 32 + 16 + 32 = 80
|
||||
assert result[0].shape == (1, 32, 80, 3)
|
||||
|
||||
def test_batch_size_matching(self):
|
||||
"""Test that different batch sizes are handled correctly"""
|
||||
node = ImageStitch()
|
||||
image1 = self.create_test_image(batch_size=2, height=32, width=32)
|
||||
image2 = self.create_test_image(batch_size=1, height=32, width=32)
|
||||
|
||||
result = node.stitch(image1, "right", False, 0, "white", image2)
|
||||
|
||||
# Should match larger batch size
|
||||
assert result[0].shape == (2, 32, 64, 3)
|
||||
|
||||
def test_channel_matching_rgb_to_rgba(self):
|
||||
"""Test that channel differences are handled (RGB + alpha)"""
|
||||
node = ImageStitch()
|
||||
image1 = self.create_test_image(channels=3) # RGB
|
||||
image2 = self.create_test_image(channels=4) # RGBA
|
||||
|
||||
result = node.stitch(image1, "right", False, 0, "white", image2)
|
||||
|
||||
# Should have 4 channels (RGBA)
|
||||
assert result[0].shape[-1] == 4
|
||||
|
||||
def test_channel_matching_rgba_to_rgb(self):
|
||||
"""Test that channel differences are handled (RGBA + RGB)"""
|
||||
node = ImageStitch()
|
||||
image1 = self.create_test_image(channels=4) # RGBA
|
||||
image2 = self.create_test_image(channels=3) # RGB
|
||||
|
||||
result = node.stitch(image1, "right", False, 0, "white", image2)
|
||||
|
||||
# Should have 4 channels (RGBA)
|
||||
assert result[0].shape[-1] == 4
|
||||
|
||||
def test_all_color_options(self):
|
||||
"""Test all available color options"""
|
||||
node = ImageStitch()
|
||||
image1 = self.create_test_image(height=32, width=32)
|
||||
image2 = self.create_test_image(height=32, width=32)
|
||||
|
||||
colors = ["white", "black", "red", "green", "blue"]
|
||||
|
||||
for color in colors:
|
||||
result = node.stitch(image1, "right", False, 16, color, image2)
|
||||
assert result[0].shape == (1, 32, 80, 3) # Basic shape check
|
||||
|
||||
def test_all_directions(self):
|
||||
"""Test all direction options"""
|
||||
node = ImageStitch()
|
||||
image1 = self.create_test_image(height=32, width=32)
|
||||
image2 = self.create_test_image(height=32, width=32)
|
||||
|
||||
directions = ["right", "left", "up", "down"]
|
||||
|
||||
for direction in directions:
|
||||
result = node.stitch(image1, direction, False, 0, "white", image2)
|
||||
assert result[0].shape == (1, 32, 64, 3) if direction in ["right", "left"] else (1, 64, 32, 3)
|
||||
|
||||
def test_batch_size_channel_spacing_integration(self):
|
||||
"""Test integration of batch matching, channel matching, size matching, and spacings"""
|
||||
node = ImageStitch()
|
||||
image1 = self.create_test_image(batch_size=2, height=64, width=48, channels=3)
|
||||
image2 = self.create_test_image(batch_size=1, height=32, width=32, channels=4)
|
||||
|
||||
result = node.stitch(image1, "right", True, 8, "red", image2)
|
||||
|
||||
# Should handle: batch matching, size matching, channel matching, spacing
|
||||
assert result[0].shape[0] == 2 # Batch size matched
|
||||
assert result[0].shape[-1] == 4 # Channels matched to max
|
||||
assert result[0].shape[1] == 64 # Height from image1 (size matching)
|
||||
# Width should be: 48 + 8 (spacing) + resized_image2_width
|
||||
expected_image2_width = int(64 * (32/32)) # Resized to height 64
|
||||
expected_total_width = 48 + 8 + expected_image2_width
|
||||
assert result[0].shape[2] == expected_total_width
|
||||
|
51
tests-unit/folder_paths_test/misc_test.py
Normal file
51
tests-unit/folder_paths_test/misc_test.py
Normal file
@@ -0,0 +1,51 @@
|
||||
import pytest
|
||||
import os
|
||||
import tempfile
|
||||
from folder_paths import get_input_subfolders, set_input_directory
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def mock_folder_structure():
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
# Create a nested folder structure
|
||||
folders = [
|
||||
"folder1",
|
||||
"folder1/subfolder1",
|
||||
"folder1/subfolder2",
|
||||
"folder2",
|
||||
"folder2/deep",
|
||||
"folder2/deep/nested",
|
||||
"empty_folder"
|
||||
]
|
||||
|
||||
# Create the folders
|
||||
for folder in folders:
|
||||
os.makedirs(os.path.join(temp_dir, folder))
|
||||
|
||||
# Add some files to test they're not included
|
||||
with open(os.path.join(temp_dir, "root_file.txt"), "w") as f:
|
||||
f.write("test")
|
||||
with open(os.path.join(temp_dir, "folder1", "test.txt"), "w") as f:
|
||||
f.write("test")
|
||||
|
||||
set_input_directory(temp_dir)
|
||||
yield temp_dir
|
||||
|
||||
|
||||
def test_gets_all_folders(mock_folder_structure):
|
||||
folders = get_input_subfolders()
|
||||
expected = ["folder1", "folder1/subfolder1", "folder1/subfolder2",
|
||||
"folder2", "folder2/deep", "folder2/deep/nested", "empty_folder"]
|
||||
assert sorted(folders) == sorted(expected)
|
||||
|
||||
|
||||
def test_handles_nonexistent_input_directory():
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
nonexistent = os.path.join(temp_dir, "nonexistent")
|
||||
set_input_directory(nonexistent)
|
||||
assert get_input_subfolders() == []
|
||||
|
||||
|
||||
def test_empty_input_directory():
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
set_input_directory(temp_dir)
|
||||
assert get_input_subfolders() == [] # Empty since we don't include root
|
@@ -1,3 +1,4 @@
|
||||
pytest>=7.8.0
|
||||
pytest-aiohttp
|
||||
pytest-asyncio
|
||||
websocket-client
|
||||
|
@@ -1,4 +1,4 @@
|
||||
# Config for testing nodes
|
||||
testing:
|
||||
custom_nodes: tests/inference/testing_nodes
|
||||
custom_nodes: testing_nodes
|
||||
|
||||
|
410
tests/inference/test_async_nodes.py
Normal file
410
tests/inference/test_async_nodes.py
Normal file
@@ -0,0 +1,410 @@
|
||||
import pytest
|
||||
import time
|
||||
import torch
|
||||
import urllib.error
|
||||
import numpy as np
|
||||
import subprocess
|
||||
|
||||
from pytest import fixture
|
||||
from comfy_execution.graph_utils import GraphBuilder
|
||||
from tests.inference.test_execution import ComfyClient
|
||||
|
||||
|
||||
@pytest.mark.execution
|
||||
class TestAsyncNodes:
|
||||
@fixture(scope="class", autouse=True, params=[
|
||||
(False, 0),
|
||||
(True, 0),
|
||||
(True, 100),
|
||||
])
|
||||
def _server(self, args_pytest, request):
|
||||
pargs = [
|
||||
'python','main.py',
|
||||
'--output-directory', args_pytest["output_dir"],
|
||||
'--listen', args_pytest["listen"],
|
||||
'--port', str(args_pytest["port"]),
|
||||
'--extra-model-paths-config', 'tests/inference/extra_model_paths.yaml',
|
||||
]
|
||||
use_lru, lru_size = request.param
|
||||
if use_lru:
|
||||
pargs += ['--cache-lru', str(lru_size)]
|
||||
# Running server with args: pargs
|
||||
p = subprocess.Popen(pargs)
|
||||
yield
|
||||
p.kill()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
@fixture(scope="class", autouse=True)
|
||||
def shared_client(self, args_pytest, _server):
|
||||
client = ComfyClient()
|
||||
n_tries = 5
|
||||
for i in range(n_tries):
|
||||
time.sleep(4)
|
||||
try:
|
||||
client.connect(listen=args_pytest["listen"], port=args_pytest["port"])
|
||||
except ConnectionRefusedError:
|
||||
# Retrying...
|
||||
pass
|
||||
else:
|
||||
break
|
||||
yield client
|
||||
del client
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
@fixture
|
||||
def client(self, shared_client, request):
|
||||
shared_client.set_test_name(f"async_nodes[{request.node.name}]")
|
||||
yield shared_client
|
||||
|
||||
@fixture
|
||||
def builder(self, request):
|
||||
yield GraphBuilder(prefix=request.node.name)
|
||||
|
||||
# Happy Path Tests
|
||||
|
||||
def test_basic_async_execution(self, client: ComfyClient, builder: GraphBuilder):
|
||||
"""Test that a basic async node executes correctly."""
|
||||
g = builder
|
||||
image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||
sleep_node = g.node("TestSleep", value=image.out(0), seconds=0.1)
|
||||
output = g.node("SaveImage", images=sleep_node.out(0))
|
||||
|
||||
result = client.run(g)
|
||||
|
||||
# Verify execution completed
|
||||
assert result.did_run(sleep_node), "Async sleep node should have executed"
|
||||
assert result.did_run(output), "Output node should have executed"
|
||||
|
||||
# Verify the image passed through correctly
|
||||
result_images = result.get_images(output)
|
||||
assert len(result_images) == 1, "Should have 1 image"
|
||||
assert np.array(result_images[0]).min() == 0 and np.array(result_images[0]).max() == 0, "Image should be black"
|
||||
|
||||
def test_multiple_async_parallel_execution(self, client: ComfyClient, builder: GraphBuilder):
|
||||
"""Test that multiple async nodes execute in parallel."""
|
||||
g = builder
|
||||
image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||
|
||||
# Create multiple async sleep nodes with different durations
|
||||
sleep1 = g.node("TestSleep", value=image.out(0), seconds=0.3)
|
||||
sleep2 = g.node("TestSleep", value=image.out(0), seconds=0.4)
|
||||
sleep3 = g.node("TestSleep", value=image.out(0), seconds=0.5)
|
||||
|
||||
# Add outputs for each
|
||||
_output1 = g.node("PreviewImage", images=sleep1.out(0))
|
||||
_output2 = g.node("PreviewImage", images=sleep2.out(0))
|
||||
_output3 = g.node("PreviewImage", images=sleep3.out(0))
|
||||
|
||||
start_time = time.time()
|
||||
result = client.run(g)
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
# Should take ~0.5s (max duration) not 1.2s (sum of durations)
|
||||
assert elapsed_time < 0.8, f"Parallel execution took {elapsed_time}s, expected < 0.8s"
|
||||
|
||||
# Verify all nodes executed
|
||||
assert result.did_run(sleep1) and result.did_run(sleep2) and result.did_run(sleep3)
|
||||
|
||||
def test_async_with_dependencies(self, client: ComfyClient, builder: GraphBuilder):
|
||||
"""Test async nodes with proper dependency handling."""
|
||||
g = builder
|
||||
image1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||
image2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
|
||||
|
||||
# Chain of async operations
|
||||
sleep1 = g.node("TestSleep", value=image1.out(0), seconds=0.2)
|
||||
sleep2 = g.node("TestSleep", value=image2.out(0), seconds=0.2)
|
||||
|
||||
# Average depends on both async results
|
||||
average = g.node("TestVariadicAverage", input1=sleep1.out(0), input2=sleep2.out(0))
|
||||
output = g.node("SaveImage", images=average.out(0))
|
||||
|
||||
result = client.run(g)
|
||||
|
||||
# Verify execution order
|
||||
assert result.did_run(sleep1) and result.did_run(sleep2)
|
||||
assert result.did_run(average) and result.did_run(output)
|
||||
|
||||
# Verify averaged result
|
||||
result_images = result.get_images(output)
|
||||
avg_value = np.array(result_images[0]).mean()
|
||||
assert abs(avg_value - 127.5) < 1, f"Average value {avg_value} should be ~127.5"
|
||||
|
||||
def test_async_validate_inputs(self, client: ComfyClient, builder: GraphBuilder):
|
||||
"""Test async VALIDATE_INPUTS function."""
|
||||
g = builder
|
||||
# Create a test node with async validation
|
||||
validation_node = g.node("TestAsyncValidation", value=5.0, threshold=10.0)
|
||||
g.node("SaveImage", images=validation_node.out(0))
|
||||
|
||||
# Should pass validation
|
||||
result = client.run(g)
|
||||
assert result.did_run(validation_node)
|
||||
|
||||
# Test validation failure
|
||||
validation_node.inputs['threshold'] = 3.0 # Will fail since value > threshold
|
||||
with pytest.raises(urllib.error.HTTPError):
|
||||
client.run(g)
|
||||
|
||||
def test_async_lazy_evaluation(self, client: ComfyClient, builder: GraphBuilder):
|
||||
"""Test async nodes with lazy evaluation."""
|
||||
g = builder
|
||||
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||
input2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
|
||||
mask = g.node("StubMask", value=0.0, height=512, width=512, batch_size=1)
|
||||
|
||||
# Create async nodes that will be evaluated lazily
|
||||
sleep1 = g.node("TestSleep", value=input1.out(0), seconds=0.3)
|
||||
sleep2 = g.node("TestSleep", value=input2.out(0), seconds=0.3)
|
||||
|
||||
# Use lazy mix that only needs sleep1 (mask=0.0)
|
||||
lazy_mix = g.node("TestLazyMixImages", image1=sleep1.out(0), image2=sleep2.out(0), mask=mask.out(0))
|
||||
g.node("SaveImage", images=lazy_mix.out(0))
|
||||
|
||||
start_time = time.time()
|
||||
result = client.run(g)
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
# Should only execute sleep1, not sleep2
|
||||
assert elapsed_time < 0.5, f"Should skip sleep2, took {elapsed_time}s"
|
||||
assert result.did_run(sleep1), "Sleep1 should have executed"
|
||||
assert not result.did_run(sleep2), "Sleep2 should have been skipped"
|
||||
|
||||
def test_async_check_lazy_status(self, client: ComfyClient, builder: GraphBuilder):
|
||||
"""Test async check_lazy_status function."""
|
||||
g = builder
|
||||
# Create a node with async check_lazy_status
|
||||
lazy_node = g.node("TestAsyncLazyCheck",
|
||||
input1="value1",
|
||||
input2="value2",
|
||||
condition=True)
|
||||
g.node("SaveImage", images=lazy_node.out(0))
|
||||
|
||||
result = client.run(g)
|
||||
assert result.did_run(lazy_node)
|
||||
|
||||
# Error Handling Tests
|
||||
|
||||
def test_async_execution_error(self, client: ComfyClient, builder: GraphBuilder):
|
||||
"""Test that async execution errors are properly handled."""
|
||||
g = builder
|
||||
image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||
# Create an async node that will error
|
||||
error_node = g.node("TestAsyncError", value=image.out(0), error_after=0.1)
|
||||
g.node("SaveImage", images=error_node.out(0))
|
||||
|
||||
try:
|
||||
client.run(g)
|
||||
assert False, "Should have raised an error"
|
||||
except Exception as e:
|
||||
assert 'prompt_id' in e.args[0], f"Did not get proper error message: {e}"
|
||||
assert e.args[0]['node_id'] == error_node.id, "Error should be from async error node"
|
||||
|
||||
def test_async_validation_error(self, client: ComfyClient, builder: GraphBuilder):
|
||||
"""Test async validation error handling."""
|
||||
g = builder
|
||||
# Node with async validation that will fail
|
||||
validation_node = g.node("TestAsyncValidationError", value=15.0, max_value=10.0)
|
||||
g.node("SaveImage", images=validation_node.out(0))
|
||||
|
||||
with pytest.raises(urllib.error.HTTPError) as exc_info:
|
||||
client.run(g)
|
||||
# Verify it's a validation error
|
||||
assert exc_info.value.code == 400
|
||||
|
||||
def test_async_timeout_handling(self, client: ComfyClient, builder: GraphBuilder):
|
||||
"""Test handling of async operations that timeout."""
|
||||
g = builder
|
||||
image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||
# Very long sleep that would timeout
|
||||
timeout_node = g.node("TestAsyncTimeout", value=image.out(0), timeout=0.5, operation_time=2.0)
|
||||
g.node("SaveImage", images=timeout_node.out(0))
|
||||
|
||||
try:
|
||||
client.run(g)
|
||||
assert False, "Should have raised a timeout error"
|
||||
except Exception as e:
|
||||
assert 'timeout' in str(e).lower(), f"Expected timeout error, got: {e}"
|
||||
|
||||
def test_concurrent_async_error_recovery(self, client: ComfyClient, builder: GraphBuilder):
|
||||
"""Test that workflow can recover after async errors."""
|
||||
g = builder
|
||||
image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||
|
||||
# First run with error
|
||||
error_node = g.node("TestAsyncError", value=image.out(0), error_after=0.1)
|
||||
g.node("SaveImage", images=error_node.out(0))
|
||||
|
||||
try:
|
||||
client.run(g)
|
||||
except Exception:
|
||||
pass # Expected
|
||||
|
||||
# Second run should succeed
|
||||
g2 = GraphBuilder(prefix="recovery_test")
|
||||
image2 = g2.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
|
||||
sleep_node = g2.node("TestSleep", value=image2.out(0), seconds=0.1)
|
||||
g2.node("SaveImage", images=sleep_node.out(0))
|
||||
|
||||
result = client.run(g2)
|
||||
assert result.did_run(sleep_node), "Should be able to run after error"
|
||||
|
||||
def test_sync_error_during_async_execution(self, client: ComfyClient, builder: GraphBuilder):
|
||||
"""Test handling when sync node errors while async node is executing."""
|
||||
g = builder
|
||||
image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||
|
||||
# Async node that takes time
|
||||
sleep_node = g.node("TestSleep", value=image.out(0), seconds=0.5)
|
||||
|
||||
# Sync node that will error immediately
|
||||
error_node = g.node("TestSyncError", value=image.out(0))
|
||||
|
||||
# Both feed into output
|
||||
g.node("PreviewImage", images=sleep_node.out(0))
|
||||
g.node("PreviewImage", images=error_node.out(0))
|
||||
|
||||
try:
|
||||
client.run(g)
|
||||
assert False, "Should have raised an error"
|
||||
except Exception as e:
|
||||
# Verify the sync error was caught even though async was running
|
||||
assert 'prompt_id' in e.args[0]
|
||||
|
||||
# Edge Cases
|
||||
|
||||
def test_async_with_execution_blocker(self, client: ComfyClient, builder: GraphBuilder):
|
||||
"""Test async nodes with execution blockers."""
|
||||
g = builder
|
||||
image1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||
image2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
|
||||
|
||||
# Async sleep nodes
|
||||
sleep1 = g.node("TestSleep", value=image1.out(0), seconds=0.2)
|
||||
sleep2 = g.node("TestSleep", value=image2.out(0), seconds=0.2)
|
||||
|
||||
# Create list of images
|
||||
image_list = g.node("TestMakeListNode", value1=sleep1.out(0), value2=sleep2.out(0))
|
||||
|
||||
# Create list of blocking conditions - [False, True] to block only the second item
|
||||
int1 = g.node("StubInt", value=1)
|
||||
int2 = g.node("StubInt", value=2)
|
||||
block_list = g.node("TestMakeListNode", value1=int1.out(0), value2=int2.out(0))
|
||||
|
||||
# Compare each value against 2, so first is False (1 != 2) and second is True (2 == 2)
|
||||
compare = g.node("TestIntConditions", a=block_list.out(0), b=2, operation="==")
|
||||
|
||||
# Block based on the comparison results
|
||||
blocker = g.node("TestExecutionBlocker", input=image_list.out(0), block=compare.out(0), verbose=False)
|
||||
|
||||
output = g.node("PreviewImage", images=blocker.out(0))
|
||||
|
||||
result = client.run(g)
|
||||
images = result.get_images(output)
|
||||
assert len(images) == 1, "Should have blocked second image"
|
||||
|
||||
def test_async_caching_behavior(self, client: ComfyClient, builder: GraphBuilder):
|
||||
"""Test that async nodes are properly cached."""
|
||||
g = builder
|
||||
image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||
sleep_node = g.node("TestSleep", value=image.out(0), seconds=0.2)
|
||||
g.node("SaveImage", images=sleep_node.out(0))
|
||||
|
||||
# First run
|
||||
result1 = client.run(g)
|
||||
assert result1.did_run(sleep_node), "Should run first time"
|
||||
|
||||
# Second run - should be cached
|
||||
start_time = time.time()
|
||||
result2 = client.run(g)
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
assert not result2.did_run(sleep_node), "Should be cached"
|
||||
assert elapsed_time < 0.1, f"Cached run took {elapsed_time}s, should be instant"
|
||||
|
||||
def test_async_with_dynamic_prompts(self, client: ComfyClient, builder: GraphBuilder):
|
||||
"""Test async nodes within dynamically generated prompts."""
|
||||
g = builder
|
||||
image1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||
image2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
|
||||
|
||||
# Node that generates async nodes dynamically
|
||||
dynamic_async = g.node("TestDynamicAsyncGeneration",
|
||||
image1=image1.out(0),
|
||||
image2=image2.out(0),
|
||||
num_async_nodes=3,
|
||||
sleep_duration=0.2)
|
||||
g.node("SaveImage", images=dynamic_async.out(0))
|
||||
|
||||
start_time = time.time()
|
||||
result = client.run(g)
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
# Should execute async nodes in parallel within dynamic prompt
|
||||
assert elapsed_time < 0.5, f"Dynamic async execution took {elapsed_time}s"
|
||||
assert result.did_run(dynamic_async)
|
||||
|
||||
def test_async_resource_cleanup(self, client: ComfyClient, builder: GraphBuilder):
|
||||
"""Test that async resources are properly cleaned up."""
|
||||
g = builder
|
||||
image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||
|
||||
# Create multiple async nodes that use resources
|
||||
resource_nodes = []
|
||||
for i in range(5):
|
||||
node = g.node("TestAsyncResourceUser",
|
||||
value=image.out(0),
|
||||
resource_id=f"resource_{i}",
|
||||
duration=0.1)
|
||||
resource_nodes.append(node)
|
||||
g.node("PreviewImage", images=node.out(0))
|
||||
|
||||
result = client.run(g)
|
||||
|
||||
# Verify all nodes executed
|
||||
for node in resource_nodes:
|
||||
assert result.did_run(node)
|
||||
|
||||
# Run again to ensure resources were cleaned up
|
||||
result2 = client.run(g)
|
||||
# Should be cached but not error due to resource conflicts
|
||||
for node in resource_nodes:
|
||||
assert not result2.did_run(node), "Should be cached"
|
||||
|
||||
def test_async_cancellation(self, client: ComfyClient, builder: GraphBuilder):
|
||||
"""Test cancellation of async operations."""
|
||||
# This would require implementing cancellation in the client
|
||||
# For now, we'll test that long-running async operations can be interrupted
|
||||
pass # TODO: Implement when cancellation API is available
|
||||
|
||||
def test_mixed_sync_async_execution(self, client: ComfyClient, builder: GraphBuilder):
|
||||
"""Test workflows with both sync and async nodes."""
|
||||
g = builder
|
||||
image1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||
image2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
|
||||
mask = g.node("StubMask", value=0.5, height=512, width=512, batch_size=1)
|
||||
|
||||
# Mix of sync and async operations
|
||||
# Sync: lazy mix images
|
||||
sync_op1 = g.node("TestLazyMixImages", image1=image1.out(0), image2=image2.out(0), mask=mask.out(0))
|
||||
# Async: sleep
|
||||
async_op1 = g.node("TestSleep", value=sync_op1.out(0), seconds=0.2)
|
||||
# Sync: custom validation
|
||||
sync_op2 = g.node("TestCustomValidation1", input1=async_op1.out(0), input2=0.5)
|
||||
# Async: sleep again
|
||||
async_op2 = g.node("TestSleep", value=sync_op2.out(0), seconds=0.2)
|
||||
|
||||
output = g.node("SaveImage", images=async_op2.out(0))
|
||||
|
||||
result = client.run(g)
|
||||
|
||||
# Verify all nodes executed in correct order
|
||||
assert result.did_run(sync_op1)
|
||||
assert result.did_run(async_op1)
|
||||
assert result.did_run(sync_op2)
|
||||
assert result.did_run(async_op2)
|
||||
|
||||
# Image should be a mix of black and white (gray)
|
||||
result_images = result.get_images(output)
|
||||
avg_value = np.array(result_images[0]).mean()
|
||||
assert abs(avg_value - 63.75) < 5, f"Average value {avg_value} should be ~63.75"
|
@@ -252,7 +252,7 @@ class TestExecution:
|
||||
|
||||
@pytest.mark.parametrize("test_type, test_value", [
|
||||
("StubInt", 5),
|
||||
("StubFloat", 5.0)
|
||||
("StubMask", 5.0)
|
||||
])
|
||||
def test_validation_error_edge1(self, test_type, test_value, client: ComfyClient, builder: GraphBuilder):
|
||||
g = builder
|
||||
@@ -497,6 +497,69 @@ class TestExecution:
|
||||
assert numpy.array(images[0]).min() == 63 and numpy.array(images[0]).max() == 63, "Image should have value 0.25"
|
||||
assert not result.did_run(test_node), "The execution should have been cached"
|
||||
|
||||
def test_parallel_sleep_nodes(self, client: ComfyClient, builder: GraphBuilder):
|
||||
g = builder
|
||||
image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||
|
||||
# Create sleep nodes for each duration
|
||||
sleep_node1 = g.node("TestSleep", value=image.out(0), seconds=2.8)
|
||||
sleep_node2 = g.node("TestSleep", value=image.out(0), seconds=2.9)
|
||||
sleep_node3 = g.node("TestSleep", value=image.out(0), seconds=3.0)
|
||||
|
||||
# Add outputs to verify the execution
|
||||
_output1 = g.node("PreviewImage", images=sleep_node1.out(0))
|
||||
_output2 = g.node("PreviewImage", images=sleep_node2.out(0))
|
||||
_output3 = g.node("PreviewImage", images=sleep_node3.out(0))
|
||||
|
||||
start_time = time.time()
|
||||
result = client.run(g)
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
# The test should take around 0.4 seconds (the longest sleep duration)
|
||||
# plus some overhead, but definitely less than the sum of all sleeps (0.9s)
|
||||
# We'll allow for up to 0.8s total to account for overhead
|
||||
assert elapsed_time < 4.0, f"Parallel execution took {elapsed_time}s, expected less than 0.8s"
|
||||
|
||||
# Verify that all nodes executed
|
||||
assert result.did_run(sleep_node1), "Sleep node 1 should have run"
|
||||
assert result.did_run(sleep_node2), "Sleep node 2 should have run"
|
||||
assert result.did_run(sleep_node3), "Sleep node 3 should have run"
|
||||
|
||||
def test_parallel_sleep_expansion(self, client: ComfyClient, builder: GraphBuilder):
|
||||
g = builder
|
||||
# Create input images with different values
|
||||
image1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||
image2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
|
||||
image3 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
|
||||
|
||||
# Create a TestParallelSleep node that expands into multiple TestSleep nodes
|
||||
parallel_sleep = g.node("TestParallelSleep",
|
||||
image1=image1.out(0),
|
||||
image2=image2.out(0),
|
||||
image3=image3.out(0),
|
||||
sleep1=0.4,
|
||||
sleep2=0.5,
|
||||
sleep3=0.6)
|
||||
output = g.node("SaveImage", images=parallel_sleep.out(0))
|
||||
|
||||
start_time = time.time()
|
||||
result = client.run(g)
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
# Similar to the previous test, expect parallel execution of the sleep nodes
|
||||
# which should complete in less than the sum of all sleeps
|
||||
assert elapsed_time < 0.8, f"Expansion execution took {elapsed_time}s, expected less than 0.8s"
|
||||
|
||||
# Verify the parallel sleep node executed
|
||||
assert result.did_run(parallel_sleep), "ParallelSleep node should have run"
|
||||
|
||||
# Verify we get an image as output (blend of the three input images)
|
||||
result_images = result.get_images(output)
|
||||
assert len(result_images) == 1, "Should have 1 image"
|
||||
# Average pixel value should be around 170 (255 * 2 // 3)
|
||||
avg_value = numpy.array(result_images[0]).mean()
|
||||
assert avg_value == 170, f"Image average value {avg_value} should be 170"
|
||||
|
||||
# This tests that nodes with OUTPUT_IS_LIST function correctly when they receive an ExecutionBlocker
|
||||
# as input. We also test that when that list (containing an ExecutionBlocker) is passed to a node,
|
||||
# only that one entry in the list is blocked.
|
||||
|
@@ -3,6 +3,7 @@ from .flow_control import FLOW_CONTROL_NODE_CLASS_MAPPINGS, FLOW_CONTROL_NODE_DI
|
||||
from .util import UTILITY_NODE_CLASS_MAPPINGS, UTILITY_NODE_DISPLAY_NAME_MAPPINGS
|
||||
from .conditions import CONDITION_NODE_CLASS_MAPPINGS, CONDITION_NODE_DISPLAY_NAME_MAPPINGS
|
||||
from .stubs import TEST_STUB_NODE_CLASS_MAPPINGS, TEST_STUB_NODE_DISPLAY_NAME_MAPPINGS
|
||||
from .async_test_nodes import ASYNC_TEST_NODE_CLASS_MAPPINGS, ASYNC_TEST_NODE_DISPLAY_NAME_MAPPINGS
|
||||
|
||||
# NODE_CLASS_MAPPINGS = GENERAL_NODE_CLASS_MAPPINGS.update(COMPONENT_NODE_CLASS_MAPPINGS)
|
||||
# NODE_DISPLAY_NAME_MAPPINGS = GENERAL_NODE_DISPLAY_NAME_MAPPINGS.update(COMPONENT_NODE_DISPLAY_NAME_MAPPINGS)
|
||||
@@ -13,6 +14,7 @@ NODE_CLASS_MAPPINGS.update(FLOW_CONTROL_NODE_CLASS_MAPPINGS)
|
||||
NODE_CLASS_MAPPINGS.update(UTILITY_NODE_CLASS_MAPPINGS)
|
||||
NODE_CLASS_MAPPINGS.update(CONDITION_NODE_CLASS_MAPPINGS)
|
||||
NODE_CLASS_MAPPINGS.update(TEST_STUB_NODE_CLASS_MAPPINGS)
|
||||
NODE_CLASS_MAPPINGS.update(ASYNC_TEST_NODE_CLASS_MAPPINGS)
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {}
|
||||
NODE_DISPLAY_NAME_MAPPINGS.update(TEST_NODE_DISPLAY_NAME_MAPPINGS)
|
||||
@@ -20,4 +22,5 @@ NODE_DISPLAY_NAME_MAPPINGS.update(FLOW_CONTROL_NODE_DISPLAY_NAME_MAPPINGS)
|
||||
NODE_DISPLAY_NAME_MAPPINGS.update(UTILITY_NODE_DISPLAY_NAME_MAPPINGS)
|
||||
NODE_DISPLAY_NAME_MAPPINGS.update(CONDITION_NODE_DISPLAY_NAME_MAPPINGS)
|
||||
NODE_DISPLAY_NAME_MAPPINGS.update(TEST_STUB_NODE_DISPLAY_NAME_MAPPINGS)
|
||||
NODE_DISPLAY_NAME_MAPPINGS.update(ASYNC_TEST_NODE_DISPLAY_NAME_MAPPINGS)
|
||||
|
||||
|
343
tests/inference/testing_nodes/testing-pack/async_test_nodes.py
Normal file
343
tests/inference/testing_nodes/testing-pack/async_test_nodes.py
Normal file
@@ -0,0 +1,343 @@
|
||||
import torch
|
||||
import asyncio
|
||||
from typing import Dict
|
||||
from comfy.utils import ProgressBar
|
||||
from comfy_execution.graph_utils import GraphBuilder
|
||||
from comfy.comfy_types.node_typing import ComfyNodeABC
|
||||
from comfy.comfy_types import IO
|
||||
|
||||
|
||||
class TestAsyncValidation(ComfyNodeABC):
|
||||
"""Test node with async VALIDATE_INPUTS."""
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"value": ("FLOAT", {"default": 5.0}),
|
||||
"threshold": ("FLOAT", {"default": 10.0}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "process"
|
||||
CATEGORY = "_for_testing/async"
|
||||
|
||||
@classmethod
|
||||
async def VALIDATE_INPUTS(cls, value, threshold):
|
||||
# Simulate async validation (e.g., checking remote service)
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
if value > threshold:
|
||||
return f"Value {value} exceeds threshold {threshold}"
|
||||
return True
|
||||
|
||||
def process(self, value, threshold):
|
||||
# Create image based on value
|
||||
intensity = value / 10.0
|
||||
image = torch.ones([1, 512, 512, 3]) * intensity
|
||||
return (image,)
|
||||
|
||||
|
||||
class TestAsyncError(ComfyNodeABC):
|
||||
"""Test node that errors during async execution."""
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"value": (IO.ANY, {}),
|
||||
"error_after": ("FLOAT", {"default": 0.1, "min": 0.0, "max": 10.0}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = (IO.ANY,)
|
||||
FUNCTION = "error_execution"
|
||||
CATEGORY = "_for_testing/async"
|
||||
|
||||
async def error_execution(self, value, error_after):
|
||||
await asyncio.sleep(error_after)
|
||||
raise RuntimeError("Intentional async execution error for testing")
|
||||
|
||||
|
||||
class TestAsyncValidationError(ComfyNodeABC):
|
||||
"""Test node with async validation that always fails."""
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"value": ("FLOAT", {"default": 5.0}),
|
||||
"max_value": ("FLOAT", {"default": 10.0}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "process"
|
||||
CATEGORY = "_for_testing/async"
|
||||
|
||||
@classmethod
|
||||
async def VALIDATE_INPUTS(cls, value, max_value):
|
||||
await asyncio.sleep(0.05)
|
||||
# Always fail validation for values > max_value
|
||||
if value > max_value:
|
||||
return f"Async validation failed: {value} > {max_value}"
|
||||
return True
|
||||
|
||||
def process(self, value, max_value):
|
||||
# This won't be reached if validation fails
|
||||
image = torch.ones([1, 512, 512, 3]) * (value / max_value)
|
||||
return (image,)
|
||||
|
||||
|
||||
class TestAsyncTimeout(ComfyNodeABC):
|
||||
"""Test node that simulates timeout scenarios."""
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"value": (IO.ANY, {}),
|
||||
"timeout": ("FLOAT", {"default": 1.0, "min": 0.1, "max": 10.0}),
|
||||
"operation_time": ("FLOAT", {"default": 2.0, "min": 0.1, "max": 10.0}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = (IO.ANY,)
|
||||
FUNCTION = "timeout_execution"
|
||||
CATEGORY = "_for_testing/async"
|
||||
|
||||
async def timeout_execution(self, value, timeout, operation_time):
|
||||
try:
|
||||
# This will timeout if operation_time > timeout
|
||||
await asyncio.wait_for(asyncio.sleep(operation_time), timeout=timeout)
|
||||
return (value,)
|
||||
except asyncio.TimeoutError:
|
||||
raise RuntimeError(f"Operation timed out after {timeout} seconds")
|
||||
|
||||
|
||||
class TestSyncError(ComfyNodeABC):
|
||||
"""Test node that errors synchronously (for mixed sync/async testing)."""
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"value": (IO.ANY, {}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = (IO.ANY,)
|
||||
FUNCTION = "sync_error"
|
||||
CATEGORY = "_for_testing/async"
|
||||
|
||||
def sync_error(self, value):
|
||||
raise RuntimeError("Intentional sync execution error for testing")
|
||||
|
||||
|
||||
class TestAsyncLazyCheck(ComfyNodeABC):
|
||||
"""Test node with async check_lazy_status."""
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"input1": (IO.ANY, {"lazy": True}),
|
||||
"input2": (IO.ANY, {"lazy": True}),
|
||||
"condition": ("BOOLEAN", {"default": True}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "process"
|
||||
CATEGORY = "_for_testing/async"
|
||||
|
||||
async def check_lazy_status(self, condition, input1, input2):
|
||||
# Simulate async checking (e.g., querying remote service)
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
needed = []
|
||||
if condition and input1 is None:
|
||||
needed.append("input1")
|
||||
if not condition and input2 is None:
|
||||
needed.append("input2")
|
||||
return needed
|
||||
|
||||
def process(self, input1, input2, condition):
|
||||
# Return a simple image
|
||||
return (torch.ones([1, 512, 512, 3]),)
|
||||
|
||||
|
||||
class TestDynamicAsyncGeneration(ComfyNodeABC):
|
||||
"""Test node that dynamically generates async nodes."""
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"image1": ("IMAGE",),
|
||||
"image2": ("IMAGE",),
|
||||
"num_async_nodes": ("INT", {"default": 3, "min": 1, "max": 10}),
|
||||
"sleep_duration": ("FLOAT", {"default": 0.2, "min": 0.1, "max": 1.0}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "generate_async_workflow"
|
||||
CATEGORY = "_for_testing/async"
|
||||
|
||||
def generate_async_workflow(self, image1, image2, num_async_nodes, sleep_duration):
|
||||
g = GraphBuilder()
|
||||
|
||||
# Create multiple async sleep nodes
|
||||
sleep_nodes = []
|
||||
for i in range(num_async_nodes):
|
||||
image = image1 if i % 2 == 0 else image2
|
||||
sleep_node = g.node("TestSleep", value=image, seconds=sleep_duration)
|
||||
sleep_nodes.append(sleep_node)
|
||||
|
||||
# Average all results
|
||||
if len(sleep_nodes) == 1:
|
||||
final_node = sleep_nodes[0]
|
||||
else:
|
||||
avg_inputs = {"input1": sleep_nodes[0].out(0)}
|
||||
for i, node in enumerate(sleep_nodes[1:], 2):
|
||||
avg_inputs[f"input{i}"] = node.out(0)
|
||||
final_node = g.node("TestVariadicAverage", **avg_inputs)
|
||||
|
||||
return {
|
||||
"result": (final_node.out(0),),
|
||||
"expand": g.finalize(),
|
||||
}
|
||||
|
||||
|
||||
class TestAsyncResourceUser(ComfyNodeABC):
|
||||
"""Test node that uses resources during async execution."""
|
||||
|
||||
# Class-level resource tracking for testing
|
||||
_active_resources: Dict[str, bool] = {}
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"value": (IO.ANY, {}),
|
||||
"resource_id": ("STRING", {"default": "resource_0"}),
|
||||
"duration": ("FLOAT", {"default": 0.1, "min": 0.0, "max": 1.0}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = (IO.ANY,)
|
||||
FUNCTION = "use_resource"
|
||||
CATEGORY = "_for_testing/async"
|
||||
|
||||
async def use_resource(self, value, resource_id, duration):
|
||||
# Check if resource is already in use
|
||||
if self._active_resources.get(resource_id, False):
|
||||
raise RuntimeError(f"Resource {resource_id} is already in use!")
|
||||
|
||||
# Mark resource as in use
|
||||
self._active_resources[resource_id] = True
|
||||
|
||||
try:
|
||||
# Simulate resource usage
|
||||
await asyncio.sleep(duration)
|
||||
return (value,)
|
||||
finally:
|
||||
# Always clean up resource
|
||||
self._active_resources[resource_id] = False
|
||||
|
||||
|
||||
class TestAsyncBatchProcessing(ComfyNodeABC):
|
||||
"""Test async processing of batched inputs."""
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"images": ("IMAGE",),
|
||||
"process_time_per_item": ("FLOAT", {"default": 0.1, "min": 0.01, "max": 1.0}),
|
||||
},
|
||||
"hidden": {
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "process_batch"
|
||||
CATEGORY = "_for_testing/async"
|
||||
|
||||
async def process_batch(self, images, process_time_per_item, unique_id):
|
||||
batch_size = images.shape[0]
|
||||
pbar = ProgressBar(batch_size, node_id=unique_id)
|
||||
|
||||
# Process each image in the batch
|
||||
processed = []
|
||||
for i in range(batch_size):
|
||||
# Simulate async processing
|
||||
await asyncio.sleep(process_time_per_item)
|
||||
|
||||
# Simple processing: invert the image
|
||||
processed_image = 1.0 - images[i:i+1]
|
||||
processed.append(processed_image)
|
||||
|
||||
pbar.update(1)
|
||||
|
||||
# Stack processed images
|
||||
result = torch.cat(processed, dim=0)
|
||||
return (result,)
|
||||
|
||||
|
||||
class TestAsyncConcurrentLimit(ComfyNodeABC):
|
||||
"""Test concurrent execution limits for async nodes."""
|
||||
|
||||
_semaphore = asyncio.Semaphore(2) # Only allow 2 concurrent executions
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"value": (IO.ANY, {}),
|
||||
"duration": ("FLOAT", {"default": 0.5, "min": 0.1, "max": 2.0}),
|
||||
"node_id": ("INT", {"default": 0}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = (IO.ANY,)
|
||||
FUNCTION = "limited_execution"
|
||||
CATEGORY = "_for_testing/async"
|
||||
|
||||
async def limited_execution(self, value, duration, node_id):
|
||||
async with self._semaphore:
|
||||
# Node {node_id} acquired semaphore
|
||||
await asyncio.sleep(duration)
|
||||
# Node {node_id} releasing semaphore
|
||||
return (value,)
|
||||
|
||||
|
||||
# Add node mappings
|
||||
ASYNC_TEST_NODE_CLASS_MAPPINGS = {
|
||||
"TestAsyncValidation": TestAsyncValidation,
|
||||
"TestAsyncError": TestAsyncError,
|
||||
"TestAsyncValidationError": TestAsyncValidationError,
|
||||
"TestAsyncTimeout": TestAsyncTimeout,
|
||||
"TestSyncError": TestSyncError,
|
||||
"TestAsyncLazyCheck": TestAsyncLazyCheck,
|
||||
"TestDynamicAsyncGeneration": TestDynamicAsyncGeneration,
|
||||
"TestAsyncResourceUser": TestAsyncResourceUser,
|
||||
"TestAsyncBatchProcessing": TestAsyncBatchProcessing,
|
||||
"TestAsyncConcurrentLimit": TestAsyncConcurrentLimit,
|
||||
}
|
||||
|
||||
ASYNC_TEST_NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"TestAsyncValidation": "Test Async Validation",
|
||||
"TestAsyncError": "Test Async Error",
|
||||
"TestAsyncValidationError": "Test Async Validation Error",
|
||||
"TestAsyncTimeout": "Test Async Timeout",
|
||||
"TestSyncError": "Test Sync Error",
|
||||
"TestAsyncLazyCheck": "Test Async Lazy Check",
|
||||
"TestDynamicAsyncGeneration": "Test Dynamic Async Generation",
|
||||
"TestAsyncResourceUser": "Test Async Resource User",
|
||||
"TestAsyncBatchProcessing": "Test Async Batch Processing",
|
||||
"TestAsyncConcurrentLimit": "Test Async Concurrent Limit",
|
||||
}
|
@@ -1,6 +1,11 @@
|
||||
import torch
|
||||
import time
|
||||
import asyncio
|
||||
from comfy.utils import ProgressBar
|
||||
from .tools import VariantSupport
|
||||
from comfy_execution.graph_utils import GraphBuilder
|
||||
from comfy.comfy_types.node_typing import ComfyNodeABC
|
||||
from comfy.comfy_types import IO
|
||||
|
||||
class TestLazyMixImages:
|
||||
@classmethod
|
||||
@@ -333,6 +338,131 @@ class TestMixedExpansionReturns:
|
||||
"expand": g.finalize(),
|
||||
}
|
||||
|
||||
class TestSamplingInExpansion:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"model": ("MODEL",),
|
||||
"clip": ("CLIP",),
|
||||
"vae": ("VAE",),
|
||||
"seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
|
||||
"steps": ("INT", {"default": 20, "min": 1, "max": 100}),
|
||||
"cfg": ("FLOAT", {"default": 7.0, "min": 0.0, "max": 30.0}),
|
||||
"prompt": ("STRING", {"multiline": True, "default": "a beautiful landscape with mountains and trees"}),
|
||||
"negative_prompt": ("STRING", {"multiline": True, "default": "blurry, bad quality, worst quality"}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "sampling_in_expansion"
|
||||
|
||||
CATEGORY = "Testing/Nodes"
|
||||
|
||||
def sampling_in_expansion(self, model, clip, vae, seed, steps, cfg, prompt, negative_prompt):
|
||||
g = GraphBuilder()
|
||||
|
||||
# Create a basic image generation workflow using the input model, clip and vae
|
||||
# 1. Setup text prompts using the provided CLIP model
|
||||
positive_prompt = g.node("CLIPTextEncode",
|
||||
text=prompt,
|
||||
clip=clip)
|
||||
negative_prompt = g.node("CLIPTextEncode",
|
||||
text=negative_prompt,
|
||||
clip=clip)
|
||||
|
||||
# 2. Create empty latent with specified size
|
||||
empty_latent = g.node("EmptyLatentImage", width=512, height=512, batch_size=1)
|
||||
|
||||
# 3. Setup sampler and generate image latent
|
||||
sampler = g.node("KSampler",
|
||||
model=model,
|
||||
positive=positive_prompt.out(0),
|
||||
negative=negative_prompt.out(0),
|
||||
latent_image=empty_latent.out(0),
|
||||
seed=seed,
|
||||
steps=steps,
|
||||
cfg=cfg,
|
||||
sampler_name="euler_ancestral",
|
||||
scheduler="normal")
|
||||
|
||||
# 4. Decode latent to image using VAE
|
||||
output = g.node("VAEDecode", samples=sampler.out(0), vae=vae)
|
||||
|
||||
return {
|
||||
"result": (output.out(0),),
|
||||
"expand": g.finalize(),
|
||||
}
|
||||
|
||||
class TestSleep(ComfyNodeABC):
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"value": (IO.ANY, {}),
|
||||
"seconds": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 9999.0, "step": 0.01, "tooltip": "The amount of seconds to sleep."}),
|
||||
},
|
||||
"hidden": {
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
RETURN_TYPES = (IO.ANY,)
|
||||
FUNCTION = "sleep"
|
||||
|
||||
CATEGORY = "_for_testing"
|
||||
|
||||
async def sleep(self, value, seconds, unique_id):
|
||||
pbar = ProgressBar(seconds, node_id=unique_id)
|
||||
start = time.time()
|
||||
expiration = start + seconds
|
||||
now = start
|
||||
while now < expiration:
|
||||
now = time.time()
|
||||
pbar.update_absolute(now - start)
|
||||
await asyncio.sleep(0.01)
|
||||
return (value,)
|
||||
|
||||
class TestParallelSleep(ComfyNodeABC):
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"image1": ("IMAGE", ),
|
||||
"image2": ("IMAGE", ),
|
||||
"image3": ("IMAGE", ),
|
||||
"sleep1": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||
"sleep2": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||
"sleep3": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||
},
|
||||
"hidden": {
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "parallel_sleep"
|
||||
CATEGORY = "_for_testing"
|
||||
OUTPUT_NODE = True
|
||||
|
||||
def parallel_sleep(self, image1, image2, image3, sleep1, sleep2, sleep3, unique_id):
|
||||
# Create a graph dynamically with three TestSleep nodes
|
||||
g = GraphBuilder()
|
||||
|
||||
# Create sleep nodes for each duration and image
|
||||
sleep_node1 = g.node("TestSleep", value=image1, seconds=sleep1)
|
||||
sleep_node2 = g.node("TestSleep", value=image2, seconds=sleep2)
|
||||
sleep_node3 = g.node("TestSleep", value=image3, seconds=sleep3)
|
||||
|
||||
# Blend the results using TestVariadicAverage
|
||||
blend = g.node("TestVariadicAverage",
|
||||
input1=sleep_node1.out(0),
|
||||
input2=sleep_node2.out(0),
|
||||
input3=sleep_node3.out(0))
|
||||
|
||||
return {
|
||||
"result": (blend.out(0),),
|
||||
"expand": g.finalize(),
|
||||
}
|
||||
|
||||
TEST_NODE_CLASS_MAPPINGS = {
|
||||
"TestLazyMixImages": TestLazyMixImages,
|
||||
"TestVariadicAverage": TestVariadicAverage,
|
||||
@@ -345,6 +475,9 @@ TEST_NODE_CLASS_MAPPINGS = {
|
||||
"TestCustomValidation5": TestCustomValidation5,
|
||||
"TestDynamicDependencyCycle": TestDynamicDependencyCycle,
|
||||
"TestMixedExpansionReturns": TestMixedExpansionReturns,
|
||||
"TestSamplingInExpansion": TestSamplingInExpansion,
|
||||
"TestSleep": TestSleep,
|
||||
"TestParallelSleep": TestParallelSleep,
|
||||
}
|
||||
|
||||
TEST_NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
@@ -359,4 +492,7 @@ TEST_NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"TestCustomValidation5": "Custom Validation 5",
|
||||
"TestDynamicDependencyCycle": "Dynamic Dependency Cycle",
|
||||
"TestMixedExpansionReturns": "Mixed Expansion Returns",
|
||||
"TestSamplingInExpansion": "Sampling In Expansion",
|
||||
"TestSleep": "Test Sleep",
|
||||
"TestParallelSleep": "Test Parallel Sleep",
|
||||
}
|
||||
|
18
utils/install_util.py
Normal file
18
utils/install_util.py
Normal file
@@ -0,0 +1,18 @@
|
||||
from pathlib import Path
|
||||
import sys
|
||||
|
||||
# The path to the requirements.txt file
|
||||
requirements_path = Path(__file__).parents[1] / "requirements.txt"
|
||||
|
||||
|
||||
def get_missing_requirements_message():
|
||||
"""The warning message to display when a package is missing."""
|
||||
|
||||
extra = ""
|
||||
if sys.flags.no_user_site:
|
||||
extra = "-s "
|
||||
return f"""
|
||||
Please install the updated requirements.txt file by running:
|
||||
{sys.executable} {extra}-m pip install -r {requirements_path}
|
||||
If you are on the portable package you can run: update\\update_comfyui.bat to solve this problem.
|
||||
""".strip()
|
Reference in New Issue
Block a user