mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-08-02 23:14:49 +08:00
Compare commits
127 Commits
Author | SHA1 | Date | |
---|---|---|---|
|
f1c2301697 | ||
|
8d31a6632f | ||
|
b643eae08b | ||
|
baa6b4dc36 | ||
|
d4aeefc297 | ||
|
587e7ca654 | ||
|
c90459eba0 | ||
|
04278afb10 | ||
|
935ae153e1 | ||
|
e91662e784 | ||
|
63fafaef45 | ||
|
ec28cd9136 | ||
|
6eb5d64522 | ||
|
10a79e9898 | ||
|
ea3f39bd69 | ||
|
b33cd61070 | ||
|
34eda0f853 | ||
|
d31e226650 | ||
|
b79fd7d92c | ||
|
38c22e631a | ||
|
6bbdcd28ae | ||
|
ab130001a8 | ||
|
ca4b8f30e0 | ||
|
70b84058c1 | ||
|
2ca8f6e23d | ||
|
7985ff88b9 | ||
|
c6812947e9 | ||
|
9230f65823 | ||
|
6ab1e6fd4a | ||
|
07dcbc3a3e | ||
|
8ae23d8e80 | ||
|
7df42b9a23 | ||
|
5d8bbb7281 | ||
|
2c1d2375d6 | ||
|
64ccb3c7e3 | ||
|
9465b23432 | ||
|
bb4416dd5b | ||
|
c0b0da264b | ||
|
c26ca27207 | ||
|
7c6bb84016 | ||
|
c54d3ed5e6 | ||
|
c7ee4b37a1 | ||
|
7b70b266d8 | ||
|
8f60d093ba | ||
|
dafbe321d2 | ||
|
5f84ea63e8 | ||
|
843a7ff70c | ||
|
a60620dcea | ||
|
015f73dc49 | ||
|
904bf58e7d | ||
|
5f50263088 | ||
|
5e806f555d | ||
|
f07e5bb522 | ||
|
03ec517afb | ||
|
f257fc999f | ||
|
bb50e69839 | ||
|
510f3438c1 | ||
|
ea63b1c092 | ||
|
9953f22fce | ||
|
d1a6bd6845 | ||
|
83dbac28eb | ||
|
538cb068bc | ||
|
1b3eee672c | ||
|
5a69f84c3c | ||
|
9eee470244 | ||
|
045377ea89 | ||
|
4d341b78e8 | ||
|
6138f92084 | ||
|
be0726c1ed | ||
|
766ae119a8 | ||
|
fc90ceb6ba | ||
|
4506ddc86a | ||
|
20ace7c853 | ||
|
b29b3b86c5 | ||
|
22ec02afc0 | ||
|
39f114c44b | ||
|
6730f3e1a3 | ||
|
73332160c8 | ||
|
2622c55aff | ||
![]() |
1beb348ee2 | ||
|
9aa39e743c | ||
|
d31df04c8a | ||
|
e68763f40c | ||
|
310ad09258 | ||
|
4f7a3cb6fb | ||
|
bb222ceddb | ||
|
14af129c55 | ||
|
fca42836f2 | ||
|
858d51f91a | ||
|
cd5017c1c9 | ||
|
83f343146a | ||
|
b021cf67c7 | ||
|
1770fc77ed | ||
|
05a9f3faa1 | ||
|
86c5970ac0 | ||
|
bfc214f434 | ||
|
3f5939add6 | ||
|
5960f946a9 | ||
|
5cfe38f41c | ||
|
0f9c2a7822 | ||
|
153d0a8142 | ||
|
ab4dd19b91 | ||
|
f1d6cef71c | ||
|
33fb282d5c | ||
|
50bf66e5c4 | ||
|
e60e19b175 | ||
|
a5af64d3ce | ||
|
3e52e0364c | ||
|
34608de2e9 | ||
|
39fb74c5bd | ||
|
74e124f4d7 | ||
|
a562c17e8a | ||
|
5942c17d55 | ||
|
c032b11e07 | ||
|
b8ffb2937f | ||
|
ce37c11164 | ||
|
b5c3906b38 | ||
|
5d43e75e5b | ||
|
517f4a94e4 | ||
|
52a471c5c7 | ||
|
ad76574cb8 | ||
|
9acfe4df41 | ||
|
9829b013ea | ||
|
5c69cde037 | ||
|
e9589d6d92 | ||
|
0d82a798a5 | ||
|
925fff26fd |
@@ -75,6 +75,25 @@ else:
|
||||
print("pulling latest changes")
|
||||
pull(repo)
|
||||
|
||||
if "--stable" in sys.argv:
|
||||
def latest_tag(repo):
|
||||
versions = []
|
||||
for k in repo.references:
|
||||
try:
|
||||
prefix = "refs/tags/v"
|
||||
if k.startswith(prefix):
|
||||
version = list(map(int, k[len(prefix):].split(".")))
|
||||
versions.append((version[0] * 10000000000 + version[1] * 100000 + version[2], k))
|
||||
except:
|
||||
pass
|
||||
versions.sort()
|
||||
if len(versions) > 0:
|
||||
return versions[-1][1]
|
||||
return None
|
||||
latest_tag = latest_tag(repo)
|
||||
if latest_tag is not None:
|
||||
repo.checkout(latest_tag)
|
||||
|
||||
print("Done!")
|
||||
|
||||
self_update = True
|
||||
@@ -115,3 +134,13 @@ if not os.path.exists(req_path) or not files_equal(repo_req_path, req_path):
|
||||
shutil.copy(repo_req_path, req_path)
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
stable_update_script = os.path.join(repo_path, ".ci/update_windows/update_comfyui_stable.bat")
|
||||
stable_update_script_to = os.path.join(cur_path, "update_comfyui_stable.bat")
|
||||
|
||||
try:
|
||||
if not file_size(stable_update_script_to) > 10:
|
||||
shutil.copy(stable_update_script, stable_update_script_to)
|
||||
except:
|
||||
pass
|
||||
|
8
.ci/update_windows/update_comfyui_stable.bat
Executable file
8
.ci/update_windows/update_comfyui_stable.bat
Executable file
@@ -0,0 +1,8 @@
|
||||
@echo off
|
||||
..\python_embeded\python.exe .\update.py ..\ComfyUI\ --stable
|
||||
if exist update_new.py (
|
||||
move /y update_new.py update.py
|
||||
echo Running updater again since it got updated.
|
||||
..\python_embeded\python.exe .\update.py ..\ComfyUI\ --skip_self_update --stable
|
||||
)
|
||||
if "%~1"=="" pause
|
@@ -14,7 +14,7 @@ run_cpu.bat
|
||||
|
||||
IF YOU GET A RED ERROR IN THE UI MAKE SURE YOU HAVE A MODEL/CHECKPOINT IN: ComfyUI\models\checkpoints
|
||||
|
||||
You can download the stable diffusion 1.5 one from: https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.ckpt
|
||||
You can download the stable diffusion 1.5 one from: https://huggingface.co/Comfy-Org/stable-diffusion-v1-5-archive/blob/main/v1-5-pruned-emaonly-fp16.safetensors
|
||||
|
||||
|
||||
RECOMMENDED WAY TO UPDATE:
|
||||
|
2
.ci/windows_nightly_base_files/run_nvidia_gpu_fast.bat
Normal file
2
.ci/windows_nightly_base_files/run_nvidia_gpu_fast.bat
Normal file
@@ -0,0 +1,2 @@
|
||||
.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build --fast
|
||||
pause
|
2
.gitattributes
vendored
Normal file
2
.gitattributes
vendored
Normal file
@@ -0,0 +1,2 @@
|
||||
/web/assets/** linguist-generated
|
||||
/web/** linguist-vendored
|
3
.github/ISSUE_TEMPLATE/config.yml
vendored
3
.github/ISSUE_TEMPLATE/config.yml
vendored
@@ -1,5 +1,8 @@
|
||||
blank_issues_enabled: true
|
||||
contact_links:
|
||||
- name: ComfyUI Frontend Issues
|
||||
url: https://github.com/Comfy-Org/ComfyUI_frontend/issues
|
||||
about: Issues related to the ComfyUI frontend (display issues, user interaction bugs), please go to the frontend repo to file the issue
|
||||
- name: ComfyUI Matrix Space
|
||||
url: https://app.element.io/#/room/%23comfyui_space%3Amatrix.org
|
||||
about: The ComfyUI Matrix Space is available for support and general discussion related to ComfyUI (Matrix is like Discord but open source).
|
||||
|
16
.github/workflows/pullrequest-ci-run.yml
vendored
16
.github/workflows/pullrequest-ci-run.yml
vendored
@@ -35,3 +35,19 @@ jobs:
|
||||
torch_version: ${{ matrix.torch_version }}
|
||||
google_credentials: ${{ secrets.GCS_SERVICE_ACCOUNT_JSON }}
|
||||
comfyui_flags: ${{ matrix.flags }}
|
||||
use_prior_commit: 'true'
|
||||
comment:
|
||||
if: ${{ github.event.label.name == 'Run-CI-Test' }}
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
pull-requests: write
|
||||
steps:
|
||||
- uses: actions/github-script@v6
|
||||
with:
|
||||
script: |
|
||||
github.rest.issues.createComment({
|
||||
issue_number: context.issue.number,
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
body: '(Automated Bot Message) CI Tests are running, you can view the results at https://ci.comfy.org/?branch=${{ github.event.pull_request.number }}%2Fmerge'
|
||||
})
|
||||
|
21
.github/workflows/stale-issues.yml
vendored
Normal file
21
.github/workflows/stale-issues.yml
vendored
Normal file
@@ -0,0 +1,21 @@
|
||||
name: 'Close stale issues'
|
||||
on:
|
||||
schedule:
|
||||
# Run daily at 430 am PT
|
||||
- cron: '30 11 * * *'
|
||||
permissions:
|
||||
issues: write
|
||||
|
||||
jobs:
|
||||
stale:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/stale@v9
|
||||
with:
|
||||
stale-issue-message: "This issue is being marked stale because it has not had any activity for 30 days. Reply below within 7 days if your issue still isn't solved, and it will be left open. Otherwise, the issue will be closed automatically."
|
||||
days-before-stale: 30
|
||||
days-before-close: 7
|
||||
stale-issue-label: 'Stale'
|
||||
only-labels: 'User Support'
|
||||
exempt-all-assignees: true
|
||||
exempt-all-milestones: true
|
@@ -1,10 +1,4 @@
|
||||
# This is a temporary action during frontend TS migration.
|
||||
# This file should be removed after TS migration is completed.
|
||||
# The browser test is here to ensure TS repo is working the same way as the
|
||||
# current JS code.
|
||||
# If you are adding UI feature, please sync your changes to the TS repo:
|
||||
# huchenlei/ComfyUI_frontend and update test expectation files accordingly.
|
||||
name: Playwright Browser Tests CI
|
||||
name: Test server launches without errors
|
||||
|
||||
on:
|
||||
push:
|
||||
@@ -21,15 +15,6 @@ jobs:
|
||||
with:
|
||||
repository: "comfyanonymous/ComfyUI"
|
||||
path: "ComfyUI"
|
||||
- name: Checkout ComfyUI_frontend
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
repository: "huchenlei/ComfyUI_frontend"
|
||||
path: "ComfyUI_frontend"
|
||||
ref: "fcc54d803e5b6a9b08a462a1d94899318c96dcbb"
|
||||
- uses: actions/setup-node@v3
|
||||
with:
|
||||
node-version: lts/*
|
||||
- uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: '3.8'
|
||||
@@ -45,16 +30,6 @@ jobs:
|
||||
python main.py --cpu 2>&1 | tee console_output.log &
|
||||
wait-for-it --service 127.0.0.1:8188 -t 600
|
||||
working-directory: ComfyUI
|
||||
- name: Install ComfyUI_frontend dependencies
|
||||
run: |
|
||||
npm ci
|
||||
working-directory: ComfyUI_frontend
|
||||
- name: Install Playwright Browsers
|
||||
run: npx playwright install --with-deps
|
||||
working-directory: ComfyUI_frontend
|
||||
- name: Run Playwright tests
|
||||
run: npx playwright test
|
||||
working-directory: ComfyUI_frontend
|
||||
- name: Check for unhandled exceptions in server log
|
||||
run: |
|
||||
if grep -qE "Exception|Error" console_output.log; then
|
||||
@@ -62,12 +37,6 @@ jobs:
|
||||
exit 1
|
||||
fi
|
||||
working-directory: ComfyUI
|
||||
- uses: actions/upload-artifact@v4
|
||||
if: always()
|
||||
with:
|
||||
name: playwright-report
|
||||
path: ComfyUI_frontend/playwright-report/
|
||||
retention-days: 30
|
||||
- uses: actions/upload-artifact@v4
|
||||
if: always()
|
||||
with:
|
30
.github/workflows/test-ui.yaml
vendored
30
.github/workflows/test-ui.yaml
vendored
@@ -1,30 +0,0 @@
|
||||
name: Tests CI
|
||||
|
||||
on: [push, pull_request]
|
||||
|
||||
jobs:
|
||||
test:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/setup-node@v3
|
||||
with:
|
||||
node-version: 18
|
||||
- uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: '3.10'
|
||||
- name: Install requirements
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
|
||||
pip install -r requirements.txt
|
||||
- name: Run Tests
|
||||
run: |
|
||||
npm ci
|
||||
npm run test:generate
|
||||
npm test -- --verbose
|
||||
working-directory: ./tests-ui
|
||||
- name: Run Unit Tests
|
||||
run: |
|
||||
pip install -r tests-unit/requirements.txt
|
||||
python -m pytest tests-unit
|
@@ -67,6 +67,7 @@ jobs:
|
||||
mkdir update
|
||||
cp -r ComfyUI/.ci/update_windows/* ./update/
|
||||
cp -r ComfyUI/.ci/windows_base_files/* ./
|
||||
cp -r ComfyUI/.ci/windows_nightly_base_files/* ./
|
||||
|
||||
echo "call update_comfyui.bat nopause
|
||||
..\python_embeded\python.exe -s -m pip install --upgrade --pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/nightly/cu${{ inputs.cu }} -r ../ComfyUI/requirements.txt pygit2
|
||||
|
3
.gitignore
vendored
3
.gitignore
vendored
@@ -18,4 +18,5 @@ venv/
|
||||
/tests-ui/data/object_info.json
|
||||
/user/
|
||||
*.log
|
||||
web_custom_versions/
|
||||
web_custom_versions/
|
||||
.DS_Store
|
||||
|
89
README.md
89
README.md
@@ -1,8 +1,35 @@
|
||||
ComfyUI
|
||||
=======
|
||||
The most powerful and modular stable diffusion GUI and backend.
|
||||
-----------
|
||||
<div align="center">
|
||||
|
||||
# ComfyUI
|
||||
**The most powerful and modular diffusion model GUI and backend.**
|
||||
|
||||
|
||||
[![Website][website-shield]][website-url]
|
||||
[![Dynamic JSON Badge][discord-shield]][discord-url]
|
||||
[![Matrix][matrix-shield]][matrix-url]
|
||||
<br>
|
||||
[![][github-release-shield]][github-release-link]
|
||||
[![][github-release-date-shield]][github-release-link]
|
||||
[![][github-downloads-shield]][github-downloads-link]
|
||||
[![][github-downloads-latest-shield]][github-downloads-link]
|
||||
|
||||
[matrix-shield]: https://img.shields.io/badge/Matrix-000000?style=flat&logo=matrix&logoColor=white
|
||||
[matrix-url]: https://app.element.io/#/room/%23comfyui_space%3Amatrix.org
|
||||
[website-shield]: https://img.shields.io/badge/ComfyOrg-4285F4?style=flat
|
||||
[website-url]: https://www.comfy.org/
|
||||
<!-- Workaround to display total user from https://github.com/badges/shields/issues/4500#issuecomment-2060079995 -->
|
||||
[discord-shield]: https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fdiscord.com%2Fapi%2Finvites%2Fcomfyorg%3Fwith_counts%3Dtrue&query=%24.approximate_member_count&logo=discord&logoColor=white&label=Discord&color=green&suffix=%20total
|
||||
[discord-url]: https://www.comfy.org/discord
|
||||
|
||||
[github-release-shield]: https://img.shields.io/github/v/release/comfyanonymous/ComfyUI?style=flat&sort=semver
|
||||
[github-release-link]: https://github.com/comfyanonymous/ComfyUI/releases
|
||||
[github-release-date-shield]: https://img.shields.io/github/release-date/comfyanonymous/ComfyUI?style=flat
|
||||
[github-downloads-shield]: https://img.shields.io/github/downloads/comfyanonymous/ComfyUI/total?style=flat
|
||||
[github-downloads-latest-shield]: https://img.shields.io/github/downloads/comfyanonymous/ComfyUI/latest/total?style=flat&label=downloads%40latest
|
||||
[github-downloads-link]: https://github.com/comfyanonymous/ComfyUI/releases
|
||||
|
||||

|
||||
</div>
|
||||
|
||||
This ui will let you design and execute advanced stable diffusion pipelines using a graph/nodes/flowchart based interface. For some workflow examples and see what ComfyUI can do you can check out:
|
||||
### [ComfyUI Examples](https://comfyanonymous.github.io/ComfyUI_examples/)
|
||||
@@ -48,6 +75,7 @@ Workflow examples can be found on the [Examples page](https://comfyanonymous.git
|
||||
|------------------------------------|--------------------------------------------------------------------------------------------------------------------|
|
||||
| Ctrl + Enter | Queue up current graph for generation |
|
||||
| Ctrl + Shift + Enter | Queue up current graph as first for generation |
|
||||
| Ctrl + Alt + Enter | Cancel current generation |
|
||||
| Ctrl + Z/Ctrl + Y | Undo/Redo |
|
||||
| Ctrl + S | Save workflow |
|
||||
| Ctrl + O | Load workflow |
|
||||
@@ -70,6 +98,8 @@ Workflow examples can be found on the [Examples page](https://comfyanonymous.git
|
||||
| H | Toggle visibility of history |
|
||||
| R | Refresh graph |
|
||||
| Double-Click LMB | Open node quick search palette |
|
||||
| Shift + Drag | Move multiple wires at once |
|
||||
| Ctrl + Alt + LMB | Disconnect all wires from clicked slot |
|
||||
|
||||
Ctrl can also be replaced with Cmd instead for macOS users
|
||||
|
||||
@@ -105,17 +135,17 @@ Put your VAE in: models/vae
|
||||
### AMD GPUs (Linux only)
|
||||
AMD users can install rocm and pytorch with pip if you don't have it already installed, this is the command to install the stable version:
|
||||
|
||||
```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.0```
|
||||
```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.1```
|
||||
|
||||
This is the command to install the nightly with ROCm 6.0 which might have some performance improvements:
|
||||
This is the command to install the nightly with ROCm 6.2 which might have some performance improvements:
|
||||
|
||||
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm6.1```
|
||||
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm6.2```
|
||||
|
||||
### NVIDIA
|
||||
|
||||
Nvidia users should install stable pytorch using this command:
|
||||
|
||||
```pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu121```
|
||||
```pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu124```
|
||||
|
||||
This is the command to install pytorch nightly instead which might have performance improvements:
|
||||
|
||||
@@ -200,7 +230,7 @@ To use a textual inversion concepts/embeddings in a text prompt put them in the
|
||||
|
||||
Use ```--preview-method auto``` to enable previews.
|
||||
|
||||
The default installation includes a fast latent preview method that's low-resolution. To enable higher-quality previews with [TAESD](https://github.com/madebyollin/taesd), download the [taesd_decoder.pth](https://github.com/madebyollin/taesd/raw/main/taesd_decoder.pth) (for SD1.x and SD2.x) and [taesdxl_decoder.pth](https://github.com/madebyollin/taesd/raw/main/taesdxl_decoder.pth) (for SDXL) models and place them in the `models/vae_approx` folder. Once they're installed, restart ComfyUI to enable high-quality previews.
|
||||
The default installation includes a fast latent preview method that's low-resolution. To enable higher-quality previews with [TAESD](https://github.com/madebyollin/taesd), download the [taesd_decoder.pth, taesdxl_decoder.pth, taesd3_decoder.pth and taef1_decoder.pth](https://github.com/madebyollin/taesd/) and place them in the `models/vae_approx` folder. Once they're installed, restart ComfyUI and launch it with `--preview-method taesd` to enable high-quality previews.
|
||||
|
||||
## How to use TLS/SSL?
|
||||
Generate a self-signed certificate (not appropriate for shared/production use) and key by running the command: `openssl req -x509 -newkey rsa:4096 -keyout key.pem -out cert.pem -sha256 -days 3650 -nodes -subj "/C=XX/ST=StateName/L=CityName/O=CompanyName/OU=CompanySectionName/CN=CommonNameOrHostname"`
|
||||
@@ -216,6 +246,47 @@ Use `--tls-keyfile key.pem --tls-certfile cert.pem` to enable TLS/SSL, the app w
|
||||
|
||||
See also: [https://www.comfy.org/](https://www.comfy.org/)
|
||||
|
||||
## Frontend Development
|
||||
|
||||
As of August 15, 2024, we have transitioned to a new frontend, which is now hosted in a separate repository: [ComfyUI Frontend](https://github.com/Comfy-Org/ComfyUI_frontend). This repository now hosts the compiled JS (from TS/Vue) under the `web/` directory.
|
||||
|
||||
### Reporting Issues and Requesting Features
|
||||
|
||||
For any bugs, issues, or feature requests related to the frontend, please use the [ComfyUI Frontend repository](https://github.com/Comfy-Org/ComfyUI_frontend). This will help us manage and address frontend-specific concerns more efficiently.
|
||||
|
||||
### Using the Latest Frontend
|
||||
|
||||
The new frontend is now the default for ComfyUI. However, please note:
|
||||
|
||||
1. The frontend in the main ComfyUI repository is updated weekly.
|
||||
2. Daily releases are available in the separate frontend repository.
|
||||
|
||||
To use the most up-to-date frontend version:
|
||||
|
||||
1. For the latest daily release, launch ComfyUI with this command line argument:
|
||||
|
||||
```
|
||||
--front-end-version Comfy-Org/ComfyUI_frontend@latest
|
||||
```
|
||||
|
||||
2. For a specific version, replace `latest` with the desired version number:
|
||||
|
||||
```
|
||||
--front-end-version Comfy-Org/ComfyUI_frontend@1.2.2
|
||||
```
|
||||
|
||||
This approach allows you to easily switch between the stable weekly release and the cutting-edge daily updates, or even specific versions for testing purposes.
|
||||
|
||||
### Accessing the Legacy Frontend
|
||||
|
||||
If you need to use the legacy frontend for any reason, you can access it using the following command line argument:
|
||||
|
||||
```
|
||||
--front-end-version Comfy-Org/ComfyUI_legacy_frontend@latest
|
||||
```
|
||||
|
||||
This will use a snapshot of the legacy frontend preserved in the [ComfyUI Legacy Frontend repository](https://github.com/Comfy-Org/ComfyUI_legacy_frontend).
|
||||
|
||||
# QA
|
||||
|
||||
### Which GPU should I buy for this?
|
||||
|
0
api_server/__init__.py
Normal file
0
api_server/__init__.py
Normal file
0
api_server/routes/__init__.py
Normal file
0
api_server/routes/__init__.py
Normal file
3
api_server/routes/internal/README.md
Normal file
3
api_server/routes/internal/README.md
Normal file
@@ -0,0 +1,3 @@
|
||||
# ComfyUI Internal Routes
|
||||
|
||||
All routes under the `/internal` path are designated for **internal use by ComfyUI only**. These routes are not intended for use by external applications may change at any time without notice.
|
0
api_server/routes/internal/__init__.py
Normal file
0
api_server/routes/internal/__init__.py
Normal file
44
api_server/routes/internal/internal_routes.py
Normal file
44
api_server/routes/internal/internal_routes.py
Normal file
@@ -0,0 +1,44 @@
|
||||
from aiohttp import web
|
||||
from typing import Optional
|
||||
from folder_paths import models_dir, user_directory, output_directory
|
||||
from api_server.services.file_service import FileService
|
||||
import app.logger
|
||||
|
||||
class InternalRoutes:
|
||||
'''
|
||||
The top level web router for internal routes: /internal/*
|
||||
The endpoints here should NOT be depended upon. It is for ComfyUI frontend use only.
|
||||
Check README.md for more information.
|
||||
|
||||
'''
|
||||
def __init__(self):
|
||||
self.routes: web.RouteTableDef = web.RouteTableDef()
|
||||
self._app: Optional[web.Application] = None
|
||||
self.file_service = FileService({
|
||||
"models": models_dir,
|
||||
"user": user_directory,
|
||||
"output": output_directory
|
||||
})
|
||||
|
||||
def setup_routes(self):
|
||||
@self.routes.get('/files')
|
||||
async def list_files(request):
|
||||
directory_key = request.query.get('directory', '')
|
||||
try:
|
||||
file_list = self.file_service.list_files(directory_key)
|
||||
return web.json_response({"files": file_list})
|
||||
except ValueError as e:
|
||||
return web.json_response({"error": str(e)}, status=400)
|
||||
except Exception as e:
|
||||
return web.json_response({"error": str(e)}, status=500)
|
||||
|
||||
@self.routes.get('/logs')
|
||||
async def get_logs(request):
|
||||
return web.json_response(app.logger.get_logs())
|
||||
|
||||
def get_app(self):
|
||||
if self._app is None:
|
||||
self._app = web.Application()
|
||||
self.setup_routes()
|
||||
self._app.add_routes(self.routes)
|
||||
return self._app
|
0
api_server/services/__init__.py
Normal file
0
api_server/services/__init__.py
Normal file
13
api_server/services/file_service.py
Normal file
13
api_server/services/file_service.py
Normal file
@@ -0,0 +1,13 @@
|
||||
from typing import Dict, List, Optional
|
||||
from api_server.utils.file_operations import FileSystemOperations, FileSystemItem
|
||||
|
||||
class FileService:
|
||||
def __init__(self, allowed_directories: Dict[str, str], file_system_ops: Optional[FileSystemOperations] = None):
|
||||
self.allowed_directories: Dict[str, str] = allowed_directories
|
||||
self.file_system_ops: FileSystemOperations = file_system_ops or FileSystemOperations()
|
||||
|
||||
def list_files(self, directory_key: str) -> List[FileSystemItem]:
|
||||
if directory_key not in self.allowed_directories:
|
||||
raise ValueError("Invalid directory key")
|
||||
directory_path: str = self.allowed_directories[directory_key]
|
||||
return self.file_system_ops.walk_directory(directory_path)
|
42
api_server/utils/file_operations.py
Normal file
42
api_server/utils/file_operations.py
Normal file
@@ -0,0 +1,42 @@
|
||||
import os
|
||||
from typing import List, Union, TypedDict, Literal
|
||||
from typing_extensions import TypeGuard
|
||||
class FileInfo(TypedDict):
|
||||
name: str
|
||||
path: str
|
||||
type: Literal["file"]
|
||||
size: int
|
||||
|
||||
class DirectoryInfo(TypedDict):
|
||||
name: str
|
||||
path: str
|
||||
type: Literal["directory"]
|
||||
|
||||
FileSystemItem = Union[FileInfo, DirectoryInfo]
|
||||
|
||||
def is_file_info(item: FileSystemItem) -> TypeGuard[FileInfo]:
|
||||
return item["type"] == "file"
|
||||
|
||||
class FileSystemOperations:
|
||||
@staticmethod
|
||||
def walk_directory(directory: str) -> List[FileSystemItem]:
|
||||
file_list: List[FileSystemItem] = []
|
||||
for root, dirs, files in os.walk(directory):
|
||||
for name in files:
|
||||
file_path = os.path.join(root, name)
|
||||
relative_path = os.path.relpath(file_path, directory)
|
||||
file_list.append({
|
||||
"name": name,
|
||||
"path": relative_path,
|
||||
"type": "file",
|
||||
"size": os.path.getsize(file_path)
|
||||
})
|
||||
for name in dirs:
|
||||
dir_path = os.path.join(root, name)
|
||||
relative_path = os.path.relpath(dir_path, directory)
|
||||
file_list.append({
|
||||
"name": name,
|
||||
"path": relative_path,
|
||||
"type": "directory"
|
||||
})
|
||||
return file_list
|
@@ -8,7 +8,7 @@ import zipfile
|
||||
from dataclasses import dataclass
|
||||
from functools import cached_property
|
||||
from pathlib import Path
|
||||
from typing import TypedDict
|
||||
from typing import TypedDict, Optional
|
||||
|
||||
import requests
|
||||
from typing_extensions import NotRequired
|
||||
@@ -132,12 +132,13 @@ class FrontendManager:
|
||||
return match_result.group(1), match_result.group(2), match_result.group(3)
|
||||
|
||||
@classmethod
|
||||
def init_frontend_unsafe(cls, version_string: str) -> str:
|
||||
def init_frontend_unsafe(cls, version_string: str, provider: Optional[FrontEndProvider] = None) -> str:
|
||||
"""
|
||||
Initializes the frontend for the specified version.
|
||||
|
||||
Args:
|
||||
version_string (str): The version string.
|
||||
provider (FrontEndProvider, optional): The provider to use. Defaults to None.
|
||||
|
||||
Returns:
|
||||
str: The path to the initialized frontend.
|
||||
@@ -150,7 +151,7 @@ class FrontendManager:
|
||||
return cls.DEFAULT_FRONTEND_PATH
|
||||
|
||||
repo_owner, repo_name, version = cls.parse_version_string(version_string)
|
||||
provider = FrontEndProvider(repo_owner, repo_name)
|
||||
provider = provider or FrontEndProvider(repo_owner, repo_name)
|
||||
release = provider.get_release(version)
|
||||
|
||||
semantic_version = release["tag_name"].lstrip("v")
|
||||
@@ -158,15 +159,21 @@ class FrontendManager:
|
||||
Path(cls.CUSTOM_FRONTENDS_ROOT) / provider.folder_name / semantic_version
|
||||
)
|
||||
if not os.path.exists(web_root):
|
||||
os.makedirs(web_root, exist_ok=True)
|
||||
logging.info(
|
||||
"Downloading frontend(%s) version(%s) to (%s)",
|
||||
provider.folder_name,
|
||||
semantic_version,
|
||||
web_root,
|
||||
)
|
||||
logging.debug(release)
|
||||
download_release_asset_zip(release, destination_path=web_root)
|
||||
try:
|
||||
os.makedirs(web_root, exist_ok=True)
|
||||
logging.info(
|
||||
"Downloading frontend(%s) version(%s) to (%s)",
|
||||
provider.folder_name,
|
||||
semantic_version,
|
||||
web_root,
|
||||
)
|
||||
logging.debug(release)
|
||||
download_release_asset_zip(release, destination_path=web_root)
|
||||
finally:
|
||||
# Clean up the directory if it is empty, i.e. the download failed
|
||||
if not os.listdir(web_root):
|
||||
os.rmdir(web_root)
|
||||
|
||||
return web_root
|
||||
|
||||
@classmethod
|
||||
|
31
app/logger.py
Normal file
31
app/logger.py
Normal file
@@ -0,0 +1,31 @@
|
||||
import logging
|
||||
from logging.handlers import MemoryHandler
|
||||
from collections import deque
|
||||
|
||||
logs = None
|
||||
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
||||
|
||||
|
||||
def get_logs():
|
||||
return "\n".join([formatter.format(x) for x in logs])
|
||||
|
||||
|
||||
def setup_logger(verbose: bool = False, capacity: int = 300):
|
||||
global logs
|
||||
if logs:
|
||||
return
|
||||
|
||||
# Setup default global logger
|
||||
logger = logging.getLogger()
|
||||
logger.setLevel(logging.DEBUG if verbose else logging.INFO)
|
||||
|
||||
stream_handler = logging.StreamHandler()
|
||||
stream_handler.setFormatter(logging.Formatter("%(message)s"))
|
||||
logger.addHandler(stream_handler)
|
||||
|
||||
# Create a memory handler with a deque as its buffer
|
||||
logs = deque(maxlen=capacity)
|
||||
memory_handler = MemoryHandler(capacity, flushLevel=logging.INFO)
|
||||
memory_handler.buffer = logs
|
||||
memory_handler.setFormatter(formatter)
|
||||
logger.addHandler(memory_handler)
|
@@ -92,6 +92,10 @@ class LatentPreviewMethod(enum.Enum):
|
||||
|
||||
parser.add_argument("--preview-method", type=LatentPreviewMethod, default=LatentPreviewMethod.NoPreviews, help="Default preview method for sampler nodes.", action=EnumAction)
|
||||
|
||||
cache_group = parser.add_mutually_exclusive_group()
|
||||
cache_group.add_argument("--cache-classic", action="store_true", help="Use the old style (aggressive) caching.")
|
||||
cache_group.add_argument("--cache-lru", type=int, default=0, help="Use LRU caching with a maximum of N node results cached. May use more RAM/VRAM.")
|
||||
|
||||
attn_group = parser.add_mutually_exclusive_group()
|
||||
attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization. Ignored when xformers is used.")
|
||||
attn_group.add_argument("--use-quad-cross-attention", action="store_true", help="Use the sub-quadratic cross attention optimization . Ignored when xformers is used.")
|
||||
@@ -112,10 +116,14 @@ vram_group.add_argument("--lowvram", action="store_true", help="Split the unet i
|
||||
vram_group.add_argument("--novram", action="store_true", help="When lowvram isn't enough.")
|
||||
vram_group.add_argument("--cpu", action="store_true", help="To use the CPU for everything (slow).")
|
||||
|
||||
parser.add_argument("--reserve-vram", type=float, default=None, help="Set the amount of vram in GB you want to reserve for use by your OS/other software. By default some amount is reverved depending on your OS.")
|
||||
|
||||
|
||||
parser.add_argument("--default-hashing-function", type=str, choices=['md5', 'sha1', 'sha256', 'sha512'], default='sha256', help="Allows you to choose the hash function to use for duplicate filename / contents comparison. Default is sha256.")
|
||||
|
||||
parser.add_argument("--disable-smart-memory", action="store_true", help="Force ComfyUI to agressively offload to regular ram instead of keeping models in vram when it can.")
|
||||
parser.add_argument("--deterministic", action="store_true", help="Make pytorch use slower deterministic algorithms when it can. Note that this might not make images deterministic in all cases.")
|
||||
parser.add_argument("--fast", action="store_true", help="Enable some untested and potentially quality deteriorating optimizations.")
|
||||
|
||||
parser.add_argument("--dont-print-server", action="store_true", help="Don't print server output.")
|
||||
parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test for CI.")
|
||||
@@ -171,10 +179,3 @@ if args.windows_standalone_build:
|
||||
|
||||
if args.disable_auto_launch:
|
||||
args.auto_launch = False
|
||||
|
||||
import logging
|
||||
logging_level = logging.INFO
|
||||
if args.verbose:
|
||||
logging_level = logging.DEBUG
|
||||
|
||||
logging.basicConfig(format="%(message)s", level=logging_level)
|
||||
|
@@ -88,10 +88,11 @@ class CLIPTextModel_(torch.nn.Module):
|
||||
heads = config_dict["num_attention_heads"]
|
||||
intermediate_size = config_dict["intermediate_size"]
|
||||
intermediate_activation = config_dict["hidden_act"]
|
||||
num_positions = config_dict["max_position_embeddings"]
|
||||
self.eos_token_id = config_dict["eos_token_id"]
|
||||
|
||||
super().__init__()
|
||||
self.embeddings = CLIPEmbeddings(embed_dim, dtype=dtype, device=device, operations=operations)
|
||||
self.embeddings = CLIPEmbeddings(embed_dim, num_positions=num_positions, dtype=dtype, device=device, operations=operations)
|
||||
self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations)
|
||||
self.final_layer_norm = operations.LayerNorm(embed_dim, dtype=dtype, device=device)
|
||||
|
||||
@@ -123,7 +124,6 @@ class CLIPTextModel(torch.nn.Module):
|
||||
self.text_model = CLIPTextModel_(config_dict, dtype, device, operations)
|
||||
embed_dim = config_dict["hidden_size"]
|
||||
self.text_projection = operations.Linear(embed_dim, embed_dim, bias=False, dtype=dtype, device=device)
|
||||
self.text_projection.weight.copy_(torch.eye(embed_dim))
|
||||
self.dtype = dtype
|
||||
|
||||
def get_input_embeddings(self):
|
||||
|
@@ -34,6 +34,8 @@ import comfy.t2i_adapter.adapter
|
||||
import comfy.ldm.cascade.controlnet
|
||||
import comfy.cldm.mmdit
|
||||
import comfy.ldm.hydit.controlnet
|
||||
import comfy.ldm.flux.controlnet
|
||||
|
||||
|
||||
def broadcast_image_to(tensor, target_batch_size, batched_number):
|
||||
current_batch_size = tensor.shape[0]
|
||||
@@ -146,7 +148,7 @@ class ControlBase:
|
||||
elif self.strength_type == StrengthType.LINEAR_UP:
|
||||
x *= (self.strength ** float(len(control_output) - i))
|
||||
|
||||
if x.dtype != output_dtype:
|
||||
if output_dtype is not None and x.dtype != output_dtype:
|
||||
x = x.to(output_dtype)
|
||||
|
||||
out[key].append(x)
|
||||
@@ -204,7 +206,6 @@ class ControlNet(ControlBase):
|
||||
if self.manual_cast_dtype is not None:
|
||||
dtype = self.manual_cast_dtype
|
||||
|
||||
output_dtype = x_noisy.dtype
|
||||
if self.cond_hint is None or x_noisy.shape[2] * self.compression_ratio != self.cond_hint.shape[2] or x_noisy.shape[3] * self.compression_ratio != self.cond_hint.shape[3]:
|
||||
if self.cond_hint is not None:
|
||||
del self.cond_hint
|
||||
@@ -234,7 +235,7 @@ class ControlNet(ControlBase):
|
||||
x_noisy = self.model_sampling_current.calculate_input(t, x_noisy)
|
||||
|
||||
control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.to(dtype), context=context.to(dtype), **extra)
|
||||
return self.control_merge(control, control_prev, output_dtype)
|
||||
return self.control_merge(control, control_prev, output_dtype=None)
|
||||
|
||||
def copy(self):
|
||||
c = ControlNet(None, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype)
|
||||
@@ -389,7 +390,8 @@ def controlnet_config(sd):
|
||||
else:
|
||||
operations = comfy.ops.disable_weight_init
|
||||
|
||||
return model_config, operations, load_device, unet_dtype, manual_cast_dtype
|
||||
offload_device = comfy.model_management.unet_offload_device()
|
||||
return model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device
|
||||
|
||||
def controlnet_load_state_dict(control_model, sd):
|
||||
missing, unexpected = control_model.load_state_dict(sd, strict=False)
|
||||
@@ -403,12 +405,12 @@ def controlnet_load_state_dict(control_model, sd):
|
||||
|
||||
def load_controlnet_mmdit(sd):
|
||||
new_sd = comfy.model_detection.convert_diffusers_mmdit(sd, "")
|
||||
model_config, operations, load_device, unet_dtype, manual_cast_dtype = controlnet_config(new_sd)
|
||||
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(new_sd)
|
||||
num_blocks = comfy.model_detection.count_blocks(new_sd, 'joint_blocks.{}.')
|
||||
for k in sd:
|
||||
new_sd[k] = sd[k]
|
||||
|
||||
control_model = comfy.cldm.mmdit.ControlNet(num_blocks=num_blocks, operations=operations, device=load_device, dtype=unet_dtype, **model_config.unet_config)
|
||||
control_model = comfy.cldm.mmdit.ControlNet(num_blocks=num_blocks, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
|
||||
control_model = controlnet_load_state_dict(control_model, new_sd)
|
||||
|
||||
latent_format = comfy.latent_formats.SD3()
|
||||
@@ -416,10 +418,11 @@ def load_controlnet_mmdit(sd):
|
||||
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
|
||||
return control
|
||||
|
||||
def load_controlnet_hunyuandit(controlnet_data):
|
||||
model_config, operations, load_device, unet_dtype, manual_cast_dtype = controlnet_config(controlnet_data)
|
||||
|
||||
control_model = comfy.ldm.hydit.controlnet.HunYuanControlNet(operations=operations, device=load_device, dtype=unet_dtype)
|
||||
def load_controlnet_hunyuandit(controlnet_data):
|
||||
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(controlnet_data)
|
||||
|
||||
control_model = comfy.ldm.hydit.controlnet.HunYuanControlNet(operations=operations, device=offload_device, dtype=unet_dtype)
|
||||
control_model = controlnet_load_state_dict(control_model, controlnet_data)
|
||||
|
||||
latent_format = comfy.latent_formats.SDXL()
|
||||
@@ -427,6 +430,33 @@ def load_controlnet_hunyuandit(controlnet_data):
|
||||
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds, strength_type=StrengthType.CONSTANT)
|
||||
return control
|
||||
|
||||
def load_controlnet_flux_xlabs(sd):
|
||||
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(sd)
|
||||
control_model = comfy.ldm.flux.controlnet.ControlNetFlux(operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
|
||||
control_model = controlnet_load_state_dict(control_model, sd)
|
||||
extra_conds = ['y', 'guidance']
|
||||
control = ControlNet(control_model, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
|
||||
return control
|
||||
|
||||
def load_controlnet_flux_instantx(sd):
|
||||
new_sd = comfy.model_detection.convert_diffusers_mmdit(sd, "")
|
||||
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(new_sd)
|
||||
for k in sd:
|
||||
new_sd[k] = sd[k]
|
||||
|
||||
num_union_modes = 0
|
||||
union_cnet = "controlnet_mode_embedder.weight"
|
||||
if union_cnet in new_sd:
|
||||
num_union_modes = new_sd[union_cnet].shape[0]
|
||||
|
||||
control_model = comfy.ldm.flux.controlnet.ControlNetFlux(latent_input=True, num_union_modes=num_union_modes, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
|
||||
control_model = controlnet_load_state_dict(control_model, new_sd)
|
||||
|
||||
latent_format = comfy.latent_formats.Flux()
|
||||
extra_conds = ['y', 'guidance']
|
||||
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
|
||||
return control
|
||||
|
||||
def load_controlnet(ckpt_path, model=None):
|
||||
controlnet_data = comfy.utils.load_torch_file(ckpt_path, safe_load=True)
|
||||
if 'after_proj_list.18.bias' in controlnet_data.keys(): #Hunyuan DiT
|
||||
@@ -489,7 +519,12 @@ def load_controlnet(ckpt_path, model=None):
|
||||
logging.warning("leftover keys: {}".format(leftover_keys))
|
||||
controlnet_data = new_sd
|
||||
elif "controlnet_blocks.0.weight" in controlnet_data: #SD3 diffusers format
|
||||
return load_controlnet_mmdit(controlnet_data)
|
||||
if "double_blocks.0.img_attn.norm.key_norm.scale" in controlnet_data:
|
||||
return load_controlnet_flux_xlabs(controlnet_data)
|
||||
elif "pos_embed_input.proj.weight" in controlnet_data:
|
||||
return load_controlnet_mmdit(controlnet_data)
|
||||
elif "controlnet_x_embedder.weight" in controlnet_data:
|
||||
return load_controlnet_flux_instantx(controlnet_data)
|
||||
|
||||
pth_key = 'control_model.zero_convs.0.0.weight'
|
||||
pth = False
|
||||
@@ -521,6 +556,7 @@ def load_controlnet(ckpt_path, model=None):
|
||||
if manual_cast_dtype is not None:
|
||||
controlnet_config["operations"] = comfy.ops.manual_cast
|
||||
controlnet_config["dtype"] = unet_dtype
|
||||
controlnet_config["device"] = comfy.model_management.unet_offload_device()
|
||||
controlnet_config.pop("out_channels")
|
||||
controlnet_config["hint_channels"] = controlnet_data["{}input_hint_block.0.weight".format(prefix)].shape[1]
|
||||
control_model = comfy.cldm.cldm.ControlNet(**controlnet_config)
|
||||
|
@@ -22,7 +22,7 @@ def load_diffusers(model_path, output_vae=True, output_clip=True, embedding_dire
|
||||
if text_encoder2_path is not None:
|
||||
text_encoder_paths.append(text_encoder2_path)
|
||||
|
||||
unet = comfy.sd.load_unet(unet_path)
|
||||
unet = comfy.sd.load_diffusion_model(unet_path)
|
||||
|
||||
clip = None
|
||||
if output_clip:
|
||||
|
62
comfy/float.py
Normal file
62
comfy/float.py
Normal file
@@ -0,0 +1,62 @@
|
||||
import torch
|
||||
import math
|
||||
|
||||
def calc_mantissa(abs_x, exponent, normal_mask, MANTISSA_BITS, EXPONENT_BIAS, generator=None):
|
||||
mantissa_scaled = torch.where(
|
||||
normal_mask,
|
||||
(abs_x / (2.0 ** (exponent - EXPONENT_BIAS)) - 1.0) * (2**MANTISSA_BITS),
|
||||
(abs_x / (2.0 ** (-EXPONENT_BIAS + 1 - MANTISSA_BITS)))
|
||||
)
|
||||
|
||||
mantissa_scaled += torch.rand(mantissa_scaled.size(), dtype=mantissa_scaled.dtype, layout=mantissa_scaled.layout, device=mantissa_scaled.device, generator=generator)
|
||||
return mantissa_scaled.floor() / (2**MANTISSA_BITS)
|
||||
|
||||
#Not 100% sure about this
|
||||
def manual_stochastic_round_to_float8(x, dtype, generator=None):
|
||||
if dtype == torch.float8_e4m3fn:
|
||||
EXPONENT_BITS, MANTISSA_BITS, EXPONENT_BIAS = 4, 3, 7
|
||||
elif dtype == torch.float8_e5m2:
|
||||
EXPONENT_BITS, MANTISSA_BITS, EXPONENT_BIAS = 5, 2, 15
|
||||
else:
|
||||
raise ValueError("Unsupported dtype")
|
||||
|
||||
x = x.half()
|
||||
sign = torch.sign(x)
|
||||
abs_x = x.abs()
|
||||
sign = torch.where(abs_x == 0, 0, sign)
|
||||
|
||||
# Combine exponent calculation and clamping
|
||||
exponent = torch.clamp(
|
||||
torch.floor(torch.log2(abs_x)) + EXPONENT_BIAS,
|
||||
0, 2**EXPONENT_BITS - 1
|
||||
)
|
||||
|
||||
# Combine mantissa calculation and rounding
|
||||
normal_mask = ~(exponent == 0)
|
||||
|
||||
abs_x[:] = calc_mantissa(abs_x, exponent, normal_mask, MANTISSA_BITS, EXPONENT_BIAS, generator=generator)
|
||||
|
||||
sign *= torch.where(
|
||||
normal_mask,
|
||||
(2.0 ** (exponent - EXPONENT_BIAS)) * (1.0 + abs_x),
|
||||
(2.0 ** (-EXPONENT_BIAS + 1)) * abs_x
|
||||
)
|
||||
del abs_x
|
||||
|
||||
return sign.to(dtype=dtype)
|
||||
|
||||
|
||||
|
||||
def stochastic_rounding(value, dtype, seed=0):
|
||||
if dtype == torch.float32:
|
||||
return value.to(dtype=torch.float32)
|
||||
if dtype == torch.float16:
|
||||
return value.to(dtype=torch.float16)
|
||||
if dtype == torch.bfloat16:
|
||||
return value.to(dtype=torch.bfloat16)
|
||||
if dtype == torch.float8_e4m3fn or dtype == torch.float8_e5m2:
|
||||
generator = torch.Generator(device=value.device)
|
||||
generator.manual_seed(seed)
|
||||
return manual_stochastic_round_to_float8(value, dtype, generator=generator)
|
||||
|
||||
return value.to(dtype=dtype)
|
@@ -9,6 +9,7 @@ from tqdm.auto import trange, tqdm
|
||||
from . import utils
|
||||
from . import deis
|
||||
import comfy.model_patcher
|
||||
import comfy.model_sampling
|
||||
|
||||
def append_zero(x):
|
||||
return torch.cat([x, x.new_zeros([1])])
|
||||
@@ -509,6 +510,9 @@ def sample_dpm_adaptive(model, x, sigma_min, sigma_max, extra_args=None, callbac
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_dpmpp_2s_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
||||
if isinstance(model.inner_model.inner_model.model_sampling, comfy.model_sampling.CONST):
|
||||
return sample_dpmpp_2s_ancestral_RF(model, x, sigmas, extra_args, callback, disable, eta, s_noise, noise_sampler)
|
||||
|
||||
"""Ancestral sampling with DPM-Solver++(2S) second-order steps."""
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
|
||||
@@ -541,6 +545,55 @@ def sample_dpmpp_2s_ancestral(model, x, sigmas, extra_args=None, callback=None,
|
||||
return x
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_dpmpp_2s_ancestral_RF(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
||||
"""Ancestral sampling with DPM-Solver++(2S) second-order steps."""
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
sigma_fn = lambda lbda: (lbda.exp() + 1) ** -1
|
||||
lambda_fn = lambda sigma: ((1-sigma)/sigma).log()
|
||||
|
||||
# logged_x = x.unsqueeze(0)
|
||||
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||
downstep_ratio = 1 + (sigmas[i+1]/sigmas[i] - 1) * eta
|
||||
sigma_down = sigmas[i+1] * downstep_ratio
|
||||
alpha_ip1 = 1 - sigmas[i+1]
|
||||
alpha_down = 1 - sigma_down
|
||||
renoise_coeff = (sigmas[i+1]**2 - sigma_down**2*alpha_ip1**2/alpha_down**2)**0.5
|
||||
# sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||
if sigmas[i + 1] == 0:
|
||||
# Euler method
|
||||
d = to_d(x, sigmas[i], denoised)
|
||||
dt = sigma_down - sigmas[i]
|
||||
x = x + d * dt
|
||||
else:
|
||||
# DPM-Solver++(2S)
|
||||
if sigmas[i] == 1.0:
|
||||
sigma_s = 0.9999
|
||||
else:
|
||||
t_i, t_down = lambda_fn(sigmas[i]), lambda_fn(sigma_down)
|
||||
r = 1 / 2
|
||||
h = t_down - t_i
|
||||
s = t_i + r * h
|
||||
sigma_s = sigma_fn(s)
|
||||
# sigma_s = sigmas[i+1]
|
||||
sigma_s_i_ratio = sigma_s / sigmas[i]
|
||||
u = sigma_s_i_ratio * x + (1 - sigma_s_i_ratio) * denoised
|
||||
D_i = model(u, sigma_s * s_in, **extra_args)
|
||||
sigma_down_i_ratio = sigma_down / sigmas[i]
|
||||
x = sigma_down_i_ratio * x + (1 - sigma_down_i_ratio) * D_i
|
||||
# print("sigma_i", sigmas[i], "sigma_ip1", sigmas[i+1],"sigma_down", sigma_down, "sigma_down_i_ratio", sigma_down_i_ratio, "sigma_s_i_ratio", sigma_s_i_ratio, "renoise_coeff", renoise_coeff)
|
||||
# Noise addition
|
||||
if sigmas[i + 1] > 0 and eta > 0:
|
||||
x = (alpha_ip1/alpha_down) * x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * renoise_coeff
|
||||
# logged_x = torch.cat((logged_x, x.unsqueeze(0)), dim=0)
|
||||
return x
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_dpmpp_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=1 / 2):
|
||||
"""DPM-Solver++ (stochastic)."""
|
||||
|
@@ -141,6 +141,7 @@ class StableAudio1(LatentFormat):
|
||||
latent_channels = 64
|
||||
|
||||
class Flux(SD3):
|
||||
latent_channels = 16
|
||||
def __init__(self):
|
||||
self.scale_factor = 0.3611
|
||||
self.shift_factor = 0.1159
|
||||
@@ -162,6 +163,7 @@ class Flux(SD3):
|
||||
[-0.0005, -0.0530, -0.0020],
|
||||
[-0.1273, -0.0932, -0.0680]
|
||||
]
|
||||
self.taesd_decoder_name = "taef1_decoder"
|
||||
|
||||
def process_in(self, latent):
|
||||
return (latent - self.shift_factor) * self.scale_factor
|
||||
|
@@ -1,4 +1,5 @@
|
||||
import torch
|
||||
import comfy.ops
|
||||
|
||||
def pad_to_patch_size(img, patch_size=(2, 2), padding_mode="circular"):
|
||||
if padding_mode == "circular" and torch.jit.is_tracing() or torch.jit.is_scripting():
|
||||
@@ -6,3 +7,15 @@ def pad_to_patch_size(img, patch_size=(2, 2), padding_mode="circular"):
|
||||
pad_h = (patch_size[0] - img.shape[-2] % patch_size[0]) % patch_size[0]
|
||||
pad_w = (patch_size[1] - img.shape[-1] % patch_size[1]) % patch_size[1]
|
||||
return torch.nn.functional.pad(img, (0, pad_w, 0, pad_h), mode=padding_mode)
|
||||
|
||||
try:
|
||||
rms_norm_torch = torch.nn.functional.rms_norm
|
||||
except:
|
||||
rms_norm_torch = None
|
||||
|
||||
def rms_norm(x, weight, eps=1e-6):
|
||||
if rms_norm_torch is not None:
|
||||
return rms_norm_torch(x, weight.shape, weight=comfy.ops.cast_to(weight, dtype=x.dtype, device=x.device), eps=eps)
|
||||
else:
|
||||
rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + eps)
|
||||
return (x * rrms) * comfy.ops.cast_to(weight, dtype=x.dtype, device=x.device)
|
||||
|
151
comfy/ldm/flux/controlnet.py
Normal file
151
comfy/ldm/flux/controlnet.py
Normal file
@@ -0,0 +1,151 @@
|
||||
#Original code can be found on: https://github.com/XLabs-AI/x-flux/blob/main/src/flux/controlnet.py
|
||||
|
||||
import torch
|
||||
import math
|
||||
from torch import Tensor, nn
|
||||
from einops import rearrange, repeat
|
||||
|
||||
from .layers import (DoubleStreamBlock, EmbedND, LastLayer,
|
||||
MLPEmbedder, SingleStreamBlock,
|
||||
timestep_embedding)
|
||||
|
||||
from .model import Flux
|
||||
import comfy.ldm.common_dit
|
||||
|
||||
|
||||
class ControlNetFlux(Flux):
|
||||
def __init__(self, latent_input=False, num_union_modes=0, image_model=None, dtype=None, device=None, operations=None, **kwargs):
|
||||
super().__init__(final_layer=False, dtype=dtype, device=device, operations=operations, **kwargs)
|
||||
|
||||
self.main_model_double = 19
|
||||
self.main_model_single = 38
|
||||
# add ControlNet blocks
|
||||
self.controlnet_blocks = nn.ModuleList([])
|
||||
for _ in range(self.params.depth):
|
||||
controlnet_block = operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device)
|
||||
self.controlnet_blocks.append(controlnet_block)
|
||||
|
||||
self.controlnet_single_blocks = nn.ModuleList([])
|
||||
for _ in range(self.params.depth_single_blocks):
|
||||
self.controlnet_single_blocks.append(operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device))
|
||||
|
||||
self.num_union_modes = num_union_modes
|
||||
self.controlnet_mode_embedder = None
|
||||
if self.num_union_modes > 0:
|
||||
self.controlnet_mode_embedder = operations.Embedding(self.num_union_modes, self.hidden_size, dtype=dtype, device=device)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
self.latent_input = latent_input
|
||||
self.pos_embed_input = operations.Linear(self.in_channels, self.hidden_size, bias=True, dtype=dtype, device=device)
|
||||
if not self.latent_input:
|
||||
self.input_hint_block = nn.Sequential(
|
||||
operations.Conv2d(3, 16, 3, padding=1, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device)
|
||||
)
|
||||
|
||||
def forward_orig(
|
||||
self,
|
||||
img: Tensor,
|
||||
img_ids: Tensor,
|
||||
controlnet_cond: Tensor,
|
||||
txt: Tensor,
|
||||
txt_ids: Tensor,
|
||||
timesteps: Tensor,
|
||||
y: Tensor,
|
||||
guidance: Tensor = None,
|
||||
control_type: Tensor = None,
|
||||
) -> Tensor:
|
||||
if img.ndim != 3 or txt.ndim != 3:
|
||||
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
||||
|
||||
# running on sequences img
|
||||
img = self.img_in(img)
|
||||
if not self.latent_input:
|
||||
controlnet_cond = self.input_hint_block(controlnet_cond)
|
||||
controlnet_cond = rearrange(controlnet_cond, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
|
||||
|
||||
controlnet_cond = self.pos_embed_input(controlnet_cond)
|
||||
img = img + controlnet_cond
|
||||
vec = self.time_in(timestep_embedding(timesteps, 256))
|
||||
if self.params.guidance_embed:
|
||||
vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
|
||||
vec = vec + self.vector_in(y)
|
||||
txt = self.txt_in(txt)
|
||||
|
||||
if self.controlnet_mode_embedder is not None and len(control_type) > 0:
|
||||
control_cond = self.controlnet_mode_embedder(torch.tensor(control_type, device=img.device), out_dtype=img.dtype).unsqueeze(0).repeat((txt.shape[0], 1, 1))
|
||||
txt = torch.cat([control_cond, txt], dim=1)
|
||||
txt_ids = torch.cat([txt_ids[:,:1], txt_ids], dim=1)
|
||||
|
||||
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||
pe = self.pe_embedder(ids)
|
||||
|
||||
controlnet_double = ()
|
||||
|
||||
for i in range(len(self.double_blocks)):
|
||||
img, txt = self.double_blocks[i](img=img, txt=txt, vec=vec, pe=pe)
|
||||
controlnet_double = controlnet_double + (self.controlnet_blocks[i](img),)
|
||||
|
||||
img = torch.cat((txt, img), 1)
|
||||
|
||||
controlnet_single = ()
|
||||
|
||||
for i in range(len(self.single_blocks)):
|
||||
img = self.single_blocks[i](img, vec=vec, pe=pe)
|
||||
controlnet_single = controlnet_single + (self.controlnet_single_blocks[i](img[:, txt.shape[1] :, ...]),)
|
||||
|
||||
repeat = math.ceil(self.main_model_double / len(controlnet_double))
|
||||
if self.latent_input:
|
||||
out_input = ()
|
||||
for x in controlnet_double:
|
||||
out_input += (x,) * repeat
|
||||
else:
|
||||
out_input = (controlnet_double * repeat)
|
||||
|
||||
out = {"input": out_input[:self.main_model_double]}
|
||||
if len(controlnet_single) > 0:
|
||||
repeat = math.ceil(self.main_model_single / len(controlnet_single))
|
||||
out_output = ()
|
||||
if self.latent_input:
|
||||
for x in controlnet_single:
|
||||
out_output += (x,) * repeat
|
||||
else:
|
||||
out_output = (controlnet_single * repeat)
|
||||
out["output"] = out_output[:self.main_model_single]
|
||||
return out
|
||||
|
||||
def forward(self, x, timesteps, context, y, guidance=None, hint=None, **kwargs):
|
||||
patch_size = 2
|
||||
if self.latent_input:
|
||||
hint = comfy.ldm.common_dit.pad_to_patch_size(hint, (patch_size, patch_size))
|
||||
hint = rearrange(hint, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
|
||||
else:
|
||||
hint = hint * 2.0 - 1.0
|
||||
|
||||
bs, c, h, w = x.shape
|
||||
x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size))
|
||||
|
||||
img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
|
||||
|
||||
h_len = ((h + (patch_size // 2)) // patch_size)
|
||||
w_len = ((w + (patch_size // 2)) // patch_size)
|
||||
img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
|
||||
img_ids[..., 1] = img_ids[..., 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype)[:, None]
|
||||
img_ids[..., 2] = img_ids[..., 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype)[None, :]
|
||||
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
||||
|
||||
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
|
||||
return self.forward_orig(img, img_ids, hint, context, txt_ids, timesteps, y, guidance, control_type=kwargs.get("control_type", []))
|
@@ -6,6 +6,7 @@ from torch import Tensor, nn
|
||||
|
||||
from .math import attention, rope
|
||||
import comfy.ops
|
||||
import comfy.ldm.common_dit
|
||||
|
||||
|
||||
class EmbedND(nn.Module):
|
||||
@@ -63,10 +64,7 @@ class RMSNorm(torch.nn.Module):
|
||||
self.scale = nn.Parameter(torch.empty((dim), dtype=dtype, device=device))
|
||||
|
||||
def forward(self, x: Tensor):
|
||||
x_dtype = x.dtype
|
||||
x = x.float()
|
||||
rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6)
|
||||
return (x * rrms).to(dtype=x_dtype) * comfy.ops.cast_to(self.scale, dtype=x_dtype, device=x.device)
|
||||
return comfy.ldm.common_dit.rms_norm(x, self.scale, 1e-6)
|
||||
|
||||
|
||||
class QKNorm(torch.nn.Module):
|
||||
@@ -170,15 +168,15 @@ class DoubleStreamBlock(nn.Module):
|
||||
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
|
||||
|
||||
# calculate the img bloks
|
||||
img += img_mod1.gate * self.img_attn.proj(img_attn)
|
||||
img += img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift)
|
||||
img = img + img_mod1.gate * self.img_attn.proj(img_attn)
|
||||
img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift)
|
||||
|
||||
# calculate the txt bloks
|
||||
txt += txt_mod1.gate * self.txt_attn.proj(txt_attn)
|
||||
txt += txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
|
||||
|
||||
if txt.dtype == torch.float16:
|
||||
txt = txt.clip(-65504, 65504)
|
||||
txt = torch.nan_to_num(txt, nan=0.0, posinf=65504, neginf=-65504)
|
||||
|
||||
return img, txt
|
||||
|
||||
@@ -233,7 +231,7 @@ class SingleStreamBlock(nn.Module):
|
||||
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
|
||||
x += mod.gate * output
|
||||
if x.dtype == torch.float16:
|
||||
x = x.clip(-65504, 65504)
|
||||
x = torch.nan_to_num(x, nan=0.0, posinf=65504, neginf=-65504)
|
||||
return x
|
||||
|
||||
|
||||
|
@@ -38,7 +38,7 @@ class Flux(nn.Module):
|
||||
Transformer model for flow matching on sequences.
|
||||
"""
|
||||
|
||||
def __init__(self, image_model=None, dtype=None, device=None, operations=None, **kwargs):
|
||||
def __init__(self, image_model=None, final_layer=True, dtype=None, device=None, operations=None, **kwargs):
|
||||
super().__init__()
|
||||
self.dtype = dtype
|
||||
params = FluxParams(**kwargs)
|
||||
@@ -83,7 +83,8 @@ class Flux(nn.Module):
|
||||
]
|
||||
)
|
||||
|
||||
self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels, dtype=dtype, device=device, operations=operations)
|
||||
if final_layer:
|
||||
self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
def forward_orig(
|
||||
self,
|
||||
@@ -94,6 +95,7 @@ class Flux(nn.Module):
|
||||
timesteps: Tensor,
|
||||
y: Tensor,
|
||||
guidance: Tensor = None,
|
||||
control=None,
|
||||
) -> Tensor:
|
||||
if img.ndim != 3 or txt.ndim != 3:
|
||||
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
||||
@@ -112,18 +114,34 @@ class Flux(nn.Module):
|
||||
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||
pe = self.pe_embedder(ids)
|
||||
|
||||
for block in self.double_blocks:
|
||||
for i, block in enumerate(self.double_blocks):
|
||||
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
|
||||
|
||||
if control is not None: # Controlnet
|
||||
control_i = control.get("input")
|
||||
if i < len(control_i):
|
||||
add = control_i[i]
|
||||
if add is not None:
|
||||
img += add
|
||||
|
||||
img = torch.cat((txt, img), 1)
|
||||
for block in self.single_blocks:
|
||||
|
||||
for i, block in enumerate(self.single_blocks):
|
||||
img = block(img, vec=vec, pe=pe)
|
||||
|
||||
if control is not None: # Controlnet
|
||||
control_o = control.get("output")
|
||||
if i < len(control_o):
|
||||
add = control_o[i]
|
||||
if add is not None:
|
||||
img[:, txt.shape[1] :, ...] += add
|
||||
|
||||
img = img[:, txt.shape[1] :, ...]
|
||||
|
||||
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
|
||||
return img
|
||||
|
||||
def forward(self, x, timestep, context, y, guidance, **kwargs):
|
||||
def forward(self, x, timestep, context, y, guidance, control=None, **kwargs):
|
||||
bs, c, h, w = x.shape
|
||||
patch_size = 2
|
||||
x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size))
|
||||
@@ -138,5 +156,5 @@ class Flux(nn.Module):
|
||||
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
||||
|
||||
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
|
||||
out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance)
|
||||
out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control)
|
||||
return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)[:,:,:h,:w]
|
||||
|
@@ -372,7 +372,7 @@ class HunYuanDiT(nn.Module):
|
||||
for layer, block in enumerate(self.blocks):
|
||||
if layer > self.depth // 2:
|
||||
if controls is not None:
|
||||
skip = skips.pop() + controls.pop()
|
||||
skip = skips.pop() + controls.pop().to(dtype=x.dtype)
|
||||
else:
|
||||
skip = skips.pop()
|
||||
x = block(x, c, text_states, freqs_cis_img, skip) # (N, L, D)
|
||||
|
@@ -358,7 +358,7 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_resh
|
||||
disabled_xformers = True
|
||||
|
||||
if disabled_xformers:
|
||||
return attention_pytorch(q, k, v, heads, mask)
|
||||
return attention_pytorch(q, k, v, heads, mask, skip_reshape=skip_reshape)
|
||||
|
||||
if skip_reshape:
|
||||
q, k, v = map(
|
||||
|
@@ -355,29 +355,9 @@ class RMSNorm(torch.nn.Module):
|
||||
else:
|
||||
self.register_parameter("weight", None)
|
||||
|
||||
def _norm(self, x):
|
||||
"""
|
||||
Apply the RMSNorm normalization to the input tensor.
|
||||
Args:
|
||||
x (torch.Tensor): The input tensor.
|
||||
Returns:
|
||||
torch.Tensor: The normalized tensor.
|
||||
"""
|
||||
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Forward pass through the RMSNorm layer.
|
||||
Args:
|
||||
x (torch.Tensor): The input tensor.
|
||||
Returns:
|
||||
torch.Tensor: The output tensor after applying RMSNorm.
|
||||
"""
|
||||
x = self._norm(x)
|
||||
if self.learnable_scale:
|
||||
return x * self.weight.to(device=x.device, dtype=x.dtype)
|
||||
else:
|
||||
return x
|
||||
return comfy.ldm.common_dit.rms_norm(x, self.weight, self.eps)
|
||||
|
||||
|
||||
|
||||
class SwiGLUFeedForward(nn.Module):
|
||||
|
236
comfy/lora.py
236
comfy/lora.py
@@ -16,8 +16,12 @@
|
||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
import comfy.utils
|
||||
import comfy.model_management
|
||||
import comfy.model_base
|
||||
import logging
|
||||
import torch
|
||||
|
||||
LORA_CLIP_MAP = {
|
||||
"mlp.fc1": "mlp_fc1",
|
||||
@@ -318,7 +322,235 @@ def model_lora_keys_unet(model, key_map={}):
|
||||
for k in diffusers_keys:
|
||||
if k.endswith(".weight"):
|
||||
to = diffusers_keys[k]
|
||||
key_lora = "transformer.{}".format(k[:-len(".weight")]) #simpletrainer and probably regular diffusers flux lora format
|
||||
key_map[key_lora] = to
|
||||
key_map["transformer.{}".format(k[:-len(".weight")])] = to #simpletrainer and probably regular diffusers flux lora format
|
||||
key_map["lycoris_{}".format(k[:-len(".weight")].replace(".", "_"))] = to #simpletrainer lycoris
|
||||
|
||||
return key_map
|
||||
|
||||
|
||||
def weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype):
|
||||
dora_scale = comfy.model_management.cast_to_device(dora_scale, weight.device, intermediate_dtype)
|
||||
lora_diff *= alpha
|
||||
weight_calc = weight + lora_diff.type(weight.dtype)
|
||||
weight_norm = (
|
||||
weight_calc.transpose(0, 1)
|
||||
.reshape(weight_calc.shape[1], -1)
|
||||
.norm(dim=1, keepdim=True)
|
||||
.reshape(weight_calc.shape[1], *[1] * (weight_calc.dim() - 1))
|
||||
.transpose(0, 1)
|
||||
)
|
||||
|
||||
weight_calc *= (dora_scale / weight_norm).type(weight.dtype)
|
||||
if strength != 1.0:
|
||||
weight_calc -= weight
|
||||
weight += strength * (weight_calc)
|
||||
else:
|
||||
weight[:] = weight_calc
|
||||
return weight
|
||||
|
||||
def pad_tensor_to_shape(tensor: torch.Tensor, new_shape: list[int]) -> torch.Tensor:
|
||||
"""
|
||||
Pad a tensor to a new shape with zeros.
|
||||
|
||||
Args:
|
||||
tensor (torch.Tensor): The original tensor to be padded.
|
||||
new_shape (List[int]): The desired shape of the padded tensor.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: A new tensor padded with zeros to the specified shape.
|
||||
|
||||
Note:
|
||||
If the new shape is smaller than the original tensor in any dimension,
|
||||
the original tensor will be truncated in that dimension.
|
||||
"""
|
||||
if any([new_shape[i] < tensor.shape[i] for i in range(len(new_shape))]):
|
||||
raise ValueError("The new shape must be larger than the original tensor in all dimensions")
|
||||
|
||||
if len(new_shape) != len(tensor.shape):
|
||||
raise ValueError("The new shape must have the same number of dimensions as the original tensor")
|
||||
|
||||
# Create a new tensor filled with zeros
|
||||
padded_tensor = torch.zeros(new_shape, dtype=tensor.dtype, device=tensor.device)
|
||||
|
||||
# Create slicing tuples for both tensors
|
||||
orig_slices = tuple(slice(0, dim) for dim in tensor.shape)
|
||||
new_slices = tuple(slice(0, dim) for dim in tensor.shape)
|
||||
|
||||
# Copy the original tensor into the new tensor
|
||||
padded_tensor[new_slices] = tensor[orig_slices]
|
||||
|
||||
return padded_tensor
|
||||
|
||||
def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32):
|
||||
for p in patches:
|
||||
strength = p[0]
|
||||
v = p[1]
|
||||
strength_model = p[2]
|
||||
offset = p[3]
|
||||
function = p[4]
|
||||
if function is None:
|
||||
function = lambda a: a
|
||||
|
||||
old_weight = None
|
||||
if offset is not None:
|
||||
old_weight = weight
|
||||
weight = weight.narrow(offset[0], offset[1], offset[2])
|
||||
|
||||
if strength_model != 1.0:
|
||||
weight *= strength_model
|
||||
|
||||
if isinstance(v, list):
|
||||
v = (calculate_weight(v[1:], v[0].clone(), key, intermediate_dtype=intermediate_dtype), )
|
||||
|
||||
if len(v) == 1:
|
||||
patch_type = "diff"
|
||||
elif len(v) == 2:
|
||||
patch_type = v[0]
|
||||
v = v[1]
|
||||
|
||||
if patch_type == "diff":
|
||||
diff: torch.Tensor = v[0]
|
||||
# An extra flag to pad the weight if the diff's shape is larger than the weight
|
||||
do_pad_weight = len(v) > 1 and v[1]['pad_weight']
|
||||
if do_pad_weight and diff.shape != weight.shape:
|
||||
logging.info("Pad weight {} from {} to shape: {}".format(key, weight.shape, diff.shape))
|
||||
weight = pad_tensor_to_shape(weight, diff.shape)
|
||||
|
||||
if strength != 0.0:
|
||||
if diff.shape != weight.shape:
|
||||
logging.warning("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, diff.shape, weight.shape))
|
||||
else:
|
||||
weight += function(strength * comfy.model_management.cast_to_device(diff, weight.device, weight.dtype))
|
||||
elif patch_type == "lora": #lora/locon
|
||||
mat1 = comfy.model_management.cast_to_device(v[0], weight.device, intermediate_dtype)
|
||||
mat2 = comfy.model_management.cast_to_device(v[1], weight.device, intermediate_dtype)
|
||||
dora_scale = v[4]
|
||||
if v[2] is not None:
|
||||
alpha = v[2] / mat2.shape[0]
|
||||
else:
|
||||
alpha = 1.0
|
||||
|
||||
if v[3] is not None:
|
||||
#locon mid weights, hopefully the math is fine because I didn't properly test it
|
||||
mat3 = comfy.model_management.cast_to_device(v[3], weight.device, intermediate_dtype)
|
||||
final_shape = [mat2.shape[1], mat2.shape[0], mat3.shape[2], mat3.shape[3]]
|
||||
mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1), mat3.transpose(0, 1).flatten(start_dim=1)).reshape(final_shape).transpose(0, 1)
|
||||
try:
|
||||
lora_diff = torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1)).reshape(weight.shape)
|
||||
if dora_scale is not None:
|
||||
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype))
|
||||
else:
|
||||
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
||||
except Exception as e:
|
||||
logging.error("ERROR {} {} {}".format(patch_type, key, e))
|
||||
elif patch_type == "lokr":
|
||||
w1 = v[0]
|
||||
w2 = v[1]
|
||||
w1_a = v[3]
|
||||
w1_b = v[4]
|
||||
w2_a = v[5]
|
||||
w2_b = v[6]
|
||||
t2 = v[7]
|
||||
dora_scale = v[8]
|
||||
dim = None
|
||||
|
||||
if w1 is None:
|
||||
dim = w1_b.shape[0]
|
||||
w1 = torch.mm(comfy.model_management.cast_to_device(w1_a, weight.device, intermediate_dtype),
|
||||
comfy.model_management.cast_to_device(w1_b, weight.device, intermediate_dtype))
|
||||
else:
|
||||
w1 = comfy.model_management.cast_to_device(w1, weight.device, intermediate_dtype)
|
||||
|
||||
if w2 is None:
|
||||
dim = w2_b.shape[0]
|
||||
if t2 is None:
|
||||
w2 = torch.mm(comfy.model_management.cast_to_device(w2_a, weight.device, intermediate_dtype),
|
||||
comfy.model_management.cast_to_device(w2_b, weight.device, intermediate_dtype))
|
||||
else:
|
||||
w2 = torch.einsum('i j k l, j r, i p -> p r k l',
|
||||
comfy.model_management.cast_to_device(t2, weight.device, intermediate_dtype),
|
||||
comfy.model_management.cast_to_device(w2_b, weight.device, intermediate_dtype),
|
||||
comfy.model_management.cast_to_device(w2_a, weight.device, intermediate_dtype))
|
||||
else:
|
||||
w2 = comfy.model_management.cast_to_device(w2, weight.device, intermediate_dtype)
|
||||
|
||||
if len(w2.shape) == 4:
|
||||
w1 = w1.unsqueeze(2).unsqueeze(2)
|
||||
if v[2] is not None and dim is not None:
|
||||
alpha = v[2] / dim
|
||||
else:
|
||||
alpha = 1.0
|
||||
|
||||
try:
|
||||
lora_diff = torch.kron(w1, w2).reshape(weight.shape)
|
||||
if dora_scale is not None:
|
||||
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype))
|
||||
else:
|
||||
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
||||
except Exception as e:
|
||||
logging.error("ERROR {} {} {}".format(patch_type, key, e))
|
||||
elif patch_type == "loha":
|
||||
w1a = v[0]
|
||||
w1b = v[1]
|
||||
if v[2] is not None:
|
||||
alpha = v[2] / w1b.shape[0]
|
||||
else:
|
||||
alpha = 1.0
|
||||
|
||||
w2a = v[3]
|
||||
w2b = v[4]
|
||||
dora_scale = v[7]
|
||||
if v[5] is not None: #cp decomposition
|
||||
t1 = v[5]
|
||||
t2 = v[6]
|
||||
m1 = torch.einsum('i j k l, j r, i p -> p r k l',
|
||||
comfy.model_management.cast_to_device(t1, weight.device, intermediate_dtype),
|
||||
comfy.model_management.cast_to_device(w1b, weight.device, intermediate_dtype),
|
||||
comfy.model_management.cast_to_device(w1a, weight.device, intermediate_dtype))
|
||||
|
||||
m2 = torch.einsum('i j k l, j r, i p -> p r k l',
|
||||
comfy.model_management.cast_to_device(t2, weight.device, intermediate_dtype),
|
||||
comfy.model_management.cast_to_device(w2b, weight.device, intermediate_dtype),
|
||||
comfy.model_management.cast_to_device(w2a, weight.device, intermediate_dtype))
|
||||
else:
|
||||
m1 = torch.mm(comfy.model_management.cast_to_device(w1a, weight.device, intermediate_dtype),
|
||||
comfy.model_management.cast_to_device(w1b, weight.device, intermediate_dtype))
|
||||
m2 = torch.mm(comfy.model_management.cast_to_device(w2a, weight.device, intermediate_dtype),
|
||||
comfy.model_management.cast_to_device(w2b, weight.device, intermediate_dtype))
|
||||
|
||||
try:
|
||||
lora_diff = (m1 * m2).reshape(weight.shape)
|
||||
if dora_scale is not None:
|
||||
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype))
|
||||
else:
|
||||
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
||||
except Exception as e:
|
||||
logging.error("ERROR {} {} {}".format(patch_type, key, e))
|
||||
elif patch_type == "glora":
|
||||
if v[4] is not None:
|
||||
alpha = v[4] / v[0].shape[0]
|
||||
else:
|
||||
alpha = 1.0
|
||||
|
||||
dora_scale = v[5]
|
||||
|
||||
a1 = comfy.model_management.cast_to_device(v[0].flatten(start_dim=1), weight.device, intermediate_dtype)
|
||||
a2 = comfy.model_management.cast_to_device(v[1].flatten(start_dim=1), weight.device, intermediate_dtype)
|
||||
b1 = comfy.model_management.cast_to_device(v[2].flatten(start_dim=1), weight.device, intermediate_dtype)
|
||||
b2 = comfy.model_management.cast_to_device(v[3].flatten(start_dim=1), weight.device, intermediate_dtype)
|
||||
|
||||
try:
|
||||
lora_diff = (torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1).to(dtype=intermediate_dtype), a2), a1)).reshape(weight.shape)
|
||||
if dora_scale is not None:
|
||||
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype))
|
||||
else:
|
||||
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
||||
except Exception as e:
|
||||
logging.error("ERROR {} {} {}".format(patch_type, key, e))
|
||||
else:
|
||||
logging.warning("patch type not recognized {} {}".format(patch_type, key))
|
||||
|
||||
if old_weight is not None:
|
||||
weight = old_weight
|
||||
|
||||
return weight
|
||||
|
@@ -1,3 +1,21 @@
|
||||
"""
|
||||
This file is part of ComfyUI.
|
||||
Copyright (C) 2024 Comfy
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU General Public License as published by
|
||||
the Free Software Foundation, either version 3 of the License, or
|
||||
(at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU General Public License
|
||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
"""
|
||||
|
||||
import torch
|
||||
import logging
|
||||
from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep
|
||||
@@ -77,10 +95,10 @@ class BaseModel(torch.nn.Module):
|
||||
self.device = device
|
||||
|
||||
if not unet_config.get("disable_unet_model_creation", False):
|
||||
if self.manual_cast_dtype is not None:
|
||||
operations = comfy.ops.manual_cast
|
||||
if model_config.custom_operations is None:
|
||||
operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype)
|
||||
else:
|
||||
operations = comfy.ops.disable_weight_init
|
||||
operations = model_config.custom_operations
|
||||
self.diffusion_model = unet_model(**unet_config, device=device, operations=operations)
|
||||
if comfy.model_management.force_channels_last():
|
||||
self.diffusion_model.to(memory_format=torch.channels_last)
|
||||
|
@@ -472,9 +472,15 @@ def unet_config_from_diffusers_unet(state_dict, dtype=None):
|
||||
'transformer_depth': [0, 1, 1], 'channel_mult': [1, 2, 4], 'transformer_depth_middle': -2, 'use_linear_in_transformer': False,
|
||||
'context_dim': 768, 'num_head_channels': 64, 'transformer_depth_output': [0, 0, 1, 1, 1, 1],
|
||||
'use_temporal_attention': False, 'use_temporal_resblock': False}
|
||||
|
||||
SD15_diffusers_inpaint = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, 'adm_in_channels': None,
|
||||
'dtype': dtype, 'in_channels': 9, 'model_channels': 320, 'num_res_blocks': [2, 2, 2, 2], 'transformer_depth': [1, 1, 1, 1, 1, 1, 0, 0],
|
||||
'channel_mult': [1, 2, 4, 4], 'transformer_depth_middle': 1, 'use_linear_in_transformer': False, 'context_dim': 768, 'num_heads': 8,
|
||||
'transformer_depth_output': [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
|
||||
'use_temporal_attention': False, 'use_temporal_resblock': False}
|
||||
|
||||
|
||||
supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl, SDXL_mid_cnet, SDXL_small_cnet, SDXL_diffusers_inpaint, SSD_1B, Segmind_Vega, KOALA_700M, KOALA_1B, SD09_XS, SD_XS, SDXL_diffusers_ip2p]
|
||||
supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl, SDXL_mid_cnet, SDXL_small_cnet, SDXL_diffusers_inpaint, SSD_1B, Segmind_Vega, KOALA_700M, KOALA_1B, SD09_XS, SD_XS, SDXL_diffusers_ip2p, SD15_diffusers_inpaint]
|
||||
|
||||
for unet_config in supported_models:
|
||||
matches = True
|
||||
|
@@ -44,9 +44,15 @@ cpu_state = CPUState.GPU
|
||||
|
||||
total_vram = 0
|
||||
|
||||
lowvram_available = True
|
||||
xpu_available = False
|
||||
torch_version = ""
|
||||
try:
|
||||
torch_version = torch.version.__version__
|
||||
xpu_available = (int(torch_version[0]) < 2 or (int(torch_version[0]) == 2 and int(torch_version[2]) <= 4)) and torch.xpu.is_available()
|
||||
except:
|
||||
pass
|
||||
|
||||
lowvram_available = True
|
||||
if args.deterministic:
|
||||
logging.info("Using deterministic algorithms for pytorch")
|
||||
torch.use_deterministic_algorithms(True, warn_only=True)
|
||||
@@ -66,10 +72,10 @@ if args.directml is not None:
|
||||
|
||||
try:
|
||||
import intel_extension_for_pytorch as ipex
|
||||
if torch.xpu.is_available():
|
||||
xpu_available = True
|
||||
_ = torch.xpu.device_count()
|
||||
xpu_available = torch.xpu.is_available()
|
||||
except:
|
||||
pass
|
||||
xpu_available = xpu_available or (hasattr(torch, "xpu") and torch.xpu.is_available())
|
||||
|
||||
try:
|
||||
if torch.backends.mps.is_available():
|
||||
@@ -189,7 +195,6 @@ VAE_DTYPES = [torch.float32]
|
||||
|
||||
try:
|
||||
if is_nvidia():
|
||||
torch_version = torch.version.__version__
|
||||
if int(torch_version[0]) >= 2:
|
||||
if ENABLE_PYTORCH_ATTENTION == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
|
||||
ENABLE_PYTORCH_ATTENTION = True
|
||||
@@ -315,17 +320,15 @@ class LoadedModel:
|
||||
self.model_use_more_vram(use_more_vram)
|
||||
else:
|
||||
try:
|
||||
if lowvram_model_memory > 0 and load_weights:
|
||||
self.real_model = self.model.patch_model_lowvram(device_to=patch_model_to, lowvram_model_memory=lowvram_model_memory, force_patch_weights=force_patch_weights)
|
||||
else:
|
||||
self.real_model = self.model.patch_model(device_to=patch_model_to, patch_weights=load_weights)
|
||||
self.real_model = self.model.patch_model(device_to=patch_model_to, lowvram_model_memory=lowvram_model_memory, load_weights=load_weights, force_patch_weights=force_patch_weights)
|
||||
except Exception as e:
|
||||
self.model.unpatch_model(self.model.offload_device)
|
||||
self.model_unload()
|
||||
raise e
|
||||
|
||||
if is_intel_xpu() and not args.disable_ipex_optimize:
|
||||
self.real_model = ipex.optimize(self.real_model.eval(), graph_mode=True, concat_linear=True)
|
||||
if is_intel_xpu() and not args.disable_ipex_optimize and self.real_model is not None:
|
||||
with torch.no_grad():
|
||||
self.real_model = ipex.optimize(self.real_model.eval(), inplace=True, graph_mode=True, concat_linear=True)
|
||||
|
||||
self.weights_loaded = True
|
||||
return self.real_model
|
||||
@@ -338,8 +341,9 @@ class LoadedModel:
|
||||
def model_unload(self, memory_to_free=None, unpatch_weights=True):
|
||||
if memory_to_free is not None:
|
||||
if memory_to_free < self.model.loaded_size():
|
||||
self.model.partially_unload(self.model.offload_device, memory_to_free)
|
||||
return False
|
||||
freed = self.model.partially_unload(self.model.offload_device, memory_to_free)
|
||||
if freed >= memory_to_free:
|
||||
return False
|
||||
self.model.unpatch_model(self.model.offload_device, unpatch_weights=unpatch_weights)
|
||||
self.model.model_patches_to(self.model.offload_device)
|
||||
self.weights_loaded = self.weights_loaded and not unpatch_weights
|
||||
@@ -366,8 +370,21 @@ def offloaded_memory(loaded_models, device):
|
||||
offloaded_mem += m.model_offloaded_memory()
|
||||
return offloaded_mem
|
||||
|
||||
WINDOWS = any(platform.win32_ver())
|
||||
|
||||
EXTRA_RESERVED_VRAM = 400 * 1024 * 1024
|
||||
if WINDOWS:
|
||||
EXTRA_RESERVED_VRAM = 600 * 1024 * 1024 #Windows is higher because of the shared vram issue
|
||||
|
||||
if args.reserve_vram is not None:
|
||||
EXTRA_RESERVED_VRAM = args.reserve_vram * 1024 * 1024 * 1024
|
||||
logging.debug("Reserving {}MB vram for other applications.".format(EXTRA_RESERVED_VRAM / (1024 * 1024)))
|
||||
|
||||
def extra_reserved_memory():
|
||||
return EXTRA_RESERVED_VRAM
|
||||
|
||||
def minimum_inference_memory():
|
||||
return (1024 * 1024 * 1024) * 1.2
|
||||
return (1024 * 1024 * 1024) * 0.8 + extra_reserved_memory()
|
||||
|
||||
def unload_model_clones(model, unload_weights_only=True, force_unload=True):
|
||||
to_unload = []
|
||||
@@ -391,6 +408,8 @@ def unload_model_clones(model, unload_weights_only=True, force_unload=True):
|
||||
if not force_unload:
|
||||
if unload_weights_only and unload_weight == False:
|
||||
return None
|
||||
else:
|
||||
unload_weight = True
|
||||
|
||||
for i in to_unload:
|
||||
logging.debug("unload clone {} {}".format(i, unload_weight))
|
||||
@@ -434,15 +453,15 @@ def free_memory(memory_required, device, keep_loaded=[]):
|
||||
soft_empty_cache()
|
||||
return unloaded_models
|
||||
|
||||
def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimum_memory_required=None):
|
||||
def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimum_memory_required=None, force_full_load=False):
|
||||
global vram_state
|
||||
|
||||
inference_memory = minimum_inference_memory()
|
||||
extra_mem = max(inference_memory, memory_required) + 100 * 1024 * 1024
|
||||
extra_mem = max(inference_memory, memory_required + extra_reserved_memory())
|
||||
if minimum_memory_required is None:
|
||||
minimum_memory_required = extra_mem
|
||||
else:
|
||||
minimum_memory_required = max(inference_memory, minimum_memory_required) + 100 * 1024 * 1024
|
||||
minimum_memory_required = max(inference_memory, minimum_memory_required + extra_reserved_memory())
|
||||
|
||||
models = set(models)
|
||||
|
||||
@@ -513,7 +532,7 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
|
||||
else:
|
||||
vram_set_state = vram_state
|
||||
lowvram_model_memory = 0
|
||||
if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM):
|
||||
if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM) and not force_full_load:
|
||||
model_size = loaded_model.model_memory_required(torch_dev)
|
||||
current_free_mem = get_free_memory(torch_dev)
|
||||
lowvram_model_memory = max(64 * (1024 * 1024), (current_free_mem - minimum_memory_required), min(current_free_mem * 0.4, current_free_mem - minimum_inference_memory()))
|
||||
@@ -552,7 +571,9 @@ def loaded_models(only_currently_used=False):
|
||||
def cleanup_models(keep_clone_weights_loaded=False):
|
||||
to_delete = []
|
||||
for i in range(len(current_loaded_models)):
|
||||
if sys.getrefcount(current_loaded_models[i].model) <= 2:
|
||||
#TODO: very fragile function needs improvement
|
||||
num_refs = sys.getrefcount(current_loaded_models[i].model)
|
||||
if num_refs <= 2:
|
||||
if not keep_clone_weights_loaded:
|
||||
to_delete = [i] + to_delete
|
||||
#TODO: find a less fragile way to do this.
|
||||
@@ -659,6 +680,7 @@ def unet_manual_cast(weight_dtype, inference_device, supported_dtypes=[torch.flo
|
||||
if bf16_supported and weight_dtype == torch.bfloat16:
|
||||
return None
|
||||
|
||||
fp16_supported = should_use_fp16(inference_device, prioritize_performance=True)
|
||||
for dt in supported_dtypes:
|
||||
if dt == torch.float16 and fp16_supported:
|
||||
return torch.float16
|
||||
@@ -684,6 +706,20 @@ def text_encoder_device():
|
||||
else:
|
||||
return torch.device("cpu")
|
||||
|
||||
def text_encoder_initial_device(load_device, offload_device, model_size=0):
|
||||
if load_device == offload_device or model_size <= 1024 * 1024 * 1024:
|
||||
return offload_device
|
||||
|
||||
if is_device_mps(load_device):
|
||||
return offload_device
|
||||
|
||||
mem_l = get_free_memory(load_device)
|
||||
mem_o = get_free_memory(offload_device)
|
||||
if mem_l > (mem_o * 0.5) and model_size * 1.2 < mem_l:
|
||||
return load_device
|
||||
else:
|
||||
return offload_device
|
||||
|
||||
def text_encoder_dtype(device=None):
|
||||
if args.fp8_e4m3fn_text_enc:
|
||||
return torch.float8_e4m3fn
|
||||
@@ -860,7 +896,8 @@ def pytorch_attention_flash_attention():
|
||||
def force_upcast_attention_dtype():
|
||||
upcast = args.force_upcast_attention
|
||||
try:
|
||||
if platform.mac_ver()[0] in ['14.5']: #black image bug on OSX Sonoma 14.5
|
||||
macos_version = tuple(int(n) for n in platform.mac_ver()[0].split("."))
|
||||
if (14, 5) <= macos_version < (14, 7): # black image bug on recent versions of MacOS
|
||||
upcast = True
|
||||
except:
|
||||
pass
|
||||
@@ -956,23 +993,23 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma
|
||||
if torch.version.hip:
|
||||
return True
|
||||
|
||||
props = torch.cuda.get_device_properties("cuda")
|
||||
props = torch.cuda.get_device_properties(device)
|
||||
if props.major >= 8:
|
||||
return True
|
||||
|
||||
if props.major < 6:
|
||||
return False
|
||||
|
||||
fp16_works = False
|
||||
#FP16 is confirmed working on a 1080 (GP104) but it's a bit slower than FP32 so it should only be enabled
|
||||
#when the model doesn't actually fit on the card
|
||||
#TODO: actually test if GP106 and others have the same type of behavior
|
||||
#FP16 is confirmed working on a 1080 (GP104) and on latest pytorch actually seems faster than fp32
|
||||
nvidia_10_series = ["1080", "1070", "titan x", "p3000", "p3200", "p4000", "p4200", "p5000", "p5200", "p6000", "1060", "1050", "p40", "p100", "p6", "p4"]
|
||||
for x in nvidia_10_series:
|
||||
if x in props.name.lower():
|
||||
fp16_works = True
|
||||
if WINDOWS or manual_cast:
|
||||
return True
|
||||
else:
|
||||
return False #weird linux behavior where fp32 is faster
|
||||
|
||||
if fp16_works or manual_cast:
|
||||
if manual_cast:
|
||||
free_model_memory = maximum_vram_for_weights(device)
|
||||
if (not prioritize_performance) or model_params * 4 > free_model_memory:
|
||||
return True
|
||||
@@ -1012,7 +1049,7 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma
|
||||
if is_intel_xpu():
|
||||
return True
|
||||
|
||||
props = torch.cuda.get_device_properties("cuda")
|
||||
props = torch.cuda.get_device_properties(device)
|
||||
if props.major >= 8:
|
||||
return True
|
||||
|
||||
@@ -1025,6 +1062,16 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma
|
||||
|
||||
return False
|
||||
|
||||
def supports_fp8_compute(device=None):
|
||||
props = torch.cuda.get_device_properties(device)
|
||||
if props.major >= 9:
|
||||
return True
|
||||
if props.major < 8:
|
||||
return False
|
||||
if props.minor < 9:
|
||||
return False
|
||||
return True
|
||||
|
||||
def soft_empty_cache(force=False):
|
||||
global cpu_state
|
||||
if cpu_state == CPUState.MPS:
|
||||
|
@@ -22,32 +22,26 @@ import inspect
|
||||
import logging
|
||||
import uuid
|
||||
import collections
|
||||
import math
|
||||
|
||||
import comfy.utils
|
||||
import comfy.float
|
||||
import comfy.model_management
|
||||
import comfy.lora
|
||||
from comfy.types import UnetWrapperFunction
|
||||
|
||||
|
||||
def weight_decompose(dora_scale, weight, lora_diff, alpha, strength):
|
||||
dora_scale = comfy.model_management.cast_to_device(dora_scale, weight.device, torch.float32)
|
||||
lora_diff *= alpha
|
||||
weight_calc = weight + lora_diff.type(weight.dtype)
|
||||
weight_norm = (
|
||||
weight_calc.transpose(0, 1)
|
||||
.reshape(weight_calc.shape[1], -1)
|
||||
.norm(dim=1, keepdim=True)
|
||||
.reshape(weight_calc.shape[1], *[1] * (weight_calc.dim() - 1))
|
||||
.transpose(0, 1)
|
||||
)
|
||||
|
||||
weight_calc *= (dora_scale / weight_norm).type(weight.dtype)
|
||||
if strength != 1.0:
|
||||
weight_calc -= weight
|
||||
weight += strength * (weight_calc)
|
||||
else:
|
||||
weight[:] = weight_calc
|
||||
return weight
|
||||
|
||||
def string_to_seed(data):
|
||||
crc = 0xFFFFFFFF
|
||||
for byte in data:
|
||||
if isinstance(byte, str):
|
||||
byte = ord(byte)
|
||||
crc ^= byte
|
||||
for _ in range(8):
|
||||
if crc & 1:
|
||||
crc = (crc >> 1) ^ 0xEDB88320
|
||||
else:
|
||||
crc >>= 1
|
||||
return crc ^ 0xFFFFFFFF
|
||||
|
||||
def set_model_options_patch_replace(model_options, patch, name, block_name, number, transformer_index=None):
|
||||
to = model_options["transformer_options"].copy()
|
||||
@@ -90,12 +84,11 @@ def wipe_lowvram_weight(m):
|
||||
m.bias_function = None
|
||||
|
||||
class LowVramPatch:
|
||||
def __init__(self, key, model_patcher):
|
||||
def __init__(self, key, patches):
|
||||
self.key = key
|
||||
self.model_patcher = model_patcher
|
||||
self.patches = patches
|
||||
def __call__(self, weight):
|
||||
return self.model_patcher.calculate_weight(self.model_patcher.patches[self.key], weight, self.key)
|
||||
|
||||
return comfy.lora.calculate_weight(self.patches[self.key], weight, self.key, intermediate_dtype=weight.dtype)
|
||||
|
||||
class ModelPatcher:
|
||||
def __init__(self, model, load_device, offload_device, size=0, weight_inplace_update=False):
|
||||
@@ -327,46 +320,36 @@ class ModelPatcher:
|
||||
temp_weight = comfy.model_management.cast_to_device(weight, device_to, torch.float32, copy=True)
|
||||
else:
|
||||
temp_weight = weight.to(torch.float32, copy=True)
|
||||
out_weight = self.calculate_weight(self.patches[key], temp_weight, key).to(weight.dtype)
|
||||
out_weight = comfy.lora.calculate_weight(self.patches[key], temp_weight, key)
|
||||
out_weight = comfy.float.stochastic_rounding(out_weight, weight.dtype, seed=string_to_seed(key))
|
||||
if inplace_update:
|
||||
comfy.utils.copy_to_param(self.model, key, out_weight)
|
||||
else:
|
||||
comfy.utils.set_attr_param(self.model, key, out_weight)
|
||||
|
||||
def patch_model(self, device_to=None, patch_weights=True):
|
||||
for k in self.object_patches:
|
||||
old = comfy.utils.set_attr(self.model, k, self.object_patches[k])
|
||||
if k not in self.object_patches_backup:
|
||||
self.object_patches_backup[k] = old
|
||||
|
||||
if patch_weights:
|
||||
model_sd = self.model_state_dict()
|
||||
for key in self.patches:
|
||||
if key not in model_sd:
|
||||
logging.warning("could not patch. key doesn't exist in model: {}".format(key))
|
||||
continue
|
||||
|
||||
self.patch_weight_to_device(key, device_to)
|
||||
|
||||
if device_to is not None:
|
||||
self.model.to(device_to)
|
||||
self.model.device = device_to
|
||||
self.model.model_loaded_weight_memory = self.model_size()
|
||||
|
||||
return self.model
|
||||
|
||||
def lowvram_load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False):
|
||||
def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False):
|
||||
mem_counter = 0
|
||||
patch_counter = 0
|
||||
lowvram_counter = 0
|
||||
loading = []
|
||||
for n, m in self.model.named_modules():
|
||||
if hasattr(m, "comfy_cast_weights") or hasattr(m, "weight"):
|
||||
loading.append((comfy.model_management.module_size(m), n, m))
|
||||
|
||||
load_completely = []
|
||||
loading.sort(reverse=True)
|
||||
for x in loading:
|
||||
n = x[1]
|
||||
m = x[2]
|
||||
module_mem = x[0]
|
||||
|
||||
lowvram_weight = False
|
||||
if hasattr(m, "comfy_cast_weights"):
|
||||
module_mem = comfy.model_management.module_size(m)
|
||||
|
||||
if not full_load and hasattr(m, "comfy_cast_weights"):
|
||||
if mem_counter + module_mem >= lowvram_model_memory:
|
||||
lowvram_weight = True
|
||||
lowvram_counter += 1
|
||||
if m.comfy_cast_weights:
|
||||
if hasattr(m, "prev_comfy_cast_weights"): #Already lowvramed
|
||||
continue
|
||||
|
||||
weight_key = "{}.weight".format(n)
|
||||
@@ -377,13 +360,13 @@ class ModelPatcher:
|
||||
if force_patch_weights:
|
||||
self.patch_weight_to_device(weight_key)
|
||||
else:
|
||||
m.weight_function = LowVramPatch(weight_key, self)
|
||||
m.weight_function = LowVramPatch(weight_key, self.patches)
|
||||
patch_counter += 1
|
||||
if bias_key in self.patches:
|
||||
if force_patch_weights:
|
||||
self.patch_weight_to_device(bias_key)
|
||||
else:
|
||||
m.bias_function = LowVramPatch(bias_key, self)
|
||||
m.bias_function = LowVramPatch(bias_key, self.patches)
|
||||
patch_counter += 1
|
||||
|
||||
m.prev_comfy_cast_weights = m.comfy_cast_weights
|
||||
@@ -394,202 +377,56 @@ class ModelPatcher:
|
||||
wipe_lowvram_weight(m)
|
||||
|
||||
if hasattr(m, "weight"):
|
||||
mem_counter += comfy.model_management.module_size(m)
|
||||
param = list(m.parameters())
|
||||
if len(param) > 0:
|
||||
weight = param[0]
|
||||
if weight.device == device_to:
|
||||
continue
|
||||
mem_counter += module_mem
|
||||
load_completely.append((module_mem, n, m))
|
||||
|
||||
self.patch_weight_to_device(weight_key) #TODO: speed this up without OOM
|
||||
self.patch_weight_to_device(bias_key)
|
||||
m.to(device_to)
|
||||
logging.debug("lowvram: loaded module regularly {} {}".format(n, m))
|
||||
load_completely.sort(reverse=True)
|
||||
for x in load_completely:
|
||||
n = x[1]
|
||||
m = x[2]
|
||||
weight_key = "{}.weight".format(n)
|
||||
bias_key = "{}.bias".format(n)
|
||||
if hasattr(m, "comfy_patched_weights"):
|
||||
if m.comfy_patched_weights == True:
|
||||
continue
|
||||
|
||||
self.patch_weight_to_device(weight_key, device_to=device_to)
|
||||
self.patch_weight_to_device(bias_key, device_to=device_to)
|
||||
logging.debug("lowvram: loaded module regularly {} {}".format(n, m))
|
||||
m.comfy_patched_weights = True
|
||||
|
||||
for x in load_completely:
|
||||
x[2].to(device_to)
|
||||
|
||||
if lowvram_counter > 0:
|
||||
logging.info("loaded in lowvram mode {}".format(lowvram_model_memory / (1024 * 1024)))
|
||||
logging.info("loaded partially {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), patch_counter))
|
||||
self.model.model_lowvram = True
|
||||
else:
|
||||
logging.info("loaded completely {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024)))
|
||||
logging.info("loaded completely {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), full_load))
|
||||
self.model.model_lowvram = False
|
||||
if full_load:
|
||||
self.model.to(device_to)
|
||||
mem_counter = self.model_size()
|
||||
|
||||
self.model.lowvram_patch_counter += patch_counter
|
||||
self.model.device = device_to
|
||||
self.model.model_loaded_weight_memory = mem_counter
|
||||
|
||||
def patch_model(self, device_to=None, lowvram_model_memory=0, load_weights=True, force_patch_weights=False):
|
||||
for k in self.object_patches:
|
||||
old = comfy.utils.set_attr(self.model, k, self.object_patches[k])
|
||||
if k not in self.object_patches_backup:
|
||||
self.object_patches_backup[k] = old
|
||||
|
||||
def patch_model_lowvram(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False):
|
||||
self.patch_model(device_to, patch_weights=False)
|
||||
self.lowvram_load(device_to, lowvram_model_memory=lowvram_model_memory, force_patch_weights=force_patch_weights)
|
||||
if lowvram_model_memory == 0:
|
||||
full_load = True
|
||||
else:
|
||||
full_load = False
|
||||
|
||||
if load_weights:
|
||||
self.load(device_to, lowvram_model_memory=lowvram_model_memory, force_patch_weights=force_patch_weights, full_load=full_load)
|
||||
return self.model
|
||||
|
||||
def calculate_weight(self, patches, weight, key):
|
||||
for p in patches:
|
||||
strength = p[0]
|
||||
v = p[1]
|
||||
strength_model = p[2]
|
||||
offset = p[3]
|
||||
function = p[4]
|
||||
if function is None:
|
||||
function = lambda a: a
|
||||
|
||||
old_weight = None
|
||||
if offset is not None:
|
||||
old_weight = weight
|
||||
weight = weight.narrow(offset[0], offset[1], offset[2])
|
||||
|
||||
if strength_model != 1.0:
|
||||
weight *= strength_model
|
||||
|
||||
if isinstance(v, list):
|
||||
v = (self.calculate_weight(v[1:], v[0].clone(), key), )
|
||||
|
||||
if len(v) == 1:
|
||||
patch_type = "diff"
|
||||
elif len(v) == 2:
|
||||
patch_type = v[0]
|
||||
v = v[1]
|
||||
|
||||
if patch_type == "diff":
|
||||
w1 = v[0]
|
||||
if strength != 0.0:
|
||||
if w1.shape != weight.shape:
|
||||
logging.warning("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, w1.shape, weight.shape))
|
||||
else:
|
||||
weight += function(strength * comfy.model_management.cast_to_device(w1, weight.device, weight.dtype))
|
||||
elif patch_type == "lora": #lora/locon
|
||||
mat1 = comfy.model_management.cast_to_device(v[0], weight.device, torch.float32)
|
||||
mat2 = comfy.model_management.cast_to_device(v[1], weight.device, torch.float32)
|
||||
dora_scale = v[4]
|
||||
if v[2] is not None:
|
||||
alpha = v[2] / mat2.shape[0]
|
||||
else:
|
||||
alpha = 1.0
|
||||
|
||||
if v[3] is not None:
|
||||
#locon mid weights, hopefully the math is fine because I didn't properly test it
|
||||
mat3 = comfy.model_management.cast_to_device(v[3], weight.device, torch.float32)
|
||||
final_shape = [mat2.shape[1], mat2.shape[0], mat3.shape[2], mat3.shape[3]]
|
||||
mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1), mat3.transpose(0, 1).flatten(start_dim=1)).reshape(final_shape).transpose(0, 1)
|
||||
try:
|
||||
lora_diff = torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1)).reshape(weight.shape)
|
||||
if dora_scale is not None:
|
||||
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength))
|
||||
else:
|
||||
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
||||
except Exception as e:
|
||||
logging.error("ERROR {} {} {}".format(patch_type, key, e))
|
||||
elif patch_type == "lokr":
|
||||
w1 = v[0]
|
||||
w2 = v[1]
|
||||
w1_a = v[3]
|
||||
w1_b = v[4]
|
||||
w2_a = v[5]
|
||||
w2_b = v[6]
|
||||
t2 = v[7]
|
||||
dora_scale = v[8]
|
||||
dim = None
|
||||
|
||||
if w1 is None:
|
||||
dim = w1_b.shape[0]
|
||||
w1 = torch.mm(comfy.model_management.cast_to_device(w1_a, weight.device, torch.float32),
|
||||
comfy.model_management.cast_to_device(w1_b, weight.device, torch.float32))
|
||||
else:
|
||||
w1 = comfy.model_management.cast_to_device(w1, weight.device, torch.float32)
|
||||
|
||||
if w2 is None:
|
||||
dim = w2_b.shape[0]
|
||||
if t2 is None:
|
||||
w2 = torch.mm(comfy.model_management.cast_to_device(w2_a, weight.device, torch.float32),
|
||||
comfy.model_management.cast_to_device(w2_b, weight.device, torch.float32))
|
||||
else:
|
||||
w2 = torch.einsum('i j k l, j r, i p -> p r k l',
|
||||
comfy.model_management.cast_to_device(t2, weight.device, torch.float32),
|
||||
comfy.model_management.cast_to_device(w2_b, weight.device, torch.float32),
|
||||
comfy.model_management.cast_to_device(w2_a, weight.device, torch.float32))
|
||||
else:
|
||||
w2 = comfy.model_management.cast_to_device(w2, weight.device, torch.float32)
|
||||
|
||||
if len(w2.shape) == 4:
|
||||
w1 = w1.unsqueeze(2).unsqueeze(2)
|
||||
if v[2] is not None and dim is not None:
|
||||
alpha = v[2] / dim
|
||||
else:
|
||||
alpha = 1.0
|
||||
|
||||
try:
|
||||
lora_diff = torch.kron(w1, w2).reshape(weight.shape)
|
||||
if dora_scale is not None:
|
||||
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength))
|
||||
else:
|
||||
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
||||
except Exception as e:
|
||||
logging.error("ERROR {} {} {}".format(patch_type, key, e))
|
||||
elif patch_type == "loha":
|
||||
w1a = v[0]
|
||||
w1b = v[1]
|
||||
if v[2] is not None:
|
||||
alpha = v[2] / w1b.shape[0]
|
||||
else:
|
||||
alpha = 1.0
|
||||
|
||||
w2a = v[3]
|
||||
w2b = v[4]
|
||||
dora_scale = v[7]
|
||||
if v[5] is not None: #cp decomposition
|
||||
t1 = v[5]
|
||||
t2 = v[6]
|
||||
m1 = torch.einsum('i j k l, j r, i p -> p r k l',
|
||||
comfy.model_management.cast_to_device(t1, weight.device, torch.float32),
|
||||
comfy.model_management.cast_to_device(w1b, weight.device, torch.float32),
|
||||
comfy.model_management.cast_to_device(w1a, weight.device, torch.float32))
|
||||
|
||||
m2 = torch.einsum('i j k l, j r, i p -> p r k l',
|
||||
comfy.model_management.cast_to_device(t2, weight.device, torch.float32),
|
||||
comfy.model_management.cast_to_device(w2b, weight.device, torch.float32),
|
||||
comfy.model_management.cast_to_device(w2a, weight.device, torch.float32))
|
||||
else:
|
||||
m1 = torch.mm(comfy.model_management.cast_to_device(w1a, weight.device, torch.float32),
|
||||
comfy.model_management.cast_to_device(w1b, weight.device, torch.float32))
|
||||
m2 = torch.mm(comfy.model_management.cast_to_device(w2a, weight.device, torch.float32),
|
||||
comfy.model_management.cast_to_device(w2b, weight.device, torch.float32))
|
||||
|
||||
try:
|
||||
lora_diff = (m1 * m2).reshape(weight.shape)
|
||||
if dora_scale is not None:
|
||||
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength))
|
||||
else:
|
||||
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
||||
except Exception as e:
|
||||
logging.error("ERROR {} {} {}".format(patch_type, key, e))
|
||||
elif patch_type == "glora":
|
||||
if v[4] is not None:
|
||||
alpha = v[4] / v[0].shape[0]
|
||||
else:
|
||||
alpha = 1.0
|
||||
|
||||
dora_scale = v[5]
|
||||
|
||||
a1 = comfy.model_management.cast_to_device(v[0].flatten(start_dim=1), weight.device, torch.float32)
|
||||
a2 = comfy.model_management.cast_to_device(v[1].flatten(start_dim=1), weight.device, torch.float32)
|
||||
b1 = comfy.model_management.cast_to_device(v[2].flatten(start_dim=1), weight.device, torch.float32)
|
||||
b2 = comfy.model_management.cast_to_device(v[3].flatten(start_dim=1), weight.device, torch.float32)
|
||||
|
||||
try:
|
||||
lora_diff = (torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1), a2), a1)).reshape(weight.shape)
|
||||
if dora_scale is not None:
|
||||
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength))
|
||||
else:
|
||||
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
||||
except Exception as e:
|
||||
logging.error("ERROR {} {} {}".format(patch_type, key, e))
|
||||
else:
|
||||
logging.warning("patch type not recognized {} {}".format(patch_type, key))
|
||||
|
||||
if old_weight is not None:
|
||||
weight = old_weight
|
||||
|
||||
return weight
|
||||
|
||||
def unpatch_model(self, device_to=None, unpatch_weights=True):
|
||||
if unpatch_weights:
|
||||
if self.model.model_lowvram:
|
||||
@@ -615,6 +452,10 @@ class ModelPatcher:
|
||||
self.model.device = device_to
|
||||
self.model.model_loaded_weight_memory = 0
|
||||
|
||||
for m in self.model.modules():
|
||||
if hasattr(m, "comfy_patched_weights"):
|
||||
del m.comfy_patched_weights
|
||||
|
||||
keys = list(self.object_patches_backup.keys())
|
||||
for k in keys:
|
||||
comfy.utils.set_attr(self.model, k, self.object_patches_backup[k])
|
||||
@@ -624,40 +465,47 @@ class ModelPatcher:
|
||||
def partially_unload(self, device_to, memory_to_free=0):
|
||||
memory_freed = 0
|
||||
patch_counter = 0
|
||||
unload_list = []
|
||||
|
||||
for n, m in list(self.model.named_modules())[::-1]:
|
||||
if memory_to_free < memory_freed:
|
||||
break
|
||||
|
||||
for n, m in self.model.named_modules():
|
||||
shift_lowvram = False
|
||||
if hasattr(m, "comfy_cast_weights"):
|
||||
module_mem = comfy.model_management.module_size(m)
|
||||
weight_key = "{}.weight".format(n)
|
||||
bias_key = "{}.bias".format(n)
|
||||
unload_list.append((module_mem, n, m))
|
||||
|
||||
unload_list.sort()
|
||||
for unload in unload_list:
|
||||
if memory_to_free < memory_freed:
|
||||
break
|
||||
module_mem = unload[0]
|
||||
n = unload[1]
|
||||
m = unload[2]
|
||||
weight_key = "{}.weight".format(n)
|
||||
bias_key = "{}.bias".format(n)
|
||||
|
||||
if m.weight is not None and m.weight.device != device_to:
|
||||
for key in [weight_key, bias_key]:
|
||||
bk = self.backup.get(key, None)
|
||||
if bk is not None:
|
||||
if bk.inplace_update:
|
||||
comfy.utils.copy_to_param(self.model, key, bk.weight)
|
||||
else:
|
||||
comfy.utils.set_attr_param(self.model, key, bk.weight)
|
||||
self.backup.pop(key)
|
||||
if hasattr(m, "comfy_patched_weights") and m.comfy_patched_weights == True:
|
||||
for key in [weight_key, bias_key]:
|
||||
bk = self.backup.get(key, None)
|
||||
if bk is not None:
|
||||
if bk.inplace_update:
|
||||
comfy.utils.copy_to_param(self.model, key, bk.weight)
|
||||
else:
|
||||
comfy.utils.set_attr_param(self.model, key, bk.weight)
|
||||
self.backup.pop(key)
|
||||
|
||||
m.to(device_to)
|
||||
if weight_key in self.patches:
|
||||
m.weight_function = LowVramPatch(weight_key, self)
|
||||
patch_counter += 1
|
||||
if bias_key in self.patches:
|
||||
m.bias_function = LowVramPatch(bias_key, self)
|
||||
patch_counter += 1
|
||||
m.to(device_to)
|
||||
if weight_key in self.patches:
|
||||
m.weight_function = LowVramPatch(weight_key, self.patches)
|
||||
patch_counter += 1
|
||||
if bias_key in self.patches:
|
||||
m.bias_function = LowVramPatch(bias_key, self.patches)
|
||||
patch_counter += 1
|
||||
|
||||
m.prev_comfy_cast_weights = m.comfy_cast_weights
|
||||
m.comfy_cast_weights = True
|
||||
memory_freed += module_mem
|
||||
logging.debug("freed {}".format(n))
|
||||
m.prev_comfy_cast_weights = m.comfy_cast_weights
|
||||
m.comfy_cast_weights = True
|
||||
m.comfy_patched_weights = False
|
||||
memory_freed += module_mem
|
||||
logging.debug("freed {}".format(n))
|
||||
|
||||
self.model.model_lowvram = True
|
||||
self.model.lowvram_patch_counter += patch_counter
|
||||
@@ -665,13 +513,20 @@ class ModelPatcher:
|
||||
return memory_freed
|
||||
|
||||
def partially_load(self, device_to, extra_memory=0):
|
||||
self.unpatch_model(unpatch_weights=False)
|
||||
self.patch_model(load_weights=False)
|
||||
full_load = False
|
||||
if self.model.model_lowvram == False:
|
||||
return 0
|
||||
if self.model.model_loaded_weight_memory + extra_memory > self.model_size():
|
||||
pass #TODO: Full load
|
||||
full_load = True
|
||||
current_used = self.model.model_loaded_weight_memory
|
||||
self.lowvram_load(device_to, lowvram_model_memory=current_used + extra_memory)
|
||||
self.load(device_to, lowvram_model_memory=current_used + extra_memory, full_load=full_load)
|
||||
return self.model.model_loaded_weight_memory - current_used
|
||||
|
||||
def current_loaded_device(self):
|
||||
return self.model.device
|
||||
|
||||
def calculate_weight(self, patches, weight, key, intermediate_dtype=torch.float32):
|
||||
print("WARNING the ModelPatcher.calculate_weight function is deprecated, please use: comfy.lora.calculate_weight instead")
|
||||
return comfy.lora.calculate_weight(patches, weight, key, intermediate_dtype=intermediate_dtype)
|
||||
|
89
comfy/ops.py
89
comfy/ops.py
@@ -18,29 +18,42 @@
|
||||
|
||||
import torch
|
||||
import comfy.model_management
|
||||
from comfy.cli_args import args
|
||||
|
||||
def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False):
|
||||
if device is None or weight.device == device:
|
||||
if not copy:
|
||||
if dtype is None or weight.dtype == dtype:
|
||||
return weight
|
||||
return weight.to(dtype=dtype, copy=copy)
|
||||
|
||||
def cast_to(weight, dtype=None, device=None, non_blocking=False):
|
||||
return weight.to(device=device, dtype=dtype, non_blocking=non_blocking)
|
||||
r = torch.empty_like(weight, dtype=dtype, device=device)
|
||||
r.copy_(weight, non_blocking=non_blocking)
|
||||
return r
|
||||
|
||||
def cast_to_input(weight, input, non_blocking=False):
|
||||
return cast_to(weight, input.dtype, input.device, non_blocking=non_blocking)
|
||||
def cast_to_input(weight, input, non_blocking=False, copy=True):
|
||||
return cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy)
|
||||
|
||||
def cast_bias_weight(s, input=None, dtype=None, device=None):
|
||||
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None):
|
||||
if input is not None:
|
||||
if dtype is None:
|
||||
dtype = input.dtype
|
||||
if bias_dtype is None:
|
||||
bias_dtype = dtype
|
||||
if device is None:
|
||||
device = input.device
|
||||
|
||||
bias = None
|
||||
non_blocking = comfy.model_management.device_should_use_non_blocking(device)
|
||||
non_blocking = comfy.model_management.device_supports_non_blocking(device)
|
||||
if s.bias is not None:
|
||||
bias = cast_to(s.bias, dtype, device, non_blocking=non_blocking)
|
||||
if s.bias_function is not None:
|
||||
has_function = s.bias_function is not None
|
||||
bias = cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=has_function)
|
||||
if has_function:
|
||||
bias = s.bias_function(bias)
|
||||
weight = cast_to(s.weight, dtype, device, non_blocking=non_blocking)
|
||||
if s.weight_function is not None:
|
||||
|
||||
has_function = s.weight_function is not None
|
||||
weight = cast_to(s.weight, dtype, device, non_blocking=non_blocking, copy=has_function)
|
||||
if has_function:
|
||||
weight = s.weight_function(weight)
|
||||
return weight, bias
|
||||
|
||||
@@ -238,3 +251,59 @@ class manual_cast(disable_weight_init):
|
||||
|
||||
class Embedding(disable_weight_init.Embedding):
|
||||
comfy_cast_weights = True
|
||||
|
||||
|
||||
def fp8_linear(self, input):
|
||||
dtype = self.weight.dtype
|
||||
if dtype not in [torch.float8_e4m3fn]:
|
||||
return None
|
||||
|
||||
if len(input.shape) == 3:
|
||||
inn = input.reshape(-1, input.shape[2]).to(dtype)
|
||||
non_blocking = comfy.model_management.device_supports_non_blocking(input.device)
|
||||
w, bias = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input.dtype)
|
||||
w = w.t()
|
||||
|
||||
scale_weight = self.scale_weight
|
||||
scale_input = self.scale_input
|
||||
if scale_weight is None:
|
||||
scale_weight = torch.ones((1), device=input.device, dtype=torch.float32)
|
||||
if scale_input is None:
|
||||
scale_input = scale_weight
|
||||
if scale_input is None:
|
||||
scale_input = torch.ones((1), device=input.device, dtype=torch.float32)
|
||||
|
||||
if bias is not None:
|
||||
o = torch._scaled_mm(inn, w, out_dtype=input.dtype, bias=bias, scale_a=scale_input, scale_b=scale_weight)
|
||||
else:
|
||||
o = torch._scaled_mm(inn, w, out_dtype=input.dtype, scale_a=scale_input, scale_b=scale_weight)
|
||||
|
||||
if isinstance(o, tuple):
|
||||
o = o[0]
|
||||
|
||||
return o.reshape((-1, input.shape[1], self.weight.shape[0]))
|
||||
return None
|
||||
|
||||
class fp8_ops(manual_cast):
|
||||
class Linear(manual_cast.Linear):
|
||||
def reset_parameters(self):
|
||||
self.scale_weight = None
|
||||
self.scale_input = None
|
||||
return None
|
||||
|
||||
def forward_comfy_cast_weights(self, input):
|
||||
out = fp8_linear(self, input)
|
||||
if out is not None:
|
||||
return out
|
||||
|
||||
weight, bias = cast_bias_weight(self, input)
|
||||
return torch.nn.functional.linear(input, weight, bias)
|
||||
|
||||
|
||||
def pick_operations(weight_dtype, compute_dtype, load_device=None):
|
||||
if compute_dtype is None or weight_dtype == compute_dtype:
|
||||
return disable_weight_init
|
||||
if args.fast:
|
||||
if comfy.model_management.supports_fp8_compute(load_device):
|
||||
return fp8_ops
|
||||
return manual_cast
|
||||
|
84
comfy/sd.py
84
comfy/sd.py
@@ -24,6 +24,7 @@ import comfy.text_encoders.sa_t5
|
||||
import comfy.text_encoders.aura_t5
|
||||
import comfy.text_encoders.hydit
|
||||
import comfy.text_encoders.flux
|
||||
import comfy.text_encoders.long_clipl
|
||||
|
||||
import comfy.model_patcher
|
||||
import comfy.lora
|
||||
@@ -62,7 +63,7 @@ def load_lora_for_models(model, clip, lora, strength_model, strength_clip):
|
||||
|
||||
|
||||
class CLIP:
|
||||
def __init__(self, target=None, embedding_directory=None, no_init=False, tokenizer_data={}):
|
||||
def __init__(self, target=None, embedding_directory=None, no_init=False, tokenizer_data={}, parameters=0, model_options={}):
|
||||
if no_init:
|
||||
return
|
||||
params = target.params.copy()
|
||||
@@ -71,20 +72,29 @@ class CLIP:
|
||||
|
||||
load_device = model_management.text_encoder_device()
|
||||
offload_device = model_management.text_encoder_offload_device()
|
||||
params['device'] = offload_device
|
||||
dtype = model_management.text_encoder_dtype(load_device)
|
||||
dtype = model_options.get("dtype", None)
|
||||
if dtype is None:
|
||||
dtype = model_management.text_encoder_dtype(load_device)
|
||||
|
||||
params['dtype'] = dtype
|
||||
params['device'] = model_management.text_encoder_initial_device(load_device, offload_device, parameters * model_management.dtype_size(dtype))
|
||||
params['model_options'] = model_options
|
||||
|
||||
self.cond_stage_model = clip(**(params))
|
||||
|
||||
for dt in self.cond_stage_model.dtypes:
|
||||
if not model_management.supports_cast(load_device, dt):
|
||||
load_device = offload_device
|
||||
if params['device'] != offload_device:
|
||||
self.cond_stage_model.to(offload_device)
|
||||
logging.warning("Had to shift TE back.")
|
||||
|
||||
self.tokenizer = tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
|
||||
self.patcher = comfy.model_patcher.ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device)
|
||||
if params['device'] == load_device:
|
||||
model_management.load_models_gpu([self.patcher], force_full_load=True)
|
||||
self.layer_idx = None
|
||||
logging.debug("CLIP model load device: {}, offload device: {}".format(load_device, offload_device))
|
||||
logging.debug("CLIP model load device: {}, offload device: {}, current: {}".format(load_device, offload_device, params['device']))
|
||||
|
||||
def clone(self):
|
||||
n = CLIP(no_init=True)
|
||||
@@ -390,11 +400,14 @@ class CLIPType(Enum):
|
||||
HUNYUAN_DIT = 5
|
||||
FLUX = 6
|
||||
|
||||
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION):
|
||||
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
|
||||
clip_data = []
|
||||
for p in ckpt_paths:
|
||||
clip_data.append(comfy.utils.load_torch_file(p, safe_load=True))
|
||||
return load_text_encoder_state_dicts(clip_data, embedding_directory=embedding_directory, clip_type=clip_type, model_options=model_options)
|
||||
|
||||
def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
|
||||
clip_data = state_dicts
|
||||
class EmptyClass:
|
||||
pass
|
||||
|
||||
@@ -431,8 +444,13 @@ def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DI
|
||||
clip_target.clip = comfy.text_encoders.sa_t5.SAT5Model
|
||||
clip_target.tokenizer = comfy.text_encoders.sa_t5.SAT5Tokenizer
|
||||
else:
|
||||
clip_target.clip = sd1_clip.SD1ClipModel
|
||||
clip_target.tokenizer = sd1_clip.SD1Tokenizer
|
||||
w = clip_data[0].get("text_model.embeddings.position_embedding.weight", None)
|
||||
if w is not None and w.shape[0] == 248:
|
||||
clip_target.clip = comfy.text_encoders.long_clipl.LongClipModel
|
||||
clip_target.tokenizer = comfy.text_encoders.long_clipl.LongClipTokenizer
|
||||
else:
|
||||
clip_target.clip = sd1_clip.SD1ClipModel
|
||||
clip_target.tokenizer = sd1_clip.SD1Tokenizer
|
||||
elif len(clip_data) == 2:
|
||||
if clip_type == CLIPType.SD3:
|
||||
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=True, clip_g=True, t5=False)
|
||||
@@ -456,7 +474,11 @@ def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DI
|
||||
clip_target.clip = comfy.text_encoders.sd3_clip.SD3ClipModel
|
||||
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
|
||||
|
||||
clip = CLIP(clip_target, embedding_directory=embedding_directory)
|
||||
parameters = 0
|
||||
for c in clip_data:
|
||||
parameters += comfy.utils.calculate_parameters(c)
|
||||
|
||||
clip = CLIP(clip_target, embedding_directory=embedding_directory, parameters=parameters, model_options=model_options)
|
||||
for c in clip_data:
|
||||
m, u = clip.load_sd(c)
|
||||
if len(m) > 0:
|
||||
@@ -498,15 +520,19 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl
|
||||
|
||||
return (model, clip, vae)
|
||||
|
||||
def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True):
|
||||
def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}, te_model_options={}):
|
||||
sd = comfy.utils.load_torch_file(ckpt_path)
|
||||
sd_keys = sd.keys()
|
||||
out = load_state_dict_guess_config(sd, output_vae, output_clip, output_clipvision, embedding_directory, output_model, model_options, te_model_options=te_model_options)
|
||||
if out is None:
|
||||
raise RuntimeError("ERROR: Could not detect model type of: {}".format(ckpt_path))
|
||||
return out
|
||||
|
||||
def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}, te_model_options={}):
|
||||
clip = None
|
||||
clipvision = None
|
||||
vae = None
|
||||
model = None
|
||||
model_patcher = None
|
||||
clip_target = None
|
||||
|
||||
diffusion_model_prefix = model_detection.unet_prefix_from_state_dict(sd)
|
||||
parameters = comfy.utils.calculate_parameters(sd, diffusion_model_prefix)
|
||||
@@ -515,13 +541,18 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
|
||||
|
||||
model_config = model_detection.model_config_from_unet(sd, diffusion_model_prefix)
|
||||
if model_config is None:
|
||||
raise RuntimeError("ERROR: Could not detect model type of: {}".format(ckpt_path))
|
||||
return None
|
||||
|
||||
unet_weight_dtype = list(model_config.supported_inference_dtypes)
|
||||
if weight_dtype is not None:
|
||||
unet_weight_dtype.append(weight_dtype)
|
||||
|
||||
unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=unet_weight_dtype)
|
||||
model_config.custom_operations = model_options.get("custom_operations", None)
|
||||
unet_dtype = model_options.get("weight_dtype", None)
|
||||
|
||||
if unet_dtype is None:
|
||||
unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=unet_weight_dtype)
|
||||
|
||||
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
|
||||
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
|
||||
|
||||
@@ -545,7 +576,8 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
|
||||
if clip_target is not None:
|
||||
clip_sd = model_config.process_clip_state_dict(sd)
|
||||
if len(clip_sd) > 0:
|
||||
clip = CLIP(clip_target, embedding_directory=embedding_directory, tokenizer_data=clip_sd)
|
||||
parameters = comfy.utils.calculate_parameters(clip_sd)
|
||||
clip = CLIP(clip_target, embedding_directory=embedding_directory, tokenizer_data=clip_sd, parameters=parameters, model_options=te_model_options)
|
||||
m, u = clip.load_sd(clip_sd, full_model=True)
|
||||
if len(m) > 0:
|
||||
m_filter = list(filter(lambda a: ".logit_scale" not in a and ".transformer.text_projection.weight" not in a, m))
|
||||
@@ -567,12 +599,13 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
|
||||
model_patcher = comfy.model_patcher.ModelPatcher(model, load_device=load_device, offload_device=model_management.unet_offload_device())
|
||||
if inital_load_device != torch.device("cpu"):
|
||||
logging.info("loaded straight to GPU")
|
||||
model_management.load_model_gpu(model_patcher)
|
||||
model_management.load_models_gpu([model_patcher], force_full_load=True)
|
||||
|
||||
return (model_patcher, clip, vae, clipvision)
|
||||
|
||||
|
||||
def load_unet_state_dict(sd, dtype=None): #load unet in diffusers or regular format
|
||||
def load_diffusion_model_state_dict(sd, model_options={}): #load unet in diffusers or regular format
|
||||
dtype = model_options.get("dtype", None)
|
||||
|
||||
#Allow loading unets from checkpoint files
|
||||
diffusion_model_prefix = model_detection.unet_prefix_from_state_dict(sd)
|
||||
@@ -614,6 +647,7 @@ def load_unet_state_dict(sd, dtype=None): #load unet in diffusers or regular for
|
||||
|
||||
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
|
||||
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
|
||||
model_config.custom_operations = model_options.get("custom_operations", None)
|
||||
model = model_config.get_model(new_sd, "")
|
||||
model = model.to(offload_device)
|
||||
model.load_model_weights(new_sd, "")
|
||||
@@ -622,24 +656,36 @@ def load_unet_state_dict(sd, dtype=None): #load unet in diffusers or regular for
|
||||
logging.info("left over keys in unet: {}".format(left_over))
|
||||
return comfy.model_patcher.ModelPatcher(model, load_device=load_device, offload_device=offload_device)
|
||||
|
||||
def load_unet(unet_path, dtype=None):
|
||||
|
||||
def load_diffusion_model(unet_path, model_options={}):
|
||||
sd = comfy.utils.load_torch_file(unet_path)
|
||||
model = load_unet_state_dict(sd, dtype=dtype)
|
||||
model = load_diffusion_model_state_dict(sd, model_options=model_options)
|
||||
if model is None:
|
||||
logging.error("ERROR UNSUPPORTED UNET {}".format(unet_path))
|
||||
raise RuntimeError("ERROR: Could not detect model type of: {}".format(unet_path))
|
||||
return model
|
||||
|
||||
def load_unet(unet_path, dtype=None):
|
||||
print("WARNING: the load_unet function has been deprecated and will be removed please switch to: load_diffusion_model")
|
||||
return load_diffusion_model(unet_path, model_options={"dtype": dtype})
|
||||
|
||||
def load_unet_state_dict(sd, dtype=None):
|
||||
print("WARNING: the load_unet_state_dict function has been deprecated and will be removed please switch to: load_diffusion_model_state_dict")
|
||||
return load_diffusion_model_state_dict(sd, model_options={"dtype": dtype})
|
||||
|
||||
def save_checkpoint(output_path, model, clip=None, vae=None, clip_vision=None, metadata=None, extra_keys={}):
|
||||
clip_sd = None
|
||||
load_models = [model]
|
||||
if clip is not None:
|
||||
load_models.append(clip.load_model())
|
||||
clip_sd = clip.get_sd()
|
||||
vae_sd = None
|
||||
if vae is not None:
|
||||
vae_sd = vae.get_sd()
|
||||
|
||||
model_management.load_models_gpu(load_models, force_patch_weights=True)
|
||||
clip_vision_sd = clip_vision.get_sd() if clip_vision is not None else None
|
||||
sd = model.model.state_dict_for_saving(clip_sd, vae.get_sd(), clip_vision_sd)
|
||||
sd = model.model.state_dict_for_saving(clip_sd, vae_sd, clip_vision_sd)
|
||||
for k in extra_keys:
|
||||
sd[k] = extra_keys[k]
|
||||
|
||||
|
@@ -75,7 +75,6 @@ class ClipTokenWeightEncoder:
|
||||
return r
|
||||
|
||||
class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
"""Uses the CLIP transformer encoder for text (from huggingface)"""
|
||||
LAYERS = [
|
||||
"last",
|
||||
"pooled",
|
||||
@@ -84,7 +83,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
def __init__(self, version="openai/clip-vit-large-patch14", device="cpu", max_length=77,
|
||||
freeze=True, layer="last", layer_idx=None, textmodel_json_config=None, dtype=None, model_class=comfy.clip_model.CLIPTextModel,
|
||||
special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=True, enable_attention_masks=False, zero_out_masked=False,
|
||||
return_projected_pooled=True, return_attention_masks=False): # clip-vit-base-patch32
|
||||
return_projected_pooled=True, return_attention_masks=False, model_options={}): # clip-vit-base-patch32
|
||||
super().__init__()
|
||||
assert layer in self.LAYERS
|
||||
|
||||
@@ -94,7 +93,11 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
with open(textmodel_json_config) as f:
|
||||
config = json.load(f)
|
||||
|
||||
self.operations = comfy.ops.manual_cast
|
||||
operations = model_options.get("custom_operations", None)
|
||||
if operations is None:
|
||||
operations = comfy.ops.manual_cast
|
||||
|
||||
self.operations = operations
|
||||
self.transformer = model_class(config, dtype, device, self.operations)
|
||||
self.num_layers = self.transformer.num_layers
|
||||
|
||||
@@ -552,8 +555,12 @@ class SD1Tokenizer:
|
||||
def state_dict(self):
|
||||
return {}
|
||||
|
||||
class SD1CheckpointClipModel(SDClipModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
super().__init__(device=device, return_projected_pooled=False, dtype=dtype, model_options=model_options)
|
||||
|
||||
class SD1ClipModel(torch.nn.Module):
|
||||
def __init__(self, device="cpu", dtype=None, clip_name="l", clip_model=SDClipModel, name=None, **kwargs):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}, clip_name="l", clip_model=SD1CheckpointClipModel, name=None, **kwargs):
|
||||
super().__init__()
|
||||
|
||||
if name is not None:
|
||||
@@ -563,7 +570,7 @@ class SD1ClipModel(torch.nn.Module):
|
||||
self.clip_name = clip_name
|
||||
self.clip = "clip_{}".format(self.clip_name)
|
||||
|
||||
setattr(self, self.clip, clip_model(device=device, dtype=dtype, **kwargs))
|
||||
setattr(self, self.clip, clip_model(device=device, dtype=dtype, model_options=model_options, **kwargs))
|
||||
|
||||
self.dtypes = set()
|
||||
if dtype is not None:
|
||||
|
@@ -3,14 +3,14 @@ import torch
|
||||
import os
|
||||
|
||||
class SDXLClipG(sd1_clip.SDClipModel):
|
||||
def __init__(self, device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, dtype=None):
|
||||
def __init__(self, device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, dtype=None, model_options={}):
|
||||
if layer == "penultimate":
|
||||
layer="hidden"
|
||||
layer_idx=-2
|
||||
|
||||
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_config_bigg.json")
|
||||
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype,
|
||||
special_tokens={"start": 49406, "end": 49407, "pad": 0}, layer_norm_hidden_state=False)
|
||||
special_tokens={"start": 49406, "end": 49407, "pad": 0}, layer_norm_hidden_state=False, return_projected_pooled=True, model_options=model_options)
|
||||
|
||||
def load_sd(self, sd):
|
||||
return super().load_sd(sd)
|
||||
@@ -38,10 +38,10 @@ class SDXLTokenizer:
|
||||
return {}
|
||||
|
||||
class SDXLClipModel(torch.nn.Module):
|
||||
def __init__(self, device="cpu", dtype=None):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
super().__init__()
|
||||
self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False)
|
||||
self.clip_g = SDXLClipG(device=device, dtype=dtype)
|
||||
self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False, model_options=model_options)
|
||||
self.clip_g = SDXLClipG(device=device, dtype=dtype, model_options=model_options)
|
||||
self.dtypes = set([dtype])
|
||||
|
||||
def set_clip_options(self, options):
|
||||
@@ -66,8 +66,8 @@ class SDXLClipModel(torch.nn.Module):
|
||||
return self.clip_l.load_sd(sd)
|
||||
|
||||
class SDXLRefinerClipModel(sd1_clip.SD1ClipModel):
|
||||
def __init__(self, device="cpu", dtype=None):
|
||||
super().__init__(device=device, dtype=dtype, clip_name="g", clip_model=SDXLClipG)
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
super().__init__(device=device, dtype=dtype, clip_name="g", clip_model=SDXLClipG, model_options=model_options)
|
||||
|
||||
|
||||
class StableCascadeClipGTokenizer(sd1_clip.SDTokenizer):
|
||||
@@ -79,14 +79,14 @@ class StableCascadeTokenizer(sd1_clip.SD1Tokenizer):
|
||||
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="g", tokenizer=StableCascadeClipGTokenizer)
|
||||
|
||||
class StableCascadeClipG(sd1_clip.SDClipModel):
|
||||
def __init__(self, device="cpu", max_length=77, freeze=True, layer="hidden", layer_idx=-1, dtype=None):
|
||||
def __init__(self, device="cpu", max_length=77, freeze=True, layer="hidden", layer_idx=-1, dtype=None, model_options={}):
|
||||
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_config_bigg.json")
|
||||
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype,
|
||||
special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=False, enable_attention_masks=True)
|
||||
special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=False, enable_attention_masks=True, return_projected_pooled=True, model_options=model_options)
|
||||
|
||||
def load_sd(self, sd):
|
||||
return super().load_sd(sd)
|
||||
|
||||
class StableCascadeClipModel(sd1_clip.SD1ClipModel):
|
||||
def __init__(self, device="cpu", dtype=None):
|
||||
super().__init__(device=device, dtype=dtype, clip_name="g", clip_model=StableCascadeClipG)
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
super().__init__(device=device, dtype=dtype, clip_name="g", clip_model=StableCascadeClipG, model_options=model_options)
|
||||
|
@@ -181,7 +181,7 @@ class SDXL(supported_models_base.BASE):
|
||||
|
||||
latent_format = latent_formats.SDXL
|
||||
|
||||
memory_usage_factor = 0.7
|
||||
memory_usage_factor = 0.8
|
||||
|
||||
def model_type(self, state_dict, prefix=""):
|
||||
if 'edm_mean' in state_dict and 'edm_std' in state_dict: #Playground V2.5
|
||||
@@ -654,6 +654,7 @@ class Flux(supported_models_base.BASE):
|
||||
def clip_target(self, state_dict={}):
|
||||
pref = self.text_encoder_key_prefix[0]
|
||||
t5_key = "{}t5xxl.transformer.encoder.final_layer_norm.weight".format(pref)
|
||||
dtype_t5 = None
|
||||
if t5_key in state_dict:
|
||||
dtype_t5 = state_dict[t5_key].dtype
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.flux.FluxTokenizer, comfy.text_encoders.flux.flux_clip(dtype_t5=dtype_t5))
|
||||
|
@@ -1,3 +1,21 @@
|
||||
"""
|
||||
This file is part of ComfyUI.
|
||||
Copyright (C) 2024 Comfy
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU General Public License as published by
|
||||
the Free Software Foundation, either version 3 of the License, or
|
||||
(at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU General Public License
|
||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
"""
|
||||
|
||||
import torch
|
||||
from . import model_base
|
||||
from . import utils
|
||||
@@ -30,6 +48,7 @@ class BASE:
|
||||
memory_usage_factor = 2.0
|
||||
|
||||
manual_cast_dtype = None
|
||||
custom_operations = None
|
||||
|
||||
@classmethod
|
||||
def matches(s, unet_config, state_dict=None):
|
||||
|
@@ -4,9 +4,9 @@ import comfy.text_encoders.t5
|
||||
import os
|
||||
|
||||
class PT5XlModel(sd1_clip.SDClipModel):
|
||||
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None):
|
||||
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, model_options={}):
|
||||
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_pile_config_xl.json")
|
||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 2, "pad": 1}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=True, zero_out_masked=True)
|
||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 2, "pad": 1}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=True, zero_out_masked=True, model_options=model_options)
|
||||
|
||||
class PT5XlTokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
@@ -18,5 +18,5 @@ class AuraT5Tokenizer(sd1_clip.SD1Tokenizer):
|
||||
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="pile_t5xl", tokenizer=PT5XlTokenizer)
|
||||
|
||||
class AuraT5Model(sd1_clip.SD1ClipModel):
|
||||
def __init__(self, device="cpu", dtype=None, **kwargs):
|
||||
super().__init__(device=device, dtype=dtype, name="pile_t5xl", clip_model=PT5XlModel, **kwargs)
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}, **kwargs):
|
||||
super().__init__(device=device, dtype=dtype, model_options=model_options, name="pile_t5xl", clip_model=PT5XlModel, **kwargs)
|
||||
|
@@ -6,9 +6,9 @@ import torch
|
||||
import os
|
||||
|
||||
class T5XXLModel(sd1_clip.SDClipModel):
|
||||
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None):
|
||||
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, model_options={}):
|
||||
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_config_xxl.json")
|
||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5)
|
||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, model_options=model_options)
|
||||
|
||||
class T5XXLTokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
@@ -35,11 +35,11 @@ class FluxTokenizer:
|
||||
|
||||
|
||||
class FluxClipModel(torch.nn.Module):
|
||||
def __init__(self, dtype_t5=None, device="cpu", dtype=None):
|
||||
def __init__(self, dtype_t5=None, device="cpu", dtype=None, model_options={}):
|
||||
super().__init__()
|
||||
dtype_t5 = comfy.model_management.pick_weight_dtype(dtype_t5, dtype, device)
|
||||
self.clip_l = sd1_clip.SDClipModel(device=device, dtype=dtype, return_projected_pooled=False)
|
||||
self.t5xxl = T5XXLModel(device=device, dtype=dtype_t5)
|
||||
self.clip_l = sd1_clip.SDClipModel(device=device, dtype=dtype, return_projected_pooled=False, model_options=model_options)
|
||||
self.t5xxl = T5XXLModel(device=device, dtype=dtype_t5, model_options=model_options)
|
||||
self.dtypes = set([dtype, dtype_t5])
|
||||
|
||||
def set_clip_options(self, options):
|
||||
@@ -66,6 +66,6 @@ class FluxClipModel(torch.nn.Module):
|
||||
|
||||
def flux_clip(dtype_t5=None):
|
||||
class FluxClipModel_(FluxClipModel):
|
||||
def __init__(self, device="cpu", dtype=None):
|
||||
super().__init__(dtype_t5=dtype_t5, device=device, dtype=dtype)
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
super().__init__(dtype_t5=dtype_t5, device=device, dtype=dtype, model_options=model_options)
|
||||
return FluxClipModel_
|
||||
|
@@ -7,9 +7,9 @@ import os
|
||||
import torch
|
||||
|
||||
class HyditBertModel(sd1_clip.SDClipModel):
|
||||
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None):
|
||||
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, model_options={}):
|
||||
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "hydit_clip.json")
|
||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"start": 101, "end": 102, "pad": 0}, model_class=BertModel, enable_attention_masks=True, return_attention_masks=True)
|
||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"start": 101, "end": 102, "pad": 0}, model_class=BertModel, enable_attention_masks=True, return_attention_masks=True, model_options=model_options)
|
||||
|
||||
class HyditBertTokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
@@ -18,9 +18,9 @@ class HyditBertTokenizer(sd1_clip.SDTokenizer):
|
||||
|
||||
|
||||
class MT5XLModel(sd1_clip.SDClipModel):
|
||||
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None):
|
||||
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, model_options={}):
|
||||
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "mt5_config_xl.json")
|
||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=True, return_attention_masks=True)
|
||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=True, return_attention_masks=True, model_options=model_options)
|
||||
|
||||
class MT5XLTokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
@@ -50,10 +50,10 @@ class HyditTokenizer:
|
||||
return {"mt5xl.spiece_model": self.mt5xl.state_dict()["spiece_model"]}
|
||||
|
||||
class HyditModel(torch.nn.Module):
|
||||
def __init__(self, device="cpu", dtype=None):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
super().__init__()
|
||||
self.hydit_clip = HyditBertModel(dtype=dtype)
|
||||
self.mt5xl = MT5XLModel(dtype=dtype)
|
||||
self.hydit_clip = HyditBertModel(dtype=dtype, model_options=model_options)
|
||||
self.mt5xl = MT5XLModel(dtype=dtype, model_options=model_options)
|
||||
|
||||
self.dtypes = set()
|
||||
if dtype is not None:
|
||||
|
25
comfy/text_encoders/long_clipl.json
Normal file
25
comfy/text_encoders/long_clipl.json
Normal file
@@ -0,0 +1,25 @@
|
||||
{
|
||||
"_name_or_path": "openai/clip-vit-large-patch14",
|
||||
"architectures": [
|
||||
"CLIPTextModel"
|
||||
],
|
||||
"attention_dropout": 0.0,
|
||||
"bos_token_id": 0,
|
||||
"dropout": 0.0,
|
||||
"eos_token_id": 49407,
|
||||
"hidden_act": "quick_gelu",
|
||||
"hidden_size": 768,
|
||||
"initializer_factor": 1.0,
|
||||
"initializer_range": 0.02,
|
||||
"intermediate_size": 3072,
|
||||
"layer_norm_eps": 1e-05,
|
||||
"max_position_embeddings": 248,
|
||||
"model_type": "clip_text_model",
|
||||
"num_attention_heads": 12,
|
||||
"num_hidden_layers": 12,
|
||||
"pad_token_id": 1,
|
||||
"projection_dim": 768,
|
||||
"torch_dtype": "float32",
|
||||
"transformers_version": "4.24.0",
|
||||
"vocab_size": 49408
|
||||
}
|
19
comfy/text_encoders/long_clipl.py
Normal file
19
comfy/text_encoders/long_clipl.py
Normal file
@@ -0,0 +1,19 @@
|
||||
from comfy import sd1_clip
|
||||
import os
|
||||
|
||||
class LongClipTokenizer_(sd1_clip.SDTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
super().__init__(max_length=248, embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
|
||||
|
||||
class LongClipModel_(sd1_clip.SDClipModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "long_clipl.json")
|
||||
super().__init__(device=device, textmodel_json_config=textmodel_json_config, return_projected_pooled=False, dtype=dtype, model_options=model_options)
|
||||
|
||||
class LongClipTokenizer(sd1_clip.SD1Tokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, tokenizer=LongClipTokenizer_)
|
||||
|
||||
class LongClipModel(sd1_clip.SD1ClipModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}, **kwargs):
|
||||
super().__init__(device=device, dtype=dtype, model_options=model_options, clip_model=LongClipModel_, **kwargs)
|
@@ -4,9 +4,9 @@ import comfy.text_encoders.t5
|
||||
import os
|
||||
|
||||
class T5BaseModel(sd1_clip.SDClipModel):
|
||||
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None):
|
||||
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, model_options={}):
|
||||
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_config_base.json")
|
||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=True, zero_out_masked=True)
|
||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, model_options=model_options, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=True, zero_out_masked=True)
|
||||
|
||||
class T5BaseTokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
@@ -18,5 +18,5 @@ class SAT5Tokenizer(sd1_clip.SD1Tokenizer):
|
||||
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="t5base", tokenizer=T5BaseTokenizer)
|
||||
|
||||
class SAT5Model(sd1_clip.SD1ClipModel):
|
||||
def __init__(self, device="cpu", dtype=None, **kwargs):
|
||||
super().__init__(device=device, dtype=dtype, name="t5base", clip_model=T5BaseModel, **kwargs)
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}, **kwargs):
|
||||
super().__init__(device=device, dtype=dtype, model_options=model_options, name="t5base", clip_model=T5BaseModel, **kwargs)
|
||||
|
@@ -2,13 +2,13 @@ from comfy import sd1_clip
|
||||
import os
|
||||
|
||||
class SD2ClipHModel(sd1_clip.SDClipModel):
|
||||
def __init__(self, arch="ViT-H-14", device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, dtype=None):
|
||||
def __init__(self, arch="ViT-H-14", device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, dtype=None, model_options={}):
|
||||
if layer == "penultimate":
|
||||
layer="hidden"
|
||||
layer_idx=-2
|
||||
|
||||
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd2_clip_config.json")
|
||||
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"start": 49406, "end": 49407, "pad": 0})
|
||||
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"start": 49406, "end": 49407, "pad": 0}, return_projected_pooled=True, model_options=model_options)
|
||||
|
||||
class SD2ClipHTokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, tokenizer_path=None, embedding_directory=None, tokenizer_data={}):
|
||||
@@ -19,5 +19,5 @@ class SD2Tokenizer(sd1_clip.SD1Tokenizer):
|
||||
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="h", tokenizer=SD2ClipHTokenizer)
|
||||
|
||||
class SD2ClipModel(sd1_clip.SD1ClipModel):
|
||||
def __init__(self, device="cpu", dtype=None, **kwargs):
|
||||
super().__init__(device=device, dtype=dtype, clip_name="h", clip_model=SD2ClipHModel, **kwargs)
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}, **kwargs):
|
||||
super().__init__(device=device, dtype=dtype, model_options=model_options, clip_name="h", clip_model=SD2ClipHModel, **kwargs)
|
||||
|
@@ -8,14 +8,14 @@ import comfy.model_management
|
||||
import logging
|
||||
|
||||
class T5XXLModel(sd1_clip.SDClipModel):
|
||||
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None):
|
||||
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, model_options={}):
|
||||
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_config_xxl.json")
|
||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5)
|
||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, model_options=model_options)
|
||||
|
||||
class T5XXLTokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
|
||||
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=77)
|
||||
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=77)
|
||||
|
||||
|
||||
class SD3Tokenizer:
|
||||
@@ -38,24 +38,24 @@ class SD3Tokenizer:
|
||||
return {}
|
||||
|
||||
class SD3ClipModel(torch.nn.Module):
|
||||
def __init__(self, clip_l=True, clip_g=True, t5=True, dtype_t5=None, device="cpu", dtype=None):
|
||||
def __init__(self, clip_l=True, clip_g=True, t5=True, dtype_t5=None, device="cpu", dtype=None, model_options={}):
|
||||
super().__init__()
|
||||
self.dtypes = set()
|
||||
if clip_l:
|
||||
self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False, return_projected_pooled=False)
|
||||
self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False, return_projected_pooled=False, model_options=model_options)
|
||||
self.dtypes.add(dtype)
|
||||
else:
|
||||
self.clip_l = None
|
||||
|
||||
if clip_g:
|
||||
self.clip_g = sdxl_clip.SDXLClipG(device=device, dtype=dtype)
|
||||
self.clip_g = sdxl_clip.SDXLClipG(device=device, dtype=dtype, model_options=model_options)
|
||||
self.dtypes.add(dtype)
|
||||
else:
|
||||
self.clip_g = None
|
||||
|
||||
if t5:
|
||||
dtype_t5 = comfy.model_management.pick_weight_dtype(dtype_t5, dtype, device)
|
||||
self.t5xxl = T5XXLModel(device=device, dtype=dtype_t5)
|
||||
self.t5xxl = T5XXLModel(device=device, dtype=dtype_t5, model_options=model_options)
|
||||
self.dtypes.add(dtype_t5)
|
||||
else:
|
||||
self.t5xxl = None
|
||||
@@ -132,6 +132,6 @@ class SD3ClipModel(torch.nn.Module):
|
||||
|
||||
def sd3_clip(clip_l=True, clip_g=True, t5=True, dtype_t5=None):
|
||||
class SD3ClipModel_(SD3ClipModel):
|
||||
def __init__(self, device="cpu", dtype=None):
|
||||
super().__init__(clip_l=clip_l, clip_g=clip_g, t5=t5, dtype_t5=dtype_t5, device=device, dtype=dtype)
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
super().__init__(clip_l=clip_l, clip_g=clip_g, t5=t5, dtype_t5=dtype_t5, device=device, dtype=dtype, model_options=model_options)
|
||||
return SD3ClipModel_
|
||||
|
@@ -528,6 +528,8 @@ def flux_to_diffusers(mmdit_config, output_prefix=""):
|
||||
("guidance_in.out_layer.weight", "time_text_embed.guidance_embedder.linear_2.weight"),
|
||||
("final_layer.adaLN_modulation.1.bias", "norm_out.linear.bias", swap_scale_shift),
|
||||
("final_layer.adaLN_modulation.1.weight", "norm_out.linear.weight", swap_scale_shift),
|
||||
("pos_embed_input.bias", "controlnet_x_embedder.bias"),
|
||||
("pos_embed_input.weight", "controlnet_x_embedder.weight"),
|
||||
}
|
||||
|
||||
for k in MAP_BASIC:
|
||||
|
308
comfy_execution/caching.py
Normal file
308
comfy_execution/caching.py
Normal file
@@ -0,0 +1,308 @@
|
||||
import itertools
|
||||
from typing import Sequence, Mapping
|
||||
from comfy_execution.graph import DynamicPrompt
|
||||
|
||||
import nodes
|
||||
|
||||
from comfy_execution.graph_utils import is_link
|
||||
|
||||
class CacheKeySet:
|
||||
def __init__(self, dynprompt, node_ids, is_changed_cache):
|
||||
self.keys = {}
|
||||
self.subcache_keys = {}
|
||||
|
||||
def add_keys(self, node_ids):
|
||||
raise NotImplementedError()
|
||||
|
||||
def all_node_ids(self):
|
||||
return set(self.keys.keys())
|
||||
|
||||
def get_used_keys(self):
|
||||
return self.keys.values()
|
||||
|
||||
def get_used_subcache_keys(self):
|
||||
return self.subcache_keys.values()
|
||||
|
||||
def get_data_key(self, node_id):
|
||||
return self.keys.get(node_id, None)
|
||||
|
||||
def get_subcache_key(self, node_id):
|
||||
return self.subcache_keys.get(node_id, None)
|
||||
|
||||
class Unhashable:
|
||||
def __init__(self):
|
||||
self.value = float("NaN")
|
||||
|
||||
def to_hashable(obj):
|
||||
# So that we don't infinitely recurse since frozenset and tuples
|
||||
# are Sequences.
|
||||
if isinstance(obj, (int, float, str, bool, type(None))):
|
||||
return obj
|
||||
elif isinstance(obj, Mapping):
|
||||
return frozenset([(to_hashable(k), to_hashable(v)) for k, v in sorted(obj.items())])
|
||||
elif isinstance(obj, Sequence):
|
||||
return frozenset(zip(itertools.count(), [to_hashable(i) for i in obj]))
|
||||
else:
|
||||
# TODO - Support other objects like tensors?
|
||||
return Unhashable()
|
||||
|
||||
class CacheKeySetID(CacheKeySet):
|
||||
def __init__(self, dynprompt, node_ids, is_changed_cache):
|
||||
super().__init__(dynprompt, node_ids, is_changed_cache)
|
||||
self.dynprompt = dynprompt
|
||||
self.add_keys(node_ids)
|
||||
|
||||
def add_keys(self, node_ids):
|
||||
for node_id in node_ids:
|
||||
if node_id in self.keys:
|
||||
continue
|
||||
if not self.dynprompt.has_node(node_id):
|
||||
continue
|
||||
node = self.dynprompt.get_node(node_id)
|
||||
self.keys[node_id] = (node_id, node["class_type"])
|
||||
self.subcache_keys[node_id] = (node_id, node["class_type"])
|
||||
|
||||
class CacheKeySetInputSignature(CacheKeySet):
|
||||
def __init__(self, dynprompt, node_ids, is_changed_cache):
|
||||
super().__init__(dynprompt, node_ids, is_changed_cache)
|
||||
self.dynprompt = dynprompt
|
||||
self.is_changed_cache = is_changed_cache
|
||||
self.add_keys(node_ids)
|
||||
|
||||
def include_node_id_in_input(self) -> bool:
|
||||
return False
|
||||
|
||||
def add_keys(self, node_ids):
|
||||
for node_id in node_ids:
|
||||
if node_id in self.keys:
|
||||
continue
|
||||
if not self.dynprompt.has_node(node_id):
|
||||
continue
|
||||
node = self.dynprompt.get_node(node_id)
|
||||
self.keys[node_id] = self.get_node_signature(self.dynprompt, node_id)
|
||||
self.subcache_keys[node_id] = (node_id, node["class_type"])
|
||||
|
||||
def get_node_signature(self, dynprompt, node_id):
|
||||
signature = []
|
||||
ancestors, order_mapping = self.get_ordered_ancestry(dynprompt, node_id)
|
||||
signature.append(self.get_immediate_node_signature(dynprompt, node_id, order_mapping))
|
||||
for ancestor_id in ancestors:
|
||||
signature.append(self.get_immediate_node_signature(dynprompt, ancestor_id, order_mapping))
|
||||
return to_hashable(signature)
|
||||
|
||||
def get_immediate_node_signature(self, dynprompt, node_id, ancestor_order_mapping):
|
||||
if not dynprompt.has_node(node_id):
|
||||
# This node doesn't exist -- we can't cache it.
|
||||
return [float("NaN")]
|
||||
node = dynprompt.get_node(node_id)
|
||||
class_type = node["class_type"]
|
||||
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
||||
signature = [class_type, self.is_changed_cache.get(node_id)]
|
||||
if self.include_node_id_in_input() or (hasattr(class_def, "NOT_IDEMPOTENT") and class_def.NOT_IDEMPOTENT):
|
||||
signature.append(node_id)
|
||||
inputs = node["inputs"]
|
||||
for key in sorted(inputs.keys()):
|
||||
if is_link(inputs[key]):
|
||||
(ancestor_id, ancestor_socket) = inputs[key]
|
||||
ancestor_index = ancestor_order_mapping[ancestor_id]
|
||||
signature.append((key,("ANCESTOR", ancestor_index, ancestor_socket)))
|
||||
else:
|
||||
signature.append((key, inputs[key]))
|
||||
return signature
|
||||
|
||||
# This function returns a list of all ancestors of the given node. The order of the list is
|
||||
# deterministic based on which specific inputs the ancestor is connected by.
|
||||
def get_ordered_ancestry(self, dynprompt, node_id):
|
||||
ancestors = []
|
||||
order_mapping = {}
|
||||
self.get_ordered_ancestry_internal(dynprompt, node_id, ancestors, order_mapping)
|
||||
return ancestors, order_mapping
|
||||
|
||||
def get_ordered_ancestry_internal(self, dynprompt, node_id, ancestors, order_mapping):
|
||||
if not dynprompt.has_node(node_id):
|
||||
return
|
||||
inputs = dynprompt.get_node(node_id)["inputs"]
|
||||
input_keys = sorted(inputs.keys())
|
||||
for key in input_keys:
|
||||
if is_link(inputs[key]):
|
||||
ancestor_id = inputs[key][0]
|
||||
if ancestor_id not in order_mapping:
|
||||
ancestors.append(ancestor_id)
|
||||
order_mapping[ancestor_id] = len(ancestors) - 1
|
||||
self.get_ordered_ancestry_internal(dynprompt, ancestor_id, ancestors, order_mapping)
|
||||
|
||||
class BasicCache:
|
||||
def __init__(self, key_class):
|
||||
self.key_class = key_class
|
||||
self.initialized = False
|
||||
self.dynprompt: DynamicPrompt
|
||||
self.cache_key_set: CacheKeySet
|
||||
self.cache = {}
|
||||
self.subcaches = {}
|
||||
|
||||
def set_prompt(self, dynprompt, node_ids, is_changed_cache):
|
||||
self.dynprompt = dynprompt
|
||||
self.cache_key_set = self.key_class(dynprompt, node_ids, is_changed_cache)
|
||||
self.is_changed_cache = is_changed_cache
|
||||
self.initialized = True
|
||||
|
||||
def all_node_ids(self):
|
||||
assert self.initialized
|
||||
node_ids = self.cache_key_set.all_node_ids()
|
||||
for subcache in self.subcaches.values():
|
||||
node_ids = node_ids.union(subcache.all_node_ids())
|
||||
return node_ids
|
||||
|
||||
def _clean_cache(self):
|
||||
preserve_keys = set(self.cache_key_set.get_used_keys())
|
||||
to_remove = []
|
||||
for key in self.cache:
|
||||
if key not in preserve_keys:
|
||||
to_remove.append(key)
|
||||
for key in to_remove:
|
||||
del self.cache[key]
|
||||
|
||||
def _clean_subcaches(self):
|
||||
preserve_subcaches = set(self.cache_key_set.get_used_subcache_keys())
|
||||
|
||||
to_remove = []
|
||||
for key in self.subcaches:
|
||||
if key not in preserve_subcaches:
|
||||
to_remove.append(key)
|
||||
for key in to_remove:
|
||||
del self.subcaches[key]
|
||||
|
||||
def clean_unused(self):
|
||||
assert self.initialized
|
||||
self._clean_cache()
|
||||
self._clean_subcaches()
|
||||
|
||||
def _set_immediate(self, node_id, value):
|
||||
assert self.initialized
|
||||
cache_key = self.cache_key_set.get_data_key(node_id)
|
||||
self.cache[cache_key] = value
|
||||
|
||||
def _get_immediate(self, node_id):
|
||||
if not self.initialized:
|
||||
return None
|
||||
cache_key = self.cache_key_set.get_data_key(node_id)
|
||||
if cache_key in self.cache:
|
||||
return self.cache[cache_key]
|
||||
else:
|
||||
return None
|
||||
|
||||
def _ensure_subcache(self, node_id, children_ids):
|
||||
subcache_key = self.cache_key_set.get_subcache_key(node_id)
|
||||
subcache = self.subcaches.get(subcache_key, None)
|
||||
if subcache is None:
|
||||
subcache = BasicCache(self.key_class)
|
||||
self.subcaches[subcache_key] = subcache
|
||||
subcache.set_prompt(self.dynprompt, children_ids, self.is_changed_cache)
|
||||
return subcache
|
||||
|
||||
def _get_subcache(self, node_id):
|
||||
assert self.initialized
|
||||
subcache_key = self.cache_key_set.get_subcache_key(node_id)
|
||||
if subcache_key in self.subcaches:
|
||||
return self.subcaches[subcache_key]
|
||||
else:
|
||||
return None
|
||||
|
||||
def recursive_debug_dump(self):
|
||||
result = []
|
||||
for key in self.cache:
|
||||
result.append({"key": key, "value": self.cache[key]})
|
||||
for key in self.subcaches:
|
||||
result.append({"subcache_key": key, "subcache": self.subcaches[key].recursive_debug_dump()})
|
||||
return result
|
||||
|
||||
class HierarchicalCache(BasicCache):
|
||||
def __init__(self, key_class):
|
||||
super().__init__(key_class)
|
||||
|
||||
def _get_cache_for(self, node_id):
|
||||
assert self.dynprompt is not None
|
||||
parent_id = self.dynprompt.get_parent_node_id(node_id)
|
||||
if parent_id is None:
|
||||
return self
|
||||
|
||||
hierarchy = []
|
||||
while parent_id is not None:
|
||||
hierarchy.append(parent_id)
|
||||
parent_id = self.dynprompt.get_parent_node_id(parent_id)
|
||||
|
||||
cache = self
|
||||
for parent_id in reversed(hierarchy):
|
||||
cache = cache._get_subcache(parent_id)
|
||||
if cache is None:
|
||||
return None
|
||||
return cache
|
||||
|
||||
def get(self, node_id):
|
||||
cache = self._get_cache_for(node_id)
|
||||
if cache is None:
|
||||
return None
|
||||
return cache._get_immediate(node_id)
|
||||
|
||||
def set(self, node_id, value):
|
||||
cache = self._get_cache_for(node_id)
|
||||
assert cache is not None
|
||||
cache._set_immediate(node_id, value)
|
||||
|
||||
def ensure_subcache_for(self, node_id, children_ids):
|
||||
cache = self._get_cache_for(node_id)
|
||||
assert cache is not None
|
||||
return cache._ensure_subcache(node_id, children_ids)
|
||||
|
||||
class LRUCache(BasicCache):
|
||||
def __init__(self, key_class, max_size=100):
|
||||
super().__init__(key_class)
|
||||
self.max_size = max_size
|
||||
self.min_generation = 0
|
||||
self.generation = 0
|
||||
self.used_generation = {}
|
||||
self.children = {}
|
||||
|
||||
def set_prompt(self, dynprompt, node_ids, is_changed_cache):
|
||||
super().set_prompt(dynprompt, node_ids, is_changed_cache)
|
||||
self.generation += 1
|
||||
for node_id in node_ids:
|
||||
self._mark_used(node_id)
|
||||
|
||||
def clean_unused(self):
|
||||
while len(self.cache) > self.max_size and self.min_generation < self.generation:
|
||||
self.min_generation += 1
|
||||
to_remove = [key for key in self.cache if self.used_generation[key] < self.min_generation]
|
||||
for key in to_remove:
|
||||
del self.cache[key]
|
||||
del self.used_generation[key]
|
||||
if key in self.children:
|
||||
del self.children[key]
|
||||
self._clean_subcaches()
|
||||
|
||||
def get(self, node_id):
|
||||
self._mark_used(node_id)
|
||||
return self._get_immediate(node_id)
|
||||
|
||||
def _mark_used(self, node_id):
|
||||
cache_key = self.cache_key_set.get_data_key(node_id)
|
||||
if cache_key is not None:
|
||||
self.used_generation[cache_key] = self.generation
|
||||
|
||||
def set(self, node_id, value):
|
||||
self._mark_used(node_id)
|
||||
return self._set_immediate(node_id, value)
|
||||
|
||||
def ensure_subcache_for(self, node_id, children_ids):
|
||||
# Just uses subcaches for tracking 'live' nodes
|
||||
super()._ensure_subcache(node_id, children_ids)
|
||||
|
||||
self.cache_key_set.add_keys(children_ids)
|
||||
self._mark_used(node_id)
|
||||
cache_key = self.cache_key_set.get_data_key(node_id)
|
||||
self.children[cache_key] = []
|
||||
for child_id in children_ids:
|
||||
self._mark_used(child_id)
|
||||
self.children[cache_key].append(self.cache_key_set.get_data_key(child_id))
|
||||
return self
|
||||
|
259
comfy_execution/graph.py
Normal file
259
comfy_execution/graph.py
Normal file
@@ -0,0 +1,259 @@
|
||||
import nodes
|
||||
|
||||
from comfy_execution.graph_utils import is_link
|
||||
|
||||
class DependencyCycleError(Exception):
|
||||
pass
|
||||
|
||||
class NodeInputError(Exception):
|
||||
pass
|
||||
|
||||
class NodeNotFoundError(Exception):
|
||||
pass
|
||||
|
||||
class DynamicPrompt:
|
||||
def __init__(self, original_prompt):
|
||||
# The original prompt provided by the user
|
||||
self.original_prompt = original_prompt
|
||||
# Any extra pieces of the graph created during execution
|
||||
self.ephemeral_prompt = {}
|
||||
self.ephemeral_parents = {}
|
||||
self.ephemeral_display = {}
|
||||
|
||||
def get_node(self, node_id):
|
||||
if node_id in self.ephemeral_prompt:
|
||||
return self.ephemeral_prompt[node_id]
|
||||
if node_id in self.original_prompt:
|
||||
return self.original_prompt[node_id]
|
||||
raise NodeNotFoundError(f"Node {node_id} not found")
|
||||
|
||||
def has_node(self, node_id):
|
||||
return node_id in self.original_prompt or node_id in self.ephemeral_prompt
|
||||
|
||||
def add_ephemeral_node(self, node_id, node_info, parent_id, display_id):
|
||||
self.ephemeral_prompt[node_id] = node_info
|
||||
self.ephemeral_parents[node_id] = parent_id
|
||||
self.ephemeral_display[node_id] = display_id
|
||||
|
||||
def get_real_node_id(self, node_id):
|
||||
while node_id in self.ephemeral_parents:
|
||||
node_id = self.ephemeral_parents[node_id]
|
||||
return node_id
|
||||
|
||||
def get_parent_node_id(self, node_id):
|
||||
return self.ephemeral_parents.get(node_id, None)
|
||||
|
||||
def get_display_node_id(self, node_id):
|
||||
while node_id in self.ephemeral_display:
|
||||
node_id = self.ephemeral_display[node_id]
|
||||
return node_id
|
||||
|
||||
def all_node_ids(self):
|
||||
return set(self.original_prompt.keys()).union(set(self.ephemeral_prompt.keys()))
|
||||
|
||||
def get_original_prompt(self):
|
||||
return self.original_prompt
|
||||
|
||||
def get_input_info(class_def, input_name):
|
||||
valid_inputs = class_def.INPUT_TYPES()
|
||||
input_info = None
|
||||
input_category = None
|
||||
if "required" in valid_inputs and input_name in valid_inputs["required"]:
|
||||
input_category = "required"
|
||||
input_info = valid_inputs["required"][input_name]
|
||||
elif "optional" in valid_inputs and input_name in valid_inputs["optional"]:
|
||||
input_category = "optional"
|
||||
input_info = valid_inputs["optional"][input_name]
|
||||
elif "hidden" in valid_inputs and input_name in valid_inputs["hidden"]:
|
||||
input_category = "hidden"
|
||||
input_info = valid_inputs["hidden"][input_name]
|
||||
if input_info is None:
|
||||
return None, None, None
|
||||
input_type = input_info[0]
|
||||
if len(input_info) > 1:
|
||||
extra_info = input_info[1]
|
||||
else:
|
||||
extra_info = {}
|
||||
return input_type, input_category, extra_info
|
||||
|
||||
class TopologicalSort:
|
||||
def __init__(self, dynprompt):
|
||||
self.dynprompt = dynprompt
|
||||
self.pendingNodes = {}
|
||||
self.blockCount = {} # Number of nodes this node is directly blocked by
|
||||
self.blocking = {} # Which nodes are blocked by this node
|
||||
|
||||
def get_input_info(self, unique_id, input_name):
|
||||
class_type = self.dynprompt.get_node(unique_id)["class_type"]
|
||||
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
||||
return get_input_info(class_def, input_name)
|
||||
|
||||
def make_input_strong_link(self, to_node_id, to_input):
|
||||
inputs = self.dynprompt.get_node(to_node_id)["inputs"]
|
||||
if to_input not in inputs:
|
||||
raise NodeInputError(f"Node {to_node_id} says it needs input {to_input}, but there is no input to that node at all")
|
||||
value = inputs[to_input]
|
||||
if not is_link(value):
|
||||
raise NodeInputError(f"Node {to_node_id} says it needs input {to_input}, but that value is a constant")
|
||||
from_node_id, from_socket = value
|
||||
self.add_strong_link(from_node_id, from_socket, to_node_id)
|
||||
|
||||
def add_strong_link(self, from_node_id, from_socket, to_node_id):
|
||||
self.add_node(from_node_id)
|
||||
if to_node_id not in self.blocking[from_node_id]:
|
||||
self.blocking[from_node_id][to_node_id] = {}
|
||||
self.blockCount[to_node_id] += 1
|
||||
self.blocking[from_node_id][to_node_id][from_socket] = True
|
||||
|
||||
def add_node(self, unique_id, include_lazy=False, subgraph_nodes=None):
|
||||
if unique_id in self.pendingNodes:
|
||||
return
|
||||
self.pendingNodes[unique_id] = True
|
||||
self.blockCount[unique_id] = 0
|
||||
self.blocking[unique_id] = {}
|
||||
|
||||
inputs = self.dynprompt.get_node(unique_id)["inputs"]
|
||||
for input_name in inputs:
|
||||
value = inputs[input_name]
|
||||
if is_link(value):
|
||||
from_node_id, from_socket = value
|
||||
if subgraph_nodes is not None and from_node_id not in subgraph_nodes:
|
||||
continue
|
||||
input_type, input_category, input_info = self.get_input_info(unique_id, input_name)
|
||||
is_lazy = input_info is not None and "lazy" in input_info and input_info["lazy"]
|
||||
if include_lazy or not is_lazy:
|
||||
self.add_strong_link(from_node_id, from_socket, unique_id)
|
||||
|
||||
def get_ready_nodes(self):
|
||||
return [node_id for node_id in self.pendingNodes if self.blockCount[node_id] == 0]
|
||||
|
||||
def pop_node(self, unique_id):
|
||||
del self.pendingNodes[unique_id]
|
||||
for blocked_node_id in self.blocking[unique_id]:
|
||||
self.blockCount[blocked_node_id] -= 1
|
||||
del self.blocking[unique_id]
|
||||
|
||||
def is_empty(self):
|
||||
return len(self.pendingNodes) == 0
|
||||
|
||||
class ExecutionList(TopologicalSort):
|
||||
"""
|
||||
ExecutionList implements a topological dissolve of the graph. After a node is staged for execution,
|
||||
it can still be returned to the graph after having further dependencies added.
|
||||
"""
|
||||
def __init__(self, dynprompt, output_cache):
|
||||
super().__init__(dynprompt)
|
||||
self.output_cache = output_cache
|
||||
self.staged_node_id = None
|
||||
|
||||
def add_strong_link(self, from_node_id, from_socket, to_node_id):
|
||||
if self.output_cache.get(from_node_id) is not None:
|
||||
# Nothing to do
|
||||
return
|
||||
super().add_strong_link(from_node_id, from_socket, to_node_id)
|
||||
|
||||
def stage_node_execution(self):
|
||||
assert self.staged_node_id is None
|
||||
if self.is_empty():
|
||||
return None, None, None
|
||||
available = self.get_ready_nodes()
|
||||
if len(available) == 0:
|
||||
cycled_nodes = self.get_nodes_in_cycle()
|
||||
# Because cycles composed entirely of static nodes are caught during initial validation,
|
||||
# we will 'blame' the first node in the cycle that is not a static node.
|
||||
blamed_node = cycled_nodes[0]
|
||||
for node_id in cycled_nodes:
|
||||
display_node_id = self.dynprompt.get_display_node_id(node_id)
|
||||
if display_node_id != node_id:
|
||||
blamed_node = display_node_id
|
||||
break
|
||||
ex = DependencyCycleError("Dependency cycle detected")
|
||||
error_details = {
|
||||
"node_id": blamed_node,
|
||||
"exception_message": str(ex),
|
||||
"exception_type": "graph.DependencyCycleError",
|
||||
"traceback": [],
|
||||
"current_inputs": []
|
||||
}
|
||||
return None, error_details, ex
|
||||
|
||||
self.staged_node_id = self.ux_friendly_pick_node(available)
|
||||
return self.staged_node_id, None, None
|
||||
|
||||
def ux_friendly_pick_node(self, node_list):
|
||||
# If an output node is available, do that first.
|
||||
# Technically this has no effect on the overall length of execution, but it feels better as a user
|
||||
# for a PreviewImage to display a result as soon as it can
|
||||
# Some other heuristics could probably be used here to improve the UX further.
|
||||
def is_output(node_id):
|
||||
class_type = self.dynprompt.get_node(node_id)["class_type"]
|
||||
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
||||
if hasattr(class_def, 'OUTPUT_NODE') and class_def.OUTPUT_NODE == True:
|
||||
return True
|
||||
return False
|
||||
|
||||
for node_id in node_list:
|
||||
if is_output(node_id):
|
||||
return node_id
|
||||
|
||||
#This should handle the VAEDecode -> preview case
|
||||
for node_id in node_list:
|
||||
for blocked_node_id in self.blocking[node_id]:
|
||||
if is_output(blocked_node_id):
|
||||
return node_id
|
||||
|
||||
#This should handle the VAELoader -> VAEDecode -> preview case
|
||||
for node_id in node_list:
|
||||
for blocked_node_id in self.blocking[node_id]:
|
||||
for blocked_node_id1 in self.blocking[blocked_node_id]:
|
||||
if is_output(blocked_node_id1):
|
||||
return node_id
|
||||
|
||||
#TODO: this function should be improved
|
||||
return node_list[0]
|
||||
|
||||
def unstage_node_execution(self):
|
||||
assert self.staged_node_id is not None
|
||||
self.staged_node_id = None
|
||||
|
||||
def complete_node_execution(self):
|
||||
node_id = self.staged_node_id
|
||||
self.pop_node(node_id)
|
||||
self.staged_node_id = None
|
||||
|
||||
def get_nodes_in_cycle(self):
|
||||
# We'll dissolve the graph in reverse topological order to leave only the nodes in the cycle.
|
||||
# We're skipping some of the performance optimizations from the original TopologicalSort to keep
|
||||
# the code simple (and because having a cycle in the first place is a catastrophic error)
|
||||
blocked_by = { node_id: {} for node_id in self.pendingNodes }
|
||||
for from_node_id in self.blocking:
|
||||
for to_node_id in self.blocking[from_node_id]:
|
||||
if True in self.blocking[from_node_id][to_node_id].values():
|
||||
blocked_by[to_node_id][from_node_id] = True
|
||||
to_remove = [node_id for node_id in blocked_by if len(blocked_by[node_id]) == 0]
|
||||
while len(to_remove) > 0:
|
||||
for node_id in to_remove:
|
||||
for to_node_id in blocked_by:
|
||||
if node_id in blocked_by[to_node_id]:
|
||||
del blocked_by[to_node_id][node_id]
|
||||
del blocked_by[node_id]
|
||||
to_remove = [node_id for node_id in blocked_by if len(blocked_by[node_id]) == 0]
|
||||
return list(blocked_by.keys())
|
||||
|
||||
class ExecutionBlocker:
|
||||
"""
|
||||
Return this from a node and any users will be blocked with the given error message.
|
||||
If the message is None, execution will be blocked silently instead.
|
||||
Generally, you should avoid using this functionality unless absolutely necessary. Whenever it's
|
||||
possible, a lazy input will be more efficient and have a better user experience.
|
||||
This functionality is useful in two cases:
|
||||
1. You want to conditionally prevent an output node from executing. (Particularly a built-in node
|
||||
like SaveImage. For your own output nodes, I would recommend just adding a BOOL input and using
|
||||
lazy evaluation to let it conditionally disable itself.)
|
||||
2. You have a node with multiple possible outputs, some of which are invalid and should not be used.
|
||||
(I would recommend not making nodes like this in the future -- instead, make multiple nodes with
|
||||
different outputs. Unfortunately, there are several popular existing nodes using this pattern.)
|
||||
"""
|
||||
def __init__(self, message):
|
||||
self.message = message
|
||||
|
139
comfy_execution/graph_utils.py
Normal file
139
comfy_execution/graph_utils.py
Normal file
@@ -0,0 +1,139 @@
|
||||
def is_link(obj):
|
||||
if not isinstance(obj, list):
|
||||
return False
|
||||
if len(obj) != 2:
|
||||
return False
|
||||
if not isinstance(obj[0], str):
|
||||
return False
|
||||
if not isinstance(obj[1], int) and not isinstance(obj[1], float):
|
||||
return False
|
||||
return True
|
||||
|
||||
# The GraphBuilder is just a utility class that outputs graphs in the form expected by the ComfyUI back-end
|
||||
class GraphBuilder:
|
||||
_default_prefix_root = ""
|
||||
_default_prefix_call_index = 0
|
||||
_default_prefix_graph_index = 0
|
||||
|
||||
def __init__(self, prefix = None):
|
||||
if prefix is None:
|
||||
self.prefix = GraphBuilder.alloc_prefix()
|
||||
else:
|
||||
self.prefix = prefix
|
||||
self.nodes = {}
|
||||
self.id_gen = 1
|
||||
|
||||
@classmethod
|
||||
def set_default_prefix(cls, prefix_root, call_index, graph_index = 0):
|
||||
cls._default_prefix_root = prefix_root
|
||||
cls._default_prefix_call_index = call_index
|
||||
cls._default_prefix_graph_index = graph_index
|
||||
|
||||
@classmethod
|
||||
def alloc_prefix(cls, root=None, call_index=None, graph_index=None):
|
||||
if root is None:
|
||||
root = GraphBuilder._default_prefix_root
|
||||
if call_index is None:
|
||||
call_index = GraphBuilder._default_prefix_call_index
|
||||
if graph_index is None:
|
||||
graph_index = GraphBuilder._default_prefix_graph_index
|
||||
result = f"{root}.{call_index}.{graph_index}."
|
||||
GraphBuilder._default_prefix_graph_index += 1
|
||||
return result
|
||||
|
||||
def node(self, class_type, id=None, **kwargs):
|
||||
if id is None:
|
||||
id = str(self.id_gen)
|
||||
self.id_gen += 1
|
||||
id = self.prefix + id
|
||||
if id in self.nodes:
|
||||
return self.nodes[id]
|
||||
|
||||
node = Node(id, class_type, kwargs)
|
||||
self.nodes[id] = node
|
||||
return node
|
||||
|
||||
def lookup_node(self, id):
|
||||
id = self.prefix + id
|
||||
return self.nodes.get(id)
|
||||
|
||||
def finalize(self):
|
||||
output = {}
|
||||
for node_id, node in self.nodes.items():
|
||||
output[node_id] = node.serialize()
|
||||
return output
|
||||
|
||||
def replace_node_output(self, node_id, index, new_value):
|
||||
node_id = self.prefix + node_id
|
||||
to_remove = []
|
||||
for node in self.nodes.values():
|
||||
for key, value in node.inputs.items():
|
||||
if is_link(value) and value[0] == node_id and value[1] == index:
|
||||
if new_value is None:
|
||||
to_remove.append((node, key))
|
||||
else:
|
||||
node.inputs[key] = new_value
|
||||
for node, key in to_remove:
|
||||
del node.inputs[key]
|
||||
|
||||
def remove_node(self, id):
|
||||
id = self.prefix + id
|
||||
del self.nodes[id]
|
||||
|
||||
class Node:
|
||||
def __init__(self, id, class_type, inputs):
|
||||
self.id = id
|
||||
self.class_type = class_type
|
||||
self.inputs = inputs
|
||||
self.override_display_id = None
|
||||
|
||||
def out(self, index):
|
||||
return [self.id, index]
|
||||
|
||||
def set_input(self, key, value):
|
||||
if value is None:
|
||||
if key in self.inputs:
|
||||
del self.inputs[key]
|
||||
else:
|
||||
self.inputs[key] = value
|
||||
|
||||
def get_input(self, key):
|
||||
return self.inputs.get(key)
|
||||
|
||||
def set_override_display_id(self, override_display_id):
|
||||
self.override_display_id = override_display_id
|
||||
|
||||
def serialize(self):
|
||||
serialized = {
|
||||
"class_type": self.class_type,
|
||||
"inputs": self.inputs
|
||||
}
|
||||
if self.override_display_id is not None:
|
||||
serialized["override_display_id"] = self.override_display_id
|
||||
return serialized
|
||||
|
||||
def add_graph_prefix(graph, outputs, prefix):
|
||||
# Change the node IDs and any internal links
|
||||
new_graph = {}
|
||||
for node_id, node_info in graph.items():
|
||||
# Make sure the added nodes have unique IDs
|
||||
new_node_id = prefix + node_id
|
||||
new_node = { "class_type": node_info["class_type"], "inputs": {} }
|
||||
for input_name, input_value in node_info.get("inputs", {}).items():
|
||||
if is_link(input_value):
|
||||
new_node["inputs"][input_name] = [prefix + input_value[0], input_value[1]]
|
||||
else:
|
||||
new_node["inputs"][input_name] = input_value
|
||||
new_graph[new_node_id] = new_node
|
||||
|
||||
# Change the node IDs in the outputs
|
||||
new_outputs = []
|
||||
for n in range(len(outputs)):
|
||||
output = outputs[n]
|
||||
if is_link(output):
|
||||
new_outputs.append([prefix + output[0], output[1]])
|
||||
else:
|
||||
new_outputs.append(output)
|
||||
|
||||
return new_graph, tuple(new_outputs)
|
||||
|
@@ -333,6 +333,25 @@ class VAESave:
|
||||
comfy.utils.save_torch_file(vae.get_sd(), output_checkpoint, metadata=metadata)
|
||||
return {}
|
||||
|
||||
class ModelSave:
|
||||
def __init__(self):
|
||||
self.output_dir = folder_paths.get_output_directory()
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "model": ("MODEL",),
|
||||
"filename_prefix": ("STRING", {"default": "diffusion_models/ComfyUI"}),},
|
||||
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},}
|
||||
RETURN_TYPES = ()
|
||||
FUNCTION = "save"
|
||||
OUTPUT_NODE = True
|
||||
|
||||
CATEGORY = "advanced/model_merging"
|
||||
|
||||
def save(self, model, filename_prefix, prompt=None, extra_pnginfo=None):
|
||||
save_checkpoint(model, filename_prefix=filename_prefix, output_dir=self.output_dir, prompt=prompt, extra_pnginfo=extra_pnginfo)
|
||||
return {}
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"ModelMergeSimple": ModelMergeSimple,
|
||||
"ModelMergeBlocks": ModelMergeBlocks,
|
||||
@@ -344,4 +363,9 @@ NODE_CLASS_MAPPINGS = {
|
||||
"CLIPMergeAdd": CLIPAdd,
|
||||
"CLIPSave": CLIPSave,
|
||||
"VAESave": VAESave,
|
||||
"ModelSave": ModelSave,
|
||||
}
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"CheckpointSave": "Save Checkpoint",
|
||||
}
|
||||
|
@@ -4,14 +4,14 @@ class Example:
|
||||
|
||||
Class methods
|
||||
-------------
|
||||
INPUT_TYPES (dict):
|
||||
INPUT_TYPES (dict):
|
||||
Tell the main program input parameters of nodes.
|
||||
IS_CHANGED:
|
||||
optional method to control when the node is re executed.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
RETURN_TYPES (`tuple`):
|
||||
RETURN_TYPES (`tuple`):
|
||||
The type of each element in the output tuple.
|
||||
RETURN_NAMES (`tuple`):
|
||||
Optional: The name of each output in the output tuple.
|
||||
@@ -23,13 +23,19 @@ class Example:
|
||||
Assumed to be False if not present.
|
||||
CATEGORY (`str`):
|
||||
The category the node should appear in the UI.
|
||||
DEPRECATED (`bool`):
|
||||
Indicates whether the node is deprecated. Deprecated nodes are hidden by default in the UI, but remain
|
||||
functional in existing workflows that use them.
|
||||
EXPERIMENTAL (`bool`):
|
||||
Indicates whether the node is experimental. Experimental nodes are marked as such in the UI and may be subject to
|
||||
significant changes or removal in future versions. Use with caution in production workflows.
|
||||
execute(s) -> tuple || None:
|
||||
The entry point method. The name of this method must be the same as the value of property `FUNCTION`.
|
||||
For example, if `FUNCTION = "execute"` then this method's name must be `execute`, if `FUNCTION = "foo"` then it must be `foo`.
|
||||
"""
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
"""
|
||||
@@ -54,7 +60,8 @@ class Example:
|
||||
"min": 0, #Minimum value
|
||||
"max": 4096, #Maximum value
|
||||
"step": 64, #Slider's step
|
||||
"display": "number" # Cosmetic only: display as "number" or "slider"
|
||||
"display": "number", # Cosmetic only: display as "number" or "slider"
|
||||
"lazy": True # Will only be evaluated if check_lazy_status requires it
|
||||
}),
|
||||
"float_field": ("FLOAT", {
|
||||
"default": 1.0,
|
||||
@@ -62,11 +69,14 @@ class Example:
|
||||
"max": 10.0,
|
||||
"step": 0.01,
|
||||
"round": 0.001, #The value representing the precision to round to, will be set to the step value by default. Can be set to False to disable rounding.
|
||||
"display": "number"}),
|
||||
"display": "number",
|
||||
"lazy": True
|
||||
}),
|
||||
"print_to_screen": (["enable", "disable"],),
|
||||
"string_field": ("STRING", {
|
||||
"multiline": False, #True if you want the field to look like the one on the ClipTextEncode node
|
||||
"default": "Hello World!"
|
||||
"default": "Hello World!",
|
||||
"lazy": True
|
||||
}),
|
||||
},
|
||||
}
|
||||
@@ -80,6 +90,23 @@ class Example:
|
||||
|
||||
CATEGORY = "Example"
|
||||
|
||||
def check_lazy_status(self, image, string_field, int_field, float_field, print_to_screen):
|
||||
"""
|
||||
Return a list of input names that need to be evaluated.
|
||||
|
||||
This function will be called if there are any lazy inputs which have not yet been
|
||||
evaluated. As long as you return at least one field which has not yet been evaluated
|
||||
(and more exist), this function will be called again once the value of the requested
|
||||
field is available.
|
||||
|
||||
Any evaluated inputs will be passed as arguments to this function. Any unevaluated
|
||||
inputs will have the value None.
|
||||
"""
|
||||
if print_to_screen == "enable":
|
||||
return ["int_field", "float_field", "string_field"]
|
||||
else:
|
||||
return []
|
||||
|
||||
def test(self, image, string_field, int_field, float_field, print_to_screen):
|
||||
if print_to_screen == "enable":
|
||||
print(f"""Your input contains:
|
||||
|
648
execution.py
648
execution.py
@@ -5,6 +5,7 @@ import threading
|
||||
import heapq
|
||||
import time
|
||||
import traceback
|
||||
from enum import Enum
|
||||
import inspect
|
||||
from typing import List, Literal, NamedTuple, Optional
|
||||
|
||||
@@ -12,102 +13,219 @@ import torch
|
||||
import nodes
|
||||
|
||||
import comfy.model_management
|
||||
from comfy_execution.graph import get_input_info, ExecutionList, DynamicPrompt, ExecutionBlocker
|
||||
from comfy_execution.graph_utils import is_link, GraphBuilder
|
||||
from comfy_execution.caching import HierarchicalCache, LRUCache, CacheKeySetInputSignature, CacheKeySetID
|
||||
from comfy.cli_args import args
|
||||
|
||||
def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_data={}):
|
||||
class ExecutionResult(Enum):
|
||||
SUCCESS = 0
|
||||
FAILURE = 1
|
||||
PENDING = 2
|
||||
|
||||
class DuplicateNodeError(Exception):
|
||||
pass
|
||||
|
||||
class IsChangedCache:
|
||||
def __init__(self, dynprompt, outputs_cache):
|
||||
self.dynprompt = dynprompt
|
||||
self.outputs_cache = outputs_cache
|
||||
self.is_changed = {}
|
||||
|
||||
def get(self, node_id):
|
||||
if node_id in self.is_changed:
|
||||
return self.is_changed[node_id]
|
||||
|
||||
node = self.dynprompt.get_node(node_id)
|
||||
class_type = node["class_type"]
|
||||
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
||||
if not hasattr(class_def, "IS_CHANGED"):
|
||||
self.is_changed[node_id] = False
|
||||
return self.is_changed[node_id]
|
||||
|
||||
if "is_changed" in node:
|
||||
self.is_changed[node_id] = node["is_changed"]
|
||||
return self.is_changed[node_id]
|
||||
|
||||
# Intentionally do not use cached outputs here. We only want constants in IS_CHANGED
|
||||
input_data_all, _ = get_input_data(node["inputs"], class_def, node_id, None)
|
||||
try:
|
||||
is_changed = _map_node_over_list(class_def, input_data_all, "IS_CHANGED")
|
||||
node["is_changed"] = [None if isinstance(x, ExecutionBlocker) else x for x in is_changed]
|
||||
except Exception as e:
|
||||
logging.warning("WARNING: {}".format(e))
|
||||
node["is_changed"] = float("NaN")
|
||||
finally:
|
||||
self.is_changed[node_id] = node["is_changed"]
|
||||
return self.is_changed[node_id]
|
||||
|
||||
class CacheSet:
|
||||
def __init__(self, lru_size=None):
|
||||
if lru_size is None or lru_size == 0:
|
||||
self.init_classic_cache()
|
||||
else:
|
||||
self.init_lru_cache(lru_size)
|
||||
self.all = [self.outputs, self.ui, self.objects]
|
||||
|
||||
# Useful for those with ample RAM/VRAM -- allows experimenting without
|
||||
# blowing away the cache every time
|
||||
def init_lru_cache(self, cache_size):
|
||||
self.outputs = LRUCache(CacheKeySetInputSignature, max_size=cache_size)
|
||||
self.ui = LRUCache(CacheKeySetInputSignature, max_size=cache_size)
|
||||
self.objects = HierarchicalCache(CacheKeySetID)
|
||||
|
||||
# Performs like the old cache -- dump data ASAP
|
||||
def init_classic_cache(self):
|
||||
self.outputs = HierarchicalCache(CacheKeySetInputSignature)
|
||||
self.ui = HierarchicalCache(CacheKeySetInputSignature)
|
||||
self.objects = HierarchicalCache(CacheKeySetID)
|
||||
|
||||
def recursive_debug_dump(self):
|
||||
result = {
|
||||
"outputs": self.outputs.recursive_debug_dump(),
|
||||
"ui": self.ui.recursive_debug_dump(),
|
||||
}
|
||||
return result
|
||||
|
||||
def get_input_data(inputs, class_def, unique_id, outputs=None, dynprompt=None, extra_data={}):
|
||||
valid_inputs = class_def.INPUT_TYPES()
|
||||
input_data_all = {}
|
||||
missing_keys = {}
|
||||
for x in inputs:
|
||||
input_data = inputs[x]
|
||||
if isinstance(input_data, list):
|
||||
input_type, input_category, input_info = get_input_info(class_def, x)
|
||||
def mark_missing():
|
||||
missing_keys[x] = True
|
||||
input_data_all[x] = (None,)
|
||||
if is_link(input_data) and (not input_info or not input_info.get("rawLink", False)):
|
||||
input_unique_id = input_data[0]
|
||||
output_index = input_data[1]
|
||||
if input_unique_id not in outputs:
|
||||
input_data_all[x] = (None,)
|
||||
if outputs is None:
|
||||
mark_missing()
|
||||
continue # This might be a lazily-evaluated input
|
||||
cached_output = outputs.get(input_unique_id)
|
||||
if cached_output is None:
|
||||
mark_missing()
|
||||
continue
|
||||
obj = outputs[input_unique_id][output_index]
|
||||
if output_index >= len(cached_output):
|
||||
mark_missing()
|
||||
continue
|
||||
obj = cached_output[output_index]
|
||||
input_data_all[x] = obj
|
||||
else:
|
||||
if ("required" in valid_inputs and x in valid_inputs["required"]) or ("optional" in valid_inputs and x in valid_inputs["optional"]):
|
||||
input_data_all[x] = [input_data]
|
||||
elif input_category is not None:
|
||||
input_data_all[x] = [input_data]
|
||||
|
||||
if "hidden" in valid_inputs:
|
||||
h = valid_inputs["hidden"]
|
||||
for x in h:
|
||||
if h[x] == "PROMPT":
|
||||
input_data_all[x] = [prompt]
|
||||
input_data_all[x] = [dynprompt.get_original_prompt() if dynprompt is not None else {}]
|
||||
if h[x] == "DYNPROMPT":
|
||||
input_data_all[x] = [dynprompt]
|
||||
if h[x] == "EXTRA_PNGINFO":
|
||||
input_data_all[x] = [extra_data.get('extra_pnginfo', None)]
|
||||
if h[x] == "UNIQUE_ID":
|
||||
input_data_all[x] = [unique_id]
|
||||
return input_data_all
|
||||
return input_data_all, missing_keys
|
||||
|
||||
def map_node_over_list(obj, input_data_all, func, allow_interrupt=False):
|
||||
map_node_over_list = None #Don't hook this please
|
||||
|
||||
def _map_node_over_list(obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None):
|
||||
# check if node wants the lists
|
||||
input_is_list = False
|
||||
if hasattr(obj, "INPUT_IS_LIST"):
|
||||
input_is_list = obj.INPUT_IS_LIST
|
||||
input_is_list = getattr(obj, "INPUT_IS_LIST", False)
|
||||
|
||||
if len(input_data_all) == 0:
|
||||
max_len_input = 0
|
||||
else:
|
||||
max_len_input = max([len(x) for x in input_data_all.values()])
|
||||
max_len_input = max(len(x) for x in input_data_all.values())
|
||||
|
||||
# get a slice of inputs, repeat last input when list isn't long enough
|
||||
def slice_dict(d, i):
|
||||
d_new = dict()
|
||||
for k,v in d.items():
|
||||
d_new[k] = v[i if len(v) > i else -1]
|
||||
return d_new
|
||||
return {k: v[i if len(v) > i else -1] for k, v in d.items()}
|
||||
|
||||
results = []
|
||||
def process_inputs(inputs, index=None):
|
||||
if allow_interrupt:
|
||||
nodes.before_node_execution()
|
||||
execution_block = None
|
||||
for k, v in inputs.items():
|
||||
if isinstance(v, ExecutionBlocker):
|
||||
execution_block = execution_block_cb(v) if execution_block_cb else v
|
||||
break
|
||||
if execution_block is None:
|
||||
if pre_execute_cb is not None and index is not None:
|
||||
pre_execute_cb(index)
|
||||
results.append(getattr(obj, func)(**inputs))
|
||||
else:
|
||||
results.append(execution_block)
|
||||
|
||||
if input_is_list:
|
||||
if allow_interrupt:
|
||||
nodes.before_node_execution()
|
||||
results.append(getattr(obj, func)(**input_data_all))
|
||||
process_inputs(input_data_all, 0)
|
||||
elif max_len_input == 0:
|
||||
if allow_interrupt:
|
||||
nodes.before_node_execution()
|
||||
results.append(getattr(obj, func)())
|
||||
else:
|
||||
process_inputs({})
|
||||
else:
|
||||
for i in range(max_len_input):
|
||||
if allow_interrupt:
|
||||
nodes.before_node_execution()
|
||||
results.append(getattr(obj, func)(**slice_dict(input_data_all, i)))
|
||||
input_dict = slice_dict(input_data_all, i)
|
||||
process_inputs(input_dict, i)
|
||||
return results
|
||||
|
||||
def get_output_data(obj, input_data_all):
|
||||
def merge_result_data(results, obj):
|
||||
# check which outputs need concatenating
|
||||
output = []
|
||||
output_is_list = [False] * len(results[0])
|
||||
if hasattr(obj, "OUTPUT_IS_LIST"):
|
||||
output_is_list = obj.OUTPUT_IS_LIST
|
||||
|
||||
# merge node execution results
|
||||
for i, is_list in zip(range(len(results[0])), output_is_list):
|
||||
if is_list:
|
||||
output.append([x for o in results for x in o[i]])
|
||||
else:
|
||||
output.append([o[i] for o in results])
|
||||
return output
|
||||
|
||||
def get_output_data(obj, input_data_all, execution_block_cb=None, pre_execute_cb=None):
|
||||
|
||||
results = []
|
||||
uis = []
|
||||
return_values = map_node_over_list(obj, input_data_all, obj.FUNCTION, allow_interrupt=True)
|
||||
|
||||
for r in return_values:
|
||||
subgraph_results = []
|
||||
return_values = _map_node_over_list(obj, input_data_all, obj.FUNCTION, allow_interrupt=True, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb)
|
||||
has_subgraph = False
|
||||
for i in range(len(return_values)):
|
||||
r = return_values[i]
|
||||
if isinstance(r, dict):
|
||||
if 'ui' in r:
|
||||
uis.append(r['ui'])
|
||||
if 'result' in r:
|
||||
results.append(r['result'])
|
||||
if 'expand' in r:
|
||||
# Perform an expansion, but do not append results
|
||||
has_subgraph = True
|
||||
new_graph = r['expand']
|
||||
result = r.get("result", None)
|
||||
if isinstance(result, ExecutionBlocker):
|
||||
result = tuple([result] * len(obj.RETURN_TYPES))
|
||||
subgraph_results.append((new_graph, result))
|
||||
elif 'result' in r:
|
||||
result = r.get("result", None)
|
||||
if isinstance(result, ExecutionBlocker):
|
||||
result = tuple([result] * len(obj.RETURN_TYPES))
|
||||
results.append(result)
|
||||
subgraph_results.append((None, result))
|
||||
else:
|
||||
if isinstance(r, ExecutionBlocker):
|
||||
r = tuple([r] * len(obj.RETURN_TYPES))
|
||||
results.append(r)
|
||||
subgraph_results.append((None, r))
|
||||
|
||||
output = []
|
||||
if len(results) > 0:
|
||||
# check which outputs need concatenating
|
||||
output_is_list = [False] * len(results[0])
|
||||
if hasattr(obj, "OUTPUT_IS_LIST"):
|
||||
output_is_list = obj.OUTPUT_IS_LIST
|
||||
|
||||
# merge node execution results
|
||||
for i, is_list in zip(range(len(results[0])), output_is_list):
|
||||
if is_list:
|
||||
output.append([x for o in results for x in o[i]])
|
||||
else:
|
||||
output.append([o[i] for o in results])
|
||||
|
||||
if has_subgraph:
|
||||
output = subgraph_results
|
||||
elif len(results) > 0:
|
||||
output = merge_result_data(results, obj)
|
||||
else:
|
||||
output = []
|
||||
ui = dict()
|
||||
if len(uis) > 0:
|
||||
ui = {k: [y for x in uis for y in x[k]] for k in uis[0].keys()}
|
||||
return output, ui
|
||||
return output, ui, has_subgraph
|
||||
|
||||
def format_value(x):
|
||||
if x is None:
|
||||
@@ -117,53 +235,145 @@ def format_value(x):
|
||||
else:
|
||||
return str(x)
|
||||
|
||||
def recursive_execute(server, prompt, outputs, current_item, extra_data, executed, prompt_id, outputs_ui, object_storage):
|
||||
def execute(server, dynprompt, caches, current_item, extra_data, executed, prompt_id, execution_list, pending_subgraph_results):
|
||||
unique_id = current_item
|
||||
inputs = prompt[unique_id]['inputs']
|
||||
class_type = prompt[unique_id]['class_type']
|
||||
real_node_id = dynprompt.get_real_node_id(unique_id)
|
||||
display_node_id = dynprompt.get_display_node_id(unique_id)
|
||||
parent_node_id = dynprompt.get_parent_node_id(unique_id)
|
||||
inputs = dynprompt.get_node(unique_id)['inputs']
|
||||
class_type = dynprompt.get_node(unique_id)['class_type']
|
||||
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
||||
if unique_id in outputs:
|
||||
return (True, None, None)
|
||||
|
||||
for x in inputs:
|
||||
input_data = inputs[x]
|
||||
|
||||
if isinstance(input_data, list):
|
||||
input_unique_id = input_data[0]
|
||||
output_index = input_data[1]
|
||||
if input_unique_id not in outputs:
|
||||
result = recursive_execute(server, prompt, outputs, input_unique_id, extra_data, executed, prompt_id, outputs_ui, object_storage)
|
||||
if result[0] is not True:
|
||||
# Another node failed further upstream
|
||||
return result
|
||||
if caches.outputs.get(unique_id) is not None:
|
||||
if server.client_id is not None:
|
||||
cached_output = caches.ui.get(unique_id) or {}
|
||||
server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": cached_output.get("output",None), "prompt_id": prompt_id }, server.client_id)
|
||||
return (ExecutionResult.SUCCESS, None, None)
|
||||
|
||||
input_data_all = None
|
||||
try:
|
||||
input_data_all = get_input_data(inputs, class_def, unique_id, outputs, prompt, extra_data)
|
||||
if server.client_id is not None:
|
||||
server.last_node_id = unique_id
|
||||
server.send_sync("executing", { "node": unique_id, "prompt_id": prompt_id }, server.client_id)
|
||||
if unique_id in pending_subgraph_results:
|
||||
cached_results = pending_subgraph_results[unique_id]
|
||||
resolved_outputs = []
|
||||
for is_subgraph, result in cached_results:
|
||||
if not is_subgraph:
|
||||
resolved_outputs.append(result)
|
||||
else:
|
||||
resolved_output = []
|
||||
for r in result:
|
||||
if is_link(r):
|
||||
source_node, source_output = r[0], r[1]
|
||||
node_output = caches.outputs.get(source_node)[source_output]
|
||||
for o in node_output:
|
||||
resolved_output.append(o)
|
||||
|
||||
obj = object_storage.get((unique_id, class_type), None)
|
||||
if obj is None:
|
||||
obj = class_def()
|
||||
object_storage[(unique_id, class_type)] = obj
|
||||
|
||||
output_data, output_ui = get_output_data(obj, input_data_all)
|
||||
outputs[unique_id] = output_data
|
||||
if len(output_ui) > 0:
|
||||
outputs_ui[unique_id] = output_ui
|
||||
else:
|
||||
resolved_output.append(r)
|
||||
resolved_outputs.append(tuple(resolved_output))
|
||||
output_data = merge_result_data(resolved_outputs, class_def)
|
||||
output_ui = []
|
||||
has_subgraph = False
|
||||
else:
|
||||
input_data_all, missing_keys = get_input_data(inputs, class_def, unique_id, caches.outputs, dynprompt, extra_data)
|
||||
if server.client_id is not None:
|
||||
server.send_sync("executed", { "node": unique_id, "output": output_ui, "prompt_id": prompt_id }, server.client_id)
|
||||
server.last_node_id = display_node_id
|
||||
server.send_sync("executing", { "node": unique_id, "display_node": display_node_id, "prompt_id": prompt_id }, server.client_id)
|
||||
|
||||
obj = caches.objects.get(unique_id)
|
||||
if obj is None:
|
||||
obj = class_def()
|
||||
caches.objects.set(unique_id, obj)
|
||||
|
||||
if hasattr(obj, "check_lazy_status"):
|
||||
required_inputs = _map_node_over_list(obj, input_data_all, "check_lazy_status", allow_interrupt=True)
|
||||
required_inputs = set(sum([r for r in required_inputs if isinstance(r,list)], []))
|
||||
required_inputs = [x for x in required_inputs if isinstance(x,str) and (
|
||||
x not in input_data_all or x in missing_keys
|
||||
)]
|
||||
if len(required_inputs) > 0:
|
||||
for i in required_inputs:
|
||||
execution_list.make_input_strong_link(unique_id, i)
|
||||
return (ExecutionResult.PENDING, None, None)
|
||||
|
||||
def execution_block_cb(block):
|
||||
if block.message is not None:
|
||||
mes = {
|
||||
"prompt_id": prompt_id,
|
||||
"node_id": unique_id,
|
||||
"node_type": class_type,
|
||||
"executed": list(executed),
|
||||
|
||||
"exception_message": f"Execution Blocked: {block.message}",
|
||||
"exception_type": "ExecutionBlocked",
|
||||
"traceback": [],
|
||||
"current_inputs": [],
|
||||
"current_outputs": [],
|
||||
}
|
||||
server.send_sync("execution_error", mes, server.client_id)
|
||||
return ExecutionBlocker(None)
|
||||
else:
|
||||
return block
|
||||
def pre_execute_cb(call_index):
|
||||
GraphBuilder.set_default_prefix(unique_id, call_index, 0)
|
||||
output_data, output_ui, has_subgraph = get_output_data(obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb)
|
||||
if len(output_ui) > 0:
|
||||
caches.ui.set(unique_id, {
|
||||
"meta": {
|
||||
"node_id": unique_id,
|
||||
"display_node": display_node_id,
|
||||
"parent_node": parent_node_id,
|
||||
"real_node_id": real_node_id,
|
||||
},
|
||||
"output": output_ui
|
||||
})
|
||||
if server.client_id is not None:
|
||||
server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": output_ui, "prompt_id": prompt_id }, server.client_id)
|
||||
if has_subgraph:
|
||||
cached_outputs = []
|
||||
new_node_ids = []
|
||||
new_output_ids = []
|
||||
new_output_links = []
|
||||
for i in range(len(output_data)):
|
||||
new_graph, node_outputs = output_data[i]
|
||||
if new_graph is None:
|
||||
cached_outputs.append((False, node_outputs))
|
||||
else:
|
||||
# Check for conflicts
|
||||
for node_id in new_graph.keys():
|
||||
if dynprompt.has_node(node_id):
|
||||
raise DuplicateNodeError(f"Attempt to add duplicate node {node_id}. Ensure node ids are unique and deterministic or use graph_utils.GraphBuilder.")
|
||||
for node_id, node_info in new_graph.items():
|
||||
new_node_ids.append(node_id)
|
||||
display_id = node_info.get("override_display_id", unique_id)
|
||||
dynprompt.add_ephemeral_node(node_id, node_info, unique_id, display_id)
|
||||
# Figure out if the newly created node is an output node
|
||||
class_type = node_info["class_type"]
|
||||
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
||||
if hasattr(class_def, 'OUTPUT_NODE') and class_def.OUTPUT_NODE == True:
|
||||
new_output_ids.append(node_id)
|
||||
for i in range(len(node_outputs)):
|
||||
if is_link(node_outputs[i]):
|
||||
from_node_id, from_socket = node_outputs[i][0], node_outputs[i][1]
|
||||
new_output_links.append((from_node_id, from_socket))
|
||||
cached_outputs.append((True, node_outputs))
|
||||
new_node_ids = set(new_node_ids)
|
||||
for cache in caches.all:
|
||||
cache.ensure_subcache_for(unique_id, new_node_ids).clean_unused()
|
||||
for node_id in new_output_ids:
|
||||
execution_list.add_node(node_id)
|
||||
for link in new_output_links:
|
||||
execution_list.add_strong_link(link[0], link[1], unique_id)
|
||||
pending_subgraph_results[unique_id] = cached_outputs
|
||||
return (ExecutionResult.PENDING, None, None)
|
||||
caches.outputs.set(unique_id, output_data)
|
||||
except comfy.model_management.InterruptProcessingException as iex:
|
||||
logging.info("Processing interrupted")
|
||||
|
||||
# skip formatting inputs/outputs
|
||||
error_details = {
|
||||
"node_id": unique_id,
|
||||
"node_id": real_node_id,
|
||||
}
|
||||
|
||||
return (False, error_details, iex)
|
||||
return (ExecutionResult.FAILURE, error_details, iex)
|
||||
except Exception as ex:
|
||||
typ, _, tb = sys.exc_info()
|
||||
exception_type = full_type_name(typ)
|
||||
@@ -173,121 +383,36 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute
|
||||
for name, inputs in input_data_all.items():
|
||||
input_data_formatted[name] = [format_value(x) for x in inputs]
|
||||
|
||||
output_data_formatted = {}
|
||||
for node_id, node_outputs in outputs.items():
|
||||
output_data_formatted[node_id] = [[format_value(x) for x in l] for l in node_outputs]
|
||||
|
||||
logging.error(f"!!! Exception during processing!!! {ex}")
|
||||
logging.error(f"!!! Exception during processing !!! {ex}")
|
||||
logging.error(traceback.format_exc())
|
||||
|
||||
error_details = {
|
||||
"node_id": unique_id,
|
||||
"node_id": real_node_id,
|
||||
"exception_message": str(ex),
|
||||
"exception_type": exception_type,
|
||||
"traceback": traceback.format_tb(tb),
|
||||
"current_inputs": input_data_formatted,
|
||||
"current_outputs": output_data_formatted
|
||||
"current_inputs": input_data_formatted
|
||||
}
|
||||
|
||||
if isinstance(ex, comfy.model_management.OOM_EXCEPTION):
|
||||
logging.error("Got an OOM, unloading all loaded models.")
|
||||
comfy.model_management.unload_all_models()
|
||||
|
||||
return (False, error_details, ex)
|
||||
return (ExecutionResult.FAILURE, error_details, ex)
|
||||
|
||||
executed.add(unique_id)
|
||||
|
||||
return (True, None, None)
|
||||
|
||||
def recursive_will_execute(prompt, outputs, current_item, memo={}):
|
||||
unique_id = current_item
|
||||
|
||||
if unique_id in memo:
|
||||
return memo[unique_id]
|
||||
|
||||
inputs = prompt[unique_id]['inputs']
|
||||
will_execute = []
|
||||
if unique_id in outputs:
|
||||
return []
|
||||
|
||||
for x in inputs:
|
||||
input_data = inputs[x]
|
||||
if isinstance(input_data, list):
|
||||
input_unique_id = input_data[0]
|
||||
output_index = input_data[1]
|
||||
if input_unique_id not in outputs:
|
||||
will_execute += recursive_will_execute(prompt, outputs, input_unique_id, memo)
|
||||
|
||||
memo[unique_id] = will_execute + [unique_id]
|
||||
return memo[unique_id]
|
||||
|
||||
def recursive_output_delete_if_changed(prompt, old_prompt, outputs, current_item):
|
||||
unique_id = current_item
|
||||
inputs = prompt[unique_id]['inputs']
|
||||
class_type = prompt[unique_id]['class_type']
|
||||
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
||||
|
||||
is_changed_old = ''
|
||||
is_changed = ''
|
||||
to_delete = False
|
||||
if hasattr(class_def, 'IS_CHANGED'):
|
||||
if unique_id in old_prompt and 'is_changed' in old_prompt[unique_id]:
|
||||
is_changed_old = old_prompt[unique_id]['is_changed']
|
||||
if 'is_changed' not in prompt[unique_id]:
|
||||
input_data_all = get_input_data(inputs, class_def, unique_id, outputs)
|
||||
if input_data_all is not None:
|
||||
try:
|
||||
#is_changed = class_def.IS_CHANGED(**input_data_all)
|
||||
is_changed = map_node_over_list(class_def, input_data_all, "IS_CHANGED")
|
||||
prompt[unique_id]['is_changed'] = is_changed
|
||||
except:
|
||||
to_delete = True
|
||||
else:
|
||||
is_changed = prompt[unique_id]['is_changed']
|
||||
|
||||
if unique_id not in outputs:
|
||||
return True
|
||||
|
||||
if not to_delete:
|
||||
if is_changed != is_changed_old:
|
||||
to_delete = True
|
||||
elif unique_id not in old_prompt:
|
||||
to_delete = True
|
||||
elif class_type != old_prompt[unique_id]['class_type']:
|
||||
to_delete = True
|
||||
elif inputs == old_prompt[unique_id]['inputs']:
|
||||
for x in inputs:
|
||||
input_data = inputs[x]
|
||||
|
||||
if isinstance(input_data, list):
|
||||
input_unique_id = input_data[0]
|
||||
output_index = input_data[1]
|
||||
if input_unique_id in outputs:
|
||||
to_delete = recursive_output_delete_if_changed(prompt, old_prompt, outputs, input_unique_id)
|
||||
else:
|
||||
to_delete = True
|
||||
if to_delete:
|
||||
break
|
||||
else:
|
||||
to_delete = True
|
||||
|
||||
if to_delete:
|
||||
d = outputs.pop(unique_id)
|
||||
del d
|
||||
return to_delete
|
||||
return (ExecutionResult.SUCCESS, None, None)
|
||||
|
||||
class PromptExecutor:
|
||||
def __init__(self, server):
|
||||
def __init__(self, server, lru_size=None):
|
||||
self.lru_size = lru_size
|
||||
self.server = server
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.outputs = {}
|
||||
self.object_storage = {}
|
||||
self.outputs_ui = {}
|
||||
self.caches = CacheSet(self.lru_size)
|
||||
self.status_messages = []
|
||||
self.success = True
|
||||
self.old_prompt = {}
|
||||
|
||||
def add_message(self, event, data: dict, broadcast: bool):
|
||||
data = {
|
||||
@@ -318,26 +443,13 @@ class PromptExecutor:
|
||||
"node_id": node_id,
|
||||
"node_type": class_type,
|
||||
"executed": list(executed),
|
||||
|
||||
"exception_message": error["exception_message"],
|
||||
"exception_type": error["exception_type"],
|
||||
"traceback": error["traceback"],
|
||||
"current_inputs": error["current_inputs"],
|
||||
"current_outputs": error["current_outputs"],
|
||||
"current_outputs": list(current_outputs),
|
||||
}
|
||||
self.add_message("execution_error", mes, broadcast=False)
|
||||
|
||||
# Next, remove the subsequent outputs since they will not be executed
|
||||
to_delete = []
|
||||
for o in self.outputs:
|
||||
if (o not in current_outputs) and (o not in executed):
|
||||
to_delete += [o]
|
||||
if o in self.old_prompt:
|
||||
d = self.old_prompt.pop(o)
|
||||
del d
|
||||
for o in to_delete:
|
||||
d = self.outputs.pop(o)
|
||||
del d
|
||||
|
||||
def execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]):
|
||||
nodes.interrupt_processing(False)
|
||||
@@ -351,65 +463,59 @@ class PromptExecutor:
|
||||
self.add_message("execution_start", { "prompt_id": prompt_id}, broadcast=False)
|
||||
|
||||
with torch.inference_mode():
|
||||
#delete cached outputs if nodes don't exist for them
|
||||
to_delete = []
|
||||
for o in self.outputs:
|
||||
if o not in prompt:
|
||||
to_delete += [o]
|
||||
for o in to_delete:
|
||||
d = self.outputs.pop(o)
|
||||
del d
|
||||
to_delete = []
|
||||
for o in self.object_storage:
|
||||
if o[0] not in prompt:
|
||||
to_delete += [o]
|
||||
else:
|
||||
p = prompt[o[0]]
|
||||
if o[1] != p['class_type']:
|
||||
to_delete += [o]
|
||||
for o in to_delete:
|
||||
d = self.object_storage.pop(o)
|
||||
del d
|
||||
dynamic_prompt = DynamicPrompt(prompt)
|
||||
is_changed_cache = IsChangedCache(dynamic_prompt, self.caches.outputs)
|
||||
for cache in self.caches.all:
|
||||
cache.set_prompt(dynamic_prompt, prompt.keys(), is_changed_cache)
|
||||
cache.clean_unused()
|
||||
|
||||
for x in prompt:
|
||||
recursive_output_delete_if_changed(prompt, self.old_prompt, self.outputs, x)
|
||||
|
||||
current_outputs = set(self.outputs.keys())
|
||||
for x in list(self.outputs_ui.keys()):
|
||||
if x not in current_outputs:
|
||||
d = self.outputs_ui.pop(x)
|
||||
del d
|
||||
cached_nodes = []
|
||||
for node_id in prompt:
|
||||
if self.caches.outputs.get(node_id) is not None:
|
||||
cached_nodes.append(node_id)
|
||||
|
||||
comfy.model_management.cleanup_models(keep_clone_weights_loaded=True)
|
||||
self.add_message("execution_cached",
|
||||
{ "nodes": list(current_outputs) , "prompt_id": prompt_id},
|
||||
{ "nodes": cached_nodes, "prompt_id": prompt_id},
|
||||
broadcast=False)
|
||||
pending_subgraph_results = {}
|
||||
executed = set()
|
||||
output_node_id = None
|
||||
to_execute = []
|
||||
|
||||
execution_list = ExecutionList(dynamic_prompt, self.caches.outputs)
|
||||
current_outputs = self.caches.outputs.all_node_ids()
|
||||
for node_id in list(execute_outputs):
|
||||
to_execute += [(0, node_id)]
|
||||
execution_list.add_node(node_id)
|
||||
|
||||
while len(to_execute) > 0:
|
||||
#always execute the output that depends on the least amount of unexecuted nodes first
|
||||
memo = {}
|
||||
to_execute = sorted(list(map(lambda a: (len(recursive_will_execute(prompt, self.outputs, a[-1], memo)), a[-1]), to_execute)))
|
||||
output_node_id = to_execute.pop(0)[-1]
|
||||
|
||||
# This call shouldn't raise anything if there's an error deep in
|
||||
# the actual SD code, instead it will report the node where the
|
||||
# error was raised
|
||||
self.success, error, ex = recursive_execute(self.server, prompt, self.outputs, output_node_id, extra_data, executed, prompt_id, self.outputs_ui, self.object_storage)
|
||||
if self.success is not True:
|
||||
self.handle_execution_error(prompt_id, prompt, current_outputs, executed, error, ex)
|
||||
while not execution_list.is_empty():
|
||||
node_id, error, ex = execution_list.stage_node_execution()
|
||||
if error is not None:
|
||||
self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex)
|
||||
break
|
||||
|
||||
result, error, ex = execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results)
|
||||
self.success = result != ExecutionResult.FAILURE
|
||||
if result == ExecutionResult.FAILURE:
|
||||
self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex)
|
||||
break
|
||||
elif result == ExecutionResult.PENDING:
|
||||
execution_list.unstage_node_execution()
|
||||
else: # result == ExecutionResult.SUCCESS:
|
||||
execution_list.complete_node_execution()
|
||||
else:
|
||||
# Only execute when the while-loop ends without break
|
||||
self.add_message("execution_success", { "prompt_id": prompt_id }, broadcast=False)
|
||||
|
||||
for x in executed:
|
||||
self.old_prompt[x] = copy.deepcopy(prompt[x])
|
||||
ui_outputs = {}
|
||||
meta_outputs = {}
|
||||
all_node_ids = self.caches.ui.all_node_ids()
|
||||
for node_id in all_node_ids:
|
||||
ui_info = self.caches.ui.get(node_id)
|
||||
if ui_info is not None:
|
||||
ui_outputs[node_id] = ui_info["output"]
|
||||
meta_outputs[node_id] = ui_info["meta"]
|
||||
self.history_result = {
|
||||
"outputs": ui_outputs,
|
||||
"meta": meta_outputs,
|
||||
}
|
||||
self.server.last_node_id = None
|
||||
if comfy.model_management.DISABLE_SMART_MEMORY:
|
||||
comfy.model_management.unload_all_models()
|
||||
@@ -426,31 +532,37 @@ def validate_inputs(prompt, item, validated):
|
||||
obj_class = nodes.NODE_CLASS_MAPPINGS[class_type]
|
||||
|
||||
class_inputs = obj_class.INPUT_TYPES()
|
||||
required_inputs = class_inputs['required']
|
||||
valid_inputs = set(class_inputs.get('required',{})).union(set(class_inputs.get('optional',{})))
|
||||
|
||||
errors = []
|
||||
valid = True
|
||||
|
||||
validate_function_inputs = []
|
||||
validate_has_kwargs = False
|
||||
if hasattr(obj_class, "VALIDATE_INPUTS"):
|
||||
validate_function_inputs = inspect.getfullargspec(obj_class.VALIDATE_INPUTS).args
|
||||
argspec = inspect.getfullargspec(obj_class.VALIDATE_INPUTS)
|
||||
validate_function_inputs = argspec.args
|
||||
validate_has_kwargs = argspec.varkw is not None
|
||||
received_types = {}
|
||||
|
||||
for x in required_inputs:
|
||||
for x in valid_inputs:
|
||||
type_input, input_category, extra_info = get_input_info(obj_class, x)
|
||||
assert extra_info is not None
|
||||
if x not in inputs:
|
||||
error = {
|
||||
"type": "required_input_missing",
|
||||
"message": "Required input is missing",
|
||||
"details": f"{x}",
|
||||
"extra_info": {
|
||||
"input_name": x
|
||||
if input_category == "required":
|
||||
error = {
|
||||
"type": "required_input_missing",
|
||||
"message": "Required input is missing",
|
||||
"details": f"{x}",
|
||||
"extra_info": {
|
||||
"input_name": x
|
||||
}
|
||||
}
|
||||
}
|
||||
errors.append(error)
|
||||
errors.append(error)
|
||||
continue
|
||||
|
||||
val = inputs[x]
|
||||
info = required_inputs[x]
|
||||
type_input = info[0]
|
||||
info = (type_input, extra_info)
|
||||
if isinstance(val, list):
|
||||
if len(val) != 2:
|
||||
error = {
|
||||
@@ -469,8 +581,9 @@ def validate_inputs(prompt, item, validated):
|
||||
o_id = val[0]
|
||||
o_class_type = prompt[o_id]['class_type']
|
||||
r = nodes.NODE_CLASS_MAPPINGS[o_class_type].RETURN_TYPES
|
||||
if r[val[1]] != type_input:
|
||||
received_type = r[val[1]]
|
||||
received_type = r[val[1]]
|
||||
received_types[x] = received_type
|
||||
if 'input_types' not in validate_function_inputs and received_type != type_input:
|
||||
details = f"{x}, {received_type} != {type_input}"
|
||||
error = {
|
||||
"type": "return_type_mismatch",
|
||||
@@ -521,6 +634,9 @@ def validate_inputs(prompt, item, validated):
|
||||
if type_input == "STRING":
|
||||
val = str(val)
|
||||
inputs[x] = val
|
||||
if type_input == "BOOLEAN":
|
||||
val = bool(val)
|
||||
inputs[x] = val
|
||||
except Exception as ex:
|
||||
error = {
|
||||
"type": "invalid_input_type",
|
||||
@@ -536,11 +652,11 @@ def validate_inputs(prompt, item, validated):
|
||||
errors.append(error)
|
||||
continue
|
||||
|
||||
if len(info) > 1:
|
||||
if "min" in info[1] and val < info[1]["min"]:
|
||||
if x not in validate_function_inputs and not validate_has_kwargs:
|
||||
if "min" in extra_info and val < extra_info["min"]:
|
||||
error = {
|
||||
"type": "value_smaller_than_min",
|
||||
"message": "Value {} smaller than min of {}".format(val, info[1]["min"]),
|
||||
"message": "Value {} smaller than min of {}".format(val, extra_info["min"]),
|
||||
"details": f"{x}",
|
||||
"extra_info": {
|
||||
"input_name": x,
|
||||
@@ -550,10 +666,10 @@ def validate_inputs(prompt, item, validated):
|
||||
}
|
||||
errors.append(error)
|
||||
continue
|
||||
if "max" in info[1] and val > info[1]["max"]:
|
||||
if "max" in extra_info and val > extra_info["max"]:
|
||||
error = {
|
||||
"type": "value_bigger_than_max",
|
||||
"message": "Value {} bigger than max of {}".format(val, info[1]["max"]),
|
||||
"message": "Value {} bigger than max of {}".format(val, extra_info["max"]),
|
||||
"details": f"{x}",
|
||||
"extra_info": {
|
||||
"input_name": x,
|
||||
@@ -564,7 +680,6 @@ def validate_inputs(prompt, item, validated):
|
||||
errors.append(error)
|
||||
continue
|
||||
|
||||
if x not in validate_function_inputs:
|
||||
if isinstance(type_input, list):
|
||||
if val not in type_input:
|
||||
input_config = info
|
||||
@@ -591,18 +706,20 @@ def validate_inputs(prompt, item, validated):
|
||||
errors.append(error)
|
||||
continue
|
||||
|
||||
if len(validate_function_inputs) > 0:
|
||||
input_data_all = get_input_data(inputs, obj_class, unique_id)
|
||||
if len(validate_function_inputs) > 0 or validate_has_kwargs:
|
||||
input_data_all, _ = get_input_data(inputs, obj_class, unique_id)
|
||||
input_filtered = {}
|
||||
for x in input_data_all:
|
||||
if x in validate_function_inputs:
|
||||
if x in validate_function_inputs or validate_has_kwargs:
|
||||
input_filtered[x] = input_data_all[x]
|
||||
if 'input_types' in validate_function_inputs:
|
||||
input_filtered['input_types'] = [received_types]
|
||||
|
||||
#ret = obj_class.VALIDATE_INPUTS(**input_filtered)
|
||||
ret = map_node_over_list(obj_class, input_filtered, "VALIDATE_INPUTS")
|
||||
ret = _map_node_over_list(obj_class, input_filtered, "VALIDATE_INPUTS")
|
||||
for x in input_filtered:
|
||||
for i, r in enumerate(ret):
|
||||
if r is not True:
|
||||
if r is not True and not isinstance(r, ExecutionBlocker):
|
||||
details = f"{x}"
|
||||
if r is not False:
|
||||
details += f" - {str(r)}"
|
||||
@@ -613,8 +730,6 @@ def validate_inputs(prompt, item, validated):
|
||||
"details": details,
|
||||
"extra_info": {
|
||||
"input_name": x,
|
||||
"input_config": info,
|
||||
"received_value": val,
|
||||
}
|
||||
}
|
||||
errors.append(error)
|
||||
@@ -780,7 +895,7 @@ class PromptQueue:
|
||||
completed: bool
|
||||
messages: List[str]
|
||||
|
||||
def task_done(self, item_id, outputs,
|
||||
def task_done(self, item_id, history_result,
|
||||
status: Optional['PromptQueue.ExecutionStatus']):
|
||||
with self.mutex:
|
||||
prompt = self.currently_running.pop(item_id)
|
||||
@@ -793,9 +908,10 @@ class PromptQueue:
|
||||
|
||||
self.history[prompt[1]] = {
|
||||
"prompt": prompt,
|
||||
"outputs": copy.deepcopy(outputs),
|
||||
"outputs": {},
|
||||
'status': status_dict,
|
||||
}
|
||||
self.history[prompt[1]].update(history_result)
|
||||
self.server.queue_updated()
|
||||
|
||||
def get_current_queue(self):
|
||||
|
@@ -17,7 +17,7 @@ folder_names_and_paths["configs"] = ([os.path.join(models_dir, "configs")], [".y
|
||||
folder_names_and_paths["loras"] = ([os.path.join(models_dir, "loras")], supported_pt_extensions)
|
||||
folder_names_and_paths["vae"] = ([os.path.join(models_dir, "vae")], supported_pt_extensions)
|
||||
folder_names_and_paths["clip"] = ([os.path.join(models_dir, "clip")], supported_pt_extensions)
|
||||
folder_names_and_paths["unet"] = ([os.path.join(models_dir, "unet")], supported_pt_extensions)
|
||||
folder_names_and_paths["diffusion_models"] = ([os.path.join(models_dir, "unet"), os.path.join(models_dir, "diffusion_models")], supported_pt_extensions)
|
||||
folder_names_and_paths["clip_vision"] = ([os.path.join(models_dir, "clip_vision")], supported_pt_extensions)
|
||||
folder_names_and_paths["style_models"] = ([os.path.join(models_dir, "style_models")], supported_pt_extensions)
|
||||
folder_names_and_paths["embeddings"] = ([os.path.join(models_dir, "embeddings")], supported_pt_extensions)
|
||||
@@ -44,6 +44,10 @@ user_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "user
|
||||
|
||||
filename_list_cache: dict[str, tuple[list[str], dict[str, float], float]] = {}
|
||||
|
||||
def map_legacy(folder_name: str) -> str:
|
||||
legacy = {"unet": "diffusion_models"}
|
||||
return legacy.get(folder_name, folder_name)
|
||||
|
||||
if not os.path.exists(input_directory):
|
||||
try:
|
||||
os.makedirs(input_directory)
|
||||
@@ -128,12 +132,14 @@ def exists_annotated_filepath(name) -> bool:
|
||||
|
||||
def add_model_folder_path(folder_name: str, full_folder_path: str) -> None:
|
||||
global folder_names_and_paths
|
||||
folder_name = map_legacy(folder_name)
|
||||
if folder_name in folder_names_and_paths:
|
||||
folder_names_and_paths[folder_name][0].append(full_folder_path)
|
||||
else:
|
||||
folder_names_and_paths[folder_name] = ([full_folder_path], set())
|
||||
|
||||
def get_folder_paths(folder_name: str) -> list[str]:
|
||||
folder_name = map_legacy(folder_name)
|
||||
return folder_names_and_paths[folder_name][0][:]
|
||||
|
||||
def recursive_search(directory: str, excluded_dir_names: list[str] | None=None) -> tuple[list[str], dict[str, float]]:
|
||||
@@ -180,6 +186,7 @@ def filter_files_extensions(files: Collection[str], extensions: Collection[str])
|
||||
|
||||
def get_full_path(folder_name: str, filename: str) -> str | None:
|
||||
global folder_names_and_paths
|
||||
folder_name = map_legacy(folder_name)
|
||||
if folder_name not in folder_names_and_paths:
|
||||
return None
|
||||
folders = folder_names_and_paths[folder_name]
|
||||
@@ -194,6 +201,7 @@ def get_full_path(folder_name: str, filename: str) -> str | None:
|
||||
return None
|
||||
|
||||
def get_filename_list_(folder_name: str) -> tuple[list[str], dict[str, float], float]:
|
||||
folder_name = map_legacy(folder_name)
|
||||
global folder_names_and_paths
|
||||
output_list = set()
|
||||
folders = folder_names_and_paths[folder_name]
|
||||
@@ -208,6 +216,7 @@ def get_filename_list_(folder_name: str) -> tuple[list[str], dict[str, float], f
|
||||
def cached_filename_list_(folder_name: str) -> tuple[list[str], dict[str, float], float] | None:
|
||||
global filename_list_cache
|
||||
global folder_names_and_paths
|
||||
folder_name = map_legacy(folder_name)
|
||||
if folder_name not in filename_list_cache:
|
||||
return None
|
||||
out = filename_list_cache[folder_name]
|
||||
@@ -227,6 +236,7 @@ def cached_filename_list_(folder_name: str) -> tuple[list[str], dict[str, float]
|
||||
return out
|
||||
|
||||
def get_filename_list(folder_name: str) -> list[str]:
|
||||
folder_name = map_legacy(folder_name)
|
||||
out = cached_filename_list_(folder_name)
|
||||
if out is None:
|
||||
out = get_filename_list_(folder_name)
|
||||
|
10
main.py
10
main.py
@@ -6,6 +6,10 @@ import importlib.util
|
||||
import folder_paths
|
||||
import time
|
||||
from comfy.cli_args import args
|
||||
from app.logger import setup_logger
|
||||
|
||||
|
||||
setup_logger(verbose=args.verbose)
|
||||
|
||||
|
||||
def execute_prestartup_script():
|
||||
@@ -101,7 +105,7 @@ def cuda_malloc_warning():
|
||||
logging.warning("\nWARNING: this card most likely does not support cuda-malloc, if you get \"CUDA error\" please run ComfyUI with: --disable-cuda-malloc\n")
|
||||
|
||||
def prompt_worker(q, server):
|
||||
e = execution.PromptExecutor(server)
|
||||
e = execution.PromptExecutor(server, lru_size=args.cache_lru)
|
||||
last_gc_collect = 0
|
||||
need_gc = False
|
||||
gc_collect_interval = 10.0
|
||||
@@ -121,7 +125,7 @@ def prompt_worker(q, server):
|
||||
e.execute(item[2], prompt_id, item[3], item[4])
|
||||
need_gc = True
|
||||
q.task_done(item_id,
|
||||
e.outputs_ui,
|
||||
e.history_result,
|
||||
status=execution.PromptQueue.ExecutionStatus(
|
||||
status_str='success' if e.success else 'error',
|
||||
completed=e.success,
|
||||
@@ -242,6 +246,7 @@ if __name__ == "__main__":
|
||||
folder_paths.add_model_folder_path("checkpoints", os.path.join(folder_paths.get_output_directory(), "checkpoints"))
|
||||
folder_paths.add_model_folder_path("clip", os.path.join(folder_paths.get_output_directory(), "clip"))
|
||||
folder_paths.add_model_folder_path("vae", os.path.join(folder_paths.get_output_directory(), "vae"))
|
||||
folder_paths.add_model_folder_path("diffusion_models", os.path.join(folder_paths.get_output_directory(), "diffusion_models"))
|
||||
|
||||
if args.input_directory:
|
||||
input_dir = os.path.abspath(args.input_directory)
|
||||
@@ -261,6 +266,7 @@ if __name__ == "__main__":
|
||||
call_on_start = startup_server
|
||||
|
||||
try:
|
||||
loop.run_until_complete(server.setup())
|
||||
loop.run_until_complete(run(server, address=args.listen, port=args.port, verbose=not args.dont_print_server, call_on_start=call_on_start))
|
||||
except KeyboardInterrupt:
|
||||
logging.info("\nStopped server")
|
||||
|
2
model_filemanager/__init__.py
Normal file
2
model_filemanager/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
# model_manager/__init__.py
|
||||
from .download_models import download_model, DownloadModelStatus, DownloadStatusType, create_model_path, check_file_exists, track_download_progress, validate_model_subdirectory, validate_filename
|
240
model_filemanager/download_models.py
Normal file
240
model_filemanager/download_models.py
Normal file
@@ -0,0 +1,240 @@
|
||||
from __future__ import annotations
|
||||
import aiohttp
|
||||
import os
|
||||
import traceback
|
||||
import logging
|
||||
from folder_paths import models_dir
|
||||
import re
|
||||
from typing import Callable, Any, Optional, Awaitable, Dict
|
||||
from enum import Enum
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
class DownloadStatusType(Enum):
|
||||
PENDING = "pending"
|
||||
IN_PROGRESS = "in_progress"
|
||||
COMPLETED = "completed"
|
||||
ERROR = "error"
|
||||
|
||||
@dataclass
|
||||
class DownloadModelStatus():
|
||||
status: str
|
||||
progress_percentage: float
|
||||
message: str
|
||||
already_existed: bool = False
|
||||
|
||||
def __init__(self, status: DownloadStatusType, progress_percentage: float, message: str, already_existed: bool):
|
||||
self.status = status.value # Store the string value of the Enum
|
||||
self.progress_percentage = progress_percentage
|
||||
self.message = message
|
||||
self.already_existed = already_existed
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"status": self.status,
|
||||
"progress_percentage": self.progress_percentage,
|
||||
"message": self.message,
|
||||
"already_existed": self.already_existed
|
||||
}
|
||||
|
||||
async def download_model(model_download_request: Callable[[str], Awaitable[aiohttp.ClientResponse]],
|
||||
model_name: str,
|
||||
model_url: str,
|
||||
model_sub_directory: str,
|
||||
progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]],
|
||||
progress_interval: float = 1.0) -> DownloadModelStatus:
|
||||
"""
|
||||
Download a model file from a given URL into the models directory.
|
||||
|
||||
Args:
|
||||
model_download_request (Callable[[str], Awaitable[aiohttp.ClientResponse]]):
|
||||
A function that makes an HTTP request. This makes it easier to mock in unit tests.
|
||||
model_name (str):
|
||||
The name of the model file to be downloaded. This will be the filename on disk.
|
||||
model_url (str):
|
||||
The URL from which to download the model.
|
||||
model_sub_directory (str):
|
||||
The subdirectory within the main models directory where the model
|
||||
should be saved (e.g., 'checkpoints', 'loras', etc.).
|
||||
progress_callback (Callable[[str, DownloadModelStatus], Awaitable[Any]]):
|
||||
An asynchronous function to call with progress updates.
|
||||
|
||||
Returns:
|
||||
DownloadModelStatus: The result of the download operation.
|
||||
"""
|
||||
if not validate_model_subdirectory(model_sub_directory):
|
||||
return DownloadModelStatus(
|
||||
DownloadStatusType.ERROR,
|
||||
0,
|
||||
"Invalid model subdirectory",
|
||||
False
|
||||
)
|
||||
|
||||
if not validate_filename(model_name):
|
||||
return DownloadModelStatus(
|
||||
DownloadStatusType.ERROR,
|
||||
0,
|
||||
"Invalid model name",
|
||||
False
|
||||
)
|
||||
|
||||
file_path, relative_path = create_model_path(model_name, model_sub_directory, models_dir)
|
||||
existing_file = await check_file_exists(file_path, model_name, progress_callback, relative_path)
|
||||
if existing_file:
|
||||
return existing_file
|
||||
|
||||
try:
|
||||
status = DownloadModelStatus(DownloadStatusType.PENDING, 0, f"Starting download of {model_name}", False)
|
||||
await progress_callback(relative_path, status)
|
||||
|
||||
response = await model_download_request(model_url)
|
||||
if response.status != 200:
|
||||
error_message = f"Failed to download {model_name}. Status code: {response.status}"
|
||||
logging.error(error_message)
|
||||
status = DownloadModelStatus(DownloadStatusType.ERROR, 0, error_message, False)
|
||||
await progress_callback(relative_path, status)
|
||||
return DownloadModelStatus(DownloadStatusType.ERROR, 0, error_message, False)
|
||||
|
||||
return await track_download_progress(response, file_path, model_name, progress_callback, relative_path, progress_interval)
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Error in downloading model: {e}")
|
||||
return await handle_download_error(e, model_name, progress_callback, relative_path)
|
||||
|
||||
|
||||
def create_model_path(model_name: str, model_directory: str, models_base_dir: str) -> tuple[str, str]:
|
||||
full_model_dir = os.path.join(models_base_dir, model_directory)
|
||||
os.makedirs(full_model_dir, exist_ok=True)
|
||||
file_path = os.path.join(full_model_dir, model_name)
|
||||
|
||||
# Ensure the resulting path is still within the base directory
|
||||
abs_file_path = os.path.abspath(file_path)
|
||||
abs_base_dir = os.path.abspath(str(models_base_dir))
|
||||
if os.path.commonprefix([abs_file_path, abs_base_dir]) != abs_base_dir:
|
||||
raise Exception(f"Invalid model directory: {model_directory}/{model_name}")
|
||||
|
||||
|
||||
relative_path = '/'.join([model_directory, model_name])
|
||||
return file_path, relative_path
|
||||
|
||||
async def check_file_exists(file_path: str,
|
||||
model_name: str,
|
||||
progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]],
|
||||
relative_path: str) -> Optional[DownloadModelStatus]:
|
||||
if os.path.exists(file_path):
|
||||
status = DownloadModelStatus(DownloadStatusType.COMPLETED, 100, f"{model_name} already exists", True)
|
||||
await progress_callback(relative_path, status)
|
||||
return status
|
||||
return None
|
||||
|
||||
|
||||
async def track_download_progress(response: aiohttp.ClientResponse,
|
||||
file_path: str,
|
||||
model_name: str,
|
||||
progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]],
|
||||
relative_path: str,
|
||||
interval: float = 1.0) -> DownloadModelStatus:
|
||||
try:
|
||||
total_size = int(response.headers.get('Content-Length', 0))
|
||||
downloaded = 0
|
||||
last_update_time = time.time()
|
||||
|
||||
async def update_progress():
|
||||
nonlocal last_update_time
|
||||
progress = (downloaded / total_size) * 100 if total_size > 0 else 0
|
||||
status = DownloadModelStatus(DownloadStatusType.IN_PROGRESS, progress, f"Downloading {model_name}", False)
|
||||
await progress_callback(relative_path, status)
|
||||
last_update_time = time.time()
|
||||
|
||||
with open(file_path, 'wb') as f:
|
||||
chunk_iterator = response.content.iter_chunked(8192)
|
||||
while True:
|
||||
try:
|
||||
chunk = await chunk_iterator.__anext__()
|
||||
except StopAsyncIteration:
|
||||
break
|
||||
f.write(chunk)
|
||||
downloaded += len(chunk)
|
||||
|
||||
if time.time() - last_update_time >= interval:
|
||||
await update_progress()
|
||||
|
||||
await update_progress()
|
||||
|
||||
logging.info(f"Successfully downloaded {model_name}. Total downloaded: {downloaded}")
|
||||
status = DownloadModelStatus(DownloadStatusType.COMPLETED, 100, f"Successfully downloaded {model_name}", False)
|
||||
await progress_callback(relative_path, status)
|
||||
|
||||
return status
|
||||
except Exception as e:
|
||||
logging.error(f"Error in track_download_progress: {e}")
|
||||
logging.error(traceback.format_exc())
|
||||
return await handle_download_error(e, model_name, progress_callback, relative_path)
|
||||
|
||||
async def handle_download_error(e: Exception,
|
||||
model_name: str,
|
||||
progress_callback: Callable[[str, DownloadModelStatus], Any],
|
||||
relative_path: str) -> DownloadModelStatus:
|
||||
error_message = f"Error downloading {model_name}: {str(e)}"
|
||||
status = DownloadModelStatus(DownloadStatusType.ERROR, 0, error_message, False)
|
||||
await progress_callback(relative_path, status)
|
||||
return status
|
||||
|
||||
def validate_model_subdirectory(model_subdirectory: str) -> bool:
|
||||
"""
|
||||
Validate that the model subdirectory is safe to install into.
|
||||
Must not contain relative paths, nested paths or special characters
|
||||
other than underscores and hyphens.
|
||||
|
||||
Args:
|
||||
model_subdirectory (str): The subdirectory for the specific model type.
|
||||
|
||||
Returns:
|
||||
bool: True if the subdirectory is safe, False otherwise.
|
||||
"""
|
||||
if len(model_subdirectory) > 50:
|
||||
return False
|
||||
|
||||
if '..' in model_subdirectory or '/' in model_subdirectory:
|
||||
return False
|
||||
|
||||
if not re.match(r'^[a-zA-Z0-9_-]+$', model_subdirectory):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def validate_filename(filename: str)-> bool:
|
||||
"""
|
||||
Validate a filename to ensure it's safe and doesn't contain any path traversal attempts.
|
||||
|
||||
Args:
|
||||
filename (str): The filename to validate
|
||||
|
||||
Returns:
|
||||
bool: True if the filename is valid, False otherwise
|
||||
"""
|
||||
if not filename.lower().endswith(('.sft', '.safetensors')):
|
||||
return False
|
||||
|
||||
# Check if the filename is empty, None, or just whitespace
|
||||
if not filename or not filename.strip():
|
||||
return False
|
||||
|
||||
# Check for any directory traversal attempts or invalid characters
|
||||
if any(char in filename for char in ['..', '/', '\\', '\n', '\r', '\t', '\0']):
|
||||
return False
|
||||
|
||||
# Check if the filename starts with a dot (hidden file)
|
||||
if filename.startswith('.'):
|
||||
return False
|
||||
|
||||
# Use a whitelist of allowed characters
|
||||
if not re.match(r'^[a-zA-Z0-9_\-. ]+$', filename):
|
||||
return False
|
||||
|
||||
# Ensure the filename isn't too long
|
||||
if len(filename) > 255:
|
||||
return False
|
||||
|
||||
return True
|
130
nodes.py
130
nodes.py
@@ -47,11 +47,18 @@ MAX_RESOLUTION=16384
|
||||
class CLIPTextEncode:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"text": ("STRING", {"multiline": True, "dynamicPrompts": True}), "clip": ("CLIP", )}}
|
||||
return {
|
||||
"required": {
|
||||
"text": ("STRING", {"multiline": True, "dynamicPrompts": True, "tooltip": "The text to be encoded."}),
|
||||
"clip": ("CLIP", {"tooltip": "The CLIP model used for encoding the text."})
|
||||
}
|
||||
}
|
||||
RETURN_TYPES = ("CONDITIONING",)
|
||||
OUTPUT_TOOLTIPS = ("A conditioning containing the embedded text used to guide the diffusion model.",)
|
||||
FUNCTION = "encode"
|
||||
|
||||
CATEGORY = "conditioning"
|
||||
DESCRIPTION = "Encodes a text prompt using a CLIP model into an embedding that can be used to guide the diffusion model towards generating specific images."
|
||||
|
||||
def encode(self, clip, text):
|
||||
tokens = clip.tokenize(text)
|
||||
@@ -260,11 +267,18 @@ class ConditioningSetTimestepRange:
|
||||
class VAEDecode:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "samples": ("LATENT", ), "vae": ("VAE", )}}
|
||||
return {
|
||||
"required": {
|
||||
"samples": ("LATENT", {"tooltip": "The latent to be decoded."}),
|
||||
"vae": ("VAE", {"tooltip": "The VAE model used for decoding the latent."})
|
||||
}
|
||||
}
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
OUTPUT_TOOLTIPS = ("The decoded image.",)
|
||||
FUNCTION = "decode"
|
||||
|
||||
CATEGORY = "latent"
|
||||
DESCRIPTION = "Decodes latent images back into pixel space images."
|
||||
|
||||
def decode(self, vae, samples):
|
||||
return (vae.decode(samples["samples"]), )
|
||||
@@ -506,12 +520,19 @@ class CheckpointLoader:
|
||||
class CheckpointLoaderSimple:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "ckpt_name": (folder_paths.get_filename_list("checkpoints"), ),
|
||||
}}
|
||||
return {
|
||||
"required": {
|
||||
"ckpt_name": (folder_paths.get_filename_list("checkpoints"), {"tooltip": "The name of the checkpoint (model) to load."}),
|
||||
}
|
||||
}
|
||||
RETURN_TYPES = ("MODEL", "CLIP", "VAE")
|
||||
OUTPUT_TOOLTIPS = ("The model used for denoising latents.",
|
||||
"The CLIP model used for encoding text prompts.",
|
||||
"The VAE model used for encoding and decoding images to and from latent space.")
|
||||
FUNCTION = "load_checkpoint"
|
||||
|
||||
CATEGORY = "loaders"
|
||||
DESCRIPTION = "Loads a diffusion model checkpoint, diffusion models are used to denoise latents."
|
||||
|
||||
def load_checkpoint(self, ckpt_name):
|
||||
ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name)
|
||||
@@ -582,16 +603,22 @@ class LoraLoader:
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "model": ("MODEL",),
|
||||
"clip": ("CLIP", ),
|
||||
"lora_name": (folder_paths.get_filename_list("loras"), ),
|
||||
"strength_model": ("FLOAT", {"default": 1.0, "min": -100.0, "max": 100.0, "step": 0.01}),
|
||||
"strength_clip": ("FLOAT", {"default": 1.0, "min": -100.0, "max": 100.0, "step": 0.01}),
|
||||
}}
|
||||
return {
|
||||
"required": {
|
||||
"model": ("MODEL", {"tooltip": "The diffusion model the LoRA will be applied to."}),
|
||||
"clip": ("CLIP", {"tooltip": "The CLIP model the LoRA will be applied to."}),
|
||||
"lora_name": (folder_paths.get_filename_list("loras"), {"tooltip": "The name of the LoRA."}),
|
||||
"strength_model": ("FLOAT", {"default": 1.0, "min": -100.0, "max": 100.0, "step": 0.01, "tooltip": "How strongly to modify the diffusion model. This value can be negative."}),
|
||||
"strength_clip": ("FLOAT", {"default": 1.0, "min": -100.0, "max": 100.0, "step": 0.01, "tooltip": "How strongly to modify the CLIP model. This value can be negative."}),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("MODEL", "CLIP")
|
||||
OUTPUT_TOOLTIPS = ("The modified diffusion model.", "The modified CLIP model.")
|
||||
FUNCTION = "load_lora"
|
||||
|
||||
CATEGORY = "loaders"
|
||||
DESCRIPTION = "LoRAs are used to modify diffusion and CLIP models, altering the way in which latents are denoised such as applying styles. Multiple LoRA nodes can be linked together."
|
||||
|
||||
def load_lora(self, model, clip, lora_name, strength_model, strength_clip):
|
||||
if strength_model == 0 and strength_clip == 0:
|
||||
@@ -638,6 +665,8 @@ class VAELoader:
|
||||
sd1_taesd_dec = False
|
||||
sd3_taesd_enc = False
|
||||
sd3_taesd_dec = False
|
||||
f1_taesd_enc = False
|
||||
f1_taesd_dec = False
|
||||
|
||||
for v in approx_vaes:
|
||||
if v.startswith("taesd_decoder."):
|
||||
@@ -652,12 +681,18 @@ class VAELoader:
|
||||
sd3_taesd_dec = True
|
||||
elif v.startswith("taesd3_encoder."):
|
||||
sd3_taesd_enc = True
|
||||
elif v.startswith("taef1_encoder."):
|
||||
f1_taesd_dec = True
|
||||
elif v.startswith("taef1_decoder."):
|
||||
f1_taesd_enc = True
|
||||
if sd1_taesd_dec and sd1_taesd_enc:
|
||||
vaes.append("taesd")
|
||||
if sdxl_taesd_dec and sdxl_taesd_enc:
|
||||
vaes.append("taesdxl")
|
||||
if sd3_taesd_dec and sd3_taesd_enc:
|
||||
vaes.append("taesd3")
|
||||
if f1_taesd_dec and f1_taesd_enc:
|
||||
vaes.append("taef1")
|
||||
return vaes
|
||||
|
||||
@staticmethod
|
||||
@@ -685,6 +720,9 @@ class VAELoader:
|
||||
elif name == "taesd3":
|
||||
sd["vae_scale"] = torch.tensor(1.5305)
|
||||
sd["vae_shift"] = torch.tensor(0.0609)
|
||||
elif name == "taef1":
|
||||
sd["vae_scale"] = torch.tensor(0.3611)
|
||||
sd["vae_shift"] = torch.tensor(0.1159)
|
||||
return sd
|
||||
|
||||
@classmethod
|
||||
@@ -697,7 +735,7 @@ class VAELoader:
|
||||
|
||||
#TODO: scale factor?
|
||||
def load_vae(self, vae_name):
|
||||
if vae_name in ["taesd", "taesdxl", "taesd3"]:
|
||||
if vae_name in ["taesd", "taesdxl", "taesd3", "taef1"]:
|
||||
sd = self.load_taesd(vae_name)
|
||||
else:
|
||||
vae_path = folder_paths.get_full_path("vae", vae_name)
|
||||
@@ -817,7 +855,7 @@ class ControlNetApplyAdvanced:
|
||||
class UNETLoader:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "unet_name": (folder_paths.get_filename_list("unet"), ),
|
||||
return {"required": { "unet_name": (folder_paths.get_filename_list("diffusion_models"), ),
|
||||
"weight_dtype": (["default", "fp8_e4m3fn", "fp8_e5m2"],)
|
||||
}}
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
@@ -826,14 +864,14 @@ class UNETLoader:
|
||||
CATEGORY = "advanced/loaders"
|
||||
|
||||
def load_unet(self, unet_name, weight_dtype):
|
||||
dtype = None
|
||||
model_options = {}
|
||||
if weight_dtype == "fp8_e4m3fn":
|
||||
dtype = torch.float8_e4m3fn
|
||||
model_options["dtype"] = torch.float8_e4m3fn
|
||||
elif weight_dtype == "fp8_e5m2":
|
||||
dtype = torch.float8_e5m2
|
||||
model_options["dtype"] = torch.float8_e5m2
|
||||
|
||||
unet_path = folder_paths.get_full_path("unet", unet_name)
|
||||
model = comfy.sd.load_unet(unet_path, dtype=dtype)
|
||||
unet_path = folder_paths.get_full_path("diffusion_models", unet_name)
|
||||
model = comfy.sd.load_diffusion_model(unet_path, model_options=model_options)
|
||||
return (model,)
|
||||
|
||||
class CLIPLoader:
|
||||
@@ -1033,13 +1071,19 @@ class EmptyLatentImage:
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "width": ("INT", {"default": 512, "min": 16, "max": MAX_RESOLUTION, "step": 8}),
|
||||
"height": ("INT", {"default": 512, "min": 16, "max": MAX_RESOLUTION, "step": 8}),
|
||||
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096})}}
|
||||
return {
|
||||
"required": {
|
||||
"width": ("INT", {"default": 512, "min": 16, "max": MAX_RESOLUTION, "step": 8, "tooltip": "The width of the latent images in pixels."}),
|
||||
"height": ("INT", {"default": 512, "min": 16, "max": MAX_RESOLUTION, "step": 8, "tooltip": "The height of the latent images in pixels."}),
|
||||
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096, "tooltip": "The number of latent images in the batch."})
|
||||
}
|
||||
}
|
||||
RETURN_TYPES = ("LATENT",)
|
||||
OUTPUT_TOOLTIPS = ("The empty latent image batch.",)
|
||||
FUNCTION = "generate"
|
||||
|
||||
CATEGORY = "latent"
|
||||
DESCRIPTION = "Create a new batch of empty latent images to be denoised via sampling."
|
||||
|
||||
def generate(self, width, height, batch_size=1):
|
||||
latent = torch.zeros([batch_size, 4, height // 8, width // 8], device=self.device)
|
||||
@@ -1359,24 +1403,27 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive,
|
||||
class KSampler:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required":
|
||||
{"model": ("MODEL",),
|
||||
"seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
|
||||
"steps": ("INT", {"default": 20, "min": 1, "max": 10000}),
|
||||
"cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01}),
|
||||
"sampler_name": (comfy.samplers.KSampler.SAMPLERS, ),
|
||||
"scheduler": (comfy.samplers.KSampler.SCHEDULERS, ),
|
||||
"positive": ("CONDITIONING", ),
|
||||
"negative": ("CONDITIONING", ),
|
||||
"latent_image": ("LATENT", ),
|
||||
"denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
|
||||
}
|
||||
}
|
||||
return {
|
||||
"required": {
|
||||
"model": ("MODEL", {"tooltip": "The model used for denoising the input latent."}),
|
||||
"seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff, "tooltip": "The random seed used for creating the noise."}),
|
||||
"steps": ("INT", {"default": 20, "min": 1, "max": 10000, "tooltip": "The number of steps used in the denoising process."}),
|
||||
"cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01, "tooltip": "The Classifier-Free Guidance scale balances creativity and adherence to the prompt. Higher values result in images more closely matching the prompt however too high values will negatively impact quality."}),
|
||||
"sampler_name": (comfy.samplers.KSampler.SAMPLERS, {"tooltip": "The algorithm used when sampling, this can affect the quality, speed, and style of the generated output."}),
|
||||
"scheduler": (comfy.samplers.KSampler.SCHEDULERS, {"tooltip": "The scheduler controls how noise is gradually removed to form the image."}),
|
||||
"positive": ("CONDITIONING", {"tooltip": "The conditioning describing the attributes you want to include in the image."}),
|
||||
"negative": ("CONDITIONING", {"tooltip": "The conditioning describing the attributes you want to exclude from the image."}),
|
||||
"latent_image": ("LATENT", {"tooltip": "The latent image to denoise."}),
|
||||
"denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "The amount of denoising applied, lower values will maintain the structure of the initial image allowing for image to image sampling."}),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("LATENT",)
|
||||
OUTPUT_TOOLTIPS = ("The denoised latent.",)
|
||||
FUNCTION = "sample"
|
||||
|
||||
CATEGORY = "sampling"
|
||||
DESCRIPTION = "Uses the provided model, positive and negative conditioning to denoise the latent image."
|
||||
|
||||
def sample(self, model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=1.0):
|
||||
return common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=denoise)
|
||||
@@ -1424,11 +1471,15 @@ class SaveImage:
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required":
|
||||
{"images": ("IMAGE", ),
|
||||
"filename_prefix": ("STRING", {"default": "ComfyUI"})},
|
||||
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
|
||||
}
|
||||
return {
|
||||
"required": {
|
||||
"images": ("IMAGE", {"tooltip": "The images to save."}),
|
||||
"filename_prefix": ("STRING", {"default": "ComfyUI", "tooltip": "The prefix for the file to save. This may include formatting information such as %date:yyyy-MM-dd% or %Empty Latent Image.width% to include values from nodes."})
|
||||
},
|
||||
"hidden": {
|
||||
"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ()
|
||||
FUNCTION = "save_images"
|
||||
@@ -1436,6 +1487,7 @@ class SaveImage:
|
||||
OUTPUT_NODE = True
|
||||
|
||||
CATEGORY = "image"
|
||||
DESCRIPTION = "Saves the input images to your ComfyUI output directory."
|
||||
|
||||
def save_images(self, images, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None):
|
||||
filename_prefix += self.prefix_append
|
||||
@@ -2077,3 +2129,5 @@ def init_extra_nodes(init_custom_nodes=True):
|
||||
else:
|
||||
logging.warning("Please do a: pip install -r requirements.txt")
|
||||
logging.warning("")
|
||||
|
||||
return import_failed
|
||||
|
@@ -79,7 +79,7 @@
|
||||
"#!wget -c https://huggingface.co/comfyanonymous/clip_vision_g/resolve/main/clip_vision_g.safetensors -P ./models/clip_vision/\n",
|
||||
"\n",
|
||||
"# SD1.5\n",
|
||||
"!wget -c https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned-emaonly.ckpt -P ./models/checkpoints/\n",
|
||||
"!wget -c https://huggingface.co/Comfy-Org/stable-diffusion-v1-5-archive/resolve/main/v1-5-pruned-emaonly-fp16.safetensors -P ./models/checkpoints/\n",
|
||||
"\n",
|
||||
"# SD2\n",
|
||||
"#!wget -c https://huggingface.co/stabilityai/stable-diffusion-2-1-base/resolve/main/v2-1_512-ema-pruned.safetensors -P ./models/checkpoints/\n",
|
||||
|
@@ -1,6 +1,7 @@
|
||||
[pytest]
|
||||
markers =
|
||||
inference: mark as inference test (deselect with '-m "not inference"')
|
||||
execution: mark as execution test (deselect with '-m "not execution"')
|
||||
testpaths =
|
||||
tests
|
||||
tests-unit
|
||||
|
@@ -43,7 +43,7 @@ prompt_text = """
|
||||
"4": {
|
||||
"class_type": "CheckpointLoaderSimple",
|
||||
"inputs": {
|
||||
"ckpt_name": "v1-5-pruned-emaonly.ckpt"
|
||||
"ckpt_name": "v1-5-pruned-emaonly.safetensors"
|
||||
}
|
||||
},
|
||||
"5": {
|
||||
|
@@ -41,15 +41,14 @@ def get_images(ws, prompt):
|
||||
continue #previews are binary data
|
||||
|
||||
history = get_history(prompt_id)[prompt_id]
|
||||
for o in history['outputs']:
|
||||
for node_id in history['outputs']:
|
||||
node_output = history['outputs'][node_id]
|
||||
if 'images' in node_output:
|
||||
images_output = []
|
||||
for image in node_output['images']:
|
||||
image_data = get_image(image['filename'], image['subfolder'], image['type'])
|
||||
images_output.append(image_data)
|
||||
output_images[node_id] = images_output
|
||||
for node_id in history['outputs']:
|
||||
node_output = history['outputs'][node_id]
|
||||
images_output = []
|
||||
if 'images' in node_output:
|
||||
for image in node_output['images']:
|
||||
image_data = get_image(image['filename'], image['subfolder'], image['type'])
|
||||
images_output.append(image_data)
|
||||
output_images[node_id] = images_output
|
||||
|
||||
return output_images
|
||||
|
||||
@@ -85,7 +84,7 @@ prompt_text = """
|
||||
"4": {
|
||||
"class_type": "CheckpointLoaderSimple",
|
||||
"inputs": {
|
||||
"ckpt_name": "v1-5-pruned-emaonly.ckpt"
|
||||
"ckpt_name": "v1-5-pruned-emaonly.safetensors"
|
||||
}
|
||||
},
|
||||
"5": {
|
||||
|
@@ -81,7 +81,7 @@ prompt_text = """
|
||||
"4": {
|
||||
"class_type": "CheckpointLoaderSimple",
|
||||
"inputs": {
|
||||
"ckpt_name": "v1-5-pruned-emaonly.ckpt"
|
||||
"ckpt_name": "v1-5-pruned-emaonly.safetensors"
|
||||
}
|
||||
},
|
||||
"5": {
|
||||
|
83
server.py
83
server.py
@@ -12,7 +12,6 @@ import json
|
||||
import glob
|
||||
import struct
|
||||
import ssl
|
||||
import hashlib
|
||||
from PIL import Image, ImageOps
|
||||
from PIL.PngImagePlugin import PngInfo
|
||||
from io import BytesIO
|
||||
@@ -28,7 +27,9 @@ import comfy.model_management
|
||||
import node_helpers
|
||||
from app.frontend_management import FrontendManager
|
||||
from app.user_manager import UserManager
|
||||
|
||||
from model_filemanager import download_model, DownloadModelStatus
|
||||
from typing import Optional
|
||||
from api_server.routes.internal.internal_routes import InternalRoutes
|
||||
|
||||
class BinaryEventTypes:
|
||||
PREVIEW_IMAGE = 1
|
||||
@@ -40,6 +41,21 @@ async def send_socket_catch_exception(function, message):
|
||||
except (aiohttp.ClientError, aiohttp.ClientPayloadError, ConnectionResetError) as err:
|
||||
logging.warning("send error: {}".format(err))
|
||||
|
||||
def get_comfyui_version():
|
||||
comfyui_version = "unknown"
|
||||
repo_path = os.path.dirname(os.path.realpath(__file__))
|
||||
try:
|
||||
import pygit2
|
||||
repo = pygit2.Repository(repo_path)
|
||||
comfyui_version = repo.describe(describe_strategy=pygit2.GIT_DESCRIBE_TAGS)
|
||||
except Exception:
|
||||
try:
|
||||
import subprocess
|
||||
comfyui_version = subprocess.check_output(["git", "describe", "--tags"], cwd=repo_path).decode('utf-8')
|
||||
except Exception as e:
|
||||
logging.warning(f"Failed to get ComfyUI version: {e}")
|
||||
return comfyui_version.strip()
|
||||
|
||||
@web.middleware
|
||||
async def cache_control(request: web.Request, handler):
|
||||
response: web.Response = await handler(request)
|
||||
@@ -72,10 +88,12 @@ class PromptServer():
|
||||
mimetypes.types_map['.js'] = 'application/javascript; charset=utf-8'
|
||||
|
||||
self.user_manager = UserManager()
|
||||
self.internal_routes = InternalRoutes()
|
||||
self.supports = ["custom_nodes_from_web"]
|
||||
self.prompt_queue = None
|
||||
self.loop = loop
|
||||
self.messages = asyncio.Queue()
|
||||
self.client_session:Optional[aiohttp.ClientSession] = None
|
||||
self.number = 0
|
||||
|
||||
middlewares = [cache_control]
|
||||
@@ -138,6 +156,14 @@ class PromptServer():
|
||||
embeddings = folder_paths.get_filename_list("embeddings")
|
||||
return web.json_response(list(map(lambda a: os.path.splitext(a)[0], embeddings)))
|
||||
|
||||
@routes.get("/models/{folder}")
|
||||
async def get_models(request):
|
||||
folder = request.match_info.get("folder", None)
|
||||
if not folder in folder_paths.folder_names_and_paths:
|
||||
return web.Response(status=404)
|
||||
files = folder_paths.get_filename_list(folder)
|
||||
return web.json_response(files)
|
||||
|
||||
@routes.get("/extensions")
|
||||
async def get_extensions(request):
|
||||
files = glob.glob(os.path.join(
|
||||
@@ -389,16 +415,20 @@ class PromptServer():
|
||||
return web.json_response(dt["__metadata__"])
|
||||
|
||||
@routes.get("/system_stats")
|
||||
async def get_queue(request):
|
||||
async def system_stats(request):
|
||||
device = comfy.model_management.get_torch_device()
|
||||
device_name = comfy.model_management.get_torch_device_name(device)
|
||||
vram_total, torch_vram_total = comfy.model_management.get_total_memory(device, torch_total_too=True)
|
||||
vram_free, torch_vram_free = comfy.model_management.get_free_memory(device, torch_free_too=True)
|
||||
|
||||
system_stats = {
|
||||
"system": {
|
||||
"os": os.name,
|
||||
"comfyui_version": get_comfyui_version(),
|
||||
"python_version": sys.version,
|
||||
"embedded_python": os.path.split(os.path.split(sys.executable)[0])[1] == "python_embeded"
|
||||
"pytorch_version": comfy.model_management.torch_version,
|
||||
"embedded_python": os.path.split(os.path.split(sys.executable)[0])[1] == "python_embeded",
|
||||
"argv": sys.argv
|
||||
},
|
||||
"devices": [
|
||||
{
|
||||
@@ -422,6 +452,7 @@ class PromptServer():
|
||||
obj_class = nodes.NODE_CLASS_MAPPINGS[node_class]
|
||||
info = {}
|
||||
info['input'] = obj_class.INPUT_TYPES()
|
||||
info['input_order'] = {key: list(value.keys()) for (key, value) in obj_class.INPUT_TYPES().items()}
|
||||
info['output'] = obj_class.RETURN_TYPES
|
||||
info['output_is_list'] = obj_class.OUTPUT_IS_LIST if hasattr(obj_class, 'OUTPUT_IS_LIST') else [False] * len(obj_class.RETURN_TYPES)
|
||||
info['output_name'] = obj_class.RETURN_NAMES if hasattr(obj_class, 'RETURN_NAMES') else info['output']
|
||||
@@ -437,6 +468,14 @@ class PromptServer():
|
||||
|
||||
if hasattr(obj_class, 'CATEGORY'):
|
||||
info['category'] = obj_class.CATEGORY
|
||||
|
||||
if hasattr(obj_class, 'OUTPUT_TOOLTIPS'):
|
||||
info['output_tooltips'] = obj_class.OUTPUT_TOOLTIPS
|
||||
|
||||
if getattr(obj_class, "DEPRECATED", False):
|
||||
info['deprecated'] = True
|
||||
if getattr(obj_class, "EXPERIMENTAL", False):
|
||||
info['experimental'] = True
|
||||
return info
|
||||
|
||||
@routes.get("/object_info")
|
||||
@@ -559,9 +598,42 @@ class PromptServer():
|
||||
self.prompt_queue.delete_history_item(id_to_delete)
|
||||
|
||||
return web.Response(status=200)
|
||||
|
||||
# Internal route. Should not be depended upon and is subject to change at any time.
|
||||
# TODO(robinhuang): Move to internal route table class once we refactor PromptServer to pass around Websocket.
|
||||
@routes.post("/internal/models/download")
|
||||
async def download_handler(request):
|
||||
async def report_progress(filename: str, status: DownloadModelStatus):
|
||||
payload = status.to_dict()
|
||||
payload['download_path'] = filename
|
||||
await self.send_json("download_progress", payload)
|
||||
|
||||
data = await request.json()
|
||||
url = data.get('url')
|
||||
model_directory = data.get('model_directory')
|
||||
model_filename = data.get('model_filename')
|
||||
progress_interval = data.get('progress_interval', 1.0) # In seconds, how often to report download progress.
|
||||
|
||||
if not url or not model_directory or not model_filename:
|
||||
return web.json_response({"status": "error", "message": "Missing URL or folder path or filename"}, status=400)
|
||||
|
||||
session = self.client_session
|
||||
if session is None:
|
||||
logging.error("Client session is not initialized")
|
||||
return web.Response(status=500)
|
||||
|
||||
task = asyncio.create_task(download_model(lambda url: session.get(url), model_filename, url, model_directory, report_progress, progress_interval))
|
||||
await task
|
||||
|
||||
return web.json_response(task.result().to_dict())
|
||||
|
||||
async def setup(self):
|
||||
timeout = aiohttp.ClientTimeout(total=None) # no timeout
|
||||
self.client_session = aiohttp.ClientSession(timeout=timeout)
|
||||
|
||||
def add_routes(self):
|
||||
self.user_manager.add_routes(self.routes)
|
||||
self.app.add_subapp('/internal', self.internal_routes.get_app())
|
||||
|
||||
# Prefix every route with /api for easier matching for delegation.
|
||||
# This is very useful for frontend dev server, which need to forward
|
||||
@@ -680,6 +752,9 @@ class PromptServer():
|
||||
site = web.TCPSite(runner, address, port, ssl_context=ssl_ctx)
|
||||
await site.start()
|
||||
|
||||
self.address = address
|
||||
self.port = port
|
||||
|
||||
if verbose:
|
||||
logging.info("Starting server\n")
|
||||
logging.info("To see the GUI go to: {}://{}:{}".format(scheme, address, port))
|
||||
|
1
tests-ui/.gitignore
vendored
1
tests-ui/.gitignore
vendored
@@ -1 +0,0 @@
|
||||
node_modules
|
@@ -1,9 +0,0 @@
|
||||
const { start } = require("./utils");
|
||||
const lg = require("./utils/litegraph");
|
||||
|
||||
// Load things once per test file before to ensure its all warmed up for the tests
|
||||
beforeAll(async () => {
|
||||
lg.setup(global);
|
||||
await start({ resetEnv: true });
|
||||
lg.teardown(global);
|
||||
});
|
@@ -1,4 +0,0 @@
|
||||
{
|
||||
"presets": ["@babel/preset-env"],
|
||||
"plugins": ["babel-plugin-transform-import-meta"]
|
||||
}
|
@@ -1,14 +0,0 @@
|
||||
module.exports = async function () {
|
||||
global.ResizeObserver = class ResizeObserver {
|
||||
observe() {}
|
||||
unobserve() {}
|
||||
disconnect() {}
|
||||
};
|
||||
|
||||
const { nop } = require("./utils/nopProxy");
|
||||
global.enableWebGLCanvas = nop;
|
||||
|
||||
HTMLCanvasElement.prototype.getContext = nop;
|
||||
|
||||
localStorage["Comfy.Settings.Comfy.Logging.Enabled"] = "false";
|
||||
};
|
@@ -1,11 +0,0 @@
|
||||
/** @type {import('jest').Config} */
|
||||
const config = {
|
||||
testEnvironment: "jsdom",
|
||||
setupFiles: ["./globalSetup.js"],
|
||||
setupFilesAfterEnv: ["./afterSetup.js"],
|
||||
clearMocks: true,
|
||||
resetModules: true,
|
||||
testTimeout: 10000
|
||||
};
|
||||
|
||||
module.exports = config;
|
5586
tests-ui/package-lock.json
generated
5586
tests-ui/package-lock.json
generated
File diff suppressed because it is too large
Load Diff
@@ -1,31 +0,0 @@
|
||||
{
|
||||
"name": "comfui-tests",
|
||||
"version": "1.0.0",
|
||||
"description": "UI tests",
|
||||
"main": "index.js",
|
||||
"scripts": {
|
||||
"test": "jest",
|
||||
"test:generate": "node setup.js"
|
||||
},
|
||||
"repository": {
|
||||
"type": "git",
|
||||
"url": "git+https://github.com/comfyanonymous/ComfyUI.git"
|
||||
},
|
||||
"keywords": [
|
||||
"comfyui",
|
||||
"test"
|
||||
],
|
||||
"author": "comfyanonymous",
|
||||
"license": "GPL-3.0",
|
||||
"bugs": {
|
||||
"url": "https://github.com/comfyanonymous/ComfyUI/issues"
|
||||
},
|
||||
"homepage": "https://github.com/comfyanonymous/ComfyUI#readme",
|
||||
"devDependencies": {
|
||||
"@babel/preset-env": "^7.22.20",
|
||||
"@types/jest": "^29.5.5",
|
||||
"babel-plugin-transform-import-meta": "^2.2.1",
|
||||
"jest": "^29.7.0",
|
||||
"jest-environment-jsdom": "^29.7.0"
|
||||
}
|
||||
}
|
@@ -1,88 +0,0 @@
|
||||
const { spawn } = require("child_process");
|
||||
const { resolve } = require("path");
|
||||
const { existsSync, mkdirSync, writeFileSync } = require("fs");
|
||||
const http = require("http");
|
||||
|
||||
async function setup() {
|
||||
// Wait up to 30s for it to start
|
||||
let success = false;
|
||||
let child;
|
||||
for (let i = 0; i < 30; i++) {
|
||||
try {
|
||||
await new Promise((res, rej) => {
|
||||
http
|
||||
.get("http://127.0.0.1:8188/object_info", (resp) => {
|
||||
let data = "";
|
||||
resp.on("data", (chunk) => {
|
||||
data += chunk;
|
||||
});
|
||||
resp.on("end", () => {
|
||||
// Modify the response data to add some checkpoints
|
||||
const objectInfo = JSON.parse(data);
|
||||
objectInfo.CheckpointLoaderSimple.input.required.ckpt_name[0] = ["model1.safetensors", "model2.ckpt"];
|
||||
objectInfo.VAELoader.input.required.vae_name[0] = ["vae1.safetensors", "vae2.ckpt"];
|
||||
|
||||
data = JSON.stringify(objectInfo, undefined, "\t");
|
||||
|
||||
const outDir = resolve("./data");
|
||||
if (!existsSync(outDir)) {
|
||||
mkdirSync(outDir);
|
||||
}
|
||||
|
||||
const outPath = resolve(outDir, "object_info.json");
|
||||
console.log(`Writing ${Object.keys(objectInfo).length} nodes to ${outPath}`);
|
||||
writeFileSync(outPath, data, {
|
||||
encoding: "utf8",
|
||||
});
|
||||
res();
|
||||
});
|
||||
})
|
||||
.on("error", rej);
|
||||
});
|
||||
success = true;
|
||||
break;
|
||||
} catch (error) {
|
||||
console.log(i + "/30", error);
|
||||
if (i === 0) {
|
||||
// Start the server on first iteration if it fails to connect
|
||||
console.log("Starting ComfyUI server...");
|
||||
|
||||
let python = resolve("../../python_embeded/python.exe");
|
||||
let args;
|
||||
let cwd;
|
||||
if (existsSync(python)) {
|
||||
args = ["-s", "ComfyUI/main.py"];
|
||||
cwd = "../..";
|
||||
} else {
|
||||
python = "python";
|
||||
args = ["main.py"];
|
||||
cwd = "..";
|
||||
}
|
||||
args.push("--cpu");
|
||||
console.log(python, ...args);
|
||||
child = spawn(python, args, { cwd });
|
||||
child.on("error", (err) => {
|
||||
console.log(`Server error (${err})`);
|
||||
i = 30;
|
||||
});
|
||||
child.on("exit", (code) => {
|
||||
if (!success) {
|
||||
console.log(`Server exited (${code})`);
|
||||
i = 30;
|
||||
}
|
||||
});
|
||||
}
|
||||
await new Promise((r) => {
|
||||
setTimeout(r, 1000);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
child?.kill();
|
||||
|
||||
if (!success) {
|
||||
throw new Error("Waiting for server failed...");
|
||||
}
|
||||
}
|
||||
|
||||
setup();
|
@@ -1,196 +0,0 @@
|
||||
// @ts-check
|
||||
/// <reference path="../node_modules/@types/jest/index.d.ts" />
|
||||
const { start } = require("../utils");
|
||||
const lg = require("../utils/litegraph");
|
||||
|
||||
describe("extensions", () => {
|
||||
beforeEach(() => {
|
||||
lg.setup(global);
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
lg.teardown(global);
|
||||
});
|
||||
|
||||
it("calls each extension hook", async () => {
|
||||
const mockExtension = {
|
||||
name: "TestExtension",
|
||||
init: jest.fn(),
|
||||
setup: jest.fn(),
|
||||
addCustomNodeDefs: jest.fn(),
|
||||
getCustomWidgets: jest.fn(),
|
||||
beforeRegisterNodeDef: jest.fn(),
|
||||
registerCustomNodes: jest.fn(),
|
||||
loadedGraphNode: jest.fn(),
|
||||
nodeCreated: jest.fn(),
|
||||
beforeConfigureGraph: jest.fn(),
|
||||
afterConfigureGraph: jest.fn(),
|
||||
};
|
||||
|
||||
const { app, ez, graph } = await start({
|
||||
async preSetup(app) {
|
||||
app.registerExtension(mockExtension);
|
||||
},
|
||||
});
|
||||
|
||||
// Basic initialisation hooks should be called once, with app
|
||||
expect(mockExtension.init).toHaveBeenCalledTimes(1);
|
||||
expect(mockExtension.init).toHaveBeenCalledWith(app);
|
||||
|
||||
// Adding custom node defs should be passed the full list of nodes
|
||||
expect(mockExtension.addCustomNodeDefs).toHaveBeenCalledTimes(1);
|
||||
expect(mockExtension.addCustomNodeDefs.mock.calls[0][1]).toStrictEqual(app);
|
||||
const defs = mockExtension.addCustomNodeDefs.mock.calls[0][0];
|
||||
expect(defs).toHaveProperty("KSampler");
|
||||
expect(defs).toHaveProperty("LoadImage");
|
||||
|
||||
// Get custom widgets is called once and should return new widget types
|
||||
expect(mockExtension.getCustomWidgets).toHaveBeenCalledTimes(1);
|
||||
expect(mockExtension.getCustomWidgets).toHaveBeenCalledWith(app);
|
||||
|
||||
// Before register node def will be called once per node type
|
||||
const nodeNames = Object.keys(defs);
|
||||
const nodeCount = nodeNames.length;
|
||||
expect(mockExtension.beforeRegisterNodeDef).toHaveBeenCalledTimes(nodeCount);
|
||||
for (let i = 0; i < 10; i++) {
|
||||
// It should be send the JS class and the original JSON definition
|
||||
const nodeClass = mockExtension.beforeRegisterNodeDef.mock.calls[i][0];
|
||||
const nodeDef = mockExtension.beforeRegisterNodeDef.mock.calls[i][1];
|
||||
|
||||
expect(nodeClass.name).toBe("ComfyNode");
|
||||
expect(nodeClass.comfyClass).toBe(nodeNames[i]);
|
||||
expect(nodeDef.name).toBe(nodeNames[i]);
|
||||
expect(nodeDef).toHaveProperty("input");
|
||||
expect(nodeDef).toHaveProperty("output");
|
||||
}
|
||||
|
||||
// Register custom nodes is called once after registerNode defs to allow adding other frontend nodes
|
||||
expect(mockExtension.registerCustomNodes).toHaveBeenCalledTimes(1);
|
||||
|
||||
// Before configure graph will be called here as the default graph is being loaded
|
||||
expect(mockExtension.beforeConfigureGraph).toHaveBeenCalledTimes(1);
|
||||
// it gets sent the graph data that is going to be loaded
|
||||
const graphData = mockExtension.beforeConfigureGraph.mock.calls[0][0];
|
||||
|
||||
// A node created is fired for each node constructor that is called
|
||||
expect(mockExtension.nodeCreated).toHaveBeenCalledTimes(graphData.nodes.length);
|
||||
for (let i = 0; i < graphData.nodes.length; i++) {
|
||||
expect(mockExtension.nodeCreated.mock.calls[i][0].type).toBe(graphData.nodes[i].type);
|
||||
}
|
||||
|
||||
// Each node then calls loadedGraphNode to allow them to be updated
|
||||
expect(mockExtension.loadedGraphNode).toHaveBeenCalledTimes(graphData.nodes.length);
|
||||
for (let i = 0; i < graphData.nodes.length; i++) {
|
||||
expect(mockExtension.loadedGraphNode.mock.calls[i][0].type).toBe(graphData.nodes[i].type);
|
||||
}
|
||||
|
||||
// After configure is then called once all the setup is done
|
||||
expect(mockExtension.afterConfigureGraph).toHaveBeenCalledTimes(1);
|
||||
|
||||
expect(mockExtension.setup).toHaveBeenCalledTimes(1);
|
||||
expect(mockExtension.setup).toHaveBeenCalledWith(app);
|
||||
|
||||
// Ensure hooks are called in the correct order
|
||||
const callOrder = [
|
||||
"init",
|
||||
"addCustomNodeDefs",
|
||||
"getCustomWidgets",
|
||||
"beforeRegisterNodeDef",
|
||||
"registerCustomNodes",
|
||||
"beforeConfigureGraph",
|
||||
"nodeCreated",
|
||||
"loadedGraphNode",
|
||||
"afterConfigureGraph",
|
||||
"setup",
|
||||
];
|
||||
for (let i = 1; i < callOrder.length; i++) {
|
||||
const fn1 = mockExtension[callOrder[i - 1]];
|
||||
const fn2 = mockExtension[callOrder[i]];
|
||||
expect(fn1.mock.invocationCallOrder[0]).toBeLessThan(fn2.mock.invocationCallOrder[0]);
|
||||
}
|
||||
|
||||
graph.clear();
|
||||
|
||||
// Ensure adding a new node calls the correct callback
|
||||
ez.LoadImage();
|
||||
expect(mockExtension.loadedGraphNode).toHaveBeenCalledTimes(graphData.nodes.length);
|
||||
expect(mockExtension.nodeCreated).toHaveBeenCalledTimes(graphData.nodes.length + 1);
|
||||
expect(mockExtension.nodeCreated.mock.lastCall[0].type).toBe("LoadImage");
|
||||
|
||||
// Reload the graph to ensure correct hooks are fired
|
||||
await graph.reload();
|
||||
|
||||
// These hooks should not be fired again
|
||||
expect(mockExtension.init).toHaveBeenCalledTimes(1);
|
||||
expect(mockExtension.addCustomNodeDefs).toHaveBeenCalledTimes(1);
|
||||
expect(mockExtension.getCustomWidgets).toHaveBeenCalledTimes(1);
|
||||
expect(mockExtension.registerCustomNodes).toHaveBeenCalledTimes(1);
|
||||
expect(mockExtension.beforeRegisterNodeDef).toHaveBeenCalledTimes(nodeCount);
|
||||
expect(mockExtension.setup).toHaveBeenCalledTimes(1);
|
||||
|
||||
// These should be called again
|
||||
expect(mockExtension.beforeConfigureGraph).toHaveBeenCalledTimes(2);
|
||||
expect(mockExtension.nodeCreated).toHaveBeenCalledTimes(graphData.nodes.length + 2);
|
||||
expect(mockExtension.loadedGraphNode).toHaveBeenCalledTimes(graphData.nodes.length + 1);
|
||||
expect(mockExtension.afterConfigureGraph).toHaveBeenCalledTimes(2);
|
||||
}, 15000);
|
||||
|
||||
it("allows custom nodeDefs and widgets to be registered", async () => {
|
||||
const widgetMock = jest.fn((node, inputName, inputData, app) => {
|
||||
expect(node.constructor.comfyClass).toBe("TestNode");
|
||||
expect(inputName).toBe("test_input");
|
||||
expect(inputData[0]).toBe("CUSTOMWIDGET");
|
||||
expect(inputData[1]?.hello).toBe("world");
|
||||
expect(app).toStrictEqual(app);
|
||||
|
||||
return {
|
||||
widget: node.addWidget("button", inputName, "hello", () => {}),
|
||||
};
|
||||
});
|
||||
|
||||
// Register our extension that adds a custom node + widget type
|
||||
const mockExtension = {
|
||||
name: "TestExtension",
|
||||
addCustomNodeDefs: (nodeDefs) => {
|
||||
nodeDefs["TestNode"] = {
|
||||
output: [],
|
||||
output_name: [],
|
||||
output_is_list: [],
|
||||
name: "TestNode",
|
||||
display_name: "TestNode",
|
||||
category: "Test",
|
||||
input: {
|
||||
required: {
|
||||
test_input: ["CUSTOMWIDGET", { hello: "world" }],
|
||||
},
|
||||
},
|
||||
};
|
||||
},
|
||||
getCustomWidgets: jest.fn(() => {
|
||||
return {
|
||||
CUSTOMWIDGET: widgetMock,
|
||||
};
|
||||
}),
|
||||
};
|
||||
|
||||
const { graph, ez } = await start({
|
||||
async preSetup(app) {
|
||||
app.registerExtension(mockExtension);
|
||||
},
|
||||
});
|
||||
|
||||
expect(mockExtension.getCustomWidgets).toBeCalledTimes(1);
|
||||
|
||||
graph.clear();
|
||||
expect(widgetMock).toBeCalledTimes(0);
|
||||
const node = ez.TestNode();
|
||||
expect(widgetMock).toBeCalledTimes(1);
|
||||
|
||||
// Ensure our custom widget is created
|
||||
expect(node.inputs.length).toBe(0);
|
||||
expect(node.widgets.length).toBe(1);
|
||||
const w = node.widgets[0].widget;
|
||||
expect(w.name).toBe("test_input");
|
||||
expect(w.type).toBe("button");
|
||||
});
|
||||
});
|
File diff suppressed because it is too large
Load Diff
@@ -1,295 +0,0 @@
|
||||
// @ts-check
|
||||
/// <reference path="../node_modules/@types/jest/index.d.ts" />
|
||||
const { start } = require("../utils");
|
||||
const lg = require("../utils/litegraph");
|
||||
|
||||
describe("users", () => {
|
||||
beforeEach(() => {
|
||||
lg.setup(global);
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
lg.teardown(global);
|
||||
});
|
||||
|
||||
function expectNoUserScreen() {
|
||||
// Ensure login isnt visible
|
||||
const selection = document.querySelectorAll("#comfy-user-selection")?.[0];
|
||||
expect(selection["style"].display).toBe("none");
|
||||
const menu = document.querySelectorAll(".comfy-menu")?.[0];
|
||||
expect(window.getComputedStyle(menu)?.display).not.toBe("none");
|
||||
}
|
||||
|
||||
describe("multi-user", () => {
|
||||
function mockAddStylesheet() {
|
||||
const utils = require("../../web/scripts/utils");
|
||||
utils.addStylesheet = jest.fn().mockReturnValue(Promise.resolve());
|
||||
}
|
||||
|
||||
async function waitForUserScreenShow() {
|
||||
mockAddStylesheet();
|
||||
|
||||
// Wait for "show" to be called
|
||||
const { UserSelectionScreen } = require("../../web/scripts/ui/userSelection");
|
||||
let resolve, reject;
|
||||
const fn = UserSelectionScreen.prototype.show;
|
||||
const p = new Promise((res, rej) => {
|
||||
resolve = res;
|
||||
reject = rej;
|
||||
});
|
||||
jest.spyOn(UserSelectionScreen.prototype, "show").mockImplementation(async (...args) => {
|
||||
const res = fn(...args);
|
||||
await new Promise(process.nextTick); // wait for promises to resolve
|
||||
resolve();
|
||||
return res;
|
||||
});
|
||||
// @ts-ignore
|
||||
setTimeout(() => reject("timeout waiting for UserSelectionScreen to be shown."), 500);
|
||||
await p;
|
||||
await new Promise(process.nextTick); // wait for promises to resolve
|
||||
}
|
||||
|
||||
async function testUserScreen(onShown, users) {
|
||||
if (!users) {
|
||||
users = {};
|
||||
}
|
||||
const starting = start({
|
||||
resetEnv: true,
|
||||
userConfig: { storage: "server", users },
|
||||
});
|
||||
|
||||
// Ensure no current user
|
||||
expect(localStorage["Comfy.userId"]).toBeFalsy();
|
||||
expect(localStorage["Comfy.userName"]).toBeFalsy();
|
||||
|
||||
await waitForUserScreenShow();
|
||||
|
||||
const selection = document.querySelectorAll("#comfy-user-selection")?.[0];
|
||||
expect(selection).toBeTruthy();
|
||||
|
||||
// Ensure login is visible
|
||||
expect(window.getComputedStyle(selection)?.display).not.toBe("none");
|
||||
// Ensure menu is hidden
|
||||
const menu = document.querySelectorAll(".comfy-menu")?.[0];
|
||||
expect(window.getComputedStyle(menu)?.display).toBe("none");
|
||||
|
||||
const isCreate = await onShown(selection);
|
||||
|
||||
// Submit form
|
||||
selection.querySelectorAll("form")[0].submit();
|
||||
await new Promise(process.nextTick); // wait for promises to resolve
|
||||
|
||||
// Wait for start
|
||||
const s = await starting;
|
||||
|
||||
// Ensure login is removed
|
||||
expect(document.querySelectorAll("#comfy-user-selection")).toHaveLength(0);
|
||||
expect(window.getComputedStyle(menu)?.display).not.toBe("none");
|
||||
|
||||
// Ensure settings + templates are saved
|
||||
const { api } = require("../../web/scripts/api");
|
||||
expect(api.createUser).toHaveBeenCalledTimes(+isCreate);
|
||||
expect(api.storeSettings).toHaveBeenCalledTimes(+isCreate);
|
||||
expect(api.storeUserData).toHaveBeenCalledTimes(+isCreate);
|
||||
if (isCreate) {
|
||||
expect(api.storeUserData).toHaveBeenCalledWith("comfy.templates.json", null, { stringify: false });
|
||||
expect(s.app.isNewUserSession).toBeTruthy();
|
||||
} else {
|
||||
expect(s.app.isNewUserSession).toBeFalsy();
|
||||
}
|
||||
|
||||
return { users, selection, ...s };
|
||||
}
|
||||
|
||||
it("allows user creation if no users", async () => {
|
||||
const { users } = await testUserScreen((selection) => {
|
||||
// Ensure we have no users flag added
|
||||
expect(selection.classList.contains("no-users")).toBeTruthy();
|
||||
|
||||
// Enter a username
|
||||
const input = selection.getElementsByTagName("input")[0];
|
||||
input.focus();
|
||||
input.value = "Test User";
|
||||
|
||||
return true;
|
||||
});
|
||||
|
||||
expect(users).toStrictEqual({
|
||||
"Test User!": "Test User",
|
||||
});
|
||||
|
||||
expect(localStorage["Comfy.userId"]).toBe("Test User!");
|
||||
expect(localStorage["Comfy.userName"]).toBe("Test User");
|
||||
});
|
||||
it("allows user creation if no current user but other users", async () => {
|
||||
const users = {
|
||||
"Test User 2!": "Test User 2",
|
||||
};
|
||||
|
||||
await testUserScreen((selection) => {
|
||||
expect(selection.classList.contains("no-users")).toBeFalsy();
|
||||
|
||||
// Enter a username
|
||||
const input = selection.getElementsByTagName("input")[0];
|
||||
input.focus();
|
||||
input.value = "Test User 3";
|
||||
return true;
|
||||
}, users);
|
||||
|
||||
expect(users).toStrictEqual({
|
||||
"Test User 2!": "Test User 2",
|
||||
"Test User 3!": "Test User 3",
|
||||
});
|
||||
|
||||
expect(localStorage["Comfy.userId"]).toBe("Test User 3!");
|
||||
expect(localStorage["Comfy.userName"]).toBe("Test User 3");
|
||||
});
|
||||
it("allows user selection if no current user but other users", async () => {
|
||||
const users = {
|
||||
"A!": "A",
|
||||
"B!": "B",
|
||||
"C!": "C",
|
||||
};
|
||||
|
||||
await testUserScreen((selection) => {
|
||||
expect(selection.classList.contains("no-users")).toBeFalsy();
|
||||
|
||||
// Check user list
|
||||
const select = selection.getElementsByTagName("select")[0];
|
||||
const options = select.getElementsByTagName("option");
|
||||
expect(
|
||||
[...options]
|
||||
.filter((o) => !o.disabled)
|
||||
.reduce((p, n) => {
|
||||
p[n.getAttribute("value")] = n.textContent;
|
||||
return p;
|
||||
}, {})
|
||||
).toStrictEqual(users);
|
||||
|
||||
// Select an option
|
||||
select.focus();
|
||||
select.value = options[2].value;
|
||||
|
||||
return false;
|
||||
}, users);
|
||||
|
||||
expect(users).toStrictEqual(users);
|
||||
|
||||
expect(localStorage["Comfy.userId"]).toBe("B!");
|
||||
expect(localStorage["Comfy.userName"]).toBe("B");
|
||||
});
|
||||
it("doesnt show user screen if current user", async () => {
|
||||
const starting = start({
|
||||
resetEnv: true,
|
||||
userConfig: {
|
||||
storage: "server",
|
||||
users: {
|
||||
"User!": "User",
|
||||
},
|
||||
},
|
||||
localStorage: {
|
||||
"Comfy.userId": "User!",
|
||||
"Comfy.userName": "User",
|
||||
},
|
||||
});
|
||||
await new Promise(process.nextTick); // wait for promises to resolve
|
||||
|
||||
expectNoUserScreen();
|
||||
|
||||
await starting;
|
||||
});
|
||||
it("allows user switching", async () => {
|
||||
const { app } = await start({
|
||||
resetEnv: true,
|
||||
userConfig: {
|
||||
storage: "server",
|
||||
users: {
|
||||
"User!": "User",
|
||||
},
|
||||
},
|
||||
localStorage: {
|
||||
"Comfy.userId": "User!",
|
||||
"Comfy.userName": "User",
|
||||
},
|
||||
});
|
||||
|
||||
// cant actually test switching user easily but can check the setting is present
|
||||
expect(app.ui.settings.settingsLookup["Comfy.SwitchUser"]).toBeTruthy();
|
||||
});
|
||||
});
|
||||
describe("single-user", () => {
|
||||
it("doesnt show user creation if no default user", async () => {
|
||||
const { app } = await start({
|
||||
resetEnv: true,
|
||||
userConfig: { migrated: false, storage: "server" },
|
||||
});
|
||||
expectNoUserScreen();
|
||||
|
||||
// It should store the settings
|
||||
const { api } = require("../../web/scripts/api");
|
||||
expect(api.storeSettings).toHaveBeenCalledTimes(1);
|
||||
expect(api.storeUserData).toHaveBeenCalledTimes(1);
|
||||
expect(api.storeUserData).toHaveBeenCalledWith("comfy.templates.json", null, { stringify: false });
|
||||
expect(app.isNewUserSession).toBeTruthy();
|
||||
});
|
||||
it("doesnt show user creation if default user", async () => {
|
||||
const { app } = await start({
|
||||
resetEnv: true,
|
||||
userConfig: { migrated: true, storage: "server" },
|
||||
});
|
||||
expectNoUserScreen();
|
||||
|
||||
// It should store the settings
|
||||
const { api } = require("../../web/scripts/api");
|
||||
expect(api.storeSettings).toHaveBeenCalledTimes(0);
|
||||
expect(api.storeUserData).toHaveBeenCalledTimes(0);
|
||||
expect(app.isNewUserSession).toBeFalsy();
|
||||
});
|
||||
it("doesnt allow user switching", async () => {
|
||||
const { app } = await start({
|
||||
resetEnv: true,
|
||||
userConfig: { migrated: true, storage: "server" },
|
||||
});
|
||||
expectNoUserScreen();
|
||||
|
||||
expect(app.ui.settings.settingsLookup["Comfy.SwitchUser"]).toBeFalsy();
|
||||
});
|
||||
});
|
||||
describe("browser-user", () => {
|
||||
it("doesnt show user creation if no default user", async () => {
|
||||
const { app } = await start({
|
||||
resetEnv: true,
|
||||
userConfig: { migrated: false, storage: "browser" },
|
||||
});
|
||||
expectNoUserScreen();
|
||||
|
||||
// It should store the settings
|
||||
const { api } = require("../../web/scripts/api");
|
||||
expect(api.storeSettings).toHaveBeenCalledTimes(0);
|
||||
expect(api.storeUserData).toHaveBeenCalledTimes(0);
|
||||
expect(app.isNewUserSession).toBeFalsy();
|
||||
});
|
||||
it("doesnt show user creation if default user", async () => {
|
||||
const { app } = await start({
|
||||
resetEnv: true,
|
||||
userConfig: { migrated: true, storage: "server" },
|
||||
});
|
||||
expectNoUserScreen();
|
||||
|
||||
// It should store the settings
|
||||
const { api } = require("../../web/scripts/api");
|
||||
expect(api.storeSettings).toHaveBeenCalledTimes(0);
|
||||
expect(api.storeUserData).toHaveBeenCalledTimes(0);
|
||||
expect(app.isNewUserSession).toBeFalsy();
|
||||
});
|
||||
it("doesnt allow user switching", async () => {
|
||||
const { app } = await start({
|
||||
resetEnv: true,
|
||||
userConfig: { migrated: true, storage: "browser" },
|
||||
});
|
||||
expectNoUserScreen();
|
||||
|
||||
expect(app.ui.settings.settingsLookup["Comfy.SwitchUser"]).toBeFalsy();
|
||||
});
|
||||
});
|
||||
});
|
@@ -1,557 +0,0 @@
|
||||
// @ts-check
|
||||
/// <reference path="../node_modules/@types/jest/index.d.ts" />
|
||||
|
||||
const {
|
||||
start,
|
||||
makeNodeDef,
|
||||
checkBeforeAndAfterReload,
|
||||
assertNotNullOrUndefined,
|
||||
createDefaultWorkflow,
|
||||
} = require("../utils");
|
||||
const lg = require("../utils/litegraph");
|
||||
|
||||
/**
|
||||
* @typedef { import("../utils/ezgraph") } Ez
|
||||
* @typedef { ReturnType<Ez["Ez"]["graph"]>["ez"] } EzNodeFactory
|
||||
*/
|
||||
|
||||
/**
|
||||
* @param { EzNodeFactory } ez
|
||||
* @param { InstanceType<Ez["EzGraph"]> } graph
|
||||
* @param { InstanceType<Ez["EzInput"]> } input
|
||||
* @param { string } widgetType
|
||||
* @param { number } controlWidgetCount
|
||||
* @returns
|
||||
*/
|
||||
async function connectPrimitiveAndReload(ez, graph, input, widgetType, controlWidgetCount = 0) {
|
||||
// Connect to primitive and ensure its still connected after
|
||||
let primitive = ez.PrimitiveNode();
|
||||
primitive.outputs[0].connectTo(input);
|
||||
|
||||
await checkBeforeAndAfterReload(graph, async () => {
|
||||
primitive = graph.find(primitive);
|
||||
let { connections } = primitive.outputs[0];
|
||||
expect(connections).toHaveLength(1);
|
||||
expect(connections[0].targetNode.id).toBe(input.node.node.id);
|
||||
|
||||
// Ensure widget is correct type
|
||||
const valueWidget = primitive.widgets.value;
|
||||
expect(valueWidget.widget.type).toBe(widgetType);
|
||||
|
||||
// Check if control_after_generate should be added
|
||||
if (controlWidgetCount) {
|
||||
const controlWidget = primitive.widgets.control_after_generate;
|
||||
expect(controlWidget.widget.type).toBe("combo");
|
||||
if (widgetType === "combo") {
|
||||
const filterWidget = primitive.widgets.control_filter_list;
|
||||
expect(filterWidget.widget.type).toBe("string");
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure we dont have other widgets
|
||||
expect(primitive.node.widgets).toHaveLength(1 + controlWidgetCount);
|
||||
});
|
||||
|
||||
return primitive;
|
||||
}
|
||||
|
||||
describe("widget inputs", () => {
|
||||
beforeEach(() => {
|
||||
lg.setup(global);
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
lg.teardown(global);
|
||||
});
|
||||
|
||||
[
|
||||
{ name: "int", type: "INT", widget: "number", control: 1 },
|
||||
{ name: "float", type: "FLOAT", widget: "number", control: 1 },
|
||||
{ name: "text", type: "STRING" },
|
||||
{
|
||||
name: "customtext",
|
||||
type: "STRING",
|
||||
opt: { multiline: true },
|
||||
},
|
||||
{ name: "toggle", type: "BOOLEAN" },
|
||||
{ name: "combo", type: ["a", "b", "c"], control: 2 },
|
||||
].forEach((c) => {
|
||||
test(`widget conversion + primitive works on ${c.name}`, async () => {
|
||||
const { ez, graph } = await start({
|
||||
mockNodeDefs: makeNodeDef("TestNode", { [c.name]: [c.type, c.opt ?? {}] }),
|
||||
});
|
||||
|
||||
// Create test node and convert to input
|
||||
const n = ez.TestNode();
|
||||
const w = n.widgets[c.name];
|
||||
w.convertToInput();
|
||||
expect(w.isConvertedToInput).toBeTruthy();
|
||||
const input = w.getConvertedInput();
|
||||
expect(input).toBeTruthy();
|
||||
|
||||
// @ts-ignore : input is valid here
|
||||
await connectPrimitiveAndReload(ez, graph, input, c.widget ?? c.name, c.control);
|
||||
});
|
||||
});
|
||||
|
||||
test("converted widget works after reload", async () => {
|
||||
const { ez, graph } = await start();
|
||||
let n = ez.CheckpointLoaderSimple();
|
||||
|
||||
const inputCount = n.inputs.length;
|
||||
|
||||
// Convert ckpt name to an input
|
||||
n.widgets.ckpt_name.convertToInput();
|
||||
expect(n.widgets.ckpt_name.isConvertedToInput).toBeTruthy();
|
||||
expect(n.inputs.ckpt_name).toBeTruthy();
|
||||
expect(n.inputs.length).toEqual(inputCount + 1);
|
||||
|
||||
// Convert back to widget and ensure input is removed
|
||||
n.widgets.ckpt_name.convertToWidget();
|
||||
expect(n.widgets.ckpt_name.isConvertedToInput).toBeFalsy();
|
||||
expect(n.inputs.ckpt_name).toBeFalsy();
|
||||
expect(n.inputs.length).toEqual(inputCount);
|
||||
|
||||
// Convert again and reload the graph to ensure it maintains state
|
||||
n.widgets.ckpt_name.convertToInput();
|
||||
expect(n.inputs.length).toEqual(inputCount + 1);
|
||||
|
||||
const primitive = await connectPrimitiveAndReload(ez, graph, n.inputs.ckpt_name, "combo", 2);
|
||||
|
||||
// Disconnect & reconnect
|
||||
primitive.outputs[0].connections[0].disconnect();
|
||||
let { connections } = primitive.outputs[0];
|
||||
expect(connections).toHaveLength(0);
|
||||
|
||||
primitive.outputs[0].connectTo(n.inputs.ckpt_name);
|
||||
({ connections } = primitive.outputs[0]);
|
||||
expect(connections).toHaveLength(1);
|
||||
expect(connections[0].targetNode.id).toBe(n.node.id);
|
||||
|
||||
// Convert back to widget and ensure input is removed
|
||||
n.widgets.ckpt_name.convertToWidget();
|
||||
expect(n.widgets.ckpt_name.isConvertedToInput).toBeFalsy();
|
||||
expect(n.inputs.ckpt_name).toBeFalsy();
|
||||
expect(n.inputs.length).toEqual(inputCount);
|
||||
});
|
||||
|
||||
test("converted widget works on clone", async () => {
|
||||
const { graph, ez } = await start();
|
||||
let n = ez.CheckpointLoaderSimple();
|
||||
|
||||
// Convert the widget to an input
|
||||
n.widgets.ckpt_name.convertToInput();
|
||||
expect(n.widgets.ckpt_name.isConvertedToInput).toBeTruthy();
|
||||
|
||||
// Clone the node
|
||||
n.menu["Clone"].call();
|
||||
expect(graph.nodes).toHaveLength(2);
|
||||
const clone = graph.nodes[1];
|
||||
expect(clone.id).not.toEqual(n.id);
|
||||
|
||||
// Ensure the clone has an input
|
||||
expect(clone.widgets.ckpt_name.isConvertedToInput).toBeTruthy();
|
||||
expect(clone.inputs.ckpt_name).toBeTruthy();
|
||||
|
||||
// Ensure primitive connects to both nodes
|
||||
let primitive = ez.PrimitiveNode();
|
||||
primitive.outputs[0].connectTo(n.inputs.ckpt_name);
|
||||
primitive.outputs[0].connectTo(clone.inputs.ckpt_name);
|
||||
expect(primitive.outputs[0].connections).toHaveLength(2);
|
||||
|
||||
// Convert back to widget and ensure input is removed
|
||||
clone.widgets.ckpt_name.convertToWidget();
|
||||
expect(clone.widgets.ckpt_name.isConvertedToInput).toBeFalsy();
|
||||
expect(clone.inputs.ckpt_name).toBeFalsy();
|
||||
});
|
||||
|
||||
test("shows missing node error on custom node with converted input", async () => {
|
||||
const { graph } = await start();
|
||||
|
||||
const dialogShow = jest.spyOn(graph.app.ui.dialog, "show");
|
||||
|
||||
await graph.app.loadGraphData({
|
||||
last_node_id: 3,
|
||||
last_link_id: 4,
|
||||
nodes: [
|
||||
{
|
||||
id: 1,
|
||||
type: "TestNode",
|
||||
pos: [41.87329101561909, 389.7381480823742],
|
||||
size: { 0: 220, 1: 374 },
|
||||
flags: {},
|
||||
order: 1,
|
||||
mode: 0,
|
||||
inputs: [{ name: "test", type: "FLOAT", link: 4, widget: { name: "test" }, slot_index: 0 }],
|
||||
outputs: [],
|
||||
properties: { "Node name for S&R": "TestNode" },
|
||||
widgets_values: [1],
|
||||
},
|
||||
{
|
||||
id: 3,
|
||||
type: "PrimitiveNode",
|
||||
pos: [-312, 433],
|
||||
size: { 0: 210, 1: 82 },
|
||||
flags: {},
|
||||
order: 0,
|
||||
mode: 0,
|
||||
outputs: [{ links: [4], widget: { name: "test" } }],
|
||||
title: "test",
|
||||
properties: {},
|
||||
},
|
||||
],
|
||||
links: [[4, 3, 0, 1, 6, "FLOAT"]],
|
||||
groups: [],
|
||||
config: {},
|
||||
extra: {},
|
||||
version: 0.4,
|
||||
});
|
||||
|
||||
expect(dialogShow).toBeCalledTimes(1);
|
||||
expect(dialogShow.mock.calls[0][0].innerHTML).toContain("the following node types were not found");
|
||||
expect(dialogShow.mock.calls[0][0].innerHTML).toContain("TestNode");
|
||||
});
|
||||
|
||||
test("defaultInput widgets can be converted back to inputs", async () => {
|
||||
const { graph, ez } = await start({
|
||||
mockNodeDefs: makeNodeDef("TestNode", { example: ["INT", { defaultInput: true }] }),
|
||||
});
|
||||
|
||||
// Create test node and ensure it starts as an input
|
||||
let n = ez.TestNode();
|
||||
let w = n.widgets.example;
|
||||
expect(w.isConvertedToInput).toBeTruthy();
|
||||
let input = w.getConvertedInput();
|
||||
expect(input).toBeTruthy();
|
||||
|
||||
// Ensure it can be converted to
|
||||
w.convertToWidget();
|
||||
expect(w.isConvertedToInput).toBeFalsy();
|
||||
expect(n.inputs.length).toEqual(0);
|
||||
// and from
|
||||
w.convertToInput();
|
||||
expect(w.isConvertedToInput).toBeTruthy();
|
||||
input = w.getConvertedInput();
|
||||
|
||||
// Reload and ensure it still only has 1 converted widget
|
||||
if (!assertNotNullOrUndefined(input)) return;
|
||||
|
||||
await connectPrimitiveAndReload(ez, graph, input, "number", 1);
|
||||
n = graph.find(n);
|
||||
expect(n.widgets).toHaveLength(1);
|
||||
w = n.widgets.example;
|
||||
expect(w.isConvertedToInput).toBeTruthy();
|
||||
|
||||
// Convert back to widget and ensure it is still a widget after reload
|
||||
w.convertToWidget();
|
||||
await graph.reload();
|
||||
n = graph.find(n);
|
||||
expect(n.widgets).toHaveLength(1);
|
||||
expect(n.widgets[0].isConvertedToInput).toBeFalsy();
|
||||
expect(n.inputs.length).toEqual(0);
|
||||
});
|
||||
|
||||
test("forceInput widgets can not be converted back to inputs", async () => {
|
||||
const { graph, ez } = await start({
|
||||
mockNodeDefs: makeNodeDef("TestNode", { example: ["INT", { forceInput: true }] }),
|
||||
});
|
||||
|
||||
// Create test node and ensure it starts as an input
|
||||
let n = ez.TestNode();
|
||||
let w = n.widgets.example;
|
||||
expect(w.isConvertedToInput).toBeTruthy();
|
||||
const input = w.getConvertedInput();
|
||||
expect(input).toBeTruthy();
|
||||
|
||||
// Convert to widget should error
|
||||
expect(() => w.convertToWidget()).toThrow();
|
||||
|
||||
// Reload and ensure it still only has 1 converted widget
|
||||
if (assertNotNullOrUndefined(input)) {
|
||||
await connectPrimitiveAndReload(ez, graph, input, "number", 1);
|
||||
n = graph.find(n);
|
||||
expect(n.widgets).toHaveLength(1);
|
||||
expect(n.widgets.example.isConvertedToInput).toBeTruthy();
|
||||
}
|
||||
});
|
||||
|
||||
test("primitive can connect to matching combos on converted widgets", async () => {
|
||||
const { ez } = await start({
|
||||
mockNodeDefs: {
|
||||
...makeNodeDef("TestNode1", { example: [["A", "B", "C"], { forceInput: true }] }),
|
||||
...makeNodeDef("TestNode2", { example: [["A", "B", "C"], { forceInput: true }] }),
|
||||
},
|
||||
});
|
||||
|
||||
const n1 = ez.TestNode1();
|
||||
const n2 = ez.TestNode2();
|
||||
const p = ez.PrimitiveNode();
|
||||
p.outputs[0].connectTo(n1.inputs[0]);
|
||||
p.outputs[0].connectTo(n2.inputs[0]);
|
||||
expect(p.outputs[0].connections).toHaveLength(2);
|
||||
const valueWidget = p.widgets.value;
|
||||
expect(valueWidget.widget.type).toBe("combo");
|
||||
expect(valueWidget.widget.options.values).toEqual(["A", "B", "C"]);
|
||||
});
|
||||
|
||||
test("primitive can not connect to non matching combos on converted widgets", async () => {
|
||||
const { ez } = await start({
|
||||
mockNodeDefs: {
|
||||
...makeNodeDef("TestNode1", { example: [["A", "B", "C"], { forceInput: true }] }),
|
||||
...makeNodeDef("TestNode2", { example: [["A", "B"], { forceInput: true }] }),
|
||||
},
|
||||
});
|
||||
|
||||
const n1 = ez.TestNode1();
|
||||
const n2 = ez.TestNode2();
|
||||
const p = ez.PrimitiveNode();
|
||||
p.outputs[0].connectTo(n1.inputs[0]);
|
||||
expect(() => p.outputs[0].connectTo(n2.inputs[0])).toThrow();
|
||||
expect(p.outputs[0].connections).toHaveLength(1);
|
||||
});
|
||||
|
||||
test("combo output can not connect to non matching combos list input", async () => {
|
||||
const { ez } = await start({
|
||||
mockNodeDefs: {
|
||||
...makeNodeDef("TestNode1", {}, [["A", "B"]]),
|
||||
...makeNodeDef("TestNode2", { example: [["A", "B"], { forceInput: true }] }),
|
||||
...makeNodeDef("TestNode3", { example: [["A", "B", "C"], { forceInput: true }] }),
|
||||
},
|
||||
});
|
||||
|
||||
const n1 = ez.TestNode1();
|
||||
const n2 = ez.TestNode2();
|
||||
const n3 = ez.TestNode3();
|
||||
|
||||
n1.outputs[0].connectTo(n2.inputs[0]);
|
||||
expect(() => n1.outputs[0].connectTo(n3.inputs[0])).toThrow();
|
||||
});
|
||||
|
||||
test("combo primitive can filter list when control_after_generate called", async () => {
|
||||
const { ez } = await start({
|
||||
mockNodeDefs: {
|
||||
...makeNodeDef("TestNode1", { example: [["A", "B", "C", "D", "AA", "BB", "CC", "DD", "AAA", "BBB"], {}] }),
|
||||
},
|
||||
});
|
||||
|
||||
const n1 = ez.TestNode1();
|
||||
n1.widgets.example.convertToInput();
|
||||
const p = ez.PrimitiveNode();
|
||||
p.outputs[0].connectTo(n1.inputs[0]);
|
||||
|
||||
const value = p.widgets.value;
|
||||
const control = p.widgets.control_after_generate.widget;
|
||||
const filter = p.widgets.control_filter_list;
|
||||
|
||||
expect(p.widgets.length).toBe(3);
|
||||
control.value = "increment";
|
||||
expect(value.value).toBe("A");
|
||||
|
||||
// Manually trigger after queue when set to increment
|
||||
control["afterQueued"]();
|
||||
expect(value.value).toBe("B");
|
||||
|
||||
// Filter to items containing D
|
||||
filter.value = "D";
|
||||
control["afterQueued"]();
|
||||
expect(value.value).toBe("D");
|
||||
control["afterQueued"]();
|
||||
expect(value.value).toBe("DD");
|
||||
|
||||
// Check decrement
|
||||
value.value = "BBB";
|
||||
control.value = "decrement";
|
||||
filter.value = "B";
|
||||
control["afterQueued"]();
|
||||
expect(value.value).toBe("BB");
|
||||
control["afterQueued"]();
|
||||
expect(value.value).toBe("B");
|
||||
|
||||
// Check regex works
|
||||
value.value = "BBB";
|
||||
filter.value = "/[AB]|^C$/";
|
||||
control["afterQueued"]();
|
||||
expect(value.value).toBe("AAA");
|
||||
control["afterQueued"]();
|
||||
expect(value.value).toBe("BB");
|
||||
control["afterQueued"]();
|
||||
expect(value.value).toBe("AA");
|
||||
control["afterQueued"]();
|
||||
expect(value.value).toBe("C");
|
||||
control["afterQueued"]();
|
||||
expect(value.value).toBe("B");
|
||||
control["afterQueued"]();
|
||||
expect(value.value).toBe("A");
|
||||
|
||||
// Check random
|
||||
control.value = "randomize";
|
||||
filter.value = "/D/";
|
||||
for (let i = 0; i < 100; i++) {
|
||||
control["afterQueued"]();
|
||||
expect(value.value === "D" || value.value === "DD").toBeTruthy();
|
||||
}
|
||||
|
||||
// Ensure it doesnt apply when fixed
|
||||
control.value = "fixed";
|
||||
value.value = "B";
|
||||
filter.value = "C";
|
||||
control["afterQueued"]();
|
||||
expect(value.value).toBe("B");
|
||||
});
|
||||
|
||||
describe("reroutes", () => {
|
||||
async function checkOutput(graph, values) {
|
||||
expect((await graph.toPrompt()).output).toStrictEqual({
|
||||
1: { inputs: { ckpt_name: "model1.safetensors" }, class_type: "CheckpointLoaderSimple" },
|
||||
2: { inputs: { text: "positive", clip: ["1", 1] }, class_type: "CLIPTextEncode" },
|
||||
3: { inputs: { text: "negative", clip: ["1", 1] }, class_type: "CLIPTextEncode" },
|
||||
4: {
|
||||
inputs: { width: values.width ?? 512, height: values.height ?? 512, batch_size: values?.batch_size ?? 1 },
|
||||
class_type: "EmptyLatentImage",
|
||||
},
|
||||
5: {
|
||||
inputs: {
|
||||
seed: 0,
|
||||
steps: 20,
|
||||
cfg: 8,
|
||||
sampler_name: "euler",
|
||||
scheduler: values?.scheduler ?? "normal",
|
||||
denoise: 1,
|
||||
model: ["1", 0],
|
||||
positive: ["2", 0],
|
||||
negative: ["3", 0],
|
||||
latent_image: ["4", 0],
|
||||
},
|
||||
class_type: "KSampler",
|
||||
},
|
||||
6: { inputs: { samples: ["5", 0], vae: ["1", 2] }, class_type: "VAEDecode" },
|
||||
7: {
|
||||
inputs: { filename_prefix: values.filename_prefix ?? "ComfyUI", images: ["6", 0] },
|
||||
class_type: "SaveImage",
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
async function waitForWidget(node) {
|
||||
// widgets are created slightly after the graph is ready
|
||||
// hard to find an exact hook to get these so just wait for them to be ready
|
||||
for (let i = 0; i < 10; i++) {
|
||||
await new Promise((r) => setTimeout(r, 10));
|
||||
if (node.widgets?.value) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
it("can connect primitive via a reroute path to a widget input", async () => {
|
||||
const { ez, graph } = await start();
|
||||
const nodes = createDefaultWorkflow(ez, graph);
|
||||
|
||||
nodes.empty.widgets.width.convertToInput();
|
||||
nodes.sampler.widgets.scheduler.convertToInput();
|
||||
nodes.save.widgets.filename_prefix.convertToInput();
|
||||
|
||||
let widthReroute = ez.Reroute();
|
||||
let schedulerReroute = ez.Reroute();
|
||||
let fileReroute = ez.Reroute();
|
||||
|
||||
let widthNext = widthReroute;
|
||||
let schedulerNext = schedulerReroute;
|
||||
let fileNext = fileReroute;
|
||||
|
||||
for (let i = 0; i < 5; i++) {
|
||||
let next = ez.Reroute();
|
||||
widthNext.outputs[0].connectTo(next.inputs[0]);
|
||||
widthNext = next;
|
||||
|
||||
next = ez.Reroute();
|
||||
schedulerNext.outputs[0].connectTo(next.inputs[0]);
|
||||
schedulerNext = next;
|
||||
|
||||
next = ez.Reroute();
|
||||
fileNext.outputs[0].connectTo(next.inputs[0]);
|
||||
fileNext = next;
|
||||
}
|
||||
|
||||
widthNext.outputs[0].connectTo(nodes.empty.inputs.width);
|
||||
schedulerNext.outputs[0].connectTo(nodes.sampler.inputs.scheduler);
|
||||
fileNext.outputs[0].connectTo(nodes.save.inputs.filename_prefix);
|
||||
|
||||
let widthPrimitive = ez.PrimitiveNode();
|
||||
let schedulerPrimitive = ez.PrimitiveNode();
|
||||
let filePrimitive = ez.PrimitiveNode();
|
||||
|
||||
widthPrimitive.outputs[0].connectTo(widthReroute.inputs[0]);
|
||||
schedulerPrimitive.outputs[0].connectTo(schedulerReroute.inputs[0]);
|
||||
filePrimitive.outputs[0].connectTo(fileReroute.inputs[0]);
|
||||
expect(widthPrimitive.widgets.value.value).toBe(512);
|
||||
widthPrimitive.widgets.value.value = 1024;
|
||||
expect(schedulerPrimitive.widgets.value.value).toBe("normal");
|
||||
schedulerPrimitive.widgets.value.value = "simple";
|
||||
expect(filePrimitive.widgets.value.value).toBe("ComfyUI");
|
||||
filePrimitive.widgets.value.value = "ComfyTest";
|
||||
|
||||
await checkBeforeAndAfterReload(graph, async () => {
|
||||
widthPrimitive = graph.find(widthPrimitive);
|
||||
schedulerPrimitive = graph.find(schedulerPrimitive);
|
||||
filePrimitive = graph.find(filePrimitive);
|
||||
await waitForWidget(filePrimitive);
|
||||
expect(widthPrimitive.widgets.length).toBe(2);
|
||||
expect(schedulerPrimitive.widgets.length).toBe(3);
|
||||
expect(filePrimitive.widgets.length).toBe(1);
|
||||
|
||||
await checkOutput(graph, {
|
||||
width: 1024,
|
||||
scheduler: "simple",
|
||||
filename_prefix: "ComfyTest",
|
||||
});
|
||||
});
|
||||
});
|
||||
it("can connect primitive via a reroute path to multiple widget inputs", async () => {
|
||||
const { ez, graph } = await start();
|
||||
const nodes = createDefaultWorkflow(ez, graph);
|
||||
|
||||
nodes.empty.widgets.width.convertToInput();
|
||||
nodes.empty.widgets.height.convertToInput();
|
||||
nodes.empty.widgets.batch_size.convertToInput();
|
||||
|
||||
let reroute = ez.Reroute();
|
||||
let prevReroute = reroute;
|
||||
for (let i = 0; i < 5; i++) {
|
||||
const next = ez.Reroute();
|
||||
prevReroute.outputs[0].connectTo(next.inputs[0]);
|
||||
prevReroute = next;
|
||||
}
|
||||
|
||||
const r1 = ez.Reroute(prevReroute.outputs[0]);
|
||||
const r2 = ez.Reroute(prevReroute.outputs[0]);
|
||||
const r3 = ez.Reroute(r2.outputs[0]);
|
||||
const r4 = ez.Reroute(r2.outputs[0]);
|
||||
|
||||
r1.outputs[0].connectTo(nodes.empty.inputs.width);
|
||||
r3.outputs[0].connectTo(nodes.empty.inputs.height);
|
||||
r4.outputs[0].connectTo(nodes.empty.inputs.batch_size);
|
||||
|
||||
let primitive = ez.PrimitiveNode();
|
||||
primitive.outputs[0].connectTo(reroute.inputs[0]);
|
||||
expect(primitive.widgets.value.value).toBe(1);
|
||||
primitive.widgets.value.value = 64;
|
||||
|
||||
await checkBeforeAndAfterReload(graph, async (r) => {
|
||||
primitive = graph.find(primitive);
|
||||
await waitForWidget(primitive);
|
||||
|
||||
// Ensure widget configs are merged
|
||||
expect(primitive.widgets.value.widget.options?.min).toBe(16); // width/height min
|
||||
expect(primitive.widgets.value.widget.options?.max).toBe(4096); // batch max
|
||||
expect(primitive.widgets.value.widget.options?.step).toBe(80); // width/height step * 10
|
||||
|
||||
await checkOutput(graph, {
|
||||
width: 64,
|
||||
height: 64,
|
||||
batch_size: 64,
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
@@ -1,452 +0,0 @@
|
||||
// @ts-check
|
||||
/// <reference path="../../web/types/litegraph.d.ts" />
|
||||
|
||||
/**
|
||||
* @typedef { import("../../web/scripts/app")["app"] } app
|
||||
* @typedef { import("../../web/types/litegraph") } LG
|
||||
* @typedef { import("../../web/types/litegraph").IWidget } IWidget
|
||||
* @typedef { import("../../web/types/litegraph").ContextMenuItem } ContextMenuItem
|
||||
* @typedef { import("../../web/types/litegraph").INodeInputSlot } INodeInputSlot
|
||||
* @typedef { import("../../web/types/litegraph").INodeOutputSlot } INodeOutputSlot
|
||||
* @typedef { InstanceType<LG["LGraphNode"]> & { widgets?: Array<IWidget> } } LGNode
|
||||
* @typedef { (...args: EzOutput[] | [...EzOutput[], Record<string, unknown>]) => EzNode } EzNodeFactory
|
||||
*/
|
||||
|
||||
export class EzConnection {
|
||||
/** @type { app } */
|
||||
app;
|
||||
/** @type { InstanceType<LG["LLink"]> } */
|
||||
link;
|
||||
|
||||
get originNode() {
|
||||
return new EzNode(this.app, this.app.graph.getNodeById(this.link.origin_id));
|
||||
}
|
||||
|
||||
get originOutput() {
|
||||
return this.originNode.outputs[this.link.origin_slot];
|
||||
}
|
||||
|
||||
get targetNode() {
|
||||
return new EzNode(this.app, this.app.graph.getNodeById(this.link.target_id));
|
||||
}
|
||||
|
||||
get targetInput() {
|
||||
return this.targetNode.inputs[this.link.target_slot];
|
||||
}
|
||||
|
||||
/**
|
||||
* @param { app } app
|
||||
* @param { InstanceType<LG["LLink"]> } link
|
||||
*/
|
||||
constructor(app, link) {
|
||||
this.app = app;
|
||||
this.link = link;
|
||||
}
|
||||
|
||||
disconnect() {
|
||||
this.targetInput.disconnect();
|
||||
}
|
||||
}
|
||||
|
||||
export class EzSlot {
|
||||
/** @type { EzNode } */
|
||||
node;
|
||||
/** @type { number } */
|
||||
index;
|
||||
|
||||
/**
|
||||
* @param { EzNode } node
|
||||
* @param { number } index
|
||||
*/
|
||||
constructor(node, index) {
|
||||
this.node = node;
|
||||
this.index = index;
|
||||
}
|
||||
}
|
||||
|
||||
export class EzInput extends EzSlot {
|
||||
/** @type { INodeInputSlot } */
|
||||
input;
|
||||
|
||||
/**
|
||||
* @param { EzNode } node
|
||||
* @param { number } index
|
||||
* @param { INodeInputSlot } input
|
||||
*/
|
||||
constructor(node, index, input) {
|
||||
super(node, index);
|
||||
this.input = input;
|
||||
}
|
||||
|
||||
get connection() {
|
||||
const link = this.node.node.inputs?.[this.index]?.link;
|
||||
if (link == null) {
|
||||
return null;
|
||||
}
|
||||
return new EzConnection(this.node.app, this.node.app.graph.links[link]);
|
||||
}
|
||||
|
||||
disconnect() {
|
||||
this.node.node.disconnectInput(this.index);
|
||||
}
|
||||
}
|
||||
|
||||
export class EzOutput extends EzSlot {
|
||||
/** @type { INodeOutputSlot } */
|
||||
output;
|
||||
|
||||
/**
|
||||
* @param { EzNode } node
|
||||
* @param { number } index
|
||||
* @param { INodeOutputSlot } output
|
||||
*/
|
||||
constructor(node, index, output) {
|
||||
super(node, index);
|
||||
this.output = output;
|
||||
}
|
||||
|
||||
get connections() {
|
||||
return (this.node.node.outputs?.[this.index]?.links ?? []).map(
|
||||
(l) => new EzConnection(this.node.app, this.node.app.graph.links[l])
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* @param { EzInput } input
|
||||
*/
|
||||
connectTo(input) {
|
||||
if (!input) throw new Error("Invalid input");
|
||||
|
||||
/**
|
||||
* @type { LG["LLink"] | null }
|
||||
*/
|
||||
const link = this.node.node.connect(this.index, input.node.node, input.index);
|
||||
if (!link) {
|
||||
const inp = input.input;
|
||||
const inName = inp.name || inp.label || inp.type;
|
||||
throw new Error(
|
||||
`Connecting from ${input.node.node.type}#${input.node.id}[${inName}#${input.index}] -> ${this.node.node.type}#${this.node.id}[${
|
||||
this.output.name ?? this.output.type
|
||||
}#${this.index}] failed.`
|
||||
);
|
||||
}
|
||||
return link;
|
||||
}
|
||||
}
|
||||
|
||||
export class EzNodeMenuItem {
|
||||
/** @type { EzNode } */
|
||||
node;
|
||||
/** @type { number } */
|
||||
index;
|
||||
/** @type { ContextMenuItem } */
|
||||
item;
|
||||
|
||||
/**
|
||||
* @param { EzNode } node
|
||||
* @param { number } index
|
||||
* @param { ContextMenuItem } item
|
||||
*/
|
||||
constructor(node, index, item) {
|
||||
this.node = node;
|
||||
this.index = index;
|
||||
this.item = item;
|
||||
}
|
||||
|
||||
call(selectNode = true) {
|
||||
if (!this.item?.callback) throw new Error(`Menu Item ${this.item?.content ?? "[null]"} has no callback.`);
|
||||
if (selectNode) {
|
||||
this.node.select();
|
||||
}
|
||||
return this.item.callback.call(this.node.node, undefined, undefined, undefined, undefined, this.node.node);
|
||||
}
|
||||
}
|
||||
|
||||
export class EzWidget {
|
||||
/** @type { EzNode } */
|
||||
node;
|
||||
/** @type { number } */
|
||||
index;
|
||||
/** @type { IWidget } */
|
||||
widget;
|
||||
|
||||
/**
|
||||
* @param { EzNode } node
|
||||
* @param { number } index
|
||||
* @param { IWidget } widget
|
||||
*/
|
||||
constructor(node, index, widget) {
|
||||
this.node = node;
|
||||
this.index = index;
|
||||
this.widget = widget;
|
||||
}
|
||||
|
||||
get value() {
|
||||
return this.widget.value;
|
||||
}
|
||||
|
||||
set value(v) {
|
||||
this.widget.value = v;
|
||||
this.widget.callback?.call?.(this.widget, v)
|
||||
}
|
||||
|
||||
get isConvertedToInput() {
|
||||
// @ts-ignore : this type is valid for converted widgets
|
||||
return this.widget.type === "converted-widget";
|
||||
}
|
||||
|
||||
getConvertedInput() {
|
||||
if (!this.isConvertedToInput) throw new Error(`Widget ${this.widget.name} is not converted to input.`);
|
||||
|
||||
return this.node.inputs.find((inp) => inp.input["widget"]?.name === this.widget.name);
|
||||
}
|
||||
|
||||
convertToWidget() {
|
||||
if (!this.isConvertedToInput)
|
||||
throw new Error(`Widget ${this.widget.name} cannot be converted as it is already a widget.`);
|
||||
var menu = this.node.menu["Convert Input to Widget"].item.submenu.options;
|
||||
var index = menu.findIndex(a => a.content == `Convert ${this.widget.name} to widget`);
|
||||
menu[index].callback.call();
|
||||
}
|
||||
|
||||
convertToInput() {
|
||||
if (this.isConvertedToInput)
|
||||
throw new Error(`Widget ${this.widget.name} cannot be converted as it is already an input.`);
|
||||
var menu = this.node.menu["Convert Widget to Input"].item.submenu.options;
|
||||
var index = menu.findIndex(a => a.content == `Convert ${this.widget.name} to input`);
|
||||
menu[index].callback.call();
|
||||
}
|
||||
}
|
||||
|
||||
export class EzNode {
|
||||
/** @type { app } */
|
||||
app;
|
||||
/** @type { LGNode } */
|
||||
node;
|
||||
|
||||
/**
|
||||
* @param { app } app
|
||||
* @param { LGNode } node
|
||||
*/
|
||||
constructor(app, node) {
|
||||
this.app = app;
|
||||
this.node = node;
|
||||
}
|
||||
|
||||
get id() {
|
||||
return this.node.id;
|
||||
}
|
||||
|
||||
get inputs() {
|
||||
return this.#makeLookupArray("inputs", "name", EzInput);
|
||||
}
|
||||
|
||||
get outputs() {
|
||||
return this.#makeLookupArray("outputs", "name", EzOutput);
|
||||
}
|
||||
|
||||
get widgets() {
|
||||
return this.#makeLookupArray("widgets", "name", EzWidget);
|
||||
}
|
||||
|
||||
get menu() {
|
||||
return this.#makeLookupArray(() => this.app.canvas.getNodeMenuOptions(this.node), "content", EzNodeMenuItem);
|
||||
}
|
||||
|
||||
get isRemoved() {
|
||||
return !this.app.graph.getNodeById(this.id);
|
||||
}
|
||||
|
||||
select(addToSelection = false) {
|
||||
this.app.canvas.selectNode(this.node, addToSelection);
|
||||
}
|
||||
|
||||
// /**
|
||||
// * @template { "inputs" | "outputs" } T
|
||||
// * @param { T } type
|
||||
// * @returns { Record<string, type extends "inputs" ? EzInput : EzOutput> & (type extends "inputs" ? EzInput [] : EzOutput[]) }
|
||||
// */
|
||||
// #getSlotItems(type) {
|
||||
// // @ts-ignore : these items are correct
|
||||
// return (this.node[type] ?? []).reduce((p, s, i) => {
|
||||
// if (s.name in p) {
|
||||
// throw new Error(`Unable to store input ${s.name} on array as name conflicts.`);
|
||||
// }
|
||||
// // @ts-ignore
|
||||
// p.push((p[s.name] = new (type === "inputs" ? EzInput : EzOutput)(this, i, s)));
|
||||
// return p;
|
||||
// }, Object.assign([], { $: this }));
|
||||
// }
|
||||
|
||||
/**
|
||||
* @template { { new(node: EzNode, index: number, obj: any): any } } T
|
||||
* @param { "inputs" | "outputs" | "widgets" | (() => Array<unknown>) } nodeProperty
|
||||
* @param { string } nameProperty
|
||||
* @param { T } ctor
|
||||
* @returns { Record<string, InstanceType<T>> & Array<InstanceType<T>> }
|
||||
*/
|
||||
#makeLookupArray(nodeProperty, nameProperty, ctor) {
|
||||
const items = typeof nodeProperty === "function" ? nodeProperty() : this.node[nodeProperty];
|
||||
// @ts-ignore
|
||||
return (items ?? []).reduce((p, s, i) => {
|
||||
if (!s) return p;
|
||||
|
||||
const name = s[nameProperty];
|
||||
const item = new ctor(this, i, s);
|
||||
// @ts-ignore
|
||||
p.push(item);
|
||||
if (name) {
|
||||
// @ts-ignore
|
||||
if (name in p) {
|
||||
throw new Error(`Unable to store ${nodeProperty} ${name} on array as name conflicts.`);
|
||||
}
|
||||
}
|
||||
// @ts-ignore
|
||||
p[name] = item;
|
||||
return p;
|
||||
}, Object.assign([], { $: this }));
|
||||
}
|
||||
}
|
||||
|
||||
export class EzGraph {
|
||||
/** @type { app } */
|
||||
app;
|
||||
|
||||
/**
|
||||
* @param { app } app
|
||||
*/
|
||||
constructor(app) {
|
||||
this.app = app;
|
||||
}
|
||||
|
||||
get nodes() {
|
||||
return this.app.graph._nodes.map((n) => new EzNode(this.app, n));
|
||||
}
|
||||
|
||||
clear() {
|
||||
this.app.graph.clear();
|
||||
}
|
||||
|
||||
arrange() {
|
||||
this.app.graph.arrange();
|
||||
}
|
||||
|
||||
stringify() {
|
||||
return JSON.stringify(this.app.graph.serialize(), undefined);
|
||||
}
|
||||
|
||||
/**
|
||||
* @param { number | LGNode | EzNode } obj
|
||||
* @returns { EzNode }
|
||||
*/
|
||||
find(obj) {
|
||||
let match;
|
||||
let id;
|
||||
if (typeof obj === "number") {
|
||||
id = obj;
|
||||
} else {
|
||||
id = obj.id;
|
||||
}
|
||||
|
||||
match = this.app.graph.getNodeById(id);
|
||||
|
||||
if (!match) {
|
||||
throw new Error(`Unable to find node with ID ${id}.`);
|
||||
}
|
||||
|
||||
return new EzNode(this.app, match);
|
||||
}
|
||||
|
||||
/**
|
||||
* @returns { Promise<void> }
|
||||
*/
|
||||
reload() {
|
||||
const graph = JSON.parse(JSON.stringify(this.app.graph.serialize()));
|
||||
return new Promise((r) => {
|
||||
this.app.graph.clear();
|
||||
setTimeout(async () => {
|
||||
await this.app.loadGraphData(graph);
|
||||
r();
|
||||
}, 10);
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* @returns { Promise<{
|
||||
* workflow: {},
|
||||
* output: Record<string, {
|
||||
* class_name: string,
|
||||
* inputs: Record<string, [string, number] | unknown>
|
||||
* }>}> }
|
||||
*/
|
||||
toPrompt() {
|
||||
// @ts-ignore
|
||||
return this.app.graphToPrompt();
|
||||
}
|
||||
}
|
||||
|
||||
export const Ez = {
|
||||
/**
|
||||
* Quickly build and interact with a ComfyUI graph
|
||||
* @example
|
||||
* const { ez, graph } = Ez.graph(app);
|
||||
* graph.clear();
|
||||
* const [model, clip, vae] = ez.CheckpointLoaderSimple().outputs;
|
||||
* const [pos] = ez.CLIPTextEncode(clip, { text: "positive" }).outputs;
|
||||
* const [neg] = ez.CLIPTextEncode(clip, { text: "negative" }).outputs;
|
||||
* const [latent] = ez.KSampler(model, pos, neg, ...ez.EmptyLatentImage().outputs).outputs;
|
||||
* const [image] = ez.VAEDecode(latent, vae).outputs;
|
||||
* const saveNode = ez.SaveImage(image);
|
||||
* console.log(saveNode);
|
||||
* graph.arrange();
|
||||
* @param { app } app
|
||||
* @param { LG["LiteGraph"] } LiteGraph
|
||||
* @param { LG["LGraphCanvas"] } LGraphCanvas
|
||||
* @param { boolean } clearGraph
|
||||
* @returns { { graph: EzGraph, ez: Record<string, EzNodeFactory> } }
|
||||
*/
|
||||
graph(app, LiteGraph = window["LiteGraph"], LGraphCanvas = window["LGraphCanvas"], clearGraph = true) {
|
||||
// Always set the active canvas so things work
|
||||
LGraphCanvas.active_canvas = app.canvas;
|
||||
|
||||
if (clearGraph) {
|
||||
app.graph.clear();
|
||||
}
|
||||
|
||||
// @ts-ignore : this proxy handles utility methods & node creation
|
||||
const factory = new Proxy(
|
||||
{},
|
||||
{
|
||||
get(_, p) {
|
||||
if (typeof p !== "string") throw new Error("Invalid node");
|
||||
const node = LiteGraph.createNode(p);
|
||||
if (!node) throw new Error(`Unknown node "${p}"`);
|
||||
app.graph.add(node);
|
||||
|
||||
/**
|
||||
* @param {Parameters<EzNodeFactory>} args
|
||||
*/
|
||||
return function (...args) {
|
||||
const ezNode = new EzNode(app, node);
|
||||
const inputs = ezNode.inputs;
|
||||
|
||||
let slot = 0;
|
||||
for (const arg of args) {
|
||||
if (arg instanceof EzOutput) {
|
||||
arg.connectTo(inputs[slot++]);
|
||||
} else {
|
||||
for (const k in arg) {
|
||||
ezNode.widgets[k].value = arg[k];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return ezNode;
|
||||
};
|
||||
},
|
||||
}
|
||||
);
|
||||
|
||||
return { graph: new EzGraph(app), ez: factory };
|
||||
},
|
||||
};
|
@@ -1,129 +0,0 @@
|
||||
const { mockApi } = require("./setup");
|
||||
const { Ez } = require("./ezgraph");
|
||||
const lg = require("./litegraph");
|
||||
const fs = require("fs");
|
||||
const path = require("path");
|
||||
|
||||
const html = fs.readFileSync(path.resolve(__dirname, "../../web/index.html"))
|
||||
|
||||
/**
|
||||
*
|
||||
* @param { Parameters<typeof mockApi>[0] & {
|
||||
* resetEnv?: boolean,
|
||||
* preSetup?(app): Promise<void>,
|
||||
* localStorage?: Record<string, string>
|
||||
* } } config
|
||||
* @returns
|
||||
*/
|
||||
export async function start(config = {}) {
|
||||
if(config.resetEnv) {
|
||||
jest.resetModules();
|
||||
jest.resetAllMocks();
|
||||
lg.setup(global);
|
||||
localStorage.clear();
|
||||
sessionStorage.clear();
|
||||
}
|
||||
|
||||
Object.assign(localStorage, config.localStorage ?? {});
|
||||
document.body.innerHTML = html;
|
||||
|
||||
mockApi(config);
|
||||
const { app } = require("../../web/scripts/app");
|
||||
config.preSetup?.(app);
|
||||
await app.setup();
|
||||
|
||||
return { ...Ez.graph(app, global["LiteGraph"], global["LGraphCanvas"]), app };
|
||||
}
|
||||
|
||||
/**
|
||||
* @param { ReturnType<Ez["graph"]>["graph"] } graph
|
||||
* @param { (hasReloaded: boolean) => (Promise<void> | void) } cb
|
||||
*/
|
||||
export async function checkBeforeAndAfterReload(graph, cb) {
|
||||
await cb(false);
|
||||
await graph.reload();
|
||||
await cb(true);
|
||||
}
|
||||
|
||||
/**
|
||||
* @param { string } name
|
||||
* @param { Record<string, string | [string | string[], any]> } input
|
||||
* @param { (string | string[])[] | Record<string, string | string[]> } output
|
||||
* @returns { Record<string, import("../../web/types/comfy").ComfyObjectInfo> }
|
||||
*/
|
||||
export function makeNodeDef(name, input, output = {}) {
|
||||
const nodeDef = {
|
||||
name,
|
||||
category: "test",
|
||||
output: [],
|
||||
output_name: [],
|
||||
output_is_list: [],
|
||||
input: {
|
||||
required: {},
|
||||
},
|
||||
};
|
||||
for (const k in input) {
|
||||
nodeDef.input.required[k] = typeof input[k] === "string" ? [input[k], {}] : [...input[k]];
|
||||
}
|
||||
if (output instanceof Array) {
|
||||
output = output.reduce((p, c) => {
|
||||
p[c] = c;
|
||||
return p;
|
||||
}, {});
|
||||
}
|
||||
for (const k in output) {
|
||||
nodeDef.output.push(output[k]);
|
||||
nodeDef.output_name.push(k);
|
||||
nodeDef.output_is_list.push(false);
|
||||
}
|
||||
|
||||
return { [name]: nodeDef };
|
||||
}
|
||||
|
||||
/**
|
||||
/**
|
||||
* @template { any } T
|
||||
* @param { T } x
|
||||
* @returns { x is Exclude<T, null | undefined> }
|
||||
*/
|
||||
export function assertNotNullOrUndefined(x) {
|
||||
expect(x).not.toEqual(null);
|
||||
expect(x).not.toEqual(undefined);
|
||||
return true;
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param { ReturnType<Ez["graph"]>["ez"] } ez
|
||||
* @param { ReturnType<Ez["graph"]>["graph"] } graph
|
||||
*/
|
||||
export function createDefaultWorkflow(ez, graph) {
|
||||
graph.clear();
|
||||
const ckpt = ez.CheckpointLoaderSimple();
|
||||
|
||||
const pos = ez.CLIPTextEncode(ckpt.outputs.CLIP, { text: "positive" });
|
||||
const neg = ez.CLIPTextEncode(ckpt.outputs.CLIP, { text: "negative" });
|
||||
|
||||
const empty = ez.EmptyLatentImage();
|
||||
const sampler = ez.KSampler(
|
||||
ckpt.outputs.MODEL,
|
||||
pos.outputs.CONDITIONING,
|
||||
neg.outputs.CONDITIONING,
|
||||
empty.outputs.LATENT
|
||||
);
|
||||
|
||||
const decode = ez.VAEDecode(sampler.outputs.LATENT, ckpt.outputs.VAE);
|
||||
const save = ez.SaveImage(decode.outputs.IMAGE);
|
||||
graph.arrange();
|
||||
|
||||
return { ckpt, pos, neg, empty, sampler, decode, save };
|
||||
}
|
||||
|
||||
export async function getNodeDefs() {
|
||||
const { api } = require("../../web/scripts/api");
|
||||
return api.getNodeDefs();
|
||||
}
|
||||
|
||||
export async function getNodeDef(nodeId) {
|
||||
return (await getNodeDefs())[nodeId];
|
||||
}
|
@@ -1,36 +0,0 @@
|
||||
const fs = require("fs");
|
||||
const path = require("path");
|
||||
const { nop } = require("../utils/nopProxy");
|
||||
|
||||
function forEachKey(cb) {
|
||||
for (const k of [
|
||||
"LiteGraph",
|
||||
"LGraph",
|
||||
"LLink",
|
||||
"LGraphNode",
|
||||
"LGraphGroup",
|
||||
"DragAndScale",
|
||||
"LGraphCanvas",
|
||||
"ContextMenu",
|
||||
]) {
|
||||
cb(k);
|
||||
}
|
||||
}
|
||||
|
||||
export function setup(ctx) {
|
||||
const lg = fs.readFileSync(path.resolve("../web/lib/litegraph.core.js"), "utf-8");
|
||||
const globalTemp = {};
|
||||
(function (console) {
|
||||
eval(lg);
|
||||
}).call(globalTemp, nop);
|
||||
|
||||
forEachKey((k) => (ctx[k] = globalTemp[k]));
|
||||
require(path.resolve("../web/lib/litegraph.extensions.js"));
|
||||
}
|
||||
|
||||
export function teardown(ctx) {
|
||||
forEachKey((k) => delete ctx[k]);
|
||||
|
||||
// Clear document after each run
|
||||
document.getElementsByTagName("html")[0].innerHTML = "";
|
||||
}
|
@@ -1,6 +0,0 @@
|
||||
export const nop = new Proxy(function () {}, {
|
||||
get: () => nop,
|
||||
set: () => true,
|
||||
apply: () => nop,
|
||||
construct: () => nop,
|
||||
});
|
@@ -1,82 +0,0 @@
|
||||
require("../../web/scripts/api");
|
||||
|
||||
const fs = require("fs");
|
||||
const path = require("path");
|
||||
function* walkSync(dir) {
|
||||
const files = fs.readdirSync(dir, { withFileTypes: true });
|
||||
for (const file of files) {
|
||||
if (file.isDirectory()) {
|
||||
yield* walkSync(path.join(dir, file.name));
|
||||
} else {
|
||||
yield path.join(dir, file.name);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @typedef { import("../../web/types/comfy").ComfyObjectInfo } ComfyObjectInfo
|
||||
*/
|
||||
|
||||
/**
|
||||
* @param {{
|
||||
* mockExtensions?: string[],
|
||||
* mockNodeDefs?: Record<string, ComfyObjectInfo>,
|
||||
* settings?: Record<string, string>
|
||||
* userConfig?: {storage: "server" | "browser", users?: Record<string, any>, migrated?: boolean },
|
||||
* userData?: Record<string, any>
|
||||
* }} config
|
||||
*/
|
||||
export function mockApi(config = {}) {
|
||||
let { mockExtensions, mockNodeDefs, userConfig, settings, userData } = {
|
||||
userConfig,
|
||||
settings: {},
|
||||
userData: {},
|
||||
...config,
|
||||
};
|
||||
if (!mockExtensions) {
|
||||
mockExtensions = Array.from(walkSync(path.resolve("../web/extensions/core")))
|
||||
.filter((x) => x.endsWith(".js"))
|
||||
.map((x) => path.relative(path.resolve("../web"), x));
|
||||
}
|
||||
if (!mockNodeDefs) {
|
||||
mockNodeDefs = JSON.parse(fs.readFileSync(path.resolve("./data/object_info.json")));
|
||||
}
|
||||
|
||||
const events = new EventTarget();
|
||||
const mockApi = {
|
||||
addEventListener: events.addEventListener.bind(events),
|
||||
removeEventListener: events.removeEventListener.bind(events),
|
||||
dispatchEvent: events.dispatchEvent.bind(events),
|
||||
getSystemStats: jest.fn(),
|
||||
getExtensions: jest.fn(() => mockExtensions),
|
||||
getNodeDefs: jest.fn(() => mockNodeDefs),
|
||||
init: jest.fn(),
|
||||
apiURL: jest.fn((x) => "../../web/" + x),
|
||||
createUser: jest.fn((username) => {
|
||||
if(username in userConfig.users) {
|
||||
return { status: 400, json: () => "Duplicate" }
|
||||
}
|
||||
userConfig.users[username + "!"] = username;
|
||||
return { status: 200, json: () => username + "!" }
|
||||
}),
|
||||
getUserConfig: jest.fn(() => userConfig ?? { storage: "browser", migrated: false }),
|
||||
getSettings: jest.fn(() => settings),
|
||||
storeSettings: jest.fn((v) => Object.assign(settings, v)),
|
||||
getUserData: jest.fn((f) => {
|
||||
if (f in userData) {
|
||||
return { status: 200, json: () => userData[f] };
|
||||
} else {
|
||||
return { status: 404 };
|
||||
}
|
||||
}),
|
||||
storeUserData: jest.fn((file, data) => {
|
||||
userData[file] = data;
|
||||
}),
|
||||
listUserData: jest.fn(() => [])
|
||||
};
|
||||
jest.mock("../../web/scripts/api", () => ({
|
||||
get api() {
|
||||
return mockApi;
|
||||
},
|
||||
}));
|
||||
}
|
@@ -1,6 +1,7 @@
|
||||
import argparse
|
||||
import pytest
|
||||
from requests.exceptions import HTTPError
|
||||
from unittest.mock import patch
|
||||
|
||||
from app.frontend_management import (
|
||||
FrontendManager,
|
||||
@@ -83,6 +84,35 @@ def test_init_frontend_invalid_provider():
|
||||
with pytest.raises(HTTPError):
|
||||
FrontendManager.init_frontend_unsafe(version_string)
|
||||
|
||||
@pytest.fixture
|
||||
def mock_os_functions():
|
||||
with patch('app.frontend_management.os.makedirs') as mock_makedirs, \
|
||||
patch('app.frontend_management.os.listdir') as mock_listdir, \
|
||||
patch('app.frontend_management.os.rmdir') as mock_rmdir:
|
||||
mock_listdir.return_value = [] # Simulate empty directory
|
||||
yield mock_makedirs, mock_listdir, mock_rmdir
|
||||
|
||||
@pytest.fixture
|
||||
def mock_download():
|
||||
with patch('app.frontend_management.download_release_asset_zip') as mock:
|
||||
mock.side_effect = Exception("Download failed") # Simulate download failure
|
||||
yield mock
|
||||
|
||||
def test_finally_block(mock_os_functions, mock_download, mock_provider):
|
||||
# Arrange
|
||||
mock_makedirs, mock_listdir, mock_rmdir = mock_os_functions
|
||||
version_string = 'test-owner/test-repo@1.0.0'
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(Exception):
|
||||
FrontendManager.init_frontend_unsafe(version_string, mock_provider)
|
||||
|
||||
# Assert
|
||||
mock_makedirs.assert_called_once()
|
||||
mock_download.assert_called_once()
|
||||
mock_listdir.assert_called_once()
|
||||
mock_rmdir.assert_called_once()
|
||||
|
||||
|
||||
def test_parse_version_string():
|
||||
version_string = "owner/repo@1.0.0"
|
||||
|
0
tests-unit/prompt_server_test/__init__.py
Normal file
0
tests-unit/prompt_server_test/__init__.py
Normal file
321
tests-unit/prompt_server_test/download_models_test.py
Normal file
321
tests-unit/prompt_server_test/download_models_test.py
Normal file
@@ -0,0 +1,321 @@
|
||||
import pytest
|
||||
import aiohttp
|
||||
from aiohttp import ClientResponse
|
||||
import itertools
|
||||
import os
|
||||
from unittest.mock import AsyncMock, patch, MagicMock
|
||||
from model_filemanager import download_model, validate_model_subdirectory, track_download_progress, create_model_path, check_file_exists, DownloadStatusType, DownloadModelStatus, validate_filename
|
||||
|
||||
class AsyncIteratorMock:
|
||||
"""
|
||||
A mock class that simulates an asynchronous iterator.
|
||||
This is used to mimic the behavior of aiohttp's content iterator.
|
||||
"""
|
||||
def __init__(self, seq):
|
||||
# Convert the input sequence into an iterator
|
||||
self.iter = iter(seq)
|
||||
|
||||
def __aiter__(self):
|
||||
# This method is called when 'async for' is used
|
||||
return self
|
||||
|
||||
async def __anext__(self):
|
||||
# This method is called for each iteration in an 'async for' loop
|
||||
try:
|
||||
return next(self.iter)
|
||||
except StopIteration:
|
||||
# This is the asynchronous equivalent of StopIteration
|
||||
raise StopAsyncIteration
|
||||
|
||||
class ContentMock:
|
||||
"""
|
||||
A mock class that simulates the content attribute of an aiohttp ClientResponse.
|
||||
This class provides the iter_chunked method which returns an async iterator of chunks.
|
||||
"""
|
||||
def __init__(self, chunks):
|
||||
# Store the chunks that will be returned by the iterator
|
||||
self.chunks = chunks
|
||||
|
||||
def iter_chunked(self, chunk_size):
|
||||
# This method mimics aiohttp's content.iter_chunked()
|
||||
# For simplicity in testing, we ignore chunk_size and just return our predefined chunks
|
||||
return AsyncIteratorMock(self.chunks)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_download_model_success():
|
||||
mock_response = AsyncMock(spec=aiohttp.ClientResponse)
|
||||
mock_response.status = 200
|
||||
mock_response.headers = {'Content-Length': '1000'}
|
||||
# Create a mock for content that returns an async iterator directly
|
||||
chunks = [b'a' * 500, b'b' * 300, b'c' * 200]
|
||||
mock_response.content = ContentMock(chunks)
|
||||
|
||||
mock_make_request = AsyncMock(return_value=mock_response)
|
||||
mock_progress_callback = AsyncMock()
|
||||
|
||||
# Mock file operations
|
||||
mock_open = MagicMock()
|
||||
mock_file = MagicMock()
|
||||
mock_open.return_value.__enter__.return_value = mock_file
|
||||
time_values = itertools.count(0, 0.1)
|
||||
|
||||
with patch('model_filemanager.create_model_path', return_value=('models/checkpoints/model.sft', 'checkpoints/model.sft')), \
|
||||
patch('model_filemanager.check_file_exists', return_value=None), \
|
||||
patch('builtins.open', mock_open), \
|
||||
patch('time.time', side_effect=time_values): # Simulate time passing
|
||||
|
||||
result = await download_model(
|
||||
mock_make_request,
|
||||
'model.sft',
|
||||
'http://example.com/model.sft',
|
||||
'checkpoints',
|
||||
mock_progress_callback
|
||||
)
|
||||
|
||||
# Assert the result
|
||||
assert isinstance(result, DownloadModelStatus)
|
||||
assert result.message == 'Successfully downloaded model.sft'
|
||||
assert result.status == 'completed'
|
||||
assert result.already_existed is False
|
||||
|
||||
# Check progress callback calls
|
||||
assert mock_progress_callback.call_count >= 3 # At least start, one progress update, and completion
|
||||
|
||||
# Check initial call
|
||||
mock_progress_callback.assert_any_call(
|
||||
'checkpoints/model.sft',
|
||||
DownloadModelStatus(DownloadStatusType.PENDING, 0, "Starting download of model.sft", False)
|
||||
)
|
||||
|
||||
# Check final call
|
||||
mock_progress_callback.assert_any_call(
|
||||
'checkpoints/model.sft',
|
||||
DownloadModelStatus(DownloadStatusType.COMPLETED, 100, "Successfully downloaded model.sft", False)
|
||||
)
|
||||
|
||||
# Verify file writing
|
||||
mock_file.write.assert_any_call(b'a' * 500)
|
||||
mock_file.write.assert_any_call(b'b' * 300)
|
||||
mock_file.write.assert_any_call(b'c' * 200)
|
||||
|
||||
# Verify request was made
|
||||
mock_make_request.assert_called_once_with('http://example.com/model.sft')
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_download_model_url_request_failure():
|
||||
# Mock dependencies
|
||||
mock_response = AsyncMock(spec=ClientResponse)
|
||||
mock_response.status = 404 # Simulate a "Not Found" error
|
||||
mock_get = AsyncMock(return_value=mock_response)
|
||||
mock_progress_callback = AsyncMock()
|
||||
|
||||
# Mock the create_model_path function
|
||||
with patch('model_filemanager.create_model_path', return_value=('/mock/path/model.safetensors', 'mock/path/model.safetensors')):
|
||||
# Mock the check_file_exists function to return None (file doesn't exist)
|
||||
with patch('model_filemanager.check_file_exists', return_value=None):
|
||||
# Call the function
|
||||
result = await download_model(
|
||||
mock_get,
|
||||
'model.safetensors',
|
||||
'http://example.com/model.safetensors',
|
||||
'mock_directory',
|
||||
mock_progress_callback
|
||||
)
|
||||
|
||||
# Assert the expected behavior
|
||||
assert isinstance(result, DownloadModelStatus)
|
||||
assert result.status == 'error'
|
||||
assert result.message == 'Failed to download model.safetensors. Status code: 404'
|
||||
assert result.already_existed is False
|
||||
|
||||
# Check that progress_callback was called with the correct arguments
|
||||
mock_progress_callback.assert_any_call(
|
||||
'mock_directory/model.safetensors',
|
||||
DownloadModelStatus(
|
||||
status=DownloadStatusType.PENDING,
|
||||
progress_percentage=0,
|
||||
message='Starting download of model.safetensors',
|
||||
already_existed=False
|
||||
)
|
||||
)
|
||||
mock_progress_callback.assert_called_with(
|
||||
'mock_directory/model.safetensors',
|
||||
DownloadModelStatus(
|
||||
status=DownloadStatusType.ERROR,
|
||||
progress_percentage=0,
|
||||
message='Failed to download model.safetensors. Status code: 404',
|
||||
already_existed=False
|
||||
)
|
||||
)
|
||||
|
||||
# Verify that the get method was called with the correct URL
|
||||
mock_get.assert_called_once_with('http://example.com/model.safetensors')
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_download_model_invalid_model_subdirectory():
|
||||
|
||||
mock_make_request = AsyncMock()
|
||||
mock_progress_callback = AsyncMock()
|
||||
|
||||
|
||||
result = await download_model(
|
||||
mock_make_request,
|
||||
'model.sft',
|
||||
'http://example.com/model.sft',
|
||||
'../bad_path',
|
||||
mock_progress_callback
|
||||
)
|
||||
|
||||
# Assert the result
|
||||
assert isinstance(result, DownloadModelStatus)
|
||||
assert result.message == 'Invalid model subdirectory'
|
||||
assert result.status == 'error'
|
||||
assert result.already_existed is False
|
||||
|
||||
|
||||
# For create_model_path function
|
||||
def test_create_model_path(tmp_path, monkeypatch):
|
||||
mock_models_dir = tmp_path / "models"
|
||||
monkeypatch.setattr('folder_paths.models_dir', str(mock_models_dir))
|
||||
|
||||
model_name = "test_model.sft"
|
||||
model_directory = "test_dir"
|
||||
|
||||
file_path, relative_path = create_model_path(model_name, model_directory, mock_models_dir)
|
||||
|
||||
assert file_path == str(mock_models_dir / model_directory / model_name)
|
||||
assert relative_path == f"{model_directory}/{model_name}"
|
||||
assert os.path.exists(os.path.dirname(file_path))
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_file_exists_when_file_exists(tmp_path):
|
||||
file_path = tmp_path / "existing_model.sft"
|
||||
file_path.touch() # Create an empty file
|
||||
|
||||
mock_callback = AsyncMock()
|
||||
|
||||
result = await check_file_exists(str(file_path), "existing_model.sft", mock_callback, "test/existing_model.sft")
|
||||
|
||||
assert result is not None
|
||||
assert result.status == "completed"
|
||||
assert result.message == "existing_model.sft already exists"
|
||||
assert result.already_existed is True
|
||||
|
||||
mock_callback.assert_called_once_with(
|
||||
"test/existing_model.sft",
|
||||
DownloadModelStatus(DownloadStatusType.COMPLETED, 100, "existing_model.sft already exists", already_existed=True)
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_file_exists_when_file_does_not_exist(tmp_path):
|
||||
file_path = tmp_path / "non_existing_model.sft"
|
||||
|
||||
mock_callback = AsyncMock()
|
||||
|
||||
result = await check_file_exists(str(file_path), "non_existing_model.sft", mock_callback, "test/non_existing_model.sft")
|
||||
|
||||
assert result is None
|
||||
mock_callback.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_track_download_progress_no_content_length():
|
||||
mock_response = AsyncMock(spec=aiohttp.ClientResponse)
|
||||
mock_response.headers = {} # No Content-Length header
|
||||
mock_response.content.iter_chunked.return_value = AsyncIteratorMock([b'a' * 500, b'b' * 500])
|
||||
|
||||
mock_callback = AsyncMock()
|
||||
mock_open = MagicMock(return_value=MagicMock())
|
||||
|
||||
with patch('builtins.open', mock_open):
|
||||
result = await track_download_progress(
|
||||
mock_response, '/mock/path/model.sft', 'model.sft',
|
||||
mock_callback, 'models/model.sft', interval=0.1
|
||||
)
|
||||
|
||||
assert result.status == "completed"
|
||||
# Check that progress was reported even without knowing the total size
|
||||
mock_callback.assert_any_call(
|
||||
'models/model.sft',
|
||||
DownloadModelStatus(DownloadStatusType.IN_PROGRESS, 0, "Downloading model.sft", already_existed=False)
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_track_download_progress_interval():
|
||||
mock_response = AsyncMock(spec=aiohttp.ClientResponse)
|
||||
mock_response.headers = {'Content-Length': '1000'}
|
||||
mock_response.content.iter_chunked.return_value = AsyncIteratorMock([b'a' * 100] * 10)
|
||||
|
||||
mock_callback = AsyncMock()
|
||||
mock_open = MagicMock(return_value=MagicMock())
|
||||
|
||||
# Create a mock time function that returns incremental float values
|
||||
mock_time = MagicMock()
|
||||
mock_time.side_effect = [i * 0.5 for i in range(30)] # This should be enough for 10 chunks
|
||||
|
||||
with patch('builtins.open', mock_open), \
|
||||
patch('time.time', mock_time):
|
||||
await track_download_progress(
|
||||
mock_response, '/mock/path/model.sft', 'model.sft',
|
||||
mock_callback, 'models/model.sft', interval=1.0
|
||||
)
|
||||
|
||||
# Print out the actual call count and the arguments of each call for debugging
|
||||
print(f"mock_callback was called {mock_callback.call_count} times")
|
||||
for i, call in enumerate(mock_callback.call_args_list):
|
||||
args, kwargs = call
|
||||
print(f"Call {i + 1}: {args[1].status}, Progress: {args[1].progress_percentage:.2f}%")
|
||||
|
||||
# Assert that progress was updated at least 3 times (start, at least one interval, and end)
|
||||
assert mock_callback.call_count >= 3, f"Expected at least 3 calls, but got {mock_callback.call_count}"
|
||||
|
||||
# Verify the first and last calls
|
||||
first_call = mock_callback.call_args_list[0]
|
||||
assert first_call[0][1].status == "in_progress"
|
||||
# Allow for some initial progress, but it should be less than 50%
|
||||
assert 0 <= first_call[0][1].progress_percentage < 50, f"First call progress was {first_call[0][1].progress_percentage}%"
|
||||
|
||||
last_call = mock_callback.call_args_list[-1]
|
||||
assert last_call[0][1].status == "completed"
|
||||
assert last_call[0][1].progress_percentage == 100
|
||||
|
||||
def test_valid_subdirectory():
|
||||
assert validate_model_subdirectory("valid-model123") is True
|
||||
|
||||
def test_subdirectory_too_long():
|
||||
assert validate_model_subdirectory("a" * 51) is False
|
||||
|
||||
def test_subdirectory_with_double_dots():
|
||||
assert validate_model_subdirectory("model/../unsafe") is False
|
||||
|
||||
def test_subdirectory_with_slash():
|
||||
assert validate_model_subdirectory("model/unsafe") is False
|
||||
|
||||
def test_subdirectory_with_special_characters():
|
||||
assert validate_model_subdirectory("model@unsafe") is False
|
||||
|
||||
def test_subdirectory_with_underscore_and_dash():
|
||||
assert validate_model_subdirectory("valid_model-name") is True
|
||||
|
||||
def test_empty_subdirectory():
|
||||
assert validate_model_subdirectory("") is False
|
||||
|
||||
@pytest.mark.parametrize("filename, expected", [
|
||||
("valid_model.safetensors", True),
|
||||
("valid_model.sft", True),
|
||||
("valid model.safetensors", True), # Test with space
|
||||
("UPPERCASE_MODEL.SAFETENSORS", True),
|
||||
("model_with.multiple.dots.pt", False),
|
||||
("", False), # Empty string
|
||||
("../../../etc/passwd", False), # Path traversal attempt
|
||||
("/etc/passwd", False), # Absolute path
|
||||
("\\windows\\system32\\config\\sam", False), # Windows path
|
||||
(".hidden_file.pt", False), # Hidden file
|
||||
("invalid<char>.ckpt", False), # Invalid character
|
||||
("invalid?.ckpt", False), # Another invalid character
|
||||
("very" * 100 + ".safetensors", False), # Too long filename
|
||||
("\nmodel_with_newline.pt", False), # Newline character
|
||||
("model_with_emoji😊.pt", False), # Emoji in filename
|
||||
])
|
||||
def test_validate_filename(filename, expected):
|
||||
assert validate_filename(filename) == expected
|
@@ -1 +1,3 @@
|
||||
pytest>=7.8.0
|
||||
pytest-aiohttp
|
||||
pytest-asyncio
|
||||
|
115
tests-unit/server/routes/internal_routes_test.py
Normal file
115
tests-unit/server/routes/internal_routes_test.py
Normal file
@@ -0,0 +1,115 @@
|
||||
import pytest
|
||||
from aiohttp import web
|
||||
from unittest.mock import MagicMock, patch
|
||||
from api_server.routes.internal.internal_routes import InternalRoutes
|
||||
from api_server.services.file_service import FileService
|
||||
from folder_paths import models_dir, user_directory, output_directory
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def internal_routes():
|
||||
return InternalRoutes()
|
||||
|
||||
@pytest.fixture
|
||||
def aiohttp_client_factory(aiohttp_client, internal_routes):
|
||||
async def _get_client():
|
||||
app = internal_routes.get_app()
|
||||
return await aiohttp_client(app)
|
||||
return _get_client
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_files_valid_directory(aiohttp_client_factory, internal_routes):
|
||||
mock_file_list = [
|
||||
{"name": "file1.txt", "path": "file1.txt", "type": "file", "size": 100},
|
||||
{"name": "dir1", "path": "dir1", "type": "directory"}
|
||||
]
|
||||
internal_routes.file_service.list_files = MagicMock(return_value=mock_file_list)
|
||||
client = await aiohttp_client_factory()
|
||||
resp = await client.get('/files?directory=models')
|
||||
assert resp.status == 200
|
||||
data = await resp.json()
|
||||
assert 'files' in data
|
||||
assert len(data['files']) == 2
|
||||
assert data['files'] == mock_file_list
|
||||
|
||||
# Check other valid directories
|
||||
resp = await client.get('/files?directory=user')
|
||||
assert resp.status == 200
|
||||
resp = await client.get('/files?directory=output')
|
||||
assert resp.status == 200
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_files_invalid_directory(aiohttp_client_factory, internal_routes):
|
||||
internal_routes.file_service.list_files = MagicMock(side_effect=ValueError("Invalid directory key"))
|
||||
client = await aiohttp_client_factory()
|
||||
resp = await client.get('/files?directory=invalid')
|
||||
assert resp.status == 400
|
||||
data = await resp.json()
|
||||
assert 'error' in data
|
||||
assert data['error'] == "Invalid directory key"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_files_exception(aiohttp_client_factory, internal_routes):
|
||||
internal_routes.file_service.list_files = MagicMock(side_effect=Exception("Unexpected error"))
|
||||
client = await aiohttp_client_factory()
|
||||
resp = await client.get('/files?directory=models')
|
||||
assert resp.status == 500
|
||||
data = await resp.json()
|
||||
assert 'error' in data
|
||||
assert data['error'] == "Unexpected error"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_files_no_directory_param(aiohttp_client_factory, internal_routes):
|
||||
mock_file_list = []
|
||||
internal_routes.file_service.list_files = MagicMock(return_value=mock_file_list)
|
||||
client = await aiohttp_client_factory()
|
||||
resp = await client.get('/files')
|
||||
assert resp.status == 200
|
||||
data = await resp.json()
|
||||
assert 'files' in data
|
||||
assert len(data['files']) == 0
|
||||
|
||||
def test_setup_routes(internal_routes):
|
||||
internal_routes.setup_routes()
|
||||
routes = internal_routes.routes
|
||||
assert any(route.method == 'GET' and str(route.path) == '/files' for route in routes)
|
||||
|
||||
def test_get_app(internal_routes):
|
||||
app = internal_routes.get_app()
|
||||
assert isinstance(app, web.Application)
|
||||
assert internal_routes._app is not None
|
||||
|
||||
def test_get_app_reuse(internal_routes):
|
||||
app1 = internal_routes.get_app()
|
||||
app2 = internal_routes.get_app()
|
||||
assert app1 is app2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_routes_added_to_app(aiohttp_client_factory, internal_routes):
|
||||
client = await aiohttp_client_factory()
|
||||
try:
|
||||
resp = await client.get('/files')
|
||||
print(f"Response received: status {resp.status}")
|
||||
except Exception as e:
|
||||
print(f"Exception occurred during GET request: {e}")
|
||||
raise
|
||||
|
||||
assert resp.status != 404, "Route /files does not exist"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_file_service_initialization():
|
||||
with patch('api_server.routes.internal.internal_routes.FileService') as MockFileService:
|
||||
# Create a mock instance
|
||||
mock_file_service_instance = MagicMock(spec=FileService)
|
||||
MockFileService.return_value = mock_file_service_instance
|
||||
internal_routes = InternalRoutes()
|
||||
|
||||
# Check if FileService was initialized with the correct parameters
|
||||
MockFileService.assert_called_once_with({
|
||||
"models": models_dir,
|
||||
"user": user_directory,
|
||||
"output": output_directory
|
||||
})
|
||||
|
||||
# Verify that the file_service attribute of InternalRoutes is set
|
||||
assert internal_routes.file_service == mock_file_service_instance
|
54
tests-unit/server/services/file_service_test.py
Normal file
54
tests-unit/server/services/file_service_test.py
Normal file
@@ -0,0 +1,54 @@
|
||||
import pytest
|
||||
from unittest.mock import MagicMock
|
||||
from api_server.services.file_service import FileService
|
||||
|
||||
@pytest.fixture
|
||||
def mock_file_system_ops():
|
||||
return MagicMock()
|
||||
|
||||
@pytest.fixture
|
||||
def file_service(mock_file_system_ops):
|
||||
allowed_directories = {
|
||||
"models": "/path/to/models",
|
||||
"user": "/path/to/user",
|
||||
"output": "/path/to/output"
|
||||
}
|
||||
return FileService(allowed_directories, file_system_ops=mock_file_system_ops)
|
||||
|
||||
def test_list_files_valid_directory(file_service, mock_file_system_ops):
|
||||
mock_file_system_ops.walk_directory.return_value = [
|
||||
{"name": "file1.txt", "path": "file1.txt", "type": "file", "size": 100},
|
||||
{"name": "dir1", "path": "dir1", "type": "directory"}
|
||||
]
|
||||
|
||||
result = file_service.list_files("models")
|
||||
|
||||
assert len(result) == 2
|
||||
assert result[0]["name"] == "file1.txt"
|
||||
assert result[1]["name"] == "dir1"
|
||||
mock_file_system_ops.walk_directory.assert_called_once_with("/path/to/models")
|
||||
|
||||
def test_list_files_invalid_directory(file_service):
|
||||
# Does not support walking directories outside of the allowed directories
|
||||
with pytest.raises(ValueError, match="Invalid directory key"):
|
||||
file_service.list_files("invalid_key")
|
||||
|
||||
def test_list_files_empty_directory(file_service, mock_file_system_ops):
|
||||
mock_file_system_ops.walk_directory.return_value = []
|
||||
|
||||
result = file_service.list_files("models")
|
||||
|
||||
assert len(result) == 0
|
||||
mock_file_system_ops.walk_directory.assert_called_once_with("/path/to/models")
|
||||
|
||||
@pytest.mark.parametrize("directory_key", ["models", "user", "output"])
|
||||
def test_list_files_all_allowed_directories(file_service, mock_file_system_ops, directory_key):
|
||||
mock_file_system_ops.walk_directory.return_value = [
|
||||
{"name": f"file_{directory_key}.txt", "path": f"file_{directory_key}.txt", "type": "file", "size": 100}
|
||||
]
|
||||
|
||||
result = file_service.list_files(directory_key)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0]["name"] == f"file_{directory_key}.txt"
|
||||
mock_file_system_ops.walk_directory.assert_called_once_with(f"/path/to/{directory_key}")
|
42
tests-unit/server/utils/file_operations_test.py
Normal file
42
tests-unit/server/utils/file_operations_test.py
Normal file
@@ -0,0 +1,42 @@
|
||||
import pytest
|
||||
from typing import List
|
||||
from api_server.utils.file_operations import FileSystemOperations, FileSystemItem, is_file_info
|
||||
|
||||
@pytest.fixture
|
||||
def temp_directory(tmp_path):
|
||||
# Create a temporary directory structure
|
||||
dir1 = tmp_path / "dir1"
|
||||
dir2 = tmp_path / "dir2"
|
||||
dir1.mkdir()
|
||||
dir2.mkdir()
|
||||
(dir1 / "file1.txt").write_text("content1")
|
||||
(dir2 / "file2.txt").write_text("content2")
|
||||
(tmp_path / "file3.txt").write_text("content3")
|
||||
return tmp_path
|
||||
|
||||
def test_walk_directory(temp_directory):
|
||||
result: List[FileSystemItem] = FileSystemOperations.walk_directory(str(temp_directory))
|
||||
|
||||
assert len(result) == 5 # 2 directories and 3 files
|
||||
|
||||
files = [item for item in result if item['type'] == 'file']
|
||||
dirs = [item for item in result if item['type'] == 'directory']
|
||||
|
||||
assert len(files) == 3
|
||||
assert len(dirs) == 2
|
||||
|
||||
file_names = {file['name'] for file in files}
|
||||
assert file_names == {'file1.txt', 'file2.txt', 'file3.txt'}
|
||||
|
||||
dir_names = {dir['name'] for dir in dirs}
|
||||
assert dir_names == {'dir1', 'dir2'}
|
||||
|
||||
def test_walk_directory_empty(tmp_path):
|
||||
result = FileSystemOperations.walk_directory(str(tmp_path))
|
||||
assert len(result) == 0
|
||||
|
||||
def test_walk_directory_file_size(temp_directory):
|
||||
result: List[FileSystemItem] = FileSystemOperations.walk_directory(str(temp_directory))
|
||||
files = [item for item in result if is_file_info(item)]
|
||||
for file in files:
|
||||
assert file['size'] > 0 # Assuming all files have some content
|
4
tests/inference/extra_model_paths.yaml
Normal file
4
tests/inference/extra_model_paths.yaml
Normal file
@@ -0,0 +1,4 @@
|
||||
# Config for testing nodes
|
||||
testing:
|
||||
custom_nodes: tests/inference/testing_nodes
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user