mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-08-02 23:14:49 +08:00
Compare commits
163 Commits
Author | SHA1 | Date | |
---|---|---|---|
|
b8ffb2937f | ||
|
ce37c11164 | ||
|
b5c3906b38 | ||
|
5d43e75e5b | ||
|
517f4a94e4 | ||
|
52a471c5c7 | ||
|
ad76574cb8 | ||
|
9acfe4df41 | ||
|
9829b013ea | ||
|
5c69cde037 | ||
|
e9589d6d92 | ||
|
0d82a798a5 | ||
|
925fff26fd | ||
|
75b9b55b22 | ||
|
1765f1c60c | ||
|
1de69fe4d5 | ||
|
ae197f651b | ||
|
1b5b8ca81a | ||
|
6678d5cf65 | ||
|
e172564eea | ||
|
a3cc326748 | ||
|
86a97e91fc | ||
|
5acdadc9f3 | ||
|
55ad9d5f8c | ||
|
a9f04edc58 | ||
|
a475ec2300 | ||
|
06eb9fb426 | ||
|
413322645e | ||
|
11200de970 | ||
|
037c38eb0f | ||
|
1e11d2d1f5 | ||
|
65ea6be38f | ||
|
5df6f57b5d | ||
|
6588bfdef9 | ||
|
50ed2879ef | ||
|
66d4233210 | ||
|
591010b7ef | ||
|
08f92d55e9 | ||
|
8115d8cce9 | ||
|
6969fc9ba4 | ||
|
cb7c4b4be3 | ||
|
1208863eca | ||
|
e1c528196e | ||
|
17030fd4c0 | ||
|
c19dcd362f | ||
|
1c08bf35b4 | ||
|
2a02546e20 | ||
|
b334605a66 | ||
|
de17a9755e | ||
|
c14ac98fed | ||
|
2894511893 | ||
|
f3bc40223a | ||
|
841e74ac40 | ||
|
2d75df45e6 | ||
|
1abc9c8703 | ||
|
8edbcf5209 | ||
|
e545a636ba | ||
|
33e5203a2a | ||
|
a178e25912 | ||
|
78e133d041 | ||
|
7afa985fba | ||
|
ddb6a9f47c | ||
|
3b71f84b50 | ||
|
0a6b008117 | ||
|
56f3c660bf | ||
|
f7a5107784 | ||
|
91be9c2867 | ||
|
03c5018c98 | ||
|
2ba5cc8b86 | ||
|
1e68002b87 | ||
|
ba9095e5bd | ||
|
f123328b82 | ||
|
63a7e8edba | ||
|
0eea47d580 | ||
|
7cd0cdfce6 | ||
|
ea03c9dcd2 | ||
|
3a9ee995cf | ||
|
47da42d928 | ||
|
17bbd83176 | ||
|
bfb52de866 | ||
|
eca962c6da | ||
|
c1696cd1b5 | ||
|
369f459b20 | ||
|
ce9ac2fe05 | ||
|
e638f2858a | ||
|
a531001cc7 | ||
|
d420bc792a | ||
|
d965474aaa | ||
|
1c61361fd2 | ||
|
a6decf1e62 | ||
|
48eb1399c0 | ||
|
b4f6ebb2e8 | ||
|
d7430a1651 | ||
|
f2b80f95d2 | ||
|
1aa9cf3292 | ||
|
2f88d19ef3 | ||
|
eb96c3bd82 | ||
|
5f98de7697 | ||
|
8d34211a7a | ||
|
1589b58d3e | ||
|
7ad574bffd | ||
|
e2382b6adb | ||
|
c24f897352 | ||
|
a5991a7aa6 | ||
|
2c038ccef0 | ||
|
b85216a3c0 | ||
|
82cae45d44 | ||
|
25853d0be8 | ||
|
79040635da | ||
|
66d35c07ce | ||
|
c75b50607b | ||
|
4ba7fa0244 | ||
|
ab76abc767 | ||
|
9300058026 | ||
|
f82d09c9b4 | ||
|
e6829e7ac5 | ||
|
07f6a1a685 | ||
|
e746965c50 | ||
|
45a2842d7f | ||
|
17b41f622e | ||
|
cf4418b806 | ||
|
6225a7827c | ||
|
b6779d8df3 | ||
|
8328a2d8cd | ||
|
afe732bef9 | ||
|
a9ac56fc0d | ||
|
25b51b1a8b | ||
|
61a2b00bc2 | ||
|
a5f4292f9f | ||
|
f87810cd3e | ||
|
10c919f4c7 | ||
|
ce80e69fb8 | ||
|
19944ad252 | ||
|
10b43ceea5 | ||
|
0a4c49c57c | ||
|
88ed893034 | ||
|
334ba48cea | ||
|
14764aa2e2 | ||
|
b2c995f623 | ||
|
4151fbfa8a | ||
|
6045ed31f8 | ||
|
f836e69346 | ||
|
5b69cfe7c3 | ||
|
95fa9545f1 | ||
|
11b74147ee | ||
|
6ab8cad22e | ||
|
011b11d8d7 | ||
|
ff6ca2a892 | ||
|
374e093e09 | ||
|
855789403b | ||
|
6f7869f365 | ||
|
281ad42df4 | ||
|
1cde6b2eff | ||
|
c5a48b15bd | ||
|
f2298799ba | ||
|
60383f3b64 | ||
|
8270c62530 | ||
|
821f93872e | ||
|
e1630391d6 | ||
|
99458e8aca | ||
|
33346fd9b8 | ||
|
136c93cb47 | ||
|
1305fb294c |
@@ -62,8 +62,15 @@ except:
|
||||
|
||||
print("checking out master branch")
|
||||
branch = repo.lookup_branch('master')
|
||||
ref = repo.lookup_reference(branch.name)
|
||||
repo.checkout(ref)
|
||||
if branch is None:
|
||||
ref = repo.lookup_reference('refs/remotes/origin/master')
|
||||
repo.checkout(ref)
|
||||
branch = repo.lookup_branch('master')
|
||||
if branch is None:
|
||||
repo.create_branch('master', repo.get(ref.target))
|
||||
else:
|
||||
ref = repo.lookup_reference(branch.name)
|
||||
repo.checkout(ref)
|
||||
|
||||
print("pulling latest changes")
|
||||
pull(repo)
|
||||
|
53
.github/workflows/pullrequest-ci-run.yml
vendored
Normal file
53
.github/workflows/pullrequest-ci-run.yml
vendored
Normal file
@@ -0,0 +1,53 @@
|
||||
# This is the GitHub Workflow that drives full-GPU-enabled tests of pull requests to ComfyUI, when the 'Run-CI-Test' label is added
|
||||
# Results are reported as checkmarks on the commits, as well as onto https://ci.comfy.org/
|
||||
name: Pull Request CI Workflow Runs
|
||||
on:
|
||||
pull_request_target:
|
||||
types: [labeled]
|
||||
|
||||
jobs:
|
||||
pr-test-stable:
|
||||
if: ${{ github.event.label.name == 'Run-CI-Test' }}
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
os: [macos, linux, windows]
|
||||
python_version: ["3.9", "3.10", "3.11", "3.12"]
|
||||
cuda_version: ["12.1"]
|
||||
torch_version: ["stable"]
|
||||
include:
|
||||
- os: macos
|
||||
runner_label: [self-hosted, macOS]
|
||||
flags: "--use-pytorch-cross-attention"
|
||||
- os: linux
|
||||
runner_label: [self-hosted, Linux]
|
||||
flags: ""
|
||||
- os: windows
|
||||
runner_label: [self-hosted, win]
|
||||
flags: ""
|
||||
runs-on: ${{ matrix.runner_label }}
|
||||
steps:
|
||||
- name: Test Workflows
|
||||
uses: comfy-org/comfy-action@main
|
||||
with:
|
||||
os: ${{ matrix.os }}
|
||||
python_version: ${{ matrix.python_version }}
|
||||
torch_version: ${{ matrix.torch_version }}
|
||||
google_credentials: ${{ secrets.GCS_SERVICE_ACCOUNT_JSON }}
|
||||
comfyui_flags: ${{ matrix.flags }}
|
||||
use_prior_commit: 'true'
|
||||
comment:
|
||||
if: ${{ github.event.label.name == 'Run-CI-Test' }}
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
pull-requests: write
|
||||
steps:
|
||||
- uses: actions/github-script@v6
|
||||
with:
|
||||
script: |
|
||||
github.rest.issues.createComment({
|
||||
issue_number: context.issue.number,
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
body: '(Automated Bot Message) CI Tests are running, you can view the results at https://ci.comfy.org/?branch=${{ github.event.pull_request.number }}%2Fmerge'
|
||||
})
|
23
.github/workflows/pylint.yml
vendored
Normal file
23
.github/workflows/pylint.yml
vendored
Normal file
@@ -0,0 +1,23 @@
|
||||
name: Python Linting
|
||||
|
||||
on: [push, pull_request]
|
||||
|
||||
jobs:
|
||||
pylint:
|
||||
name: Run Pylint
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v2
|
||||
with:
|
||||
python-version: 3.x
|
||||
|
||||
- name: Install Pylint
|
||||
run: pip install pylint
|
||||
|
||||
- name: Run Pylint
|
||||
run: pylint --rcfile=.pylintrc $(find . -type f -name "*.py")
|
95
.github/workflows/stable-release.yml
vendored
95
.github/workflows/stable-release.yml
vendored
@@ -2,9 +2,28 @@
|
||||
name: "Release Stable Version"
|
||||
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- 'v*'
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
git_tag:
|
||||
description: 'Git tag'
|
||||
required: true
|
||||
type: string
|
||||
cu:
|
||||
description: 'CUDA version'
|
||||
required: true
|
||||
type: string
|
||||
default: "121"
|
||||
python_minor:
|
||||
description: 'Python minor version'
|
||||
required: true
|
||||
type: string
|
||||
default: "11"
|
||||
python_patch:
|
||||
description: 'Python patch version'
|
||||
required: true
|
||||
type: string
|
||||
default: "9"
|
||||
|
||||
|
||||
jobs:
|
||||
package_comfy_windows:
|
||||
@@ -13,69 +32,44 @@ jobs:
|
||||
packages: "write"
|
||||
pull-requests: "read"
|
||||
runs-on: windows-latest
|
||||
strategy:
|
||||
matrix:
|
||||
python_version: [3.11.8]
|
||||
cuda_version: [121]
|
||||
steps:
|
||||
- name: Calculate Minor Version
|
||||
shell: bash
|
||||
run: |
|
||||
# Extract the minor version from the Python version
|
||||
MINOR_VERSION=$(echo "${{ matrix.python_version }}" | cut -d'.' -f2)
|
||||
echo "MINOR_VERSION=$MINOR_VERSION" >> $GITHUB_ENV
|
||||
- name: Setup Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ matrix.python_version }}
|
||||
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
ref: ${{ inputs.git_tag }}
|
||||
fetch-depth: 0
|
||||
persist-credentials: false
|
||||
- uses: actions/cache/restore@v4
|
||||
id: cache
|
||||
with:
|
||||
path: |
|
||||
cu${{ inputs.cu }}_python_deps.tar
|
||||
update_comfyui_and_python_dependencies.bat
|
||||
key: ${{ runner.os }}-build-cu${{ inputs.cu }}-${{ inputs.python_minor }}
|
||||
- shell: bash
|
||||
run: |
|
||||
echo "@echo off
|
||||
call update_comfyui.bat nopause
|
||||
echo -
|
||||
echo This will try to update pytorch and all python dependencies.
|
||||
echo -
|
||||
echo If you just want to update normally, close this and run update_comfyui.bat instead.
|
||||
echo -
|
||||
pause
|
||||
..\python_embeded\python.exe -s -m pip install --upgrade torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu${{ matrix.cuda_version }} -r ../ComfyUI/requirements.txt pygit2
|
||||
pause" > update_comfyui_and_python_dependencies.bat
|
||||
|
||||
python -m pip wheel --no-cache-dir torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu${{ matrix.cuda_version }} -r requirements.txt pygit2 -w ./temp_wheel_dir
|
||||
python -m pip install --no-cache-dir ./temp_wheel_dir/*
|
||||
echo installed basic
|
||||
ls -lah temp_wheel_dir
|
||||
mv temp_wheel_dir cu${{ matrix.cuda_version }}_python_deps
|
||||
mv cu${{ matrix.cuda_version }}_python_deps ../
|
||||
mv cu${{ inputs.cu }}_python_deps.tar ../
|
||||
mv update_comfyui_and_python_dependencies.bat ../
|
||||
cd ..
|
||||
tar xf cu${{ inputs.cu }}_python_deps.tar
|
||||
pwd
|
||||
ls
|
||||
|
||||
|
||||
- shell: bash
|
||||
run: |
|
||||
cd ..
|
||||
cp -r ComfyUI ComfyUI_copy
|
||||
curl https://www.python.org/ftp/python/${{ matrix.python_version }}/python-${{ matrix.python_version }}-embed-amd64.zip -o python_embeded.zip
|
||||
curl https://www.python.org/ftp/python/3.${{ inputs.python_minor }}.${{ inputs.python_patch }}/python-3.${{ inputs.python_minor }}.${{ inputs.python_patch }}-embed-amd64.zip -o python_embeded.zip
|
||||
unzip python_embeded.zip -d python_embeded
|
||||
cd python_embeded
|
||||
echo ${{ env.MINOR_VERSION }}
|
||||
echo 'import site' >> ./python3${{ env.MINOR_VERSION }}._pth
|
||||
echo 'import site' >> ./python3${{ inputs.python_minor }}._pth
|
||||
curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py
|
||||
./python.exe get-pip.py
|
||||
./python.exe --version
|
||||
echo "Pip version:"
|
||||
./python.exe -m pip --version
|
||||
./python.exe -s -m pip install ../cu${{ inputs.cu }}_python_deps/*
|
||||
sed -i '1i../ComfyUI' ./python3${{ inputs.python_minor }}._pth
|
||||
cd ..
|
||||
|
||||
set PATH=$PWD/Scripts:$PATH
|
||||
echo $PATH
|
||||
./python.exe -s -m pip install ../cu${{ matrix.cuda_version }}_python_deps/*
|
||||
sed -i '1i../ComfyUI' ./python3${{ env.MINOR_VERSION }}._pth
|
||||
cd ..
|
||||
|
||||
git clone https://github.com/comfyanonymous/taesd
|
||||
git clone --depth 1 https://github.com/comfyanonymous/taesd
|
||||
cp taesd/*.pth ./ComfyUI_copy/models/vae_approx/
|
||||
|
||||
mkdir ComfyUI_windows_portable
|
||||
@@ -104,6 +98,7 @@ jobs:
|
||||
with:
|
||||
repo_token: ${{ secrets.GITHUB_TOKEN }}
|
||||
file: ComfyUI_windows_portable_nvidia.7z
|
||||
tag: ${{ github.ref }}
|
||||
tag: ${{ inputs.git_tag }}
|
||||
overwrite: true
|
||||
|
||||
prerelease: true
|
||||
make_latest: false
|
||||
|
2
.github/workflows/test-browser.yml
vendored
2
.github/workflows/test-browser.yml
vendored
@@ -32,7 +32,7 @@ jobs:
|
||||
node-version: lts/*
|
||||
- uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: '3.10'
|
||||
python-version: '3.8'
|
||||
- name: Install requirements
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
|
95
.github/workflows/test-ci.yml
vendored
Normal file
95
.github/workflows/test-ci.yml
vendored
Normal file
@@ -0,0 +1,95 @@
|
||||
# This is the GitHub Workflow that drives automatic full-GPU-enabled tests of all new commits to the master branch of ComfyUI
|
||||
# Results are reported as checkmarks on the commits, as well as onto https://ci.comfy.org/
|
||||
name: Full Comfy CI Workflow Runs
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- master
|
||||
paths-ignore:
|
||||
- 'app/**'
|
||||
- 'input/**'
|
||||
- 'output/**'
|
||||
- 'notebooks/**'
|
||||
- 'script_examples/**'
|
||||
- '.github/**'
|
||||
- 'web/**'
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
test-stable:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
os: [macos, linux, windows]
|
||||
python_version: ["3.9", "3.10", "3.11", "3.12"]
|
||||
cuda_version: ["12.1"]
|
||||
torch_version: ["stable"]
|
||||
include:
|
||||
- os: macos
|
||||
runner_label: [self-hosted, macOS]
|
||||
flags: "--use-pytorch-cross-attention"
|
||||
- os: linux
|
||||
runner_label: [self-hosted, Linux]
|
||||
flags: ""
|
||||
- os: windows
|
||||
runner_label: [self-hosted, win]
|
||||
flags: ""
|
||||
runs-on: ${{ matrix.runner_label }}
|
||||
steps:
|
||||
- name: Test Workflows
|
||||
uses: comfy-org/comfy-action@main
|
||||
with:
|
||||
os: ${{ matrix.os }}
|
||||
python_version: ${{ matrix.python_version }}
|
||||
torch_version: ${{ matrix.torch_version }}
|
||||
google_credentials: ${{ secrets.GCS_SERVICE_ACCOUNT_JSON }}
|
||||
comfyui_flags: ${{ matrix.flags }}
|
||||
|
||||
test-win-nightly:
|
||||
strategy:
|
||||
fail-fast: true
|
||||
matrix:
|
||||
os: [windows]
|
||||
python_version: ["3.9", "3.10", "3.11", "3.12"]
|
||||
cuda_version: ["12.1"]
|
||||
torch_version: ["nightly"]
|
||||
include:
|
||||
- os: windows
|
||||
runner_label: [self-hosted, win]
|
||||
flags: ""
|
||||
runs-on: ${{ matrix.runner_label }}
|
||||
steps:
|
||||
- name: Test Workflows
|
||||
uses: comfy-org/comfy-action@main
|
||||
with:
|
||||
os: ${{ matrix.os }}
|
||||
python_version: ${{ matrix.python_version }}
|
||||
torch_version: ${{ matrix.torch_version }}
|
||||
google_credentials: ${{ secrets.GCS_SERVICE_ACCOUNT_JSON }}
|
||||
comfyui_flags: ${{ matrix.flags }}
|
||||
|
||||
test-unix-nightly:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
os: [macos, linux]
|
||||
python_version: ["3.11"]
|
||||
cuda_version: ["12.1"]
|
||||
torch_version: ["nightly"]
|
||||
include:
|
||||
- os: macos
|
||||
runner_label: [self-hosted, macOS]
|
||||
flags: "--use-pytorch-cross-attention"
|
||||
- os: linux
|
||||
runner_label: [self-hosted, Linux]
|
||||
flags: ""
|
||||
runs-on: ${{ matrix.runner_label }}
|
||||
steps:
|
||||
- name: Test Workflows
|
||||
uses: comfy-org/comfy-action@main
|
||||
with:
|
||||
os: ${{ matrix.os }}
|
||||
python_version: ${{ matrix.python_version }}
|
||||
torch_version: ${{ matrix.torch_version }}
|
||||
google_credentials: ${{ secrets.GCS_SERVICE_ACCOUNT_JSON }}
|
||||
comfyui_flags: ${{ matrix.flags }}
|
4
.github/workflows/test-ui.yaml
vendored
4
.github/workflows/test-ui.yaml
vendored
@@ -24,3 +24,7 @@ jobs:
|
||||
npm run test:generate
|
||||
npm test -- --verbose
|
||||
working-directory: ./tests-ui
|
||||
- name: Run Unit Tests
|
||||
run: |
|
||||
pip install -r tests-unit/requirements.txt
|
||||
python -m pytest tests-unit
|
||||
|
@@ -8,11 +8,16 @@ on:
|
||||
required: false
|
||||
type: string
|
||||
default: ""
|
||||
extra_dependencies:
|
||||
description: 'extra dependencies'
|
||||
required: false
|
||||
type: string
|
||||
default: "\"numpy<2\""
|
||||
cu:
|
||||
description: 'cuda version'
|
||||
required: true
|
||||
type: string
|
||||
default: "121"
|
||||
default: "124"
|
||||
|
||||
python_minor:
|
||||
description: 'python minor version'
|
||||
@@ -24,7 +29,7 @@ on:
|
||||
description: 'python patch version'
|
||||
required: true
|
||||
type: string
|
||||
default: "8"
|
||||
default: "9"
|
||||
# push:
|
||||
# branches:
|
||||
# - master
|
||||
@@ -51,7 +56,7 @@ jobs:
|
||||
..\python_embeded\python.exe -s -m pip install --upgrade torch torchvision torchaudio ${{ inputs.xformers }} --extra-index-url https://download.pytorch.org/whl/cu${{ inputs.cu }} -r ../ComfyUI/requirements.txt pygit2
|
||||
pause" > update_comfyui_and_python_dependencies.bat
|
||||
|
||||
python -m pip wheel --no-cache-dir torch torchvision torchaudio ${{ inputs.xformers }} --extra-index-url https://download.pytorch.org/whl/cu${{ inputs.cu }} -r requirements.txt pygit2 -w ./temp_wheel_dir
|
||||
python -m pip wheel --no-cache-dir torch torchvision torchaudio ${{ inputs.xformers }} ${{ inputs.extra_dependencies }} --extra-index-url https://download.pytorch.org/whl/cu${{ inputs.cu }} -r requirements.txt pygit2 -w ./temp_wheel_dir
|
||||
python -m pip install --no-cache-dir ./temp_wheel_dir/*
|
||||
echo installed basic
|
||||
ls -lah temp_wheel_dir
|
||||
|
@@ -19,7 +19,7 @@ on:
|
||||
description: 'python patch version'
|
||||
required: true
|
||||
type: string
|
||||
default: "3"
|
||||
default: "4"
|
||||
# push:
|
||||
# branches:
|
||||
# - master
|
||||
@@ -49,13 +49,13 @@ jobs:
|
||||
echo 'import site' >> ./python3${{ inputs.python_minor }}._pth
|
||||
curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py
|
||||
./python.exe get-pip.py
|
||||
python -m pip wheel torch torchvision torchaudio mpmath==1.3.0 numpy==1.26.4 --pre --extra-index-url https://download.pytorch.org/whl/nightly/cu${{ inputs.cu }} -r ../ComfyUI/requirements.txt pygit2 -w ../temp_wheel_dir
|
||||
python -m pip wheel torch torchvision torchaudio --pre --extra-index-url https://download.pytorch.org/whl/nightly/cu${{ inputs.cu }} -r ../ComfyUI/requirements.txt pygit2 -w ../temp_wheel_dir
|
||||
ls ../temp_wheel_dir
|
||||
./python.exe -s -m pip install --pre ../temp_wheel_dir/*
|
||||
sed -i '1i../ComfyUI' ./python3${{ inputs.python_minor }}._pth
|
||||
cd ..
|
||||
|
||||
git clone https://github.com/comfyanonymous/taesd
|
||||
git clone --depth 1 https://github.com/comfyanonymous/taesd
|
||||
cp taesd/*.pth ./ComfyUI_copy/models/vae_approx/
|
||||
|
||||
mkdir ComfyUI_windows_portable_nightly_pytorch
|
||||
|
@@ -7,7 +7,7 @@ on:
|
||||
description: 'cuda version'
|
||||
required: true
|
||||
type: string
|
||||
default: "121"
|
||||
default: "124"
|
||||
|
||||
python_minor:
|
||||
description: 'python minor version'
|
||||
@@ -19,7 +19,7 @@ on:
|
||||
description: 'python patch version'
|
||||
required: true
|
||||
type: string
|
||||
default: "8"
|
||||
default: "9"
|
||||
# push:
|
||||
# branches:
|
||||
# - master
|
||||
@@ -66,7 +66,7 @@ jobs:
|
||||
sed -i '1i../ComfyUI' ./python3${{ inputs.python_minor }}._pth
|
||||
cd ..
|
||||
|
||||
git clone https://github.com/comfyanonymous/taesd
|
||||
git clone --depth 1 https://github.com/comfyanonymous/taesd
|
||||
cp taesd/*.pth ./ComfyUI_copy/models/vae_approx/
|
||||
|
||||
mkdir ComfyUI_windows_portable
|
||||
|
4
.gitignore
vendored
4
.gitignore
vendored
@@ -17,4 +17,6 @@ venv/
|
||||
!/web/extensions/core/
|
||||
/tests-ui/data/object_info.json
|
||||
/user/
|
||||
*.log
|
||||
*.log
|
||||
web_custom_versions/
|
||||
.DS_Store
|
||||
|
18
README.md
18
README.md
@@ -12,6 +12,7 @@ This ui will let you design and execute advanced stable diffusion pipelines usin
|
||||
## Features
|
||||
- Nodes/graph/flowchart interface to experiment and create complex Stable Diffusion workflows without needing to code anything.
|
||||
- Fully supports SD1.x, SD2.x, [SDXL](https://comfyanonymous.github.io/ComfyUI_examples/sdxl/), [Stable Video Diffusion](https://comfyanonymous.github.io/ComfyUI_examples/video/), [Stable Cascade](https://comfyanonymous.github.io/ComfyUI_examples/stable_cascade/), [SD3](https://comfyanonymous.github.io/ComfyUI_examples/sd3/) and [Stable Audio](https://comfyanonymous.github.io/ComfyUI_examples/audio/)
|
||||
- [Flux](https://comfyanonymous.github.io/ComfyUI_examples/flux/)
|
||||
- Asynchronous Queue system
|
||||
- Many optimizations: Only re-executes the parts of the workflow that changes between executions.
|
||||
- Smart memory management: can automatically run models on GPUs with as low as 1GB vram.
|
||||
@@ -33,6 +34,7 @@ This ui will let you design and execute advanced stable diffusion pipelines usin
|
||||
- [LCM models and Loras](https://comfyanonymous.github.io/ComfyUI_examples/lcm/)
|
||||
- [SDXL Turbo](https://comfyanonymous.github.io/ComfyUI_examples/sdturbo/)
|
||||
- [AuraFlow](https://comfyanonymous.github.io/ComfyUI_examples/aura_flow/)
|
||||
- [HunyuanDiT](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_dit/)
|
||||
- Latent previews with [TAESD](#how-to-show-high-quality-previews)
|
||||
- Starts up very fast.
|
||||
- Works fully offline: will never download anything.
|
||||
@@ -77,7 +79,7 @@ Ctrl can also be replaced with Cmd instead for macOS users
|
||||
|
||||
There is a portable standalone build for Windows that should work for running on Nvidia GPUs or for running on your CPU only on the [releases page](https://github.com/comfyanonymous/ComfyUI/releases).
|
||||
|
||||
### [Direct link to download](https://github.com/comfyanonymous/ComfyUI/releases/download/latest/ComfyUI_windows_portable_nvidia_cu121_or_cpu.7z)
|
||||
### [Direct link to download](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_nvidia.7z)
|
||||
|
||||
Simply download, extract with [7-Zip](https://7-zip.org) and run. Make sure you put your Stable Diffusion checkpoints/models (the huge ckpt/safetensors files) in: ComfyUI\models\checkpoints
|
||||
|
||||
@@ -163,20 +165,6 @@ You can install ComfyUI in Apple Mac silicon (M1 or M2) with any recent macOS ve
|
||||
|
||||
```pip install torch-directml``` Then you can launch ComfyUI with: ```python main.py --directml```
|
||||
|
||||
### I already have another UI for Stable Diffusion installed do I really have to install all of these dependencies?
|
||||
|
||||
You don't. If you have another UI installed and working with its own python venv you can use that venv to run ComfyUI. You can open up your favorite terminal and activate it:
|
||||
|
||||
```source path_to_other_sd_gui/venv/bin/activate```
|
||||
|
||||
or on Windows:
|
||||
|
||||
With Powershell: ```"path_to_other_sd_gui\venv\Scripts\Activate.ps1"```
|
||||
|
||||
With cmd.exe: ```"path_to_other_sd_gui\venv\Scripts\activate.bat"```
|
||||
|
||||
And then you can use that terminal to run ComfyUI without installing any dependencies. Note that the venv folder might be called something else depending on the SD UI.
|
||||
|
||||
# Running
|
||||
|
||||
```python main.py```
|
||||
|
0
app/__init__.py
Normal file
0
app/__init__.py
Normal file
188
app/frontend_management.py
Normal file
188
app/frontend_management.py
Normal file
@@ -0,0 +1,188 @@
|
||||
from __future__ import annotations
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import tempfile
|
||||
import zipfile
|
||||
from dataclasses import dataclass
|
||||
from functools import cached_property
|
||||
from pathlib import Path
|
||||
from typing import TypedDict
|
||||
|
||||
import requests
|
||||
from typing_extensions import NotRequired
|
||||
from comfy.cli_args import DEFAULT_VERSION_STRING
|
||||
|
||||
|
||||
REQUEST_TIMEOUT = 10 # seconds
|
||||
|
||||
|
||||
class Asset(TypedDict):
|
||||
url: str
|
||||
|
||||
|
||||
class Release(TypedDict):
|
||||
id: int
|
||||
tag_name: str
|
||||
name: str
|
||||
prerelease: bool
|
||||
created_at: str
|
||||
published_at: str
|
||||
body: str
|
||||
assets: NotRequired[list[Asset]]
|
||||
|
||||
|
||||
@dataclass
|
||||
class FrontEndProvider:
|
||||
owner: str
|
||||
repo: str
|
||||
|
||||
@property
|
||||
def folder_name(self) -> str:
|
||||
return f"{self.owner}_{self.repo}"
|
||||
|
||||
@property
|
||||
def release_url(self) -> str:
|
||||
return f"https://api.github.com/repos/{self.owner}/{self.repo}/releases"
|
||||
|
||||
@cached_property
|
||||
def all_releases(self) -> list[Release]:
|
||||
releases = []
|
||||
api_url = self.release_url
|
||||
while api_url:
|
||||
response = requests.get(api_url, timeout=REQUEST_TIMEOUT)
|
||||
response.raise_for_status() # Raises an HTTPError if the response was an error
|
||||
releases.extend(response.json())
|
||||
# GitHub uses the Link header to provide pagination links. Check if it exists and update api_url accordingly.
|
||||
if "next" in response.links:
|
||||
api_url = response.links["next"]["url"]
|
||||
else:
|
||||
api_url = None
|
||||
return releases
|
||||
|
||||
@cached_property
|
||||
def latest_release(self) -> Release:
|
||||
latest_release_url = f"{self.release_url}/latest"
|
||||
response = requests.get(latest_release_url, timeout=REQUEST_TIMEOUT)
|
||||
response.raise_for_status() # Raises an HTTPError if the response was an error
|
||||
return response.json()
|
||||
|
||||
def get_release(self, version: str) -> Release:
|
||||
if version == "latest":
|
||||
return self.latest_release
|
||||
else:
|
||||
for release in self.all_releases:
|
||||
if release["tag_name"] in [version, f"v{version}"]:
|
||||
return release
|
||||
raise ValueError(f"Version {version} not found in releases")
|
||||
|
||||
|
||||
def download_release_asset_zip(release: Release, destination_path: str) -> None:
|
||||
"""Download dist.zip from github release."""
|
||||
asset_url = None
|
||||
for asset in release.get("assets", []):
|
||||
if asset["name"] == "dist.zip":
|
||||
asset_url = asset["url"]
|
||||
break
|
||||
|
||||
if not asset_url:
|
||||
raise ValueError("dist.zip not found in the release assets")
|
||||
|
||||
# Use a temporary file to download the zip content
|
||||
with tempfile.TemporaryFile() as tmp_file:
|
||||
headers = {"Accept": "application/octet-stream"}
|
||||
response = requests.get(
|
||||
asset_url, headers=headers, allow_redirects=True, timeout=REQUEST_TIMEOUT
|
||||
)
|
||||
response.raise_for_status() # Ensure we got a successful response
|
||||
|
||||
# Write the content to the temporary file
|
||||
tmp_file.write(response.content)
|
||||
|
||||
# Go back to the beginning of the temporary file
|
||||
tmp_file.seek(0)
|
||||
|
||||
# Extract the zip file content to the destination path
|
||||
with zipfile.ZipFile(tmp_file, "r") as zip_ref:
|
||||
zip_ref.extractall(destination_path)
|
||||
|
||||
|
||||
class FrontendManager:
|
||||
DEFAULT_FRONTEND_PATH = str(Path(__file__).parents[1] / "web")
|
||||
CUSTOM_FRONTENDS_ROOT = str(Path(__file__).parents[1] / "web_custom_versions")
|
||||
|
||||
@classmethod
|
||||
def parse_version_string(cls, value: str) -> tuple[str, str, str]:
|
||||
"""
|
||||
Args:
|
||||
value (str): The version string to parse.
|
||||
|
||||
Returns:
|
||||
tuple[str, str]: A tuple containing provider name and version.
|
||||
|
||||
Raises:
|
||||
argparse.ArgumentTypeError: If the version string is invalid.
|
||||
"""
|
||||
VERSION_PATTERN = r"^([a-zA-Z0-9][a-zA-Z0-9-]{0,38})/([a-zA-Z0-9_.-]+)@(v?\d+\.\d+\.\d+|latest)$"
|
||||
match_result = re.match(VERSION_PATTERN, value)
|
||||
if match_result is None:
|
||||
raise argparse.ArgumentTypeError(f"Invalid version string: {value}")
|
||||
|
||||
return match_result.group(1), match_result.group(2), match_result.group(3)
|
||||
|
||||
@classmethod
|
||||
def init_frontend_unsafe(cls, version_string: str) -> str:
|
||||
"""
|
||||
Initializes the frontend for the specified version.
|
||||
|
||||
Args:
|
||||
version_string (str): The version string.
|
||||
|
||||
Returns:
|
||||
str: The path to the initialized frontend.
|
||||
|
||||
Raises:
|
||||
Exception: If there is an error during the initialization process.
|
||||
main error source might be request timeout or invalid URL.
|
||||
"""
|
||||
if version_string == DEFAULT_VERSION_STRING:
|
||||
return cls.DEFAULT_FRONTEND_PATH
|
||||
|
||||
repo_owner, repo_name, version = cls.parse_version_string(version_string)
|
||||
provider = FrontEndProvider(repo_owner, repo_name)
|
||||
release = provider.get_release(version)
|
||||
|
||||
semantic_version = release["tag_name"].lstrip("v")
|
||||
web_root = str(
|
||||
Path(cls.CUSTOM_FRONTENDS_ROOT) / provider.folder_name / semantic_version
|
||||
)
|
||||
if not os.path.exists(web_root):
|
||||
os.makedirs(web_root, exist_ok=True)
|
||||
logging.info(
|
||||
"Downloading frontend(%s) version(%s) to (%s)",
|
||||
provider.folder_name,
|
||||
semantic_version,
|
||||
web_root,
|
||||
)
|
||||
logging.debug(release)
|
||||
download_release_asset_zip(release, destination_path=web_root)
|
||||
return web_root
|
||||
|
||||
@classmethod
|
||||
def init_frontend(cls, version_string: str) -> str:
|
||||
"""
|
||||
Initializes the frontend with the specified version string.
|
||||
|
||||
Args:
|
||||
version_string (str): The version string to initialize the frontend with.
|
||||
|
||||
Returns:
|
||||
str: The path of the initialized frontend.
|
||||
"""
|
||||
try:
|
||||
return cls.init_frontend_unsafe(version_string)
|
||||
except Exception as e:
|
||||
logging.error("Failed to initialize frontend: %s", e)
|
||||
logging.info("Falling back to the default frontend.")
|
||||
return cls.DEFAULT_FRONTEND_PATH
|
@@ -13,6 +13,7 @@ from ..ldm.modules.diffusionmodules.util import (
|
||||
from ..ldm.modules.attention import SpatialTransformer
|
||||
from ..ldm.modules.diffusionmodules.openaimodel import UNetModel, TimestepEmbedSequential, ResBlock, Downsample
|
||||
from ..ldm.util import exists
|
||||
from .control_types import UNION_CONTROLNET_TYPES
|
||||
from collections import OrderedDict
|
||||
import comfy.ops
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
@@ -361,7 +362,7 @@ class ControlNet(nn.Module):
|
||||
controlnet_cond = self.input_hint_block(hint[idx], emb, context)
|
||||
feat_seq = torch.mean(controlnet_cond, dim=(2, 3))
|
||||
if idx < len(control_type):
|
||||
feat_seq += self.task_embedding[control_type[idx]]
|
||||
feat_seq += self.task_embedding[control_type[idx]].to(dtype=feat_seq.dtype, device=feat_seq.device)
|
||||
|
||||
inputs.append(feat_seq.unsqueeze(1))
|
||||
condition_list.append(controlnet_cond)
|
||||
@@ -390,6 +391,18 @@ class ControlNet(nn.Module):
|
||||
if self.control_add_embedding is not None: #Union Controlnet
|
||||
control_type = kwargs.get("control_type", [])
|
||||
|
||||
if any([c >= self.num_control_type for c in control_type]):
|
||||
max_type = max(control_type)
|
||||
max_type_name = {
|
||||
v: k for k, v in UNION_CONTROLNET_TYPES.items()
|
||||
}[max_type]
|
||||
raise ValueError(
|
||||
f"Control type {max_type_name}({max_type}) is out of range for the number of control types" +
|
||||
f"({self.num_control_type}) supported.\n" +
|
||||
"Please consider using the ProMax ControlNet Union model.\n" +
|
||||
"https://huggingface.co/xinsir/controlnet-union-sdxl-1.0/tree/main"
|
||||
)
|
||||
|
||||
emb += self.control_add_embedding(control_type, emb.dtype, emb.device)
|
||||
if len(control_type) > 0:
|
||||
if len(hint.shape) < 5:
|
||||
|
10
comfy/cldm/control_types.py
Normal file
10
comfy/cldm/control_types.py
Normal file
@@ -0,0 +1,10 @@
|
||||
UNION_CONTROLNET_TYPES = {
|
||||
"openpose": 0,
|
||||
"depth": 1,
|
||||
"hed/pidi/scribble/ted": 2,
|
||||
"canny/lineart/anime_lineart/mlsd": 3,
|
||||
"normal": 4,
|
||||
"segment": 5,
|
||||
"tile": 6,
|
||||
"repaint": 7,
|
||||
}
|
@@ -1,7 +1,10 @@
|
||||
import argparse
|
||||
import enum
|
||||
import os
|
||||
from typing import Optional
|
||||
import comfy.options
|
||||
|
||||
|
||||
class EnumAction(argparse.Action):
|
||||
"""
|
||||
Argparse action for handling Enums
|
||||
@@ -109,6 +112,7 @@ vram_group.add_argument("--lowvram", action="store_true", help="Split the unet i
|
||||
vram_group.add_argument("--novram", action="store_true", help="When lowvram isn't enough.")
|
||||
vram_group.add_argument("--cpu", action="store_true", help="To use the CPU for everything (slow).")
|
||||
|
||||
parser.add_argument("--default-hashing-function", type=str, choices=['md5', 'sha1', 'sha256', 'sha512'], default='sha256', help="Allows you to choose the hash function to use for duplicate filename / contents comparison. Default is sha256.")
|
||||
|
||||
parser.add_argument("--disable-smart-memory", action="store_true", help="Force ComfyUI to agressively offload to regular ram instead of keeping models in vram when it can.")
|
||||
parser.add_argument("--deterministic", action="store_true", help="Make pytorch use slower deterministic algorithms when it can. Note that this might not make images deterministic in all cases.")
|
||||
@@ -124,6 +128,38 @@ parser.add_argument("--multi-user", action="store_true", help="Enables per-user
|
||||
|
||||
parser.add_argument("--verbose", action="store_true", help="Enables more debug prints.")
|
||||
|
||||
# The default built-in provider hosted under web/
|
||||
DEFAULT_VERSION_STRING = "comfyanonymous/ComfyUI@latest"
|
||||
|
||||
parser.add_argument(
|
||||
"--front-end-version",
|
||||
type=str,
|
||||
default=DEFAULT_VERSION_STRING,
|
||||
help="""
|
||||
Specifies the version of the frontend to be used. This command needs internet connectivity to query and
|
||||
download available frontend implementations from GitHub releases.
|
||||
|
||||
The version string should be in the format of:
|
||||
[repoOwner]/[repoName]@[version]
|
||||
where version is one of: "latest" or a valid version number (e.g. "1.0.0")
|
||||
""",
|
||||
)
|
||||
|
||||
def is_valid_directory(path: Optional[str]) -> Optional[str]:
|
||||
"""Validate if the given path is a directory."""
|
||||
if path is None:
|
||||
return None
|
||||
|
||||
if not os.path.isdir(path):
|
||||
raise argparse.ArgumentTypeError(f"{path} is not a valid directory.")
|
||||
return path
|
||||
|
||||
parser.add_argument(
|
||||
"--front-end-root",
|
||||
type=is_valid_directory,
|
||||
default=None,
|
||||
help="The local filesystem path to the directory where the frontend is located. Overrides --front-end-version.",
|
||||
)
|
||||
|
||||
if comfy.options.args_parsing:
|
||||
args = parser.parse_args()
|
||||
|
@@ -5,7 +5,7 @@
|
||||
"attention_dropout": 0.0,
|
||||
"bos_token_id": 0,
|
||||
"dropout": 0.0,
|
||||
"eos_token_id": 2,
|
||||
"eos_token_id": 49407,
|
||||
"hidden_act": "gelu",
|
||||
"hidden_size": 1280,
|
||||
"initializer_factor": 1.0,
|
||||
|
@@ -1,5 +1,6 @@
|
||||
import torch
|
||||
from comfy.ldm.modules.attention import optimized_attention_for_device
|
||||
import comfy.ops
|
||||
|
||||
class CLIPAttention(torch.nn.Module):
|
||||
def __init__(self, embed_dim, heads, dtype, device, operations):
|
||||
@@ -71,13 +72,13 @@ class CLIPEncoder(torch.nn.Module):
|
||||
return x, intermediate
|
||||
|
||||
class CLIPEmbeddings(torch.nn.Module):
|
||||
def __init__(self, embed_dim, vocab_size=49408, num_positions=77, dtype=None, device=None):
|
||||
def __init__(self, embed_dim, vocab_size=49408, num_positions=77, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.token_embedding = torch.nn.Embedding(vocab_size, embed_dim, dtype=dtype, device=device)
|
||||
self.position_embedding = torch.nn.Embedding(num_positions, embed_dim, dtype=dtype, device=device)
|
||||
self.token_embedding = operations.Embedding(vocab_size, embed_dim, dtype=dtype, device=device)
|
||||
self.position_embedding = operations.Embedding(num_positions, embed_dim, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, input_tokens):
|
||||
return self.token_embedding(input_tokens) + self.position_embedding.weight
|
||||
def forward(self, input_tokens, dtype=torch.float32):
|
||||
return self.token_embedding(input_tokens, out_dtype=dtype) + comfy.ops.cast_to(self.position_embedding.weight, dtype=dtype, device=input_tokens.device)
|
||||
|
||||
|
||||
class CLIPTextModel_(torch.nn.Module):
|
||||
@@ -87,14 +88,15 @@ class CLIPTextModel_(torch.nn.Module):
|
||||
heads = config_dict["num_attention_heads"]
|
||||
intermediate_size = config_dict["intermediate_size"]
|
||||
intermediate_activation = config_dict["hidden_act"]
|
||||
self.eos_token_id = config_dict["eos_token_id"]
|
||||
|
||||
super().__init__()
|
||||
self.embeddings = CLIPEmbeddings(embed_dim, dtype=torch.float32, device=device)
|
||||
self.embeddings = CLIPEmbeddings(embed_dim, dtype=dtype, device=device, operations=operations)
|
||||
self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations)
|
||||
self.final_layer_norm = operations.LayerNorm(embed_dim, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, input_tokens, attention_mask=None, intermediate_output=None, final_layer_norm_intermediate=True):
|
||||
x = self.embeddings(input_tokens)
|
||||
def forward(self, input_tokens, attention_mask=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=torch.float32):
|
||||
x = self.embeddings(input_tokens, dtype=dtype)
|
||||
mask = None
|
||||
if attention_mask is not None:
|
||||
mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1])
|
||||
@@ -111,7 +113,7 @@ class CLIPTextModel_(torch.nn.Module):
|
||||
if i is not None and final_layer_norm_intermediate:
|
||||
i = self.final_layer_norm(i)
|
||||
|
||||
pooled_output = x[torch.arange(x.shape[0], device=x.device), input_tokens.to(dtype=torch.int, device=x.device).argmax(dim=-1),]
|
||||
pooled_output = x[torch.arange(x.shape[0], device=x.device), (torch.round(input_tokens).to(dtype=torch.int, device=x.device) == self.eos_token_id).int().argmax(dim=-1),]
|
||||
return x, i, pooled_output
|
||||
|
||||
class CLIPTextModel(torch.nn.Module):
|
||||
@@ -153,11 +155,11 @@ class CLIPVisionEmbeddings(torch.nn.Module):
|
||||
|
||||
num_patches = (image_size // patch_size) ** 2
|
||||
num_positions = num_patches + 1
|
||||
self.position_embedding = torch.nn.Embedding(num_positions, embed_dim, dtype=dtype, device=device)
|
||||
self.position_embedding = operations.Embedding(num_positions, embed_dim, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, pixel_values):
|
||||
embeds = self.patch_embedding(pixel_values).flatten(2).transpose(1, 2)
|
||||
return torch.cat([self.class_embedding.to(embeds.device).expand(pixel_values.shape[0], 1, -1), embeds], dim=1) + self.position_embedding.weight.to(embeds.device)
|
||||
return torch.cat([comfy.ops.cast_to_input(self.class_embedding, embeds).expand(pixel_values.shape[0], 1, -1), embeds], dim=1) + comfy.ops.cast_to_input(self.position_embedding.weight, embeds)
|
||||
|
||||
|
||||
class CLIPVision(torch.nn.Module):
|
||||
@@ -169,7 +171,7 @@ class CLIPVision(torch.nn.Module):
|
||||
intermediate_size = config_dict["intermediate_size"]
|
||||
intermediate_activation = config_dict["hidden_act"]
|
||||
|
||||
self.embeddings = CLIPVisionEmbeddings(embed_dim, config_dict["num_channels"], config_dict["patch_size"], config_dict["image_size"], dtype=torch.float32, device=device, operations=operations)
|
||||
self.embeddings = CLIPVisionEmbeddings(embed_dim, config_dict["num_channels"], config_dict["patch_size"], config_dict["image_size"], dtype=dtype, device=device, operations=operations)
|
||||
self.pre_layrnorm = operations.LayerNorm(embed_dim)
|
||||
self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations)
|
||||
self.post_layernorm = operations.LayerNorm(embed_dim)
|
||||
|
@@ -34,6 +34,7 @@ class ClipVisionModel():
|
||||
with open(json_config) as f:
|
||||
config = json.load(f)
|
||||
|
||||
self.image_size = config.get("image_size", 224)
|
||||
self.load_device = comfy.model_management.text_encoder_device()
|
||||
offload_device = comfy.model_management.text_encoder_offload_device()
|
||||
self.dtype = comfy.model_management.text_encoder_dtype(self.load_device)
|
||||
@@ -50,7 +51,7 @@ class ClipVisionModel():
|
||||
|
||||
def encode_image(self, image):
|
||||
comfy.model_management.load_model_gpu(self.patcher)
|
||||
pixel_values = clip_preprocess(image.to(self.load_device)).float()
|
||||
pixel_values = clip_preprocess(image.to(self.load_device), size=self.image_size).float()
|
||||
out = self.model(pixel_values=pixel_values, intermediate_output=-2)
|
||||
|
||||
outputs = Output()
|
||||
@@ -93,7 +94,10 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False):
|
||||
elif "vision_model.encoder.layers.30.layer_norm1.weight" in sd:
|
||||
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_h.json")
|
||||
elif "vision_model.encoder.layers.22.layer_norm1.weight" in sd:
|
||||
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl.json")
|
||||
if sd["vision_model.embeddings.position_embedding.weight"].shape[0] == 577:
|
||||
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl_336.json")
|
||||
else:
|
||||
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl.json")
|
||||
else:
|
||||
return None
|
||||
|
||||
|
18
comfy/clip_vision_config_vitl_336.json
Normal file
18
comfy/clip_vision_config_vitl_336.json
Normal file
@@ -0,0 +1,18 @@
|
||||
{
|
||||
"attention_dropout": 0.0,
|
||||
"dropout": 0.0,
|
||||
"hidden_act": "quick_gelu",
|
||||
"hidden_size": 1024,
|
||||
"image_size": 336,
|
||||
"initializer_factor": 1.0,
|
||||
"initializer_range": 0.02,
|
||||
"intermediate_size": 4096,
|
||||
"layer_norm_eps": 1e-5,
|
||||
"model_type": "clip_vision_model",
|
||||
"num_attention_heads": 16,
|
||||
"num_channels": 3,
|
||||
"num_hidden_layers": 24,
|
||||
"patch_size": 14,
|
||||
"projection_dim": 768,
|
||||
"torch_dtype": "float32"
|
||||
}
|
@@ -1,4 +1,24 @@
|
||||
"""
|
||||
This file is part of ComfyUI.
|
||||
Copyright (C) 2024 Comfy
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU General Public License as published by
|
||||
the Free Software Foundation, either version 3 of the License, or
|
||||
(at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU General Public License
|
||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
"""
|
||||
|
||||
|
||||
import torch
|
||||
from enum import Enum
|
||||
import math
|
||||
import os
|
||||
import logging
|
||||
@@ -13,7 +33,7 @@ import comfy.cldm.cldm
|
||||
import comfy.t2i_adapter.adapter
|
||||
import comfy.ldm.cascade.controlnet
|
||||
import comfy.cldm.mmdit
|
||||
|
||||
import comfy.ldm.hydit.controlnet
|
||||
|
||||
def broadcast_image_to(tensor, target_batch_size, batched_number):
|
||||
current_batch_size = tensor.shape[0]
|
||||
@@ -33,6 +53,10 @@ def broadcast_image_to(tensor, target_batch_size, batched_number):
|
||||
else:
|
||||
return torch.cat([tensor] * batched_number, dim=0)
|
||||
|
||||
class StrengthType(Enum):
|
||||
CONSTANT = 1
|
||||
LINEAR_UP = 2
|
||||
|
||||
class ControlBase:
|
||||
def __init__(self, device=None):
|
||||
self.cond_hint_original = None
|
||||
@@ -45,11 +69,14 @@ class ControlBase:
|
||||
self.timestep_range = None
|
||||
self.compression_ratio = 8
|
||||
self.upscale_algorithm = 'nearest-exact'
|
||||
self.extra_args = {}
|
||||
|
||||
if device is None:
|
||||
device = comfy.model_management.get_torch_device()
|
||||
self.device = device
|
||||
self.previous_controlnet = None
|
||||
self.extra_conds = []
|
||||
self.strength_type = StrengthType.CONSTANT
|
||||
|
||||
def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(0.0, 1.0), vae=None):
|
||||
self.cond_hint_original = cond_hint
|
||||
@@ -90,7 +117,10 @@ class ControlBase:
|
||||
c.compression_ratio = self.compression_ratio
|
||||
c.upscale_algorithm = self.upscale_algorithm
|
||||
c.latent_format = self.latent_format
|
||||
c.extra_args = self.extra_args.copy()
|
||||
c.vae = self.vae
|
||||
c.extra_conds = self.extra_conds.copy()
|
||||
c.strength_type = self.strength_type
|
||||
|
||||
def inference_memory_requirements(self, dtype):
|
||||
if self.previous_controlnet is not None:
|
||||
@@ -111,7 +141,10 @@ class ControlBase:
|
||||
|
||||
if x not in applied_to: #memory saving strategy, allow shared tensors and only apply strength to shared tensors once
|
||||
applied_to.add(x)
|
||||
x *= self.strength
|
||||
if self.strength_type == StrengthType.CONSTANT:
|
||||
x *= self.strength
|
||||
elif self.strength_type == StrengthType.LINEAR_UP:
|
||||
x *= (self.strength ** float(len(control_output) - i))
|
||||
|
||||
if x.dtype != output_dtype:
|
||||
x = x.to(output_dtype)
|
||||
@@ -135,8 +168,12 @@ class ControlBase:
|
||||
o[i] = prev_val + o[i] #TODO: change back to inplace add if shared tensors stop being an issue
|
||||
return out
|
||||
|
||||
def set_extra_arg(self, argument, value=None):
|
||||
self.extra_args[argument] = value
|
||||
|
||||
|
||||
class ControlNet(ControlBase):
|
||||
def __init__(self, control_model=None, global_average_pooling=False, compression_ratio=8, latent_format=None, device=None, load_device=None, manual_cast_dtype=None):
|
||||
def __init__(self, control_model=None, global_average_pooling=False, compression_ratio=8, latent_format=None, device=None, load_device=None, manual_cast_dtype=None, extra_conds=["y"], strength_type=StrengthType.CONSTANT):
|
||||
super().__init__(device)
|
||||
self.control_model = control_model
|
||||
self.load_device = load_device
|
||||
@@ -148,6 +185,8 @@ class ControlNet(ControlBase):
|
||||
self.model_sampling_current = None
|
||||
self.manual_cast_dtype = manual_cast_dtype
|
||||
self.latent_format = latent_format
|
||||
self.extra_conds += extra_conds
|
||||
self.strength_type = strength_type
|
||||
|
||||
def get_control(self, x_noisy, t, cond, batched_number):
|
||||
control_prev = None
|
||||
@@ -185,13 +224,16 @@ class ControlNet(ControlBase):
|
||||
self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number)
|
||||
|
||||
context = cond.get('crossattn_controlnet', cond['c_crossattn'])
|
||||
y = cond.get('y', None)
|
||||
if y is not None:
|
||||
y = y.to(dtype)
|
||||
extra = self.extra_args.copy()
|
||||
for c in self.extra_conds:
|
||||
temp = cond.get(c, None)
|
||||
if temp is not None:
|
||||
extra[c] = temp.to(dtype)
|
||||
|
||||
timestep = self.model_sampling_current.timestep(t)
|
||||
x_noisy = self.model_sampling_current.calculate_input(t, x_noisy)
|
||||
|
||||
control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.float(), context=context.to(dtype), y=y)
|
||||
control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.to(dtype), context=context.to(dtype), **extra)
|
||||
return self.control_merge(control, control_prev, output_dtype)
|
||||
|
||||
def copy(self):
|
||||
@@ -280,6 +322,7 @@ class ControlLora(ControlNet):
|
||||
ControlBase.__init__(self, device)
|
||||
self.control_weights = control_weights
|
||||
self.global_average_pooling = global_average_pooling
|
||||
self.extra_conds += ["y"]
|
||||
|
||||
def pre_run(self, model, percent_to_timestep_function):
|
||||
super().pre_run(model, percent_to_timestep_function)
|
||||
@@ -332,12 +375,8 @@ class ControlLora(ControlNet):
|
||||
def inference_memory_requirements(self, dtype):
|
||||
return comfy.utils.calculate_parameters(self.control_weights) * comfy.model_management.dtype_size(dtype) + ControlBase.inference_memory_requirements(self, dtype)
|
||||
|
||||
def load_controlnet_mmdit(sd):
|
||||
new_sd = comfy.model_detection.convert_diffusers_mmdit(sd, "")
|
||||
model_config = comfy.model_detection.model_config_from_unet(new_sd, "", True)
|
||||
num_blocks = comfy.model_detection.count_blocks(new_sd, 'joint_blocks.{}.')
|
||||
for k in sd:
|
||||
new_sd[k] = sd[k]
|
||||
def controlnet_config(sd):
|
||||
model_config = comfy.model_detection.model_config_from_unet(sd, "", True)
|
||||
|
||||
supported_inference_dtypes = model_config.supported_inference_dtypes
|
||||
|
||||
@@ -350,23 +389,49 @@ def load_controlnet_mmdit(sd):
|
||||
else:
|
||||
operations = comfy.ops.disable_weight_init
|
||||
|
||||
control_model = comfy.cldm.mmdit.ControlNet(num_blocks=num_blocks, operations=operations, device=load_device, dtype=unet_dtype, **controlnet_config)
|
||||
missing, unexpected = control_model.load_state_dict(new_sd, strict=False)
|
||||
return model_config, operations, load_device, unet_dtype, manual_cast_dtype
|
||||
|
||||
def controlnet_load_state_dict(control_model, sd):
|
||||
missing, unexpected = control_model.load_state_dict(sd, strict=False)
|
||||
|
||||
if len(missing) > 0:
|
||||
logging.warning("missing controlnet keys: {}".format(missing))
|
||||
|
||||
if len(unexpected) > 0:
|
||||
logging.debug("unexpected controlnet keys: {}".format(unexpected))
|
||||
return control_model
|
||||
|
||||
def load_controlnet_mmdit(sd):
|
||||
new_sd = comfy.model_detection.convert_diffusers_mmdit(sd, "")
|
||||
model_config, operations, load_device, unet_dtype, manual_cast_dtype = controlnet_config(new_sd)
|
||||
num_blocks = comfy.model_detection.count_blocks(new_sd, 'joint_blocks.{}.')
|
||||
for k in sd:
|
||||
new_sd[k] = sd[k]
|
||||
|
||||
control_model = comfy.cldm.mmdit.ControlNet(num_blocks=num_blocks, operations=operations, device=load_device, dtype=unet_dtype, **model_config.unet_config)
|
||||
control_model = controlnet_load_state_dict(control_model, new_sd)
|
||||
|
||||
latent_format = comfy.latent_formats.SD3()
|
||||
latent_format.shift_factor = 0 #SD3 controlnet weirdness
|
||||
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
|
||||
return control
|
||||
|
||||
def load_controlnet_hunyuandit(controlnet_data):
|
||||
model_config, operations, load_device, unet_dtype, manual_cast_dtype = controlnet_config(controlnet_data)
|
||||
|
||||
control_model = comfy.ldm.hydit.controlnet.HunYuanControlNet(operations=operations, device=load_device, dtype=unet_dtype)
|
||||
control_model = controlnet_load_state_dict(control_model, controlnet_data)
|
||||
|
||||
latent_format = comfy.latent_formats.SDXL()
|
||||
extra_conds = ['text_embedding_mask', 'encoder_hidden_states_t5', 'text_embedding_mask_t5', 'image_meta_size', 'style', 'cos_cis_img', 'sin_cis_img']
|
||||
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds, strength_type=StrengthType.CONSTANT)
|
||||
return control
|
||||
|
||||
def load_controlnet(ckpt_path, model=None):
|
||||
controlnet_data = comfy.utils.load_torch_file(ckpt_path, safe_load=True)
|
||||
if 'after_proj_list.18.bias' in controlnet_data.keys(): #Hunyuan DiT
|
||||
return load_controlnet_hunyuandit(controlnet_data)
|
||||
|
||||
if "lora_controlnet" in controlnet_data:
|
||||
return ControlLora(controlnet_data)
|
||||
|
||||
|
@@ -139,3 +139,32 @@ class SD3(LatentFormat):
|
||||
|
||||
class StableAudio1(LatentFormat):
|
||||
latent_channels = 64
|
||||
|
||||
class Flux(SD3):
|
||||
def __init__(self):
|
||||
self.scale_factor = 0.3611
|
||||
self.shift_factor = 0.1159
|
||||
self.latent_rgb_factors =[
|
||||
[-0.0404, 0.0159, 0.0609],
|
||||
[ 0.0043, 0.0298, 0.0850],
|
||||
[ 0.0328, -0.0749, -0.0503],
|
||||
[-0.0245, 0.0085, 0.0549],
|
||||
[ 0.0966, 0.0894, 0.0530],
|
||||
[ 0.0035, 0.0399, 0.0123],
|
||||
[ 0.0583, 0.1184, 0.1262],
|
||||
[-0.0191, -0.0206, -0.0306],
|
||||
[-0.0324, 0.0055, 0.1001],
|
||||
[ 0.0955, 0.0659, -0.0545],
|
||||
[-0.0504, 0.0231, -0.0013],
|
||||
[ 0.0500, -0.0008, -0.0088],
|
||||
[ 0.0982, 0.0941, 0.0976],
|
||||
[-0.1233, -0.0280, -0.0897],
|
||||
[-0.0005, -0.0530, -0.0020],
|
||||
[-0.1273, -0.0932, -0.0680]
|
||||
]
|
||||
|
||||
def process_in(self, latent):
|
||||
return (latent - self.shift_factor) * self.scale_factor
|
||||
|
||||
def process_out(self, latent):
|
||||
return (latent / self.scale_factor) + self.shift_factor
|
||||
|
@@ -9,6 +9,7 @@ from einops import rearrange
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
import math
|
||||
import comfy.ops
|
||||
|
||||
class FourierFeatures(nn.Module):
|
||||
def __init__(self, in_features, out_features, std=1., dtype=None, device=None):
|
||||
@@ -18,7 +19,7 @@ class FourierFeatures(nn.Module):
|
||||
[out_features // 2, in_features], dtype=dtype, device=device))
|
||||
|
||||
def forward(self, input):
|
||||
f = 2 * math.pi * input @ self.weight.T.to(dtype=input.dtype, device=input.device)
|
||||
f = 2 * math.pi * input @ comfy.ops.cast_to_input(self.weight.T, input)
|
||||
return torch.cat([f.cos(), f.sin()], dim=-1)
|
||||
|
||||
# norms
|
||||
@@ -38,9 +39,9 @@ class LayerNorm(nn.Module):
|
||||
|
||||
def forward(self, x):
|
||||
beta = self.beta
|
||||
if self.beta is not None:
|
||||
beta = beta.to(dtype=x.dtype, device=x.device)
|
||||
return F.layer_norm(x, x.shape[-1:], weight=self.gamma.to(dtype=x.dtype, device=x.device), bias=beta)
|
||||
if beta is not None:
|
||||
beta = comfy.ops.cast_to_input(beta, x)
|
||||
return F.layer_norm(x, x.shape[-1:], weight=comfy.ops.cast_to_input(self.gamma, x), bias=beta)
|
||||
|
||||
class GLU(nn.Module):
|
||||
def __init__(
|
||||
@@ -123,7 +124,9 @@ class RotaryEmbedding(nn.Module):
|
||||
scale_base = 512,
|
||||
interpolation_factor = 1.,
|
||||
base = 10000,
|
||||
base_rescale_factor = 1.
|
||||
base_rescale_factor = 1.,
|
||||
dtype=None,
|
||||
device=None,
|
||||
):
|
||||
super().__init__()
|
||||
# proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
|
||||
@@ -131,8 +134,8 @@ class RotaryEmbedding(nn.Module):
|
||||
# https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
|
||||
base *= base_rescale_factor ** (dim / (dim - 2))
|
||||
|
||||
inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim))
|
||||
self.register_buffer('inv_freq', inv_freq)
|
||||
# inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim))
|
||||
self.register_buffer('inv_freq', torch.empty((dim // 2,), device=device, dtype=dtype))
|
||||
|
||||
assert interpolation_factor >= 1.
|
||||
self.interpolation_factor = interpolation_factor
|
||||
@@ -161,14 +164,14 @@ class RotaryEmbedding(nn.Module):
|
||||
|
||||
t = t / self.interpolation_factor
|
||||
|
||||
freqs = torch.einsum('i , j -> i j', t, self.inv_freq.to(dtype=dtype, device=device))
|
||||
freqs = torch.einsum('i , j -> i j', t, comfy.ops.cast_to_input(self.inv_freq, t))
|
||||
freqs = torch.cat((freqs, freqs), dim = -1)
|
||||
|
||||
if self.scale is None:
|
||||
return freqs, 1.
|
||||
|
||||
power = (torch.arange(seq_len, device = device) - (seq_len // 2)) / self.scale_base
|
||||
scale = self.scale.to(dtype=dtype, device=device) ** rearrange(power, 'n -> n 1')
|
||||
scale = comfy.ops.cast_to_input(self.scale, t) ** rearrange(power, 'n -> n 1')
|
||||
scale = torch.cat((scale, scale), dim = -1)
|
||||
|
||||
return freqs, scale
|
||||
@@ -568,7 +571,7 @@ class ContinuousTransformer(nn.Module):
|
||||
self.project_out = operations.Linear(dim, dim_out, bias=False, dtype=dtype, device=device) if dim_out is not None else nn.Identity()
|
||||
|
||||
if rotary_pos_emb:
|
||||
self.rotary_pos_emb = RotaryEmbedding(max(dim_heads // 2, 32))
|
||||
self.rotary_pos_emb = RotaryEmbedding(max(dim_heads // 2, 32), device=device, dtype=dtype)
|
||||
else:
|
||||
self.rotary_pos_emb = None
|
||||
|
||||
|
@@ -8,6 +8,8 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
import comfy.ops
|
||||
import comfy.ldm.common_dit
|
||||
|
||||
def modulate(x, shift, scale):
|
||||
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
||||
@@ -406,10 +408,7 @@ class MMDiT(nn.Module):
|
||||
|
||||
def patchify(self, x):
|
||||
B, C, H, W = x.size()
|
||||
pad_h = (self.patch_size - H % self.patch_size) % self.patch_size
|
||||
pad_w = (self.patch_size - W % self.patch_size) % self.patch_size
|
||||
|
||||
x = torch.nn.functional.pad(x, (0, pad_w, 0, pad_h), mode='reflect')
|
||||
x = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size))
|
||||
x = x.view(
|
||||
B,
|
||||
C,
|
||||
@@ -427,7 +426,7 @@ class MMDiT(nn.Module):
|
||||
max_dim = max(h, w)
|
||||
|
||||
cur_dim = self.h_max
|
||||
pos_encoding = self.positional_encoding.reshape(1, cur_dim, cur_dim, -1).to(device=x.device, dtype=x.dtype)
|
||||
pos_encoding = comfy.ops.cast_to_input(self.positional_encoding.reshape(1, cur_dim, cur_dim, -1), x)
|
||||
|
||||
if max_dim > cur_dim:
|
||||
pos_encoding = F.interpolate(pos_encoding.movedim(-1, 1), (max_dim, max_dim), mode="bilinear").movedim(1, -1)
|
||||
@@ -455,7 +454,7 @@ class MMDiT(nn.Module):
|
||||
t = timestep
|
||||
|
||||
c = self.cond_seq_linear(c_seq) # B, T_c, D
|
||||
c = torch.cat([self.register_tokens.to(device=c.device, dtype=c.dtype).repeat(c.size(0), 1, 1), c], dim=1)
|
||||
c = torch.cat([comfy.ops.cast_to_input(self.register_tokens, c).repeat(c.size(0), 1, 1), c], dim=1)
|
||||
|
||||
global_cond = self.t_embedder(t, x.dtype) # B, D
|
||||
|
||||
|
@@ -19,14 +19,7 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
|
||||
class Linear(torch.nn.Linear):
|
||||
def reset_parameters(self):
|
||||
return None
|
||||
|
||||
class Conv2d(torch.nn.Conv2d):
|
||||
def reset_parameters(self):
|
||||
return None
|
||||
import comfy.ops
|
||||
|
||||
class OptimizedAttention(nn.Module):
|
||||
def __init__(self, c, nhead, dropout=0.0, dtype=None, device=None, operations=None):
|
||||
@@ -78,13 +71,13 @@ class GlobalResponseNorm(nn.Module):
|
||||
"from https://github.com/facebookresearch/ConvNeXt-V2/blob/3608f67cc1dae164790c5d0aead7bf2d73d9719b/models/utils.py#L105"
|
||||
def __init__(self, dim, dtype=None, device=None):
|
||||
super().__init__()
|
||||
self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim, dtype=dtype, device=device))
|
||||
self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim, dtype=dtype, device=device))
|
||||
self.gamma = nn.Parameter(torch.empty(1, 1, 1, dim, dtype=dtype, device=device))
|
||||
self.beta = nn.Parameter(torch.empty(1, 1, 1, dim, dtype=dtype, device=device))
|
||||
|
||||
def forward(self, x):
|
||||
Gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True)
|
||||
Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6)
|
||||
return self.gamma.to(device=x.device, dtype=x.dtype) * (x * Nx) + self.beta.to(device=x.device, dtype=x.dtype) + x
|
||||
return comfy.ops.cast_to_input(self.gamma, x) * (x * Nx) + comfy.ops.cast_to_input(self.beta, x) + x
|
||||
|
||||
|
||||
class ResBlock(nn.Module):
|
||||
|
8
comfy/ldm/common_dit.py
Normal file
8
comfy/ldm/common_dit.py
Normal file
@@ -0,0 +1,8 @@
|
||||
import torch
|
||||
|
||||
def pad_to_patch_size(img, patch_size=(2, 2), padding_mode="circular"):
|
||||
if padding_mode == "circular" and torch.jit.is_tracing() or torch.jit.is_scripting():
|
||||
padding_mode = "reflect"
|
||||
pad_h = (patch_size[0] - img.shape[-2] % patch_size[0]) % patch_size[0]
|
||||
pad_w = (patch_size[1] - img.shape[-1] % patch_size[1]) % patch_size[1]
|
||||
return torch.nn.functional.pad(img, (0, pad_w, 0, pad_h), mode=padding_mode)
|
251
comfy/ldm/flux/layers.py
Normal file
251
comfy/ldm/flux/layers.py
Normal file
@@ -0,0 +1,251 @@
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
|
||||
from .math import attention, rope
|
||||
import comfy.ops
|
||||
|
||||
|
||||
class EmbedND(nn.Module):
|
||||
def __init__(self, dim: int, theta: int, axes_dim: list):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.theta = theta
|
||||
self.axes_dim = axes_dim
|
||||
|
||||
def forward(self, ids: Tensor) -> Tensor:
|
||||
n_axes = ids.shape[-1]
|
||||
emb = torch.cat(
|
||||
[rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
|
||||
dim=-3,
|
||||
)
|
||||
|
||||
return emb.unsqueeze(1)
|
||||
|
||||
|
||||
def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0):
|
||||
"""
|
||||
Create sinusoidal timestep embeddings.
|
||||
:param t: a 1-D Tensor of N indices, one per batch element.
|
||||
These may be fractional.
|
||||
:param dim: the dimension of the output.
|
||||
:param max_period: controls the minimum frequency of the embeddings.
|
||||
:return: an (N, D) Tensor of positional embeddings.
|
||||
"""
|
||||
t = time_factor * t
|
||||
half = dim // 2
|
||||
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half)
|
||||
|
||||
args = t[:, None].float() * freqs[None]
|
||||
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||
if dim % 2:
|
||||
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
||||
if torch.is_floating_point(t):
|
||||
embedding = embedding.to(t)
|
||||
return embedding
|
||||
|
||||
class MLPEmbedder(nn.Module):
|
||||
def __init__(self, in_dim: int, hidden_dim: int, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.in_layer = operations.Linear(in_dim, hidden_dim, bias=True, dtype=dtype, device=device)
|
||||
self.silu = nn.SiLU()
|
||||
self.out_layer = operations.Linear(hidden_dim, hidden_dim, bias=True, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
return self.out_layer(self.silu(self.in_layer(x)))
|
||||
|
||||
|
||||
class RMSNorm(torch.nn.Module):
|
||||
def __init__(self, dim: int, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.scale = nn.Parameter(torch.empty((dim), dtype=dtype, device=device))
|
||||
|
||||
def forward(self, x: Tensor):
|
||||
x_dtype = x.dtype
|
||||
x = x.float()
|
||||
rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6)
|
||||
return (x * rrms).to(dtype=x_dtype) * comfy.ops.cast_to(self.scale, dtype=x_dtype, device=x.device)
|
||||
|
||||
|
||||
class QKNorm(torch.nn.Module):
|
||||
def __init__(self, dim: int, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.query_norm = RMSNorm(dim, dtype=dtype, device=device, operations=operations)
|
||||
self.key_norm = RMSNorm(dim, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple:
|
||||
q = self.query_norm(q)
|
||||
k = self.key_norm(k)
|
||||
return q.to(v), k.to(v)
|
||||
|
||||
|
||||
class SelfAttention(nn.Module):
|
||||
def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
head_dim = dim // num_heads
|
||||
|
||||
self.qkv = operations.Linear(dim, dim * 3, bias=qkv_bias, dtype=dtype, device=device)
|
||||
self.norm = QKNorm(head_dim, dtype=dtype, device=device, operations=operations)
|
||||
self.proj = operations.Linear(dim, dim, dtype=dtype, device=device)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModulationOut:
|
||||
shift: Tensor
|
||||
scale: Tensor
|
||||
gate: Tensor
|
||||
|
||||
|
||||
class Modulation(nn.Module):
|
||||
def __init__(self, dim: int, double: bool, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.is_double = double
|
||||
self.multiplier = 6 if double else 3
|
||||
self.lin = operations.Linear(dim, self.multiplier * dim, bias=True, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, vec: Tensor) -> tuple:
|
||||
out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1)
|
||||
|
||||
return (
|
||||
ModulationOut(*out[:3]),
|
||||
ModulationOut(*out[3:]) if self.is_double else None,
|
||||
)
|
||||
|
||||
|
||||
class DoubleStreamBlock(nn.Module):
|
||||
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
|
||||
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||
self.num_heads = num_heads
|
||||
self.hidden_size = hidden_size
|
||||
self.img_mod = Modulation(hidden_size, double=True, dtype=dtype, device=device, operations=operations)
|
||||
self.img_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
self.img_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
self.img_mlp = nn.Sequential(
|
||||
operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device),
|
||||
nn.GELU(approximate="tanh"),
|
||||
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
|
||||
)
|
||||
|
||||
self.txt_mod = Modulation(hidden_size, double=True, dtype=dtype, device=device, operations=operations)
|
||||
self.txt_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
self.txt_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
self.txt_mlp = nn.Sequential(
|
||||
operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device),
|
||||
nn.GELU(approximate="tanh"),
|
||||
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
|
||||
)
|
||||
|
||||
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor):
|
||||
img_mod1, img_mod2 = self.img_mod(vec)
|
||||
txt_mod1, txt_mod2 = self.txt_mod(vec)
|
||||
|
||||
# prepare image for attention
|
||||
img_modulated = self.img_norm1(img)
|
||||
img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
|
||||
img_qkv = self.img_attn.qkv(img_modulated)
|
||||
img_q, img_k, img_v = img_qkv.view(img_qkv.shape[0], img_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
|
||||
|
||||
# prepare txt for attention
|
||||
txt_modulated = self.txt_norm1(txt)
|
||||
txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
|
||||
txt_qkv = self.txt_attn.qkv(txt_modulated)
|
||||
txt_q, txt_k, txt_v = txt_qkv.view(txt_qkv.shape[0], txt_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
|
||||
|
||||
# run actual attention
|
||||
attn = attention(torch.cat((txt_q, img_q), dim=2),
|
||||
torch.cat((txt_k, img_k), dim=2),
|
||||
torch.cat((txt_v, img_v), dim=2), pe=pe)
|
||||
|
||||
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
|
||||
|
||||
# calculate the img bloks
|
||||
img += img_mod1.gate * self.img_attn.proj(img_attn)
|
||||
img += img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift)
|
||||
|
||||
# calculate the txt bloks
|
||||
txt += txt_mod1.gate * self.txt_attn.proj(txt_attn)
|
||||
txt += txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
|
||||
|
||||
if txt.dtype == torch.float16:
|
||||
txt = txt.clip(-65504, 65504)
|
||||
|
||||
return img, txt
|
||||
|
||||
|
||||
class SingleStreamBlock(nn.Module):
|
||||
"""
|
||||
A DiT block with parallel linear layers as described in
|
||||
https://arxiv.org/abs/2302.05442 and adapted modulation interface.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
mlp_ratio: float = 4.0,
|
||||
qk_scale: float = None,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_dim = hidden_size
|
||||
self.num_heads = num_heads
|
||||
head_dim = hidden_size // num_heads
|
||||
self.scale = qk_scale or head_dim**-0.5
|
||||
|
||||
self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||
# qkv and mlp_in
|
||||
self.linear1 = operations.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim, dtype=dtype, device=device)
|
||||
# proj and mlp_out
|
||||
self.linear2 = operations.Linear(hidden_size + self.mlp_hidden_dim, hidden_size, dtype=dtype, device=device)
|
||||
|
||||
self.norm = QKNorm(head_dim, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
self.pre_norm = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
|
||||
self.mlp_act = nn.GELU(approximate="tanh")
|
||||
self.modulation = Modulation(hidden_size, double=False, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
def forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor:
|
||||
mod, _ = self.modulation(vec)
|
||||
x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
|
||||
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
|
||||
|
||||
q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||
q, k = self.norm(q, k, v)
|
||||
|
||||
# compute attention
|
||||
attn = attention(q, k, v, pe=pe)
|
||||
# compute activation in mlp stream, cat again and run second linear layer
|
||||
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
|
||||
x += mod.gate * output
|
||||
if x.dtype == torch.float16:
|
||||
x = x.clip(-65504, 65504)
|
||||
return x
|
||||
|
||||
|
||||
class LastLayer(nn.Module):
|
||||
def __init__(self, hidden_size: int, patch_size: int, out_channels: int, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.norm_final = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
self.linear = operations.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True, dtype=dtype, device=device)
|
||||
self.adaLN_modulation = nn.Sequential(nn.SiLU(), operations.Linear(hidden_size, 2 * hidden_size, bias=True, dtype=dtype, device=device))
|
||||
|
||||
def forward(self, x: Tensor, vec: Tensor) -> Tensor:
|
||||
shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1)
|
||||
x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
|
||||
x = self.linear(x)
|
||||
return x
|
35
comfy/ldm/flux/math.py
Normal file
35
comfy/ldm/flux/math.py
Normal file
@@ -0,0 +1,35 @@
|
||||
import torch
|
||||
from einops import rearrange
|
||||
from torch import Tensor
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
import comfy.model_management
|
||||
|
||||
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor:
|
||||
q, k = apply_rope(q, k, pe)
|
||||
|
||||
heads = q.shape[1]
|
||||
x = optimized_attention(q, k, v, heads, skip_reshape=True)
|
||||
return x
|
||||
|
||||
|
||||
def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
|
||||
assert dim % 2 == 0
|
||||
if comfy.model_management.is_device_mps(pos.device) or comfy.model_management.is_intel_xpu():
|
||||
device = torch.device("cpu")
|
||||
else:
|
||||
device = pos.device
|
||||
|
||||
scale = torch.linspace(0, (dim - 2) / dim, steps=dim//2, dtype=torch.float64, device=device)
|
||||
omega = 1.0 / (theta**scale)
|
||||
out = torch.einsum("...n,d->...nd", pos.to(dtype=torch.float32, device=device), omega)
|
||||
out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
|
||||
out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
|
||||
return out.to(dtype=torch.float32, device=pos.device)
|
||||
|
||||
|
||||
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor):
|
||||
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
|
||||
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
|
||||
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
|
||||
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
|
||||
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
|
142
comfy/ldm/flux/model.py
Normal file
142
comfy/ldm/flux/model.py
Normal file
@@ -0,0 +1,142 @@
|
||||
#Original code can be found on: https://github.com/black-forest-labs/flux
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
|
||||
from .layers import (
|
||||
DoubleStreamBlock,
|
||||
EmbedND,
|
||||
LastLayer,
|
||||
MLPEmbedder,
|
||||
SingleStreamBlock,
|
||||
timestep_embedding,
|
||||
)
|
||||
|
||||
from einops import rearrange, repeat
|
||||
import comfy.ldm.common_dit
|
||||
|
||||
@dataclass
|
||||
class FluxParams:
|
||||
in_channels: int
|
||||
vec_in_dim: int
|
||||
context_in_dim: int
|
||||
hidden_size: int
|
||||
mlp_ratio: float
|
||||
num_heads: int
|
||||
depth: int
|
||||
depth_single_blocks: int
|
||||
axes_dim: list
|
||||
theta: int
|
||||
qkv_bias: bool
|
||||
guidance_embed: bool
|
||||
|
||||
|
||||
class Flux(nn.Module):
|
||||
"""
|
||||
Transformer model for flow matching on sequences.
|
||||
"""
|
||||
|
||||
def __init__(self, image_model=None, dtype=None, device=None, operations=None, **kwargs):
|
||||
super().__init__()
|
||||
self.dtype = dtype
|
||||
params = FluxParams(**kwargs)
|
||||
self.params = params
|
||||
self.in_channels = params.in_channels * 2 * 2
|
||||
self.out_channels = self.in_channels
|
||||
if params.hidden_size % params.num_heads != 0:
|
||||
raise ValueError(
|
||||
f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
|
||||
)
|
||||
pe_dim = params.hidden_size // params.num_heads
|
||||
if sum(params.axes_dim) != pe_dim:
|
||||
raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
|
||||
self.hidden_size = params.hidden_size
|
||||
self.num_heads = params.num_heads
|
||||
self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
|
||||
self.img_in = operations.Linear(self.in_channels, self.hidden_size, bias=True, dtype=dtype, device=device)
|
||||
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, dtype=dtype, device=device, operations=operations)
|
||||
self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size, dtype=dtype, device=device, operations=operations)
|
||||
self.guidance_in = (
|
||||
MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, dtype=dtype, device=device, operations=operations) if params.guidance_embed else nn.Identity()
|
||||
)
|
||||
self.txt_in = operations.Linear(params.context_in_dim, self.hidden_size, dtype=dtype, device=device)
|
||||
|
||||
self.double_blocks = nn.ModuleList(
|
||||
[
|
||||
DoubleStreamBlock(
|
||||
self.hidden_size,
|
||||
self.num_heads,
|
||||
mlp_ratio=params.mlp_ratio,
|
||||
qkv_bias=params.qkv_bias,
|
||||
dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
for _ in range(params.depth)
|
||||
]
|
||||
)
|
||||
|
||||
self.single_blocks = nn.ModuleList(
|
||||
[
|
||||
SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, dtype=dtype, device=device, operations=operations)
|
||||
for _ in range(params.depth_single_blocks)
|
||||
]
|
||||
)
|
||||
|
||||
self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
def forward_orig(
|
||||
self,
|
||||
img: Tensor,
|
||||
img_ids: Tensor,
|
||||
txt: Tensor,
|
||||
txt_ids: Tensor,
|
||||
timesteps: Tensor,
|
||||
y: Tensor,
|
||||
guidance: Tensor = None,
|
||||
) -> Tensor:
|
||||
if img.ndim != 3 or txt.ndim != 3:
|
||||
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
||||
|
||||
# running on sequences img
|
||||
img = self.img_in(img)
|
||||
vec = self.time_in(timestep_embedding(timesteps, 256).to(img.dtype))
|
||||
if self.params.guidance_embed:
|
||||
if guidance is None:
|
||||
raise ValueError("Didn't get guidance strength for guidance distilled model.")
|
||||
vec = vec + self.guidance_in(timestep_embedding(guidance, 256).to(img.dtype))
|
||||
|
||||
vec = vec + self.vector_in(y)
|
||||
txt = self.txt_in(txt)
|
||||
|
||||
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||
pe = self.pe_embedder(ids)
|
||||
|
||||
for block in self.double_blocks:
|
||||
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
|
||||
|
||||
img = torch.cat((txt, img), 1)
|
||||
for block in self.single_blocks:
|
||||
img = block(img, vec=vec, pe=pe)
|
||||
img = img[:, txt.shape[1] :, ...]
|
||||
|
||||
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
|
||||
return img
|
||||
|
||||
def forward(self, x, timestep, context, y, guidance, **kwargs):
|
||||
bs, c, h, w = x.shape
|
||||
patch_size = 2
|
||||
x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size))
|
||||
|
||||
img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
|
||||
|
||||
h_len = ((h + (patch_size // 2)) // patch_size)
|
||||
w_len = ((w + (patch_size // 2)) // patch_size)
|
||||
img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
|
||||
img_ids[..., 1] = img_ids[..., 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype)[:, None]
|
||||
img_ids[..., 2] = img_ids[..., 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype)[None, :]
|
||||
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
||||
|
||||
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
|
||||
out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance)
|
||||
return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)[:,:,:h,:w]
|
218
comfy/ldm/hydit/attn_layers.py
Normal file
218
comfy/ldm/hydit/attn_layers.py
Normal file
@@ -0,0 +1,218 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from typing import Tuple, Union, Optional
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
|
||||
|
||||
def reshape_for_broadcast(freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]], x: torch.Tensor, head_first=False):
|
||||
"""
|
||||
Reshape frequency tensor for broadcasting it with another tensor.
|
||||
|
||||
This function reshapes the frequency tensor to have the same shape as the target tensor 'x'
|
||||
for the purpose of broadcasting the frequency tensor during element-wise operations.
|
||||
|
||||
Args:
|
||||
freqs_cis (Union[torch.Tensor, Tuple[torch.Tensor]]): Frequency tensor to be reshaped.
|
||||
x (torch.Tensor): Target tensor for broadcasting compatibility.
|
||||
head_first (bool): head dimension first (except batch dim) or not.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Reshaped frequency tensor.
|
||||
|
||||
Raises:
|
||||
AssertionError: If the frequency tensor doesn't match the expected shape.
|
||||
AssertionError: If the target tensor 'x' doesn't have the expected number of dimensions.
|
||||
"""
|
||||
ndim = x.ndim
|
||||
assert 0 <= 1 < ndim
|
||||
|
||||
if isinstance(freqs_cis, tuple):
|
||||
# freqs_cis: (cos, sin) in real space
|
||||
if head_first:
|
||||
assert freqs_cis[0].shape == (x.shape[-2], x.shape[-1]), f'freqs_cis shape {freqs_cis[0].shape} does not match x shape {x.shape}'
|
||||
shape = [d if i == ndim - 2 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
|
||||
else:
|
||||
assert freqs_cis[0].shape == (x.shape[1], x.shape[-1]), f'freqs_cis shape {freqs_cis[0].shape} does not match x shape {x.shape}'
|
||||
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
|
||||
return freqs_cis[0].view(*shape), freqs_cis[1].view(*shape)
|
||||
else:
|
||||
# freqs_cis: values in complex space
|
||||
if head_first:
|
||||
assert freqs_cis.shape == (x.shape[-2], x.shape[-1]), f'freqs_cis shape {freqs_cis.shape} does not match x shape {x.shape}'
|
||||
shape = [d if i == ndim - 2 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
|
||||
else:
|
||||
assert freqs_cis.shape == (x.shape[1], x.shape[-1]), f'freqs_cis shape {freqs_cis.shape} does not match x shape {x.shape}'
|
||||
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
|
||||
return freqs_cis.view(*shape)
|
||||
|
||||
|
||||
def rotate_half(x):
|
||||
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
|
||||
return torch.stack([-x_imag, x_real], dim=-1).flatten(3)
|
||||
|
||||
|
||||
def apply_rotary_emb(
|
||||
xq: torch.Tensor,
|
||||
xk: Optional[torch.Tensor],
|
||||
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
|
||||
head_first: bool = False,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Apply rotary embeddings to input tensors using the given frequency tensor.
|
||||
|
||||
This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided
|
||||
frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor
|
||||
is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are
|
||||
returned as real tensors.
|
||||
|
||||
Args:
|
||||
xq (torch.Tensor): Query tensor to apply rotary embeddings. [B, S, H, D]
|
||||
xk (torch.Tensor): Key tensor to apply rotary embeddings. [B, S, H, D]
|
||||
freqs_cis (Union[torch.Tensor, Tuple[torch.Tensor]]): Precomputed frequency tensor for complex exponentials.
|
||||
head_first (bool): head dimension first (except batch dim) or not.
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
|
||||
|
||||
"""
|
||||
xk_out = None
|
||||
if isinstance(freqs_cis, tuple):
|
||||
cos, sin = reshape_for_broadcast(freqs_cis, xq, head_first) # [S, D]
|
||||
xq_out = (xq * cos + rotate_half(xq) * sin)
|
||||
if xk is not None:
|
||||
xk_out = (xk * cos + rotate_half(xk) * sin)
|
||||
else:
|
||||
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # [B, S, H, D//2]
|
||||
freqs_cis = reshape_for_broadcast(freqs_cis, xq_, head_first).to(xq.device) # [S, D//2] --> [1, S, 1, D//2]
|
||||
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3).type_as(xq)
|
||||
if xk is not None:
|
||||
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) # [B, S, H, D//2]
|
||||
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3).type_as(xk)
|
||||
|
||||
return xq_out, xk_out
|
||||
|
||||
|
||||
|
||||
class CrossAttention(nn.Module):
|
||||
"""
|
||||
Use QK Normalization.
|
||||
"""
|
||||
def __init__(self,
|
||||
qdim,
|
||||
kdim,
|
||||
num_heads,
|
||||
qkv_bias=True,
|
||||
qk_norm=False,
|
||||
attn_drop=0.0,
|
||||
proj_drop=0.0,
|
||||
attn_precision=None,
|
||||
device=None,
|
||||
dtype=None,
|
||||
operations=None,
|
||||
):
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
super().__init__()
|
||||
self.attn_precision = attn_precision
|
||||
self.qdim = qdim
|
||||
self.kdim = kdim
|
||||
self.num_heads = num_heads
|
||||
assert self.qdim % num_heads == 0, "self.qdim must be divisible by num_heads"
|
||||
self.head_dim = self.qdim // num_heads
|
||||
assert self.head_dim % 8 == 0 and self.head_dim <= 128, "Only support head_dim <= 128 and divisible by 8"
|
||||
self.scale = self.head_dim ** -0.5
|
||||
|
||||
self.q_proj = operations.Linear(qdim, qdim, bias=qkv_bias, **factory_kwargs)
|
||||
self.kv_proj = operations.Linear(kdim, 2 * qdim, bias=qkv_bias, **factory_kwargs)
|
||||
|
||||
# TODO: eps should be 1 / 65530 if using fp16
|
||||
self.q_norm = operations.LayerNorm(self.head_dim, elementwise_affine=True, eps=1e-6, dtype=dtype, device=device) if qk_norm else nn.Identity()
|
||||
self.k_norm = operations.LayerNorm(self.head_dim, elementwise_affine=True, eps=1e-6, dtype=dtype, device=device) if qk_norm else nn.Identity()
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
self.out_proj = operations.Linear(qdim, qdim, bias=qkv_bias, **factory_kwargs)
|
||||
self.proj_drop = nn.Dropout(proj_drop)
|
||||
|
||||
def forward(self, x, y, freqs_cis_img=None):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
x: torch.Tensor
|
||||
(batch, seqlen1, hidden_dim) (where hidden_dim = num heads * head dim)
|
||||
y: torch.Tensor
|
||||
(batch, seqlen2, hidden_dim2)
|
||||
freqs_cis_img: torch.Tensor
|
||||
(batch, hidden_dim // 2), RoPE for image
|
||||
"""
|
||||
b, s1, c = x.shape # [b, s1, D]
|
||||
_, s2, c = y.shape # [b, s2, 1024]
|
||||
|
||||
q = self.q_proj(x).view(b, s1, self.num_heads, self.head_dim) # [b, s1, h, d]
|
||||
kv = self.kv_proj(y).view(b, s2, 2, self.num_heads, self.head_dim) # [b, s2, 2, h, d]
|
||||
k, v = kv.unbind(dim=2) # [b, s, h, d]
|
||||
q = self.q_norm(q)
|
||||
k = self.k_norm(k)
|
||||
|
||||
# Apply RoPE if needed
|
||||
if freqs_cis_img is not None:
|
||||
qq, _ = apply_rotary_emb(q, None, freqs_cis_img)
|
||||
assert qq.shape == q.shape, f'qq: {qq.shape}, q: {q.shape}'
|
||||
q = qq
|
||||
|
||||
q = q.transpose(-2, -3).contiguous() # q -> B, L1, H, C - B, H, L1, C
|
||||
k = k.transpose(-2, -3).contiguous() # k -> B, L2, H, C - B, H, C, L2
|
||||
v = v.transpose(-2, -3).contiguous()
|
||||
|
||||
context = optimized_attention(q, k, v, self.num_heads, skip_reshape=True, attn_precision=self.attn_precision)
|
||||
|
||||
out = self.out_proj(context) # context.reshape - B, L1, -1
|
||||
out = self.proj_drop(out)
|
||||
|
||||
out_tuple = (out,)
|
||||
|
||||
return out_tuple
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
"""
|
||||
We rename some layer names to align with flash attention
|
||||
"""
|
||||
def __init__(self, dim, num_heads, qkv_bias=True, qk_norm=False, attn_drop=0., proj_drop=0., attn_precision=None, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.attn_precision = attn_precision
|
||||
self.dim = dim
|
||||
self.num_heads = num_heads
|
||||
assert self.dim % num_heads == 0, 'dim should be divisible by num_heads'
|
||||
self.head_dim = self.dim // num_heads
|
||||
# This assertion is aligned with flash attention
|
||||
assert self.head_dim % 8 == 0 and self.head_dim <= 128, "Only support head_dim <= 128 and divisible by 8"
|
||||
self.scale = self.head_dim ** -0.5
|
||||
|
||||
# qkv --> Wqkv
|
||||
self.Wqkv = operations.Linear(dim, dim * 3, bias=qkv_bias, dtype=dtype, device=device)
|
||||
# TODO: eps should be 1 / 65530 if using fp16
|
||||
self.q_norm = operations.LayerNorm(self.head_dim, elementwise_affine=True, eps=1e-6, dtype=dtype, device=device) if qk_norm else nn.Identity()
|
||||
self.k_norm = operations.LayerNorm(self.head_dim, elementwise_affine=True, eps=1e-6, dtype=dtype, device=device) if qk_norm else nn.Identity()
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
self.out_proj = operations.Linear(dim, dim, dtype=dtype, device=device)
|
||||
self.proj_drop = nn.Dropout(proj_drop)
|
||||
|
||||
def forward(self, x, freqs_cis_img=None):
|
||||
B, N, C = x.shape
|
||||
qkv = self.Wqkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) # [3, b, h, s, d]
|
||||
q, k, v = qkv.unbind(0) # [b, h, s, d]
|
||||
q = self.q_norm(q) # [b, h, s, d]
|
||||
k = self.k_norm(k) # [b, h, s, d]
|
||||
|
||||
# Apply RoPE if needed
|
||||
if freqs_cis_img is not None:
|
||||
qq, kk = apply_rotary_emb(q, k, freqs_cis_img, head_first=True)
|
||||
assert qq.shape == q.shape and kk.shape == k.shape, \
|
||||
f'qq: {qq.shape}, q: {q.shape}, kk: {kk.shape}, k: {k.shape}'
|
||||
q, k = qq, kk
|
||||
|
||||
x = optimized_attention(q, k, v, self.num_heads, skip_reshape=True, attn_precision=self.attn_precision)
|
||||
x = self.out_proj(x)
|
||||
x = self.proj_drop(x)
|
||||
|
||||
out_tuple = (x,)
|
||||
|
||||
return out_tuple
|
321
comfy/ldm/hydit/controlnet.py
Normal file
321
comfy/ldm/hydit/controlnet.py
Normal file
@@ -0,0 +1,321 @@
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from torch.utils import checkpoint
|
||||
|
||||
from comfy.ldm.modules.diffusionmodules.mmdit import (
|
||||
Mlp,
|
||||
TimestepEmbedder,
|
||||
PatchEmbed,
|
||||
RMSNorm,
|
||||
)
|
||||
from comfy.ldm.modules.diffusionmodules.util import timestep_embedding
|
||||
from .poolers import AttentionPool
|
||||
|
||||
import comfy.latent_formats
|
||||
from .models import HunYuanDiTBlock, calc_rope
|
||||
|
||||
from .posemb_layers import get_2d_rotary_pos_embed, get_fill_resize_and_crop
|
||||
|
||||
|
||||
class HunYuanControlNet(nn.Module):
|
||||
"""
|
||||
HunYuanDiT: Diffusion model with a Transformer backbone.
|
||||
|
||||
Inherit ModelMixin and ConfigMixin to be compatible with the sampler StableDiffusionPipeline of diffusers.
|
||||
|
||||
Inherit PeftAdapterMixin to be compatible with the PEFT training pipeline.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
args: argparse.Namespace
|
||||
The arguments parsed by argparse.
|
||||
input_size: tuple
|
||||
The size of the input image.
|
||||
patch_size: int
|
||||
The size of the patch.
|
||||
in_channels: int
|
||||
The number of input channels.
|
||||
hidden_size: int
|
||||
The hidden size of the transformer backbone.
|
||||
depth: int
|
||||
The number of transformer blocks.
|
||||
num_heads: int
|
||||
The number of attention heads.
|
||||
mlp_ratio: float
|
||||
The ratio of the hidden size of the MLP in the transformer block.
|
||||
log_fn: callable
|
||||
The logging function.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size: tuple = 128,
|
||||
patch_size: int = 2,
|
||||
in_channels: int = 4,
|
||||
hidden_size: int = 1408,
|
||||
depth: int = 40,
|
||||
num_heads: int = 16,
|
||||
mlp_ratio: float = 4.3637,
|
||||
text_states_dim=1024,
|
||||
text_states_dim_t5=2048,
|
||||
text_len=77,
|
||||
text_len_t5=256,
|
||||
qk_norm=True, # See http://arxiv.org/abs/2302.05442 for details.
|
||||
size_cond=False,
|
||||
use_style_cond=False,
|
||||
learn_sigma=True,
|
||||
norm="layer",
|
||||
log_fn: callable = print,
|
||||
attn_precision=None,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.log_fn = log_fn
|
||||
self.depth = depth
|
||||
self.learn_sigma = learn_sigma
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = in_channels * 2 if learn_sigma else in_channels
|
||||
self.patch_size = patch_size
|
||||
self.num_heads = num_heads
|
||||
self.hidden_size = hidden_size
|
||||
self.text_states_dim = text_states_dim
|
||||
self.text_states_dim_t5 = text_states_dim_t5
|
||||
self.text_len = text_len
|
||||
self.text_len_t5 = text_len_t5
|
||||
self.size_cond = size_cond
|
||||
self.use_style_cond = use_style_cond
|
||||
self.norm = norm
|
||||
self.dtype = dtype
|
||||
self.latent_format = comfy.latent_formats.SDXL
|
||||
|
||||
self.mlp_t5 = nn.Sequential(
|
||||
nn.Linear(
|
||||
self.text_states_dim_t5,
|
||||
self.text_states_dim_t5 * 4,
|
||||
bias=True,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
),
|
||||
nn.SiLU(),
|
||||
nn.Linear(
|
||||
self.text_states_dim_t5 * 4,
|
||||
self.text_states_dim,
|
||||
bias=True,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
),
|
||||
)
|
||||
# learnable replace
|
||||
self.text_embedding_padding = nn.Parameter(
|
||||
torch.randn(
|
||||
self.text_len + self.text_len_t5,
|
||||
self.text_states_dim,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
|
||||
# Attention pooling
|
||||
pooler_out_dim = 1024
|
||||
self.pooler = AttentionPool(
|
||||
self.text_len_t5,
|
||||
self.text_states_dim_t5,
|
||||
num_heads=8,
|
||||
output_dim=pooler_out_dim,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
)
|
||||
|
||||
# Dimension of the extra input vectors
|
||||
self.extra_in_dim = pooler_out_dim
|
||||
|
||||
if self.size_cond:
|
||||
# Image size and crop size conditions
|
||||
self.extra_in_dim += 6 * 256
|
||||
|
||||
if self.use_style_cond:
|
||||
# Here we use a default learned embedder layer for future extension.
|
||||
self.style_embedder = nn.Embedding(
|
||||
1, hidden_size, dtype=dtype, device=device
|
||||
)
|
||||
self.extra_in_dim += hidden_size
|
||||
|
||||
# Text embedding for `add`
|
||||
self.x_embedder = PatchEmbed(
|
||||
input_size,
|
||||
patch_size,
|
||||
in_channels,
|
||||
hidden_size,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
)
|
||||
self.t_embedder = TimestepEmbedder(
|
||||
hidden_size, dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
self.extra_embedder = nn.Sequential(
|
||||
operations.Linear(
|
||||
self.extra_in_dim, hidden_size * 4, dtype=dtype, device=device
|
||||
),
|
||||
nn.SiLU(),
|
||||
operations.Linear(
|
||||
hidden_size * 4, hidden_size, bias=True, dtype=dtype, device=device
|
||||
),
|
||||
)
|
||||
|
||||
# Image embedding
|
||||
num_patches = self.x_embedder.num_patches
|
||||
|
||||
# HUnYuanDiT Blocks
|
||||
self.blocks = nn.ModuleList(
|
||||
[
|
||||
HunYuanDiTBlock(
|
||||
hidden_size=hidden_size,
|
||||
c_emb_size=hidden_size,
|
||||
num_heads=num_heads,
|
||||
mlp_ratio=mlp_ratio,
|
||||
text_states_dim=self.text_states_dim,
|
||||
qk_norm=qk_norm,
|
||||
norm_type=self.norm,
|
||||
skip=False,
|
||||
attn_precision=attn_precision,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
)
|
||||
for _ in range(19)
|
||||
]
|
||||
)
|
||||
|
||||
# Input zero linear for the first block
|
||||
self.before_proj = operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device)
|
||||
|
||||
|
||||
# Output zero linear for the every block
|
||||
self.after_proj_list = nn.ModuleList(
|
||||
[
|
||||
|
||||
operations.Linear(
|
||||
self.hidden_size, self.hidden_size, dtype=dtype, device=device
|
||||
)
|
||||
for _ in range(len(self.blocks))
|
||||
]
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
hint,
|
||||
timesteps,
|
||||
context,#encoder_hidden_states=None,
|
||||
text_embedding_mask=None,
|
||||
encoder_hidden_states_t5=None,
|
||||
text_embedding_mask_t5=None,
|
||||
image_meta_size=None,
|
||||
style=None,
|
||||
return_dict=False,
|
||||
**kwarg,
|
||||
):
|
||||
"""
|
||||
Forward pass of the encoder.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
x: torch.Tensor
|
||||
(B, D, H, W)
|
||||
t: torch.Tensor
|
||||
(B)
|
||||
encoder_hidden_states: torch.Tensor
|
||||
CLIP text embedding, (B, L_clip, D)
|
||||
text_embedding_mask: torch.Tensor
|
||||
CLIP text embedding mask, (B, L_clip)
|
||||
encoder_hidden_states_t5: torch.Tensor
|
||||
T5 text embedding, (B, L_t5, D)
|
||||
text_embedding_mask_t5: torch.Tensor
|
||||
T5 text embedding mask, (B, L_t5)
|
||||
image_meta_size: torch.Tensor
|
||||
(B, 6)
|
||||
style: torch.Tensor
|
||||
(B)
|
||||
cos_cis_img: torch.Tensor
|
||||
sin_cis_img: torch.Tensor
|
||||
return_dict: bool
|
||||
Whether to return a dictionary.
|
||||
"""
|
||||
condition = hint
|
||||
if condition.shape[0] == 1:
|
||||
condition = torch.repeat_interleave(condition, x.shape[0], dim=0)
|
||||
|
||||
text_states = context # 2,77,1024
|
||||
text_states_t5 = encoder_hidden_states_t5 # 2,256,2048
|
||||
text_states_mask = text_embedding_mask.bool() # 2,77
|
||||
text_states_t5_mask = text_embedding_mask_t5.bool() # 2,256
|
||||
b_t5, l_t5, c_t5 = text_states_t5.shape
|
||||
text_states_t5 = self.mlp_t5(text_states_t5.view(-1, c_t5)).view(b_t5, l_t5, -1)
|
||||
|
||||
padding = comfy.ops.cast_to_input(self.text_embedding_padding, text_states)
|
||||
|
||||
text_states[:, -self.text_len :] = torch.where(
|
||||
text_states_mask[:, -self.text_len :].unsqueeze(2),
|
||||
text_states[:, -self.text_len :],
|
||||
padding[: self.text_len],
|
||||
)
|
||||
text_states_t5[:, -self.text_len_t5 :] = torch.where(
|
||||
text_states_t5_mask[:, -self.text_len_t5 :].unsqueeze(2),
|
||||
text_states_t5[:, -self.text_len_t5 :],
|
||||
padding[self.text_len :],
|
||||
)
|
||||
|
||||
text_states = torch.cat([text_states, text_states_t5], dim=1) # 2,205,1024
|
||||
|
||||
# _, _, oh, ow = x.shape
|
||||
# th, tw = oh // self.patch_size, ow // self.patch_size
|
||||
|
||||
# Get image RoPE embedding according to `reso`lution.
|
||||
freqs_cis_img = calc_rope(
|
||||
x, self.patch_size, self.hidden_size // self.num_heads
|
||||
) # (cos_cis_img, sin_cis_img)
|
||||
|
||||
# ========================= Build time and image embedding =========================
|
||||
t = self.t_embedder(timesteps, dtype=self.dtype)
|
||||
x = self.x_embedder(x)
|
||||
|
||||
# ========================= Concatenate all extra vectors =========================
|
||||
# Build text tokens with pooling
|
||||
extra_vec = self.pooler(encoder_hidden_states_t5)
|
||||
|
||||
# Build image meta size tokens if applicable
|
||||
# if image_meta_size is not None:
|
||||
# image_meta_size = timestep_embedding(image_meta_size.view(-1), 256) # [B * 6, 256]
|
||||
# if image_meta_size.dtype != self.dtype:
|
||||
# image_meta_size = image_meta_size.half()
|
||||
# image_meta_size = image_meta_size.view(-1, 6 * 256)
|
||||
# extra_vec = torch.cat([extra_vec, image_meta_size], dim=1) # [B, D + 6 * 256]
|
||||
|
||||
# Build style tokens
|
||||
if style is not None:
|
||||
style_embedding = self.style_embedder(style)
|
||||
extra_vec = torch.cat([extra_vec, style_embedding], dim=1)
|
||||
|
||||
# Concatenate all extra vectors
|
||||
c = t + self.extra_embedder(extra_vec) # [B, D]
|
||||
|
||||
# ========================= Deal with Condition =========================
|
||||
condition = self.x_embedder(condition)
|
||||
|
||||
# ========================= Forward pass through HunYuanDiT blocks =========================
|
||||
controls = []
|
||||
x = x + self.before_proj(condition) # add condition
|
||||
for layer, block in enumerate(self.blocks):
|
||||
x = block(x, c, text_states, freqs_cis_img)
|
||||
controls.append(self.after_proj_list[layer](x)) # zero linear for output
|
||||
|
||||
return {"output": controls}
|
410
comfy/ldm/hydit/models.py
Normal file
410
comfy/ldm/hydit/models.py
Normal file
@@ -0,0 +1,410 @@
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
import comfy.ops
|
||||
from comfy.ldm.modules.diffusionmodules.mmdit import Mlp, TimestepEmbedder, PatchEmbed, RMSNorm
|
||||
from comfy.ldm.modules.diffusionmodules.util import timestep_embedding
|
||||
from torch.utils import checkpoint
|
||||
|
||||
from .attn_layers import Attention, CrossAttention
|
||||
from .poolers import AttentionPool
|
||||
from .posemb_layers import get_2d_rotary_pos_embed, get_fill_resize_and_crop
|
||||
|
||||
def calc_rope(x, patch_size, head_size):
|
||||
th = (x.shape[2] + (patch_size // 2)) // patch_size
|
||||
tw = (x.shape[3] + (patch_size // 2)) // patch_size
|
||||
base_size = 512 // 8 // patch_size
|
||||
start, stop = get_fill_resize_and_crop((th, tw), base_size)
|
||||
sub_args = [start, stop, (th, tw)]
|
||||
# head_size = HUNYUAN_DIT_CONFIG['DiT-g/2']['hidden_size'] // HUNYUAN_DIT_CONFIG['DiT-g/2']['num_heads']
|
||||
rope = get_2d_rotary_pos_embed(head_size, *sub_args)
|
||||
rope = (rope[0].to(x), rope[1].to(x))
|
||||
return rope
|
||||
|
||||
|
||||
def modulate(x, shift, scale):
|
||||
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
||||
|
||||
|
||||
class HunYuanDiTBlock(nn.Module):
|
||||
"""
|
||||
A HunYuanDiT block with `add` conditioning.
|
||||
"""
|
||||
def __init__(self,
|
||||
hidden_size,
|
||||
c_emb_size,
|
||||
num_heads,
|
||||
mlp_ratio=4.0,
|
||||
text_states_dim=1024,
|
||||
qk_norm=False,
|
||||
norm_type="layer",
|
||||
skip=False,
|
||||
attn_precision=None,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
):
|
||||
super().__init__()
|
||||
use_ele_affine = True
|
||||
|
||||
if norm_type == "layer":
|
||||
norm_layer = operations.LayerNorm
|
||||
elif norm_type == "rms":
|
||||
norm_layer = RMSNorm
|
||||
else:
|
||||
raise ValueError(f"Unknown norm_type: {norm_type}")
|
||||
|
||||
# ========================= Self-Attention =========================
|
||||
self.norm1 = norm_layer(hidden_size, elementwise_affine=use_ele_affine, eps=1e-6, dtype=dtype, device=device)
|
||||
self.attn1 = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, qk_norm=qk_norm, attn_precision=attn_precision, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
# ========================= FFN =========================
|
||||
self.norm2 = norm_layer(hidden_size, elementwise_affine=use_ele_affine, eps=1e-6, dtype=dtype, device=device)
|
||||
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||
approx_gelu = lambda: nn.GELU(approximate="tanh")
|
||||
self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
# ========================= Add =========================
|
||||
# Simply use add like SDXL.
|
||||
self.default_modulation = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
operations.Linear(c_emb_size, hidden_size, bias=True, dtype=dtype, device=device)
|
||||
)
|
||||
|
||||
# ========================= Cross-Attention =========================
|
||||
self.attn2 = CrossAttention(hidden_size, text_states_dim, num_heads=num_heads, qkv_bias=True,
|
||||
qk_norm=qk_norm, attn_precision=attn_precision, dtype=dtype, device=device, operations=operations)
|
||||
self.norm3 = norm_layer(hidden_size, elementwise_affine=True, eps=1e-6, dtype=dtype, device=device)
|
||||
|
||||
# ========================= Skip Connection =========================
|
||||
if skip:
|
||||
self.skip_norm = norm_layer(2 * hidden_size, elementwise_affine=True, eps=1e-6, dtype=dtype, device=device)
|
||||
self.skip_linear = operations.Linear(2 * hidden_size, hidden_size, dtype=dtype, device=device)
|
||||
else:
|
||||
self.skip_linear = None
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def _forward(self, x, c=None, text_states=None, freq_cis_img=None, skip=None):
|
||||
# Long Skip Connection
|
||||
if self.skip_linear is not None:
|
||||
cat = torch.cat([x, skip], dim=-1)
|
||||
if cat.dtype != x.dtype:
|
||||
cat = cat.to(x.dtype)
|
||||
cat = self.skip_norm(cat)
|
||||
x = self.skip_linear(cat)
|
||||
|
||||
# Self-Attention
|
||||
shift_msa = self.default_modulation(c).unsqueeze(dim=1)
|
||||
attn_inputs = (
|
||||
self.norm1(x) + shift_msa, freq_cis_img,
|
||||
)
|
||||
x = x + self.attn1(*attn_inputs)[0]
|
||||
|
||||
# Cross-Attention
|
||||
cross_inputs = (
|
||||
self.norm3(x), text_states, freq_cis_img
|
||||
)
|
||||
x = x + self.attn2(*cross_inputs)[0]
|
||||
|
||||
# FFN Layer
|
||||
mlp_inputs = self.norm2(x)
|
||||
x = x + self.mlp(mlp_inputs)
|
||||
|
||||
return x
|
||||
|
||||
def forward(self, x, c=None, text_states=None, freq_cis_img=None, skip=None):
|
||||
if self.gradient_checkpointing and self.training:
|
||||
return checkpoint.checkpoint(self._forward, x, c, text_states, freq_cis_img, skip)
|
||||
return self._forward(x, c, text_states, freq_cis_img, skip)
|
||||
|
||||
|
||||
class FinalLayer(nn.Module):
|
||||
"""
|
||||
The final layer of HunYuanDiT.
|
||||
"""
|
||||
def __init__(self, final_hidden_size, c_emb_size, patch_size, out_channels, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.norm_final = operations.LayerNorm(final_hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
self.linear = operations.Linear(final_hidden_size, patch_size * patch_size * out_channels, bias=True, dtype=dtype, device=device)
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
operations.Linear(c_emb_size, 2 * final_hidden_size, bias=True, dtype=dtype, device=device)
|
||||
)
|
||||
|
||||
def forward(self, x, c):
|
||||
shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
|
||||
x = modulate(self.norm_final(x), shift, scale)
|
||||
x = self.linear(x)
|
||||
return x
|
||||
|
||||
|
||||
class HunYuanDiT(nn.Module):
|
||||
"""
|
||||
HunYuanDiT: Diffusion model with a Transformer backbone.
|
||||
|
||||
Inherit ModelMixin and ConfigMixin to be compatible with the sampler StableDiffusionPipeline of diffusers.
|
||||
|
||||
Inherit PeftAdapterMixin to be compatible with the PEFT training pipeline.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
args: argparse.Namespace
|
||||
The arguments parsed by argparse.
|
||||
input_size: tuple
|
||||
The size of the input image.
|
||||
patch_size: int
|
||||
The size of the patch.
|
||||
in_channels: int
|
||||
The number of input channels.
|
||||
hidden_size: int
|
||||
The hidden size of the transformer backbone.
|
||||
depth: int
|
||||
The number of transformer blocks.
|
||||
num_heads: int
|
||||
The number of attention heads.
|
||||
mlp_ratio: float
|
||||
The ratio of the hidden size of the MLP in the transformer block.
|
||||
log_fn: callable
|
||||
The logging function.
|
||||
"""
|
||||
#@register_to_config
|
||||
def __init__(self,
|
||||
input_size: tuple = 32,
|
||||
patch_size: int = 2,
|
||||
in_channels: int = 4,
|
||||
hidden_size: int = 1152,
|
||||
depth: int = 28,
|
||||
num_heads: int = 16,
|
||||
mlp_ratio: float = 4.0,
|
||||
text_states_dim = 1024,
|
||||
text_states_dim_t5 = 2048,
|
||||
text_len = 77,
|
||||
text_len_t5 = 256,
|
||||
qk_norm = True,# See http://arxiv.org/abs/2302.05442 for details.
|
||||
size_cond = False,
|
||||
use_style_cond = False,
|
||||
learn_sigma = True,
|
||||
norm = "layer",
|
||||
log_fn: callable = print,
|
||||
attn_precision=None,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.log_fn = log_fn
|
||||
self.depth = depth
|
||||
self.learn_sigma = learn_sigma
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = in_channels * 2 if learn_sigma else in_channels
|
||||
self.patch_size = patch_size
|
||||
self.num_heads = num_heads
|
||||
self.hidden_size = hidden_size
|
||||
self.text_states_dim = text_states_dim
|
||||
self.text_states_dim_t5 = text_states_dim_t5
|
||||
self.text_len = text_len
|
||||
self.text_len_t5 = text_len_t5
|
||||
self.size_cond = size_cond
|
||||
self.use_style_cond = use_style_cond
|
||||
self.norm = norm
|
||||
self.dtype = dtype
|
||||
#import pdb
|
||||
#pdb.set_trace()
|
||||
|
||||
self.mlp_t5 = nn.Sequential(
|
||||
operations.Linear(self.text_states_dim_t5, self.text_states_dim_t5 * 4, bias=True, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Linear(self.text_states_dim_t5 * 4, self.text_states_dim, bias=True, dtype=dtype, device=device),
|
||||
)
|
||||
# learnable replace
|
||||
self.text_embedding_padding = nn.Parameter(
|
||||
torch.empty(self.text_len + self.text_len_t5, self.text_states_dim, dtype=dtype, device=device))
|
||||
|
||||
# Attention pooling
|
||||
pooler_out_dim = 1024
|
||||
self.pooler = AttentionPool(self.text_len_t5, self.text_states_dim_t5, num_heads=8, output_dim=pooler_out_dim, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
# Dimension of the extra input vectors
|
||||
self.extra_in_dim = pooler_out_dim
|
||||
|
||||
if self.size_cond:
|
||||
# Image size and crop size conditions
|
||||
self.extra_in_dim += 6 * 256
|
||||
|
||||
if self.use_style_cond:
|
||||
# Here we use a default learned embedder layer for future extension.
|
||||
self.style_embedder = operations.Embedding(1, hidden_size, dtype=dtype, device=device)
|
||||
self.extra_in_dim += hidden_size
|
||||
|
||||
# Text embedding for `add`
|
||||
self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, dtype=dtype, device=device, operations=operations)
|
||||
self.t_embedder = TimestepEmbedder(hidden_size, dtype=dtype, device=device, operations=operations)
|
||||
self.extra_embedder = nn.Sequential(
|
||||
operations.Linear(self.extra_in_dim, hidden_size * 4, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Linear(hidden_size * 4, hidden_size, bias=True, dtype=dtype, device=device),
|
||||
)
|
||||
|
||||
# Image embedding
|
||||
num_patches = self.x_embedder.num_patches
|
||||
|
||||
# HUnYuanDiT Blocks
|
||||
self.blocks = nn.ModuleList([
|
||||
HunYuanDiTBlock(hidden_size=hidden_size,
|
||||
c_emb_size=hidden_size,
|
||||
num_heads=num_heads,
|
||||
mlp_ratio=mlp_ratio,
|
||||
text_states_dim=self.text_states_dim,
|
||||
qk_norm=qk_norm,
|
||||
norm_type=self.norm,
|
||||
skip=layer > depth // 2,
|
||||
attn_precision=attn_precision,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
)
|
||||
for layer in range(depth)
|
||||
])
|
||||
|
||||
self.final_layer = FinalLayer(hidden_size, hidden_size, patch_size, self.out_channels, dtype=dtype, device=device, operations=operations)
|
||||
self.unpatchify_channels = self.out_channels
|
||||
|
||||
|
||||
|
||||
def forward(self,
|
||||
x,
|
||||
t,
|
||||
context,#encoder_hidden_states=None,
|
||||
text_embedding_mask=None,
|
||||
encoder_hidden_states_t5=None,
|
||||
text_embedding_mask_t5=None,
|
||||
image_meta_size=None,
|
||||
style=None,
|
||||
return_dict=False,
|
||||
control=None,
|
||||
transformer_options=None,
|
||||
):
|
||||
"""
|
||||
Forward pass of the encoder.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
x: torch.Tensor
|
||||
(B, D, H, W)
|
||||
t: torch.Tensor
|
||||
(B)
|
||||
encoder_hidden_states: torch.Tensor
|
||||
CLIP text embedding, (B, L_clip, D)
|
||||
text_embedding_mask: torch.Tensor
|
||||
CLIP text embedding mask, (B, L_clip)
|
||||
encoder_hidden_states_t5: torch.Tensor
|
||||
T5 text embedding, (B, L_t5, D)
|
||||
text_embedding_mask_t5: torch.Tensor
|
||||
T5 text embedding mask, (B, L_t5)
|
||||
image_meta_size: torch.Tensor
|
||||
(B, 6)
|
||||
style: torch.Tensor
|
||||
(B)
|
||||
cos_cis_img: torch.Tensor
|
||||
sin_cis_img: torch.Tensor
|
||||
return_dict: bool
|
||||
Whether to return a dictionary.
|
||||
"""
|
||||
#import pdb
|
||||
#pdb.set_trace()
|
||||
encoder_hidden_states = context
|
||||
text_states = encoder_hidden_states # 2,77,1024
|
||||
text_states_t5 = encoder_hidden_states_t5 # 2,256,2048
|
||||
text_states_mask = text_embedding_mask.bool() # 2,77
|
||||
text_states_t5_mask = text_embedding_mask_t5.bool() # 2,256
|
||||
b_t5, l_t5, c_t5 = text_states_t5.shape
|
||||
text_states_t5 = self.mlp_t5(text_states_t5.view(-1, c_t5)).view(b_t5, l_t5, -1)
|
||||
|
||||
padding = comfy.ops.cast_to_input(self.text_embedding_padding, text_states)
|
||||
|
||||
text_states[:,-self.text_len:] = torch.where(text_states_mask[:,-self.text_len:].unsqueeze(2), text_states[:,-self.text_len:], padding[:self.text_len])
|
||||
text_states_t5[:,-self.text_len_t5:] = torch.where(text_states_t5_mask[:,-self.text_len_t5:].unsqueeze(2), text_states_t5[:,-self.text_len_t5:], padding[self.text_len:])
|
||||
|
||||
text_states = torch.cat([text_states, text_states_t5], dim=1) # 2,205,1024
|
||||
# clip_t5_mask = torch.cat([text_states_mask, text_states_t5_mask], dim=-1)
|
||||
|
||||
_, _, oh, ow = x.shape
|
||||
th, tw = (oh + (self.patch_size // 2)) // self.patch_size, (ow + (self.patch_size // 2)) // self.patch_size
|
||||
|
||||
|
||||
# Get image RoPE embedding according to `reso`lution.
|
||||
freqs_cis_img = calc_rope(x, self.patch_size, self.hidden_size // self.num_heads) #(cos_cis_img, sin_cis_img)
|
||||
|
||||
# ========================= Build time and image embedding =========================
|
||||
t = self.t_embedder(t, dtype=x.dtype)
|
||||
x = self.x_embedder(x)
|
||||
|
||||
# ========================= Concatenate all extra vectors =========================
|
||||
# Build text tokens with pooling
|
||||
extra_vec = self.pooler(encoder_hidden_states_t5)
|
||||
|
||||
# Build image meta size tokens if applicable
|
||||
if self.size_cond:
|
||||
image_meta_size = timestep_embedding(image_meta_size.view(-1), 256).to(x.dtype) # [B * 6, 256]
|
||||
image_meta_size = image_meta_size.view(-1, 6 * 256)
|
||||
extra_vec = torch.cat([extra_vec, image_meta_size], dim=1) # [B, D + 6 * 256]
|
||||
|
||||
# Build style tokens
|
||||
if self.use_style_cond:
|
||||
if style is None:
|
||||
style = torch.zeros((extra_vec.shape[0],), device=x.device, dtype=torch.int)
|
||||
style_embedding = self.style_embedder(style, out_dtype=x.dtype)
|
||||
extra_vec = torch.cat([extra_vec, style_embedding], dim=1)
|
||||
|
||||
# Concatenate all extra vectors
|
||||
c = t + self.extra_embedder(extra_vec) # [B, D]
|
||||
|
||||
controls = None
|
||||
if control:
|
||||
controls = control.get("output", None)
|
||||
# ========================= Forward pass through HunYuanDiT blocks =========================
|
||||
skips = []
|
||||
for layer, block in enumerate(self.blocks):
|
||||
if layer > self.depth // 2:
|
||||
if controls is not None:
|
||||
skip = skips.pop() + controls.pop()
|
||||
else:
|
||||
skip = skips.pop()
|
||||
x = block(x, c, text_states, freqs_cis_img, skip) # (N, L, D)
|
||||
else:
|
||||
x = block(x, c, text_states, freqs_cis_img) # (N, L, D)
|
||||
|
||||
if layer < (self.depth // 2 - 1):
|
||||
skips.append(x)
|
||||
if controls is not None and len(controls) != 0:
|
||||
raise ValueError("The number of controls is not equal to the number of skip connections.")
|
||||
|
||||
# ========================= Final layer =========================
|
||||
x = self.final_layer(x, c) # (N, L, patch_size ** 2 * out_channels)
|
||||
x = self.unpatchify(x, th, tw) # (N, out_channels, H, W)
|
||||
|
||||
if return_dict:
|
||||
return {'x': x}
|
||||
if self.learn_sigma:
|
||||
return x[:,:self.out_channels // 2,:oh,:ow]
|
||||
return x[:,:,:oh,:ow]
|
||||
|
||||
def unpatchify(self, x, h, w):
|
||||
"""
|
||||
x: (N, T, patch_size**2 * C)
|
||||
imgs: (N, H, W, C)
|
||||
"""
|
||||
c = self.unpatchify_channels
|
||||
p = self.x_embedder.patch_size[0]
|
||||
# h = w = int(x.shape[1] ** 0.5)
|
||||
assert h * w == x.shape[1]
|
||||
|
||||
x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
|
||||
x = torch.einsum('nhwpqc->nchpwq', x)
|
||||
imgs = x.reshape(shape=(x.shape[0], c, h * p, w * p))
|
||||
return imgs
|
37
comfy/ldm/hydit/poolers.py
Normal file
37
comfy/ldm/hydit/poolers.py
Normal file
@@ -0,0 +1,37 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
import comfy.ops
|
||||
|
||||
class AttentionPool(nn.Module):
|
||||
def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.positional_embedding = nn.Parameter(torch.empty(spacial_dim + 1, embed_dim, dtype=dtype, device=device))
|
||||
self.k_proj = operations.Linear(embed_dim, embed_dim, dtype=dtype, device=device)
|
||||
self.q_proj = operations.Linear(embed_dim, embed_dim, dtype=dtype, device=device)
|
||||
self.v_proj = operations.Linear(embed_dim, embed_dim, dtype=dtype, device=device)
|
||||
self.c_proj = operations.Linear(embed_dim, output_dim or embed_dim, dtype=dtype, device=device)
|
||||
self.num_heads = num_heads
|
||||
self.embed_dim = embed_dim
|
||||
|
||||
def forward(self, x):
|
||||
x = x[:,:self.positional_embedding.shape[0] - 1]
|
||||
x = x.permute(1, 0, 2) # NLC -> LNC
|
||||
x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (L+1)NC
|
||||
x = x + comfy.ops.cast_to_input(self.positional_embedding[:, None, :], x) # (L+1)NC
|
||||
|
||||
q = self.q_proj(x[:1])
|
||||
k = self.k_proj(x)
|
||||
v = self.v_proj(x)
|
||||
|
||||
batch_size = q.shape[1]
|
||||
head_dim = self.embed_dim // self.num_heads
|
||||
q = q.view(1, batch_size * self.num_heads, head_dim).transpose(0, 1).view(batch_size, self.num_heads, -1, head_dim)
|
||||
k = k.view(k.shape[0], batch_size * self.num_heads, head_dim).transpose(0, 1).view(batch_size, self.num_heads, -1, head_dim)
|
||||
v = v.view(v.shape[0], batch_size * self.num_heads, head_dim).transpose(0, 1).view(batch_size, self.num_heads, -1, head_dim)
|
||||
|
||||
attn_output = optimized_attention(q, k, v, self.num_heads, skip_reshape=True).transpose(0, 1)
|
||||
|
||||
attn_output = self.c_proj(attn_output)
|
||||
return attn_output.squeeze(0)
|
224
comfy/ldm/hydit/posemb_layers.py
Normal file
224
comfy/ldm/hydit/posemb_layers.py
Normal file
@@ -0,0 +1,224 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
from typing import Union
|
||||
|
||||
|
||||
def _to_tuple(x):
|
||||
if isinstance(x, int):
|
||||
return x, x
|
||||
else:
|
||||
return x
|
||||
|
||||
|
||||
def get_fill_resize_and_crop(src, tgt):
|
||||
th, tw = _to_tuple(tgt)
|
||||
h, w = _to_tuple(src)
|
||||
|
||||
tr = th / tw # base resolution
|
||||
r = h / w # target resolution
|
||||
|
||||
# resize
|
||||
if r > tr:
|
||||
resize_height = th
|
||||
resize_width = int(round(th / h * w))
|
||||
else:
|
||||
resize_width = tw
|
||||
resize_height = int(round(tw / w * h)) # resize the target resolution down based on the base resolution
|
||||
|
||||
crop_top = int(round((th - resize_height) / 2.0))
|
||||
crop_left = int(round((tw - resize_width) / 2.0))
|
||||
|
||||
return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
|
||||
|
||||
|
||||
def get_meshgrid(start, *args):
|
||||
if len(args) == 0:
|
||||
# start is grid_size
|
||||
num = _to_tuple(start)
|
||||
start = (0, 0)
|
||||
stop = num
|
||||
elif len(args) == 1:
|
||||
# start is start, args[0] is stop, step is 1
|
||||
start = _to_tuple(start)
|
||||
stop = _to_tuple(args[0])
|
||||
num = (stop[0] - start[0], stop[1] - start[1])
|
||||
elif len(args) == 2:
|
||||
# start is start, args[0] is stop, args[1] is num
|
||||
start = _to_tuple(start)
|
||||
stop = _to_tuple(args[0])
|
||||
num = _to_tuple(args[1])
|
||||
else:
|
||||
raise ValueError(f"len(args) should be 0, 1 or 2, but got {len(args)}")
|
||||
|
||||
grid_h = np.linspace(start[0], stop[0], num[0], endpoint=False, dtype=np.float32)
|
||||
grid_w = np.linspace(start[1], stop[1], num[1], endpoint=False, dtype=np.float32)
|
||||
grid = np.meshgrid(grid_w, grid_h) # here w goes first
|
||||
grid = np.stack(grid, axis=0) # [2, W, H]
|
||||
return grid
|
||||
|
||||
#################################################################################
|
||||
# Sine/Cosine Positional Embedding Functions #
|
||||
#################################################################################
|
||||
# https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
|
||||
|
||||
def get_2d_sincos_pos_embed(embed_dim, start, *args, cls_token=False, extra_tokens=0):
|
||||
"""
|
||||
grid_size: int of the grid height and width
|
||||
return:
|
||||
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
|
||||
"""
|
||||
grid = get_meshgrid(start, *args) # [2, H, w]
|
||||
# grid_h = np.arange(grid_size, dtype=np.float32)
|
||||
# grid_w = np.arange(grid_size, dtype=np.float32)
|
||||
# grid = np.meshgrid(grid_w, grid_h) # here w goes first
|
||||
# grid = np.stack(grid, axis=0) # [2, W, H]
|
||||
|
||||
grid = grid.reshape([2, 1, *grid.shape[1:]])
|
||||
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
|
||||
if cls_token and extra_tokens > 0:
|
||||
pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
|
||||
return pos_embed
|
||||
|
||||
|
||||
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
|
||||
assert embed_dim % 2 == 0
|
||||
|
||||
# use half of dimensions to encode grid_h
|
||||
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
|
||||
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
|
||||
|
||||
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
|
||||
return emb
|
||||
|
||||
|
||||
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
||||
"""
|
||||
embed_dim: output dimension for each position
|
||||
pos: a list of positions to be encoded: size (W,H)
|
||||
out: (M, D)
|
||||
"""
|
||||
assert embed_dim % 2 == 0
|
||||
omega = np.arange(embed_dim // 2, dtype=np.float64)
|
||||
omega /= embed_dim / 2.
|
||||
omega = 1. / 10000**omega # (D/2,)
|
||||
|
||||
pos = pos.reshape(-1) # (M,)
|
||||
out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
|
||||
|
||||
emb_sin = np.sin(out) # (M, D/2)
|
||||
emb_cos = np.cos(out) # (M, D/2)
|
||||
|
||||
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
|
||||
return emb
|
||||
|
||||
|
||||
#################################################################################
|
||||
# Rotary Positional Embedding Functions #
|
||||
#################################################################################
|
||||
# https://github.com/facebookresearch/llama/blob/main/llama/model.py#L443
|
||||
|
||||
def get_2d_rotary_pos_embed(embed_dim, start, *args, use_real=True):
|
||||
"""
|
||||
This is a 2d version of precompute_freqs_cis, which is a RoPE for image tokens with 2d structure.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
embed_dim: int
|
||||
embedding dimension size
|
||||
start: int or tuple of int
|
||||
If len(args) == 0, start is num; If len(args) == 1, start is start, args[0] is stop, step is 1;
|
||||
If len(args) == 2, start is start, args[0] is stop, args[1] is num.
|
||||
use_real: bool
|
||||
If True, return real part and imaginary part separately. Otherwise, return complex numbers.
|
||||
|
||||
Returns
|
||||
-------
|
||||
pos_embed: torch.Tensor
|
||||
[HW, D/2]
|
||||
"""
|
||||
grid = get_meshgrid(start, *args) # [2, H, w]
|
||||
grid = grid.reshape([2, 1, *grid.shape[1:]]) # Returns a sampling matrix with the same resolution as the target resolution
|
||||
pos_embed = get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=use_real)
|
||||
return pos_embed
|
||||
|
||||
|
||||
def get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=False):
|
||||
assert embed_dim % 4 == 0
|
||||
|
||||
# use half of dimensions to encode grid_h
|
||||
emb_h = get_1d_rotary_pos_embed(embed_dim // 2, grid[0].reshape(-1), use_real=use_real) # (H*W, D/4)
|
||||
emb_w = get_1d_rotary_pos_embed(embed_dim // 2, grid[1].reshape(-1), use_real=use_real) # (H*W, D/4)
|
||||
|
||||
if use_real:
|
||||
cos = torch.cat([emb_h[0], emb_w[0]], dim=1) # (H*W, D/2)
|
||||
sin = torch.cat([emb_h[1], emb_w[1]], dim=1) # (H*W, D/2)
|
||||
return cos, sin
|
||||
else:
|
||||
emb = torch.cat([emb_h, emb_w], dim=1) # (H*W, D/2)
|
||||
return emb
|
||||
|
||||
|
||||
def get_1d_rotary_pos_embed(dim: int, pos: Union[np.ndarray, int], theta: float = 10000.0, use_real=False):
|
||||
"""
|
||||
Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
|
||||
|
||||
This function calculates a frequency tensor with complex exponentials using the given dimension 'dim'
|
||||
and the end index 'end'. The 'theta' parameter scales the frequencies.
|
||||
The returned tensor contains complex values in complex64 data type.
|
||||
|
||||
Args:
|
||||
dim (int): Dimension of the frequency tensor.
|
||||
pos (np.ndarray, int): Position indices for the frequency tensor. [S] or scalar
|
||||
theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0.
|
||||
use_real (bool, optional): If True, return real part and imaginary part separately.
|
||||
Otherwise, return complex numbers.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Precomputed frequency tensor with complex exponentials. [S, D/2]
|
||||
|
||||
"""
|
||||
if isinstance(pos, int):
|
||||
pos = np.arange(pos)
|
||||
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) # [D/2]
|
||||
t = torch.from_numpy(pos).to(freqs.device) # type: ignore # [S]
|
||||
freqs = torch.outer(t, freqs).float() # type: ignore # [S, D/2]
|
||||
if use_real:
|
||||
freqs_cos = freqs.cos().repeat_interleave(2, dim=1) # [S, D]
|
||||
freqs_sin = freqs.sin().repeat_interleave(2, dim=1) # [S, D]
|
||||
return freqs_cos, freqs_sin
|
||||
else:
|
||||
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2]
|
||||
return freqs_cis
|
||||
|
||||
|
||||
|
||||
def calc_sizes(rope_img, patch_size, th, tw):
|
||||
if rope_img == 'extend':
|
||||
# Expansion mode
|
||||
sub_args = [(th, tw)]
|
||||
elif rope_img.startswith('base'):
|
||||
# Based on the specified dimensions, other dimensions are obtained through interpolation.
|
||||
base_size = int(rope_img[4:]) // 8 // patch_size
|
||||
start, stop = get_fill_resize_and_crop((th, tw), base_size)
|
||||
sub_args = [start, stop, (th, tw)]
|
||||
else:
|
||||
raise ValueError(f"Unknown rope_img: {rope_img}")
|
||||
return sub_args
|
||||
|
||||
|
||||
def init_image_posemb(rope_img,
|
||||
resolutions,
|
||||
patch_size,
|
||||
hidden_size,
|
||||
num_heads,
|
||||
log_fn,
|
||||
rope_real=True,
|
||||
):
|
||||
freqs_cis_img = {}
|
||||
for reso in resolutions:
|
||||
th, tw = reso.height // 8 // patch_size, reso.width // 8 // patch_size
|
||||
sub_args = calc_sizes(rope_img, patch_size, th, tw)
|
||||
freqs_cis_img[str(reso)] = get_2d_rotary_pos_embed(hidden_size // num_heads, *sub_args, use_real=rope_real)
|
||||
log_fn(f" Using image RoPE ({rope_img}) ({'real' if rope_real else 'complex'}): {sub_args} | ({reso}) "
|
||||
f"{freqs_cis_img[str(reso)][0].shape if rope_real else freqs_cis_img[str(reso)].shape}")
|
||||
return freqs_cis_img
|
@@ -7,6 +7,9 @@ import torch
|
||||
import torch.nn as nn
|
||||
from .. import attention
|
||||
from einops import rearrange, repeat
|
||||
from .util import timestep_embedding
|
||||
import comfy.ops
|
||||
import comfy.ldm.common_dit
|
||||
|
||||
def default(x, y):
|
||||
if x is not None:
|
||||
@@ -68,12 +71,14 @@ class PatchEmbed(nn.Module):
|
||||
bias: bool = True,
|
||||
strict_img_size: bool = True,
|
||||
dynamic_img_pad: bool = True,
|
||||
padding_mode='circular',
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.patch_size = (patch_size, patch_size)
|
||||
self.padding_mode = padding_mode
|
||||
if img_size is not None:
|
||||
self.img_size = (img_size, img_size)
|
||||
self.grid_size = tuple([s // p for s, p in zip(self.img_size, self.patch_size)])
|
||||
@@ -107,9 +112,7 @@ class PatchEmbed(nn.Module):
|
||||
# f"Input width ({W}) should be divisible by patch size ({self.patch_size[1]})."
|
||||
# )
|
||||
if self.dynamic_img_pad:
|
||||
pad_h = (self.patch_size[0] - H % self.patch_size[0]) % self.patch_size[0]
|
||||
pad_w = (self.patch_size[1] - W % self.patch_size[1]) % self.patch_size[1]
|
||||
x = torch.nn.functional.pad(x, (0, pad_w, 0, pad_h), mode='reflect')
|
||||
x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size, padding_mode=self.padding_mode)
|
||||
x = self.proj(x)
|
||||
if self.flatten:
|
||||
x = x.flatten(2).transpose(1, 2) # NCHW -> NLC
|
||||
@@ -230,34 +233,8 @@ class TimestepEmbedder(nn.Module):
|
||||
)
|
||||
self.frequency_embedding_size = frequency_embedding_size
|
||||
|
||||
@staticmethod
|
||||
def timestep_embedding(t, dim, max_period=10000):
|
||||
"""
|
||||
Create sinusoidal timestep embeddings.
|
||||
:param t: a 1-D Tensor of N indices, one per batch element.
|
||||
These may be fractional.
|
||||
:param dim: the dimension of the output.
|
||||
:param max_period: controls the minimum frequency of the embeddings.
|
||||
:return: an (N, D) Tensor of positional embeddings.
|
||||
"""
|
||||
half = dim // 2
|
||||
freqs = torch.exp(
|
||||
-math.log(max_period)
|
||||
* torch.arange(start=0, end=half, dtype=torch.float32, device=t.device)
|
||||
/ half
|
||||
)
|
||||
args = t[:, None].float() * freqs[None]
|
||||
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||
if dim % 2:
|
||||
embedding = torch.cat(
|
||||
[embedding, torch.zeros_like(embedding[:, :1])], dim=-1
|
||||
)
|
||||
if torch.is_floating_point(t):
|
||||
embedding = embedding.to(dtype=t.dtype)
|
||||
return embedding
|
||||
|
||||
def forward(self, t, dtype, **kwargs):
|
||||
t_freq = self.timestep_embedding(t, self.frequency_embedding_size).to(dtype)
|
||||
t_freq = timestep_embedding(t, self.frequency_embedding_size).to(dtype)
|
||||
t_emb = self.mlp(t_freq)
|
||||
return t_emb
|
||||
|
||||
@@ -949,7 +926,7 @@ class MMDiT(nn.Module):
|
||||
context = self.context_processor(context)
|
||||
|
||||
hw = x.shape[-2:]
|
||||
x = self.x_embedder(x) + self.cropped_pos_embed(hw, device=x.device).to(dtype=x.dtype, device=x.device)
|
||||
x = self.x_embedder(x) + comfy.ops.cast_to_input(self.cropped_pos_embed(hw, device=x.device), x)
|
||||
c = self.t_embedder(t, dtype=x.dtype) # (N, D)
|
||||
if y is not None and self.y_embedder is not None:
|
||||
y = self.y_embedder(y) # (N, D)
|
||||
|
@@ -809,7 +809,7 @@ class UNetModel(nn.Module):
|
||||
self.out = nn.Sequential(
|
||||
operations.GroupNorm(32, ch, dtype=self.dtype, device=device),
|
||||
nn.SiLU(),
|
||||
zero_module(operations.conv_nd(dims, model_channels, out_channels, 3, padding=1, dtype=self.dtype, device=device)),
|
||||
operations.conv_nd(dims, model_channels, out_channels, 3, padding=1, dtype=self.dtype, device=device),
|
||||
)
|
||||
if self.predict_codebook_ids:
|
||||
self.id_predictor = nn.Sequential(
|
||||
|
@@ -1,3 +1,21 @@
|
||||
"""
|
||||
This file is part of ComfyUI.
|
||||
Copyright (C) 2024 Comfy
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU General Public License as published by
|
||||
the Free Software Foundation, either version 3 of the License, or
|
||||
(at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU General Public License
|
||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
"""
|
||||
|
||||
import comfy.utils
|
||||
import logging
|
||||
|
||||
@@ -218,11 +236,17 @@ def model_lora_keys_clip(model, key_map={}):
|
||||
lora_key = "lora_prior_te_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) #cascade lora: TODO put lora key prefix in the model config
|
||||
key_map[lora_key] = k
|
||||
|
||||
for k in sdk: #OneTrainer SD3 lora
|
||||
if k.startswith("t5xxl.transformer.") and k.endswith(".weight"):
|
||||
l_key = k[len("t5xxl.transformer."):-len(".weight")]
|
||||
lora_key = "lora_te3_{}".format(l_key.replace(".", "_"))
|
||||
key_map[lora_key] = k
|
||||
for k in sdk:
|
||||
if k.endswith(".weight"):
|
||||
if k.startswith("t5xxl.transformer."):#OneTrainer SD3 lora
|
||||
l_key = k[len("t5xxl.transformer."):-len(".weight")]
|
||||
lora_key = "lora_te3_{}".format(l_key.replace(".", "_"))
|
||||
key_map[lora_key] = k
|
||||
elif k.startswith("hydit_clip.transformer.bert."): #HunyuanDiT Lora
|
||||
l_key = k[len("hydit_clip.transformer.bert."):-len(".weight")]
|
||||
lora_key = "lora_te1_{}".format(l_key.replace(".", "_"))
|
||||
key_map[lora_key] = k
|
||||
|
||||
|
||||
k = "clip_g.transformer.text_projection.weight"
|
||||
if k in sdk:
|
||||
@@ -245,6 +269,7 @@ def model_lora_keys_unet(model, key_map={}):
|
||||
key_lora = k[len("diffusion_model."):-len(".weight")].replace(".", "_")
|
||||
key_map["lora_unet_{}".format(key_lora)] = k
|
||||
key_map["lora_prior_unet_{}".format(key_lora)] = k #cascade lora: TODO put lora key prefix in the model config
|
||||
key_map["{}".format(k[:-len(".weight")])] = k #generic lora format without any weird key names
|
||||
|
||||
diffusers_keys = comfy.utils.unet_to_diffusers(model.model_config.unet_config)
|
||||
for k in diffusers_keys:
|
||||
@@ -282,4 +307,18 @@ def model_lora_keys_unet(model, key_map={}):
|
||||
key_lora = "transformer.{}".format(k[:-len(".weight")]) #simpletrainer and probably regular diffusers lora format
|
||||
key_map[key_lora] = to
|
||||
|
||||
if isinstance(model, comfy.model_base.HunyuanDiT):
|
||||
for k in sdk:
|
||||
if k.startswith("diffusion_model.") and k.endswith(".weight"):
|
||||
key_lora = k[len("diffusion_model."):-len(".weight")]
|
||||
key_map["base_model.model.{}".format(key_lora)] = k #official hunyuan lora format
|
||||
|
||||
if isinstance(model, comfy.model_base.Flux): #Diffusers lora Flux
|
||||
diffusers_keys = comfy.utils.flux_to_diffusers(model.model_config.unet_config, output_prefix="diffusion_model.")
|
||||
for k in diffusers_keys:
|
||||
if k.endswith(".weight"):
|
||||
to = diffusers_keys[k]
|
||||
key_lora = "transformer.{}".format(k[:-len(".weight")]) #simpletrainer and probably regular diffusers flux lora format
|
||||
key_map[key_lora] = to
|
||||
|
||||
return key_map
|
||||
|
@@ -1,3 +1,21 @@
|
||||
"""
|
||||
This file is part of ComfyUI.
|
||||
Copyright (C) 2024 Comfy
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU General Public License as published by
|
||||
the Free Software Foundation, either version 3 of the License, or
|
||||
(at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU General Public License
|
||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
"""
|
||||
|
||||
import torch
|
||||
import logging
|
||||
from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep
|
||||
@@ -7,8 +25,11 @@ from comfy.ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugme
|
||||
from comfy.ldm.modules.diffusionmodules.upscaling import ImageConcatWithNoiseAugmentation
|
||||
from comfy.ldm.modules.diffusionmodules.mmdit import OpenAISignatureMMDITWrapper
|
||||
import comfy.ldm.aura.mmdit
|
||||
import comfy.ldm.hydit.models
|
||||
import comfy.ldm.audio.dit
|
||||
import comfy.ldm.audio.embedders
|
||||
import comfy.ldm.flux.model
|
||||
|
||||
import comfy.model_management
|
||||
import comfy.conds
|
||||
import comfy.ops
|
||||
@@ -25,6 +46,7 @@ class ModelType(Enum):
|
||||
EDM = 5
|
||||
FLOW = 6
|
||||
V_PREDICTION_CONTINUOUS = 7
|
||||
FLUX = 8
|
||||
|
||||
|
||||
from comfy.model_sampling import EPS, V_PREDICTION, EDM, ModelSamplingDiscrete, ModelSamplingContinuousEDM, StableCascadeSampling, ModelSamplingContinuousV
|
||||
@@ -52,6 +74,9 @@ def model_sampling(model_config, model_type):
|
||||
elif model_type == ModelType.V_PREDICTION_CONTINUOUS:
|
||||
c = V_PREDICTION
|
||||
s = ModelSamplingContinuousV
|
||||
elif model_type == ModelType.FLUX:
|
||||
c = comfy.model_sampling.CONST
|
||||
s = comfy.model_sampling.ModelSamplingFlux
|
||||
|
||||
class ModelSampling(s, c):
|
||||
pass
|
||||
@@ -67,16 +92,21 @@ class BaseModel(torch.nn.Module):
|
||||
self.latent_format = model_config.latent_format
|
||||
self.model_config = model_config
|
||||
self.manual_cast_dtype = model_config.manual_cast_dtype
|
||||
self.device = device
|
||||
|
||||
if not unet_config.get("disable_unet_model_creation", False):
|
||||
if self.manual_cast_dtype is not None:
|
||||
operations = comfy.ops.manual_cast
|
||||
if model_config.custom_operations is None:
|
||||
if self.manual_cast_dtype is not None:
|
||||
operations = comfy.ops.manual_cast
|
||||
else:
|
||||
operations = comfy.ops.disable_weight_init
|
||||
else:
|
||||
operations = comfy.ops.disable_weight_init
|
||||
operations = model_config.custom_operations
|
||||
self.diffusion_model = unet_model(**unet_config, device=device, operations=operations)
|
||||
if comfy.model_management.force_channels_last():
|
||||
self.diffusion_model.to(memory_format=torch.channels_last)
|
||||
logging.debug("using channels last mode for diffusion model")
|
||||
logging.info("model weight dtype {}, manual cast: {}".format(self.get_dtype(), self.manual_cast_dtype))
|
||||
self.model_type = model_type
|
||||
self.model_sampling = model_sampling(model_config, model_type)
|
||||
|
||||
@@ -87,6 +117,7 @@ class BaseModel(torch.nn.Module):
|
||||
self.concat_keys = ()
|
||||
logging.info("model_type {}".format(model_type.name))
|
||||
logging.debug("adm {}".format(self.adm_channels))
|
||||
self.memory_usage_factor = model_config.memory_usage_factor
|
||||
|
||||
def apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs):
|
||||
sigma = t
|
||||
@@ -245,11 +276,11 @@ class BaseModel(torch.nn.Module):
|
||||
dtype = self.manual_cast_dtype
|
||||
#TODO: this needs to be tweaked
|
||||
area = input_shape[0] * math.prod(input_shape[2:])
|
||||
return (area * comfy.model_management.dtype_size(dtype) / 50) * (1024 * 1024)
|
||||
return (area * comfy.model_management.dtype_size(dtype) * 0.01 * self.memory_usage_factor) * (1024 * 1024)
|
||||
else:
|
||||
#TODO: this formula might be too aggressive since I tweaked the sub-quad and split algorithms to use less memory.
|
||||
area = input_shape[0] * math.prod(input_shape[2:])
|
||||
return (((area * 0.6) / 0.9) + 1024) * (1024 * 1024)
|
||||
return (area * 0.15 * self.memory_usage_factor) * (1024 * 1024)
|
||||
|
||||
|
||||
def unclip_adm(unclip_conditioning, device, noise_augmentor, noise_augment_merge=0.0, seed=None):
|
||||
@@ -347,6 +378,7 @@ class SDXL(BaseModel):
|
||||
flat = torch.flatten(torch.cat(out)).unsqueeze(dim=0).repeat(clip_pooled.shape[0], 1)
|
||||
return torch.cat((clip_pooled.to(flat.device), flat), dim=1)
|
||||
|
||||
|
||||
class SVD_img2vid(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.V_PREDICTION_EDM, device=None):
|
||||
super().__init__(model_config, model_type, device=device)
|
||||
@@ -587,17 +619,6 @@ class SD3(BaseModel):
|
||||
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||
return out
|
||||
|
||||
def memory_required(self, input_shape):
|
||||
if comfy.model_management.xformers_enabled() or comfy.model_management.pytorch_attention_flash_attention():
|
||||
dtype = self.get_dtype()
|
||||
if self.manual_cast_dtype is not None:
|
||||
dtype = self.manual_cast_dtype
|
||||
#TODO: this probably needs to be tweaked
|
||||
area = input_shape[0] * input_shape[2] * input_shape[3]
|
||||
return (area * comfy.model_management.dtype_size(dtype) * 0.012) * (1024 * 1024)
|
||||
else:
|
||||
area = input_shape[0] * input_shape[2] * input_shape[3]
|
||||
return (area * 0.3) * (1024 * 1024)
|
||||
|
||||
class AuraFlow(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
@@ -648,3 +669,50 @@ class StableAudio1(BaseModel):
|
||||
for l in s:
|
||||
sd["{}{}".format(k, l)] = s[l]
|
||||
return sd
|
||||
|
||||
class HunyuanDiT(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.V_PREDICTION, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hydit.models.HunYuanDiT)
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = super().extra_conds(**kwargs)
|
||||
cross_attn = kwargs.get("cross_attn", None)
|
||||
if cross_attn is not None:
|
||||
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||
|
||||
attention_mask = kwargs.get("attention_mask", None)
|
||||
if attention_mask is not None:
|
||||
out['text_embedding_mask'] = comfy.conds.CONDRegular(attention_mask)
|
||||
|
||||
conditioning_mt5xl = kwargs.get("conditioning_mt5xl", None)
|
||||
if conditioning_mt5xl is not None:
|
||||
out['encoder_hidden_states_t5'] = comfy.conds.CONDRegular(conditioning_mt5xl)
|
||||
|
||||
attention_mask_mt5xl = kwargs.get("attention_mask_mt5xl", None)
|
||||
if attention_mask_mt5xl is not None:
|
||||
out['text_embedding_mask_t5'] = comfy.conds.CONDRegular(attention_mask_mt5xl)
|
||||
|
||||
width = kwargs.get("width", 768)
|
||||
height = kwargs.get("height", 768)
|
||||
crop_w = kwargs.get("crop_w", 0)
|
||||
crop_h = kwargs.get("crop_h", 0)
|
||||
target_width = kwargs.get("target_width", width)
|
||||
target_height = kwargs.get("target_height", height)
|
||||
|
||||
out['image_meta_size'] = comfy.conds.CONDRegular(torch.FloatTensor([[height, width, target_height, target_width, 0, 0]]))
|
||||
return out
|
||||
|
||||
class Flux(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLUX, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.flux.model.Flux)
|
||||
|
||||
def encode_adm(self, **kwargs):
|
||||
return kwargs["pooled_output"]
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = super().extra_conds(**kwargs)
|
||||
cross_attn = kwargs.get("cross_attn", None)
|
||||
if cross_attn is not None:
|
||||
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([kwargs.get("guidance", 3.5)]))
|
||||
return out
|
||||
|
@@ -115,6 +115,36 @@ def detect_unet_config(state_dict, key_prefix):
|
||||
unet_config["n_layers"] = double_layers + single_layers
|
||||
return unet_config
|
||||
|
||||
if '{}mlp_t5.0.weight'.format(key_prefix) in state_dict_keys: #Hunyuan DiT
|
||||
unet_config = {}
|
||||
unet_config["image_model"] = "hydit"
|
||||
unet_config["depth"] = count_blocks(state_dict_keys, '{}blocks.'.format(key_prefix) + '{}.')
|
||||
unet_config["hidden_size"] = state_dict['{}x_embedder.proj.weight'.format(key_prefix)].shape[0]
|
||||
if unet_config["hidden_size"] == 1408 and unet_config["depth"] == 40: #DiT-g/2
|
||||
unet_config["mlp_ratio"] = 4.3637
|
||||
if state_dict['{}extra_embedder.0.weight'.format(key_prefix)].shape[1] == 3968:
|
||||
unet_config["size_cond"] = True
|
||||
unet_config["use_style_cond"] = True
|
||||
unet_config["image_model"] = "hydit1"
|
||||
return unet_config
|
||||
|
||||
if '{}double_blocks.0.img_attn.norm.key_norm.scale'.format(key_prefix) in state_dict_keys: #Flux
|
||||
dit_config = {}
|
||||
dit_config["image_model"] = "flux"
|
||||
dit_config["in_channels"] = 16
|
||||
dit_config["vec_in_dim"] = 768
|
||||
dit_config["context_in_dim"] = 4096
|
||||
dit_config["hidden_size"] = 3072
|
||||
dit_config["mlp_ratio"] = 4.0
|
||||
dit_config["num_heads"] = 24
|
||||
dit_config["depth"] = count_blocks(state_dict_keys, '{}double_blocks.'.format(key_prefix) + '{}.')
|
||||
dit_config["depth_single_blocks"] = count_blocks(state_dict_keys, '{}single_blocks.'.format(key_prefix) + '{}.')
|
||||
dit_config["axes_dim"] = [16, 56, 56]
|
||||
dit_config["theta"] = 10000
|
||||
dit_config["qkv_bias"] = True
|
||||
dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys
|
||||
return dit_config
|
||||
|
||||
if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys:
|
||||
return None
|
||||
|
||||
@@ -261,13 +291,22 @@ def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=Fal
|
||||
return model_config
|
||||
|
||||
def unet_prefix_from_state_dict(state_dict):
|
||||
if "model.model.postprocess_conv.weight" in state_dict: #audio models
|
||||
unet_key_prefix = "model.model."
|
||||
elif "model.double_layers.0.attn.w1q.weight" in state_dict: #aura flow
|
||||
unet_key_prefix = "model."
|
||||
candidates = ["model.diffusion_model.", #ldm/sgm models
|
||||
"model.model.", #audio models
|
||||
]
|
||||
counts = {k: 0 for k in candidates}
|
||||
for k in state_dict:
|
||||
for c in candidates:
|
||||
if k.startswith(c):
|
||||
counts[c] += 1
|
||||
break
|
||||
|
||||
top = max(counts, key=counts.get)
|
||||
if counts[top] > 5:
|
||||
return top
|
||||
else:
|
||||
unet_key_prefix = "model.diffusion_model."
|
||||
return unet_key_prefix
|
||||
return "model." #aura flow and others
|
||||
|
||||
|
||||
def convert_config(unet_config):
|
||||
new_config = unet_config.copy()
|
||||
@@ -456,7 +495,12 @@ def model_config_from_diffusers_unet(state_dict):
|
||||
def convert_diffusers_mmdit(state_dict, output_prefix=""):
|
||||
out_sd = {}
|
||||
|
||||
if 'transformer_blocks.0.attn.add_q_proj.weight' in state_dict: #SD3
|
||||
if 'transformer_blocks.0.attn.norm_added_k.weight' in state_dict: #Flux
|
||||
depth = count_blocks(state_dict, 'transformer_blocks.{}.')
|
||||
depth_single_blocks = count_blocks(state_dict, 'single_transformer_blocks.{}.')
|
||||
hidden_size = state_dict["x_embedder.bias"].shape[0]
|
||||
sd_map = comfy.utils.flux_to_diffusers({"depth": depth, "depth_single_blocks": depth_single_blocks, "hidden_size": hidden_size}, output_prefix=output_prefix)
|
||||
elif 'transformer_blocks.0.attn.add_q_proj.weight' in state_dict: #SD3
|
||||
num_blocks = count_blocks(state_dict, 'transformer_blocks.{}.')
|
||||
depth = state_dict["pos_embed.proj.weight"].shape[0] // 64
|
||||
sd_map = comfy.utils.mmdit_to_diffusers({"depth": depth, "num_blocks": num_blocks}, output_prefix=output_prefix)
|
||||
@@ -482,7 +526,12 @@ def convert_diffusers_mmdit(state_dict, output_prefix=""):
|
||||
old_weight = out_sd.get(t[0], None)
|
||||
if old_weight is None:
|
||||
old_weight = torch.empty_like(weight)
|
||||
old_weight = old_weight.repeat([3] + [1] * (len(old_weight.shape) - 1))
|
||||
if old_weight.shape[offset[0]] < offset[1] + offset[2]:
|
||||
exp = list(weight.shape)
|
||||
exp[offset[0]] = offset[1] + offset[2]
|
||||
new = torch.empty(exp, device=weight.device, dtype=weight.dtype)
|
||||
new[:old_weight.shape[0]] = old_weight
|
||||
old_weight = new
|
||||
|
||||
w = old_weight.narrow(offset[0], offset[1], offset[2])
|
||||
else:
|
||||
|
@@ -1,3 +1,21 @@
|
||||
"""
|
||||
This file is part of ComfyUI.
|
||||
Copyright (C) 2024 Comfy
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU General Public License as published by
|
||||
the Free Software Foundation, either version 3 of the License, or
|
||||
(at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU General Public License
|
||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
"""
|
||||
|
||||
import psutil
|
||||
import logging
|
||||
from enum import Enum
|
||||
@@ -273,9 +291,12 @@ class LoadedModel:
|
||||
def model_memory(self):
|
||||
return self.model.model_size()
|
||||
|
||||
def model_offloaded_memory(self):
|
||||
return self.model.model_size() - self.model.loaded_size()
|
||||
|
||||
def model_memory_required(self, device):
|
||||
if device == self.model.current_device:
|
||||
return 0
|
||||
if device == self.model.current_loaded_device():
|
||||
return self.model_offloaded_memory()
|
||||
else:
|
||||
return self.model_memory()
|
||||
|
||||
@@ -287,15 +308,21 @@ class LoadedModel:
|
||||
|
||||
load_weights = not self.weights_loaded
|
||||
|
||||
try:
|
||||
if lowvram_model_memory > 0 and load_weights:
|
||||
self.real_model = self.model.patch_model_lowvram(device_to=patch_model_to, lowvram_model_memory=lowvram_model_memory, force_patch_weights=force_patch_weights)
|
||||
else:
|
||||
self.real_model = self.model.patch_model(device_to=patch_model_to, patch_weights=load_weights)
|
||||
except Exception as e:
|
||||
self.model.unpatch_model(self.model.offload_device)
|
||||
self.model_unload()
|
||||
raise e
|
||||
if self.model.loaded_size() > 0:
|
||||
use_more_vram = lowvram_model_memory
|
||||
if use_more_vram == 0:
|
||||
use_more_vram = 1e32
|
||||
self.model_use_more_vram(use_more_vram)
|
||||
else:
|
||||
try:
|
||||
if lowvram_model_memory > 0 and load_weights:
|
||||
self.real_model = self.model.patch_model_lowvram(device_to=patch_model_to, lowvram_model_memory=lowvram_model_memory, force_patch_weights=force_patch_weights)
|
||||
else:
|
||||
self.real_model = self.model.patch_model(device_to=patch_model_to, patch_weights=load_weights)
|
||||
except Exception as e:
|
||||
self.model.unpatch_model(self.model.offload_device)
|
||||
self.model_unload()
|
||||
raise e
|
||||
|
||||
if is_intel_xpu() and not args.disable_ipex_optimize:
|
||||
self.real_model = ipex.optimize(self.real_model.eval(), graph_mode=True, concat_linear=True)
|
||||
@@ -304,21 +331,43 @@ class LoadedModel:
|
||||
return self.real_model
|
||||
|
||||
def should_reload_model(self, force_patch_weights=False):
|
||||
if force_patch_weights and self.model.lowvram_patch_counter > 0:
|
||||
if force_patch_weights and self.model.lowvram_patch_counter() > 0:
|
||||
return True
|
||||
return False
|
||||
|
||||
def model_unload(self, unpatch_weights=True):
|
||||
def model_unload(self, memory_to_free=None, unpatch_weights=True):
|
||||
if memory_to_free is not None:
|
||||
if memory_to_free < self.model.loaded_size():
|
||||
self.model.partially_unload(self.model.offload_device, memory_to_free)
|
||||
return False
|
||||
self.model.unpatch_model(self.model.offload_device, unpatch_weights=unpatch_weights)
|
||||
self.model.model_patches_to(self.model.offload_device)
|
||||
self.weights_loaded = self.weights_loaded and not unpatch_weights
|
||||
self.real_model = None
|
||||
return True
|
||||
|
||||
def model_use_more_vram(self, extra_memory):
|
||||
return self.model.partially_load(self.device, extra_memory)
|
||||
|
||||
def __eq__(self, other):
|
||||
return self.model is other.model
|
||||
|
||||
def use_more_memory(extra_memory, loaded_models, device):
|
||||
for m in loaded_models:
|
||||
if m.device == device:
|
||||
extra_memory -= m.model_use_more_vram(extra_memory)
|
||||
if extra_memory <= 0:
|
||||
break
|
||||
|
||||
def offloaded_memory(loaded_models, device):
|
||||
offloaded_mem = 0
|
||||
for m in loaded_models:
|
||||
if m.device == device:
|
||||
offloaded_mem += m.model_offloaded_memory()
|
||||
return offloaded_mem
|
||||
|
||||
def minimum_inference_memory():
|
||||
return (1024 * 1024 * 1024)
|
||||
return (1024 * 1024 * 1024) * 1.2
|
||||
|
||||
def unload_model_clones(model, unload_weights_only=True, force_unload=True):
|
||||
to_unload = []
|
||||
@@ -352,6 +401,7 @@ def unload_model_clones(model, unload_weights_only=True, force_unload=True):
|
||||
def free_memory(memory_required, device, keep_loaded=[]):
|
||||
unloaded_model = []
|
||||
can_unload = []
|
||||
unloaded_models = []
|
||||
|
||||
for i in range(len(current_loaded_models) -1, -1, -1):
|
||||
shift_model = current_loaded_models[i]
|
||||
@@ -362,14 +412,18 @@ def free_memory(memory_required, device, keep_loaded=[]):
|
||||
|
||||
for x in sorted(can_unload):
|
||||
i = x[-1]
|
||||
memory_to_free = None
|
||||
if not DISABLE_SMART_MEMORY:
|
||||
if get_free_memory(device) > memory_required:
|
||||
free_mem = get_free_memory(device)
|
||||
if free_mem > memory_required:
|
||||
break
|
||||
current_loaded_models[i].model_unload()
|
||||
unloaded_model.append(i)
|
||||
memory_to_free = memory_required - free_mem
|
||||
logging.debug(f"Unloading {current_loaded_models[i].model.model.__class__.__name__}")
|
||||
if current_loaded_models[i].model_unload(memory_to_free):
|
||||
unloaded_model.append(i)
|
||||
|
||||
for i in sorted(unloaded_model, reverse=True):
|
||||
current_loaded_models.pop(i)
|
||||
unloaded_models.append(current_loaded_models.pop(i))
|
||||
|
||||
if len(unloaded_model) > 0:
|
||||
soft_empty_cache()
|
||||
@@ -378,12 +432,17 @@ def free_memory(memory_required, device, keep_loaded=[]):
|
||||
mem_free_total, mem_free_torch = get_free_memory(device, torch_free_too=True)
|
||||
if mem_free_torch > mem_free_total * 0.25:
|
||||
soft_empty_cache()
|
||||
return unloaded_models
|
||||
|
||||
def load_models_gpu(models, memory_required=0, force_patch_weights=False):
|
||||
def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimum_memory_required=None):
|
||||
global vram_state
|
||||
|
||||
inference_memory = minimum_inference_memory()
|
||||
extra_mem = max(inference_memory, memory_required)
|
||||
extra_mem = max(inference_memory, memory_required + 300 * 1024 * 1024)
|
||||
if minimum_memory_required is None:
|
||||
minimum_memory_required = extra_mem
|
||||
else:
|
||||
minimum_memory_required = max(inference_memory, minimum_memory_required + 300 * 1024 * 1024)
|
||||
|
||||
models = set(models)
|
||||
|
||||
@@ -416,25 +475,36 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False):
|
||||
devs = set(map(lambda a: a.device, models_already_loaded))
|
||||
for d in devs:
|
||||
if d != torch.device("cpu"):
|
||||
free_memory(extra_mem, d, models_already_loaded)
|
||||
return
|
||||
free_memory(extra_mem + offloaded_memory(models_already_loaded, d), d, models_already_loaded)
|
||||
free_mem = get_free_memory(d)
|
||||
if free_mem < minimum_memory_required:
|
||||
logging.info("Unloading models for lowram load.") #TODO: partial model unloading when this case happens, also handle the opposite case where models can be unlowvramed.
|
||||
models_to_load = free_memory(minimum_memory_required, d)
|
||||
logging.info("{} models unloaded.".format(len(models_to_load)))
|
||||
else:
|
||||
use_more_memory(free_mem - minimum_memory_required, models_already_loaded, d)
|
||||
if len(models_to_load) == 0:
|
||||
return
|
||||
|
||||
logging.info(f"Loading {len(models_to_load)} new model{'s' if len(models_to_load) > 1 else ''}")
|
||||
|
||||
total_memory_required = {}
|
||||
for loaded_model in models_to_load:
|
||||
if unload_model_clones(loaded_model.model, unload_weights_only=True, force_unload=False) == True:#unload clones where the weights are different
|
||||
total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device)
|
||||
unload_model_clones(loaded_model.model, unload_weights_only=True, force_unload=False) #unload clones where the weights are different
|
||||
total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device)
|
||||
|
||||
for device in total_memory_required:
|
||||
if device != torch.device("cpu"):
|
||||
free_memory(total_memory_required[device] * 1.3 + extra_mem, device, models_already_loaded)
|
||||
for loaded_model in models_already_loaded:
|
||||
total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device)
|
||||
|
||||
for loaded_model in models_to_load:
|
||||
weights_unloaded = unload_model_clones(loaded_model.model, unload_weights_only=False, force_unload=False) #unload the rest of the clones where the weights can stay loaded
|
||||
if weights_unloaded is not None:
|
||||
loaded_model.weights_loaded = not weights_unloaded
|
||||
|
||||
for device in total_memory_required:
|
||||
if device != torch.device("cpu"):
|
||||
free_memory(total_memory_required[device] * 1.1 + extra_mem, device, models_already_loaded)
|
||||
|
||||
for loaded_model in models_to_load:
|
||||
model = loaded_model.model
|
||||
torch_dev = model.load_device
|
||||
@@ -446,8 +516,8 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False):
|
||||
if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM):
|
||||
model_size = loaded_model.model_memory_required(torch_dev)
|
||||
current_free_mem = get_free_memory(torch_dev)
|
||||
lowvram_model_memory = int(max(64 * (1024 * 1024), (current_free_mem - 1024 * (1024 * 1024)) / 1.3 ))
|
||||
if model_size <= (current_free_mem - inference_memory): #only switch to lowvram if really necessary
|
||||
lowvram_model_memory = max(64 * (1024 * 1024), (current_free_mem - minimum_memory_required), min(current_free_mem * 0.4, current_free_mem - minimum_inference_memory()))
|
||||
if model_size <= lowvram_model_memory: #only switch to lowvram if really necessary
|
||||
lowvram_model_memory = 0
|
||||
|
||||
if vram_set_state == VRAMState.NO_VRAM:
|
||||
@@ -455,6 +525,14 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False):
|
||||
|
||||
cur_loaded_model = loaded_model.model_load(lowvram_model_memory, force_patch_weights=force_patch_weights)
|
||||
current_loaded_models.insert(0, loaded_model)
|
||||
|
||||
|
||||
devs = set(map(lambda a: a.device, models_already_loaded))
|
||||
for d in devs:
|
||||
if d != torch.device("cpu"):
|
||||
free_mem = get_free_memory(d)
|
||||
if free_mem > minimum_memory_required:
|
||||
use_more_memory(free_mem - minimum_memory_required, models_already_loaded, d)
|
||||
return
|
||||
|
||||
|
||||
@@ -523,6 +601,9 @@ def unet_inital_load_device(parameters, dtype):
|
||||
else:
|
||||
return cpu_dev
|
||||
|
||||
def maximum_vram_for_weights(device=None):
|
||||
return (get_total_memory(device) * 0.88 - minimum_inference_memory())
|
||||
|
||||
def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, torch.bfloat16, torch.float32]):
|
||||
if args.bf16_unet:
|
||||
return torch.bfloat16
|
||||
@@ -532,12 +613,37 @@ def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, tor
|
||||
return torch.float8_e4m3fn
|
||||
if args.fp8_e5m2_unet:
|
||||
return torch.float8_e5m2
|
||||
if should_use_fp16(device=device, model_params=model_params, manual_cast=True):
|
||||
if torch.float16 in supported_dtypes:
|
||||
return torch.float16
|
||||
if should_use_bf16(device, model_params=model_params, manual_cast=True):
|
||||
if torch.bfloat16 in supported_dtypes:
|
||||
return torch.bfloat16
|
||||
|
||||
fp8_dtype = None
|
||||
try:
|
||||
for dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
|
||||
if dtype in supported_dtypes:
|
||||
fp8_dtype = dtype
|
||||
break
|
||||
except:
|
||||
pass
|
||||
|
||||
if fp8_dtype is not None:
|
||||
free_model_memory = maximum_vram_for_weights(device)
|
||||
if model_params * 2 > free_model_memory:
|
||||
return fp8_dtype
|
||||
|
||||
for dt in supported_dtypes:
|
||||
if dt == torch.float16 and should_use_fp16(device=device, model_params=model_params):
|
||||
if torch.float16 in supported_dtypes:
|
||||
return torch.float16
|
||||
if dt == torch.bfloat16 and should_use_bf16(device, model_params=model_params):
|
||||
if torch.bfloat16 in supported_dtypes:
|
||||
return torch.bfloat16
|
||||
|
||||
for dt in supported_dtypes:
|
||||
if dt == torch.float16 and should_use_fp16(device=device, model_params=model_params, manual_cast=True):
|
||||
if torch.float16 in supported_dtypes:
|
||||
return torch.float16
|
||||
if dt == torch.bfloat16 and should_use_bf16(device, model_params=model_params, manual_cast=True):
|
||||
if torch.bfloat16 in supported_dtypes:
|
||||
return torch.bfloat16
|
||||
|
||||
return torch.float32
|
||||
|
||||
# None means no manual cast
|
||||
@@ -553,13 +659,13 @@ def unet_manual_cast(weight_dtype, inference_device, supported_dtypes=[torch.flo
|
||||
if bf16_supported and weight_dtype == torch.bfloat16:
|
||||
return None
|
||||
|
||||
if fp16_supported and torch.float16 in supported_dtypes:
|
||||
return torch.float16
|
||||
for dt in supported_dtypes:
|
||||
if dt == torch.float16 and fp16_supported:
|
||||
return torch.float16
|
||||
if dt == torch.bfloat16 and bf16_supported:
|
||||
return torch.bfloat16
|
||||
|
||||
elif bf16_supported and torch.bfloat16 in supported_dtypes:
|
||||
return torch.bfloat16
|
||||
else:
|
||||
return torch.float32
|
||||
return torch.float32
|
||||
|
||||
def text_encoder_offload_device():
|
||||
if args.gpu_only:
|
||||
@@ -578,6 +684,20 @@ def text_encoder_device():
|
||||
else:
|
||||
return torch.device("cpu")
|
||||
|
||||
def text_encoder_initial_device(load_device, offload_device, model_size=0):
|
||||
if load_device == offload_device or model_size <= 1024 * 1024 * 1024:
|
||||
return offload_device
|
||||
|
||||
if is_device_mps(load_device):
|
||||
return offload_device
|
||||
|
||||
mem_l = get_free_memory(load_device)
|
||||
mem_o = get_free_memory(offload_device)
|
||||
if mem_l > (mem_o * 0.5) and model_size * 1.2 < mem_l:
|
||||
return load_device
|
||||
else:
|
||||
return offload_device
|
||||
|
||||
def text_encoder_dtype(device=None):
|
||||
if args.fp8_e4m3fn_text_enc:
|
||||
return torch.float8_e4m3fn
|
||||
@@ -649,18 +769,29 @@ def supports_cast(device, dtype): #TODO
|
||||
return True
|
||||
if dtype == torch.float16:
|
||||
return True
|
||||
if is_device_mps(device):
|
||||
return False
|
||||
if directml_enabled: #TODO: test this
|
||||
return False
|
||||
if dtype == torch.bfloat16:
|
||||
return True
|
||||
if is_device_mps(device):
|
||||
return False
|
||||
if dtype == torch.float8_e4m3fn:
|
||||
return True
|
||||
if dtype == torch.float8_e5m2:
|
||||
return True
|
||||
return False
|
||||
|
||||
def pick_weight_dtype(dtype, fallback_dtype, device=None):
|
||||
if dtype is None:
|
||||
dtype = fallback_dtype
|
||||
elif dtype_size(dtype) > dtype_size(fallback_dtype):
|
||||
dtype = fallback_dtype
|
||||
|
||||
if not supports_cast(device, dtype):
|
||||
dtype = fallback_dtype
|
||||
|
||||
return dtype
|
||||
|
||||
def device_supports_non_blocking(device):
|
||||
if is_device_mps(device):
|
||||
return False #pytorch bug? mps doesn't support non blocking
|
||||
@@ -856,7 +987,7 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma
|
||||
fp16_works = True
|
||||
|
||||
if fp16_works or manual_cast:
|
||||
free_model_memory = (get_free_memory() * 0.9 - minimum_inference_memory())
|
||||
free_model_memory = maximum_vram_for_weights(device)
|
||||
if (not prioritize_performance) or model_params * 4 > free_model_memory:
|
||||
return True
|
||||
|
||||
@@ -876,9 +1007,9 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma
|
||||
if is_device_cpu(device): #TODO ? bf16 works on CPU but is extremely slow
|
||||
return False
|
||||
|
||||
if device is not None: #TODO not sure about mps bf16 support
|
||||
if device is not None:
|
||||
if is_device_mps(device):
|
||||
return False
|
||||
return True
|
||||
|
||||
if FORCE_FP32:
|
||||
return False
|
||||
@@ -886,23 +1017,23 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma
|
||||
if directml_enabled:
|
||||
return False
|
||||
|
||||
if cpu_mode() or mps_mode():
|
||||
if mps_mode():
|
||||
return True
|
||||
|
||||
if cpu_mode():
|
||||
return False
|
||||
|
||||
if is_intel_xpu():
|
||||
return True
|
||||
|
||||
if device is None:
|
||||
device = torch.device("cuda")
|
||||
|
||||
props = torch.cuda.get_device_properties(device)
|
||||
props = torch.cuda.get_device_properties("cuda")
|
||||
if props.major >= 8:
|
||||
return True
|
||||
|
||||
bf16_works = torch.cuda.is_bf16_supported()
|
||||
|
||||
if bf16_works or manual_cast:
|
||||
free_model_memory = (get_free_memory() * 0.9 - minimum_inference_memory())
|
||||
free_model_memory = maximum_vram_for_weights(device)
|
||||
if (not prioritize_performance) or model_params * 4 > free_model_memory:
|
||||
return True
|
||||
|
||||
|
@@ -1,8 +1,27 @@
|
||||
"""
|
||||
This file is part of ComfyUI.
|
||||
Copyright (C) 2024 Comfy
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU General Public License as published by
|
||||
the Free Software Foundation, either version 3 of the License, or
|
||||
(at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU General Public License
|
||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
"""
|
||||
|
||||
import torch
|
||||
import copy
|
||||
import inspect
|
||||
import logging
|
||||
import uuid
|
||||
import collections
|
||||
|
||||
import comfy.utils
|
||||
import comfy.model_management
|
||||
@@ -63,10 +82,31 @@ def set_model_options_pre_cfg_function(model_options, pre_cfg_function, disable_
|
||||
model_options["disable_cfg1_optimization"] = True
|
||||
return model_options
|
||||
|
||||
def wipe_lowvram_weight(m):
|
||||
if hasattr(m, "prev_comfy_cast_weights"):
|
||||
m.comfy_cast_weights = m.prev_comfy_cast_weights
|
||||
del m.prev_comfy_cast_weights
|
||||
m.weight_function = None
|
||||
m.bias_function = None
|
||||
|
||||
class LowVramPatch:
|
||||
def __init__(self, key, model_patcher):
|
||||
self.key = key
|
||||
self.model_patcher = model_patcher
|
||||
def __call__(self, weight):
|
||||
return self.model_patcher.calculate_weight(self.model_patcher.patches[self.key], weight, self.key)
|
||||
|
||||
|
||||
class ModelPatcher:
|
||||
def __init__(self, model, load_device, offload_device, size=0, current_device=None, weight_inplace_update=False):
|
||||
def __init__(self, model, load_device, offload_device, size=0, weight_inplace_update=False):
|
||||
self.size = size
|
||||
self.model = model
|
||||
if not hasattr(self.model, 'device'):
|
||||
logging.debug("Model doesn't have a device attribute.")
|
||||
self.model.device = offload_device
|
||||
elif self.model.device is None:
|
||||
self.model.device = offload_device
|
||||
|
||||
self.patches = {}
|
||||
self.backup = {}
|
||||
self.object_patches = {}
|
||||
@@ -75,24 +115,32 @@ class ModelPatcher:
|
||||
self.model_size()
|
||||
self.load_device = load_device
|
||||
self.offload_device = offload_device
|
||||
if current_device is None:
|
||||
self.current_device = self.offload_device
|
||||
else:
|
||||
self.current_device = current_device
|
||||
|
||||
self.weight_inplace_update = weight_inplace_update
|
||||
self.model_lowvram = False
|
||||
self.lowvram_patch_counter = 0
|
||||
self.patches_uuid = uuid.uuid4()
|
||||
|
||||
if not hasattr(self.model, 'model_loaded_weight_memory'):
|
||||
self.model.model_loaded_weight_memory = 0
|
||||
|
||||
if not hasattr(self.model, 'lowvram_patch_counter'):
|
||||
self.model.lowvram_patch_counter = 0
|
||||
|
||||
if not hasattr(self.model, 'model_lowvram'):
|
||||
self.model.model_lowvram = False
|
||||
|
||||
def model_size(self):
|
||||
if self.size > 0:
|
||||
return self.size
|
||||
self.size = comfy.model_management.module_size(self.model)
|
||||
return self.size
|
||||
|
||||
def loaded_size(self):
|
||||
return self.model.model_loaded_weight_memory
|
||||
|
||||
def lowvram_patch_counter(self):
|
||||
return self.model.lowvram_patch_counter
|
||||
|
||||
def clone(self):
|
||||
n = ModelPatcher(self.model, self.load_device, self.offload_device, self.size, self.current_device, weight_inplace_update=self.weight_inplace_update)
|
||||
n = ModelPatcher(self.model, self.load_device, self.offload_device, self.size, weight_inplace_update=self.weight_inplace_update)
|
||||
n.patches = {}
|
||||
for k in self.patches:
|
||||
n.patches[k] = self.patches[k][:]
|
||||
@@ -264,16 +312,16 @@ class ModelPatcher:
|
||||
sd.pop(k)
|
||||
return sd
|
||||
|
||||
def patch_weight_to_device(self, key, device_to=None):
|
||||
def patch_weight_to_device(self, key, device_to=None, inplace_update=False):
|
||||
if key not in self.patches:
|
||||
return
|
||||
|
||||
weight = comfy.utils.get_attr(self.model, key)
|
||||
|
||||
inplace_update = self.weight_inplace_update
|
||||
inplace_update = self.weight_inplace_update or inplace_update
|
||||
|
||||
if key not in self.backup:
|
||||
self.backup[key] = weight.to(device=self.offload_device, copy=inplace_update)
|
||||
self.backup[key] = collections.namedtuple('Dimension', ['weight', 'inplace_update'])(weight.to(device=self.offload_device, copy=inplace_update), inplace_update)
|
||||
|
||||
if device_to is not None:
|
||||
temp_weight = comfy.model_management.cast_to_device(weight, device_to, torch.float32, copy=True)
|
||||
@@ -302,29 +350,25 @@ class ModelPatcher:
|
||||
|
||||
if device_to is not None:
|
||||
self.model.to(device_to)
|
||||
self.current_device = device_to
|
||||
self.model.device = device_to
|
||||
self.model.model_loaded_weight_memory = self.model_size()
|
||||
|
||||
return self.model
|
||||
|
||||
def patch_model_lowvram(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False):
|
||||
self.patch_model(device_to, patch_weights=False)
|
||||
|
||||
logging.info("loading in lowvram mode {}".format(lowvram_model_memory/(1024 * 1024)))
|
||||
class LowVramPatch:
|
||||
def __init__(self, key, model_patcher):
|
||||
self.key = key
|
||||
self.model_patcher = model_patcher
|
||||
def __call__(self, weight):
|
||||
return self.model_patcher.calculate_weight(self.model_patcher.patches[self.key], weight, self.key)
|
||||
|
||||
def lowvram_load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False):
|
||||
mem_counter = 0
|
||||
patch_counter = 0
|
||||
lowvram_counter = 0
|
||||
for n, m in self.model.named_modules():
|
||||
lowvram_weight = False
|
||||
if hasattr(m, "comfy_cast_weights"):
|
||||
|
||||
if not full_load and hasattr(m, "comfy_cast_weights"):
|
||||
module_mem = comfy.model_management.module_size(m)
|
||||
if mem_counter + module_mem >= lowvram_model_memory:
|
||||
lowvram_weight = True
|
||||
lowvram_counter += 1
|
||||
if m.comfy_cast_weights:
|
||||
continue
|
||||
|
||||
weight_key = "{}.weight".format(n)
|
||||
bias_key = "{}.bias".format(n)
|
||||
@@ -346,15 +390,40 @@ class ModelPatcher:
|
||||
m.prev_comfy_cast_weights = m.comfy_cast_weights
|
||||
m.comfy_cast_weights = True
|
||||
else:
|
||||
if hasattr(m, "comfy_cast_weights"):
|
||||
if m.comfy_cast_weights:
|
||||
wipe_lowvram_weight(m)
|
||||
|
||||
if hasattr(m, "weight"):
|
||||
self.patch_weight_to_device(weight_key, device_to)
|
||||
self.patch_weight_to_device(bias_key, device_to)
|
||||
m.to(device_to)
|
||||
mem_counter += comfy.model_management.module_size(m)
|
||||
param = list(m.parameters())
|
||||
if len(param) > 0:
|
||||
weight = param[0]
|
||||
if weight.device == device_to:
|
||||
continue
|
||||
|
||||
weight_to = None
|
||||
if full_load:#TODO
|
||||
weight_to = device_to
|
||||
self.patch_weight_to_device(weight_key, device_to=weight_to) #TODO: speed this up without OOM
|
||||
self.patch_weight_to_device(bias_key, device_to=weight_to)
|
||||
m.to(device_to)
|
||||
logging.debug("lowvram: loaded module regularly {} {}".format(n, m))
|
||||
|
||||
self.model_lowvram = True
|
||||
self.lowvram_patch_counter = patch_counter
|
||||
if lowvram_counter > 0:
|
||||
logging.info("loaded partially {} {}".format(lowvram_model_memory / (1024 * 1024), patch_counter))
|
||||
self.model.model_lowvram = True
|
||||
else:
|
||||
logging.info("loaded completely {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024)))
|
||||
self.model.model_lowvram = False
|
||||
self.model.lowvram_patch_counter += patch_counter
|
||||
self.model.device = device_to
|
||||
self.model.model_loaded_weight_memory = mem_counter
|
||||
|
||||
|
||||
def patch_model_lowvram(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False):
|
||||
self.patch_model(device_to, patch_weights=False)
|
||||
self.lowvram_load(device_to, lowvram_model_memory=lowvram_model_memory, force_patch_weights=force_patch_weights)
|
||||
return self.model
|
||||
|
||||
def calculate_weight(self, patches, weight, key):
|
||||
@@ -527,34 +596,89 @@ class ModelPatcher:
|
||||
|
||||
def unpatch_model(self, device_to=None, unpatch_weights=True):
|
||||
if unpatch_weights:
|
||||
if self.model_lowvram:
|
||||
if self.model.model_lowvram:
|
||||
for m in self.model.modules():
|
||||
if hasattr(m, "prev_comfy_cast_weights"):
|
||||
m.comfy_cast_weights = m.prev_comfy_cast_weights
|
||||
del m.prev_comfy_cast_weights
|
||||
m.weight_function = None
|
||||
m.bias_function = None
|
||||
wipe_lowvram_weight(m)
|
||||
|
||||
self.model_lowvram = False
|
||||
self.lowvram_patch_counter = 0
|
||||
self.model.model_lowvram = False
|
||||
self.model.lowvram_patch_counter = 0
|
||||
|
||||
keys = list(self.backup.keys())
|
||||
|
||||
if self.weight_inplace_update:
|
||||
for k in keys:
|
||||
comfy.utils.copy_to_param(self.model, k, self.backup[k])
|
||||
else:
|
||||
for k in keys:
|
||||
comfy.utils.set_attr_param(self.model, k, self.backup[k])
|
||||
for k in keys:
|
||||
bk = self.backup[k]
|
||||
if bk.inplace_update:
|
||||
comfy.utils.copy_to_param(self.model, k, bk.weight)
|
||||
else:
|
||||
comfy.utils.set_attr_param(self.model, k, bk.weight)
|
||||
|
||||
self.backup.clear()
|
||||
|
||||
if device_to is not None:
|
||||
self.model.to(device_to)
|
||||
self.current_device = device_to
|
||||
self.model.device = device_to
|
||||
self.model.model_loaded_weight_memory = 0
|
||||
|
||||
keys = list(self.object_patches_backup.keys())
|
||||
for k in keys:
|
||||
comfy.utils.set_attr(self.model, k, self.object_patches_backup[k])
|
||||
|
||||
self.object_patches_backup.clear()
|
||||
|
||||
def partially_unload(self, device_to, memory_to_free=0):
|
||||
memory_freed = 0
|
||||
patch_counter = 0
|
||||
|
||||
for n, m in list(self.model.named_modules())[::-1]:
|
||||
if memory_to_free < memory_freed:
|
||||
break
|
||||
|
||||
shift_lowvram = False
|
||||
if hasattr(m, "comfy_cast_weights"):
|
||||
module_mem = comfy.model_management.module_size(m)
|
||||
weight_key = "{}.weight".format(n)
|
||||
bias_key = "{}.bias".format(n)
|
||||
|
||||
|
||||
if m.weight is not None and m.weight.device != device_to:
|
||||
for key in [weight_key, bias_key]:
|
||||
bk = self.backup.get(key, None)
|
||||
if bk is not None:
|
||||
if bk.inplace_update:
|
||||
comfy.utils.copy_to_param(self.model, key, bk.weight)
|
||||
else:
|
||||
comfy.utils.set_attr_param(self.model, key, bk.weight)
|
||||
self.backup.pop(key)
|
||||
|
||||
m.to(device_to)
|
||||
if weight_key in self.patches:
|
||||
m.weight_function = LowVramPatch(weight_key, self)
|
||||
patch_counter += 1
|
||||
if bias_key in self.patches:
|
||||
m.bias_function = LowVramPatch(bias_key, self)
|
||||
patch_counter += 1
|
||||
|
||||
m.prev_comfy_cast_weights = m.comfy_cast_weights
|
||||
m.comfy_cast_weights = True
|
||||
memory_freed += module_mem
|
||||
logging.debug("freed {}".format(n))
|
||||
|
||||
self.model.model_lowvram = True
|
||||
self.model.lowvram_patch_counter += patch_counter
|
||||
self.model.model_loaded_weight_memory -= memory_freed
|
||||
return memory_freed
|
||||
|
||||
def partially_load(self, device_to, extra_memory=0):
|
||||
self.unpatch_model(unpatch_weights=False)
|
||||
self.patch_model(patch_weights=False)
|
||||
full_load = False
|
||||
if self.model.model_lowvram == False:
|
||||
return 0
|
||||
if self.model.model_loaded_weight_memory + extra_memory > self.model_size():
|
||||
full_load = True
|
||||
current_used = self.model.model_loaded_weight_memory
|
||||
self.lowvram_load(device_to, lowvram_model_memory=current_used + extra_memory, full_load=full_load)
|
||||
return self.model.model_loaded_weight_memory - current_used
|
||||
|
||||
def current_loaded_device(self):
|
||||
return self.model.device
|
||||
|
@@ -59,8 +59,9 @@ class ModelSamplingDiscrete(torch.nn.Module):
|
||||
beta_schedule = sampling_settings.get("beta_schedule", "linear")
|
||||
linear_start = sampling_settings.get("linear_start", 0.00085)
|
||||
linear_end = sampling_settings.get("linear_end", 0.012)
|
||||
timesteps = sampling_settings.get("timesteps", 1000)
|
||||
|
||||
self._register_schedule(given_betas=None, beta_schedule=beta_schedule, timesteps=1000, linear_start=linear_start, linear_end=linear_end, cosine_s=8e-3)
|
||||
self._register_schedule(given_betas=None, beta_schedule=beta_schedule, timesteps=timesteps, linear_start=linear_start, linear_end=linear_end, cosine_s=8e-3)
|
||||
self.sigma_data = 1.0
|
||||
|
||||
def _register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000,
|
||||
@@ -271,3 +272,43 @@ class StableCascadeSampling(ModelSamplingDiscrete):
|
||||
|
||||
percent = 1.0 - percent
|
||||
return self.sigma(torch.tensor(percent))
|
||||
|
||||
|
||||
def flux_time_shift(mu: float, sigma: float, t):
|
||||
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
|
||||
|
||||
class ModelSamplingFlux(torch.nn.Module):
|
||||
def __init__(self, model_config=None):
|
||||
super().__init__()
|
||||
if model_config is not None:
|
||||
sampling_settings = model_config.sampling_settings
|
||||
else:
|
||||
sampling_settings = {}
|
||||
|
||||
self.set_parameters(shift=sampling_settings.get("shift", 1.15))
|
||||
|
||||
def set_parameters(self, shift=1.15, timesteps=10000):
|
||||
self.shift = shift
|
||||
ts = self.sigma((torch.arange(1, timesteps + 1, 1) / timesteps))
|
||||
self.register_buffer('sigmas', ts)
|
||||
|
||||
@property
|
||||
def sigma_min(self):
|
||||
return self.sigmas[0]
|
||||
|
||||
@property
|
||||
def sigma_max(self):
|
||||
return self.sigmas[-1]
|
||||
|
||||
def timestep(self, sigma):
|
||||
return sigma
|
||||
|
||||
def sigma(self, timestep):
|
||||
return flux_time_shift(self.shift, 1.0, timestep)
|
||||
|
||||
def percent_to_sigma(self, percent):
|
||||
if percent <= 0.0:
|
||||
return 1.0
|
||||
if percent >= 1.0:
|
||||
return 0.0
|
||||
return 1.0 - percent
|
||||
|
44
comfy/ops.py
44
comfy/ops.py
@@ -19,14 +19,27 @@
|
||||
import torch
|
||||
import comfy.model_management
|
||||
|
||||
def cast_bias_weight(s, input):
|
||||
|
||||
def cast_to(weight, dtype=None, device=None, non_blocking=False):
|
||||
return weight.to(device=device, dtype=dtype, non_blocking=non_blocking)
|
||||
|
||||
def cast_to_input(weight, input, non_blocking=False):
|
||||
return cast_to(weight, input.dtype, input.device, non_blocking=non_blocking)
|
||||
|
||||
def cast_bias_weight(s, input=None, dtype=None, device=None):
|
||||
if input is not None:
|
||||
if dtype is None:
|
||||
dtype = input.dtype
|
||||
if device is None:
|
||||
device = input.device
|
||||
|
||||
bias = None
|
||||
non_blocking = comfy.model_management.device_should_use_non_blocking(input.device)
|
||||
non_blocking = comfy.model_management.device_should_use_non_blocking(device)
|
||||
if s.bias is not None:
|
||||
bias = s.bias.to(device=input.device, dtype=input.dtype, non_blocking=non_blocking)
|
||||
bias = cast_to(s.bias, dtype, device, non_blocking=non_blocking)
|
||||
if s.bias_function is not None:
|
||||
bias = s.bias_function(bias)
|
||||
weight = s.weight.to(device=input.device, dtype=input.dtype, non_blocking=non_blocking)
|
||||
weight = cast_to(s.weight, dtype, device, non_blocking=non_blocking)
|
||||
if s.weight_function is not None:
|
||||
weight = s.weight_function(weight)
|
||||
return weight, bias
|
||||
@@ -168,6 +181,26 @@ class disable_weight_init:
|
||||
else:
|
||||
return super().forward(*args, **kwargs)
|
||||
|
||||
class Embedding(torch.nn.Embedding, CastWeightBiasOp):
|
||||
def reset_parameters(self):
|
||||
self.bias = None
|
||||
return None
|
||||
|
||||
def forward_comfy_cast_weights(self, input, out_dtype=None):
|
||||
output_dtype = out_dtype
|
||||
if self.weight.dtype == torch.float16 or self.weight.dtype == torch.bfloat16:
|
||||
out_dtype = None
|
||||
weight, bias = cast_bias_weight(self, device=input.device, dtype=out_dtype)
|
||||
return torch.nn.functional.embedding(input, weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse).to(dtype=output_dtype)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
if self.comfy_cast_weights:
|
||||
return self.forward_comfy_cast_weights(*args, **kwargs)
|
||||
else:
|
||||
if "out_dtype" in kwargs:
|
||||
kwargs.pop("out_dtype")
|
||||
return super().forward(*args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def conv_nd(s, dims, *args, **kwargs):
|
||||
if dims == 2:
|
||||
@@ -202,3 +235,6 @@ class manual_cast(disable_weight_init):
|
||||
|
||||
class ConvTranspose1d(disable_weight_init.ConvTranspose1d):
|
||||
comfy_cast_weights = True
|
||||
|
||||
class Embedding(disable_weight_init.Embedding):
|
||||
comfy_cast_weights = True
|
||||
|
@@ -61,7 +61,9 @@ def prepare_sampling(model, noise_shape, conds):
|
||||
device = model.load_device
|
||||
real_model = None
|
||||
models, inference_memory = get_additional_models(conds, model.model_dtype())
|
||||
comfy.model_management.load_models_gpu([model] + models, model.memory_required([noise_shape[0] * 2] + list(noise_shape[1:])) + inference_memory)
|
||||
memory_required = model.memory_required([noise_shape[0] * 2] + list(noise_shape[1:])) + inference_memory
|
||||
minimum_memory_required = model.memory_required([noise_shape[0]] + list(noise_shape[1:])) + inference_memory
|
||||
comfy.model_management.load_models_gpu([model] + models, memory_required=memory_required, minimum_memory_required=minimum_memory_required)
|
||||
real_model = model.model
|
||||
|
||||
return real_model, conds, models
|
||||
|
@@ -6,6 +6,8 @@ from comfy import model_management
|
||||
import math
|
||||
import logging
|
||||
import comfy.sampler_helpers
|
||||
import scipy
|
||||
import numpy
|
||||
|
||||
def get_area_and_mult(conds, x_in, timestep_in):
|
||||
dims = tuple(x_in.shape[2:])
|
||||
@@ -169,7 +171,7 @@ def calc_cond_batch(model, conds, x_in, timestep, model_options):
|
||||
for i in range(1, len(to_batch_temp) + 1):
|
||||
batch_amount = to_batch_temp[:len(to_batch_temp)//i]
|
||||
input_shape = [len(batch_amount) * first_shape[0]] + list(first_shape)[1:]
|
||||
if model.memory_required(input_shape) < free_memory:
|
||||
if model.memory_required(input_shape) * 1.5 < free_memory:
|
||||
to_batch = batch_amount
|
||||
break
|
||||
|
||||
@@ -311,13 +313,18 @@ def simple_scheduler(model_sampling, steps):
|
||||
def ddim_scheduler(model_sampling, steps):
|
||||
s = model_sampling
|
||||
sigs = []
|
||||
ss = max(len(s.sigmas) // steps, 1)
|
||||
x = 1
|
||||
if math.isclose(float(s.sigmas[x]), 0, abs_tol=0.00001):
|
||||
steps += 1
|
||||
sigs = []
|
||||
else:
|
||||
sigs = [0.0]
|
||||
|
||||
ss = max(len(s.sigmas) // steps, 1)
|
||||
while x < len(s.sigmas):
|
||||
sigs += [float(s.sigmas[x])]
|
||||
x += ss
|
||||
sigs = sigs[::-1]
|
||||
sigs += [0.0]
|
||||
return torch.FloatTensor(sigs)
|
||||
|
||||
def normal_scheduler(model_sampling, steps, sgm=False, floor=False):
|
||||
@@ -325,15 +332,34 @@ def normal_scheduler(model_sampling, steps, sgm=False, floor=False):
|
||||
start = s.timestep(s.sigma_max)
|
||||
end = s.timestep(s.sigma_min)
|
||||
|
||||
append_zero = True
|
||||
if sgm:
|
||||
timesteps = torch.linspace(start, end, steps + 1)[:-1]
|
||||
else:
|
||||
if math.isclose(float(s.sigma(end)), 0, abs_tol=0.00001):
|
||||
steps += 1
|
||||
append_zero = False
|
||||
timesteps = torch.linspace(start, end, steps)
|
||||
|
||||
sigs = []
|
||||
for x in range(len(timesteps)):
|
||||
ts = timesteps[x]
|
||||
sigs.append(s.sigma(ts))
|
||||
sigs.append(float(s.sigma(ts)))
|
||||
|
||||
if append_zero:
|
||||
sigs += [0.0]
|
||||
|
||||
return torch.FloatTensor(sigs)
|
||||
|
||||
# Implemented based on: https://arxiv.org/abs/2407.12173
|
||||
def beta_scheduler(model_sampling, steps, alpha=0.6, beta=0.6):
|
||||
total_timesteps = (len(model_sampling.sigmas) - 1)
|
||||
ts = 1 - numpy.linspace(0, 1, steps, endpoint=False)
|
||||
ts = numpy.rint(scipy.stats.beta.ppf(ts, alpha, beta) * total_timesteps)
|
||||
|
||||
sigs = []
|
||||
for t in ts:
|
||||
sigs += [float(model_sampling.sigmas[int(t)])]
|
||||
sigs += [0.0]
|
||||
return torch.FloatTensor(sigs)
|
||||
|
||||
@@ -703,7 +729,7 @@ def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model
|
||||
return cfg_guider.sample(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)
|
||||
|
||||
|
||||
SCHEDULER_NAMES = ["normal", "karras", "exponential", "sgm_uniform", "simple", "ddim_uniform"]
|
||||
SCHEDULER_NAMES = ["normal", "karras", "exponential", "sgm_uniform", "simple", "ddim_uniform", "beta"]
|
||||
SAMPLER_NAMES = KSAMPLER_NAMES + ["ddim", "uni_pc", "uni_pc_bh2"]
|
||||
|
||||
def calculate_sigmas(model_sampling, scheduler_name, steps):
|
||||
@@ -719,6 +745,8 @@ def calculate_sigmas(model_sampling, scheduler_name, steps):
|
||||
sigmas = ddim_scheduler(model_sampling, steps)
|
||||
elif scheduler_name == "sgm_uniform":
|
||||
sigmas = normal_scheduler(model_sampling, steps, sgm=True)
|
||||
elif scheduler_name == "beta":
|
||||
sigmas = beta_scheduler(model_sampling, steps)
|
||||
else:
|
||||
logging.error("error invalid scheduler {}".format(scheduler_name))
|
||||
return sigmas
|
||||
|
110
comfy/sd.py
110
comfy/sd.py
@@ -17,11 +17,13 @@ from . import diffusers_convert
|
||||
from . import model_detection
|
||||
|
||||
from . import sd1_clip
|
||||
from . import sd2_clip
|
||||
from . import sdxl_clip
|
||||
from . import sd3_clip
|
||||
from . import sa_t5
|
||||
import comfy.text_encoders.sd2_clip
|
||||
import comfy.text_encoders.sd3_clip
|
||||
import comfy.text_encoders.sa_t5
|
||||
import comfy.text_encoders.aura_t5
|
||||
import comfy.text_encoders.hydit
|
||||
import comfy.text_encoders.flux
|
||||
|
||||
import comfy.model_patcher
|
||||
import comfy.lora
|
||||
@@ -60,7 +62,7 @@ def load_lora_for_models(model, clip, lora, strength_model, strength_clip):
|
||||
|
||||
|
||||
class CLIP:
|
||||
def __init__(self, target=None, embedding_directory=None, no_init=False):
|
||||
def __init__(self, target=None, embedding_directory=None, no_init=False, tokenizer_data={}, parameters=0):
|
||||
if no_init:
|
||||
return
|
||||
params = target.params.copy()
|
||||
@@ -69,20 +71,24 @@ class CLIP:
|
||||
|
||||
load_device = model_management.text_encoder_device()
|
||||
offload_device = model_management.text_encoder_offload_device()
|
||||
params['device'] = offload_device
|
||||
dtype = model_management.text_encoder_dtype(load_device)
|
||||
params['dtype'] = dtype
|
||||
|
||||
params['device'] = model_management.text_encoder_initial_device(load_device, offload_device, parameters * model_management.dtype_size(dtype))
|
||||
self.cond_stage_model = clip(**(params))
|
||||
|
||||
for dt in self.cond_stage_model.dtypes:
|
||||
if not model_management.supports_cast(load_device, dt):
|
||||
load_device = offload_device
|
||||
if params['device'] != offload_device:
|
||||
self.cond_stage_model.to(offload_device)
|
||||
logging.warning("Had to shift TE back.")
|
||||
|
||||
self.tokenizer = tokenizer(embedding_directory=embedding_directory)
|
||||
self.tokenizer = tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
|
||||
self.patcher = comfy.model_patcher.ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device)
|
||||
if params['device'] == load_device:
|
||||
model_management.load_model_gpu(self.patcher)
|
||||
self.layer_idx = None
|
||||
logging.debug("CLIP model load device: {}, offload device: {}".format(load_device, offload_device))
|
||||
logging.debug("CLIP model load device: {}, offload device: {}, current: {}".format(load_device, offload_device, params['device']))
|
||||
|
||||
def clone(self):
|
||||
n = CLIP(no_init=True)
|
||||
@@ -135,7 +141,11 @@ class CLIP:
|
||||
return self.cond_stage_model.load_sd(sd)
|
||||
|
||||
def get_sd(self):
|
||||
return self.cond_stage_model.state_dict()
|
||||
sd_clip = self.cond_stage_model.state_dict()
|
||||
sd_tokenizer = self.tokenizer.state_dict()
|
||||
for k in sd_tokenizer:
|
||||
sd_clip[k] = sd_tokenizer[k]
|
||||
return sd_clip
|
||||
|
||||
def load_model(self):
|
||||
model_management.load_model_gpu(self.patcher)
|
||||
@@ -381,6 +391,8 @@ class CLIPType(Enum):
|
||||
STABLE_CASCADE = 2
|
||||
SD3 = 3
|
||||
STABLE_AUDIO = 4
|
||||
HUNYUAN_DIT = 5
|
||||
FLUX = 6
|
||||
|
||||
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION):
|
||||
clip_data = []
|
||||
@@ -408,35 +420,51 @@ def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DI
|
||||
clip_target.clip = sdxl_clip.SDXLRefinerClipModel
|
||||
clip_target.tokenizer = sdxl_clip.SDXLTokenizer
|
||||
elif "text_model.encoder.layers.22.mlp.fc1.weight" in clip_data[0]:
|
||||
clip_target.clip = sd2_clip.SD2ClipModel
|
||||
clip_target.tokenizer = sd2_clip.SD2Tokenizer
|
||||
clip_target.clip = comfy.text_encoders.sd2_clip.SD2ClipModel
|
||||
clip_target.tokenizer = comfy.text_encoders.sd2_clip.SD2Tokenizer
|
||||
elif "encoder.block.23.layer.1.DenseReluDense.wi_1.weight" in clip_data[0]:
|
||||
weight = clip_data[0]["encoder.block.23.layer.1.DenseReluDense.wi_1.weight"]
|
||||
dtype_t5 = weight.dtype
|
||||
if weight.shape[-1] == 4096:
|
||||
clip_target.clip = sd3_clip.sd3_clip(clip_l=False, clip_g=False, t5=True, dtype_t5=dtype_t5)
|
||||
clip_target.tokenizer = sd3_clip.SD3Tokenizer
|
||||
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=False, clip_g=False, t5=True, dtype_t5=dtype_t5)
|
||||
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
|
||||
elif weight.shape[-1] == 2048:
|
||||
clip_target.clip = comfy.text_encoders.aura_t5.AuraT5Model
|
||||
clip_target.tokenizer = comfy.text_encoders.aura_t5.AuraT5Tokenizer
|
||||
elif "encoder.block.0.layer.0.SelfAttention.k.weight" in clip_data[0]:
|
||||
clip_target.clip = sa_t5.SAT5Model
|
||||
clip_target.tokenizer = sa_t5.SAT5Tokenizer
|
||||
clip_target.clip = comfy.text_encoders.sa_t5.SAT5Model
|
||||
clip_target.tokenizer = comfy.text_encoders.sa_t5.SAT5Tokenizer
|
||||
else:
|
||||
clip_target.clip = sd1_clip.SD1ClipModel
|
||||
clip_target.tokenizer = sd1_clip.SD1Tokenizer
|
||||
elif len(clip_data) == 2:
|
||||
if clip_type == CLIPType.SD3:
|
||||
clip_target.clip = sd3_clip.sd3_clip(clip_l=True, clip_g=True, t5=False)
|
||||
clip_target.tokenizer = sd3_clip.SD3Tokenizer
|
||||
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=True, clip_g=True, t5=False)
|
||||
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
|
||||
elif clip_type == CLIPType.HUNYUAN_DIT:
|
||||
clip_target.clip = comfy.text_encoders.hydit.HyditModel
|
||||
clip_target.tokenizer = comfy.text_encoders.hydit.HyditTokenizer
|
||||
elif clip_type == CLIPType.FLUX:
|
||||
weight_name = "encoder.block.23.layer.1.DenseReluDense.wi_1.weight"
|
||||
weight = clip_data[0].get(weight_name, clip_data[1].get(weight_name, None))
|
||||
dtype_t5 = None
|
||||
if weight is not None:
|
||||
dtype_t5 = weight.dtype
|
||||
|
||||
clip_target.clip = comfy.text_encoders.flux.flux_clip(dtype_t5=dtype_t5)
|
||||
clip_target.tokenizer = comfy.text_encoders.flux.FluxTokenizer
|
||||
else:
|
||||
clip_target.clip = sdxl_clip.SDXLClipModel
|
||||
clip_target.tokenizer = sdxl_clip.SDXLTokenizer
|
||||
elif len(clip_data) == 3:
|
||||
clip_target.clip = sd3_clip.SD3ClipModel
|
||||
clip_target.tokenizer = sd3_clip.SD3Tokenizer
|
||||
clip_target.clip = comfy.text_encoders.sd3_clip.SD3ClipModel
|
||||
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
|
||||
|
||||
clip = CLIP(clip_target, embedding_directory=embedding_directory)
|
||||
parameters = 0
|
||||
for c in clip_data:
|
||||
parameters += comfy.utils.calculate_parameters(c)
|
||||
|
||||
clip = CLIP(clip_target, embedding_directory=embedding_directory, parameters=parameters)
|
||||
for c in clip_data:
|
||||
m, u = clip.load_sd(c)
|
||||
if len(m) > 0:
|
||||
@@ -478,25 +506,39 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl
|
||||
|
||||
return (model, clip, vae)
|
||||
|
||||
def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True):
|
||||
def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}):
|
||||
sd = comfy.utils.load_torch_file(ckpt_path)
|
||||
sd_keys = sd.keys()
|
||||
out = load_state_dict_guess_config(sd, output_vae, output_clip, output_clipvision, embedding_directory, output_model, model_options)
|
||||
if out is None:
|
||||
raise RuntimeError("ERROR: Could not detect model type of: {}".format(ckpt_path))
|
||||
return out
|
||||
|
||||
def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}):
|
||||
clip = None
|
||||
clipvision = None
|
||||
vae = None
|
||||
model = None
|
||||
model_patcher = None
|
||||
clip_target = None
|
||||
|
||||
diffusion_model_prefix = model_detection.unet_prefix_from_state_dict(sd)
|
||||
parameters = comfy.utils.calculate_parameters(sd, diffusion_model_prefix)
|
||||
weight_dtype = comfy.utils.weight_dtype(sd, diffusion_model_prefix)
|
||||
load_device = model_management.get_torch_device()
|
||||
|
||||
model_config = model_detection.model_config_from_unet(sd, diffusion_model_prefix)
|
||||
if model_config is None:
|
||||
raise RuntimeError("ERROR: Could not detect model type of: {}".format(ckpt_path))
|
||||
return None
|
||||
|
||||
unet_weight_dtype = list(model_config.supported_inference_dtypes)
|
||||
if weight_dtype is not None:
|
||||
unet_weight_dtype.append(weight_dtype)
|
||||
|
||||
model_config.custom_operations = model_options.get("custom_operations", None)
|
||||
unet_dtype = model_options.get("weight_dtype", None)
|
||||
|
||||
if unet_dtype is None:
|
||||
unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=unet_weight_dtype)
|
||||
|
||||
unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=model_config.supported_inference_dtypes)
|
||||
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
|
||||
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
|
||||
|
||||
@@ -520,7 +562,8 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
|
||||
if clip_target is not None:
|
||||
clip_sd = model_config.process_clip_state_dict(sd)
|
||||
if len(clip_sd) > 0:
|
||||
clip = CLIP(clip_target, embedding_directory=embedding_directory)
|
||||
parameters = comfy.utils.calculate_parameters(clip_sd)
|
||||
clip = CLIP(clip_target, embedding_directory=embedding_directory, tokenizer_data=clip_sd, parameters=parameters)
|
||||
m, u = clip.load_sd(clip_sd, full_model=True)
|
||||
if len(m) > 0:
|
||||
m_filter = list(filter(lambda a: ".logit_scale" not in a and ".transformer.text_projection.weight" not in a, m))
|
||||
@@ -539,7 +582,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
|
||||
logging.debug("left over keys: {}".format(left_over))
|
||||
|
||||
if output_model:
|
||||
model_patcher = comfy.model_patcher.ModelPatcher(model, load_device=load_device, offload_device=model_management.unet_offload_device(), current_device=inital_load_device)
|
||||
model_patcher = comfy.model_patcher.ModelPatcher(model, load_device=load_device, offload_device=model_management.unet_offload_device())
|
||||
if inital_load_device != torch.device("cpu"):
|
||||
logging.info("loaded straight to GPU")
|
||||
model_management.load_model_gpu(model_patcher)
|
||||
@@ -547,7 +590,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
|
||||
return (model_patcher, clip, vae, clipvision)
|
||||
|
||||
|
||||
def load_unet_state_dict(sd): #load unet in diffusers or regular format
|
||||
def load_unet_state_dict(sd, dtype=None): #load unet in diffusers or regular format
|
||||
|
||||
#Allow loading unets from checkpoint files
|
||||
diffusion_model_prefix = model_detection.unet_prefix_from_state_dict(sd)
|
||||
@@ -556,7 +599,6 @@ def load_unet_state_dict(sd): #load unet in diffusers or regular format
|
||||
sd = temp_sd
|
||||
|
||||
parameters = comfy.utils.calculate_parameters(sd)
|
||||
unet_dtype = model_management.unet_dtype(model_params=parameters)
|
||||
load_device = model_management.get_torch_device()
|
||||
model_config = model_detection.model_config_from_unet(sd, "")
|
||||
|
||||
@@ -583,7 +625,11 @@ def load_unet_state_dict(sd): #load unet in diffusers or regular format
|
||||
logging.warning("{} {}".format(diffusers_keys[k], k))
|
||||
|
||||
offload_device = model_management.unet_offload_device()
|
||||
unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=model_config.supported_inference_dtypes)
|
||||
if dtype is None:
|
||||
unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=model_config.supported_inference_dtypes)
|
||||
else:
|
||||
unet_dtype = dtype
|
||||
|
||||
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
|
||||
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
|
||||
model = model_config.get_model(new_sd, "")
|
||||
@@ -594,9 +640,9 @@ def load_unet_state_dict(sd): #load unet in diffusers or regular format
|
||||
logging.info("left over keys in unet: {}".format(left_over))
|
||||
return comfy.model_patcher.ModelPatcher(model, load_device=load_device, offload_device=offload_device)
|
||||
|
||||
def load_unet(unet_path):
|
||||
def load_unet(unet_path, dtype=None):
|
||||
sd = comfy.utils.load_torch_file(unet_path)
|
||||
model = load_unet_state_dict(sd)
|
||||
model = load_unet_state_dict(sd, dtype=dtype)
|
||||
if model is None:
|
||||
logging.error("ERROR UNSUPPORTED UNET {}".format(unet_path))
|
||||
raise RuntimeError("ERROR: Could not detect model type of: {}".format(unet_path))
|
||||
|
@@ -94,7 +94,8 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
with open(textmodel_json_config) as f:
|
||||
config = json.load(f)
|
||||
|
||||
self.transformer = model_class(config, dtype, device, comfy.ops.manual_cast)
|
||||
self.operations = comfy.ops.manual_cast
|
||||
self.transformer = model_class(config, dtype, device, self.operations)
|
||||
self.num_layers = self.transformer.num_layers
|
||||
|
||||
self.max_length = max_length
|
||||
@@ -140,15 +141,13 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
|
||||
def set_up_textual_embeddings(self, tokens, current_embeds):
|
||||
out_tokens = []
|
||||
next_new_token = token_dict_size = current_embeds.weight.shape[0] - 1
|
||||
next_new_token = token_dict_size = current_embeds.weight.shape[0]
|
||||
embedding_weights = []
|
||||
|
||||
for x in tokens:
|
||||
tokens_temp = []
|
||||
for y in x:
|
||||
if isinstance(y, numbers.Integral):
|
||||
if y == token_dict_size: #EOS token
|
||||
y = -1
|
||||
tokens_temp += [int(y)]
|
||||
else:
|
||||
if y.shape[0] == current_embeds.weight.shape[1]:
|
||||
@@ -163,12 +162,11 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
|
||||
n = token_dict_size
|
||||
if len(embedding_weights) > 0:
|
||||
new_embedding = torch.nn.Embedding(next_new_token + 1, current_embeds.weight.shape[1], device=current_embeds.weight.device, dtype=current_embeds.weight.dtype)
|
||||
new_embedding.weight[:token_dict_size] = current_embeds.weight[:-1]
|
||||
new_embedding = self.operations.Embedding(next_new_token + 1, current_embeds.weight.shape[1], device=current_embeds.weight.device, dtype=current_embeds.weight.dtype)
|
||||
new_embedding.weight[:token_dict_size] = current_embeds.weight
|
||||
for x in embedding_weights:
|
||||
new_embedding.weight[n] = x
|
||||
n += 1
|
||||
new_embedding.weight[n] = current_embeds.weight[-1] #EOS embedding
|
||||
self.transformer.set_input_embeddings(new_embedding)
|
||||
|
||||
processed_tokens = []
|
||||
@@ -197,7 +195,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
if self.enable_attention_masks:
|
||||
attention_mask_model = attention_mask
|
||||
|
||||
outputs = self.transformer(tokens, attention_mask_model, intermediate_output=self.layer_idx, final_layer_norm_intermediate=self.layer_norm_hidden_state)
|
||||
outputs = self.transformer(tokens, attention_mask_model, intermediate_output=self.layer_idx, final_layer_norm_intermediate=self.layer_norm_hidden_state, dtype=torch.float32)
|
||||
self.transformer.set_input_embeddings(backup_embeds)
|
||||
|
||||
if self.layer == "last":
|
||||
@@ -315,6 +313,17 @@ def expand_directory_list(directories):
|
||||
dirs.add(root)
|
||||
return list(dirs)
|
||||
|
||||
def bundled_embed(embed, prefix, suffix): #bundled embedding in lora format
|
||||
i = 0
|
||||
out_list = []
|
||||
for k in embed:
|
||||
if k.startswith(prefix) and k.endswith(suffix):
|
||||
out_list.append(embed[k])
|
||||
if len(out_list) == 0:
|
||||
return None
|
||||
|
||||
return torch.cat(out_list, dim=0)
|
||||
|
||||
def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=None):
|
||||
if isinstance(embedding_directory, str):
|
||||
embedding_directory = [embedding_directory]
|
||||
@@ -381,12 +390,16 @@ def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=No
|
||||
elif embed_key is not None and embed_key in embed:
|
||||
embed_out = embed[embed_key]
|
||||
else:
|
||||
values = embed.values()
|
||||
embed_out = next(iter(values))
|
||||
embed_out = bundled_embed(embed, 'bundle_emb.', '.string_to_param.*')
|
||||
if embed_out is None:
|
||||
embed_out = bundled_embed(embed, 'bundle_emb.', '.{}'.format(embed_key))
|
||||
if embed_out is None:
|
||||
values = embed.values()
|
||||
embed_out = next(iter(values))
|
||||
return embed_out
|
||||
|
||||
class SDTokenizer:
|
||||
def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, pad_to_max_length=True, min_length=None, pad_token=None):
|
||||
def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, pad_to_max_length=True, min_length=None, pad_token=None, tokenizer_data={}):
|
||||
if tokenizer_path is None:
|
||||
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_tokenizer")
|
||||
self.tokenizer = tokenizer_class.from_pretrained(tokenizer_path)
|
||||
@@ -519,12 +532,14 @@ class SDTokenizer:
|
||||
def untokenize(self, token_weight_pair):
|
||||
return list(map(lambda a: (a, self.inv_vocab[a[0]]), token_weight_pair))
|
||||
|
||||
def state_dict(self):
|
||||
return {}
|
||||
|
||||
class SD1Tokenizer:
|
||||
def __init__(self, embedding_directory=None, clip_name="l", tokenizer=SDTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}, clip_name="l", tokenizer=SDTokenizer):
|
||||
self.clip_name = clip_name
|
||||
self.clip = "clip_{}".format(self.clip_name)
|
||||
setattr(self, self.clip, tokenizer(embedding_directory=embedding_directory))
|
||||
setattr(self, self.clip, tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data))
|
||||
|
||||
def tokenize_with_weights(self, text:str, return_word_ids=False):
|
||||
out = {}
|
||||
@@ -534,6 +549,8 @@ class SD1Tokenizer:
|
||||
def untokenize(self, token_weight_pair):
|
||||
return getattr(self, self.clip).untokenize(token_weight_pair)
|
||||
|
||||
def state_dict(self):
|
||||
return {}
|
||||
|
||||
class SD1ClipModel(torch.nn.Module):
|
||||
def __init__(self, device="cpu", dtype=None, clip_name="l", clip_model=SDClipModel, name=None, **kwargs):
|
||||
|
@@ -6,7 +6,7 @@
|
||||
"attention_dropout": 0.0,
|
||||
"bos_token_id": 0,
|
||||
"dropout": 0.0,
|
||||
"eos_token_id": 2,
|
||||
"eos_token_id": 49407,
|
||||
"hidden_act": "quick_gelu",
|
||||
"hidden_size": 768,
|
||||
"initializer_factor": 1.0,
|
||||
|
@@ -16,12 +16,12 @@ class SDXLClipG(sd1_clip.SDClipModel):
|
||||
return super().load_sd(sd)
|
||||
|
||||
class SDXLClipGTokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, tokenizer_path=None, embedding_directory=None):
|
||||
def __init__(self, tokenizer_path=None, embedding_directory=None, tokenizer_data={}):
|
||||
super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=1280, embedding_key='clip_g')
|
||||
|
||||
|
||||
class SDXLTokenizer:
|
||||
def __init__(self, embedding_directory=None):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory)
|
||||
self.clip_g = SDXLClipGTokenizer(embedding_directory=embedding_directory)
|
||||
|
||||
@@ -34,6 +34,9 @@ class SDXLTokenizer:
|
||||
def untokenize(self, token_weight_pair):
|
||||
return self.clip_g.untokenize(token_weight_pair)
|
||||
|
||||
def state_dict(self):
|
||||
return {}
|
||||
|
||||
class SDXLClipModel(torch.nn.Module):
|
||||
def __init__(self, device="cpu", dtype=None):
|
||||
super().__init__()
|
||||
@@ -68,12 +71,12 @@ class SDXLRefinerClipModel(sd1_clip.SD1ClipModel):
|
||||
|
||||
|
||||
class StableCascadeClipGTokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, tokenizer_path=None, embedding_directory=None):
|
||||
def __init__(self, tokenizer_path=None, embedding_directory=None, tokenizer_data={}):
|
||||
super().__init__(tokenizer_path, pad_with_end=True, embedding_directory=embedding_directory, embedding_size=1280, embedding_key='clip_g')
|
||||
|
||||
class StableCascadeTokenizer(sd1_clip.SD1Tokenizer):
|
||||
def __init__(self, embedding_directory=None):
|
||||
super().__init__(embedding_directory=embedding_directory, clip_name="g", tokenizer=StableCascadeClipGTokenizer)
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="g", tokenizer=StableCascadeClipGTokenizer)
|
||||
|
||||
class StableCascadeClipG(sd1_clip.SDClipModel):
|
||||
def __init__(self, device="cpu", max_length=77, freeze=True, layer="hidden", layer_idx=-1, dtype=None):
|
||||
|
@@ -3,11 +3,13 @@ from . import model_base
|
||||
from . import utils
|
||||
|
||||
from . import sd1_clip
|
||||
from . import sd2_clip
|
||||
from . import sdxl_clip
|
||||
from . import sd3_clip
|
||||
from . import sa_t5
|
||||
import comfy.text_encoders.sd2_clip
|
||||
import comfy.text_encoders.sd3_clip
|
||||
import comfy.text_encoders.sa_t5
|
||||
import comfy.text_encoders.aura_t5
|
||||
import comfy.text_encoders.hydit
|
||||
import comfy.text_encoders.flux
|
||||
|
||||
from . import supported_models_base
|
||||
from . import latent_formats
|
||||
@@ -29,6 +31,7 @@ class SD15(supported_models_base.BASE):
|
||||
}
|
||||
|
||||
latent_format = latent_formats.SD15
|
||||
memory_usage_factor = 1.0
|
||||
|
||||
def process_clip_state_dict(self, state_dict):
|
||||
k = list(state_dict.keys())
|
||||
@@ -75,6 +78,7 @@ class SD20(supported_models_base.BASE):
|
||||
}
|
||||
|
||||
latent_format = latent_formats.SD15
|
||||
memory_usage_factor = 1.0
|
||||
|
||||
def model_type(self, state_dict, prefix=""):
|
||||
if self.unet_config["in_channels"] == 4: #SD2.0 inpainting models are not v prediction
|
||||
@@ -100,7 +104,7 @@ class SD20(supported_models_base.BASE):
|
||||
return state_dict
|
||||
|
||||
def clip_target(self, state_dict={}):
|
||||
return supported_models_base.ClipTarget(sd2_clip.SD2Tokenizer, sd2_clip.SD2ClipModel)
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.sd2_clip.SD2Tokenizer, comfy.text_encoders.sd2_clip.SD2ClipModel)
|
||||
|
||||
class SD21UnclipL(SD20):
|
||||
unet_config = {
|
||||
@@ -138,6 +142,7 @@ class SDXLRefiner(supported_models_base.BASE):
|
||||
}
|
||||
|
||||
latent_format = latent_formats.SDXL
|
||||
memory_usage_factor = 1.0
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
return model_base.SDXLRefiner(self, device=device)
|
||||
@@ -176,6 +181,8 @@ class SDXL(supported_models_base.BASE):
|
||||
|
||||
latent_format = latent_formats.SDXL
|
||||
|
||||
memory_usage_factor = 0.7
|
||||
|
||||
def model_type(self, state_dict, prefix=""):
|
||||
if 'edm_mean' in state_dict and 'edm_std' in state_dict: #Playground V2.5
|
||||
self.latent_format = latent_formats.SDXL_Playground_2_5()
|
||||
@@ -503,6 +510,9 @@ class SD3(supported_models_base.BASE):
|
||||
|
||||
unet_extra_config = {}
|
||||
latent_format = latent_formats.SD3
|
||||
|
||||
memory_usage_factor = 1.2
|
||||
|
||||
text_encoder_key_prefix = ["text_encoders."]
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
@@ -524,7 +534,7 @@ class SD3(supported_models_base.BASE):
|
||||
t5 = True
|
||||
dtype_t5 = state_dict[t5_key].dtype
|
||||
|
||||
return supported_models_base.ClipTarget(sd3_clip.SD3Tokenizer, sd3_clip.sd3_clip(clip_l=clip_l, clip_g=clip_g, t5=t5, dtype_t5=dtype_t5))
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.sd3_clip.SD3Tokenizer, comfy.text_encoders.sd3_clip.sd3_clip(clip_l=clip_l, clip_g=clip_g, t5=t5, dtype_t5=dtype_t5))
|
||||
|
||||
class StableAudio(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
@@ -555,7 +565,7 @@ class StableAudio(supported_models_base.BASE):
|
||||
return utils.state_dict_prefix_replace(state_dict, replace_prefix)
|
||||
|
||||
def clip_target(self, state_dict={}):
|
||||
return supported_models_base.ClipTarget(sa_t5.SAT5Tokenizer, sa_t5.SAT5Model)
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.sa_t5.SAT5Tokenizer, comfy.text_encoders.sa_t5.SAT5Model)
|
||||
|
||||
class AuraFlow(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
@@ -580,6 +590,90 @@ class AuraFlow(supported_models_base.BASE):
|
||||
def clip_target(self, state_dict={}):
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.aura_t5.AuraT5Tokenizer, comfy.text_encoders.aura_t5.AuraT5Model)
|
||||
|
||||
models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow]
|
||||
class HunyuanDiT(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"image_model": "hydit",
|
||||
}
|
||||
|
||||
unet_extra_config = {
|
||||
"attn_precision": torch.float32,
|
||||
}
|
||||
|
||||
sampling_settings = {
|
||||
"linear_start": 0.00085,
|
||||
"linear_end": 0.018,
|
||||
}
|
||||
|
||||
latent_format = latent_formats.SDXL
|
||||
|
||||
vae_key_prefix = ["vae."]
|
||||
text_encoder_key_prefix = ["text_encoders."]
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.HunyuanDiT(self, device=device)
|
||||
return out
|
||||
|
||||
def clip_target(self, state_dict={}):
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.hydit.HyditTokenizer, comfy.text_encoders.hydit.HyditModel)
|
||||
|
||||
class HunyuanDiT1(HunyuanDiT):
|
||||
unet_config = {
|
||||
"image_model": "hydit1",
|
||||
}
|
||||
|
||||
unet_extra_config = {}
|
||||
|
||||
sampling_settings = {
|
||||
"linear_start" : 0.00085,
|
||||
"linear_end" : 0.03,
|
||||
}
|
||||
|
||||
class Flux(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"image_model": "flux",
|
||||
"guidance_embed": True,
|
||||
}
|
||||
|
||||
sampling_settings = {
|
||||
}
|
||||
|
||||
unet_extra_config = {}
|
||||
latent_format = latent_formats.Flux
|
||||
|
||||
memory_usage_factor = 2.8
|
||||
|
||||
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
|
||||
|
||||
vae_key_prefix = ["vae."]
|
||||
text_encoder_key_prefix = ["text_encoders."]
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.Flux(self, device=device)
|
||||
return out
|
||||
|
||||
def clip_target(self, state_dict={}):
|
||||
pref = self.text_encoder_key_prefix[0]
|
||||
t5_key = "{}t5xxl.transformer.encoder.final_layer_norm.weight".format(pref)
|
||||
if t5_key in state_dict:
|
||||
dtype_t5 = state_dict[t5_key].dtype
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.flux.FluxTokenizer, comfy.text_encoders.flux.flux_clip(dtype_t5=dtype_t5))
|
||||
|
||||
class FluxSchnell(Flux):
|
||||
unet_config = {
|
||||
"image_model": "flux",
|
||||
"guidance_embed": False,
|
||||
}
|
||||
|
||||
sampling_settings = {
|
||||
"multiplier": 1.0,
|
||||
"shift": 1.0,
|
||||
}
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.Flux(self, model_type=model_base.ModelType.FLOW, device=device)
|
||||
return out
|
||||
|
||||
|
||||
models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, HunyuanDiT, HunyuanDiT1, Flux, FluxSchnell]
|
||||
|
||||
models += [SVD_img2vid]
|
||||
|
@@ -1,3 +1,21 @@
|
||||
"""
|
||||
This file is part of ComfyUI.
|
||||
Copyright (C) 2024 Comfy
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU General Public License as published by
|
||||
the Free Software Foundation, either version 3 of the License, or
|
||||
(at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU General Public License
|
||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
"""
|
||||
|
||||
import torch
|
||||
from . import model_base
|
||||
from . import utils
|
||||
@@ -27,7 +45,10 @@ class BASE:
|
||||
text_encoder_key_prefix = ["cond_stage_model."]
|
||||
supported_inference_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
||||
|
||||
memory_usage_factor = 2.0
|
||||
|
||||
manual_cast_dtype = None
|
||||
custom_operations = None
|
||||
|
||||
@classmethod
|
||||
def matches(s, unet_config, state_dict=None):
|
||||
|
@@ -1,21 +1,21 @@
|
||||
from comfy import sd1_clip
|
||||
from .llama_tokenizer import LLAMATokenizer
|
||||
import comfy.t5
|
||||
from .spiece_tokenizer import SPieceTokenizer
|
||||
import comfy.text_encoders.t5
|
||||
import os
|
||||
|
||||
class PT5XlModel(sd1_clip.SDClipModel):
|
||||
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None):
|
||||
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_pile_config_xl.json")
|
||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 2, "pad": 1}, model_class=comfy.t5.T5, enable_attention_masks=True, zero_out_masked=True)
|
||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 2, "pad": 1}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=True, zero_out_masked=True)
|
||||
|
||||
class PT5XlTokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, embedding_directory=None):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
tokenizer_path = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_pile_tokenizer"), "tokenizer.model")
|
||||
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=2048, embedding_key='pile_t5xl', tokenizer_class=LLAMATokenizer, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256, pad_token=1)
|
||||
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=2048, embedding_key='pile_t5xl', tokenizer_class=SPieceTokenizer, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256, pad_token=1)
|
||||
|
||||
class AuraT5Tokenizer(sd1_clip.SD1Tokenizer):
|
||||
def __init__(self, embedding_directory=None):
|
||||
super().__init__(embedding_directory=embedding_directory, clip_name="pile_t5xl", tokenizer=PT5XlTokenizer)
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="pile_t5xl", tokenizer=PT5XlTokenizer)
|
||||
|
||||
class AuraT5Model(sd1_clip.SD1ClipModel):
|
||||
def __init__(self, device="cpu", dtype=None, **kwargs):
|
||||
|
140
comfy/text_encoders/bert.py
Normal file
140
comfy/text_encoders/bert.py
Normal file
@@ -0,0 +1,140 @@
|
||||
import torch
|
||||
from comfy.ldm.modules.attention import optimized_attention_for_device
|
||||
import comfy.ops
|
||||
|
||||
class BertAttention(torch.nn.Module):
|
||||
def __init__(self, embed_dim, heads, dtype, device, operations):
|
||||
super().__init__()
|
||||
|
||||
self.heads = heads
|
||||
self.query = operations.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
|
||||
self.key = operations.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
|
||||
self.value = operations.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
|
||||
|
||||
|
||||
def forward(self, x, mask=None, optimized_attention=None):
|
||||
q = self.query(x)
|
||||
k = self.key(x)
|
||||
v = self.value(x)
|
||||
|
||||
out = optimized_attention(q, k, v, self.heads, mask)
|
||||
return out
|
||||
|
||||
class BertOutput(torch.nn.Module):
|
||||
def __init__(self, input_dim, output_dim, layer_norm_eps, dtype, device, operations):
|
||||
super().__init__()
|
||||
self.dense = operations.Linear(input_dim, output_dim, dtype=dtype, device=device)
|
||||
self.LayerNorm = operations.LayerNorm(output_dim, eps=layer_norm_eps, dtype=dtype, device=device)
|
||||
# self.dropout = nn.Dropout(0.0)
|
||||
|
||||
def forward(self, x, y):
|
||||
x = self.dense(x)
|
||||
# hidden_states = self.dropout(hidden_states)
|
||||
x = self.LayerNorm(x + y)
|
||||
return x
|
||||
|
||||
class BertAttentionBlock(torch.nn.Module):
|
||||
def __init__(self, embed_dim, heads, layer_norm_eps, dtype, device, operations):
|
||||
super().__init__()
|
||||
self.self = BertAttention(embed_dim, heads, dtype, device, operations)
|
||||
self.output = BertOutput(embed_dim, embed_dim, layer_norm_eps, dtype, device, operations)
|
||||
|
||||
def forward(self, x, mask, optimized_attention):
|
||||
y = self.self(x, mask, optimized_attention)
|
||||
return self.output(y, x)
|
||||
|
||||
class BertIntermediate(torch.nn.Module):
|
||||
def __init__(self, embed_dim, intermediate_dim, dtype, device, operations):
|
||||
super().__init__()
|
||||
self.dense = operations.Linear(embed_dim, intermediate_dim, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.dense(x)
|
||||
return torch.nn.functional.gelu(x)
|
||||
|
||||
|
||||
class BertBlock(torch.nn.Module):
|
||||
def __init__(self, embed_dim, intermediate_dim, heads, layer_norm_eps, dtype, device, operations):
|
||||
super().__init__()
|
||||
self.attention = BertAttentionBlock(embed_dim, heads, layer_norm_eps, dtype, device, operations)
|
||||
self.intermediate = BertIntermediate(embed_dim, intermediate_dim, dtype, device, operations)
|
||||
self.output = BertOutput(intermediate_dim, embed_dim, layer_norm_eps, dtype, device, operations)
|
||||
|
||||
def forward(self, x, mask, optimized_attention):
|
||||
x = self.attention(x, mask, optimized_attention)
|
||||
y = self.intermediate(x)
|
||||
return self.output(y, x)
|
||||
|
||||
class BertEncoder(torch.nn.Module):
|
||||
def __init__(self, num_layers, embed_dim, intermediate_dim, heads, layer_norm_eps, dtype, device, operations):
|
||||
super().__init__()
|
||||
self.layer = torch.nn.ModuleList([BertBlock(embed_dim, intermediate_dim, heads, layer_norm_eps, dtype, device, operations) for i in range(num_layers)])
|
||||
|
||||
def forward(self, x, mask=None, intermediate_output=None):
|
||||
optimized_attention = optimized_attention_for_device(x.device, mask=mask is not None, small_input=True)
|
||||
|
||||
if intermediate_output is not None:
|
||||
if intermediate_output < 0:
|
||||
intermediate_output = len(self.layer) + intermediate_output
|
||||
|
||||
intermediate = None
|
||||
for i, l in enumerate(self.layer):
|
||||
x = l(x, mask, optimized_attention)
|
||||
if i == intermediate_output:
|
||||
intermediate = x.clone()
|
||||
return x, intermediate
|
||||
|
||||
class BertEmbeddings(torch.nn.Module):
|
||||
def __init__(self, vocab_size, max_position_embeddings, type_vocab_size, pad_token_id, embed_dim, layer_norm_eps, dtype, device, operations):
|
||||
super().__init__()
|
||||
self.word_embeddings = operations.Embedding(vocab_size, embed_dim, padding_idx=pad_token_id, dtype=dtype, device=device)
|
||||
self.position_embeddings = operations.Embedding(max_position_embeddings, embed_dim, dtype=dtype, device=device)
|
||||
self.token_type_embeddings = operations.Embedding(type_vocab_size, embed_dim, dtype=dtype, device=device)
|
||||
|
||||
self.LayerNorm = operations.LayerNorm(embed_dim, eps=layer_norm_eps, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, input_tokens, token_type_ids=None, dtype=None):
|
||||
x = self.word_embeddings(input_tokens, out_dtype=dtype)
|
||||
x += comfy.ops.cast_to_input(self.position_embeddings.weight[:x.shape[1]], x)
|
||||
if token_type_ids is not None:
|
||||
x += self.token_type_embeddings(token_type_ids, out_dtype=x.dtype)
|
||||
else:
|
||||
x += comfy.ops.cast_to_input(self.token_type_embeddings.weight[0], x)
|
||||
x = self.LayerNorm(x)
|
||||
return x
|
||||
|
||||
|
||||
class BertModel_(torch.nn.Module):
|
||||
def __init__(self, config_dict, dtype, device, operations):
|
||||
super().__init__()
|
||||
embed_dim = config_dict["hidden_size"]
|
||||
layer_norm_eps = config_dict["layer_norm_eps"]
|
||||
|
||||
self.embeddings = BertEmbeddings(config_dict["vocab_size"], config_dict["max_position_embeddings"], config_dict["type_vocab_size"], config_dict["pad_token_id"], embed_dim, layer_norm_eps, dtype, device, operations)
|
||||
self.encoder = BertEncoder(config_dict["num_hidden_layers"], embed_dim, config_dict["intermediate_size"], config_dict["num_attention_heads"], layer_norm_eps, dtype, device, operations)
|
||||
|
||||
def forward(self, input_tokens, attention_mask=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None):
|
||||
x = self.embeddings(input_tokens, dtype=dtype)
|
||||
mask = None
|
||||
if attention_mask is not None:
|
||||
mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1])
|
||||
mask = mask.masked_fill(mask.to(torch.bool), float("-inf"))
|
||||
|
||||
x, i = self.encoder(x, mask, intermediate_output)
|
||||
return x, i
|
||||
|
||||
|
||||
class BertModel(torch.nn.Module):
|
||||
def __init__(self, config_dict, dtype, device, operations):
|
||||
super().__init__()
|
||||
self.bert = BertModel_(config_dict, dtype, device, operations)
|
||||
self.num_layers = config_dict["num_hidden_layers"]
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.bert.embeddings.word_embeddings
|
||||
|
||||
def set_input_embeddings(self, embeddings):
|
||||
self.bert.embeddings.word_embeddings = embeddings
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
return self.bert(*args, **kwargs)
|
71
comfy/text_encoders/flux.py
Normal file
71
comfy/text_encoders/flux.py
Normal file
@@ -0,0 +1,71 @@
|
||||
from comfy import sd1_clip
|
||||
import comfy.text_encoders.t5
|
||||
import comfy.model_management
|
||||
from transformers import T5TokenizerFast
|
||||
import torch
|
||||
import os
|
||||
|
||||
class T5XXLModel(sd1_clip.SDClipModel):
|
||||
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None):
|
||||
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_config_xxl.json")
|
||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5)
|
||||
|
||||
class T5XXLTokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
|
||||
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256)
|
||||
|
||||
|
||||
class FluxTokenizer:
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory)
|
||||
self.t5xxl = T5XXLTokenizer(embedding_directory=embedding_directory)
|
||||
|
||||
def tokenize_with_weights(self, text:str, return_word_ids=False):
|
||||
out = {}
|
||||
out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids)
|
||||
out["t5xxl"] = self.t5xxl.tokenize_with_weights(text, return_word_ids)
|
||||
return out
|
||||
|
||||
def untokenize(self, token_weight_pair):
|
||||
return self.clip_l.untokenize(token_weight_pair)
|
||||
|
||||
def state_dict(self):
|
||||
return {}
|
||||
|
||||
|
||||
class FluxClipModel(torch.nn.Module):
|
||||
def __init__(self, dtype_t5=None, device="cpu", dtype=None):
|
||||
super().__init__()
|
||||
dtype_t5 = comfy.model_management.pick_weight_dtype(dtype_t5, dtype, device)
|
||||
self.clip_l = sd1_clip.SDClipModel(device=device, dtype=dtype, return_projected_pooled=False)
|
||||
self.t5xxl = T5XXLModel(device=device, dtype=dtype_t5)
|
||||
self.dtypes = set([dtype, dtype_t5])
|
||||
|
||||
def set_clip_options(self, options):
|
||||
self.clip_l.set_clip_options(options)
|
||||
self.t5xxl.set_clip_options(options)
|
||||
|
||||
def reset_clip_options(self):
|
||||
self.clip_l.reset_clip_options()
|
||||
self.t5xxl.reset_clip_options()
|
||||
|
||||
def encode_token_weights(self, token_weight_pairs):
|
||||
token_weight_pairs_l = token_weight_pairs["l"]
|
||||
token_weight_pairs_t5 = token_weight_pairs["t5xxl"]
|
||||
|
||||
t5_out, t5_pooled = self.t5xxl.encode_token_weights(token_weight_pairs_t5)
|
||||
l_out, l_pooled = self.clip_l.encode_token_weights(token_weight_pairs_l)
|
||||
return t5_out, l_pooled
|
||||
|
||||
def load_sd(self, sd):
|
||||
if "text_model.encoder.layers.1.mlp.fc1.weight" in sd:
|
||||
return self.clip_l.load_sd(sd)
|
||||
else:
|
||||
return self.t5xxl.load_sd(sd)
|
||||
|
||||
def flux_clip(dtype_t5=None):
|
||||
class FluxClipModel_(FluxClipModel):
|
||||
def __init__(self, device="cpu", dtype=None):
|
||||
super().__init__(dtype_t5=dtype_t5, device=device, dtype=dtype)
|
||||
return FluxClipModel_
|
79
comfy/text_encoders/hydit.py
Normal file
79
comfy/text_encoders/hydit.py
Normal file
@@ -0,0 +1,79 @@
|
||||
from comfy import sd1_clip
|
||||
from transformers import BertTokenizer
|
||||
from .spiece_tokenizer import SPieceTokenizer
|
||||
from .bert import BertModel
|
||||
import comfy.text_encoders.t5
|
||||
import os
|
||||
import torch
|
||||
|
||||
class HyditBertModel(sd1_clip.SDClipModel):
|
||||
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None):
|
||||
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "hydit_clip.json")
|
||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"start": 101, "end": 102, "pad": 0}, model_class=BertModel, enable_attention_masks=True, return_attention_masks=True)
|
||||
|
||||
class HyditBertTokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "hydit_clip_tokenizer")
|
||||
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=1024, embedding_key='chinese_roberta', tokenizer_class=BertTokenizer, pad_to_max_length=False, max_length=512, min_length=77)
|
||||
|
||||
|
||||
class MT5XLModel(sd1_clip.SDClipModel):
|
||||
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None):
|
||||
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "mt5_config_xl.json")
|
||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=True, return_attention_masks=True)
|
||||
|
||||
class MT5XLTokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
#tokenizer_path = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "mt5_tokenizer"), "spiece.model")
|
||||
tokenizer = tokenizer_data.get("spiece_model", None)
|
||||
super().__init__(tokenizer, pad_with_end=False, embedding_size=2048, embedding_key='mt5xl', tokenizer_class=SPieceTokenizer, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256)
|
||||
|
||||
def state_dict(self):
|
||||
return {"spiece_model": self.tokenizer.serialize_model()}
|
||||
|
||||
class HyditTokenizer:
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
mt5_tokenizer_data = tokenizer_data.get("mt5xl.spiece_model", None)
|
||||
self.hydit_clip = HyditBertTokenizer(embedding_directory=embedding_directory)
|
||||
self.mt5xl = MT5XLTokenizer(tokenizer_data={"spiece_model": mt5_tokenizer_data}, embedding_directory=embedding_directory)
|
||||
|
||||
def tokenize_with_weights(self, text:str, return_word_ids=False):
|
||||
out = {}
|
||||
out["hydit_clip"] = self.hydit_clip.tokenize_with_weights(text, return_word_ids)
|
||||
out["mt5xl"] = self.mt5xl.tokenize_with_weights(text, return_word_ids)
|
||||
return out
|
||||
|
||||
def untokenize(self, token_weight_pair):
|
||||
return self.hydit_clip.untokenize(token_weight_pair)
|
||||
|
||||
def state_dict(self):
|
||||
return {"mt5xl.spiece_model": self.mt5xl.state_dict()["spiece_model"]}
|
||||
|
||||
class HyditModel(torch.nn.Module):
|
||||
def __init__(self, device="cpu", dtype=None):
|
||||
super().__init__()
|
||||
self.hydit_clip = HyditBertModel(dtype=dtype)
|
||||
self.mt5xl = MT5XLModel(dtype=dtype)
|
||||
|
||||
self.dtypes = set()
|
||||
if dtype is not None:
|
||||
self.dtypes.add(dtype)
|
||||
|
||||
def encode_token_weights(self, token_weight_pairs):
|
||||
hydit_out = self.hydit_clip.encode_token_weights(token_weight_pairs["hydit_clip"])
|
||||
mt5_out = self.mt5xl.encode_token_weights(token_weight_pairs["mt5xl"])
|
||||
return hydit_out[0], hydit_out[1], {"attention_mask": hydit_out[2]["attention_mask"], "conditioning_mt5xl": mt5_out[0], "attention_mask_mt5xl": mt5_out[2]["attention_mask"]}
|
||||
|
||||
def load_sd(self, sd):
|
||||
if "bert.encoder.layer.0.attention.self.query.weight" in sd:
|
||||
return self.hydit_clip.load_sd(sd)
|
||||
else:
|
||||
return self.mt5xl.load_sd(sd)
|
||||
|
||||
def set_clip_options(self, options):
|
||||
self.hydit_clip.set_clip_options(options)
|
||||
self.mt5xl.set_clip_options(options)
|
||||
|
||||
def reset_clip_options(self):
|
||||
self.hydit_clip.reset_clip_options()
|
||||
self.mt5xl.reset_clip_options()
|
35
comfy/text_encoders/hydit_clip.json
Normal file
35
comfy/text_encoders/hydit_clip.json
Normal file
@@ -0,0 +1,35 @@
|
||||
{
|
||||
"_name_or_path": "hfl/chinese-roberta-wwm-ext-large",
|
||||
"architectures": [
|
||||
"BertModel"
|
||||
],
|
||||
"attention_probs_dropout_prob": 0.1,
|
||||
"bos_token_id": 0,
|
||||
"classifier_dropout": null,
|
||||
"directionality": "bidi",
|
||||
"eos_token_id": 2,
|
||||
"hidden_act": "gelu",
|
||||
"hidden_dropout_prob": 0.1,
|
||||
"hidden_size": 1024,
|
||||
"initializer_range": 0.02,
|
||||
"intermediate_size": 4096,
|
||||
"layer_norm_eps": 1e-12,
|
||||
"max_position_embeddings": 512,
|
||||
"model_type": "bert",
|
||||
"num_attention_heads": 16,
|
||||
"num_hidden_layers": 24,
|
||||
"output_past": true,
|
||||
"pad_token_id": 0,
|
||||
"pooler_fc_size": 768,
|
||||
"pooler_num_attention_heads": 12,
|
||||
"pooler_num_fc_layers": 3,
|
||||
"pooler_size_per_head": 128,
|
||||
"pooler_type": "first_token_transform",
|
||||
"position_embedding_type": "absolute",
|
||||
"torch_dtype": "float32",
|
||||
"transformers_version": "4.22.1",
|
||||
"type_vocab_size": 2,
|
||||
"use_cache": true,
|
||||
"vocab_size": 47020
|
||||
}
|
||||
|
@@ -0,0 +1,7 @@
|
||||
{
|
||||
"cls_token": "[CLS]",
|
||||
"mask_token": "[MASK]",
|
||||
"pad_token": "[PAD]",
|
||||
"sep_token": "[SEP]",
|
||||
"unk_token": "[UNK]"
|
||||
}
|
@@ -0,0 +1,16 @@
|
||||
{
|
||||
"cls_token": "[CLS]",
|
||||
"do_basic_tokenize": true,
|
||||
"do_lower_case": true,
|
||||
"mask_token": "[MASK]",
|
||||
"name_or_path": "hfl/chinese-roberta-wwm-ext",
|
||||
"never_split": null,
|
||||
"pad_token": "[PAD]",
|
||||
"sep_token": "[SEP]",
|
||||
"special_tokens_map_file": "/home/chenweifeng/.cache/huggingface/hub/models--hfl--chinese-roberta-wwm-ext/snapshots/5c58d0b8ec1d9014354d691c538661bf00bfdb44/special_tokens_map.json",
|
||||
"strip_accents": null,
|
||||
"tokenize_chinese_chars": true,
|
||||
"tokenizer_class": "BertTokenizer",
|
||||
"unk_token": "[UNK]",
|
||||
"model_max_length": 77
|
||||
}
|
47020
comfy/text_encoders/hydit_clip_tokenizer/vocab.txt
Normal file
47020
comfy/text_encoders/hydit_clip_tokenizer/vocab.txt
Normal file
File diff suppressed because it is too large
Load Diff
@@ -1,22 +0,0 @@
|
||||
import os
|
||||
|
||||
class LLAMATokenizer:
|
||||
@staticmethod
|
||||
def from_pretrained(path):
|
||||
return LLAMATokenizer(path)
|
||||
|
||||
def __init__(self, tokenizer_path):
|
||||
import sentencepiece
|
||||
self.tokenizer = sentencepiece.SentencePieceProcessor(model_file=tokenizer_path)
|
||||
self.end = self.tokenizer.eos_id()
|
||||
|
||||
def get_vocab(self):
|
||||
out = {}
|
||||
for i in range(self.tokenizer.get_piece_size()):
|
||||
out[self.tokenizer.id_to_piece(i)] = i
|
||||
return out
|
||||
|
||||
def __call__(self, string):
|
||||
out = self.tokenizer.encode(string)
|
||||
out += [self.end]
|
||||
return {"input_ids": out}
|
22
comfy/text_encoders/mt5_config_xl.json
Normal file
22
comfy/text_encoders/mt5_config_xl.json
Normal file
@@ -0,0 +1,22 @@
|
||||
{
|
||||
"d_ff": 5120,
|
||||
"d_kv": 64,
|
||||
"d_model": 2048,
|
||||
"decoder_start_token_id": 0,
|
||||
"dropout_rate": 0.1,
|
||||
"eos_token_id": 1,
|
||||
"dense_act_fn": "gelu_pytorch_tanh",
|
||||
"initializer_factor": 1.0,
|
||||
"is_encoder_decoder": true,
|
||||
"is_gated_act": true,
|
||||
"layer_norm_epsilon": 1e-06,
|
||||
"model_type": "mt5",
|
||||
"num_decoder_layers": 24,
|
||||
"num_heads": 32,
|
||||
"num_layers": 24,
|
||||
"output_past": true,
|
||||
"pad_token_id": 0,
|
||||
"relative_attention_num_buckets": 32,
|
||||
"tie_word_embeddings": false,
|
||||
"vocab_size": 250112
|
||||
}
|
@@ -1,21 +1,21 @@
|
||||
from comfy import sd1_clip
|
||||
from transformers import T5TokenizerFast
|
||||
import comfy.t5
|
||||
import comfy.text_encoders.t5
|
||||
import os
|
||||
|
||||
class T5BaseModel(sd1_clip.SDClipModel):
|
||||
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None):
|
||||
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_config_base.json")
|
||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.t5.T5, enable_attention_masks=True, zero_out_masked=True)
|
||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=True, zero_out_masked=True)
|
||||
|
||||
class T5BaseTokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, embedding_directory=None):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
|
||||
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=768, embedding_key='t5base', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=128)
|
||||
|
||||
class SAT5Tokenizer(sd1_clip.SD1Tokenizer):
|
||||
def __init__(self, embedding_directory=None):
|
||||
super().__init__(embedding_directory=embedding_directory, clip_name="t5base", tokenizer=T5BaseTokenizer)
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="t5base", tokenizer=T5BaseTokenizer)
|
||||
|
||||
class SAT5Model(sd1_clip.SD1ClipModel):
|
||||
def __init__(self, device="cpu", dtype=None, **kwargs):
|
@@ -11,12 +11,12 @@ class SD2ClipHModel(sd1_clip.SDClipModel):
|
||||
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"start": 49406, "end": 49407, "pad": 0})
|
||||
|
||||
class SD2ClipHTokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, tokenizer_path=None, embedding_directory=None):
|
||||
def __init__(self, tokenizer_path=None, embedding_directory=None, tokenizer_data={}):
|
||||
super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=1024)
|
||||
|
||||
class SD2Tokenizer(sd1_clip.SD1Tokenizer):
|
||||
def __init__(self, embedding_directory=None):
|
||||
super().__init__(embedding_directory=embedding_directory, clip_name="h", tokenizer=SD2ClipHTokenizer)
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="h", tokenizer=SD2ClipHTokenizer)
|
||||
|
||||
class SD2ClipModel(sd1_clip.SD1ClipModel):
|
||||
def __init__(self, device="cpu", dtype=None, **kwargs):
|
@@ -5,7 +5,7 @@
|
||||
"attention_dropout": 0.0,
|
||||
"bos_token_id": 0,
|
||||
"dropout": 0.0,
|
||||
"eos_token_id": 2,
|
||||
"eos_token_id": 49407,
|
||||
"hidden_act": "gelu",
|
||||
"hidden_size": 1024,
|
||||
"initializer_factor": 1.0,
|
@@ -1,7 +1,7 @@
|
||||
from comfy import sd1_clip
|
||||
from comfy import sdxl_clip
|
||||
from transformers import T5TokenizerFast
|
||||
import comfy.t5
|
||||
import comfy.text_encoders.t5
|
||||
import torch
|
||||
import os
|
||||
import comfy.model_management
|
||||
@@ -10,25 +10,16 @@ import logging
|
||||
class T5XXLModel(sd1_clip.SDClipModel):
|
||||
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None):
|
||||
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_config_xxl.json")
|
||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.t5.T5)
|
||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5)
|
||||
|
||||
class T5XXLTokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, embedding_directory=None):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
|
||||
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=77)
|
||||
|
||||
class SDT5XXLTokenizer(sd1_clip.SD1Tokenizer):
|
||||
def __init__(self, embedding_directory=None):
|
||||
super().__init__(embedding_directory=embedding_directory, clip_name="t5xxl", tokenizer=T5XXLTokenizer)
|
||||
|
||||
class SDT5XXLModel(sd1_clip.SD1ClipModel):
|
||||
def __init__(self, device="cpu", dtype=None, **kwargs):
|
||||
super().__init__(device=device, dtype=dtype, clip_name="t5xxl", clip_model=T5XXLModel, **kwargs)
|
||||
|
||||
|
||||
|
||||
class SD3Tokenizer:
|
||||
def __init__(self, embedding_directory=None):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory)
|
||||
self.clip_g = sdxl_clip.SDXLClipGTokenizer(embedding_directory=embedding_directory)
|
||||
self.t5xxl = T5XXLTokenizer(embedding_directory=embedding_directory)
|
||||
@@ -43,6 +34,9 @@ class SD3Tokenizer:
|
||||
def untokenize(self, token_weight_pair):
|
||||
return self.clip_g.untokenize(token_weight_pair)
|
||||
|
||||
def state_dict(self):
|
||||
return {}
|
||||
|
||||
class SD3ClipModel(torch.nn.Module):
|
||||
def __init__(self, clip_l=True, clip_g=True, t5=True, dtype_t5=None, device="cpu", dtype=None):
|
||||
super().__init__()
|
||||
@@ -60,14 +54,7 @@ class SD3ClipModel(torch.nn.Module):
|
||||
self.clip_g = None
|
||||
|
||||
if t5:
|
||||
if dtype_t5 is None:
|
||||
dtype_t5 = dtype
|
||||
elif comfy.model_management.dtype_size(dtype_t5) > comfy.model_management.dtype_size(dtype):
|
||||
dtype_t5 = dtype
|
||||
|
||||
if not comfy.model_management.supports_cast(device, dtype_t5):
|
||||
dtype_t5 = dtype
|
||||
|
||||
dtype_t5 = comfy.model_management.pick_weight_dtype(dtype_t5, dtype, device)
|
||||
self.t5xxl = T5XXLModel(device=device, dtype=dtype_t5)
|
||||
self.dtypes.add(dtype_t5)
|
||||
else:
|
||||
@@ -94,7 +81,7 @@ class SD3ClipModel(torch.nn.Module):
|
||||
def encode_token_weights(self, token_weight_pairs):
|
||||
token_weight_pairs_l = token_weight_pairs["l"]
|
||||
token_weight_pairs_g = token_weight_pairs["g"]
|
||||
token_weight_pars_t5 = token_weight_pairs["t5xxl"]
|
||||
token_weight_pairs_t5 = token_weight_pairs["t5xxl"]
|
||||
lg_out = None
|
||||
pooled = None
|
||||
out = None
|
||||
@@ -121,7 +108,7 @@ class SD3ClipModel(torch.nn.Module):
|
||||
pooled = torch.cat((l_pooled, g_pooled), dim=-1)
|
||||
|
||||
if self.t5xxl is not None:
|
||||
t5_out, t5_pooled = self.t5xxl.encode_token_weights(token_weight_pars_t5)
|
||||
t5_out, t5_pooled = self.t5xxl.encode_token_weights(token_weight_pairs_t5)
|
||||
if lg_out is not None:
|
||||
out = torch.cat([lg_out, t5_out], dim=-2)
|
||||
else:
|
32
comfy/text_encoders/spiece_tokenizer.py
Normal file
32
comfy/text_encoders/spiece_tokenizer.py
Normal file
@@ -0,0 +1,32 @@
|
||||
import os
|
||||
import torch
|
||||
|
||||
class SPieceTokenizer:
|
||||
add_eos = True
|
||||
|
||||
@staticmethod
|
||||
def from_pretrained(path):
|
||||
return SPieceTokenizer(path)
|
||||
|
||||
def __init__(self, tokenizer_path):
|
||||
import sentencepiece
|
||||
if torch.is_tensor(tokenizer_path):
|
||||
tokenizer_path = tokenizer_path.numpy().tobytes()
|
||||
|
||||
if isinstance(tokenizer_path, bytes):
|
||||
self.tokenizer = sentencepiece.SentencePieceProcessor(model_proto=tokenizer_path, add_eos=self.add_eos)
|
||||
else:
|
||||
self.tokenizer = sentencepiece.SentencePieceProcessor(model_file=tokenizer_path, add_eos=self.add_eos)
|
||||
|
||||
def get_vocab(self):
|
||||
out = {}
|
||||
for i in range(self.tokenizer.get_piece_size()):
|
||||
out[self.tokenizer.id_to_piece(i)] = i
|
||||
return out
|
||||
|
||||
def __call__(self, string):
|
||||
out = self.tokenizer.encode(string)
|
||||
return {"input_ids": out}
|
||||
|
||||
def serialize_model(self):
|
||||
return torch.ByteTensor(list(self.tokenizer.serialized_model_proto()))
|
@@ -1,6 +1,7 @@
|
||||
import torch
|
||||
import math
|
||||
from comfy.ldm.modules.attention import optimized_attention_for_device
|
||||
import comfy.ops
|
||||
|
||||
class T5LayerNorm(torch.nn.Module):
|
||||
def __init__(self, hidden_size, eps=1e-6, dtype=None, device=None, operations=None):
|
||||
@@ -11,7 +12,7 @@ class T5LayerNorm(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
variance = x.pow(2).mean(-1, keepdim=True)
|
||||
x = x * torch.rsqrt(variance + self.variance_epsilon)
|
||||
return self.weight.to(device=x.device, dtype=x.dtype) * x
|
||||
return comfy.ops.cast_to_input(self.weight, x) * x
|
||||
|
||||
activations = {
|
||||
"gelu_pytorch_tanh": lambda a: torch.nn.functional.gelu(a, approximate="tanh"),
|
||||
@@ -82,7 +83,7 @@ class T5Attention(torch.nn.Module):
|
||||
if relative_attention_bias:
|
||||
self.relative_attention_num_buckets = 32
|
||||
self.relative_attention_max_distance = 128
|
||||
self.relative_attention_bias = torch.nn.Embedding(self.relative_attention_num_buckets, self.num_heads, device=device)
|
||||
self.relative_attention_bias = operations.Embedding(self.relative_attention_num_buckets, self.num_heads, device=device, dtype=dtype)
|
||||
|
||||
@staticmethod
|
||||
def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128):
|
||||
@@ -132,7 +133,7 @@ class T5Attention(torch.nn.Module):
|
||||
relative_buckets += torch.where(is_small, relative_position, relative_position_if_large)
|
||||
return relative_buckets
|
||||
|
||||
def compute_bias(self, query_length, key_length, device):
|
||||
def compute_bias(self, query_length, key_length, device, dtype):
|
||||
"""Compute binned relative position bias"""
|
||||
context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None]
|
||||
memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :]
|
||||
@@ -143,7 +144,7 @@ class T5Attention(torch.nn.Module):
|
||||
num_buckets=self.relative_attention_num_buckets,
|
||||
max_distance=self.relative_attention_max_distance,
|
||||
)
|
||||
values = self.relative_attention_bias(relative_position_bucket) # shape (query_length, key_length, num_heads)
|
||||
values = self.relative_attention_bias(relative_position_bucket, out_dtype=dtype) # shape (query_length, key_length, num_heads)
|
||||
values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length)
|
||||
return values
|
||||
|
||||
@@ -152,7 +153,7 @@ class T5Attention(torch.nn.Module):
|
||||
k = self.k(x)
|
||||
v = self.v(x)
|
||||
if self.relative_attention_bias is not None:
|
||||
past_bias = self.compute_bias(x.shape[1], x.shape[1], x.device)
|
||||
past_bias = self.compute_bias(x.shape[1], x.shape[1], x.device, x.dtype)
|
||||
|
||||
if past_bias is not None:
|
||||
if mask is not None:
|
||||
@@ -199,7 +200,7 @@ class T5Stack(torch.nn.Module):
|
||||
self.final_layer_norm = T5LayerNorm(model_dim, dtype=dtype, device=device, operations=operations)
|
||||
# self.dropout = nn.Dropout(config.dropout_rate)
|
||||
|
||||
def forward(self, x, attention_mask=None, intermediate_output=None, final_layer_norm_intermediate=True):
|
||||
def forward(self, x, attention_mask=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None):
|
||||
mask = None
|
||||
if attention_mask is not None:
|
||||
mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1])
|
||||
@@ -223,9 +224,9 @@ class T5(torch.nn.Module):
|
||||
self.num_layers = config_dict["num_layers"]
|
||||
model_dim = config_dict["d_model"]
|
||||
|
||||
self.encoder = T5Stack(self.num_layers, model_dim, model_dim, config_dict["d_ff"], config_dict["dense_act_fn"], config_dict["is_gated_act"], config_dict["num_heads"], config_dict["model_type"] == "t5", dtype, device, operations)
|
||||
self.encoder = T5Stack(self.num_layers, model_dim, model_dim, config_dict["d_ff"], config_dict["dense_act_fn"], config_dict["is_gated_act"], config_dict["num_heads"], config_dict["model_type"] != "umt5", dtype, device, operations)
|
||||
self.dtype = dtype
|
||||
self.shared = torch.nn.Embedding(config_dict["vocab_size"], model_dim, device=device)
|
||||
self.shared = operations.Embedding(config_dict["vocab_size"], model_dim, device=device, dtype=dtype)
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.shared
|
||||
@@ -234,5 +235,7 @@ class T5(torch.nn.Module):
|
||||
self.shared = embeddings
|
||||
|
||||
def forward(self, input_ids, *args, **kwargs):
|
||||
x = self.shared(input_ids)
|
||||
x = self.shared(input_ids, out_dtype=kwargs.get("dtype", torch.float32))
|
||||
if self.dtype not in [torch.float32, torch.float16, torch.bfloat16]:
|
||||
x = torch.nan_to_num(x) #Fix for fp8 T5 base
|
||||
return self.encoder(x, *args, **kwargs)
|
140
comfy/utils.py
140
comfy/utils.py
@@ -1,3 +1,22 @@
|
||||
"""
|
||||
This file is part of ComfyUI.
|
||||
Copyright (C) 2024 Comfy
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU General Public License as published by
|
||||
the Free Software Foundation, either version 3 of the License, or
|
||||
(at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU General Public License
|
||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
"""
|
||||
|
||||
|
||||
import torch
|
||||
import math
|
||||
import struct
|
||||
@@ -11,7 +30,7 @@ import itertools
|
||||
def load_torch_file(ckpt, safe_load=False, device=None):
|
||||
if device is None:
|
||||
device = torch.device("cpu")
|
||||
if ckpt.lower().endswith(".safetensors"):
|
||||
if ckpt.lower().endswith(".safetensors") or ckpt.lower().endswith(".sft"):
|
||||
sd = safetensors.torch.load_file(ckpt, device=device.type)
|
||||
else:
|
||||
if safe_load:
|
||||
@@ -40,9 +59,22 @@ def calculate_parameters(sd, prefix=""):
|
||||
params = 0
|
||||
for k in sd.keys():
|
||||
if k.startswith(prefix):
|
||||
params += sd[k].nelement()
|
||||
w = sd[k]
|
||||
params += w.nelement()
|
||||
return params
|
||||
|
||||
def weight_dtype(sd, prefix=""):
|
||||
dtypes = {}
|
||||
for k in sd.keys():
|
||||
if k.startswith(prefix):
|
||||
w = sd[k]
|
||||
dtypes[w.dtype] = dtypes.get(w.dtype, 0) + 1
|
||||
|
||||
if len(dtypes) == 0:
|
||||
return None
|
||||
|
||||
return max(dtypes, key=dtypes.get)
|
||||
|
||||
def state_dict_key_replace(state_dict, keys_to_replace):
|
||||
for x in keys_to_replace:
|
||||
if x in state_dict:
|
||||
@@ -402,6 +434,110 @@ def auraflow_to_diffusers(mmdit_config, output_prefix=""):
|
||||
|
||||
return key_map
|
||||
|
||||
def flux_to_diffusers(mmdit_config, output_prefix=""):
|
||||
n_double_layers = mmdit_config.get("depth", 0)
|
||||
n_single_layers = mmdit_config.get("depth_single_blocks", 0)
|
||||
hidden_size = mmdit_config.get("hidden_size", 0)
|
||||
|
||||
key_map = {}
|
||||
for index in range(n_double_layers):
|
||||
prefix_from = "transformer_blocks.{}".format(index)
|
||||
prefix_to = "{}double_blocks.{}".format(output_prefix, index)
|
||||
|
||||
for end in ("weight", "bias"):
|
||||
k = "{}.attn.".format(prefix_from)
|
||||
qkv = "{}.img_attn.qkv.{}".format(prefix_to, end)
|
||||
key_map["{}to_q.{}".format(k, end)] = (qkv, (0, 0, hidden_size))
|
||||
key_map["{}to_k.{}".format(k, end)] = (qkv, (0, hidden_size, hidden_size))
|
||||
key_map["{}to_v.{}".format(k, end)] = (qkv, (0, hidden_size * 2, hidden_size))
|
||||
|
||||
k = "{}.attn.".format(prefix_from)
|
||||
qkv = "{}.txt_attn.qkv.{}".format(prefix_to, end)
|
||||
key_map["{}add_q_proj.{}".format(k, end)] = (qkv, (0, 0, hidden_size))
|
||||
key_map["{}add_k_proj.{}".format(k, end)] = (qkv, (0, hidden_size, hidden_size))
|
||||
key_map["{}add_v_proj.{}".format(k, end)] = (qkv, (0, hidden_size * 2, hidden_size))
|
||||
|
||||
block_map = {
|
||||
"attn.to_out.0.weight": "img_attn.proj.weight",
|
||||
"attn.to_out.0.bias": "img_attn.proj.bias",
|
||||
"norm1.linear.weight": "img_mod.lin.weight",
|
||||
"norm1.linear.bias": "img_mod.lin.bias",
|
||||
"norm1_context.linear.weight": "txt_mod.lin.weight",
|
||||
"norm1_context.linear.bias": "txt_mod.lin.bias",
|
||||
"attn.to_add_out.weight": "txt_attn.proj.weight",
|
||||
"attn.to_add_out.bias": "txt_attn.proj.bias",
|
||||
"ff.net.0.proj.weight": "img_mlp.0.weight",
|
||||
"ff.net.0.proj.bias": "img_mlp.0.bias",
|
||||
"ff.net.2.weight": "img_mlp.2.weight",
|
||||
"ff.net.2.bias": "img_mlp.2.bias",
|
||||
"ff_context.net.0.proj.weight": "txt_mlp.0.weight",
|
||||
"ff_context.net.0.proj.bias": "txt_mlp.0.bias",
|
||||
"ff_context.net.2.weight": "txt_mlp.2.weight",
|
||||
"ff_context.net.2.bias": "txt_mlp.2.bias",
|
||||
"attn.norm_q.weight": "img_attn.norm.query_norm.scale",
|
||||
"attn.norm_k.weight": "img_attn.norm.key_norm.scale",
|
||||
"attn.norm_added_q.weight": "txt_attn.norm.query_norm.scale",
|
||||
"attn.norm_added_k.weight": "txt_attn.norm.key_norm.scale",
|
||||
}
|
||||
|
||||
for k in block_map:
|
||||
key_map["{}.{}".format(prefix_from, k)] = "{}.{}".format(prefix_to, block_map[k])
|
||||
|
||||
for index in range(n_single_layers):
|
||||
prefix_from = "single_transformer_blocks.{}".format(index)
|
||||
prefix_to = "{}single_blocks.{}".format(output_prefix, index)
|
||||
|
||||
for end in ("weight", "bias"):
|
||||
k = "{}.attn.".format(prefix_from)
|
||||
qkv = "{}.linear1.{}".format(prefix_to, end)
|
||||
key_map["{}to_q.{}".format(k, end)] = (qkv, (0, 0, hidden_size))
|
||||
key_map["{}to_k.{}".format(k, end)] = (qkv, (0, hidden_size, hidden_size))
|
||||
key_map["{}to_v.{}".format(k, end)] = (qkv, (0, hidden_size * 2, hidden_size))
|
||||
key_map["{}.proj_mlp.{}".format(prefix_from, end)] = (qkv, (0, hidden_size * 3, hidden_size * 4))
|
||||
|
||||
block_map = {
|
||||
"norm.linear.weight": "modulation.lin.weight",
|
||||
"norm.linear.bias": "modulation.lin.bias",
|
||||
"proj_out.weight": "linear2.weight",
|
||||
"proj_out.bias": "linear2.bias",
|
||||
"attn.norm_q.weight": "norm.query_norm.scale",
|
||||
"attn.norm_k.weight": "norm.key_norm.scale",
|
||||
}
|
||||
|
||||
for k in block_map:
|
||||
key_map["{}.{}".format(prefix_from, k)] = "{}.{}".format(prefix_to, block_map[k])
|
||||
|
||||
MAP_BASIC = {
|
||||
("final_layer.linear.bias", "proj_out.bias"),
|
||||
("final_layer.linear.weight", "proj_out.weight"),
|
||||
("img_in.bias", "x_embedder.bias"),
|
||||
("img_in.weight", "x_embedder.weight"),
|
||||
("time_in.in_layer.bias", "time_text_embed.timestep_embedder.linear_1.bias"),
|
||||
("time_in.in_layer.weight", "time_text_embed.timestep_embedder.linear_1.weight"),
|
||||
("time_in.out_layer.bias", "time_text_embed.timestep_embedder.linear_2.bias"),
|
||||
("time_in.out_layer.weight", "time_text_embed.timestep_embedder.linear_2.weight"),
|
||||
("txt_in.bias", "context_embedder.bias"),
|
||||
("txt_in.weight", "context_embedder.weight"),
|
||||
("vector_in.in_layer.bias", "time_text_embed.text_embedder.linear_1.bias"),
|
||||
("vector_in.in_layer.weight", "time_text_embed.text_embedder.linear_1.weight"),
|
||||
("vector_in.out_layer.bias", "time_text_embed.text_embedder.linear_2.bias"),
|
||||
("vector_in.out_layer.weight", "time_text_embed.text_embedder.linear_2.weight"),
|
||||
("guidance_in.in_layer.bias", "time_text_embed.guidance_embedder.linear_1.bias"),
|
||||
("guidance_in.in_layer.weight", "time_text_embed.guidance_embedder.linear_1.weight"),
|
||||
("guidance_in.out_layer.bias", "time_text_embed.guidance_embedder.linear_2.bias"),
|
||||
("guidance_in.out_layer.weight", "time_text_embed.guidance_embedder.linear_2.weight"),
|
||||
("final_layer.adaLN_modulation.1.bias", "norm_out.linear.bias", swap_scale_shift),
|
||||
("final_layer.adaLN_modulation.1.weight", "norm_out.linear.weight", swap_scale_shift),
|
||||
}
|
||||
|
||||
for k in MAP_BASIC:
|
||||
if len(k) > 2:
|
||||
key_map[k[1]] = ("{}{}".format(output_prefix, k[0]), None, k[2])
|
||||
else:
|
||||
key_map[k[1]] = "{}{}".format(output_prefix, k[0])
|
||||
|
||||
return key_map
|
||||
|
||||
def repeat_to_batch_size(tensor, batch_size, dim=0):
|
||||
if tensor.shape[dim] > batch_size:
|
||||
return tensor.narrow(dim, 0, batch_size)
|
||||
|
@@ -7,6 +7,7 @@ import io
|
||||
import json
|
||||
import struct
|
||||
import random
|
||||
import hashlib
|
||||
from comfy.cli_args import args
|
||||
|
||||
class EmptyLatentAudio:
|
||||
@@ -147,7 +148,7 @@ class SaveAudio:
|
||||
for x in extra_pnginfo:
|
||||
metadata[x] = json.dumps(extra_pnginfo[x])
|
||||
|
||||
for (batch_number, waveform) in enumerate(audio["waveform"]):
|
||||
for (batch_number, waveform) in enumerate(audio["waveform"].cpu()):
|
||||
filename_with_batch_num = filename.replace("%batch_num%", str(batch_number))
|
||||
file = f"{filename_with_batch_num}_{counter:05}_.flac"
|
||||
|
||||
|
27
comfy_extras/nodes_controlnet.py
Normal file
27
comfy_extras/nodes_controlnet.py
Normal file
@@ -0,0 +1,27 @@
|
||||
from comfy.cldm.control_types import UNION_CONTROLNET_TYPES
|
||||
|
||||
class SetUnionControlNetType:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"control_net": ("CONTROL_NET", ),
|
||||
"type": (["auto"] + list(UNION_CONTROLNET_TYPES.keys()),)
|
||||
}}
|
||||
|
||||
CATEGORY = "conditioning/controlnet"
|
||||
RETURN_TYPES = ("CONTROL_NET",)
|
||||
|
||||
FUNCTION = "set_controlnet_type"
|
||||
|
||||
def set_controlnet_type(self, control_net, type):
|
||||
control_net = control_net.copy()
|
||||
type_number = UNION_CONTROLNET_TYPES.get(type, -1)
|
||||
if type_number >= 0:
|
||||
control_net.set_extra_arg("control_type", [type_number])
|
||||
else:
|
||||
control_net.set_extra_arg("control_type", [])
|
||||
|
||||
return (control_net,)
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"SetUnionControlNetType": SetUnionControlNetType,
|
||||
}
|
@@ -111,6 +111,25 @@ class SDTurboScheduler:
|
||||
sigmas = torch.cat([sigmas, sigmas.new_zeros([1])])
|
||||
return (sigmas, )
|
||||
|
||||
class BetaSamplingScheduler:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required":
|
||||
{"model": ("MODEL",),
|
||||
"steps": ("INT", {"default": 20, "min": 1, "max": 10000}),
|
||||
"alpha": ("FLOAT", {"default": 0.6, "min": 0.0, "max": 50.0, "step":0.01, "round": False}),
|
||||
"beta": ("FLOAT", {"default": 0.6, "min": 0.0, "max": 50.0, "step":0.01, "round": False}),
|
||||
}
|
||||
}
|
||||
RETURN_TYPES = ("SIGMAS",)
|
||||
CATEGORY = "sampling/custom_sampling/schedulers"
|
||||
|
||||
FUNCTION = "get_sigmas"
|
||||
|
||||
def get_sigmas(self, model, steps, alpha, beta):
|
||||
sigmas = comfy.samplers.beta_scheduler(model.get_model_object("model_sampling"), steps, alpha=alpha, beta=beta)
|
||||
return (sigmas, )
|
||||
|
||||
class VPScheduler:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
@@ -276,6 +295,23 @@ class SamplerDPMPP_SDE:
|
||||
sampler = comfy.samplers.ksampler(sampler_name, {"eta": eta, "s_noise": s_noise, "r": r})
|
||||
return (sampler, )
|
||||
|
||||
class SamplerDPMPP_2S_Ancestral:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required":
|
||||
{"eta": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}),
|
||||
"s_noise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}),
|
||||
}
|
||||
}
|
||||
RETURN_TYPES = ("SAMPLER",)
|
||||
CATEGORY = "sampling/custom_sampling/samplers"
|
||||
|
||||
FUNCTION = "get_sampler"
|
||||
|
||||
def get_sampler(self, eta, s_noise):
|
||||
sampler = comfy.samplers.ksampler("dpmpp_2s_ancestral", {"eta": eta, "s_noise": s_noise})
|
||||
return (sampler, )
|
||||
|
||||
class SamplerEulerAncestral:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
@@ -638,6 +674,7 @@ NODE_CLASS_MAPPINGS = {
|
||||
"ExponentialScheduler": ExponentialScheduler,
|
||||
"PolyexponentialScheduler": PolyexponentialScheduler,
|
||||
"VPScheduler": VPScheduler,
|
||||
"BetaSamplingScheduler": BetaSamplingScheduler,
|
||||
"SDTurboScheduler": SDTurboScheduler,
|
||||
"KSamplerSelect": KSamplerSelect,
|
||||
"SamplerEulerAncestral": SamplerEulerAncestral,
|
||||
@@ -646,6 +683,7 @@ NODE_CLASS_MAPPINGS = {
|
||||
"SamplerDPMPP_3M_SDE": SamplerDPMPP_3M_SDE,
|
||||
"SamplerDPMPP_2M_SDE": SamplerDPMPP_2M_SDE,
|
||||
"SamplerDPMPP_SDE": SamplerDPMPP_SDE,
|
||||
"SamplerDPMPP_2S_Ancestral": SamplerDPMPP_2S_Ancestral,
|
||||
"SamplerDPMAdaptative": SamplerDPMAdaptative,
|
||||
"SplitSigmas": SplitSigmas,
|
||||
"SplitSigmasDenoise": SplitSigmasDenoise,
|
||||
@@ -662,4 +700,4 @@ NODE_CLASS_MAPPINGS = {
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"SamplerEulerAncestralCFGPP": "SamplerEulerAncestralCFG++",
|
||||
}
|
||||
}
|
||||
|
47
comfy_extras/nodes_flux.py
Normal file
47
comfy_extras/nodes_flux.py
Normal file
@@ -0,0 +1,47 @@
|
||||
import node_helpers
|
||||
|
||||
class CLIPTextEncodeFlux:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {
|
||||
"clip": ("CLIP", ),
|
||||
"clip_l": ("STRING", {"multiline": True, "dynamicPrompts": True}),
|
||||
"t5xxl": ("STRING", {"multiline": True, "dynamicPrompts": True}),
|
||||
"guidance": ("FLOAT", {"default": 3.5, "min": 0.0, "max": 100.0, "step": 0.1}),
|
||||
}}
|
||||
RETURN_TYPES = ("CONDITIONING",)
|
||||
FUNCTION = "encode"
|
||||
|
||||
CATEGORY = "advanced/conditioning/flux"
|
||||
|
||||
def encode(self, clip, clip_l, t5xxl, guidance):
|
||||
tokens = clip.tokenize(clip_l)
|
||||
tokens["t5xxl"] = clip.tokenize(t5xxl)["t5xxl"]
|
||||
|
||||
output = clip.encode_from_tokens(tokens, return_pooled=True, return_dict=True)
|
||||
cond = output.pop("cond")
|
||||
output["guidance"] = guidance
|
||||
return ([[cond, output]], )
|
||||
|
||||
class FluxGuidance:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {
|
||||
"conditioning": ("CONDITIONING", ),
|
||||
"guidance": ("FLOAT", {"default": 3.5, "min": 0.0, "max": 100.0, "step": 0.1}),
|
||||
}}
|
||||
|
||||
RETURN_TYPES = ("CONDITIONING",)
|
||||
FUNCTION = "append"
|
||||
|
||||
CATEGORY = "advanced/conditioning/flux"
|
||||
|
||||
def append(self, conditioning, guidance):
|
||||
c = node_helpers.conditioning_set_values(conditioning, {"guidance": guidance})
|
||||
return (c, )
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"CLIPTextEncodeFlux": CLIPTextEncodeFlux,
|
||||
"FluxGuidance": FluxGuidance,
|
||||
}
|
@@ -34,7 +34,7 @@ class FreeU:
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
FUNCTION = "patch"
|
||||
|
||||
CATEGORY = "model_patches"
|
||||
CATEGORY = "model_patches/unet"
|
||||
|
||||
def patch(self, model, b1, b2, s1, s2):
|
||||
model_channels = model.model.model_config.unet_config["model_channels"]
|
||||
@@ -73,7 +73,7 @@ class FreeU_V2:
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
FUNCTION = "patch"
|
||||
|
||||
CATEGORY = "model_patches"
|
||||
CATEGORY = "model_patches/unet"
|
||||
|
||||
def patch(self, model, b1, b2, s1, s2):
|
||||
model_channels = model.model.model_config.unet_config["model_channels"]
|
||||
|
25
comfy_extras/nodes_hunyuan.py
Normal file
25
comfy_extras/nodes_hunyuan.py
Normal file
@@ -0,0 +1,25 @@
|
||||
class CLIPTextEncodeHunyuanDiT:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {
|
||||
"clip": ("CLIP", ),
|
||||
"bert": ("STRING", {"multiline": True, "dynamicPrompts": True}),
|
||||
"mt5xl": ("STRING", {"multiline": True, "dynamicPrompts": True}),
|
||||
}}
|
||||
RETURN_TYPES = ("CONDITIONING",)
|
||||
FUNCTION = "encode"
|
||||
|
||||
CATEGORY = "advanced/conditioning"
|
||||
|
||||
def encode(self, clip, bert, mt5xl):
|
||||
tokens = clip.tokenize(bert)
|
||||
tokens["mt5xl"] = clip.tokenize(mt5xl)["mt5xl"]
|
||||
|
||||
output = clip.encode_from_tokens(tokens, return_pooled=True, return_dict=True)
|
||||
cond = output.pop("cond")
|
||||
return ([[cond, output]], )
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"CLIPTextEncodeHunyuanDiT": CLIPTextEncodeHunyuanDiT,
|
||||
}
|
@@ -32,7 +32,7 @@ class HyperTile:
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
FUNCTION = "patch"
|
||||
|
||||
CATEGORY = "model_patches"
|
||||
CATEGORY = "model_patches/unet"
|
||||
|
||||
def patch(self, model, tile_size, swap_size, max_depth, scale_depth):
|
||||
model_channels = model.model.model_config.unet_config["model_channels"]
|
||||
|
@@ -2,6 +2,7 @@ import folder_paths
|
||||
import comfy.sd
|
||||
import comfy.model_sampling
|
||||
import comfy.latent_formats
|
||||
import nodes
|
||||
import torch
|
||||
|
||||
class LCM(comfy.model_sampling.EPS):
|
||||
@@ -170,6 +171,42 @@ class ModelSamplingAuraFlow(ModelSamplingSD3):
|
||||
def patch_aura(self, model, shift):
|
||||
return self.patch(model, shift, multiplier=1.0)
|
||||
|
||||
class ModelSamplingFlux:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "model": ("MODEL",),
|
||||
"max_shift": ("FLOAT", {"default": 1.15, "min": 0.0, "max": 100.0, "step":0.01}),
|
||||
"base_shift": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 100.0, "step":0.01}),
|
||||
"width": ("INT", {"default": 1024, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}),
|
||||
"height": ("INT", {"default": 1024, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}),
|
||||
}}
|
||||
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
FUNCTION = "patch"
|
||||
|
||||
CATEGORY = "advanced/model"
|
||||
|
||||
def patch(self, model, max_shift, base_shift, width, height):
|
||||
m = model.clone()
|
||||
|
||||
x1 = 256
|
||||
x2 = 4096
|
||||
mm = (max_shift - base_shift) / (x2 - x1)
|
||||
b = base_shift - mm * x1
|
||||
shift = (width * height / (8 * 8 * 2 * 2)) * mm + b
|
||||
|
||||
sampling_base = comfy.model_sampling.ModelSamplingFlux
|
||||
sampling_type = comfy.model_sampling.CONST
|
||||
|
||||
class ModelSamplingAdvanced(sampling_base, sampling_type):
|
||||
pass
|
||||
|
||||
model_sampling = ModelSamplingAdvanced(model.model.model_config)
|
||||
model_sampling.set_parameters(shift=shift)
|
||||
m.add_object_patch("model_sampling", model_sampling)
|
||||
return (m, )
|
||||
|
||||
|
||||
class ModelSamplingContinuousEDM:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
@@ -284,5 +321,6 @@ NODE_CLASS_MAPPINGS = {
|
||||
"ModelSamplingStableCascade": ModelSamplingStableCascade,
|
||||
"ModelSamplingSD3": ModelSamplingSD3,
|
||||
"ModelSamplingAuraFlow": ModelSamplingAuraFlow,
|
||||
"ModelSamplingFlux": ModelSamplingFlux,
|
||||
"RescaleCFG": RescaleCFG,
|
||||
}
|
||||
|
@@ -264,6 +264,7 @@ class CLIPSave:
|
||||
|
||||
metadata = {}
|
||||
if not args.disable_metadata:
|
||||
metadata["format"] = "pt"
|
||||
metadata["prompt"] = prompt_info
|
||||
if extra_pnginfo is not None:
|
||||
for x in extra_pnginfo:
|
||||
|
@@ -75,9 +75,36 @@ class ModelMergeSD3_2B(comfy_extras.nodes_model_merging.ModelMergeBlocks):
|
||||
|
||||
return {"required": arg_dict}
|
||||
|
||||
class ModelMergeFlux1(comfy_extras.nodes_model_merging.ModelMergeBlocks):
|
||||
CATEGORY = "advanced/model_merging/model_specific"
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
arg_dict = { "model1": ("MODEL",),
|
||||
"model2": ("MODEL",)}
|
||||
|
||||
argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
|
||||
|
||||
arg_dict["img_in."] = argument
|
||||
arg_dict["time_in."] = argument
|
||||
arg_dict["guidance_in"] = argument
|
||||
arg_dict["vector_in."] = argument
|
||||
arg_dict["txt_in."] = argument
|
||||
|
||||
for i in range(19):
|
||||
arg_dict["double_blocks.{}.".format(i)] = argument
|
||||
|
||||
for i in range(38):
|
||||
arg_dict["single_blocks.{}.".format(i)] = argument
|
||||
|
||||
arg_dict["final_layer."] = argument
|
||||
|
||||
return {"required": arg_dict}
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"ModelMergeSD1": ModelMergeSD1,
|
||||
"ModelMergeSD2": ModelMergeSD1, #SD1 and SD2 have the same blocks
|
||||
"ModelMergeSDXL": ModelMergeSDXL,
|
||||
"ModelMergeSD3_2B": ModelMergeSD3_2B,
|
||||
"ModelMergeFlux1": ModelMergeFlux1,
|
||||
}
|
||||
|
@@ -12,14 +12,14 @@ class PerturbedAttentionGuidance:
|
||||
return {
|
||||
"required": {
|
||||
"model": ("MODEL",),
|
||||
"scale": ("FLOAT", {"default": 3.0, "min": 0.0, "max": 100.0, "step": 0.1, "round": 0.01}),
|
||||
"scale": ("FLOAT", {"default": 3.0, "min": 0.0, "max": 100.0, "step": 0.01, "round": 0.01}),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
FUNCTION = "patch"
|
||||
|
||||
CATEGORY = "_for_testing"
|
||||
CATEGORY = "model_patches/unet"
|
||||
|
||||
def patch(self, model, scale):
|
||||
unet_block = "middle"
|
||||
|
@@ -96,7 +96,7 @@ class SelfAttentionGuidance:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "model": ("MODEL",),
|
||||
"scale": ("FLOAT", {"default": 0.5, "min": -2.0, "max": 5.0, "step": 0.1}),
|
||||
"scale": ("FLOAT", {"default": 0.5, "min": -2.0, "max": 5.0, "step": 0.01}),
|
||||
"blur_sigma": ("FLOAT", {"default": 2.0, "min": 0.0, "max": 10.0, "step": 0.1}),
|
||||
}}
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
|
@@ -27,8 +27,8 @@ class EmptySD3LatentImage:
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "width": ("INT", {"default": 1024, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}),
|
||||
"height": ("INT", {"default": 1024, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}),
|
||||
return {"required": { "width": ("INT", {"default": 1024, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
||||
"height": ("INT", {"default": 1024, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
||||
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096})}}
|
||||
RETURN_TYPES = ("LATENT",)
|
||||
FUNCTION = "generate"
|
||||
@@ -92,7 +92,7 @@ class ControlNetApplySD3(nodes.ControlNetApplyAdvanced):
|
||||
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
|
||||
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001})
|
||||
}}
|
||||
CATEGORY = "_for_testing/sd3"
|
||||
CATEGORY = "conditioning/controlnet"
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"TripleCLIPLoader": TripleCLIPLoader,
|
||||
@@ -100,3 +100,8 @@ NODE_CLASS_MAPPINGS = {
|
||||
"CLIPTextEncodeSD3": CLIPTextEncodeSD3,
|
||||
"ControlNetApplySD3": ControlNetApplySD3,
|
||||
}
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
# Sampling
|
||||
"ControlNetApplySD3": "ControlNetApply SD3 and HunyuanDiT",
|
||||
}
|
||||
|
17
execution.py
17
execution.py
@@ -3,6 +3,7 @@ import copy
|
||||
import logging
|
||||
import threading
|
||||
import heapq
|
||||
import time
|
||||
import traceback
|
||||
import inspect
|
||||
from typing import List, Literal, NamedTuple, Optional
|
||||
@@ -187,6 +188,11 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute
|
||||
"current_inputs": input_data_formatted,
|
||||
"current_outputs": output_data_formatted
|
||||
}
|
||||
|
||||
if isinstance(ex, comfy.model_management.OOM_EXCEPTION):
|
||||
logging.error("Got an OOM, unloading all loaded models.")
|
||||
comfy.model_management.unload_all_models()
|
||||
|
||||
return (False, error_details, ex)
|
||||
|
||||
executed.add(unique_id)
|
||||
@@ -247,6 +253,8 @@ def recursive_output_delete_if_changed(prompt, old_prompt, outputs, current_item
|
||||
to_delete = True
|
||||
elif unique_id not in old_prompt:
|
||||
to_delete = True
|
||||
elif class_type != old_prompt[unique_id]['class_type']:
|
||||
to_delete = True
|
||||
elif inputs == old_prompt[unique_id]['inputs']:
|
||||
for x in inputs:
|
||||
input_data = inputs[x]
|
||||
@@ -281,7 +289,11 @@ class PromptExecutor:
|
||||
self.success = True
|
||||
self.old_prompt = {}
|
||||
|
||||
def add_message(self, event, data, broadcast: bool):
|
||||
def add_message(self, event, data: dict, broadcast: bool):
|
||||
data = {
|
||||
**data,
|
||||
"timestamp": int(time.time() * 1000),
|
||||
}
|
||||
self.status_messages.append((event, data))
|
||||
if self.server.client_id is not None or broadcast:
|
||||
self.server.send_sync(event, data, self.server.client_id)
|
||||
@@ -392,6 +404,9 @@ class PromptExecutor:
|
||||
if self.success is not True:
|
||||
self.handle_execution_error(prompt_id, prompt, current_outputs, executed, error, ex)
|
||||
break
|
||||
else:
|
||||
# Only execute when the while-loop ends without break
|
||||
self.add_message("execution_success", { "prompt_id": prompt_id }, broadcast=False)
|
||||
|
||||
for x in executed:
|
||||
self.old_prompt[x] = copy.deepcopy(prompt[x])
|
||||
|
24
fix_torch.py
Normal file
24
fix_torch.py
Normal file
@@ -0,0 +1,24 @@
|
||||
import importlib.util
|
||||
import shutil
|
||||
import os
|
||||
import ctypes
|
||||
import logging
|
||||
|
||||
|
||||
torch_spec = importlib.util.find_spec("torch")
|
||||
for folder in torch_spec.submodule_search_locations:
|
||||
lib_folder = os.path.join(folder, "lib")
|
||||
test_file = os.path.join(lib_folder, "fbgemm.dll")
|
||||
dest = os.path.join(lib_folder, "libomp140.x86_64.dll")
|
||||
if os.path.exists(dest):
|
||||
break
|
||||
|
||||
with open(test_file, 'rb') as f:
|
||||
contents = f.read()
|
||||
if b"libomp140.x86_64.dll" not in contents:
|
||||
break
|
||||
try:
|
||||
mydll = ctypes.cdll.LoadLibrary(test_file)
|
||||
except FileNotFoundError as e:
|
||||
logging.warning("Detected pytorch version with libomp issue, patching.")
|
||||
shutil.copyfile(os.path.join(lib_folder, "libiomp5md.dll"), dest)
|
@@ -1,13 +1,13 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import time
|
||||
import logging
|
||||
from typing import Set, List, Dict, Tuple
|
||||
from collections.abc import Collection
|
||||
|
||||
supported_pt_extensions: Set[str] = set(['.ckpt', '.pt', '.bin', '.pth', '.safetensors', '.pkl'])
|
||||
supported_pt_extensions: set[str] = {'.ckpt', '.pt', '.bin', '.pth', '.safetensors', '.pkl', '.sft'}
|
||||
|
||||
SupportedFileExtensionsType = Set[str]
|
||||
ScanPathType = List[str]
|
||||
folder_names_and_paths: Dict[str, Tuple[ScanPathType, SupportedFileExtensionsType]] = {}
|
||||
folder_names_and_paths: dict[str, tuple[list[str], set[str]]] = {}
|
||||
|
||||
base_path = os.path.dirname(os.path.realpath(__file__))
|
||||
models_dir = os.path.join(base_path, "models")
|
||||
@@ -42,7 +42,7 @@ temp_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "temp
|
||||
input_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "input")
|
||||
user_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "user")
|
||||
|
||||
filename_list_cache = {}
|
||||
filename_list_cache: dict[str, tuple[list[str], dict[str, float], float]] = {}
|
||||
|
||||
if not os.path.exists(input_directory):
|
||||
try:
|
||||
@@ -50,33 +50,33 @@ if not os.path.exists(input_directory):
|
||||
except:
|
||||
logging.error("Failed to create input directory")
|
||||
|
||||
def set_output_directory(output_dir):
|
||||
def set_output_directory(output_dir: str) -> None:
|
||||
global output_directory
|
||||
output_directory = output_dir
|
||||
|
||||
def set_temp_directory(temp_dir):
|
||||
def set_temp_directory(temp_dir: str) -> None:
|
||||
global temp_directory
|
||||
temp_directory = temp_dir
|
||||
|
||||
def set_input_directory(input_dir):
|
||||
def set_input_directory(input_dir: str) -> None:
|
||||
global input_directory
|
||||
input_directory = input_dir
|
||||
|
||||
def get_output_directory():
|
||||
def get_output_directory() -> str:
|
||||
global output_directory
|
||||
return output_directory
|
||||
|
||||
def get_temp_directory():
|
||||
def get_temp_directory() -> str:
|
||||
global temp_directory
|
||||
return temp_directory
|
||||
|
||||
def get_input_directory():
|
||||
def get_input_directory() -> str:
|
||||
global input_directory
|
||||
return input_directory
|
||||
|
||||
|
||||
#NOTE: used in http server so don't put folders that should not be accessed remotely
|
||||
def get_directory_by_type(type_name):
|
||||
def get_directory_by_type(type_name: str) -> str | None:
|
||||
if type_name == "output":
|
||||
return get_output_directory()
|
||||
if type_name == "temp":
|
||||
@@ -88,7 +88,7 @@ def get_directory_by_type(type_name):
|
||||
|
||||
# determine base_dir rely on annotation if name is 'filename.ext [annotation]' format
|
||||
# otherwise use default_path as base_dir
|
||||
def annotated_filepath(name):
|
||||
def annotated_filepath(name: str) -> tuple[str, str | None]:
|
||||
if name.endswith("[output]"):
|
||||
base_dir = get_output_directory()
|
||||
name = name[:-9]
|
||||
@@ -104,7 +104,7 @@ def annotated_filepath(name):
|
||||
return name, base_dir
|
||||
|
||||
|
||||
def get_annotated_filepath(name, default_dir=None):
|
||||
def get_annotated_filepath(name: str, default_dir: str | None=None) -> str:
|
||||
name, base_dir = annotated_filepath(name)
|
||||
|
||||
if base_dir is None:
|
||||
@@ -116,7 +116,7 @@ def get_annotated_filepath(name, default_dir=None):
|
||||
return os.path.join(base_dir, name)
|
||||
|
||||
|
||||
def exists_annotated_filepath(name):
|
||||
def exists_annotated_filepath(name) -> bool:
|
||||
name, base_dir = annotated_filepath(name)
|
||||
|
||||
if base_dir is None:
|
||||
@@ -126,17 +126,17 @@ def exists_annotated_filepath(name):
|
||||
return os.path.exists(filepath)
|
||||
|
||||
|
||||
def add_model_folder_path(folder_name, full_folder_path):
|
||||
def add_model_folder_path(folder_name: str, full_folder_path: str) -> None:
|
||||
global folder_names_and_paths
|
||||
if folder_name in folder_names_and_paths:
|
||||
folder_names_and_paths[folder_name][0].append(full_folder_path)
|
||||
else:
|
||||
folder_names_and_paths[folder_name] = ([full_folder_path], set())
|
||||
|
||||
def get_folder_paths(folder_name):
|
||||
def get_folder_paths(folder_name: str) -> list[str]:
|
||||
return folder_names_and_paths[folder_name][0][:]
|
||||
|
||||
def recursive_search(directory, excluded_dir_names=None):
|
||||
def recursive_search(directory: str, excluded_dir_names: list[str] | None=None) -> tuple[list[str], dict[str, float]]:
|
||||
if not os.path.isdir(directory):
|
||||
return [], {}
|
||||
|
||||
@@ -153,6 +153,10 @@ def recursive_search(directory, excluded_dir_names=None):
|
||||
logging.warning(f"Warning: Unable to access {directory}. Skipping this path.")
|
||||
|
||||
logging.debug("recursive file list on directory {}".format(directory))
|
||||
dirpath: str
|
||||
subdirs: list[str]
|
||||
filenames: list[str]
|
||||
|
||||
for dirpath, subdirs, filenames in os.walk(directory, followlinks=True, topdown=True):
|
||||
subdirs[:] = [d for d in subdirs if d not in excluded_dir_names]
|
||||
for file_name in filenames:
|
||||
@@ -160,7 +164,7 @@ def recursive_search(directory, excluded_dir_names=None):
|
||||
result.append(relative_path)
|
||||
|
||||
for d in subdirs:
|
||||
path = os.path.join(dirpath, d)
|
||||
path: str = os.path.join(dirpath, d)
|
||||
try:
|
||||
dirs[path] = os.path.getmtime(path)
|
||||
except FileNotFoundError:
|
||||
@@ -169,12 +173,12 @@ def recursive_search(directory, excluded_dir_names=None):
|
||||
logging.debug("found {} files".format(len(result)))
|
||||
return result, dirs
|
||||
|
||||
def filter_files_extensions(files, extensions):
|
||||
def filter_files_extensions(files: Collection[str], extensions: Collection[str]) -> list[str]:
|
||||
return sorted(list(filter(lambda a: os.path.splitext(a)[-1].lower() in extensions or len(extensions) == 0, files)))
|
||||
|
||||
|
||||
|
||||
def get_full_path(folder_name, filename):
|
||||
def get_full_path(folder_name: str, filename: str) -> str | None:
|
||||
global folder_names_and_paths
|
||||
if folder_name not in folder_names_and_paths:
|
||||
return None
|
||||
@@ -189,7 +193,7 @@ def get_full_path(folder_name, filename):
|
||||
|
||||
return None
|
||||
|
||||
def get_filename_list_(folder_name):
|
||||
def get_filename_list_(folder_name: str) -> tuple[list[str], dict[str, float], float]:
|
||||
global folder_names_and_paths
|
||||
output_list = set()
|
||||
folders = folder_names_and_paths[folder_name]
|
||||
@@ -199,9 +203,9 @@ def get_filename_list_(folder_name):
|
||||
output_list.update(filter_files_extensions(files, folders[1]))
|
||||
output_folders = {**output_folders, **folders_all}
|
||||
|
||||
return (sorted(list(output_list)), output_folders, time.perf_counter())
|
||||
return sorted(list(output_list)), output_folders, time.perf_counter()
|
||||
|
||||
def cached_filename_list_(folder_name):
|
||||
def cached_filename_list_(folder_name: str) -> tuple[list[str], dict[str, float], float] | None:
|
||||
global filename_list_cache
|
||||
global folder_names_and_paths
|
||||
if folder_name not in filename_list_cache:
|
||||
@@ -222,7 +226,7 @@ def cached_filename_list_(folder_name):
|
||||
|
||||
return out
|
||||
|
||||
def get_filename_list(folder_name):
|
||||
def get_filename_list(folder_name: str) -> list[str]:
|
||||
out = cached_filename_list_(folder_name)
|
||||
if out is None:
|
||||
out = get_filename_list_(folder_name)
|
||||
@@ -230,17 +234,17 @@ def get_filename_list(folder_name):
|
||||
filename_list_cache[folder_name] = out
|
||||
return list(out[0])
|
||||
|
||||
def get_save_image_path(filename_prefix, output_dir, image_width=0, image_height=0):
|
||||
def map_filename(filename):
|
||||
def get_save_image_path(filename_prefix: str, output_dir: str, image_width=0, image_height=0) -> tuple[str, str, int, str, str]:
|
||||
def map_filename(filename: str) -> tuple[int, str]:
|
||||
prefix_len = len(os.path.basename(filename_prefix))
|
||||
prefix = filename[:prefix_len + 1]
|
||||
try:
|
||||
digits = int(filename[prefix_len + 1:].split('_')[0])
|
||||
except:
|
||||
digits = 0
|
||||
return (digits, prefix)
|
||||
return digits, prefix
|
||||
|
||||
def compute_vars(input, image_width, image_height):
|
||||
def compute_vars(input: str, image_width: int, image_height: int) -> str:
|
||||
input = input.replace("%width%", str(image_width))
|
||||
input = input.replace("%height%", str(image_height))
|
||||
return input
|
||||
|
6
main.py
6
main.py
@@ -74,6 +74,12 @@ if __name__ == "__main__":
|
||||
|
||||
import cuda_malloc
|
||||
|
||||
if args.windows_standalone_build:
|
||||
try:
|
||||
import fix_torch
|
||||
except:
|
||||
pass
|
||||
|
||||
import comfy.utils
|
||||
import yaml
|
||||
|
||||
|
@@ -1,3 +1,7 @@
|
||||
import hashlib
|
||||
|
||||
from comfy.cli_args import args
|
||||
|
||||
from PIL import ImageFile, UnidentifiedImageError
|
||||
|
||||
def conditioning_set_values(conditioning, values={}):
|
||||
@@ -21,4 +25,13 @@ def pillow(fn, arg):
|
||||
finally:
|
||||
if prev_value is not None:
|
||||
ImageFile.LOAD_TRUNCATED_IMAGES = prev_value
|
||||
return x
|
||||
return x
|
||||
|
||||
def hasher():
|
||||
hashfuncs = {
|
||||
"md5": hashlib.md5,
|
||||
"sha1": hashlib.sha1,
|
||||
"sha256": hashlib.sha256,
|
||||
"sha512": hashlib.sha512
|
||||
}
|
||||
return hashfuncs[args.default_hashing_function]
|
||||
|
53
nodes.py
53
nodes.py
@@ -748,7 +748,7 @@ class ControlNetApply:
|
||||
RETURN_TYPES = ("CONDITIONING",)
|
||||
FUNCTION = "apply_controlnet"
|
||||
|
||||
CATEGORY = "conditioning"
|
||||
CATEGORY = "conditioning/controlnet"
|
||||
|
||||
def apply_controlnet(self, conditioning, control_net, image, strength):
|
||||
if strength == 0:
|
||||
@@ -783,7 +783,7 @@ class ControlNetApplyAdvanced:
|
||||
RETURN_NAMES = ("positive", "negative")
|
||||
FUNCTION = "apply_controlnet"
|
||||
|
||||
CATEGORY = "conditioning"
|
||||
CATEGORY = "conditioning/controlnet"
|
||||
|
||||
def apply_controlnet(self, positive, negative, control_net, image, strength, start_percent, end_percent, vae=None):
|
||||
if strength == 0:
|
||||
@@ -818,15 +818,22 @@ class UNETLoader:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "unet_name": (folder_paths.get_filename_list("unet"), ),
|
||||
"weight_dtype": (["default", "fp8_e4m3fn", "fp8_e5m2"],)
|
||||
}}
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
FUNCTION = "load_unet"
|
||||
|
||||
CATEGORY = "advanced/loaders"
|
||||
|
||||
def load_unet(self, unet_name):
|
||||
def load_unet(self, unet_name, weight_dtype):
|
||||
dtype = None
|
||||
if weight_dtype == "fp8_e4m3fn":
|
||||
dtype = torch.float8_e4m3fn
|
||||
elif weight_dtype == "fp8_e5m2":
|
||||
dtype = torch.float8_e5m2
|
||||
|
||||
unet_path = folder_paths.get_full_path("unet", unet_name)
|
||||
model = comfy.sd.load_unet(unet_path)
|
||||
model = comfy.sd.load_unet(unet_path, dtype=dtype)
|
||||
return (model,)
|
||||
|
||||
class CLIPLoader:
|
||||
@@ -859,7 +866,7 @@ class DualCLIPLoader:
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "clip_name1": (folder_paths.get_filename_list("clip"), ),
|
||||
"clip_name2": (folder_paths.get_filename_list("clip"), ),
|
||||
"type": (["sdxl", "sd3"], ),
|
||||
"type": (["sdxl", "sd3", "flux"], ),
|
||||
}}
|
||||
RETURN_TYPES = ("CLIP",)
|
||||
FUNCTION = "load_clip"
|
||||
@@ -873,6 +880,8 @@ class DualCLIPLoader:
|
||||
clip_type = comfy.sd.CLIPType.STABLE_DIFFUSION
|
||||
elif type == "sd3":
|
||||
clip_type = comfy.sd.CLIPType.SD3
|
||||
elif type == "flux":
|
||||
clip_type = comfy.sd.CLIPType.FLUX
|
||||
|
||||
clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type)
|
||||
return (clip,)
|
||||
@@ -1843,6 +1852,7 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"StyleModelLoader": "Load Style Model",
|
||||
"CLIPVisionLoader": "Load CLIP Vision",
|
||||
"UpscaleModelLoader": "Load Upscale Model",
|
||||
"UNETLoader": "Load Diffusion Model",
|
||||
# Conditioning
|
||||
"CLIPVisionEncode": "CLIP Vision Encode",
|
||||
"StyleModelApply": "Apply Style Model",
|
||||
@@ -1890,29 +1900,29 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
EXTENSION_WEB_DIRS = {}
|
||||
|
||||
|
||||
def get_relative_module_name(module_path: str) -> str:
|
||||
def get_module_name(module_path: str) -> str:
|
||||
"""
|
||||
Returns the module name based on the given module path.
|
||||
Examples:
|
||||
get_module_name("C:/Users/username/ComfyUI/custom_nodes/my_custom_node.py") -> "custom_nodes.my_custom_node"
|
||||
get_module_name("C:/Users/username/ComfyUI/custom_nodes/my_custom_node") -> "custom_nodes.my_custom_node"
|
||||
get_module_name("C:/Users/username/ComfyUI/custom_nodes/my_custom_node/") -> "custom_nodes.my_custom_node"
|
||||
get_module_name("C:/Users/username/ComfyUI/custom_nodes/my_custom_node/__init__.py") -> "custom_nodes.my_custom_node"
|
||||
get_module_name("C:/Users/username/ComfyUI/custom_nodes/my_custom_node/__init__") -> "custom_nodes.my_custom_node"
|
||||
get_module_name("C:/Users/username/ComfyUI/custom_nodes/my_custom_node/__init__/") -> "custom_nodes.my_custom_node"
|
||||
get_module_name("C:/Users/username/ComfyUI/custom_nodes/my_custom_node.disabled") -> "custom_nodes.my
|
||||
get_module_name("C:/Users/username/ComfyUI/custom_nodes/my_custom_node.py") -> "my_custom_node"
|
||||
get_module_name("C:/Users/username/ComfyUI/custom_nodes/my_custom_node") -> "my_custom_node"
|
||||
get_module_name("C:/Users/username/ComfyUI/custom_nodes/my_custom_node/") -> "my_custom_node"
|
||||
get_module_name("C:/Users/username/ComfyUI/custom_nodes/my_custom_node/__init__.py") -> "my_custom_node"
|
||||
get_module_name("C:/Users/username/ComfyUI/custom_nodes/my_custom_node/__init__") -> "my_custom_node"
|
||||
get_module_name("C:/Users/username/ComfyUI/custom_nodes/my_custom_node/__init__/") -> "my_custom_node"
|
||||
get_module_name("C:/Users/username/ComfyUI/custom_nodes/my_custom_node.disabled") -> "custom_nodes
|
||||
Args:
|
||||
module_path (str): The path of the module.
|
||||
Returns:
|
||||
str: The module name.
|
||||
"""
|
||||
relative_path = os.path.relpath(module_path, folder_paths.base_path)
|
||||
base_path = os.path.basename(module_path)
|
||||
if os.path.isfile(module_path):
|
||||
relative_path = os.path.splitext(relative_path)[0]
|
||||
return relative_path.replace(os.sep, '.')
|
||||
base_path = os.path.splitext(base_path)[0]
|
||||
return base_path
|
||||
|
||||
|
||||
def load_custom_node(module_path: str, ignore=set()) -> bool:
|
||||
def load_custom_node(module_path: str, ignore=set(), module_parent="custom_nodes") -> bool:
|
||||
module_name = os.path.basename(module_path)
|
||||
if os.path.isfile(module_path):
|
||||
sp = os.path.splitext(module_path)
|
||||
@@ -1939,7 +1949,7 @@ def load_custom_node(module_path: str, ignore=set()) -> bool:
|
||||
for name, node_cls in module.NODE_CLASS_MAPPINGS.items():
|
||||
if name not in ignore:
|
||||
NODE_CLASS_MAPPINGS[name] = node_cls
|
||||
node_cls.RELATIVE_PYTHON_MODULE = get_relative_module_name(module_path)
|
||||
node_cls.RELATIVE_PYTHON_MODULE = "{}.{}".format(module_parent, get_module_name(module_path))
|
||||
if hasattr(module, "NODE_DISPLAY_NAME_MAPPINGS") and getattr(module, "NODE_DISPLAY_NAME_MAPPINGS") is not None:
|
||||
NODE_DISPLAY_NAME_MAPPINGS.update(module.NODE_DISPLAY_NAME_MAPPINGS)
|
||||
return True
|
||||
@@ -1974,7 +1984,7 @@ def init_external_custom_nodes():
|
||||
if os.path.isfile(module_path) and os.path.splitext(module_path)[1] != ".py": continue
|
||||
if module_path.endswith(".disabled"): continue
|
||||
time_before = time.perf_counter()
|
||||
success = load_custom_node(module_path, base_node_names)
|
||||
success = load_custom_node(module_path, base_node_names, module_parent="custom_nodes")
|
||||
node_import_times.append((time.perf_counter() - time_before, module_path, success))
|
||||
|
||||
if len(node_import_times) > 0:
|
||||
@@ -2036,11 +2046,14 @@ def init_builtin_extra_nodes():
|
||||
"nodes_audio.py",
|
||||
"nodes_sd3.py",
|
||||
"nodes_gits.py",
|
||||
"nodes_controlnet.py",
|
||||
"nodes_hunyuan.py",
|
||||
"nodes_flux.py",
|
||||
]
|
||||
|
||||
import_failed = []
|
||||
for node_file in extras_files:
|
||||
if not load_custom_node(os.path.join(extras_dir, node_file)):
|
||||
if not load_custom_node(os.path.join(extras_dir, node_file), module_parent="comfy_extras"):
|
||||
import_failed.append(node_file)
|
||||
|
||||
return import_failed
|
||||
|
@@ -1,5 +1,8 @@
|
||||
[pytest]
|
||||
markers =
|
||||
inference: mark as inference test (deselect with '-m "not inference"')
|
||||
testpaths = tests
|
||||
addopts = -s
|
||||
testpaths =
|
||||
tests
|
||||
tests-unit
|
||||
addopts = -s
|
||||
pythonpath = .
|
||||
|
24
server.py
24
server.py
@@ -25,9 +25,11 @@ import mimetypes
|
||||
from comfy.cli_args import args
|
||||
import comfy.utils
|
||||
import comfy.model_management
|
||||
|
||||
import node_helpers
|
||||
from app.frontend_management import FrontendManager
|
||||
from app.user_manager import UserManager
|
||||
|
||||
|
||||
class BinaryEventTypes:
|
||||
PREVIEW_IMAGE = 1
|
||||
UNENCODED_PREVIEW_IMAGE = 2
|
||||
@@ -83,8 +85,12 @@ class PromptServer():
|
||||
max_upload_size = round(args.max_upload_size * 1024 * 1024)
|
||||
self.app = web.Application(client_max_size=max_upload_size, middlewares=middlewares)
|
||||
self.sockets = dict()
|
||||
self.web_root = os.path.join(os.path.dirname(
|
||||
os.path.realpath(__file__)), "web")
|
||||
self.web_root = (
|
||||
FrontendManager.init_frontend(args.front_end_version)
|
||||
if args.front_end_root is None
|
||||
else args.front_end_root
|
||||
)
|
||||
logging.info(f"[Prompt Server] web root: {self.web_root}")
|
||||
routes = web.RouteTableDef()
|
||||
self.routes = routes
|
||||
self.last_node_id = None
|
||||
@@ -121,7 +127,11 @@ class PromptServer():
|
||||
|
||||
@routes.get("/")
|
||||
async def get_root(request):
|
||||
return web.FileResponse(os.path.join(self.web_root, "index.html"))
|
||||
response = web.FileResponse(os.path.join(self.web_root, "index.html"))
|
||||
response.headers['Cache-Control'] = 'no-cache'
|
||||
response.headers["Pragma"] = "no-cache"
|
||||
response.headers["Expires"] = "0"
|
||||
return response
|
||||
|
||||
@routes.get("/embeddings")
|
||||
def get_embeddings(self):
|
||||
@@ -156,10 +166,12 @@ class PromptServer():
|
||||
return type_dir, dir_type
|
||||
|
||||
def compare_image_hash(filepath, image):
|
||||
hasher = node_helpers.hasher()
|
||||
|
||||
# function to compare hashes of two images to see if it already exists, fix to #3465
|
||||
if os.path.exists(filepath):
|
||||
a = hashlib.sha256()
|
||||
b = hashlib.sha256()
|
||||
a = hasher()
|
||||
b = hasher()
|
||||
with open(filepath, "rb") as f:
|
||||
a.update(f.read())
|
||||
b.update(image.file.read())
|
||||
|
8
tests-unit/README.md
Normal file
8
tests-unit/README.md
Normal file
@@ -0,0 +1,8 @@
|
||||
# Pytest Unit Tests
|
||||
|
||||
## Install test dependencies
|
||||
|
||||
`pip install -r tests-units/requirements.txt`
|
||||
|
||||
## Run tests
|
||||
`pytest tests-units/`
|
0
tests-unit/app_test/__init__.py
Normal file
0
tests-unit/app_test/__init__.py
Normal file
100
tests-unit/app_test/frontend_manager_test.py
Normal file
100
tests-unit/app_test/frontend_manager_test.py
Normal file
@@ -0,0 +1,100 @@
|
||||
import argparse
|
||||
import pytest
|
||||
from requests.exceptions import HTTPError
|
||||
|
||||
from app.frontend_management import (
|
||||
FrontendManager,
|
||||
FrontEndProvider,
|
||||
Release,
|
||||
)
|
||||
from comfy.cli_args import DEFAULT_VERSION_STRING
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_releases():
|
||||
return [
|
||||
Release(
|
||||
id=1,
|
||||
tag_name="1.0.0",
|
||||
name="Release 1.0.0",
|
||||
prerelease=False,
|
||||
created_at="2022-01-01T00:00:00Z",
|
||||
published_at="2022-01-01T00:00:00Z",
|
||||
body="Release notes for 1.0.0",
|
||||
assets=[{"name": "dist.zip", "url": "https://example.com/dist.zip"}],
|
||||
),
|
||||
Release(
|
||||
id=2,
|
||||
tag_name="2.0.0",
|
||||
name="Release 2.0.0",
|
||||
prerelease=False,
|
||||
created_at="2022-02-01T00:00:00Z",
|
||||
published_at="2022-02-01T00:00:00Z",
|
||||
body="Release notes for 2.0.0",
|
||||
assets=[{"name": "dist.zip", "url": "https://example.com/dist.zip"}],
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_provider(mock_releases):
|
||||
provider = FrontEndProvider(
|
||||
owner="test-owner",
|
||||
repo="test-repo",
|
||||
)
|
||||
provider.all_releases = mock_releases
|
||||
provider.latest_release = mock_releases[1]
|
||||
FrontendManager.PROVIDERS = [provider]
|
||||
return provider
|
||||
|
||||
|
||||
def test_get_release(mock_provider, mock_releases):
|
||||
version = "1.0.0"
|
||||
release = mock_provider.get_release(version)
|
||||
assert release == mock_releases[0]
|
||||
|
||||
|
||||
def test_get_release_latest(mock_provider, mock_releases):
|
||||
version = "latest"
|
||||
release = mock_provider.get_release(version)
|
||||
assert release == mock_releases[1]
|
||||
|
||||
|
||||
def test_get_release_invalid_version(mock_provider):
|
||||
version = "invalid"
|
||||
with pytest.raises(ValueError):
|
||||
mock_provider.get_release(version)
|
||||
|
||||
|
||||
def test_init_frontend_default():
|
||||
version_string = DEFAULT_VERSION_STRING
|
||||
frontend_path = FrontendManager.init_frontend(version_string)
|
||||
assert frontend_path == FrontendManager.DEFAULT_FRONTEND_PATH
|
||||
|
||||
|
||||
def test_init_frontend_invalid_version():
|
||||
version_string = "test-owner/test-repo@1.100.99"
|
||||
with pytest.raises(HTTPError):
|
||||
FrontendManager.init_frontend_unsafe(version_string)
|
||||
|
||||
|
||||
def test_init_frontend_invalid_provider():
|
||||
version_string = "invalid/invalid@latest"
|
||||
with pytest.raises(HTTPError):
|
||||
FrontendManager.init_frontend_unsafe(version_string)
|
||||
|
||||
|
||||
def test_parse_version_string():
|
||||
version_string = "owner/repo@1.0.0"
|
||||
repo_owner, repo_name, version = FrontendManager.parse_version_string(
|
||||
version_string
|
||||
)
|
||||
assert repo_owner == "owner"
|
||||
assert repo_name == "repo"
|
||||
assert version == "1.0.0"
|
||||
|
||||
|
||||
def test_parse_version_string_invalid():
|
||||
version_string = "invalid"
|
||||
with pytest.raises(argparse.ArgumentTypeError):
|
||||
FrontendManager.parse_version_string(version_string)
|
1
tests-unit/requirements.txt
Normal file
1
tests-unit/requirements.txt
Normal file
@@ -0,0 +1 @@
|
||||
pytest>=7.8.0
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user