1
mirror of https://github.com/comfyanonymous/ComfyUI.git synced 2025-08-02 23:14:49 +08:00

Compare commits

...

127 Commits

Author SHA1 Message Date
Alex "mcmonkey" Goodwin
f1c2301697 fix typo in stale-issues (#4735) 2024-09-01 17:44:49 -04:00
comfyanonymous
8d31a6632f Speed up inference on nvidia 10 series on Linux. 2024-09-01 17:29:31 -04:00
comfyanonymous
b643eae08b Make minimum_inference_memory() depend on --reserve-vram 2024-09-01 01:18:34 -04:00
comfyanonymous
baa6b4dc36 Update manual install instructions. 2024-08-31 04:37:23 -04:00
Alex "mcmonkey" Goodwin
d4aeefc297 add github action to automatically handle stale user support issues (#4683)
* add github action to automatically handle stale user support issues

* improve stale message

* remove token part
2024-08-31 01:57:18 -04:00
comfyanonymous
587e7ca654 Remove github buttons. 2024-08-31 01:53:10 -04:00
Chenlei Hu
c90459eba0 Update ComfyUI_frontend to 1.2.40 (#4691)
* Update ComfyUI_frontend to 1.2.40

* Add files
2024-08-30 19:32:10 -04:00
Vedat Baday
04278afb10 feat: return import_failed from init_extra_nodes function (#4694) 2024-08-30 19:26:47 -04:00
comfyanonymous
935ae153e1 Cleanup. 2024-08-30 12:53:59 -04:00
Chenlei Hu
e91662e784 Get logs endpoint & system_stats additions (#4690)
* Add route for getting output logs

* Include ComfyUI version

* Move to own function

* Changed to memory logger

* Unify logger setup logic

* Fix get version git fallback

---------

Co-authored-by: pythongosssss <125205205+pythongosssss@users.noreply.github.com>
2024-08-30 12:46:37 -04:00
comfyanonymous
63fafaef45 Fix potential issue with hydit controlnets. 2024-08-30 04:58:41 -04:00
Alex "mcmonkey" Goodwin
ec28cd9136 swap legacy sdv15 link (#4682)
* swap legacy sdv15 link

* swap v15 ckpt examples to safetensors

* link the fp16 copy of the model by default
2024-08-29 19:48:48 -04:00
comfyanonymous
6eb5d64522 Fix glora lowvram issue. 2024-08-29 19:07:23 -04:00
comfyanonymous
10a79e9898 Implement model part of flux union controlnet. 2024-08-29 18:41:22 -04:00
comfyanonymous
ea3f39bd69 InstantX depth flux controlnet. 2024-08-29 02:14:19 -04:00
comfyanonymous
b33cd61070 InstantX canny controlnet. 2024-08-28 19:02:50 -04:00
Dr.Lt.Data
34eda0f853 fix: remove redundant useless loop (#4656)
fix: potential error of undefined variable

https://github.com/comfyanonymous/ComfyUI/discussions/4650
2024-08-28 17:46:30 -04:00
comfyanonymous
d31e226650 Unify RMSNorm code. 2024-08-28 16:56:38 -04:00
comfyanonymous
b79fd7d92c ComfyUI supports more than just stable diffusion. 2024-08-28 16:12:24 -04:00
comfyanonymous
38c22e631a Fix case where model was not properly unloaded in merging workflows. 2024-08-27 19:03:51 -04:00
Chenlei Hu
6bbdcd28ae Support weight padding on diff weight patch (#4576) 2024-08-27 13:55:37 -04:00
comfyanonymous
ab130001a8 Do RMSNorm in native type. 2024-08-27 02:41:56 -04:00
Chenlei Hu
ca4b8f30e0 Cleanup empty dir if frontend zip download failed (#4574) 2024-08-27 02:07:25 -04:00
Robin Huang
70b84058c1 Add relative file path to the progress report. (#4621) 2024-08-27 02:06:12 -04:00
comfyanonymous
2ca8f6e23d Make the stochastic fp8 rounding reproducible. 2024-08-26 15:12:06 -04:00
comfyanonymous
7985ff88b9 Use less memory in float8 lora patching by doing calculations in fp16. 2024-08-26 14:45:58 -04:00
comfyanonymous
c6812947e9 Fix potential memory leak. 2024-08-26 02:07:32 -04:00
comfyanonymous
9230f65823 Fix some controlnets OOMing when loading. 2024-08-25 05:54:29 -04:00
guill
6ab1e6fd4a [Bug #4529] Fix graph partial validation failure (#4588)
Currently, if a graph partially fails validation (i.e. some outputs are
valid while others have links from missing nodes), the execution loop
could get an exception resulting in server lockup.

This isn't actually possible to reproduce via the default UI, but is a
potential issue for people using the API to construct invalid graphs.
2024-08-24 15:34:58 -04:00
comfyanonymous
07dcbc3a3e Clarify how to use high quality previews. 2024-08-24 02:31:03 -04:00
comfyanonymous
8ae23d8e80 Fix onnx export. 2024-08-23 17:52:47 -04:00
comfyanonymous
7df42b9a23 Fix dora. 2024-08-23 04:58:59 -04:00
comfyanonymous
5d8bbb7281 Cleanup. 2024-08-23 04:06:27 -04:00
comfyanonymous
2c1d2375d6 Fix. 2024-08-23 04:04:55 -04:00
Simon Lui
64ccb3c7e3 Rework IPEX check for future inclusion of XPU into Pytorch upstream and do a bit more optimization of ipex.optimize(). (#4562) 2024-08-23 03:59:57 -04:00
Scorpinaus
9465b23432 Added SD15_Inpaint_Diffusers model support for unet_config_from_diffusers_unet function (#4565) 2024-08-23 03:57:08 -04:00
Chenlei Hu
bb4416dd5b Fix task.status.status_str caused by #2666 (#4551)
* Fix task.status.status_str caused by 2666 regression

* fix

* fix
2024-08-22 17:38:30 -04:00
comfyanonymous
c0b0da264b Missing imports. 2024-08-22 17:20:51 -04:00
comfyanonymous
c26ca27207 Move calculate function to comfy.lora 2024-08-22 17:12:00 -04:00
comfyanonymous
7c6bb84016 Code cleanups. 2024-08-22 17:05:12 -04:00
comfyanonymous
c54d3ed5e6 Fix issue with models staying loaded in memory. 2024-08-22 15:58:20 -04:00
comfyanonymous
c7ee4b37a1 Try to fix some lora issues. 2024-08-22 15:32:18 -04:00
David
7b70b266d8 Generalize MacOS version check for force-upcast-attention (#4548)
This code automatically forces upcasting attention for MacOS versions 14.5 and 14.6. My computer returns the string "14.6.1" for `platform.mac_ver()[0]`, so this generalizes the comparison to catch more versions.

I am running MacOS Sonoma 14.6.1 (latest version) and was seeing black image generation on previously functional workflows after recent software updates. This PR solved the issue for me.

See comfyanonymous/ComfyUI#3521
2024-08-22 13:24:21 -04:00
comfyanonymous
8f60d093ba Fix issue. 2024-08-22 10:38:24 -04:00
guill
dafbe321d2 Fix a bug where cached outputs affected IS_CHANGED (#4535)
This change fixes a bug where non-constant values could be passed to the
IS_CHANGED function. This would result in workflows taking an extra
execution before they acted as if they were cached.

The actual change is like 4 characters -- the rest is adding unit tests.
2024-08-21 23:38:46 -04:00
comfyanonymous
5f84ea63e8 Add a shortcut to the nightly package to run with --fast. 2024-08-21 23:36:58 -04:00
comfyanonymous
843a7ff70c fp16 is actually faster than fp32 on a GTX 1080. 2024-08-21 23:23:50 -04:00
comfyanonymous
a60620dcea Fix slow performance on 10 series Nvidia GPUs. 2024-08-21 16:39:02 -04:00
comfyanonymous
015f73dc49 Try a different type of flux fp16 fix. 2024-08-21 16:17:15 -04:00
comfyanonymous
904bf58e7d Make --fast work on pytorch nightly. 2024-08-21 14:01:41 -04:00
Svein Ove Aas
5f50263088 Replace use of .view with .reshape (#4522)
When generating images with fp8_e4_m3 Flux and batch size >1, using --fast, ComfyUI throws a "view size is not compatible with input tensor's size and stride" error pointing at the first of these two calls to view.

As reshape is semantically equivalent to view except for working on a broader set of inputs, there should be no downside to changing this. The only difference is that it clones the underlying data in cases where .view would error out. I have confirmed that the output still looks as expected, but cannot confirm that no mutable use is made of the tensors anywhere.

Note that --fast is only marginally faster than the default.
2024-08-21 11:21:48 -04:00
Alex "mcmonkey" Goodwin
5e806f555d add a get models list api route (#4519)
* get models list api route

* remove copypasta
2024-08-21 02:04:42 -04:00
Robin Huang
f07e5bb522 Add GET /internal/files. (#4295)
* Create internal route table.

* List files.

* Add GET /internal/files.

Retrieves list of files in models, output, and user directories.

* Refactor file names.

* Use typing_extensions for Python 3.8

* Fix tests.

* Remove print statements.

* Update README.

* Add output and user to valid directory test.

* Add missing type hints.
2024-08-21 01:25:06 -04:00
comfyanonymous
03ec517afb Remove useless line, adjust windows default reserved vram. 2024-08-21 00:47:19 -04:00
Chenlei Hu
f257fc999f Add optional deprecated/experimental flag to node class (#4506)
* Add optional deprecated flag to node class

* nit

* Add experimental flag
2024-08-21 00:01:34 -04:00
Chenlei Hu
bb50e69839 Update frontend to 1.2.30 (#4513) 2024-08-21 00:00:49 -04:00
comfyanonymous
510f3438c1 Speed up fp8 matrix mult by using better code. 2024-08-20 22:53:26 -04:00
comfyanonymous
ea63b1c092 Simpletrainer lycoris format. 2024-08-20 12:05:13 -04:00
comfyanonymous
9953f22fce Add --fast argument to enable experimental optimizations.
Optimizations that might break things/lower quality will be put behind
this flag first and might be enabled by default in the future.

Currently the only optimization is float8_e4m3fn matrix multiplication on
4000/ADA series Nvidia cards or later. If you have one of these cards you
will see a speed boost when using fp8_e4m3fn flux for example.
2024-08-20 11:55:51 -04:00
comfyanonymous
d1a6bd6845 Support loading long clipl model with the CLIP loader node. 2024-08-20 10:46:36 -04:00
comfyanonymous
83dbac28eb Properly set if clip text pooled projection instead of using hack. 2024-08-20 10:46:36 -04:00
comfyanonymous
538cb068bc Make cast_to a nop if weight is already good. 2024-08-20 10:46:36 -04:00
comfyanonymous
1b3eee672c Fix potential issue with multi devices. 2024-08-20 10:46:36 -04:00
Chenlei Hu
5a69f84c3c Update README.md (Add shield badges) (#4490) 2024-08-19 18:25:20 -04:00
comfyanonymous
9eee470244 New load_text_encoder_state_dicts function.
Now you can load text encoders straight from a list of state dicts.
2024-08-19 17:36:35 -04:00
comfyanonymous
045377ea89 Add a --reserve-vram argument if you don't want comfy to use all of it.
--reserve-vram 1.0 for example will make ComfyUI try to keep 1GB vram free.

This can also be useful if workflows are failing because of OOM errors but
in that case please report it if --reserve-vram improves your situation.
2024-08-19 17:16:18 -04:00
comfyanonymous
4d341b78e8 Bug fixes. 2024-08-19 16:28:55 -04:00
comfyanonymous
6138f92084 Use better dtype for the lowvram lora system. 2024-08-19 15:35:25 -04:00
comfyanonymous
be0726c1ed Remove duplication. 2024-08-19 15:26:50 -04:00
comfyanonymous
766ae119a8 CheckpointSave node name. 2024-08-19 15:06:12 -04:00
Yoland Yan
fc90ceb6ba Update issue template config.yml to direct frontend issues to frontend repos (#4486)
* Update config.yml

* Typos
2024-08-19 13:41:30 -04:00
comfyanonymous
4506ddc86a Better subnormal fp8 stochastic rounding. Thanks Ashen. 2024-08-19 13:38:03 -04:00
comfyanonymous
20ace7c853 Code cleanup. 2024-08-19 12:48:59 -04:00
Chenlei Hu
b29b3b86c5 Update README to include frontend section (#4468)
* Update README to include frontend section

* nit
2024-08-19 07:12:32 -04:00
comfyanonymous
22ec02afc0 Handle subnormal numbers in float8 rounding. 2024-08-19 05:51:08 -04:00
comfyanonymous
39f114c44b Less broken non blocking? 2024-08-18 16:53:17 -04:00
comfyanonymous
6730f3e1a3 Disable non blocking.
It fixed some perf issues but caused other issues that need to be debugged.
2024-08-18 14:38:09 -04:00
comfyanonymous
73332160c8 Enable non blocking transfers in lowvram mode. 2024-08-18 10:29:33 -04:00
comfyanonymous
2622c55aff Automatically use RF variant of dpmpp_2s_ancestral if RF model. 2024-08-18 00:47:25 -04:00
Ashen
1beb348ee2 dpmpp_2s_ancestral_RF for rectified flow (Flux, SD3 and Auraflow). 2024-08-18 00:33:30 -04:00
bymyself
9aa39e743c Add new shortcuts to readme (#4442) 2024-08-17 23:52:56 -04:00
comfyanonymous
d31df04c8a Indentation. 2024-08-17 23:00:44 -04:00
Xrvk
e68763f40c Add Flux model support for InstantX style controlnet residuals (#4444)
* Add Flux model support for InstantX style controlnet residuals

* Refactor Flux controlnet residual step to a separate method

* Rollback minor change

* New format for applying controlnet residuals: input->double_blocks, output->single_blocks

* Adjust XLabs Flux controlnet to fit new syntax of applying Flux controlnet residuals

* Remove unnecessary import and minor style change
2024-08-17 22:58:23 -04:00
comfyanonymous
310ad09258 Add a ModelSave node. 2024-08-17 21:43:07 -04:00
comfyanonymous
4f7a3cb6fb unet -> diffusion_models. 2024-08-17 21:31:04 -04:00
comfyanonymous
bb222ceddb Fix loras having a weak effect when applied on fp8. 2024-08-17 15:20:17 -04:00
comfyanonymous
14af129c55 Improve execution UX.
Some branches with VAELoader -> VAEDecode -> Preview were being executed
last. With this change they will be executed earlier.
2024-08-17 11:37:21 -04:00
comfyanonymous
fca42836f2 Add model_options for text encoder. 2024-08-17 11:17:20 -04:00
comfyanonymous
858d51f91a Fix VAEDecode -> Preview not being executed first. 2024-08-17 04:08:54 -04:00
comfyanonymous
cd5017c1c9 calculate_weight function to use a different dtype. 2024-08-17 01:06:08 -04:00
comfyanonymous
83f343146a Fix potential lowvram issue. 2024-08-16 17:12:42 -04:00
Chenlei Hu
b021cf67c7 Update frontend to 1.2.26 (#4415) 2024-08-16 15:25:02 -04:00
Matthew Turnshek
1770fc77ed Implement support for taef1 latent previews (#4409)
* add taef1 handling to several places

* remove guess_latent_channels and add latent_channels info directly to flux model

* remove TODO

* fix numbers
2024-08-16 12:53:13 -04:00
comfyanonymous
05a9f3faa1 Log a warning when there's an issue with IS_CHANGED. 2024-08-16 08:50:17 -04:00
comfyanonymous
86c5970ac0 Fix custom nodes hooking the map_node_over_list and breaking things. 2024-08-16 08:40:31 -04:00
Chenlei Hu
bfc214f434 Use new TS frontend uncompressed (#4379)
* Swap frontend uncompressed

* Add uncompressed files
2024-08-15 16:50:25 -04:00
comfyanonymous
3f5939add6 Tell github not to count the web directory in language stats. 2024-08-15 13:48:56 -04:00
comfyanonymous
5960f946a9 Move a few files from comfy -> comfy_execution.
Python code in the comfy folder should not import things from outside it.
2024-08-15 11:21:14 -04:00
guill
5cfe38f41c Execution Model Inversion (#2666)
* Execution Model Inversion

This PR inverts the execution model -- from recursively calling nodes to
using a topological sort of the nodes. This change allows for
modification of the node graph during execution. This allows for two
major advantages:

    1. The implementation of lazy evaluation in nodes. For example, if a
    "Mix Images" node has a mix factor of exactly 0.0, the second image
    input doesn't even need to be evaluated (and visa-versa if the mix
    factor is 1.0).

    2. Dynamic expansion of nodes. This allows for the creation of dynamic
    "node groups". Specifically, custom nodes can return subgraphs that
    replace the original node in the graph. This is an incredibly
    powerful concept. Using this functionality, it was easy to
    implement:
        a. Components (a.k.a. node groups)
        b. Flow control (i.e. while loops) via tail recursion
        c. All-in-one nodes that replicate the WebUI functionality
        d. and more
    All of those were able to be implemented entirely via custom nodes,
    so those features are *not* a part of this PR. (There are some
    front-end changes that should occur before that functionality is
    made widely available, particularly around variant sockets.)

The custom nodes associated with this PR can be found at:
https://github.com/BadCafeCode/execution-inversion-demo-comfyui

Note that some of them require that variant socket types ("*") be
enabled.

* Allow `input_info` to be of type `None`

* Handle errors (like OOM) more gracefully

* Add a command-line argument to enable variants

This allows the use of nodes that have sockets of type '*' without
applying a patch to the code.

* Fix an overly aggressive assertion.

This could happen when attempting to evaluate `IS_CHANGED` for a node
during the creation of the cache (in order to create the cache key).

* Fix Pyright warnings

* Add execution model unit tests

* Fix issue with unused literals

Behavior should now match the master branch with regard to undeclared
inputs. Undeclared inputs that are socket connections will be used while
undeclared inputs that are literals will be ignored.

* Make custom VALIDATE_INPUTS skip normal validation

Additionally, if `VALIDATE_INPUTS` takes an argument named `input_types`,
that variable will be a dictionary of the socket type of all incoming
connections. If that argument exists, normal socket type validation will
not occur. This removes the last hurdle for enabling variant types
entirely from custom nodes, so I've removed that command-line option.

I've added appropriate unit tests for these changes.

* Fix example in unit test

This wouldn't have caused any issues in the unit test, but it would have
bugged the UI if someone copy+pasted it into their own node pack.

* Use fstrings instead of '%' formatting syntax

* Use custom exception types.

* Display an error for dependency cycles

Previously, dependency cycles that were created during node expansion
would cause the application to quit (due to an uncaught exception). Now,
we'll throw a proper error to the UI. We also make an attempt to 'blame'
the most relevant node in the UI.

* Add docs on when ExecutionBlocker should be used

* Remove unused functionality

* Rename ExecutionResult.SLEEPING to PENDING

* Remove superfluous function parameter

* Pass None for uneval inputs instead of default

This applies to `VALIDATE_INPUTS`, `check_lazy_status`, and lazy values
in evaluation functions.

* Add a test for mixed node expansion

This test ensures that a node that returns a combination of expanded
subgraphs and literal values functions correctly.

* Raise exception for bad get_node calls.

* Minor refactor of IsChangedCache.get

* Refactor `map_node_over_list` function

* Fix ui output for duplicated nodes

* Add documentation on `check_lazy_status`

* Add file for execution model unit tests

* Clean up Javascript code as per review

* Improve documentation

Converted some comments to docstrings as per review

* Add a new unit test for mixed lazy results

This test validates that when an output list is fed to a lazy node, the
node will properly evaluate previous nodes that are needed by any inputs
to the lazy node.

No code in the execution model has been changed. The test already
passes.

* Allow kwargs in VALIDATE_INPUTS functions

When kwargs are used, validation is skipped for all inputs as if they
had been mentioned explicitly.

* List cached nodes in `execution_cached` message

This was previously just bugged in this PR.
2024-08-15 11:21:11 -04:00
comfyanonymous
0f9c2a7822 Try to fix SDXL OOM issue on some configurations. 2024-08-14 23:08:54 -04:00
comfyanonymous
153d0a8142 Add a update/update_comfyui_stable.bat to the standalones. 2024-08-14 22:29:23 -04:00
Chenlei Hu
ab4dd19b91 Remove legacy ui test files (#4316) 2024-08-14 21:01:06 -04:00
comfyanonymous
f1d6cef71c Revert "Disable cuda malloc by default."
This reverts commit 50bf66e5c4.
2024-08-14 08:38:07 -04:00
comfyanonymous
33fb282d5c Fix issue. 2024-08-14 02:51:47 -04:00
comfyanonymous
50bf66e5c4 Disable cuda malloc by default. 2024-08-14 02:49:25 -04:00
pythongosssss
e60e19b175 Add support for simple tooltips (#3842)
* Add support for simple tooltips

* Fix overflow

* Add tooltips for nodes in the default workflow

* new line

* Prevent potential crash

* PR feedback

* Hide tooltip when clicking (e.g. combo widget)

* Refactor tooltips, add node level support

* Fix

* move

* Fix test (and undo last change)

* Fixed indent

* Fix dom widgets, dont show tooltip if not over canvas
2024-08-14 01:22:10 -04:00
comfyanonymous
a5af64d3ce Revert "Not sure if this actually changes anything but it can't hurt."
This reverts commit 34608de2e9.
2024-08-14 01:05:17 -04:00
Robin Huang
3e52e0364c Add model downloading endpoint. (#4248)
* Add model downloading endpoint.

* Move client session init to async function.

* Break up large function.

* Send "download_progress" as websocket event.

* Fixed

* Fixed.

* Use async mock.

* Move server set up to right before run call.

* Validate that model subdirectory cannot contain relative paths.

* Add download_model test checking for invalid paths.

* Remove DS_Store.

* Consolidate DownloadStatus and DownloadModelResult

* Add progress_interval as an optional parameter.

* Use tuple type from annotations.

* Use pydantic.

* Update comment.

* Revert "Use pydantic."

This reverts commit 7461e8eb00.

* Add new line.

* Add newline EOF.

* Validate model filename as well.

* Add comment to not reply on internal.

* Restrict downloading to safetensor files only.
2024-08-13 15:48:52 -04:00
comfyanonymous
34608de2e9 Not sure if this actually changes anything but it can't hurt. 2024-08-13 13:29:16 -04:00
comfyanonymous
39fb74c5bd Fix bug when model cannot be partially unloaded. 2024-08-13 03:57:55 -04:00
comfyanonymous
74e124f4d7 Fix some issues with TE being in lowvram mode. 2024-08-12 23:42:21 -04:00
comfyanonymous
a562c17e8a load_unet -> load_diffusion_model with a model_options argument. 2024-08-12 23:20:57 -04:00
comfyanonymous
5942c17d55 Order of operations matters. 2024-08-12 21:56:18 -04:00
comfyanonymous
c032b11e07 xlabs Flux controlnet implementation. (#4260)
* xlabs Flux controlnet.

* Fix not working on old python.

* Remove comment.
2024-08-12 21:22:22 -04:00
comfyanonymous
b8ffb2937f Memory tweaks. 2024-08-12 15:07:11 -04:00
Vladimir Semyonov
ce37c11164 add DS_Store to gitignore (#4324) 2024-08-12 12:32:34 -04:00
Alex "mcmonkey" Goodwin
b5c3906b38 Automatically link the Comfy CI page on PRs (#4326)
also use_prior_commit so it doesn't get a janked merge commit instead of the real one
2024-08-12 12:32:16 -04:00
comfyanonymous
5d43e75e5b Fix some issues with the model sometimes not getting patched. 2024-08-12 12:27:54 -04:00
comfyanonymous
517f4a94e4 Fix some lora loading slowdowns. 2024-08-12 11:50:32 -04:00
comfyanonymous
52a471c5c7 Change name of log. 2024-08-12 10:35:06 -04:00
comfyanonymous
ad76574cb8 Fix some potential issues with the previous commits. 2024-08-12 00:23:29 -04:00
comfyanonymous
9acfe4df41 Support loading directly to vram with CLIPLoader node. 2024-08-12 00:06:01 -04:00
comfyanonymous
9829b013ea Fix mistake in last commit. 2024-08-12 00:00:17 -04:00
comfyanonymous
5c69cde037 Load TE model straight to vram if certain conditions are met. 2024-08-11 23:52:43 -04:00
comfyanonymous
e9589d6d92 Add a way to set model dtype and ops from load_checkpoint_guess_config. 2024-08-11 08:50:34 -04:00
comfyanonymous
0d82a798a5 Remove the ckpt_path from load_state_dict_guess_config. 2024-08-11 08:37:35 -04:00
ljleb
925fff26fd alternative to load_checkpoint_guess_config that accepts a loaded state dict (#4249)
* make alternative fn

* add back ckpt path as 2nd argument?
2024-08-11 08:36:52 -04:00
193 changed files with 117863 additions and 43503 deletions

View File

@@ -75,6 +75,25 @@ else:
print("pulling latest changes")
pull(repo)
if "--stable" in sys.argv:
def latest_tag(repo):
versions = []
for k in repo.references:
try:
prefix = "refs/tags/v"
if k.startswith(prefix):
version = list(map(int, k[len(prefix):].split(".")))
versions.append((version[0] * 10000000000 + version[1] * 100000 + version[2], k))
except:
pass
versions.sort()
if len(versions) > 0:
return versions[-1][1]
return None
latest_tag = latest_tag(repo)
if latest_tag is not None:
repo.checkout(latest_tag)
print("Done!")
self_update = True
@@ -115,3 +134,13 @@ if not os.path.exists(req_path) or not files_equal(repo_req_path, req_path):
shutil.copy(repo_req_path, req_path)
except:
pass
stable_update_script = os.path.join(repo_path, ".ci/update_windows/update_comfyui_stable.bat")
stable_update_script_to = os.path.join(cur_path, "update_comfyui_stable.bat")
try:
if not file_size(stable_update_script_to) > 10:
shutil.copy(stable_update_script, stable_update_script_to)
except:
pass

View File

@@ -0,0 +1,8 @@
@echo off
..\python_embeded\python.exe .\update.py ..\ComfyUI\ --stable
if exist update_new.py (
move /y update_new.py update.py
echo Running updater again since it got updated.
..\python_embeded\python.exe .\update.py ..\ComfyUI\ --skip_self_update --stable
)
if "%~1"=="" pause

View File

@@ -14,7 +14,7 @@ run_cpu.bat
IF YOU GET A RED ERROR IN THE UI MAKE SURE YOU HAVE A MODEL/CHECKPOINT IN: ComfyUI\models\checkpoints
You can download the stable diffusion 1.5 one from: https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.ckpt
You can download the stable diffusion 1.5 one from: https://huggingface.co/Comfy-Org/stable-diffusion-v1-5-archive/blob/main/v1-5-pruned-emaonly-fp16.safetensors
RECOMMENDED WAY TO UPDATE:

View File

@@ -0,0 +1,2 @@
.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build --fast
pause

2
.gitattributes vendored Normal file
View File

@@ -0,0 +1,2 @@
/web/assets/** linguist-generated
/web/** linguist-vendored

View File

@@ -1,5 +1,8 @@
blank_issues_enabled: true
contact_links:
- name: ComfyUI Frontend Issues
url: https://github.com/Comfy-Org/ComfyUI_frontend/issues
about: Issues related to the ComfyUI frontend (display issues, user interaction bugs), please go to the frontend repo to file the issue
- name: ComfyUI Matrix Space
url: https://app.element.io/#/room/%23comfyui_space%3Amatrix.org
about: The ComfyUI Matrix Space is available for support and general discussion related to ComfyUI (Matrix is like Discord but open source).

View File

@@ -35,3 +35,19 @@ jobs:
torch_version: ${{ matrix.torch_version }}
google_credentials: ${{ secrets.GCS_SERVICE_ACCOUNT_JSON }}
comfyui_flags: ${{ matrix.flags }}
use_prior_commit: 'true'
comment:
if: ${{ github.event.label.name == 'Run-CI-Test' }}
runs-on: ubuntu-latest
permissions:
pull-requests: write
steps:
- uses: actions/github-script@v6
with:
script: |
github.rest.issues.createComment({
issue_number: context.issue.number,
owner: context.repo.owner,
repo: context.repo.repo,
body: '(Automated Bot Message) CI Tests are running, you can view the results at https://ci.comfy.org/?branch=${{ github.event.pull_request.number }}%2Fmerge'
})

21
.github/workflows/stale-issues.yml vendored Normal file
View File

@@ -0,0 +1,21 @@
name: 'Close stale issues'
on:
schedule:
# Run daily at 430 am PT
- cron: '30 11 * * *'
permissions:
issues: write
jobs:
stale:
runs-on: ubuntu-latest
steps:
- uses: actions/stale@v9
with:
stale-issue-message: "This issue is being marked stale because it has not had any activity for 30 days. Reply below within 7 days if your issue still isn't solved, and it will be left open. Otherwise, the issue will be closed automatically."
days-before-stale: 30
days-before-close: 7
stale-issue-label: 'Stale'
only-labels: 'User Support'
exempt-all-assignees: true
exempt-all-milestones: true

View File

@@ -1,10 +1,4 @@
# This is a temporary action during frontend TS migration.
# This file should be removed after TS migration is completed.
# The browser test is here to ensure TS repo is working the same way as the
# current JS code.
# If you are adding UI feature, please sync your changes to the TS repo:
# huchenlei/ComfyUI_frontend and update test expectation files accordingly.
name: Playwright Browser Tests CI
name: Test server launches without errors
on:
push:
@@ -21,15 +15,6 @@ jobs:
with:
repository: "comfyanonymous/ComfyUI"
path: "ComfyUI"
- name: Checkout ComfyUI_frontend
uses: actions/checkout@v4
with:
repository: "huchenlei/ComfyUI_frontend"
path: "ComfyUI_frontend"
ref: "fcc54d803e5b6a9b08a462a1d94899318c96dcbb"
- uses: actions/setup-node@v3
with:
node-version: lts/*
- uses: actions/setup-python@v4
with:
python-version: '3.8'
@@ -45,16 +30,6 @@ jobs:
python main.py --cpu 2>&1 | tee console_output.log &
wait-for-it --service 127.0.0.1:8188 -t 600
working-directory: ComfyUI
- name: Install ComfyUI_frontend dependencies
run: |
npm ci
working-directory: ComfyUI_frontend
- name: Install Playwright Browsers
run: npx playwright install --with-deps
working-directory: ComfyUI_frontend
- name: Run Playwright tests
run: npx playwright test
working-directory: ComfyUI_frontend
- name: Check for unhandled exceptions in server log
run: |
if grep -qE "Exception|Error" console_output.log; then
@@ -62,12 +37,6 @@ jobs:
exit 1
fi
working-directory: ComfyUI
- uses: actions/upload-artifact@v4
if: always()
with:
name: playwright-report
path: ComfyUI_frontend/playwright-report/
retention-days: 30
- uses: actions/upload-artifact@v4
if: always()
with:

View File

@@ -1,30 +0,0 @@
name: Tests CI
on: [push, pull_request]
jobs:
test:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-node@v3
with:
node-version: 18
- uses: actions/setup-python@v4
with:
python-version: '3.10'
- name: Install requirements
run: |
python -m pip install --upgrade pip
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
pip install -r requirements.txt
- name: Run Tests
run: |
npm ci
npm run test:generate
npm test -- --verbose
working-directory: ./tests-ui
- name: Run Unit Tests
run: |
pip install -r tests-unit/requirements.txt
python -m pytest tests-unit

View File

@@ -67,6 +67,7 @@ jobs:
mkdir update
cp -r ComfyUI/.ci/update_windows/* ./update/
cp -r ComfyUI/.ci/windows_base_files/* ./
cp -r ComfyUI/.ci/windows_nightly_base_files/* ./
echo "call update_comfyui.bat nopause
..\python_embeded\python.exe -s -m pip install --upgrade --pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/nightly/cu${{ inputs.cu }} -r ../ComfyUI/requirements.txt pygit2

3
.gitignore vendored
View File

@@ -18,4 +18,5 @@ venv/
/tests-ui/data/object_info.json
/user/
*.log
web_custom_versions/
web_custom_versions/
.DS_Store

View File

@@ -1,8 +1,35 @@
ComfyUI
=======
The most powerful and modular stable diffusion GUI and backend.
-----------
<div align="center">
# ComfyUI
**The most powerful and modular diffusion model GUI and backend.**
[![Website][website-shield]][website-url]
[![Dynamic JSON Badge][discord-shield]][discord-url]
[![Matrix][matrix-shield]][matrix-url]
<br>
[![][github-release-shield]][github-release-link]
[![][github-release-date-shield]][github-release-link]
[![][github-downloads-shield]][github-downloads-link]
[![][github-downloads-latest-shield]][github-downloads-link]
[matrix-shield]: https://img.shields.io/badge/Matrix-000000?style=flat&logo=matrix&logoColor=white
[matrix-url]: https://app.element.io/#/room/%23comfyui_space%3Amatrix.org
[website-shield]: https://img.shields.io/badge/ComfyOrg-4285F4?style=flat
[website-url]: https://www.comfy.org/
<!-- Workaround to display total user from https://github.com/badges/shields/issues/4500#issuecomment-2060079995 -->
[discord-shield]: https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fdiscord.com%2Fapi%2Finvites%2Fcomfyorg%3Fwith_counts%3Dtrue&query=%24.approximate_member_count&logo=discord&logoColor=white&label=Discord&color=green&suffix=%20total
[discord-url]: https://www.comfy.org/discord
[github-release-shield]: https://img.shields.io/github/v/release/comfyanonymous/ComfyUI?style=flat&sort=semver
[github-release-link]: https://github.com/comfyanonymous/ComfyUI/releases
[github-release-date-shield]: https://img.shields.io/github/release-date/comfyanonymous/ComfyUI?style=flat
[github-downloads-shield]: https://img.shields.io/github/downloads/comfyanonymous/ComfyUI/total?style=flat
[github-downloads-latest-shield]: https://img.shields.io/github/downloads/comfyanonymous/ComfyUI/latest/total?style=flat&label=downloads%40latest
[github-downloads-link]: https://github.com/comfyanonymous/ComfyUI/releases
![ComfyUI Screenshot](comfyui_screenshot.png)
</div>
This ui will let you design and execute advanced stable diffusion pipelines using a graph/nodes/flowchart based interface. For some workflow examples and see what ComfyUI can do you can check out:
### [ComfyUI Examples](https://comfyanonymous.github.io/ComfyUI_examples/)
@@ -48,6 +75,7 @@ Workflow examples can be found on the [Examples page](https://comfyanonymous.git
|------------------------------------|--------------------------------------------------------------------------------------------------------------------|
| Ctrl + Enter | Queue up current graph for generation |
| Ctrl + Shift + Enter | Queue up current graph as first for generation |
| Ctrl + Alt + Enter | Cancel current generation |
| Ctrl + Z/Ctrl + Y | Undo/Redo |
| Ctrl + S | Save workflow |
| Ctrl + O | Load workflow |
@@ -70,6 +98,8 @@ Workflow examples can be found on the [Examples page](https://comfyanonymous.git
| H | Toggle visibility of history |
| R | Refresh graph |
| Double-Click LMB | Open node quick search palette |
| Shift + Drag | Move multiple wires at once |
| Ctrl + Alt + LMB | Disconnect all wires from clicked slot |
Ctrl can also be replaced with Cmd instead for macOS users
@@ -105,17 +135,17 @@ Put your VAE in: models/vae
### AMD GPUs (Linux only)
AMD users can install rocm and pytorch with pip if you don't have it already installed, this is the command to install the stable version:
```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.0```
```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.1```
This is the command to install the nightly with ROCm 6.0 which might have some performance improvements:
This is the command to install the nightly with ROCm 6.2 which might have some performance improvements:
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm6.1```
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm6.2```
### NVIDIA
Nvidia users should install stable pytorch using this command:
```pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu121```
```pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu124```
This is the command to install pytorch nightly instead which might have performance improvements:
@@ -200,7 +230,7 @@ To use a textual inversion concepts/embeddings in a text prompt put them in the
Use ```--preview-method auto``` to enable previews.
The default installation includes a fast latent preview method that's low-resolution. To enable higher-quality previews with [TAESD](https://github.com/madebyollin/taesd), download the [taesd_decoder.pth](https://github.com/madebyollin/taesd/raw/main/taesd_decoder.pth) (for SD1.x and SD2.x) and [taesdxl_decoder.pth](https://github.com/madebyollin/taesd/raw/main/taesdxl_decoder.pth) (for SDXL) models and place them in the `models/vae_approx` folder. Once they're installed, restart ComfyUI to enable high-quality previews.
The default installation includes a fast latent preview method that's low-resolution. To enable higher-quality previews with [TAESD](https://github.com/madebyollin/taesd), download the [taesd_decoder.pth, taesdxl_decoder.pth, taesd3_decoder.pth and taef1_decoder.pth](https://github.com/madebyollin/taesd/) and place them in the `models/vae_approx` folder. Once they're installed, restart ComfyUI and launch it with `--preview-method taesd` to enable high-quality previews.
## How to use TLS/SSL?
Generate a self-signed certificate (not appropriate for shared/production use) and key by running the command: `openssl req -x509 -newkey rsa:4096 -keyout key.pem -out cert.pem -sha256 -days 3650 -nodes -subj "/C=XX/ST=StateName/L=CityName/O=CompanyName/OU=CompanySectionName/CN=CommonNameOrHostname"`
@@ -216,6 +246,47 @@ Use `--tls-keyfile key.pem --tls-certfile cert.pem` to enable TLS/SSL, the app w
See also: [https://www.comfy.org/](https://www.comfy.org/)
## Frontend Development
As of August 15, 2024, we have transitioned to a new frontend, which is now hosted in a separate repository: [ComfyUI Frontend](https://github.com/Comfy-Org/ComfyUI_frontend). This repository now hosts the compiled JS (from TS/Vue) under the `web/` directory.
### Reporting Issues and Requesting Features
For any bugs, issues, or feature requests related to the frontend, please use the [ComfyUI Frontend repository](https://github.com/Comfy-Org/ComfyUI_frontend). This will help us manage and address frontend-specific concerns more efficiently.
### Using the Latest Frontend
The new frontend is now the default for ComfyUI. However, please note:
1. The frontend in the main ComfyUI repository is updated weekly.
2. Daily releases are available in the separate frontend repository.
To use the most up-to-date frontend version:
1. For the latest daily release, launch ComfyUI with this command line argument:
```
--front-end-version Comfy-Org/ComfyUI_frontend@latest
```
2. For a specific version, replace `latest` with the desired version number:
```
--front-end-version Comfy-Org/ComfyUI_frontend@1.2.2
```
This approach allows you to easily switch between the stable weekly release and the cutting-edge daily updates, or even specific versions for testing purposes.
### Accessing the Legacy Frontend
If you need to use the legacy frontend for any reason, you can access it using the following command line argument:
```
--front-end-version Comfy-Org/ComfyUI_legacy_frontend@latest
```
This will use a snapshot of the legacy frontend preserved in the [ComfyUI Legacy Frontend repository](https://github.com/Comfy-Org/ComfyUI_legacy_frontend).
# QA
### Which GPU should I buy for this?

0
api_server/__init__.py Normal file
View File

View File

View File

@@ -0,0 +1,3 @@
# ComfyUI Internal Routes
All routes under the `/internal` path are designated for **internal use by ComfyUI only**. These routes are not intended for use by external applications may change at any time without notice.

View File

View File

@@ -0,0 +1,44 @@
from aiohttp import web
from typing import Optional
from folder_paths import models_dir, user_directory, output_directory
from api_server.services.file_service import FileService
import app.logger
class InternalRoutes:
'''
The top level web router for internal routes: /internal/*
The endpoints here should NOT be depended upon. It is for ComfyUI frontend use only.
Check README.md for more information.
'''
def __init__(self):
self.routes: web.RouteTableDef = web.RouteTableDef()
self._app: Optional[web.Application] = None
self.file_service = FileService({
"models": models_dir,
"user": user_directory,
"output": output_directory
})
def setup_routes(self):
@self.routes.get('/files')
async def list_files(request):
directory_key = request.query.get('directory', '')
try:
file_list = self.file_service.list_files(directory_key)
return web.json_response({"files": file_list})
except ValueError as e:
return web.json_response({"error": str(e)}, status=400)
except Exception as e:
return web.json_response({"error": str(e)}, status=500)
@self.routes.get('/logs')
async def get_logs(request):
return web.json_response(app.logger.get_logs())
def get_app(self):
if self._app is None:
self._app = web.Application()
self.setup_routes()
self._app.add_routes(self.routes)
return self._app

View File

View File

@@ -0,0 +1,13 @@
from typing import Dict, List, Optional
from api_server.utils.file_operations import FileSystemOperations, FileSystemItem
class FileService:
def __init__(self, allowed_directories: Dict[str, str], file_system_ops: Optional[FileSystemOperations] = None):
self.allowed_directories: Dict[str, str] = allowed_directories
self.file_system_ops: FileSystemOperations = file_system_ops or FileSystemOperations()
def list_files(self, directory_key: str) -> List[FileSystemItem]:
if directory_key not in self.allowed_directories:
raise ValueError("Invalid directory key")
directory_path: str = self.allowed_directories[directory_key]
return self.file_system_ops.walk_directory(directory_path)

View File

@@ -0,0 +1,42 @@
import os
from typing import List, Union, TypedDict, Literal
from typing_extensions import TypeGuard
class FileInfo(TypedDict):
name: str
path: str
type: Literal["file"]
size: int
class DirectoryInfo(TypedDict):
name: str
path: str
type: Literal["directory"]
FileSystemItem = Union[FileInfo, DirectoryInfo]
def is_file_info(item: FileSystemItem) -> TypeGuard[FileInfo]:
return item["type"] == "file"
class FileSystemOperations:
@staticmethod
def walk_directory(directory: str) -> List[FileSystemItem]:
file_list: List[FileSystemItem] = []
for root, dirs, files in os.walk(directory):
for name in files:
file_path = os.path.join(root, name)
relative_path = os.path.relpath(file_path, directory)
file_list.append({
"name": name,
"path": relative_path,
"type": "file",
"size": os.path.getsize(file_path)
})
for name in dirs:
dir_path = os.path.join(root, name)
relative_path = os.path.relpath(dir_path, directory)
file_list.append({
"name": name,
"path": relative_path,
"type": "directory"
})
return file_list

View File

@@ -8,7 +8,7 @@ import zipfile
from dataclasses import dataclass
from functools import cached_property
from pathlib import Path
from typing import TypedDict
from typing import TypedDict, Optional
import requests
from typing_extensions import NotRequired
@@ -132,12 +132,13 @@ class FrontendManager:
return match_result.group(1), match_result.group(2), match_result.group(3)
@classmethod
def init_frontend_unsafe(cls, version_string: str) -> str:
def init_frontend_unsafe(cls, version_string: str, provider: Optional[FrontEndProvider] = None) -> str:
"""
Initializes the frontend for the specified version.
Args:
version_string (str): The version string.
provider (FrontEndProvider, optional): The provider to use. Defaults to None.
Returns:
str: The path to the initialized frontend.
@@ -150,7 +151,7 @@ class FrontendManager:
return cls.DEFAULT_FRONTEND_PATH
repo_owner, repo_name, version = cls.parse_version_string(version_string)
provider = FrontEndProvider(repo_owner, repo_name)
provider = provider or FrontEndProvider(repo_owner, repo_name)
release = provider.get_release(version)
semantic_version = release["tag_name"].lstrip("v")
@@ -158,15 +159,21 @@ class FrontendManager:
Path(cls.CUSTOM_FRONTENDS_ROOT) / provider.folder_name / semantic_version
)
if not os.path.exists(web_root):
os.makedirs(web_root, exist_ok=True)
logging.info(
"Downloading frontend(%s) version(%s) to (%s)",
provider.folder_name,
semantic_version,
web_root,
)
logging.debug(release)
download_release_asset_zip(release, destination_path=web_root)
try:
os.makedirs(web_root, exist_ok=True)
logging.info(
"Downloading frontend(%s) version(%s) to (%s)",
provider.folder_name,
semantic_version,
web_root,
)
logging.debug(release)
download_release_asset_zip(release, destination_path=web_root)
finally:
# Clean up the directory if it is empty, i.e. the download failed
if not os.listdir(web_root):
os.rmdir(web_root)
return web_root
@classmethod

31
app/logger.py Normal file
View File

@@ -0,0 +1,31 @@
import logging
from logging.handlers import MemoryHandler
from collections import deque
logs = None
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
def get_logs():
return "\n".join([formatter.format(x) for x in logs])
def setup_logger(verbose: bool = False, capacity: int = 300):
global logs
if logs:
return
# Setup default global logger
logger = logging.getLogger()
logger.setLevel(logging.DEBUG if verbose else logging.INFO)
stream_handler = logging.StreamHandler()
stream_handler.setFormatter(logging.Formatter("%(message)s"))
logger.addHandler(stream_handler)
# Create a memory handler with a deque as its buffer
logs = deque(maxlen=capacity)
memory_handler = MemoryHandler(capacity, flushLevel=logging.INFO)
memory_handler.buffer = logs
memory_handler.setFormatter(formatter)
logger.addHandler(memory_handler)

View File

@@ -92,6 +92,10 @@ class LatentPreviewMethod(enum.Enum):
parser.add_argument("--preview-method", type=LatentPreviewMethod, default=LatentPreviewMethod.NoPreviews, help="Default preview method for sampler nodes.", action=EnumAction)
cache_group = parser.add_mutually_exclusive_group()
cache_group.add_argument("--cache-classic", action="store_true", help="Use the old style (aggressive) caching.")
cache_group.add_argument("--cache-lru", type=int, default=0, help="Use LRU caching with a maximum of N node results cached. May use more RAM/VRAM.")
attn_group = parser.add_mutually_exclusive_group()
attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization. Ignored when xformers is used.")
attn_group.add_argument("--use-quad-cross-attention", action="store_true", help="Use the sub-quadratic cross attention optimization . Ignored when xformers is used.")
@@ -112,10 +116,14 @@ vram_group.add_argument("--lowvram", action="store_true", help="Split the unet i
vram_group.add_argument("--novram", action="store_true", help="When lowvram isn't enough.")
vram_group.add_argument("--cpu", action="store_true", help="To use the CPU for everything (slow).")
parser.add_argument("--reserve-vram", type=float, default=None, help="Set the amount of vram in GB you want to reserve for use by your OS/other software. By default some amount is reverved depending on your OS.")
parser.add_argument("--default-hashing-function", type=str, choices=['md5', 'sha1', 'sha256', 'sha512'], default='sha256', help="Allows you to choose the hash function to use for duplicate filename / contents comparison. Default is sha256.")
parser.add_argument("--disable-smart-memory", action="store_true", help="Force ComfyUI to agressively offload to regular ram instead of keeping models in vram when it can.")
parser.add_argument("--deterministic", action="store_true", help="Make pytorch use slower deterministic algorithms when it can. Note that this might not make images deterministic in all cases.")
parser.add_argument("--fast", action="store_true", help="Enable some untested and potentially quality deteriorating optimizations.")
parser.add_argument("--dont-print-server", action="store_true", help="Don't print server output.")
parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test for CI.")
@@ -171,10 +179,3 @@ if args.windows_standalone_build:
if args.disable_auto_launch:
args.auto_launch = False
import logging
logging_level = logging.INFO
if args.verbose:
logging_level = logging.DEBUG
logging.basicConfig(format="%(message)s", level=logging_level)

View File

@@ -88,10 +88,11 @@ class CLIPTextModel_(torch.nn.Module):
heads = config_dict["num_attention_heads"]
intermediate_size = config_dict["intermediate_size"]
intermediate_activation = config_dict["hidden_act"]
num_positions = config_dict["max_position_embeddings"]
self.eos_token_id = config_dict["eos_token_id"]
super().__init__()
self.embeddings = CLIPEmbeddings(embed_dim, dtype=dtype, device=device, operations=operations)
self.embeddings = CLIPEmbeddings(embed_dim, num_positions=num_positions, dtype=dtype, device=device, operations=operations)
self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations)
self.final_layer_norm = operations.LayerNorm(embed_dim, dtype=dtype, device=device)
@@ -123,7 +124,6 @@ class CLIPTextModel(torch.nn.Module):
self.text_model = CLIPTextModel_(config_dict, dtype, device, operations)
embed_dim = config_dict["hidden_size"]
self.text_projection = operations.Linear(embed_dim, embed_dim, bias=False, dtype=dtype, device=device)
self.text_projection.weight.copy_(torch.eye(embed_dim))
self.dtype = dtype
def get_input_embeddings(self):

View File

@@ -34,6 +34,8 @@ import comfy.t2i_adapter.adapter
import comfy.ldm.cascade.controlnet
import comfy.cldm.mmdit
import comfy.ldm.hydit.controlnet
import comfy.ldm.flux.controlnet
def broadcast_image_to(tensor, target_batch_size, batched_number):
current_batch_size = tensor.shape[0]
@@ -146,7 +148,7 @@ class ControlBase:
elif self.strength_type == StrengthType.LINEAR_UP:
x *= (self.strength ** float(len(control_output) - i))
if x.dtype != output_dtype:
if output_dtype is not None and x.dtype != output_dtype:
x = x.to(output_dtype)
out[key].append(x)
@@ -204,7 +206,6 @@ class ControlNet(ControlBase):
if self.manual_cast_dtype is not None:
dtype = self.manual_cast_dtype
output_dtype = x_noisy.dtype
if self.cond_hint is None or x_noisy.shape[2] * self.compression_ratio != self.cond_hint.shape[2] or x_noisy.shape[3] * self.compression_ratio != self.cond_hint.shape[3]:
if self.cond_hint is not None:
del self.cond_hint
@@ -234,7 +235,7 @@ class ControlNet(ControlBase):
x_noisy = self.model_sampling_current.calculate_input(t, x_noisy)
control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.to(dtype), context=context.to(dtype), **extra)
return self.control_merge(control, control_prev, output_dtype)
return self.control_merge(control, control_prev, output_dtype=None)
def copy(self):
c = ControlNet(None, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype)
@@ -389,7 +390,8 @@ def controlnet_config(sd):
else:
operations = comfy.ops.disable_weight_init
return model_config, operations, load_device, unet_dtype, manual_cast_dtype
offload_device = comfy.model_management.unet_offload_device()
return model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device
def controlnet_load_state_dict(control_model, sd):
missing, unexpected = control_model.load_state_dict(sd, strict=False)
@@ -403,12 +405,12 @@ def controlnet_load_state_dict(control_model, sd):
def load_controlnet_mmdit(sd):
new_sd = comfy.model_detection.convert_diffusers_mmdit(sd, "")
model_config, operations, load_device, unet_dtype, manual_cast_dtype = controlnet_config(new_sd)
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(new_sd)
num_blocks = comfy.model_detection.count_blocks(new_sd, 'joint_blocks.{}.')
for k in sd:
new_sd[k] = sd[k]
control_model = comfy.cldm.mmdit.ControlNet(num_blocks=num_blocks, operations=operations, device=load_device, dtype=unet_dtype, **model_config.unet_config)
control_model = comfy.cldm.mmdit.ControlNet(num_blocks=num_blocks, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
control_model = controlnet_load_state_dict(control_model, new_sd)
latent_format = comfy.latent_formats.SD3()
@@ -416,10 +418,11 @@ def load_controlnet_mmdit(sd):
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
return control
def load_controlnet_hunyuandit(controlnet_data):
model_config, operations, load_device, unet_dtype, manual_cast_dtype = controlnet_config(controlnet_data)
control_model = comfy.ldm.hydit.controlnet.HunYuanControlNet(operations=operations, device=load_device, dtype=unet_dtype)
def load_controlnet_hunyuandit(controlnet_data):
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(controlnet_data)
control_model = comfy.ldm.hydit.controlnet.HunYuanControlNet(operations=operations, device=offload_device, dtype=unet_dtype)
control_model = controlnet_load_state_dict(control_model, controlnet_data)
latent_format = comfy.latent_formats.SDXL()
@@ -427,6 +430,33 @@ def load_controlnet_hunyuandit(controlnet_data):
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds, strength_type=StrengthType.CONSTANT)
return control
def load_controlnet_flux_xlabs(sd):
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(sd)
control_model = comfy.ldm.flux.controlnet.ControlNetFlux(operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
control_model = controlnet_load_state_dict(control_model, sd)
extra_conds = ['y', 'guidance']
control = ControlNet(control_model, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
return control
def load_controlnet_flux_instantx(sd):
new_sd = comfy.model_detection.convert_diffusers_mmdit(sd, "")
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(new_sd)
for k in sd:
new_sd[k] = sd[k]
num_union_modes = 0
union_cnet = "controlnet_mode_embedder.weight"
if union_cnet in new_sd:
num_union_modes = new_sd[union_cnet].shape[0]
control_model = comfy.ldm.flux.controlnet.ControlNetFlux(latent_input=True, num_union_modes=num_union_modes, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
control_model = controlnet_load_state_dict(control_model, new_sd)
latent_format = comfy.latent_formats.Flux()
extra_conds = ['y', 'guidance']
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
return control
def load_controlnet(ckpt_path, model=None):
controlnet_data = comfy.utils.load_torch_file(ckpt_path, safe_load=True)
if 'after_proj_list.18.bias' in controlnet_data.keys(): #Hunyuan DiT
@@ -489,7 +519,12 @@ def load_controlnet(ckpt_path, model=None):
logging.warning("leftover keys: {}".format(leftover_keys))
controlnet_data = new_sd
elif "controlnet_blocks.0.weight" in controlnet_data: #SD3 diffusers format
return load_controlnet_mmdit(controlnet_data)
if "double_blocks.0.img_attn.norm.key_norm.scale" in controlnet_data:
return load_controlnet_flux_xlabs(controlnet_data)
elif "pos_embed_input.proj.weight" in controlnet_data:
return load_controlnet_mmdit(controlnet_data)
elif "controlnet_x_embedder.weight" in controlnet_data:
return load_controlnet_flux_instantx(controlnet_data)
pth_key = 'control_model.zero_convs.0.0.weight'
pth = False
@@ -521,6 +556,7 @@ def load_controlnet(ckpt_path, model=None):
if manual_cast_dtype is not None:
controlnet_config["operations"] = comfy.ops.manual_cast
controlnet_config["dtype"] = unet_dtype
controlnet_config["device"] = comfy.model_management.unet_offload_device()
controlnet_config.pop("out_channels")
controlnet_config["hint_channels"] = controlnet_data["{}input_hint_block.0.weight".format(prefix)].shape[1]
control_model = comfy.cldm.cldm.ControlNet(**controlnet_config)

View File

@@ -22,7 +22,7 @@ def load_diffusers(model_path, output_vae=True, output_clip=True, embedding_dire
if text_encoder2_path is not None:
text_encoder_paths.append(text_encoder2_path)
unet = comfy.sd.load_unet(unet_path)
unet = comfy.sd.load_diffusion_model(unet_path)
clip = None
if output_clip:

62
comfy/float.py Normal file
View File

@@ -0,0 +1,62 @@
import torch
import math
def calc_mantissa(abs_x, exponent, normal_mask, MANTISSA_BITS, EXPONENT_BIAS, generator=None):
mantissa_scaled = torch.where(
normal_mask,
(abs_x / (2.0 ** (exponent - EXPONENT_BIAS)) - 1.0) * (2**MANTISSA_BITS),
(abs_x / (2.0 ** (-EXPONENT_BIAS + 1 - MANTISSA_BITS)))
)
mantissa_scaled += torch.rand(mantissa_scaled.size(), dtype=mantissa_scaled.dtype, layout=mantissa_scaled.layout, device=mantissa_scaled.device, generator=generator)
return mantissa_scaled.floor() / (2**MANTISSA_BITS)
#Not 100% sure about this
def manual_stochastic_round_to_float8(x, dtype, generator=None):
if dtype == torch.float8_e4m3fn:
EXPONENT_BITS, MANTISSA_BITS, EXPONENT_BIAS = 4, 3, 7
elif dtype == torch.float8_e5m2:
EXPONENT_BITS, MANTISSA_BITS, EXPONENT_BIAS = 5, 2, 15
else:
raise ValueError("Unsupported dtype")
x = x.half()
sign = torch.sign(x)
abs_x = x.abs()
sign = torch.where(abs_x == 0, 0, sign)
# Combine exponent calculation and clamping
exponent = torch.clamp(
torch.floor(torch.log2(abs_x)) + EXPONENT_BIAS,
0, 2**EXPONENT_BITS - 1
)
# Combine mantissa calculation and rounding
normal_mask = ~(exponent == 0)
abs_x[:] = calc_mantissa(abs_x, exponent, normal_mask, MANTISSA_BITS, EXPONENT_BIAS, generator=generator)
sign *= torch.where(
normal_mask,
(2.0 ** (exponent - EXPONENT_BIAS)) * (1.0 + abs_x),
(2.0 ** (-EXPONENT_BIAS + 1)) * abs_x
)
del abs_x
return sign.to(dtype=dtype)
def stochastic_rounding(value, dtype, seed=0):
if dtype == torch.float32:
return value.to(dtype=torch.float32)
if dtype == torch.float16:
return value.to(dtype=torch.float16)
if dtype == torch.bfloat16:
return value.to(dtype=torch.bfloat16)
if dtype == torch.float8_e4m3fn or dtype == torch.float8_e5m2:
generator = torch.Generator(device=value.device)
generator.manual_seed(seed)
return manual_stochastic_round_to_float8(value, dtype, generator=generator)
return value.to(dtype=dtype)

View File

@@ -9,6 +9,7 @@ from tqdm.auto import trange, tqdm
from . import utils
from . import deis
import comfy.model_patcher
import comfy.model_sampling
def append_zero(x):
return torch.cat([x, x.new_zeros([1])])
@@ -509,6 +510,9 @@ def sample_dpm_adaptive(model, x, sigma_min, sigma_max, extra_args=None, callbac
@torch.no_grad()
def sample_dpmpp_2s_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
if isinstance(model.inner_model.inner_model.model_sampling, comfy.model_sampling.CONST):
return sample_dpmpp_2s_ancestral_RF(model, x, sigmas, extra_args, callback, disable, eta, s_noise, noise_sampler)
"""Ancestral sampling with DPM-Solver++(2S) second-order steps."""
extra_args = {} if extra_args is None else extra_args
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
@@ -541,6 +545,55 @@ def sample_dpmpp_2s_ancestral(model, x, sigmas, extra_args=None, callback=None,
return x
@torch.no_grad()
def sample_dpmpp_2s_ancestral_RF(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
"""Ancestral sampling with DPM-Solver++(2S) second-order steps."""
extra_args = {} if extra_args is None else extra_args
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
s_in = x.new_ones([x.shape[0]])
sigma_fn = lambda lbda: (lbda.exp() + 1) ** -1
lambda_fn = lambda sigma: ((1-sigma)/sigma).log()
# logged_x = x.unsqueeze(0)
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
downstep_ratio = 1 + (sigmas[i+1]/sigmas[i] - 1) * eta
sigma_down = sigmas[i+1] * downstep_ratio
alpha_ip1 = 1 - sigmas[i+1]
alpha_down = 1 - sigma_down
renoise_coeff = (sigmas[i+1]**2 - sigma_down**2*alpha_ip1**2/alpha_down**2)**0.5
# sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
if sigmas[i + 1] == 0:
# Euler method
d = to_d(x, sigmas[i], denoised)
dt = sigma_down - sigmas[i]
x = x + d * dt
else:
# DPM-Solver++(2S)
if sigmas[i] == 1.0:
sigma_s = 0.9999
else:
t_i, t_down = lambda_fn(sigmas[i]), lambda_fn(sigma_down)
r = 1 / 2
h = t_down - t_i
s = t_i + r * h
sigma_s = sigma_fn(s)
# sigma_s = sigmas[i+1]
sigma_s_i_ratio = sigma_s / sigmas[i]
u = sigma_s_i_ratio * x + (1 - sigma_s_i_ratio) * denoised
D_i = model(u, sigma_s * s_in, **extra_args)
sigma_down_i_ratio = sigma_down / sigmas[i]
x = sigma_down_i_ratio * x + (1 - sigma_down_i_ratio) * D_i
# print("sigma_i", sigmas[i], "sigma_ip1", sigmas[i+1],"sigma_down", sigma_down, "sigma_down_i_ratio", sigma_down_i_ratio, "sigma_s_i_ratio", sigma_s_i_ratio, "renoise_coeff", renoise_coeff)
# Noise addition
if sigmas[i + 1] > 0 and eta > 0:
x = (alpha_ip1/alpha_down) * x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * renoise_coeff
# logged_x = torch.cat((logged_x, x.unsqueeze(0)), dim=0)
return x
@torch.no_grad()
def sample_dpmpp_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=1 / 2):
"""DPM-Solver++ (stochastic)."""

View File

@@ -141,6 +141,7 @@ class StableAudio1(LatentFormat):
latent_channels = 64
class Flux(SD3):
latent_channels = 16
def __init__(self):
self.scale_factor = 0.3611
self.shift_factor = 0.1159
@@ -162,6 +163,7 @@ class Flux(SD3):
[-0.0005, -0.0530, -0.0020],
[-0.1273, -0.0932, -0.0680]
]
self.taesd_decoder_name = "taef1_decoder"
def process_in(self, latent):
return (latent - self.shift_factor) * self.scale_factor

View File

@@ -1,4 +1,5 @@
import torch
import comfy.ops
def pad_to_patch_size(img, patch_size=(2, 2), padding_mode="circular"):
if padding_mode == "circular" and torch.jit.is_tracing() or torch.jit.is_scripting():
@@ -6,3 +7,15 @@ def pad_to_patch_size(img, patch_size=(2, 2), padding_mode="circular"):
pad_h = (patch_size[0] - img.shape[-2] % patch_size[0]) % patch_size[0]
pad_w = (patch_size[1] - img.shape[-1] % patch_size[1]) % patch_size[1]
return torch.nn.functional.pad(img, (0, pad_w, 0, pad_h), mode=padding_mode)
try:
rms_norm_torch = torch.nn.functional.rms_norm
except:
rms_norm_torch = None
def rms_norm(x, weight, eps=1e-6):
if rms_norm_torch is not None:
return rms_norm_torch(x, weight.shape, weight=comfy.ops.cast_to(weight, dtype=x.dtype, device=x.device), eps=eps)
else:
rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + eps)
return (x * rrms) * comfy.ops.cast_to(weight, dtype=x.dtype, device=x.device)

View File

@@ -0,0 +1,151 @@
#Original code can be found on: https://github.com/XLabs-AI/x-flux/blob/main/src/flux/controlnet.py
import torch
import math
from torch import Tensor, nn
from einops import rearrange, repeat
from .layers import (DoubleStreamBlock, EmbedND, LastLayer,
MLPEmbedder, SingleStreamBlock,
timestep_embedding)
from .model import Flux
import comfy.ldm.common_dit
class ControlNetFlux(Flux):
def __init__(self, latent_input=False, num_union_modes=0, image_model=None, dtype=None, device=None, operations=None, **kwargs):
super().__init__(final_layer=False, dtype=dtype, device=device, operations=operations, **kwargs)
self.main_model_double = 19
self.main_model_single = 38
# add ControlNet blocks
self.controlnet_blocks = nn.ModuleList([])
for _ in range(self.params.depth):
controlnet_block = operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device)
self.controlnet_blocks.append(controlnet_block)
self.controlnet_single_blocks = nn.ModuleList([])
for _ in range(self.params.depth_single_blocks):
self.controlnet_single_blocks.append(operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device))
self.num_union_modes = num_union_modes
self.controlnet_mode_embedder = None
if self.num_union_modes > 0:
self.controlnet_mode_embedder = operations.Embedding(self.num_union_modes, self.hidden_size, dtype=dtype, device=device)
self.gradient_checkpointing = False
self.latent_input = latent_input
self.pos_embed_input = operations.Linear(self.in_channels, self.hidden_size, bias=True, dtype=dtype, device=device)
if not self.latent_input:
self.input_hint_block = nn.Sequential(
operations.Conv2d(3, 16, 3, padding=1, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device)
)
def forward_orig(
self,
img: Tensor,
img_ids: Tensor,
controlnet_cond: Tensor,
txt: Tensor,
txt_ids: Tensor,
timesteps: Tensor,
y: Tensor,
guidance: Tensor = None,
control_type: Tensor = None,
) -> Tensor:
if img.ndim != 3 or txt.ndim != 3:
raise ValueError("Input img and txt tensors must have 3 dimensions.")
# running on sequences img
img = self.img_in(img)
if not self.latent_input:
controlnet_cond = self.input_hint_block(controlnet_cond)
controlnet_cond = rearrange(controlnet_cond, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
controlnet_cond = self.pos_embed_input(controlnet_cond)
img = img + controlnet_cond
vec = self.time_in(timestep_embedding(timesteps, 256))
if self.params.guidance_embed:
vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
vec = vec + self.vector_in(y)
txt = self.txt_in(txt)
if self.controlnet_mode_embedder is not None and len(control_type) > 0:
control_cond = self.controlnet_mode_embedder(torch.tensor(control_type, device=img.device), out_dtype=img.dtype).unsqueeze(0).repeat((txt.shape[0], 1, 1))
txt = torch.cat([control_cond, txt], dim=1)
txt_ids = torch.cat([txt_ids[:,:1], txt_ids], dim=1)
ids = torch.cat((txt_ids, img_ids), dim=1)
pe = self.pe_embedder(ids)
controlnet_double = ()
for i in range(len(self.double_blocks)):
img, txt = self.double_blocks[i](img=img, txt=txt, vec=vec, pe=pe)
controlnet_double = controlnet_double + (self.controlnet_blocks[i](img),)
img = torch.cat((txt, img), 1)
controlnet_single = ()
for i in range(len(self.single_blocks)):
img = self.single_blocks[i](img, vec=vec, pe=pe)
controlnet_single = controlnet_single + (self.controlnet_single_blocks[i](img[:, txt.shape[1] :, ...]),)
repeat = math.ceil(self.main_model_double / len(controlnet_double))
if self.latent_input:
out_input = ()
for x in controlnet_double:
out_input += (x,) * repeat
else:
out_input = (controlnet_double * repeat)
out = {"input": out_input[:self.main_model_double]}
if len(controlnet_single) > 0:
repeat = math.ceil(self.main_model_single / len(controlnet_single))
out_output = ()
if self.latent_input:
for x in controlnet_single:
out_output += (x,) * repeat
else:
out_output = (controlnet_single * repeat)
out["output"] = out_output[:self.main_model_single]
return out
def forward(self, x, timesteps, context, y, guidance=None, hint=None, **kwargs):
patch_size = 2
if self.latent_input:
hint = comfy.ldm.common_dit.pad_to_patch_size(hint, (patch_size, patch_size))
hint = rearrange(hint, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
else:
hint = hint * 2.0 - 1.0
bs, c, h, w = x.shape
x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size))
img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
h_len = ((h + (patch_size // 2)) // patch_size)
w_len = ((w + (patch_size // 2)) // patch_size)
img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
img_ids[..., 1] = img_ids[..., 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype)[:, None]
img_ids[..., 2] = img_ids[..., 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype)[None, :]
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
return self.forward_orig(img, img_ids, hint, context, txt_ids, timesteps, y, guidance, control_type=kwargs.get("control_type", []))

View File

@@ -6,6 +6,7 @@ from torch import Tensor, nn
from .math import attention, rope
import comfy.ops
import comfy.ldm.common_dit
class EmbedND(nn.Module):
@@ -63,10 +64,7 @@ class RMSNorm(torch.nn.Module):
self.scale = nn.Parameter(torch.empty((dim), dtype=dtype, device=device))
def forward(self, x: Tensor):
x_dtype = x.dtype
x = x.float()
rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6)
return (x * rrms).to(dtype=x_dtype) * comfy.ops.cast_to(self.scale, dtype=x_dtype, device=x.device)
return comfy.ldm.common_dit.rms_norm(x, self.scale, 1e-6)
class QKNorm(torch.nn.Module):
@@ -170,15 +168,15 @@ class DoubleStreamBlock(nn.Module):
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
# calculate the img bloks
img += img_mod1.gate * self.img_attn.proj(img_attn)
img += img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift)
img = img + img_mod1.gate * self.img_attn.proj(img_attn)
img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift)
# calculate the txt bloks
txt += txt_mod1.gate * self.txt_attn.proj(txt_attn)
txt += txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
if txt.dtype == torch.float16:
txt = txt.clip(-65504, 65504)
txt = torch.nan_to_num(txt, nan=0.0, posinf=65504, neginf=-65504)
return img, txt
@@ -233,7 +231,7 @@ class SingleStreamBlock(nn.Module):
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
x += mod.gate * output
if x.dtype == torch.float16:
x = x.clip(-65504, 65504)
x = torch.nan_to_num(x, nan=0.0, posinf=65504, neginf=-65504)
return x

View File

@@ -38,7 +38,7 @@ class Flux(nn.Module):
Transformer model for flow matching on sequences.
"""
def __init__(self, image_model=None, dtype=None, device=None, operations=None, **kwargs):
def __init__(self, image_model=None, final_layer=True, dtype=None, device=None, operations=None, **kwargs):
super().__init__()
self.dtype = dtype
params = FluxParams(**kwargs)
@@ -83,7 +83,8 @@ class Flux(nn.Module):
]
)
self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels, dtype=dtype, device=device, operations=operations)
if final_layer:
self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels, dtype=dtype, device=device, operations=operations)
def forward_orig(
self,
@@ -94,6 +95,7 @@ class Flux(nn.Module):
timesteps: Tensor,
y: Tensor,
guidance: Tensor = None,
control=None,
) -> Tensor:
if img.ndim != 3 or txt.ndim != 3:
raise ValueError("Input img and txt tensors must have 3 dimensions.")
@@ -112,18 +114,34 @@ class Flux(nn.Module):
ids = torch.cat((txt_ids, img_ids), dim=1)
pe = self.pe_embedder(ids)
for block in self.double_blocks:
for i, block in enumerate(self.double_blocks):
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
if control is not None: # Controlnet
control_i = control.get("input")
if i < len(control_i):
add = control_i[i]
if add is not None:
img += add
img = torch.cat((txt, img), 1)
for block in self.single_blocks:
for i, block in enumerate(self.single_blocks):
img = block(img, vec=vec, pe=pe)
if control is not None: # Controlnet
control_o = control.get("output")
if i < len(control_o):
add = control_o[i]
if add is not None:
img[:, txt.shape[1] :, ...] += add
img = img[:, txt.shape[1] :, ...]
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
return img
def forward(self, x, timestep, context, y, guidance, **kwargs):
def forward(self, x, timestep, context, y, guidance, control=None, **kwargs):
bs, c, h, w = x.shape
patch_size = 2
x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size))
@@ -138,5 +156,5 @@ class Flux(nn.Module):
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance)
out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control)
return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)[:,:,:h,:w]

View File

@@ -372,7 +372,7 @@ class HunYuanDiT(nn.Module):
for layer, block in enumerate(self.blocks):
if layer > self.depth // 2:
if controls is not None:
skip = skips.pop() + controls.pop()
skip = skips.pop() + controls.pop().to(dtype=x.dtype)
else:
skip = skips.pop()
x = block(x, c, text_states, freqs_cis_img, skip) # (N, L, D)

View File

@@ -358,7 +358,7 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_resh
disabled_xformers = True
if disabled_xformers:
return attention_pytorch(q, k, v, heads, mask)
return attention_pytorch(q, k, v, heads, mask, skip_reshape=skip_reshape)
if skip_reshape:
q, k, v = map(

View File

@@ -355,29 +355,9 @@ class RMSNorm(torch.nn.Module):
else:
self.register_parameter("weight", None)
def _norm(self, x):
"""
Apply the RMSNorm normalization to the input tensor.
Args:
x (torch.Tensor): The input tensor.
Returns:
torch.Tensor: The normalized tensor.
"""
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
"""
Forward pass through the RMSNorm layer.
Args:
x (torch.Tensor): The input tensor.
Returns:
torch.Tensor: The output tensor after applying RMSNorm.
"""
x = self._norm(x)
if self.learnable_scale:
return x * self.weight.to(device=x.device, dtype=x.dtype)
else:
return x
return comfy.ldm.common_dit.rms_norm(x, self.weight, self.eps)
class SwiGLUFeedForward(nn.Module):

View File

@@ -16,8 +16,12 @@
along with this program. If not, see <https://www.gnu.org/licenses/>.
"""
from __future__ import annotations
import comfy.utils
import comfy.model_management
import comfy.model_base
import logging
import torch
LORA_CLIP_MAP = {
"mlp.fc1": "mlp_fc1",
@@ -318,7 +322,235 @@ def model_lora_keys_unet(model, key_map={}):
for k in diffusers_keys:
if k.endswith(".weight"):
to = diffusers_keys[k]
key_lora = "transformer.{}".format(k[:-len(".weight")]) #simpletrainer and probably regular diffusers flux lora format
key_map[key_lora] = to
key_map["transformer.{}".format(k[:-len(".weight")])] = to #simpletrainer and probably regular diffusers flux lora format
key_map["lycoris_{}".format(k[:-len(".weight")].replace(".", "_"))] = to #simpletrainer lycoris
return key_map
def weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype):
dora_scale = comfy.model_management.cast_to_device(dora_scale, weight.device, intermediate_dtype)
lora_diff *= alpha
weight_calc = weight + lora_diff.type(weight.dtype)
weight_norm = (
weight_calc.transpose(0, 1)
.reshape(weight_calc.shape[1], -1)
.norm(dim=1, keepdim=True)
.reshape(weight_calc.shape[1], *[1] * (weight_calc.dim() - 1))
.transpose(0, 1)
)
weight_calc *= (dora_scale / weight_norm).type(weight.dtype)
if strength != 1.0:
weight_calc -= weight
weight += strength * (weight_calc)
else:
weight[:] = weight_calc
return weight
def pad_tensor_to_shape(tensor: torch.Tensor, new_shape: list[int]) -> torch.Tensor:
"""
Pad a tensor to a new shape with zeros.
Args:
tensor (torch.Tensor): The original tensor to be padded.
new_shape (List[int]): The desired shape of the padded tensor.
Returns:
torch.Tensor: A new tensor padded with zeros to the specified shape.
Note:
If the new shape is smaller than the original tensor in any dimension,
the original tensor will be truncated in that dimension.
"""
if any([new_shape[i] < tensor.shape[i] for i in range(len(new_shape))]):
raise ValueError("The new shape must be larger than the original tensor in all dimensions")
if len(new_shape) != len(tensor.shape):
raise ValueError("The new shape must have the same number of dimensions as the original tensor")
# Create a new tensor filled with zeros
padded_tensor = torch.zeros(new_shape, dtype=tensor.dtype, device=tensor.device)
# Create slicing tuples for both tensors
orig_slices = tuple(slice(0, dim) for dim in tensor.shape)
new_slices = tuple(slice(0, dim) for dim in tensor.shape)
# Copy the original tensor into the new tensor
padded_tensor[new_slices] = tensor[orig_slices]
return padded_tensor
def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32):
for p in patches:
strength = p[0]
v = p[1]
strength_model = p[2]
offset = p[3]
function = p[4]
if function is None:
function = lambda a: a
old_weight = None
if offset is not None:
old_weight = weight
weight = weight.narrow(offset[0], offset[1], offset[2])
if strength_model != 1.0:
weight *= strength_model
if isinstance(v, list):
v = (calculate_weight(v[1:], v[0].clone(), key, intermediate_dtype=intermediate_dtype), )
if len(v) == 1:
patch_type = "diff"
elif len(v) == 2:
patch_type = v[0]
v = v[1]
if patch_type == "diff":
diff: torch.Tensor = v[0]
# An extra flag to pad the weight if the diff's shape is larger than the weight
do_pad_weight = len(v) > 1 and v[1]['pad_weight']
if do_pad_weight and diff.shape != weight.shape:
logging.info("Pad weight {} from {} to shape: {}".format(key, weight.shape, diff.shape))
weight = pad_tensor_to_shape(weight, diff.shape)
if strength != 0.0:
if diff.shape != weight.shape:
logging.warning("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, diff.shape, weight.shape))
else:
weight += function(strength * comfy.model_management.cast_to_device(diff, weight.device, weight.dtype))
elif patch_type == "lora": #lora/locon
mat1 = comfy.model_management.cast_to_device(v[0], weight.device, intermediate_dtype)
mat2 = comfy.model_management.cast_to_device(v[1], weight.device, intermediate_dtype)
dora_scale = v[4]
if v[2] is not None:
alpha = v[2] / mat2.shape[0]
else:
alpha = 1.0
if v[3] is not None:
#locon mid weights, hopefully the math is fine because I didn't properly test it
mat3 = comfy.model_management.cast_to_device(v[3], weight.device, intermediate_dtype)
final_shape = [mat2.shape[1], mat2.shape[0], mat3.shape[2], mat3.shape[3]]
mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1), mat3.transpose(0, 1).flatten(start_dim=1)).reshape(final_shape).transpose(0, 1)
try:
lora_diff = torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1)).reshape(weight.shape)
if dora_scale is not None:
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype))
else:
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
except Exception as e:
logging.error("ERROR {} {} {}".format(patch_type, key, e))
elif patch_type == "lokr":
w1 = v[0]
w2 = v[1]
w1_a = v[3]
w1_b = v[4]
w2_a = v[5]
w2_b = v[6]
t2 = v[7]
dora_scale = v[8]
dim = None
if w1 is None:
dim = w1_b.shape[0]
w1 = torch.mm(comfy.model_management.cast_to_device(w1_a, weight.device, intermediate_dtype),
comfy.model_management.cast_to_device(w1_b, weight.device, intermediate_dtype))
else:
w1 = comfy.model_management.cast_to_device(w1, weight.device, intermediate_dtype)
if w2 is None:
dim = w2_b.shape[0]
if t2 is None:
w2 = torch.mm(comfy.model_management.cast_to_device(w2_a, weight.device, intermediate_dtype),
comfy.model_management.cast_to_device(w2_b, weight.device, intermediate_dtype))
else:
w2 = torch.einsum('i j k l, j r, i p -> p r k l',
comfy.model_management.cast_to_device(t2, weight.device, intermediate_dtype),
comfy.model_management.cast_to_device(w2_b, weight.device, intermediate_dtype),
comfy.model_management.cast_to_device(w2_a, weight.device, intermediate_dtype))
else:
w2 = comfy.model_management.cast_to_device(w2, weight.device, intermediate_dtype)
if len(w2.shape) == 4:
w1 = w1.unsqueeze(2).unsqueeze(2)
if v[2] is not None and dim is not None:
alpha = v[2] / dim
else:
alpha = 1.0
try:
lora_diff = torch.kron(w1, w2).reshape(weight.shape)
if dora_scale is not None:
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype))
else:
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
except Exception as e:
logging.error("ERROR {} {} {}".format(patch_type, key, e))
elif patch_type == "loha":
w1a = v[0]
w1b = v[1]
if v[2] is not None:
alpha = v[2] / w1b.shape[0]
else:
alpha = 1.0
w2a = v[3]
w2b = v[4]
dora_scale = v[7]
if v[5] is not None: #cp decomposition
t1 = v[5]
t2 = v[6]
m1 = torch.einsum('i j k l, j r, i p -> p r k l',
comfy.model_management.cast_to_device(t1, weight.device, intermediate_dtype),
comfy.model_management.cast_to_device(w1b, weight.device, intermediate_dtype),
comfy.model_management.cast_to_device(w1a, weight.device, intermediate_dtype))
m2 = torch.einsum('i j k l, j r, i p -> p r k l',
comfy.model_management.cast_to_device(t2, weight.device, intermediate_dtype),
comfy.model_management.cast_to_device(w2b, weight.device, intermediate_dtype),
comfy.model_management.cast_to_device(w2a, weight.device, intermediate_dtype))
else:
m1 = torch.mm(comfy.model_management.cast_to_device(w1a, weight.device, intermediate_dtype),
comfy.model_management.cast_to_device(w1b, weight.device, intermediate_dtype))
m2 = torch.mm(comfy.model_management.cast_to_device(w2a, weight.device, intermediate_dtype),
comfy.model_management.cast_to_device(w2b, weight.device, intermediate_dtype))
try:
lora_diff = (m1 * m2).reshape(weight.shape)
if dora_scale is not None:
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype))
else:
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
except Exception as e:
logging.error("ERROR {} {} {}".format(patch_type, key, e))
elif patch_type == "glora":
if v[4] is not None:
alpha = v[4] / v[0].shape[0]
else:
alpha = 1.0
dora_scale = v[5]
a1 = comfy.model_management.cast_to_device(v[0].flatten(start_dim=1), weight.device, intermediate_dtype)
a2 = comfy.model_management.cast_to_device(v[1].flatten(start_dim=1), weight.device, intermediate_dtype)
b1 = comfy.model_management.cast_to_device(v[2].flatten(start_dim=1), weight.device, intermediate_dtype)
b2 = comfy.model_management.cast_to_device(v[3].flatten(start_dim=1), weight.device, intermediate_dtype)
try:
lora_diff = (torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1).to(dtype=intermediate_dtype), a2), a1)).reshape(weight.shape)
if dora_scale is not None:
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype))
else:
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
except Exception as e:
logging.error("ERROR {} {} {}".format(patch_type, key, e))
else:
logging.warning("patch type not recognized {} {}".format(patch_type, key))
if old_weight is not None:
weight = old_weight
return weight

View File

@@ -1,3 +1,21 @@
"""
This file is part of ComfyUI.
Copyright (C) 2024 Comfy
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
"""
import torch
import logging
from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep
@@ -77,10 +95,10 @@ class BaseModel(torch.nn.Module):
self.device = device
if not unet_config.get("disable_unet_model_creation", False):
if self.manual_cast_dtype is not None:
operations = comfy.ops.manual_cast
if model_config.custom_operations is None:
operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype)
else:
operations = comfy.ops.disable_weight_init
operations = model_config.custom_operations
self.diffusion_model = unet_model(**unet_config, device=device, operations=operations)
if comfy.model_management.force_channels_last():
self.diffusion_model.to(memory_format=torch.channels_last)

View File

@@ -472,9 +472,15 @@ def unet_config_from_diffusers_unet(state_dict, dtype=None):
'transformer_depth': [0, 1, 1], 'channel_mult': [1, 2, 4], 'transformer_depth_middle': -2, 'use_linear_in_transformer': False,
'context_dim': 768, 'num_head_channels': 64, 'transformer_depth_output': [0, 0, 1, 1, 1, 1],
'use_temporal_attention': False, 'use_temporal_resblock': False}
SD15_diffusers_inpaint = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, 'adm_in_channels': None,
'dtype': dtype, 'in_channels': 9, 'model_channels': 320, 'num_res_blocks': [2, 2, 2, 2], 'transformer_depth': [1, 1, 1, 1, 1, 1, 0, 0],
'channel_mult': [1, 2, 4, 4], 'transformer_depth_middle': 1, 'use_linear_in_transformer': False, 'context_dim': 768, 'num_heads': 8,
'transformer_depth_output': [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
'use_temporal_attention': False, 'use_temporal_resblock': False}
supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl, SDXL_mid_cnet, SDXL_small_cnet, SDXL_diffusers_inpaint, SSD_1B, Segmind_Vega, KOALA_700M, KOALA_1B, SD09_XS, SD_XS, SDXL_diffusers_ip2p]
supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl, SDXL_mid_cnet, SDXL_small_cnet, SDXL_diffusers_inpaint, SSD_1B, Segmind_Vega, KOALA_700M, KOALA_1B, SD09_XS, SD_XS, SDXL_diffusers_ip2p, SD15_diffusers_inpaint]
for unet_config in supported_models:
matches = True

View File

@@ -44,9 +44,15 @@ cpu_state = CPUState.GPU
total_vram = 0
lowvram_available = True
xpu_available = False
torch_version = ""
try:
torch_version = torch.version.__version__
xpu_available = (int(torch_version[0]) < 2 or (int(torch_version[0]) == 2 and int(torch_version[2]) <= 4)) and torch.xpu.is_available()
except:
pass
lowvram_available = True
if args.deterministic:
logging.info("Using deterministic algorithms for pytorch")
torch.use_deterministic_algorithms(True, warn_only=True)
@@ -66,10 +72,10 @@ if args.directml is not None:
try:
import intel_extension_for_pytorch as ipex
if torch.xpu.is_available():
xpu_available = True
_ = torch.xpu.device_count()
xpu_available = torch.xpu.is_available()
except:
pass
xpu_available = xpu_available or (hasattr(torch, "xpu") and torch.xpu.is_available())
try:
if torch.backends.mps.is_available():
@@ -189,7 +195,6 @@ VAE_DTYPES = [torch.float32]
try:
if is_nvidia():
torch_version = torch.version.__version__
if int(torch_version[0]) >= 2:
if ENABLE_PYTORCH_ATTENTION == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
ENABLE_PYTORCH_ATTENTION = True
@@ -315,17 +320,15 @@ class LoadedModel:
self.model_use_more_vram(use_more_vram)
else:
try:
if lowvram_model_memory > 0 and load_weights:
self.real_model = self.model.patch_model_lowvram(device_to=patch_model_to, lowvram_model_memory=lowvram_model_memory, force_patch_weights=force_patch_weights)
else:
self.real_model = self.model.patch_model(device_to=patch_model_to, patch_weights=load_weights)
self.real_model = self.model.patch_model(device_to=patch_model_to, lowvram_model_memory=lowvram_model_memory, load_weights=load_weights, force_patch_weights=force_patch_weights)
except Exception as e:
self.model.unpatch_model(self.model.offload_device)
self.model_unload()
raise e
if is_intel_xpu() and not args.disable_ipex_optimize:
self.real_model = ipex.optimize(self.real_model.eval(), graph_mode=True, concat_linear=True)
if is_intel_xpu() and not args.disable_ipex_optimize and self.real_model is not None:
with torch.no_grad():
self.real_model = ipex.optimize(self.real_model.eval(), inplace=True, graph_mode=True, concat_linear=True)
self.weights_loaded = True
return self.real_model
@@ -338,8 +341,9 @@ class LoadedModel:
def model_unload(self, memory_to_free=None, unpatch_weights=True):
if memory_to_free is not None:
if memory_to_free < self.model.loaded_size():
self.model.partially_unload(self.model.offload_device, memory_to_free)
return False
freed = self.model.partially_unload(self.model.offload_device, memory_to_free)
if freed >= memory_to_free:
return False
self.model.unpatch_model(self.model.offload_device, unpatch_weights=unpatch_weights)
self.model.model_patches_to(self.model.offload_device)
self.weights_loaded = self.weights_loaded and not unpatch_weights
@@ -366,8 +370,21 @@ def offloaded_memory(loaded_models, device):
offloaded_mem += m.model_offloaded_memory()
return offloaded_mem
WINDOWS = any(platform.win32_ver())
EXTRA_RESERVED_VRAM = 400 * 1024 * 1024
if WINDOWS:
EXTRA_RESERVED_VRAM = 600 * 1024 * 1024 #Windows is higher because of the shared vram issue
if args.reserve_vram is not None:
EXTRA_RESERVED_VRAM = args.reserve_vram * 1024 * 1024 * 1024
logging.debug("Reserving {}MB vram for other applications.".format(EXTRA_RESERVED_VRAM / (1024 * 1024)))
def extra_reserved_memory():
return EXTRA_RESERVED_VRAM
def minimum_inference_memory():
return (1024 * 1024 * 1024) * 1.2
return (1024 * 1024 * 1024) * 0.8 + extra_reserved_memory()
def unload_model_clones(model, unload_weights_only=True, force_unload=True):
to_unload = []
@@ -391,6 +408,8 @@ def unload_model_clones(model, unload_weights_only=True, force_unload=True):
if not force_unload:
if unload_weights_only and unload_weight == False:
return None
else:
unload_weight = True
for i in to_unload:
logging.debug("unload clone {} {}".format(i, unload_weight))
@@ -434,15 +453,15 @@ def free_memory(memory_required, device, keep_loaded=[]):
soft_empty_cache()
return unloaded_models
def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimum_memory_required=None):
def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimum_memory_required=None, force_full_load=False):
global vram_state
inference_memory = minimum_inference_memory()
extra_mem = max(inference_memory, memory_required) + 100 * 1024 * 1024
extra_mem = max(inference_memory, memory_required + extra_reserved_memory())
if minimum_memory_required is None:
minimum_memory_required = extra_mem
else:
minimum_memory_required = max(inference_memory, minimum_memory_required) + 100 * 1024 * 1024
minimum_memory_required = max(inference_memory, minimum_memory_required + extra_reserved_memory())
models = set(models)
@@ -513,7 +532,7 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
else:
vram_set_state = vram_state
lowvram_model_memory = 0
if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM):
if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM) and not force_full_load:
model_size = loaded_model.model_memory_required(torch_dev)
current_free_mem = get_free_memory(torch_dev)
lowvram_model_memory = max(64 * (1024 * 1024), (current_free_mem - minimum_memory_required), min(current_free_mem * 0.4, current_free_mem - minimum_inference_memory()))
@@ -552,7 +571,9 @@ def loaded_models(only_currently_used=False):
def cleanup_models(keep_clone_weights_loaded=False):
to_delete = []
for i in range(len(current_loaded_models)):
if sys.getrefcount(current_loaded_models[i].model) <= 2:
#TODO: very fragile function needs improvement
num_refs = sys.getrefcount(current_loaded_models[i].model)
if num_refs <= 2:
if not keep_clone_weights_loaded:
to_delete = [i] + to_delete
#TODO: find a less fragile way to do this.
@@ -659,6 +680,7 @@ def unet_manual_cast(weight_dtype, inference_device, supported_dtypes=[torch.flo
if bf16_supported and weight_dtype == torch.bfloat16:
return None
fp16_supported = should_use_fp16(inference_device, prioritize_performance=True)
for dt in supported_dtypes:
if dt == torch.float16 and fp16_supported:
return torch.float16
@@ -684,6 +706,20 @@ def text_encoder_device():
else:
return torch.device("cpu")
def text_encoder_initial_device(load_device, offload_device, model_size=0):
if load_device == offload_device or model_size <= 1024 * 1024 * 1024:
return offload_device
if is_device_mps(load_device):
return offload_device
mem_l = get_free_memory(load_device)
mem_o = get_free_memory(offload_device)
if mem_l > (mem_o * 0.5) and model_size * 1.2 < mem_l:
return load_device
else:
return offload_device
def text_encoder_dtype(device=None):
if args.fp8_e4m3fn_text_enc:
return torch.float8_e4m3fn
@@ -860,7 +896,8 @@ def pytorch_attention_flash_attention():
def force_upcast_attention_dtype():
upcast = args.force_upcast_attention
try:
if platform.mac_ver()[0] in ['14.5']: #black image bug on OSX Sonoma 14.5
macos_version = tuple(int(n) for n in platform.mac_ver()[0].split("."))
if (14, 5) <= macos_version < (14, 7): # black image bug on recent versions of MacOS
upcast = True
except:
pass
@@ -956,23 +993,23 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma
if torch.version.hip:
return True
props = torch.cuda.get_device_properties("cuda")
props = torch.cuda.get_device_properties(device)
if props.major >= 8:
return True
if props.major < 6:
return False
fp16_works = False
#FP16 is confirmed working on a 1080 (GP104) but it's a bit slower than FP32 so it should only be enabled
#when the model doesn't actually fit on the card
#TODO: actually test if GP106 and others have the same type of behavior
#FP16 is confirmed working on a 1080 (GP104) and on latest pytorch actually seems faster than fp32
nvidia_10_series = ["1080", "1070", "titan x", "p3000", "p3200", "p4000", "p4200", "p5000", "p5200", "p6000", "1060", "1050", "p40", "p100", "p6", "p4"]
for x in nvidia_10_series:
if x in props.name.lower():
fp16_works = True
if WINDOWS or manual_cast:
return True
else:
return False #weird linux behavior where fp32 is faster
if fp16_works or manual_cast:
if manual_cast:
free_model_memory = maximum_vram_for_weights(device)
if (not prioritize_performance) or model_params * 4 > free_model_memory:
return True
@@ -1012,7 +1049,7 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma
if is_intel_xpu():
return True
props = torch.cuda.get_device_properties("cuda")
props = torch.cuda.get_device_properties(device)
if props.major >= 8:
return True
@@ -1025,6 +1062,16 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma
return False
def supports_fp8_compute(device=None):
props = torch.cuda.get_device_properties(device)
if props.major >= 9:
return True
if props.major < 8:
return False
if props.minor < 9:
return False
return True
def soft_empty_cache(force=False):
global cpu_state
if cpu_state == CPUState.MPS:

View File

@@ -22,32 +22,26 @@ import inspect
import logging
import uuid
import collections
import math
import comfy.utils
import comfy.float
import comfy.model_management
import comfy.lora
from comfy.types import UnetWrapperFunction
def weight_decompose(dora_scale, weight, lora_diff, alpha, strength):
dora_scale = comfy.model_management.cast_to_device(dora_scale, weight.device, torch.float32)
lora_diff *= alpha
weight_calc = weight + lora_diff.type(weight.dtype)
weight_norm = (
weight_calc.transpose(0, 1)
.reshape(weight_calc.shape[1], -1)
.norm(dim=1, keepdim=True)
.reshape(weight_calc.shape[1], *[1] * (weight_calc.dim() - 1))
.transpose(0, 1)
)
weight_calc *= (dora_scale / weight_norm).type(weight.dtype)
if strength != 1.0:
weight_calc -= weight
weight += strength * (weight_calc)
else:
weight[:] = weight_calc
return weight
def string_to_seed(data):
crc = 0xFFFFFFFF
for byte in data:
if isinstance(byte, str):
byte = ord(byte)
crc ^= byte
for _ in range(8):
if crc & 1:
crc = (crc >> 1) ^ 0xEDB88320
else:
crc >>= 1
return crc ^ 0xFFFFFFFF
def set_model_options_patch_replace(model_options, patch, name, block_name, number, transformer_index=None):
to = model_options["transformer_options"].copy()
@@ -90,12 +84,11 @@ def wipe_lowvram_weight(m):
m.bias_function = None
class LowVramPatch:
def __init__(self, key, model_patcher):
def __init__(self, key, patches):
self.key = key
self.model_patcher = model_patcher
self.patches = patches
def __call__(self, weight):
return self.model_patcher.calculate_weight(self.model_patcher.patches[self.key], weight, self.key)
return comfy.lora.calculate_weight(self.patches[self.key], weight, self.key, intermediate_dtype=weight.dtype)
class ModelPatcher:
def __init__(self, model, load_device, offload_device, size=0, weight_inplace_update=False):
@@ -327,46 +320,36 @@ class ModelPatcher:
temp_weight = comfy.model_management.cast_to_device(weight, device_to, torch.float32, copy=True)
else:
temp_weight = weight.to(torch.float32, copy=True)
out_weight = self.calculate_weight(self.patches[key], temp_weight, key).to(weight.dtype)
out_weight = comfy.lora.calculate_weight(self.patches[key], temp_weight, key)
out_weight = comfy.float.stochastic_rounding(out_weight, weight.dtype, seed=string_to_seed(key))
if inplace_update:
comfy.utils.copy_to_param(self.model, key, out_weight)
else:
comfy.utils.set_attr_param(self.model, key, out_weight)
def patch_model(self, device_to=None, patch_weights=True):
for k in self.object_patches:
old = comfy.utils.set_attr(self.model, k, self.object_patches[k])
if k not in self.object_patches_backup:
self.object_patches_backup[k] = old
if patch_weights:
model_sd = self.model_state_dict()
for key in self.patches:
if key not in model_sd:
logging.warning("could not patch. key doesn't exist in model: {}".format(key))
continue
self.patch_weight_to_device(key, device_to)
if device_to is not None:
self.model.to(device_to)
self.model.device = device_to
self.model.model_loaded_weight_memory = self.model_size()
return self.model
def lowvram_load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False):
def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False):
mem_counter = 0
patch_counter = 0
lowvram_counter = 0
loading = []
for n, m in self.model.named_modules():
if hasattr(m, "comfy_cast_weights") or hasattr(m, "weight"):
loading.append((comfy.model_management.module_size(m), n, m))
load_completely = []
loading.sort(reverse=True)
for x in loading:
n = x[1]
m = x[2]
module_mem = x[0]
lowvram_weight = False
if hasattr(m, "comfy_cast_weights"):
module_mem = comfy.model_management.module_size(m)
if not full_load and hasattr(m, "comfy_cast_weights"):
if mem_counter + module_mem >= lowvram_model_memory:
lowvram_weight = True
lowvram_counter += 1
if m.comfy_cast_weights:
if hasattr(m, "prev_comfy_cast_weights"): #Already lowvramed
continue
weight_key = "{}.weight".format(n)
@@ -377,13 +360,13 @@ class ModelPatcher:
if force_patch_weights:
self.patch_weight_to_device(weight_key)
else:
m.weight_function = LowVramPatch(weight_key, self)
m.weight_function = LowVramPatch(weight_key, self.patches)
patch_counter += 1
if bias_key in self.patches:
if force_patch_weights:
self.patch_weight_to_device(bias_key)
else:
m.bias_function = LowVramPatch(bias_key, self)
m.bias_function = LowVramPatch(bias_key, self.patches)
patch_counter += 1
m.prev_comfy_cast_weights = m.comfy_cast_weights
@@ -394,202 +377,56 @@ class ModelPatcher:
wipe_lowvram_weight(m)
if hasattr(m, "weight"):
mem_counter += comfy.model_management.module_size(m)
param = list(m.parameters())
if len(param) > 0:
weight = param[0]
if weight.device == device_to:
continue
mem_counter += module_mem
load_completely.append((module_mem, n, m))
self.patch_weight_to_device(weight_key) #TODO: speed this up without OOM
self.patch_weight_to_device(bias_key)
m.to(device_to)
logging.debug("lowvram: loaded module regularly {} {}".format(n, m))
load_completely.sort(reverse=True)
for x in load_completely:
n = x[1]
m = x[2]
weight_key = "{}.weight".format(n)
bias_key = "{}.bias".format(n)
if hasattr(m, "comfy_patched_weights"):
if m.comfy_patched_weights == True:
continue
self.patch_weight_to_device(weight_key, device_to=device_to)
self.patch_weight_to_device(bias_key, device_to=device_to)
logging.debug("lowvram: loaded module regularly {} {}".format(n, m))
m.comfy_patched_weights = True
for x in load_completely:
x[2].to(device_to)
if lowvram_counter > 0:
logging.info("loaded in lowvram mode {}".format(lowvram_model_memory / (1024 * 1024)))
logging.info("loaded partially {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), patch_counter))
self.model.model_lowvram = True
else:
logging.info("loaded completely {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024)))
logging.info("loaded completely {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), full_load))
self.model.model_lowvram = False
if full_load:
self.model.to(device_to)
mem_counter = self.model_size()
self.model.lowvram_patch_counter += patch_counter
self.model.device = device_to
self.model.model_loaded_weight_memory = mem_counter
def patch_model(self, device_to=None, lowvram_model_memory=0, load_weights=True, force_patch_weights=False):
for k in self.object_patches:
old = comfy.utils.set_attr(self.model, k, self.object_patches[k])
if k not in self.object_patches_backup:
self.object_patches_backup[k] = old
def patch_model_lowvram(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False):
self.patch_model(device_to, patch_weights=False)
self.lowvram_load(device_to, lowvram_model_memory=lowvram_model_memory, force_patch_weights=force_patch_weights)
if lowvram_model_memory == 0:
full_load = True
else:
full_load = False
if load_weights:
self.load(device_to, lowvram_model_memory=lowvram_model_memory, force_patch_weights=force_patch_weights, full_load=full_load)
return self.model
def calculate_weight(self, patches, weight, key):
for p in patches:
strength = p[0]
v = p[1]
strength_model = p[2]
offset = p[3]
function = p[4]
if function is None:
function = lambda a: a
old_weight = None
if offset is not None:
old_weight = weight
weight = weight.narrow(offset[0], offset[1], offset[2])
if strength_model != 1.0:
weight *= strength_model
if isinstance(v, list):
v = (self.calculate_weight(v[1:], v[0].clone(), key), )
if len(v) == 1:
patch_type = "diff"
elif len(v) == 2:
patch_type = v[0]
v = v[1]
if patch_type == "diff":
w1 = v[0]
if strength != 0.0:
if w1.shape != weight.shape:
logging.warning("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, w1.shape, weight.shape))
else:
weight += function(strength * comfy.model_management.cast_to_device(w1, weight.device, weight.dtype))
elif patch_type == "lora": #lora/locon
mat1 = comfy.model_management.cast_to_device(v[0], weight.device, torch.float32)
mat2 = comfy.model_management.cast_to_device(v[1], weight.device, torch.float32)
dora_scale = v[4]
if v[2] is not None:
alpha = v[2] / mat2.shape[0]
else:
alpha = 1.0
if v[3] is not None:
#locon mid weights, hopefully the math is fine because I didn't properly test it
mat3 = comfy.model_management.cast_to_device(v[3], weight.device, torch.float32)
final_shape = [mat2.shape[1], mat2.shape[0], mat3.shape[2], mat3.shape[3]]
mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1), mat3.transpose(0, 1).flatten(start_dim=1)).reshape(final_shape).transpose(0, 1)
try:
lora_diff = torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1)).reshape(weight.shape)
if dora_scale is not None:
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength))
else:
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
except Exception as e:
logging.error("ERROR {} {} {}".format(patch_type, key, e))
elif patch_type == "lokr":
w1 = v[0]
w2 = v[1]
w1_a = v[3]
w1_b = v[4]
w2_a = v[5]
w2_b = v[6]
t2 = v[7]
dora_scale = v[8]
dim = None
if w1 is None:
dim = w1_b.shape[0]
w1 = torch.mm(comfy.model_management.cast_to_device(w1_a, weight.device, torch.float32),
comfy.model_management.cast_to_device(w1_b, weight.device, torch.float32))
else:
w1 = comfy.model_management.cast_to_device(w1, weight.device, torch.float32)
if w2 is None:
dim = w2_b.shape[0]
if t2 is None:
w2 = torch.mm(comfy.model_management.cast_to_device(w2_a, weight.device, torch.float32),
comfy.model_management.cast_to_device(w2_b, weight.device, torch.float32))
else:
w2 = torch.einsum('i j k l, j r, i p -> p r k l',
comfy.model_management.cast_to_device(t2, weight.device, torch.float32),
comfy.model_management.cast_to_device(w2_b, weight.device, torch.float32),
comfy.model_management.cast_to_device(w2_a, weight.device, torch.float32))
else:
w2 = comfy.model_management.cast_to_device(w2, weight.device, torch.float32)
if len(w2.shape) == 4:
w1 = w1.unsqueeze(2).unsqueeze(2)
if v[2] is not None and dim is not None:
alpha = v[2] / dim
else:
alpha = 1.0
try:
lora_diff = torch.kron(w1, w2).reshape(weight.shape)
if dora_scale is not None:
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength))
else:
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
except Exception as e:
logging.error("ERROR {} {} {}".format(patch_type, key, e))
elif patch_type == "loha":
w1a = v[0]
w1b = v[1]
if v[2] is not None:
alpha = v[2] / w1b.shape[0]
else:
alpha = 1.0
w2a = v[3]
w2b = v[4]
dora_scale = v[7]
if v[5] is not None: #cp decomposition
t1 = v[5]
t2 = v[6]
m1 = torch.einsum('i j k l, j r, i p -> p r k l',
comfy.model_management.cast_to_device(t1, weight.device, torch.float32),
comfy.model_management.cast_to_device(w1b, weight.device, torch.float32),
comfy.model_management.cast_to_device(w1a, weight.device, torch.float32))
m2 = torch.einsum('i j k l, j r, i p -> p r k l',
comfy.model_management.cast_to_device(t2, weight.device, torch.float32),
comfy.model_management.cast_to_device(w2b, weight.device, torch.float32),
comfy.model_management.cast_to_device(w2a, weight.device, torch.float32))
else:
m1 = torch.mm(comfy.model_management.cast_to_device(w1a, weight.device, torch.float32),
comfy.model_management.cast_to_device(w1b, weight.device, torch.float32))
m2 = torch.mm(comfy.model_management.cast_to_device(w2a, weight.device, torch.float32),
comfy.model_management.cast_to_device(w2b, weight.device, torch.float32))
try:
lora_diff = (m1 * m2).reshape(weight.shape)
if dora_scale is not None:
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength))
else:
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
except Exception as e:
logging.error("ERROR {} {} {}".format(patch_type, key, e))
elif patch_type == "glora":
if v[4] is not None:
alpha = v[4] / v[0].shape[0]
else:
alpha = 1.0
dora_scale = v[5]
a1 = comfy.model_management.cast_to_device(v[0].flatten(start_dim=1), weight.device, torch.float32)
a2 = comfy.model_management.cast_to_device(v[1].flatten(start_dim=1), weight.device, torch.float32)
b1 = comfy.model_management.cast_to_device(v[2].flatten(start_dim=1), weight.device, torch.float32)
b2 = comfy.model_management.cast_to_device(v[3].flatten(start_dim=1), weight.device, torch.float32)
try:
lora_diff = (torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1), a2), a1)).reshape(weight.shape)
if dora_scale is not None:
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength))
else:
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
except Exception as e:
logging.error("ERROR {} {} {}".format(patch_type, key, e))
else:
logging.warning("patch type not recognized {} {}".format(patch_type, key))
if old_weight is not None:
weight = old_weight
return weight
def unpatch_model(self, device_to=None, unpatch_weights=True):
if unpatch_weights:
if self.model.model_lowvram:
@@ -615,6 +452,10 @@ class ModelPatcher:
self.model.device = device_to
self.model.model_loaded_weight_memory = 0
for m in self.model.modules():
if hasattr(m, "comfy_patched_weights"):
del m.comfy_patched_weights
keys = list(self.object_patches_backup.keys())
for k in keys:
comfy.utils.set_attr(self.model, k, self.object_patches_backup[k])
@@ -624,40 +465,47 @@ class ModelPatcher:
def partially_unload(self, device_to, memory_to_free=0):
memory_freed = 0
patch_counter = 0
unload_list = []
for n, m in list(self.model.named_modules())[::-1]:
if memory_to_free < memory_freed:
break
for n, m in self.model.named_modules():
shift_lowvram = False
if hasattr(m, "comfy_cast_weights"):
module_mem = comfy.model_management.module_size(m)
weight_key = "{}.weight".format(n)
bias_key = "{}.bias".format(n)
unload_list.append((module_mem, n, m))
unload_list.sort()
for unload in unload_list:
if memory_to_free < memory_freed:
break
module_mem = unload[0]
n = unload[1]
m = unload[2]
weight_key = "{}.weight".format(n)
bias_key = "{}.bias".format(n)
if m.weight is not None and m.weight.device != device_to:
for key in [weight_key, bias_key]:
bk = self.backup.get(key, None)
if bk is not None:
if bk.inplace_update:
comfy.utils.copy_to_param(self.model, key, bk.weight)
else:
comfy.utils.set_attr_param(self.model, key, bk.weight)
self.backup.pop(key)
if hasattr(m, "comfy_patched_weights") and m.comfy_patched_weights == True:
for key in [weight_key, bias_key]:
bk = self.backup.get(key, None)
if bk is not None:
if bk.inplace_update:
comfy.utils.copy_to_param(self.model, key, bk.weight)
else:
comfy.utils.set_attr_param(self.model, key, bk.weight)
self.backup.pop(key)
m.to(device_to)
if weight_key in self.patches:
m.weight_function = LowVramPatch(weight_key, self)
patch_counter += 1
if bias_key in self.patches:
m.bias_function = LowVramPatch(bias_key, self)
patch_counter += 1
m.to(device_to)
if weight_key in self.patches:
m.weight_function = LowVramPatch(weight_key, self.patches)
patch_counter += 1
if bias_key in self.patches:
m.bias_function = LowVramPatch(bias_key, self.patches)
patch_counter += 1
m.prev_comfy_cast_weights = m.comfy_cast_weights
m.comfy_cast_weights = True
memory_freed += module_mem
logging.debug("freed {}".format(n))
m.prev_comfy_cast_weights = m.comfy_cast_weights
m.comfy_cast_weights = True
m.comfy_patched_weights = False
memory_freed += module_mem
logging.debug("freed {}".format(n))
self.model.model_lowvram = True
self.model.lowvram_patch_counter += patch_counter
@@ -665,13 +513,20 @@ class ModelPatcher:
return memory_freed
def partially_load(self, device_to, extra_memory=0):
self.unpatch_model(unpatch_weights=False)
self.patch_model(load_weights=False)
full_load = False
if self.model.model_lowvram == False:
return 0
if self.model.model_loaded_weight_memory + extra_memory > self.model_size():
pass #TODO: Full load
full_load = True
current_used = self.model.model_loaded_weight_memory
self.lowvram_load(device_to, lowvram_model_memory=current_used + extra_memory)
self.load(device_to, lowvram_model_memory=current_used + extra_memory, full_load=full_load)
return self.model.model_loaded_weight_memory - current_used
def current_loaded_device(self):
return self.model.device
def calculate_weight(self, patches, weight, key, intermediate_dtype=torch.float32):
print("WARNING the ModelPatcher.calculate_weight function is deprecated, please use: comfy.lora.calculate_weight instead")
return comfy.lora.calculate_weight(patches, weight, key, intermediate_dtype=intermediate_dtype)

View File

@@ -18,29 +18,42 @@
import torch
import comfy.model_management
from comfy.cli_args import args
def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False):
if device is None or weight.device == device:
if not copy:
if dtype is None or weight.dtype == dtype:
return weight
return weight.to(dtype=dtype, copy=copy)
def cast_to(weight, dtype=None, device=None, non_blocking=False):
return weight.to(device=device, dtype=dtype, non_blocking=non_blocking)
r = torch.empty_like(weight, dtype=dtype, device=device)
r.copy_(weight, non_blocking=non_blocking)
return r
def cast_to_input(weight, input, non_blocking=False):
return cast_to(weight, input.dtype, input.device, non_blocking=non_blocking)
def cast_to_input(weight, input, non_blocking=False, copy=True):
return cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy)
def cast_bias_weight(s, input=None, dtype=None, device=None):
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None):
if input is not None:
if dtype is None:
dtype = input.dtype
if bias_dtype is None:
bias_dtype = dtype
if device is None:
device = input.device
bias = None
non_blocking = comfy.model_management.device_should_use_non_blocking(device)
non_blocking = comfy.model_management.device_supports_non_blocking(device)
if s.bias is not None:
bias = cast_to(s.bias, dtype, device, non_blocking=non_blocking)
if s.bias_function is not None:
has_function = s.bias_function is not None
bias = cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=has_function)
if has_function:
bias = s.bias_function(bias)
weight = cast_to(s.weight, dtype, device, non_blocking=non_blocking)
if s.weight_function is not None:
has_function = s.weight_function is not None
weight = cast_to(s.weight, dtype, device, non_blocking=non_blocking, copy=has_function)
if has_function:
weight = s.weight_function(weight)
return weight, bias
@@ -238,3 +251,59 @@ class manual_cast(disable_weight_init):
class Embedding(disable_weight_init.Embedding):
comfy_cast_weights = True
def fp8_linear(self, input):
dtype = self.weight.dtype
if dtype not in [torch.float8_e4m3fn]:
return None
if len(input.shape) == 3:
inn = input.reshape(-1, input.shape[2]).to(dtype)
non_blocking = comfy.model_management.device_supports_non_blocking(input.device)
w, bias = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input.dtype)
w = w.t()
scale_weight = self.scale_weight
scale_input = self.scale_input
if scale_weight is None:
scale_weight = torch.ones((1), device=input.device, dtype=torch.float32)
if scale_input is None:
scale_input = scale_weight
if scale_input is None:
scale_input = torch.ones((1), device=input.device, dtype=torch.float32)
if bias is not None:
o = torch._scaled_mm(inn, w, out_dtype=input.dtype, bias=bias, scale_a=scale_input, scale_b=scale_weight)
else:
o = torch._scaled_mm(inn, w, out_dtype=input.dtype, scale_a=scale_input, scale_b=scale_weight)
if isinstance(o, tuple):
o = o[0]
return o.reshape((-1, input.shape[1], self.weight.shape[0]))
return None
class fp8_ops(manual_cast):
class Linear(manual_cast.Linear):
def reset_parameters(self):
self.scale_weight = None
self.scale_input = None
return None
def forward_comfy_cast_weights(self, input):
out = fp8_linear(self, input)
if out is not None:
return out
weight, bias = cast_bias_weight(self, input)
return torch.nn.functional.linear(input, weight, bias)
def pick_operations(weight_dtype, compute_dtype, load_device=None):
if compute_dtype is None or weight_dtype == compute_dtype:
return disable_weight_init
if args.fast:
if comfy.model_management.supports_fp8_compute(load_device):
return fp8_ops
return manual_cast

View File

@@ -24,6 +24,7 @@ import comfy.text_encoders.sa_t5
import comfy.text_encoders.aura_t5
import comfy.text_encoders.hydit
import comfy.text_encoders.flux
import comfy.text_encoders.long_clipl
import comfy.model_patcher
import comfy.lora
@@ -62,7 +63,7 @@ def load_lora_for_models(model, clip, lora, strength_model, strength_clip):
class CLIP:
def __init__(self, target=None, embedding_directory=None, no_init=False, tokenizer_data={}):
def __init__(self, target=None, embedding_directory=None, no_init=False, tokenizer_data={}, parameters=0, model_options={}):
if no_init:
return
params = target.params.copy()
@@ -71,20 +72,29 @@ class CLIP:
load_device = model_management.text_encoder_device()
offload_device = model_management.text_encoder_offload_device()
params['device'] = offload_device
dtype = model_management.text_encoder_dtype(load_device)
dtype = model_options.get("dtype", None)
if dtype is None:
dtype = model_management.text_encoder_dtype(load_device)
params['dtype'] = dtype
params['device'] = model_management.text_encoder_initial_device(load_device, offload_device, parameters * model_management.dtype_size(dtype))
params['model_options'] = model_options
self.cond_stage_model = clip(**(params))
for dt in self.cond_stage_model.dtypes:
if not model_management.supports_cast(load_device, dt):
load_device = offload_device
if params['device'] != offload_device:
self.cond_stage_model.to(offload_device)
logging.warning("Had to shift TE back.")
self.tokenizer = tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
self.patcher = comfy.model_patcher.ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device)
if params['device'] == load_device:
model_management.load_models_gpu([self.patcher], force_full_load=True)
self.layer_idx = None
logging.debug("CLIP model load device: {}, offload device: {}".format(load_device, offload_device))
logging.debug("CLIP model load device: {}, offload device: {}, current: {}".format(load_device, offload_device, params['device']))
def clone(self):
n = CLIP(no_init=True)
@@ -390,11 +400,14 @@ class CLIPType(Enum):
HUNYUAN_DIT = 5
FLUX = 6
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION):
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
clip_data = []
for p in ckpt_paths:
clip_data.append(comfy.utils.load_torch_file(p, safe_load=True))
return load_text_encoder_state_dicts(clip_data, embedding_directory=embedding_directory, clip_type=clip_type, model_options=model_options)
def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
clip_data = state_dicts
class EmptyClass:
pass
@@ -431,8 +444,13 @@ def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DI
clip_target.clip = comfy.text_encoders.sa_t5.SAT5Model
clip_target.tokenizer = comfy.text_encoders.sa_t5.SAT5Tokenizer
else:
clip_target.clip = sd1_clip.SD1ClipModel
clip_target.tokenizer = sd1_clip.SD1Tokenizer
w = clip_data[0].get("text_model.embeddings.position_embedding.weight", None)
if w is not None and w.shape[0] == 248:
clip_target.clip = comfy.text_encoders.long_clipl.LongClipModel
clip_target.tokenizer = comfy.text_encoders.long_clipl.LongClipTokenizer
else:
clip_target.clip = sd1_clip.SD1ClipModel
clip_target.tokenizer = sd1_clip.SD1Tokenizer
elif len(clip_data) == 2:
if clip_type == CLIPType.SD3:
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=True, clip_g=True, t5=False)
@@ -456,7 +474,11 @@ def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DI
clip_target.clip = comfy.text_encoders.sd3_clip.SD3ClipModel
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
clip = CLIP(clip_target, embedding_directory=embedding_directory)
parameters = 0
for c in clip_data:
parameters += comfy.utils.calculate_parameters(c)
clip = CLIP(clip_target, embedding_directory=embedding_directory, parameters=parameters, model_options=model_options)
for c in clip_data:
m, u = clip.load_sd(c)
if len(m) > 0:
@@ -498,15 +520,19 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl
return (model, clip, vae)
def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True):
def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}, te_model_options={}):
sd = comfy.utils.load_torch_file(ckpt_path)
sd_keys = sd.keys()
out = load_state_dict_guess_config(sd, output_vae, output_clip, output_clipvision, embedding_directory, output_model, model_options, te_model_options=te_model_options)
if out is None:
raise RuntimeError("ERROR: Could not detect model type of: {}".format(ckpt_path))
return out
def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}, te_model_options={}):
clip = None
clipvision = None
vae = None
model = None
model_patcher = None
clip_target = None
diffusion_model_prefix = model_detection.unet_prefix_from_state_dict(sd)
parameters = comfy.utils.calculate_parameters(sd, diffusion_model_prefix)
@@ -515,13 +541,18 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
model_config = model_detection.model_config_from_unet(sd, diffusion_model_prefix)
if model_config is None:
raise RuntimeError("ERROR: Could not detect model type of: {}".format(ckpt_path))
return None
unet_weight_dtype = list(model_config.supported_inference_dtypes)
if weight_dtype is not None:
unet_weight_dtype.append(weight_dtype)
unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=unet_weight_dtype)
model_config.custom_operations = model_options.get("custom_operations", None)
unet_dtype = model_options.get("weight_dtype", None)
if unet_dtype is None:
unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=unet_weight_dtype)
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
@@ -545,7 +576,8 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
if clip_target is not None:
clip_sd = model_config.process_clip_state_dict(sd)
if len(clip_sd) > 0:
clip = CLIP(clip_target, embedding_directory=embedding_directory, tokenizer_data=clip_sd)
parameters = comfy.utils.calculate_parameters(clip_sd)
clip = CLIP(clip_target, embedding_directory=embedding_directory, tokenizer_data=clip_sd, parameters=parameters, model_options=te_model_options)
m, u = clip.load_sd(clip_sd, full_model=True)
if len(m) > 0:
m_filter = list(filter(lambda a: ".logit_scale" not in a and ".transformer.text_projection.weight" not in a, m))
@@ -567,12 +599,13 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
model_patcher = comfy.model_patcher.ModelPatcher(model, load_device=load_device, offload_device=model_management.unet_offload_device())
if inital_load_device != torch.device("cpu"):
logging.info("loaded straight to GPU")
model_management.load_model_gpu(model_patcher)
model_management.load_models_gpu([model_patcher], force_full_load=True)
return (model_patcher, clip, vae, clipvision)
def load_unet_state_dict(sd, dtype=None): #load unet in diffusers or regular format
def load_diffusion_model_state_dict(sd, model_options={}): #load unet in diffusers or regular format
dtype = model_options.get("dtype", None)
#Allow loading unets from checkpoint files
diffusion_model_prefix = model_detection.unet_prefix_from_state_dict(sd)
@@ -614,6 +647,7 @@ def load_unet_state_dict(sd, dtype=None): #load unet in diffusers or regular for
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
model_config.custom_operations = model_options.get("custom_operations", None)
model = model_config.get_model(new_sd, "")
model = model.to(offload_device)
model.load_model_weights(new_sd, "")
@@ -622,24 +656,36 @@ def load_unet_state_dict(sd, dtype=None): #load unet in diffusers or regular for
logging.info("left over keys in unet: {}".format(left_over))
return comfy.model_patcher.ModelPatcher(model, load_device=load_device, offload_device=offload_device)
def load_unet(unet_path, dtype=None):
def load_diffusion_model(unet_path, model_options={}):
sd = comfy.utils.load_torch_file(unet_path)
model = load_unet_state_dict(sd, dtype=dtype)
model = load_diffusion_model_state_dict(sd, model_options=model_options)
if model is None:
logging.error("ERROR UNSUPPORTED UNET {}".format(unet_path))
raise RuntimeError("ERROR: Could not detect model type of: {}".format(unet_path))
return model
def load_unet(unet_path, dtype=None):
print("WARNING: the load_unet function has been deprecated and will be removed please switch to: load_diffusion_model")
return load_diffusion_model(unet_path, model_options={"dtype": dtype})
def load_unet_state_dict(sd, dtype=None):
print("WARNING: the load_unet_state_dict function has been deprecated and will be removed please switch to: load_diffusion_model_state_dict")
return load_diffusion_model_state_dict(sd, model_options={"dtype": dtype})
def save_checkpoint(output_path, model, clip=None, vae=None, clip_vision=None, metadata=None, extra_keys={}):
clip_sd = None
load_models = [model]
if clip is not None:
load_models.append(clip.load_model())
clip_sd = clip.get_sd()
vae_sd = None
if vae is not None:
vae_sd = vae.get_sd()
model_management.load_models_gpu(load_models, force_patch_weights=True)
clip_vision_sd = clip_vision.get_sd() if clip_vision is not None else None
sd = model.model.state_dict_for_saving(clip_sd, vae.get_sd(), clip_vision_sd)
sd = model.model.state_dict_for_saving(clip_sd, vae_sd, clip_vision_sd)
for k in extra_keys:
sd[k] = extra_keys[k]

View File

@@ -75,7 +75,6 @@ class ClipTokenWeightEncoder:
return r
class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
"""Uses the CLIP transformer encoder for text (from huggingface)"""
LAYERS = [
"last",
"pooled",
@@ -84,7 +83,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
def __init__(self, version="openai/clip-vit-large-patch14", device="cpu", max_length=77,
freeze=True, layer="last", layer_idx=None, textmodel_json_config=None, dtype=None, model_class=comfy.clip_model.CLIPTextModel,
special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=True, enable_attention_masks=False, zero_out_masked=False,
return_projected_pooled=True, return_attention_masks=False): # clip-vit-base-patch32
return_projected_pooled=True, return_attention_masks=False, model_options={}): # clip-vit-base-patch32
super().__init__()
assert layer in self.LAYERS
@@ -94,7 +93,11 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
with open(textmodel_json_config) as f:
config = json.load(f)
self.operations = comfy.ops.manual_cast
operations = model_options.get("custom_operations", None)
if operations is None:
operations = comfy.ops.manual_cast
self.operations = operations
self.transformer = model_class(config, dtype, device, self.operations)
self.num_layers = self.transformer.num_layers
@@ -552,8 +555,12 @@ class SD1Tokenizer:
def state_dict(self):
return {}
class SD1CheckpointClipModel(SDClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
super().__init__(device=device, return_projected_pooled=False, dtype=dtype, model_options=model_options)
class SD1ClipModel(torch.nn.Module):
def __init__(self, device="cpu", dtype=None, clip_name="l", clip_model=SDClipModel, name=None, **kwargs):
def __init__(self, device="cpu", dtype=None, model_options={}, clip_name="l", clip_model=SD1CheckpointClipModel, name=None, **kwargs):
super().__init__()
if name is not None:
@@ -563,7 +570,7 @@ class SD1ClipModel(torch.nn.Module):
self.clip_name = clip_name
self.clip = "clip_{}".format(self.clip_name)
setattr(self, self.clip, clip_model(device=device, dtype=dtype, **kwargs))
setattr(self, self.clip, clip_model(device=device, dtype=dtype, model_options=model_options, **kwargs))
self.dtypes = set()
if dtype is not None:

View File

@@ -3,14 +3,14 @@ import torch
import os
class SDXLClipG(sd1_clip.SDClipModel):
def __init__(self, device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, dtype=None):
def __init__(self, device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, dtype=None, model_options={}):
if layer == "penultimate":
layer="hidden"
layer_idx=-2
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_config_bigg.json")
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype,
special_tokens={"start": 49406, "end": 49407, "pad": 0}, layer_norm_hidden_state=False)
special_tokens={"start": 49406, "end": 49407, "pad": 0}, layer_norm_hidden_state=False, return_projected_pooled=True, model_options=model_options)
def load_sd(self, sd):
return super().load_sd(sd)
@@ -38,10 +38,10 @@ class SDXLTokenizer:
return {}
class SDXLClipModel(torch.nn.Module):
def __init__(self, device="cpu", dtype=None):
def __init__(self, device="cpu", dtype=None, model_options={}):
super().__init__()
self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False)
self.clip_g = SDXLClipG(device=device, dtype=dtype)
self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False, model_options=model_options)
self.clip_g = SDXLClipG(device=device, dtype=dtype, model_options=model_options)
self.dtypes = set([dtype])
def set_clip_options(self, options):
@@ -66,8 +66,8 @@ class SDXLClipModel(torch.nn.Module):
return self.clip_l.load_sd(sd)
class SDXLRefinerClipModel(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None):
super().__init__(device=device, dtype=dtype, clip_name="g", clip_model=SDXLClipG)
def __init__(self, device="cpu", dtype=None, model_options={}):
super().__init__(device=device, dtype=dtype, clip_name="g", clip_model=SDXLClipG, model_options=model_options)
class StableCascadeClipGTokenizer(sd1_clip.SDTokenizer):
@@ -79,14 +79,14 @@ class StableCascadeTokenizer(sd1_clip.SD1Tokenizer):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="g", tokenizer=StableCascadeClipGTokenizer)
class StableCascadeClipG(sd1_clip.SDClipModel):
def __init__(self, device="cpu", max_length=77, freeze=True, layer="hidden", layer_idx=-1, dtype=None):
def __init__(self, device="cpu", max_length=77, freeze=True, layer="hidden", layer_idx=-1, dtype=None, model_options={}):
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_config_bigg.json")
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype,
special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=False, enable_attention_masks=True)
special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=False, enable_attention_masks=True, return_projected_pooled=True, model_options=model_options)
def load_sd(self, sd):
return super().load_sd(sd)
class StableCascadeClipModel(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None):
super().__init__(device=device, dtype=dtype, clip_name="g", clip_model=StableCascadeClipG)
def __init__(self, device="cpu", dtype=None, model_options={}):
super().__init__(device=device, dtype=dtype, clip_name="g", clip_model=StableCascadeClipG, model_options=model_options)

View File

@@ -181,7 +181,7 @@ class SDXL(supported_models_base.BASE):
latent_format = latent_formats.SDXL
memory_usage_factor = 0.7
memory_usage_factor = 0.8
def model_type(self, state_dict, prefix=""):
if 'edm_mean' in state_dict and 'edm_std' in state_dict: #Playground V2.5
@@ -654,6 +654,7 @@ class Flux(supported_models_base.BASE):
def clip_target(self, state_dict={}):
pref = self.text_encoder_key_prefix[0]
t5_key = "{}t5xxl.transformer.encoder.final_layer_norm.weight".format(pref)
dtype_t5 = None
if t5_key in state_dict:
dtype_t5 = state_dict[t5_key].dtype
return supported_models_base.ClipTarget(comfy.text_encoders.flux.FluxTokenizer, comfy.text_encoders.flux.flux_clip(dtype_t5=dtype_t5))

View File

@@ -1,3 +1,21 @@
"""
This file is part of ComfyUI.
Copyright (C) 2024 Comfy
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
"""
import torch
from . import model_base
from . import utils
@@ -30,6 +48,7 @@ class BASE:
memory_usage_factor = 2.0
manual_cast_dtype = None
custom_operations = None
@classmethod
def matches(s, unet_config, state_dict=None):

View File

@@ -4,9 +4,9 @@ import comfy.text_encoders.t5
import os
class PT5XlModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None):
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, model_options={}):
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_pile_config_xl.json")
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 2, "pad": 1}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=True, zero_out_masked=True)
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 2, "pad": 1}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=True, zero_out_masked=True, model_options=model_options)
class PT5XlTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
@@ -18,5 +18,5 @@ class AuraT5Tokenizer(sd1_clip.SD1Tokenizer):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="pile_t5xl", tokenizer=PT5XlTokenizer)
class AuraT5Model(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None, **kwargs):
super().__init__(device=device, dtype=dtype, name="pile_t5xl", clip_model=PT5XlModel, **kwargs)
def __init__(self, device="cpu", dtype=None, model_options={}, **kwargs):
super().__init__(device=device, dtype=dtype, model_options=model_options, name="pile_t5xl", clip_model=PT5XlModel, **kwargs)

View File

@@ -6,9 +6,9 @@ import torch
import os
class T5XXLModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None):
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, model_options={}):
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_config_xxl.json")
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5)
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, model_options=model_options)
class T5XXLTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
@@ -35,11 +35,11 @@ class FluxTokenizer:
class FluxClipModel(torch.nn.Module):
def __init__(self, dtype_t5=None, device="cpu", dtype=None):
def __init__(self, dtype_t5=None, device="cpu", dtype=None, model_options={}):
super().__init__()
dtype_t5 = comfy.model_management.pick_weight_dtype(dtype_t5, dtype, device)
self.clip_l = sd1_clip.SDClipModel(device=device, dtype=dtype, return_projected_pooled=False)
self.t5xxl = T5XXLModel(device=device, dtype=dtype_t5)
self.clip_l = sd1_clip.SDClipModel(device=device, dtype=dtype, return_projected_pooled=False, model_options=model_options)
self.t5xxl = T5XXLModel(device=device, dtype=dtype_t5, model_options=model_options)
self.dtypes = set([dtype, dtype_t5])
def set_clip_options(self, options):
@@ -66,6 +66,6 @@ class FluxClipModel(torch.nn.Module):
def flux_clip(dtype_t5=None):
class FluxClipModel_(FluxClipModel):
def __init__(self, device="cpu", dtype=None):
super().__init__(dtype_t5=dtype_t5, device=device, dtype=dtype)
def __init__(self, device="cpu", dtype=None, model_options={}):
super().__init__(dtype_t5=dtype_t5, device=device, dtype=dtype, model_options=model_options)
return FluxClipModel_

View File

@@ -7,9 +7,9 @@ import os
import torch
class HyditBertModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None):
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, model_options={}):
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "hydit_clip.json")
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"start": 101, "end": 102, "pad": 0}, model_class=BertModel, enable_attention_masks=True, return_attention_masks=True)
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"start": 101, "end": 102, "pad": 0}, model_class=BertModel, enable_attention_masks=True, return_attention_masks=True, model_options=model_options)
class HyditBertTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
@@ -18,9 +18,9 @@ class HyditBertTokenizer(sd1_clip.SDTokenizer):
class MT5XLModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None):
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, model_options={}):
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "mt5_config_xl.json")
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=True, return_attention_masks=True)
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=True, return_attention_masks=True, model_options=model_options)
class MT5XLTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
@@ -50,10 +50,10 @@ class HyditTokenizer:
return {"mt5xl.spiece_model": self.mt5xl.state_dict()["spiece_model"]}
class HyditModel(torch.nn.Module):
def __init__(self, device="cpu", dtype=None):
def __init__(self, device="cpu", dtype=None, model_options={}):
super().__init__()
self.hydit_clip = HyditBertModel(dtype=dtype)
self.mt5xl = MT5XLModel(dtype=dtype)
self.hydit_clip = HyditBertModel(dtype=dtype, model_options=model_options)
self.mt5xl = MT5XLModel(dtype=dtype, model_options=model_options)
self.dtypes = set()
if dtype is not None:

View File

@@ -0,0 +1,25 @@
{
"_name_or_path": "openai/clip-vit-large-patch14",
"architectures": [
"CLIPTextModel"
],
"attention_dropout": 0.0,
"bos_token_id": 0,
"dropout": 0.0,
"eos_token_id": 49407,
"hidden_act": "quick_gelu",
"hidden_size": 768,
"initializer_factor": 1.0,
"initializer_range": 0.02,
"intermediate_size": 3072,
"layer_norm_eps": 1e-05,
"max_position_embeddings": 248,
"model_type": "clip_text_model",
"num_attention_heads": 12,
"num_hidden_layers": 12,
"pad_token_id": 1,
"projection_dim": 768,
"torch_dtype": "float32",
"transformers_version": "4.24.0",
"vocab_size": 49408
}

View File

@@ -0,0 +1,19 @@
from comfy import sd1_clip
import os
class LongClipTokenizer_(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(max_length=248, embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
class LongClipModel_(sd1_clip.SDClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "long_clipl.json")
super().__init__(device=device, textmodel_json_config=textmodel_json_config, return_projected_pooled=False, dtype=dtype, model_options=model_options)
class LongClipTokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, tokenizer=LongClipTokenizer_)
class LongClipModel(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}, **kwargs):
super().__init__(device=device, dtype=dtype, model_options=model_options, clip_model=LongClipModel_, **kwargs)

View File

@@ -4,9 +4,9 @@ import comfy.text_encoders.t5
import os
class T5BaseModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None):
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, model_options={}):
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_config_base.json")
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=True, zero_out_masked=True)
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, model_options=model_options, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=True, zero_out_masked=True)
class T5BaseTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
@@ -18,5 +18,5 @@ class SAT5Tokenizer(sd1_clip.SD1Tokenizer):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="t5base", tokenizer=T5BaseTokenizer)
class SAT5Model(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None, **kwargs):
super().__init__(device=device, dtype=dtype, name="t5base", clip_model=T5BaseModel, **kwargs)
def __init__(self, device="cpu", dtype=None, model_options={}, **kwargs):
super().__init__(device=device, dtype=dtype, model_options=model_options, name="t5base", clip_model=T5BaseModel, **kwargs)

View File

@@ -2,13 +2,13 @@ from comfy import sd1_clip
import os
class SD2ClipHModel(sd1_clip.SDClipModel):
def __init__(self, arch="ViT-H-14", device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, dtype=None):
def __init__(self, arch="ViT-H-14", device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, dtype=None, model_options={}):
if layer == "penultimate":
layer="hidden"
layer_idx=-2
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd2_clip_config.json")
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"start": 49406, "end": 49407, "pad": 0})
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"start": 49406, "end": 49407, "pad": 0}, return_projected_pooled=True, model_options=model_options)
class SD2ClipHTokenizer(sd1_clip.SDTokenizer):
def __init__(self, tokenizer_path=None, embedding_directory=None, tokenizer_data={}):
@@ -19,5 +19,5 @@ class SD2Tokenizer(sd1_clip.SD1Tokenizer):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="h", tokenizer=SD2ClipHTokenizer)
class SD2ClipModel(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None, **kwargs):
super().__init__(device=device, dtype=dtype, clip_name="h", clip_model=SD2ClipHModel, **kwargs)
def __init__(self, device="cpu", dtype=None, model_options={}, **kwargs):
super().__init__(device=device, dtype=dtype, model_options=model_options, clip_name="h", clip_model=SD2ClipHModel, **kwargs)

View File

@@ -8,14 +8,14 @@ import comfy.model_management
import logging
class T5XXLModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None):
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, model_options={}):
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_config_xxl.json")
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5)
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, model_options=model_options)
class T5XXLTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=77)
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=77)
class SD3Tokenizer:
@@ -38,24 +38,24 @@ class SD3Tokenizer:
return {}
class SD3ClipModel(torch.nn.Module):
def __init__(self, clip_l=True, clip_g=True, t5=True, dtype_t5=None, device="cpu", dtype=None):
def __init__(self, clip_l=True, clip_g=True, t5=True, dtype_t5=None, device="cpu", dtype=None, model_options={}):
super().__init__()
self.dtypes = set()
if clip_l:
self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False, return_projected_pooled=False)
self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False, return_projected_pooled=False, model_options=model_options)
self.dtypes.add(dtype)
else:
self.clip_l = None
if clip_g:
self.clip_g = sdxl_clip.SDXLClipG(device=device, dtype=dtype)
self.clip_g = sdxl_clip.SDXLClipG(device=device, dtype=dtype, model_options=model_options)
self.dtypes.add(dtype)
else:
self.clip_g = None
if t5:
dtype_t5 = comfy.model_management.pick_weight_dtype(dtype_t5, dtype, device)
self.t5xxl = T5XXLModel(device=device, dtype=dtype_t5)
self.t5xxl = T5XXLModel(device=device, dtype=dtype_t5, model_options=model_options)
self.dtypes.add(dtype_t5)
else:
self.t5xxl = None
@@ -132,6 +132,6 @@ class SD3ClipModel(torch.nn.Module):
def sd3_clip(clip_l=True, clip_g=True, t5=True, dtype_t5=None):
class SD3ClipModel_(SD3ClipModel):
def __init__(self, device="cpu", dtype=None):
super().__init__(clip_l=clip_l, clip_g=clip_g, t5=t5, dtype_t5=dtype_t5, device=device, dtype=dtype)
def __init__(self, device="cpu", dtype=None, model_options={}):
super().__init__(clip_l=clip_l, clip_g=clip_g, t5=t5, dtype_t5=dtype_t5, device=device, dtype=dtype, model_options=model_options)
return SD3ClipModel_

View File

@@ -528,6 +528,8 @@ def flux_to_diffusers(mmdit_config, output_prefix=""):
("guidance_in.out_layer.weight", "time_text_embed.guidance_embedder.linear_2.weight"),
("final_layer.adaLN_modulation.1.bias", "norm_out.linear.bias", swap_scale_shift),
("final_layer.adaLN_modulation.1.weight", "norm_out.linear.weight", swap_scale_shift),
("pos_embed_input.bias", "controlnet_x_embedder.bias"),
("pos_embed_input.weight", "controlnet_x_embedder.weight"),
}
for k in MAP_BASIC:

308
comfy_execution/caching.py Normal file
View File

@@ -0,0 +1,308 @@
import itertools
from typing import Sequence, Mapping
from comfy_execution.graph import DynamicPrompt
import nodes
from comfy_execution.graph_utils import is_link
class CacheKeySet:
def __init__(self, dynprompt, node_ids, is_changed_cache):
self.keys = {}
self.subcache_keys = {}
def add_keys(self, node_ids):
raise NotImplementedError()
def all_node_ids(self):
return set(self.keys.keys())
def get_used_keys(self):
return self.keys.values()
def get_used_subcache_keys(self):
return self.subcache_keys.values()
def get_data_key(self, node_id):
return self.keys.get(node_id, None)
def get_subcache_key(self, node_id):
return self.subcache_keys.get(node_id, None)
class Unhashable:
def __init__(self):
self.value = float("NaN")
def to_hashable(obj):
# So that we don't infinitely recurse since frozenset and tuples
# are Sequences.
if isinstance(obj, (int, float, str, bool, type(None))):
return obj
elif isinstance(obj, Mapping):
return frozenset([(to_hashable(k), to_hashable(v)) for k, v in sorted(obj.items())])
elif isinstance(obj, Sequence):
return frozenset(zip(itertools.count(), [to_hashable(i) for i in obj]))
else:
# TODO - Support other objects like tensors?
return Unhashable()
class CacheKeySetID(CacheKeySet):
def __init__(self, dynprompt, node_ids, is_changed_cache):
super().__init__(dynprompt, node_ids, is_changed_cache)
self.dynprompt = dynprompt
self.add_keys(node_ids)
def add_keys(self, node_ids):
for node_id in node_ids:
if node_id in self.keys:
continue
if not self.dynprompt.has_node(node_id):
continue
node = self.dynprompt.get_node(node_id)
self.keys[node_id] = (node_id, node["class_type"])
self.subcache_keys[node_id] = (node_id, node["class_type"])
class CacheKeySetInputSignature(CacheKeySet):
def __init__(self, dynprompt, node_ids, is_changed_cache):
super().__init__(dynprompt, node_ids, is_changed_cache)
self.dynprompt = dynprompt
self.is_changed_cache = is_changed_cache
self.add_keys(node_ids)
def include_node_id_in_input(self) -> bool:
return False
def add_keys(self, node_ids):
for node_id in node_ids:
if node_id in self.keys:
continue
if not self.dynprompt.has_node(node_id):
continue
node = self.dynprompt.get_node(node_id)
self.keys[node_id] = self.get_node_signature(self.dynprompt, node_id)
self.subcache_keys[node_id] = (node_id, node["class_type"])
def get_node_signature(self, dynprompt, node_id):
signature = []
ancestors, order_mapping = self.get_ordered_ancestry(dynprompt, node_id)
signature.append(self.get_immediate_node_signature(dynprompt, node_id, order_mapping))
for ancestor_id in ancestors:
signature.append(self.get_immediate_node_signature(dynprompt, ancestor_id, order_mapping))
return to_hashable(signature)
def get_immediate_node_signature(self, dynprompt, node_id, ancestor_order_mapping):
if not dynprompt.has_node(node_id):
# This node doesn't exist -- we can't cache it.
return [float("NaN")]
node = dynprompt.get_node(node_id)
class_type = node["class_type"]
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
signature = [class_type, self.is_changed_cache.get(node_id)]
if self.include_node_id_in_input() or (hasattr(class_def, "NOT_IDEMPOTENT") and class_def.NOT_IDEMPOTENT):
signature.append(node_id)
inputs = node["inputs"]
for key in sorted(inputs.keys()):
if is_link(inputs[key]):
(ancestor_id, ancestor_socket) = inputs[key]
ancestor_index = ancestor_order_mapping[ancestor_id]
signature.append((key,("ANCESTOR", ancestor_index, ancestor_socket)))
else:
signature.append((key, inputs[key]))
return signature
# This function returns a list of all ancestors of the given node. The order of the list is
# deterministic based on which specific inputs the ancestor is connected by.
def get_ordered_ancestry(self, dynprompt, node_id):
ancestors = []
order_mapping = {}
self.get_ordered_ancestry_internal(dynprompt, node_id, ancestors, order_mapping)
return ancestors, order_mapping
def get_ordered_ancestry_internal(self, dynprompt, node_id, ancestors, order_mapping):
if not dynprompt.has_node(node_id):
return
inputs = dynprompt.get_node(node_id)["inputs"]
input_keys = sorted(inputs.keys())
for key in input_keys:
if is_link(inputs[key]):
ancestor_id = inputs[key][0]
if ancestor_id not in order_mapping:
ancestors.append(ancestor_id)
order_mapping[ancestor_id] = len(ancestors) - 1
self.get_ordered_ancestry_internal(dynprompt, ancestor_id, ancestors, order_mapping)
class BasicCache:
def __init__(self, key_class):
self.key_class = key_class
self.initialized = False
self.dynprompt: DynamicPrompt
self.cache_key_set: CacheKeySet
self.cache = {}
self.subcaches = {}
def set_prompt(self, dynprompt, node_ids, is_changed_cache):
self.dynprompt = dynprompt
self.cache_key_set = self.key_class(dynprompt, node_ids, is_changed_cache)
self.is_changed_cache = is_changed_cache
self.initialized = True
def all_node_ids(self):
assert self.initialized
node_ids = self.cache_key_set.all_node_ids()
for subcache in self.subcaches.values():
node_ids = node_ids.union(subcache.all_node_ids())
return node_ids
def _clean_cache(self):
preserve_keys = set(self.cache_key_set.get_used_keys())
to_remove = []
for key in self.cache:
if key not in preserve_keys:
to_remove.append(key)
for key in to_remove:
del self.cache[key]
def _clean_subcaches(self):
preserve_subcaches = set(self.cache_key_set.get_used_subcache_keys())
to_remove = []
for key in self.subcaches:
if key not in preserve_subcaches:
to_remove.append(key)
for key in to_remove:
del self.subcaches[key]
def clean_unused(self):
assert self.initialized
self._clean_cache()
self._clean_subcaches()
def _set_immediate(self, node_id, value):
assert self.initialized
cache_key = self.cache_key_set.get_data_key(node_id)
self.cache[cache_key] = value
def _get_immediate(self, node_id):
if not self.initialized:
return None
cache_key = self.cache_key_set.get_data_key(node_id)
if cache_key in self.cache:
return self.cache[cache_key]
else:
return None
def _ensure_subcache(self, node_id, children_ids):
subcache_key = self.cache_key_set.get_subcache_key(node_id)
subcache = self.subcaches.get(subcache_key, None)
if subcache is None:
subcache = BasicCache(self.key_class)
self.subcaches[subcache_key] = subcache
subcache.set_prompt(self.dynprompt, children_ids, self.is_changed_cache)
return subcache
def _get_subcache(self, node_id):
assert self.initialized
subcache_key = self.cache_key_set.get_subcache_key(node_id)
if subcache_key in self.subcaches:
return self.subcaches[subcache_key]
else:
return None
def recursive_debug_dump(self):
result = []
for key in self.cache:
result.append({"key": key, "value": self.cache[key]})
for key in self.subcaches:
result.append({"subcache_key": key, "subcache": self.subcaches[key].recursive_debug_dump()})
return result
class HierarchicalCache(BasicCache):
def __init__(self, key_class):
super().__init__(key_class)
def _get_cache_for(self, node_id):
assert self.dynprompt is not None
parent_id = self.dynprompt.get_parent_node_id(node_id)
if parent_id is None:
return self
hierarchy = []
while parent_id is not None:
hierarchy.append(parent_id)
parent_id = self.dynprompt.get_parent_node_id(parent_id)
cache = self
for parent_id in reversed(hierarchy):
cache = cache._get_subcache(parent_id)
if cache is None:
return None
return cache
def get(self, node_id):
cache = self._get_cache_for(node_id)
if cache is None:
return None
return cache._get_immediate(node_id)
def set(self, node_id, value):
cache = self._get_cache_for(node_id)
assert cache is not None
cache._set_immediate(node_id, value)
def ensure_subcache_for(self, node_id, children_ids):
cache = self._get_cache_for(node_id)
assert cache is not None
return cache._ensure_subcache(node_id, children_ids)
class LRUCache(BasicCache):
def __init__(self, key_class, max_size=100):
super().__init__(key_class)
self.max_size = max_size
self.min_generation = 0
self.generation = 0
self.used_generation = {}
self.children = {}
def set_prompt(self, dynprompt, node_ids, is_changed_cache):
super().set_prompt(dynprompt, node_ids, is_changed_cache)
self.generation += 1
for node_id in node_ids:
self._mark_used(node_id)
def clean_unused(self):
while len(self.cache) > self.max_size and self.min_generation < self.generation:
self.min_generation += 1
to_remove = [key for key in self.cache if self.used_generation[key] < self.min_generation]
for key in to_remove:
del self.cache[key]
del self.used_generation[key]
if key in self.children:
del self.children[key]
self._clean_subcaches()
def get(self, node_id):
self._mark_used(node_id)
return self._get_immediate(node_id)
def _mark_used(self, node_id):
cache_key = self.cache_key_set.get_data_key(node_id)
if cache_key is not None:
self.used_generation[cache_key] = self.generation
def set(self, node_id, value):
self._mark_used(node_id)
return self._set_immediate(node_id, value)
def ensure_subcache_for(self, node_id, children_ids):
# Just uses subcaches for tracking 'live' nodes
super()._ensure_subcache(node_id, children_ids)
self.cache_key_set.add_keys(children_ids)
self._mark_used(node_id)
cache_key = self.cache_key_set.get_data_key(node_id)
self.children[cache_key] = []
for child_id in children_ids:
self._mark_used(child_id)
self.children[cache_key].append(self.cache_key_set.get_data_key(child_id))
return self

259
comfy_execution/graph.py Normal file
View File

@@ -0,0 +1,259 @@
import nodes
from comfy_execution.graph_utils import is_link
class DependencyCycleError(Exception):
pass
class NodeInputError(Exception):
pass
class NodeNotFoundError(Exception):
pass
class DynamicPrompt:
def __init__(self, original_prompt):
# The original prompt provided by the user
self.original_prompt = original_prompt
# Any extra pieces of the graph created during execution
self.ephemeral_prompt = {}
self.ephemeral_parents = {}
self.ephemeral_display = {}
def get_node(self, node_id):
if node_id in self.ephemeral_prompt:
return self.ephemeral_prompt[node_id]
if node_id in self.original_prompt:
return self.original_prompt[node_id]
raise NodeNotFoundError(f"Node {node_id} not found")
def has_node(self, node_id):
return node_id in self.original_prompt or node_id in self.ephemeral_prompt
def add_ephemeral_node(self, node_id, node_info, parent_id, display_id):
self.ephemeral_prompt[node_id] = node_info
self.ephemeral_parents[node_id] = parent_id
self.ephemeral_display[node_id] = display_id
def get_real_node_id(self, node_id):
while node_id in self.ephemeral_parents:
node_id = self.ephemeral_parents[node_id]
return node_id
def get_parent_node_id(self, node_id):
return self.ephemeral_parents.get(node_id, None)
def get_display_node_id(self, node_id):
while node_id in self.ephemeral_display:
node_id = self.ephemeral_display[node_id]
return node_id
def all_node_ids(self):
return set(self.original_prompt.keys()).union(set(self.ephemeral_prompt.keys()))
def get_original_prompt(self):
return self.original_prompt
def get_input_info(class_def, input_name):
valid_inputs = class_def.INPUT_TYPES()
input_info = None
input_category = None
if "required" in valid_inputs and input_name in valid_inputs["required"]:
input_category = "required"
input_info = valid_inputs["required"][input_name]
elif "optional" in valid_inputs and input_name in valid_inputs["optional"]:
input_category = "optional"
input_info = valid_inputs["optional"][input_name]
elif "hidden" in valid_inputs and input_name in valid_inputs["hidden"]:
input_category = "hidden"
input_info = valid_inputs["hidden"][input_name]
if input_info is None:
return None, None, None
input_type = input_info[0]
if len(input_info) > 1:
extra_info = input_info[1]
else:
extra_info = {}
return input_type, input_category, extra_info
class TopologicalSort:
def __init__(self, dynprompt):
self.dynprompt = dynprompt
self.pendingNodes = {}
self.blockCount = {} # Number of nodes this node is directly blocked by
self.blocking = {} # Which nodes are blocked by this node
def get_input_info(self, unique_id, input_name):
class_type = self.dynprompt.get_node(unique_id)["class_type"]
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
return get_input_info(class_def, input_name)
def make_input_strong_link(self, to_node_id, to_input):
inputs = self.dynprompt.get_node(to_node_id)["inputs"]
if to_input not in inputs:
raise NodeInputError(f"Node {to_node_id} says it needs input {to_input}, but there is no input to that node at all")
value = inputs[to_input]
if not is_link(value):
raise NodeInputError(f"Node {to_node_id} says it needs input {to_input}, but that value is a constant")
from_node_id, from_socket = value
self.add_strong_link(from_node_id, from_socket, to_node_id)
def add_strong_link(self, from_node_id, from_socket, to_node_id):
self.add_node(from_node_id)
if to_node_id not in self.blocking[from_node_id]:
self.blocking[from_node_id][to_node_id] = {}
self.blockCount[to_node_id] += 1
self.blocking[from_node_id][to_node_id][from_socket] = True
def add_node(self, unique_id, include_lazy=False, subgraph_nodes=None):
if unique_id in self.pendingNodes:
return
self.pendingNodes[unique_id] = True
self.blockCount[unique_id] = 0
self.blocking[unique_id] = {}
inputs = self.dynprompt.get_node(unique_id)["inputs"]
for input_name in inputs:
value = inputs[input_name]
if is_link(value):
from_node_id, from_socket = value
if subgraph_nodes is not None and from_node_id not in subgraph_nodes:
continue
input_type, input_category, input_info = self.get_input_info(unique_id, input_name)
is_lazy = input_info is not None and "lazy" in input_info and input_info["lazy"]
if include_lazy or not is_lazy:
self.add_strong_link(from_node_id, from_socket, unique_id)
def get_ready_nodes(self):
return [node_id for node_id in self.pendingNodes if self.blockCount[node_id] == 0]
def pop_node(self, unique_id):
del self.pendingNodes[unique_id]
for blocked_node_id in self.blocking[unique_id]:
self.blockCount[blocked_node_id] -= 1
del self.blocking[unique_id]
def is_empty(self):
return len(self.pendingNodes) == 0
class ExecutionList(TopologicalSort):
"""
ExecutionList implements a topological dissolve of the graph. After a node is staged for execution,
it can still be returned to the graph after having further dependencies added.
"""
def __init__(self, dynprompt, output_cache):
super().__init__(dynprompt)
self.output_cache = output_cache
self.staged_node_id = None
def add_strong_link(self, from_node_id, from_socket, to_node_id):
if self.output_cache.get(from_node_id) is not None:
# Nothing to do
return
super().add_strong_link(from_node_id, from_socket, to_node_id)
def stage_node_execution(self):
assert self.staged_node_id is None
if self.is_empty():
return None, None, None
available = self.get_ready_nodes()
if len(available) == 0:
cycled_nodes = self.get_nodes_in_cycle()
# Because cycles composed entirely of static nodes are caught during initial validation,
# we will 'blame' the first node in the cycle that is not a static node.
blamed_node = cycled_nodes[0]
for node_id in cycled_nodes:
display_node_id = self.dynprompt.get_display_node_id(node_id)
if display_node_id != node_id:
blamed_node = display_node_id
break
ex = DependencyCycleError("Dependency cycle detected")
error_details = {
"node_id": blamed_node,
"exception_message": str(ex),
"exception_type": "graph.DependencyCycleError",
"traceback": [],
"current_inputs": []
}
return None, error_details, ex
self.staged_node_id = self.ux_friendly_pick_node(available)
return self.staged_node_id, None, None
def ux_friendly_pick_node(self, node_list):
# If an output node is available, do that first.
# Technically this has no effect on the overall length of execution, but it feels better as a user
# for a PreviewImage to display a result as soon as it can
# Some other heuristics could probably be used here to improve the UX further.
def is_output(node_id):
class_type = self.dynprompt.get_node(node_id)["class_type"]
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
if hasattr(class_def, 'OUTPUT_NODE') and class_def.OUTPUT_NODE == True:
return True
return False
for node_id in node_list:
if is_output(node_id):
return node_id
#This should handle the VAEDecode -> preview case
for node_id in node_list:
for blocked_node_id in self.blocking[node_id]:
if is_output(blocked_node_id):
return node_id
#This should handle the VAELoader -> VAEDecode -> preview case
for node_id in node_list:
for blocked_node_id in self.blocking[node_id]:
for blocked_node_id1 in self.blocking[blocked_node_id]:
if is_output(blocked_node_id1):
return node_id
#TODO: this function should be improved
return node_list[0]
def unstage_node_execution(self):
assert self.staged_node_id is not None
self.staged_node_id = None
def complete_node_execution(self):
node_id = self.staged_node_id
self.pop_node(node_id)
self.staged_node_id = None
def get_nodes_in_cycle(self):
# We'll dissolve the graph in reverse topological order to leave only the nodes in the cycle.
# We're skipping some of the performance optimizations from the original TopologicalSort to keep
# the code simple (and because having a cycle in the first place is a catastrophic error)
blocked_by = { node_id: {} for node_id in self.pendingNodes }
for from_node_id in self.blocking:
for to_node_id in self.blocking[from_node_id]:
if True in self.blocking[from_node_id][to_node_id].values():
blocked_by[to_node_id][from_node_id] = True
to_remove = [node_id for node_id in blocked_by if len(blocked_by[node_id]) == 0]
while len(to_remove) > 0:
for node_id in to_remove:
for to_node_id in blocked_by:
if node_id in blocked_by[to_node_id]:
del blocked_by[to_node_id][node_id]
del blocked_by[node_id]
to_remove = [node_id for node_id in blocked_by if len(blocked_by[node_id]) == 0]
return list(blocked_by.keys())
class ExecutionBlocker:
"""
Return this from a node and any users will be blocked with the given error message.
If the message is None, execution will be blocked silently instead.
Generally, you should avoid using this functionality unless absolutely necessary. Whenever it's
possible, a lazy input will be more efficient and have a better user experience.
This functionality is useful in two cases:
1. You want to conditionally prevent an output node from executing. (Particularly a built-in node
like SaveImage. For your own output nodes, I would recommend just adding a BOOL input and using
lazy evaluation to let it conditionally disable itself.)
2. You have a node with multiple possible outputs, some of which are invalid and should not be used.
(I would recommend not making nodes like this in the future -- instead, make multiple nodes with
different outputs. Unfortunately, there are several popular existing nodes using this pattern.)
"""
def __init__(self, message):
self.message = message

View File

@@ -0,0 +1,139 @@
def is_link(obj):
if not isinstance(obj, list):
return False
if len(obj) != 2:
return False
if not isinstance(obj[0], str):
return False
if not isinstance(obj[1], int) and not isinstance(obj[1], float):
return False
return True
# The GraphBuilder is just a utility class that outputs graphs in the form expected by the ComfyUI back-end
class GraphBuilder:
_default_prefix_root = ""
_default_prefix_call_index = 0
_default_prefix_graph_index = 0
def __init__(self, prefix = None):
if prefix is None:
self.prefix = GraphBuilder.alloc_prefix()
else:
self.prefix = prefix
self.nodes = {}
self.id_gen = 1
@classmethod
def set_default_prefix(cls, prefix_root, call_index, graph_index = 0):
cls._default_prefix_root = prefix_root
cls._default_prefix_call_index = call_index
cls._default_prefix_graph_index = graph_index
@classmethod
def alloc_prefix(cls, root=None, call_index=None, graph_index=None):
if root is None:
root = GraphBuilder._default_prefix_root
if call_index is None:
call_index = GraphBuilder._default_prefix_call_index
if graph_index is None:
graph_index = GraphBuilder._default_prefix_graph_index
result = f"{root}.{call_index}.{graph_index}."
GraphBuilder._default_prefix_graph_index += 1
return result
def node(self, class_type, id=None, **kwargs):
if id is None:
id = str(self.id_gen)
self.id_gen += 1
id = self.prefix + id
if id in self.nodes:
return self.nodes[id]
node = Node(id, class_type, kwargs)
self.nodes[id] = node
return node
def lookup_node(self, id):
id = self.prefix + id
return self.nodes.get(id)
def finalize(self):
output = {}
for node_id, node in self.nodes.items():
output[node_id] = node.serialize()
return output
def replace_node_output(self, node_id, index, new_value):
node_id = self.prefix + node_id
to_remove = []
for node in self.nodes.values():
for key, value in node.inputs.items():
if is_link(value) and value[0] == node_id and value[1] == index:
if new_value is None:
to_remove.append((node, key))
else:
node.inputs[key] = new_value
for node, key in to_remove:
del node.inputs[key]
def remove_node(self, id):
id = self.prefix + id
del self.nodes[id]
class Node:
def __init__(self, id, class_type, inputs):
self.id = id
self.class_type = class_type
self.inputs = inputs
self.override_display_id = None
def out(self, index):
return [self.id, index]
def set_input(self, key, value):
if value is None:
if key in self.inputs:
del self.inputs[key]
else:
self.inputs[key] = value
def get_input(self, key):
return self.inputs.get(key)
def set_override_display_id(self, override_display_id):
self.override_display_id = override_display_id
def serialize(self):
serialized = {
"class_type": self.class_type,
"inputs": self.inputs
}
if self.override_display_id is not None:
serialized["override_display_id"] = self.override_display_id
return serialized
def add_graph_prefix(graph, outputs, prefix):
# Change the node IDs and any internal links
new_graph = {}
for node_id, node_info in graph.items():
# Make sure the added nodes have unique IDs
new_node_id = prefix + node_id
new_node = { "class_type": node_info["class_type"], "inputs": {} }
for input_name, input_value in node_info.get("inputs", {}).items():
if is_link(input_value):
new_node["inputs"][input_name] = [prefix + input_value[0], input_value[1]]
else:
new_node["inputs"][input_name] = input_value
new_graph[new_node_id] = new_node
# Change the node IDs in the outputs
new_outputs = []
for n in range(len(outputs)):
output = outputs[n]
if is_link(output):
new_outputs.append([prefix + output[0], output[1]])
else:
new_outputs.append(output)
return new_graph, tuple(new_outputs)

View File

@@ -333,6 +333,25 @@ class VAESave:
comfy.utils.save_torch_file(vae.get_sd(), output_checkpoint, metadata=metadata)
return {}
class ModelSave:
def __init__(self):
self.output_dir = folder_paths.get_output_directory()
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"filename_prefix": ("STRING", {"default": "diffusion_models/ComfyUI"}),},
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},}
RETURN_TYPES = ()
FUNCTION = "save"
OUTPUT_NODE = True
CATEGORY = "advanced/model_merging"
def save(self, model, filename_prefix, prompt=None, extra_pnginfo=None):
save_checkpoint(model, filename_prefix=filename_prefix, output_dir=self.output_dir, prompt=prompt, extra_pnginfo=extra_pnginfo)
return {}
NODE_CLASS_MAPPINGS = {
"ModelMergeSimple": ModelMergeSimple,
"ModelMergeBlocks": ModelMergeBlocks,
@@ -344,4 +363,9 @@ NODE_CLASS_MAPPINGS = {
"CLIPMergeAdd": CLIPAdd,
"CLIPSave": CLIPSave,
"VAESave": VAESave,
"ModelSave": ModelSave,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"CheckpointSave": "Save Checkpoint",
}

View File

@@ -4,14 +4,14 @@ class Example:
Class methods
-------------
INPUT_TYPES (dict):
INPUT_TYPES (dict):
Tell the main program input parameters of nodes.
IS_CHANGED:
optional method to control when the node is re executed.
Attributes
----------
RETURN_TYPES (`tuple`):
RETURN_TYPES (`tuple`):
The type of each element in the output tuple.
RETURN_NAMES (`tuple`):
Optional: The name of each output in the output tuple.
@@ -23,13 +23,19 @@ class Example:
Assumed to be False if not present.
CATEGORY (`str`):
The category the node should appear in the UI.
DEPRECATED (`bool`):
Indicates whether the node is deprecated. Deprecated nodes are hidden by default in the UI, but remain
functional in existing workflows that use them.
EXPERIMENTAL (`bool`):
Indicates whether the node is experimental. Experimental nodes are marked as such in the UI and may be subject to
significant changes or removal in future versions. Use with caution in production workflows.
execute(s) -> tuple || None:
The entry point method. The name of this method must be the same as the value of property `FUNCTION`.
For example, if `FUNCTION = "execute"` then this method's name must be `execute`, if `FUNCTION = "foo"` then it must be `foo`.
"""
def __init__(self):
pass
@classmethod
def INPUT_TYPES(s):
"""
@@ -54,7 +60,8 @@ class Example:
"min": 0, #Minimum value
"max": 4096, #Maximum value
"step": 64, #Slider's step
"display": "number" # Cosmetic only: display as "number" or "slider"
"display": "number", # Cosmetic only: display as "number" or "slider"
"lazy": True # Will only be evaluated if check_lazy_status requires it
}),
"float_field": ("FLOAT", {
"default": 1.0,
@@ -62,11 +69,14 @@ class Example:
"max": 10.0,
"step": 0.01,
"round": 0.001, #The value representing the precision to round to, will be set to the step value by default. Can be set to False to disable rounding.
"display": "number"}),
"display": "number",
"lazy": True
}),
"print_to_screen": (["enable", "disable"],),
"string_field": ("STRING", {
"multiline": False, #True if you want the field to look like the one on the ClipTextEncode node
"default": "Hello World!"
"default": "Hello World!",
"lazy": True
}),
},
}
@@ -80,6 +90,23 @@ class Example:
CATEGORY = "Example"
def check_lazy_status(self, image, string_field, int_field, float_field, print_to_screen):
"""
Return a list of input names that need to be evaluated.
This function will be called if there are any lazy inputs which have not yet been
evaluated. As long as you return at least one field which has not yet been evaluated
(and more exist), this function will be called again once the value of the requested
field is available.
Any evaluated inputs will be passed as arguments to this function. Any unevaluated
inputs will have the value None.
"""
if print_to_screen == "enable":
return ["int_field", "float_field", "string_field"]
else:
return []
def test(self, image, string_field, int_field, float_field, print_to_screen):
if print_to_screen == "enable":
print(f"""Your input contains:

View File

@@ -5,6 +5,7 @@ import threading
import heapq
import time
import traceback
from enum import Enum
import inspect
from typing import List, Literal, NamedTuple, Optional
@@ -12,102 +13,219 @@ import torch
import nodes
import comfy.model_management
from comfy_execution.graph import get_input_info, ExecutionList, DynamicPrompt, ExecutionBlocker
from comfy_execution.graph_utils import is_link, GraphBuilder
from comfy_execution.caching import HierarchicalCache, LRUCache, CacheKeySetInputSignature, CacheKeySetID
from comfy.cli_args import args
def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_data={}):
class ExecutionResult(Enum):
SUCCESS = 0
FAILURE = 1
PENDING = 2
class DuplicateNodeError(Exception):
pass
class IsChangedCache:
def __init__(self, dynprompt, outputs_cache):
self.dynprompt = dynprompt
self.outputs_cache = outputs_cache
self.is_changed = {}
def get(self, node_id):
if node_id in self.is_changed:
return self.is_changed[node_id]
node = self.dynprompt.get_node(node_id)
class_type = node["class_type"]
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
if not hasattr(class_def, "IS_CHANGED"):
self.is_changed[node_id] = False
return self.is_changed[node_id]
if "is_changed" in node:
self.is_changed[node_id] = node["is_changed"]
return self.is_changed[node_id]
# Intentionally do not use cached outputs here. We only want constants in IS_CHANGED
input_data_all, _ = get_input_data(node["inputs"], class_def, node_id, None)
try:
is_changed = _map_node_over_list(class_def, input_data_all, "IS_CHANGED")
node["is_changed"] = [None if isinstance(x, ExecutionBlocker) else x for x in is_changed]
except Exception as e:
logging.warning("WARNING: {}".format(e))
node["is_changed"] = float("NaN")
finally:
self.is_changed[node_id] = node["is_changed"]
return self.is_changed[node_id]
class CacheSet:
def __init__(self, lru_size=None):
if lru_size is None or lru_size == 0:
self.init_classic_cache()
else:
self.init_lru_cache(lru_size)
self.all = [self.outputs, self.ui, self.objects]
# Useful for those with ample RAM/VRAM -- allows experimenting without
# blowing away the cache every time
def init_lru_cache(self, cache_size):
self.outputs = LRUCache(CacheKeySetInputSignature, max_size=cache_size)
self.ui = LRUCache(CacheKeySetInputSignature, max_size=cache_size)
self.objects = HierarchicalCache(CacheKeySetID)
# Performs like the old cache -- dump data ASAP
def init_classic_cache(self):
self.outputs = HierarchicalCache(CacheKeySetInputSignature)
self.ui = HierarchicalCache(CacheKeySetInputSignature)
self.objects = HierarchicalCache(CacheKeySetID)
def recursive_debug_dump(self):
result = {
"outputs": self.outputs.recursive_debug_dump(),
"ui": self.ui.recursive_debug_dump(),
}
return result
def get_input_data(inputs, class_def, unique_id, outputs=None, dynprompt=None, extra_data={}):
valid_inputs = class_def.INPUT_TYPES()
input_data_all = {}
missing_keys = {}
for x in inputs:
input_data = inputs[x]
if isinstance(input_data, list):
input_type, input_category, input_info = get_input_info(class_def, x)
def mark_missing():
missing_keys[x] = True
input_data_all[x] = (None,)
if is_link(input_data) and (not input_info or not input_info.get("rawLink", False)):
input_unique_id = input_data[0]
output_index = input_data[1]
if input_unique_id not in outputs:
input_data_all[x] = (None,)
if outputs is None:
mark_missing()
continue # This might be a lazily-evaluated input
cached_output = outputs.get(input_unique_id)
if cached_output is None:
mark_missing()
continue
obj = outputs[input_unique_id][output_index]
if output_index >= len(cached_output):
mark_missing()
continue
obj = cached_output[output_index]
input_data_all[x] = obj
else:
if ("required" in valid_inputs and x in valid_inputs["required"]) or ("optional" in valid_inputs and x in valid_inputs["optional"]):
input_data_all[x] = [input_data]
elif input_category is not None:
input_data_all[x] = [input_data]
if "hidden" in valid_inputs:
h = valid_inputs["hidden"]
for x in h:
if h[x] == "PROMPT":
input_data_all[x] = [prompt]
input_data_all[x] = [dynprompt.get_original_prompt() if dynprompt is not None else {}]
if h[x] == "DYNPROMPT":
input_data_all[x] = [dynprompt]
if h[x] == "EXTRA_PNGINFO":
input_data_all[x] = [extra_data.get('extra_pnginfo', None)]
if h[x] == "UNIQUE_ID":
input_data_all[x] = [unique_id]
return input_data_all
return input_data_all, missing_keys
def map_node_over_list(obj, input_data_all, func, allow_interrupt=False):
map_node_over_list = None #Don't hook this please
def _map_node_over_list(obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None):
# check if node wants the lists
input_is_list = False
if hasattr(obj, "INPUT_IS_LIST"):
input_is_list = obj.INPUT_IS_LIST
input_is_list = getattr(obj, "INPUT_IS_LIST", False)
if len(input_data_all) == 0:
max_len_input = 0
else:
max_len_input = max([len(x) for x in input_data_all.values()])
max_len_input = max(len(x) for x in input_data_all.values())
# get a slice of inputs, repeat last input when list isn't long enough
def slice_dict(d, i):
d_new = dict()
for k,v in d.items():
d_new[k] = v[i if len(v) > i else -1]
return d_new
return {k: v[i if len(v) > i else -1] for k, v in d.items()}
results = []
def process_inputs(inputs, index=None):
if allow_interrupt:
nodes.before_node_execution()
execution_block = None
for k, v in inputs.items():
if isinstance(v, ExecutionBlocker):
execution_block = execution_block_cb(v) if execution_block_cb else v
break
if execution_block is None:
if pre_execute_cb is not None and index is not None:
pre_execute_cb(index)
results.append(getattr(obj, func)(**inputs))
else:
results.append(execution_block)
if input_is_list:
if allow_interrupt:
nodes.before_node_execution()
results.append(getattr(obj, func)(**input_data_all))
process_inputs(input_data_all, 0)
elif max_len_input == 0:
if allow_interrupt:
nodes.before_node_execution()
results.append(getattr(obj, func)())
else:
process_inputs({})
else:
for i in range(max_len_input):
if allow_interrupt:
nodes.before_node_execution()
results.append(getattr(obj, func)(**slice_dict(input_data_all, i)))
input_dict = slice_dict(input_data_all, i)
process_inputs(input_dict, i)
return results
def get_output_data(obj, input_data_all):
def merge_result_data(results, obj):
# check which outputs need concatenating
output = []
output_is_list = [False] * len(results[0])
if hasattr(obj, "OUTPUT_IS_LIST"):
output_is_list = obj.OUTPUT_IS_LIST
# merge node execution results
for i, is_list in zip(range(len(results[0])), output_is_list):
if is_list:
output.append([x for o in results for x in o[i]])
else:
output.append([o[i] for o in results])
return output
def get_output_data(obj, input_data_all, execution_block_cb=None, pre_execute_cb=None):
results = []
uis = []
return_values = map_node_over_list(obj, input_data_all, obj.FUNCTION, allow_interrupt=True)
for r in return_values:
subgraph_results = []
return_values = _map_node_over_list(obj, input_data_all, obj.FUNCTION, allow_interrupt=True, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb)
has_subgraph = False
for i in range(len(return_values)):
r = return_values[i]
if isinstance(r, dict):
if 'ui' in r:
uis.append(r['ui'])
if 'result' in r:
results.append(r['result'])
if 'expand' in r:
# Perform an expansion, but do not append results
has_subgraph = True
new_graph = r['expand']
result = r.get("result", None)
if isinstance(result, ExecutionBlocker):
result = tuple([result] * len(obj.RETURN_TYPES))
subgraph_results.append((new_graph, result))
elif 'result' in r:
result = r.get("result", None)
if isinstance(result, ExecutionBlocker):
result = tuple([result] * len(obj.RETURN_TYPES))
results.append(result)
subgraph_results.append((None, result))
else:
if isinstance(r, ExecutionBlocker):
r = tuple([r] * len(obj.RETURN_TYPES))
results.append(r)
subgraph_results.append((None, r))
output = []
if len(results) > 0:
# check which outputs need concatenating
output_is_list = [False] * len(results[0])
if hasattr(obj, "OUTPUT_IS_LIST"):
output_is_list = obj.OUTPUT_IS_LIST
# merge node execution results
for i, is_list in zip(range(len(results[0])), output_is_list):
if is_list:
output.append([x for o in results for x in o[i]])
else:
output.append([o[i] for o in results])
if has_subgraph:
output = subgraph_results
elif len(results) > 0:
output = merge_result_data(results, obj)
else:
output = []
ui = dict()
if len(uis) > 0:
ui = {k: [y for x in uis for y in x[k]] for k in uis[0].keys()}
return output, ui
return output, ui, has_subgraph
def format_value(x):
if x is None:
@@ -117,53 +235,145 @@ def format_value(x):
else:
return str(x)
def recursive_execute(server, prompt, outputs, current_item, extra_data, executed, prompt_id, outputs_ui, object_storage):
def execute(server, dynprompt, caches, current_item, extra_data, executed, prompt_id, execution_list, pending_subgraph_results):
unique_id = current_item
inputs = prompt[unique_id]['inputs']
class_type = prompt[unique_id]['class_type']
real_node_id = dynprompt.get_real_node_id(unique_id)
display_node_id = dynprompt.get_display_node_id(unique_id)
parent_node_id = dynprompt.get_parent_node_id(unique_id)
inputs = dynprompt.get_node(unique_id)['inputs']
class_type = dynprompt.get_node(unique_id)['class_type']
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
if unique_id in outputs:
return (True, None, None)
for x in inputs:
input_data = inputs[x]
if isinstance(input_data, list):
input_unique_id = input_data[0]
output_index = input_data[1]
if input_unique_id not in outputs:
result = recursive_execute(server, prompt, outputs, input_unique_id, extra_data, executed, prompt_id, outputs_ui, object_storage)
if result[0] is not True:
# Another node failed further upstream
return result
if caches.outputs.get(unique_id) is not None:
if server.client_id is not None:
cached_output = caches.ui.get(unique_id) or {}
server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": cached_output.get("output",None), "prompt_id": prompt_id }, server.client_id)
return (ExecutionResult.SUCCESS, None, None)
input_data_all = None
try:
input_data_all = get_input_data(inputs, class_def, unique_id, outputs, prompt, extra_data)
if server.client_id is not None:
server.last_node_id = unique_id
server.send_sync("executing", { "node": unique_id, "prompt_id": prompt_id }, server.client_id)
if unique_id in pending_subgraph_results:
cached_results = pending_subgraph_results[unique_id]
resolved_outputs = []
for is_subgraph, result in cached_results:
if not is_subgraph:
resolved_outputs.append(result)
else:
resolved_output = []
for r in result:
if is_link(r):
source_node, source_output = r[0], r[1]
node_output = caches.outputs.get(source_node)[source_output]
for o in node_output:
resolved_output.append(o)
obj = object_storage.get((unique_id, class_type), None)
if obj is None:
obj = class_def()
object_storage[(unique_id, class_type)] = obj
output_data, output_ui = get_output_data(obj, input_data_all)
outputs[unique_id] = output_data
if len(output_ui) > 0:
outputs_ui[unique_id] = output_ui
else:
resolved_output.append(r)
resolved_outputs.append(tuple(resolved_output))
output_data = merge_result_data(resolved_outputs, class_def)
output_ui = []
has_subgraph = False
else:
input_data_all, missing_keys = get_input_data(inputs, class_def, unique_id, caches.outputs, dynprompt, extra_data)
if server.client_id is not None:
server.send_sync("executed", { "node": unique_id, "output": output_ui, "prompt_id": prompt_id }, server.client_id)
server.last_node_id = display_node_id
server.send_sync("executing", { "node": unique_id, "display_node": display_node_id, "prompt_id": prompt_id }, server.client_id)
obj = caches.objects.get(unique_id)
if obj is None:
obj = class_def()
caches.objects.set(unique_id, obj)
if hasattr(obj, "check_lazy_status"):
required_inputs = _map_node_over_list(obj, input_data_all, "check_lazy_status", allow_interrupt=True)
required_inputs = set(sum([r for r in required_inputs if isinstance(r,list)], []))
required_inputs = [x for x in required_inputs if isinstance(x,str) and (
x not in input_data_all or x in missing_keys
)]
if len(required_inputs) > 0:
for i in required_inputs:
execution_list.make_input_strong_link(unique_id, i)
return (ExecutionResult.PENDING, None, None)
def execution_block_cb(block):
if block.message is not None:
mes = {
"prompt_id": prompt_id,
"node_id": unique_id,
"node_type": class_type,
"executed": list(executed),
"exception_message": f"Execution Blocked: {block.message}",
"exception_type": "ExecutionBlocked",
"traceback": [],
"current_inputs": [],
"current_outputs": [],
}
server.send_sync("execution_error", mes, server.client_id)
return ExecutionBlocker(None)
else:
return block
def pre_execute_cb(call_index):
GraphBuilder.set_default_prefix(unique_id, call_index, 0)
output_data, output_ui, has_subgraph = get_output_data(obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb)
if len(output_ui) > 0:
caches.ui.set(unique_id, {
"meta": {
"node_id": unique_id,
"display_node": display_node_id,
"parent_node": parent_node_id,
"real_node_id": real_node_id,
},
"output": output_ui
})
if server.client_id is not None:
server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": output_ui, "prompt_id": prompt_id }, server.client_id)
if has_subgraph:
cached_outputs = []
new_node_ids = []
new_output_ids = []
new_output_links = []
for i in range(len(output_data)):
new_graph, node_outputs = output_data[i]
if new_graph is None:
cached_outputs.append((False, node_outputs))
else:
# Check for conflicts
for node_id in new_graph.keys():
if dynprompt.has_node(node_id):
raise DuplicateNodeError(f"Attempt to add duplicate node {node_id}. Ensure node ids are unique and deterministic or use graph_utils.GraphBuilder.")
for node_id, node_info in new_graph.items():
new_node_ids.append(node_id)
display_id = node_info.get("override_display_id", unique_id)
dynprompt.add_ephemeral_node(node_id, node_info, unique_id, display_id)
# Figure out if the newly created node is an output node
class_type = node_info["class_type"]
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
if hasattr(class_def, 'OUTPUT_NODE') and class_def.OUTPUT_NODE == True:
new_output_ids.append(node_id)
for i in range(len(node_outputs)):
if is_link(node_outputs[i]):
from_node_id, from_socket = node_outputs[i][0], node_outputs[i][1]
new_output_links.append((from_node_id, from_socket))
cached_outputs.append((True, node_outputs))
new_node_ids = set(new_node_ids)
for cache in caches.all:
cache.ensure_subcache_for(unique_id, new_node_ids).clean_unused()
for node_id in new_output_ids:
execution_list.add_node(node_id)
for link in new_output_links:
execution_list.add_strong_link(link[0], link[1], unique_id)
pending_subgraph_results[unique_id] = cached_outputs
return (ExecutionResult.PENDING, None, None)
caches.outputs.set(unique_id, output_data)
except comfy.model_management.InterruptProcessingException as iex:
logging.info("Processing interrupted")
# skip formatting inputs/outputs
error_details = {
"node_id": unique_id,
"node_id": real_node_id,
}
return (False, error_details, iex)
return (ExecutionResult.FAILURE, error_details, iex)
except Exception as ex:
typ, _, tb = sys.exc_info()
exception_type = full_type_name(typ)
@@ -173,121 +383,36 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute
for name, inputs in input_data_all.items():
input_data_formatted[name] = [format_value(x) for x in inputs]
output_data_formatted = {}
for node_id, node_outputs in outputs.items():
output_data_formatted[node_id] = [[format_value(x) for x in l] for l in node_outputs]
logging.error(f"!!! Exception during processing!!! {ex}")
logging.error(f"!!! Exception during processing !!! {ex}")
logging.error(traceback.format_exc())
error_details = {
"node_id": unique_id,
"node_id": real_node_id,
"exception_message": str(ex),
"exception_type": exception_type,
"traceback": traceback.format_tb(tb),
"current_inputs": input_data_formatted,
"current_outputs": output_data_formatted
"current_inputs": input_data_formatted
}
if isinstance(ex, comfy.model_management.OOM_EXCEPTION):
logging.error("Got an OOM, unloading all loaded models.")
comfy.model_management.unload_all_models()
return (False, error_details, ex)
return (ExecutionResult.FAILURE, error_details, ex)
executed.add(unique_id)
return (True, None, None)
def recursive_will_execute(prompt, outputs, current_item, memo={}):
unique_id = current_item
if unique_id in memo:
return memo[unique_id]
inputs = prompt[unique_id]['inputs']
will_execute = []
if unique_id in outputs:
return []
for x in inputs:
input_data = inputs[x]
if isinstance(input_data, list):
input_unique_id = input_data[0]
output_index = input_data[1]
if input_unique_id not in outputs:
will_execute += recursive_will_execute(prompt, outputs, input_unique_id, memo)
memo[unique_id] = will_execute + [unique_id]
return memo[unique_id]
def recursive_output_delete_if_changed(prompt, old_prompt, outputs, current_item):
unique_id = current_item
inputs = prompt[unique_id]['inputs']
class_type = prompt[unique_id]['class_type']
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
is_changed_old = ''
is_changed = ''
to_delete = False
if hasattr(class_def, 'IS_CHANGED'):
if unique_id in old_prompt and 'is_changed' in old_prompt[unique_id]:
is_changed_old = old_prompt[unique_id]['is_changed']
if 'is_changed' not in prompt[unique_id]:
input_data_all = get_input_data(inputs, class_def, unique_id, outputs)
if input_data_all is not None:
try:
#is_changed = class_def.IS_CHANGED(**input_data_all)
is_changed = map_node_over_list(class_def, input_data_all, "IS_CHANGED")
prompt[unique_id]['is_changed'] = is_changed
except:
to_delete = True
else:
is_changed = prompt[unique_id]['is_changed']
if unique_id not in outputs:
return True
if not to_delete:
if is_changed != is_changed_old:
to_delete = True
elif unique_id not in old_prompt:
to_delete = True
elif class_type != old_prompt[unique_id]['class_type']:
to_delete = True
elif inputs == old_prompt[unique_id]['inputs']:
for x in inputs:
input_data = inputs[x]
if isinstance(input_data, list):
input_unique_id = input_data[0]
output_index = input_data[1]
if input_unique_id in outputs:
to_delete = recursive_output_delete_if_changed(prompt, old_prompt, outputs, input_unique_id)
else:
to_delete = True
if to_delete:
break
else:
to_delete = True
if to_delete:
d = outputs.pop(unique_id)
del d
return to_delete
return (ExecutionResult.SUCCESS, None, None)
class PromptExecutor:
def __init__(self, server):
def __init__(self, server, lru_size=None):
self.lru_size = lru_size
self.server = server
self.reset()
def reset(self):
self.outputs = {}
self.object_storage = {}
self.outputs_ui = {}
self.caches = CacheSet(self.lru_size)
self.status_messages = []
self.success = True
self.old_prompt = {}
def add_message(self, event, data: dict, broadcast: bool):
data = {
@@ -318,26 +443,13 @@ class PromptExecutor:
"node_id": node_id,
"node_type": class_type,
"executed": list(executed),
"exception_message": error["exception_message"],
"exception_type": error["exception_type"],
"traceback": error["traceback"],
"current_inputs": error["current_inputs"],
"current_outputs": error["current_outputs"],
"current_outputs": list(current_outputs),
}
self.add_message("execution_error", mes, broadcast=False)
# Next, remove the subsequent outputs since they will not be executed
to_delete = []
for o in self.outputs:
if (o not in current_outputs) and (o not in executed):
to_delete += [o]
if o in self.old_prompt:
d = self.old_prompt.pop(o)
del d
for o in to_delete:
d = self.outputs.pop(o)
del d
def execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]):
nodes.interrupt_processing(False)
@@ -351,65 +463,59 @@ class PromptExecutor:
self.add_message("execution_start", { "prompt_id": prompt_id}, broadcast=False)
with torch.inference_mode():
#delete cached outputs if nodes don't exist for them
to_delete = []
for o in self.outputs:
if o not in prompt:
to_delete += [o]
for o in to_delete:
d = self.outputs.pop(o)
del d
to_delete = []
for o in self.object_storage:
if o[0] not in prompt:
to_delete += [o]
else:
p = prompt[o[0]]
if o[1] != p['class_type']:
to_delete += [o]
for o in to_delete:
d = self.object_storage.pop(o)
del d
dynamic_prompt = DynamicPrompt(prompt)
is_changed_cache = IsChangedCache(dynamic_prompt, self.caches.outputs)
for cache in self.caches.all:
cache.set_prompt(dynamic_prompt, prompt.keys(), is_changed_cache)
cache.clean_unused()
for x in prompt:
recursive_output_delete_if_changed(prompt, self.old_prompt, self.outputs, x)
current_outputs = set(self.outputs.keys())
for x in list(self.outputs_ui.keys()):
if x not in current_outputs:
d = self.outputs_ui.pop(x)
del d
cached_nodes = []
for node_id in prompt:
if self.caches.outputs.get(node_id) is not None:
cached_nodes.append(node_id)
comfy.model_management.cleanup_models(keep_clone_weights_loaded=True)
self.add_message("execution_cached",
{ "nodes": list(current_outputs) , "prompt_id": prompt_id},
{ "nodes": cached_nodes, "prompt_id": prompt_id},
broadcast=False)
pending_subgraph_results = {}
executed = set()
output_node_id = None
to_execute = []
execution_list = ExecutionList(dynamic_prompt, self.caches.outputs)
current_outputs = self.caches.outputs.all_node_ids()
for node_id in list(execute_outputs):
to_execute += [(0, node_id)]
execution_list.add_node(node_id)
while len(to_execute) > 0:
#always execute the output that depends on the least amount of unexecuted nodes first
memo = {}
to_execute = sorted(list(map(lambda a: (len(recursive_will_execute(prompt, self.outputs, a[-1], memo)), a[-1]), to_execute)))
output_node_id = to_execute.pop(0)[-1]
# This call shouldn't raise anything if there's an error deep in
# the actual SD code, instead it will report the node where the
# error was raised
self.success, error, ex = recursive_execute(self.server, prompt, self.outputs, output_node_id, extra_data, executed, prompt_id, self.outputs_ui, self.object_storage)
if self.success is not True:
self.handle_execution_error(prompt_id, prompt, current_outputs, executed, error, ex)
while not execution_list.is_empty():
node_id, error, ex = execution_list.stage_node_execution()
if error is not None:
self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex)
break
result, error, ex = execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results)
self.success = result != ExecutionResult.FAILURE
if result == ExecutionResult.FAILURE:
self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex)
break
elif result == ExecutionResult.PENDING:
execution_list.unstage_node_execution()
else: # result == ExecutionResult.SUCCESS:
execution_list.complete_node_execution()
else:
# Only execute when the while-loop ends without break
self.add_message("execution_success", { "prompt_id": prompt_id }, broadcast=False)
for x in executed:
self.old_prompt[x] = copy.deepcopy(prompt[x])
ui_outputs = {}
meta_outputs = {}
all_node_ids = self.caches.ui.all_node_ids()
for node_id in all_node_ids:
ui_info = self.caches.ui.get(node_id)
if ui_info is not None:
ui_outputs[node_id] = ui_info["output"]
meta_outputs[node_id] = ui_info["meta"]
self.history_result = {
"outputs": ui_outputs,
"meta": meta_outputs,
}
self.server.last_node_id = None
if comfy.model_management.DISABLE_SMART_MEMORY:
comfy.model_management.unload_all_models()
@@ -426,31 +532,37 @@ def validate_inputs(prompt, item, validated):
obj_class = nodes.NODE_CLASS_MAPPINGS[class_type]
class_inputs = obj_class.INPUT_TYPES()
required_inputs = class_inputs['required']
valid_inputs = set(class_inputs.get('required',{})).union(set(class_inputs.get('optional',{})))
errors = []
valid = True
validate_function_inputs = []
validate_has_kwargs = False
if hasattr(obj_class, "VALIDATE_INPUTS"):
validate_function_inputs = inspect.getfullargspec(obj_class.VALIDATE_INPUTS).args
argspec = inspect.getfullargspec(obj_class.VALIDATE_INPUTS)
validate_function_inputs = argspec.args
validate_has_kwargs = argspec.varkw is not None
received_types = {}
for x in required_inputs:
for x in valid_inputs:
type_input, input_category, extra_info = get_input_info(obj_class, x)
assert extra_info is not None
if x not in inputs:
error = {
"type": "required_input_missing",
"message": "Required input is missing",
"details": f"{x}",
"extra_info": {
"input_name": x
if input_category == "required":
error = {
"type": "required_input_missing",
"message": "Required input is missing",
"details": f"{x}",
"extra_info": {
"input_name": x
}
}
}
errors.append(error)
errors.append(error)
continue
val = inputs[x]
info = required_inputs[x]
type_input = info[0]
info = (type_input, extra_info)
if isinstance(val, list):
if len(val) != 2:
error = {
@@ -469,8 +581,9 @@ def validate_inputs(prompt, item, validated):
o_id = val[0]
o_class_type = prompt[o_id]['class_type']
r = nodes.NODE_CLASS_MAPPINGS[o_class_type].RETURN_TYPES
if r[val[1]] != type_input:
received_type = r[val[1]]
received_type = r[val[1]]
received_types[x] = received_type
if 'input_types' not in validate_function_inputs and received_type != type_input:
details = f"{x}, {received_type} != {type_input}"
error = {
"type": "return_type_mismatch",
@@ -521,6 +634,9 @@ def validate_inputs(prompt, item, validated):
if type_input == "STRING":
val = str(val)
inputs[x] = val
if type_input == "BOOLEAN":
val = bool(val)
inputs[x] = val
except Exception as ex:
error = {
"type": "invalid_input_type",
@@ -536,11 +652,11 @@ def validate_inputs(prompt, item, validated):
errors.append(error)
continue
if len(info) > 1:
if "min" in info[1] and val < info[1]["min"]:
if x not in validate_function_inputs and not validate_has_kwargs:
if "min" in extra_info and val < extra_info["min"]:
error = {
"type": "value_smaller_than_min",
"message": "Value {} smaller than min of {}".format(val, info[1]["min"]),
"message": "Value {} smaller than min of {}".format(val, extra_info["min"]),
"details": f"{x}",
"extra_info": {
"input_name": x,
@@ -550,10 +666,10 @@ def validate_inputs(prompt, item, validated):
}
errors.append(error)
continue
if "max" in info[1] and val > info[1]["max"]:
if "max" in extra_info and val > extra_info["max"]:
error = {
"type": "value_bigger_than_max",
"message": "Value {} bigger than max of {}".format(val, info[1]["max"]),
"message": "Value {} bigger than max of {}".format(val, extra_info["max"]),
"details": f"{x}",
"extra_info": {
"input_name": x,
@@ -564,7 +680,6 @@ def validate_inputs(prompt, item, validated):
errors.append(error)
continue
if x not in validate_function_inputs:
if isinstance(type_input, list):
if val not in type_input:
input_config = info
@@ -591,18 +706,20 @@ def validate_inputs(prompt, item, validated):
errors.append(error)
continue
if len(validate_function_inputs) > 0:
input_data_all = get_input_data(inputs, obj_class, unique_id)
if len(validate_function_inputs) > 0 or validate_has_kwargs:
input_data_all, _ = get_input_data(inputs, obj_class, unique_id)
input_filtered = {}
for x in input_data_all:
if x in validate_function_inputs:
if x in validate_function_inputs or validate_has_kwargs:
input_filtered[x] = input_data_all[x]
if 'input_types' in validate_function_inputs:
input_filtered['input_types'] = [received_types]
#ret = obj_class.VALIDATE_INPUTS(**input_filtered)
ret = map_node_over_list(obj_class, input_filtered, "VALIDATE_INPUTS")
ret = _map_node_over_list(obj_class, input_filtered, "VALIDATE_INPUTS")
for x in input_filtered:
for i, r in enumerate(ret):
if r is not True:
if r is not True and not isinstance(r, ExecutionBlocker):
details = f"{x}"
if r is not False:
details += f" - {str(r)}"
@@ -613,8 +730,6 @@ def validate_inputs(prompt, item, validated):
"details": details,
"extra_info": {
"input_name": x,
"input_config": info,
"received_value": val,
}
}
errors.append(error)
@@ -780,7 +895,7 @@ class PromptQueue:
completed: bool
messages: List[str]
def task_done(self, item_id, outputs,
def task_done(self, item_id, history_result,
status: Optional['PromptQueue.ExecutionStatus']):
with self.mutex:
prompt = self.currently_running.pop(item_id)
@@ -793,9 +908,10 @@ class PromptQueue:
self.history[prompt[1]] = {
"prompt": prompt,
"outputs": copy.deepcopy(outputs),
"outputs": {},
'status': status_dict,
}
self.history[prompt[1]].update(history_result)
self.server.queue_updated()
def get_current_queue(self):

View File

@@ -17,7 +17,7 @@ folder_names_and_paths["configs"] = ([os.path.join(models_dir, "configs")], [".y
folder_names_and_paths["loras"] = ([os.path.join(models_dir, "loras")], supported_pt_extensions)
folder_names_and_paths["vae"] = ([os.path.join(models_dir, "vae")], supported_pt_extensions)
folder_names_and_paths["clip"] = ([os.path.join(models_dir, "clip")], supported_pt_extensions)
folder_names_and_paths["unet"] = ([os.path.join(models_dir, "unet")], supported_pt_extensions)
folder_names_and_paths["diffusion_models"] = ([os.path.join(models_dir, "unet"), os.path.join(models_dir, "diffusion_models")], supported_pt_extensions)
folder_names_and_paths["clip_vision"] = ([os.path.join(models_dir, "clip_vision")], supported_pt_extensions)
folder_names_and_paths["style_models"] = ([os.path.join(models_dir, "style_models")], supported_pt_extensions)
folder_names_and_paths["embeddings"] = ([os.path.join(models_dir, "embeddings")], supported_pt_extensions)
@@ -44,6 +44,10 @@ user_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "user
filename_list_cache: dict[str, tuple[list[str], dict[str, float], float]] = {}
def map_legacy(folder_name: str) -> str:
legacy = {"unet": "diffusion_models"}
return legacy.get(folder_name, folder_name)
if not os.path.exists(input_directory):
try:
os.makedirs(input_directory)
@@ -128,12 +132,14 @@ def exists_annotated_filepath(name) -> bool:
def add_model_folder_path(folder_name: str, full_folder_path: str) -> None:
global folder_names_and_paths
folder_name = map_legacy(folder_name)
if folder_name in folder_names_and_paths:
folder_names_and_paths[folder_name][0].append(full_folder_path)
else:
folder_names_and_paths[folder_name] = ([full_folder_path], set())
def get_folder_paths(folder_name: str) -> list[str]:
folder_name = map_legacy(folder_name)
return folder_names_and_paths[folder_name][0][:]
def recursive_search(directory: str, excluded_dir_names: list[str] | None=None) -> tuple[list[str], dict[str, float]]:
@@ -180,6 +186,7 @@ def filter_files_extensions(files: Collection[str], extensions: Collection[str])
def get_full_path(folder_name: str, filename: str) -> str | None:
global folder_names_and_paths
folder_name = map_legacy(folder_name)
if folder_name not in folder_names_and_paths:
return None
folders = folder_names_and_paths[folder_name]
@@ -194,6 +201,7 @@ def get_full_path(folder_name: str, filename: str) -> str | None:
return None
def get_filename_list_(folder_name: str) -> tuple[list[str], dict[str, float], float]:
folder_name = map_legacy(folder_name)
global folder_names_and_paths
output_list = set()
folders = folder_names_and_paths[folder_name]
@@ -208,6 +216,7 @@ def get_filename_list_(folder_name: str) -> tuple[list[str], dict[str, float], f
def cached_filename_list_(folder_name: str) -> tuple[list[str], dict[str, float], float] | None:
global filename_list_cache
global folder_names_and_paths
folder_name = map_legacy(folder_name)
if folder_name not in filename_list_cache:
return None
out = filename_list_cache[folder_name]
@@ -227,6 +236,7 @@ def cached_filename_list_(folder_name: str) -> tuple[list[str], dict[str, float]
return out
def get_filename_list(folder_name: str) -> list[str]:
folder_name = map_legacy(folder_name)
out = cached_filename_list_(folder_name)
if out is None:
out = get_filename_list_(folder_name)

10
main.py
View File

@@ -6,6 +6,10 @@ import importlib.util
import folder_paths
import time
from comfy.cli_args import args
from app.logger import setup_logger
setup_logger(verbose=args.verbose)
def execute_prestartup_script():
@@ -101,7 +105,7 @@ def cuda_malloc_warning():
logging.warning("\nWARNING: this card most likely does not support cuda-malloc, if you get \"CUDA error\" please run ComfyUI with: --disable-cuda-malloc\n")
def prompt_worker(q, server):
e = execution.PromptExecutor(server)
e = execution.PromptExecutor(server, lru_size=args.cache_lru)
last_gc_collect = 0
need_gc = False
gc_collect_interval = 10.0
@@ -121,7 +125,7 @@ def prompt_worker(q, server):
e.execute(item[2], prompt_id, item[3], item[4])
need_gc = True
q.task_done(item_id,
e.outputs_ui,
e.history_result,
status=execution.PromptQueue.ExecutionStatus(
status_str='success' if e.success else 'error',
completed=e.success,
@@ -242,6 +246,7 @@ if __name__ == "__main__":
folder_paths.add_model_folder_path("checkpoints", os.path.join(folder_paths.get_output_directory(), "checkpoints"))
folder_paths.add_model_folder_path("clip", os.path.join(folder_paths.get_output_directory(), "clip"))
folder_paths.add_model_folder_path("vae", os.path.join(folder_paths.get_output_directory(), "vae"))
folder_paths.add_model_folder_path("diffusion_models", os.path.join(folder_paths.get_output_directory(), "diffusion_models"))
if args.input_directory:
input_dir = os.path.abspath(args.input_directory)
@@ -261,6 +266,7 @@ if __name__ == "__main__":
call_on_start = startup_server
try:
loop.run_until_complete(server.setup())
loop.run_until_complete(run(server, address=args.listen, port=args.port, verbose=not args.dont_print_server, call_on_start=call_on_start))
except KeyboardInterrupt:
logging.info("\nStopped server")

View File

@@ -0,0 +1,2 @@
# model_manager/__init__.py
from .download_models import download_model, DownloadModelStatus, DownloadStatusType, create_model_path, check_file_exists, track_download_progress, validate_model_subdirectory, validate_filename

View File

@@ -0,0 +1,240 @@
from __future__ import annotations
import aiohttp
import os
import traceback
import logging
from folder_paths import models_dir
import re
from typing import Callable, Any, Optional, Awaitable, Dict
from enum import Enum
import time
from dataclasses import dataclass
class DownloadStatusType(Enum):
PENDING = "pending"
IN_PROGRESS = "in_progress"
COMPLETED = "completed"
ERROR = "error"
@dataclass
class DownloadModelStatus():
status: str
progress_percentage: float
message: str
already_existed: bool = False
def __init__(self, status: DownloadStatusType, progress_percentage: float, message: str, already_existed: bool):
self.status = status.value # Store the string value of the Enum
self.progress_percentage = progress_percentage
self.message = message
self.already_existed = already_existed
def to_dict(self) -> Dict[str, Any]:
return {
"status": self.status,
"progress_percentage": self.progress_percentage,
"message": self.message,
"already_existed": self.already_existed
}
async def download_model(model_download_request: Callable[[str], Awaitable[aiohttp.ClientResponse]],
model_name: str,
model_url: str,
model_sub_directory: str,
progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]],
progress_interval: float = 1.0) -> DownloadModelStatus:
"""
Download a model file from a given URL into the models directory.
Args:
model_download_request (Callable[[str], Awaitable[aiohttp.ClientResponse]]):
A function that makes an HTTP request. This makes it easier to mock in unit tests.
model_name (str):
The name of the model file to be downloaded. This will be the filename on disk.
model_url (str):
The URL from which to download the model.
model_sub_directory (str):
The subdirectory within the main models directory where the model
should be saved (e.g., 'checkpoints', 'loras', etc.).
progress_callback (Callable[[str, DownloadModelStatus], Awaitable[Any]]):
An asynchronous function to call with progress updates.
Returns:
DownloadModelStatus: The result of the download operation.
"""
if not validate_model_subdirectory(model_sub_directory):
return DownloadModelStatus(
DownloadStatusType.ERROR,
0,
"Invalid model subdirectory",
False
)
if not validate_filename(model_name):
return DownloadModelStatus(
DownloadStatusType.ERROR,
0,
"Invalid model name",
False
)
file_path, relative_path = create_model_path(model_name, model_sub_directory, models_dir)
existing_file = await check_file_exists(file_path, model_name, progress_callback, relative_path)
if existing_file:
return existing_file
try:
status = DownloadModelStatus(DownloadStatusType.PENDING, 0, f"Starting download of {model_name}", False)
await progress_callback(relative_path, status)
response = await model_download_request(model_url)
if response.status != 200:
error_message = f"Failed to download {model_name}. Status code: {response.status}"
logging.error(error_message)
status = DownloadModelStatus(DownloadStatusType.ERROR, 0, error_message, False)
await progress_callback(relative_path, status)
return DownloadModelStatus(DownloadStatusType.ERROR, 0, error_message, False)
return await track_download_progress(response, file_path, model_name, progress_callback, relative_path, progress_interval)
except Exception as e:
logging.error(f"Error in downloading model: {e}")
return await handle_download_error(e, model_name, progress_callback, relative_path)
def create_model_path(model_name: str, model_directory: str, models_base_dir: str) -> tuple[str, str]:
full_model_dir = os.path.join(models_base_dir, model_directory)
os.makedirs(full_model_dir, exist_ok=True)
file_path = os.path.join(full_model_dir, model_name)
# Ensure the resulting path is still within the base directory
abs_file_path = os.path.abspath(file_path)
abs_base_dir = os.path.abspath(str(models_base_dir))
if os.path.commonprefix([abs_file_path, abs_base_dir]) != abs_base_dir:
raise Exception(f"Invalid model directory: {model_directory}/{model_name}")
relative_path = '/'.join([model_directory, model_name])
return file_path, relative_path
async def check_file_exists(file_path: str,
model_name: str,
progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]],
relative_path: str) -> Optional[DownloadModelStatus]:
if os.path.exists(file_path):
status = DownloadModelStatus(DownloadStatusType.COMPLETED, 100, f"{model_name} already exists", True)
await progress_callback(relative_path, status)
return status
return None
async def track_download_progress(response: aiohttp.ClientResponse,
file_path: str,
model_name: str,
progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]],
relative_path: str,
interval: float = 1.0) -> DownloadModelStatus:
try:
total_size = int(response.headers.get('Content-Length', 0))
downloaded = 0
last_update_time = time.time()
async def update_progress():
nonlocal last_update_time
progress = (downloaded / total_size) * 100 if total_size > 0 else 0
status = DownloadModelStatus(DownloadStatusType.IN_PROGRESS, progress, f"Downloading {model_name}", False)
await progress_callback(relative_path, status)
last_update_time = time.time()
with open(file_path, 'wb') as f:
chunk_iterator = response.content.iter_chunked(8192)
while True:
try:
chunk = await chunk_iterator.__anext__()
except StopAsyncIteration:
break
f.write(chunk)
downloaded += len(chunk)
if time.time() - last_update_time >= interval:
await update_progress()
await update_progress()
logging.info(f"Successfully downloaded {model_name}. Total downloaded: {downloaded}")
status = DownloadModelStatus(DownloadStatusType.COMPLETED, 100, f"Successfully downloaded {model_name}", False)
await progress_callback(relative_path, status)
return status
except Exception as e:
logging.error(f"Error in track_download_progress: {e}")
logging.error(traceback.format_exc())
return await handle_download_error(e, model_name, progress_callback, relative_path)
async def handle_download_error(e: Exception,
model_name: str,
progress_callback: Callable[[str, DownloadModelStatus], Any],
relative_path: str) -> DownloadModelStatus:
error_message = f"Error downloading {model_name}: {str(e)}"
status = DownloadModelStatus(DownloadStatusType.ERROR, 0, error_message, False)
await progress_callback(relative_path, status)
return status
def validate_model_subdirectory(model_subdirectory: str) -> bool:
"""
Validate that the model subdirectory is safe to install into.
Must not contain relative paths, nested paths or special characters
other than underscores and hyphens.
Args:
model_subdirectory (str): The subdirectory for the specific model type.
Returns:
bool: True if the subdirectory is safe, False otherwise.
"""
if len(model_subdirectory) > 50:
return False
if '..' in model_subdirectory or '/' in model_subdirectory:
return False
if not re.match(r'^[a-zA-Z0-9_-]+$', model_subdirectory):
return False
return True
def validate_filename(filename: str)-> bool:
"""
Validate a filename to ensure it's safe and doesn't contain any path traversal attempts.
Args:
filename (str): The filename to validate
Returns:
bool: True if the filename is valid, False otherwise
"""
if not filename.lower().endswith(('.sft', '.safetensors')):
return False
# Check if the filename is empty, None, or just whitespace
if not filename or not filename.strip():
return False
# Check for any directory traversal attempts or invalid characters
if any(char in filename for char in ['..', '/', '\\', '\n', '\r', '\t', '\0']):
return False
# Check if the filename starts with a dot (hidden file)
if filename.startswith('.'):
return False
# Use a whitelist of allowed characters
if not re.match(r'^[a-zA-Z0-9_\-. ]+$', filename):
return False
# Ensure the filename isn't too long
if len(filename) > 255:
return False
return True

130
nodes.py
View File

@@ -47,11 +47,18 @@ MAX_RESOLUTION=16384
class CLIPTextEncode:
@classmethod
def INPUT_TYPES(s):
return {"required": {"text": ("STRING", {"multiline": True, "dynamicPrompts": True}), "clip": ("CLIP", )}}
return {
"required": {
"text": ("STRING", {"multiline": True, "dynamicPrompts": True, "tooltip": "The text to be encoded."}),
"clip": ("CLIP", {"tooltip": "The CLIP model used for encoding the text."})
}
}
RETURN_TYPES = ("CONDITIONING",)
OUTPUT_TOOLTIPS = ("A conditioning containing the embedded text used to guide the diffusion model.",)
FUNCTION = "encode"
CATEGORY = "conditioning"
DESCRIPTION = "Encodes a text prompt using a CLIP model into an embedding that can be used to guide the diffusion model towards generating specific images."
def encode(self, clip, text):
tokens = clip.tokenize(text)
@@ -260,11 +267,18 @@ class ConditioningSetTimestepRange:
class VAEDecode:
@classmethod
def INPUT_TYPES(s):
return {"required": { "samples": ("LATENT", ), "vae": ("VAE", )}}
return {
"required": {
"samples": ("LATENT", {"tooltip": "The latent to be decoded."}),
"vae": ("VAE", {"tooltip": "The VAE model used for decoding the latent."})
}
}
RETURN_TYPES = ("IMAGE",)
OUTPUT_TOOLTIPS = ("The decoded image.",)
FUNCTION = "decode"
CATEGORY = "latent"
DESCRIPTION = "Decodes latent images back into pixel space images."
def decode(self, vae, samples):
return (vae.decode(samples["samples"]), )
@@ -506,12 +520,19 @@ class CheckpointLoader:
class CheckpointLoaderSimple:
@classmethod
def INPUT_TYPES(s):
return {"required": { "ckpt_name": (folder_paths.get_filename_list("checkpoints"), ),
}}
return {
"required": {
"ckpt_name": (folder_paths.get_filename_list("checkpoints"), {"tooltip": "The name of the checkpoint (model) to load."}),
}
}
RETURN_TYPES = ("MODEL", "CLIP", "VAE")
OUTPUT_TOOLTIPS = ("The model used for denoising latents.",
"The CLIP model used for encoding text prompts.",
"The VAE model used for encoding and decoding images to and from latent space.")
FUNCTION = "load_checkpoint"
CATEGORY = "loaders"
DESCRIPTION = "Loads a diffusion model checkpoint, diffusion models are used to denoise latents."
def load_checkpoint(self, ckpt_name):
ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name)
@@ -582,16 +603,22 @@ class LoraLoader:
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"clip": ("CLIP", ),
"lora_name": (folder_paths.get_filename_list("loras"), ),
"strength_model": ("FLOAT", {"default": 1.0, "min": -100.0, "max": 100.0, "step": 0.01}),
"strength_clip": ("FLOAT", {"default": 1.0, "min": -100.0, "max": 100.0, "step": 0.01}),
}}
return {
"required": {
"model": ("MODEL", {"tooltip": "The diffusion model the LoRA will be applied to."}),
"clip": ("CLIP", {"tooltip": "The CLIP model the LoRA will be applied to."}),
"lora_name": (folder_paths.get_filename_list("loras"), {"tooltip": "The name of the LoRA."}),
"strength_model": ("FLOAT", {"default": 1.0, "min": -100.0, "max": 100.0, "step": 0.01, "tooltip": "How strongly to modify the diffusion model. This value can be negative."}),
"strength_clip": ("FLOAT", {"default": 1.0, "min": -100.0, "max": 100.0, "step": 0.01, "tooltip": "How strongly to modify the CLIP model. This value can be negative."}),
}
}
RETURN_TYPES = ("MODEL", "CLIP")
OUTPUT_TOOLTIPS = ("The modified diffusion model.", "The modified CLIP model.")
FUNCTION = "load_lora"
CATEGORY = "loaders"
DESCRIPTION = "LoRAs are used to modify diffusion and CLIP models, altering the way in which latents are denoised such as applying styles. Multiple LoRA nodes can be linked together."
def load_lora(self, model, clip, lora_name, strength_model, strength_clip):
if strength_model == 0 and strength_clip == 0:
@@ -638,6 +665,8 @@ class VAELoader:
sd1_taesd_dec = False
sd3_taesd_enc = False
sd3_taesd_dec = False
f1_taesd_enc = False
f1_taesd_dec = False
for v in approx_vaes:
if v.startswith("taesd_decoder."):
@@ -652,12 +681,18 @@ class VAELoader:
sd3_taesd_dec = True
elif v.startswith("taesd3_encoder."):
sd3_taesd_enc = True
elif v.startswith("taef1_encoder."):
f1_taesd_dec = True
elif v.startswith("taef1_decoder."):
f1_taesd_enc = True
if sd1_taesd_dec and sd1_taesd_enc:
vaes.append("taesd")
if sdxl_taesd_dec and sdxl_taesd_enc:
vaes.append("taesdxl")
if sd3_taesd_dec and sd3_taesd_enc:
vaes.append("taesd3")
if f1_taesd_dec and f1_taesd_enc:
vaes.append("taef1")
return vaes
@staticmethod
@@ -685,6 +720,9 @@ class VAELoader:
elif name == "taesd3":
sd["vae_scale"] = torch.tensor(1.5305)
sd["vae_shift"] = torch.tensor(0.0609)
elif name == "taef1":
sd["vae_scale"] = torch.tensor(0.3611)
sd["vae_shift"] = torch.tensor(0.1159)
return sd
@classmethod
@@ -697,7 +735,7 @@ class VAELoader:
#TODO: scale factor?
def load_vae(self, vae_name):
if vae_name in ["taesd", "taesdxl", "taesd3"]:
if vae_name in ["taesd", "taesdxl", "taesd3", "taef1"]:
sd = self.load_taesd(vae_name)
else:
vae_path = folder_paths.get_full_path("vae", vae_name)
@@ -817,7 +855,7 @@ class ControlNetApplyAdvanced:
class UNETLoader:
@classmethod
def INPUT_TYPES(s):
return {"required": { "unet_name": (folder_paths.get_filename_list("unet"), ),
return {"required": { "unet_name": (folder_paths.get_filename_list("diffusion_models"), ),
"weight_dtype": (["default", "fp8_e4m3fn", "fp8_e5m2"],)
}}
RETURN_TYPES = ("MODEL",)
@@ -826,14 +864,14 @@ class UNETLoader:
CATEGORY = "advanced/loaders"
def load_unet(self, unet_name, weight_dtype):
dtype = None
model_options = {}
if weight_dtype == "fp8_e4m3fn":
dtype = torch.float8_e4m3fn
model_options["dtype"] = torch.float8_e4m3fn
elif weight_dtype == "fp8_e5m2":
dtype = torch.float8_e5m2
model_options["dtype"] = torch.float8_e5m2
unet_path = folder_paths.get_full_path("unet", unet_name)
model = comfy.sd.load_unet(unet_path, dtype=dtype)
unet_path = folder_paths.get_full_path("diffusion_models", unet_name)
model = comfy.sd.load_diffusion_model(unet_path, model_options=model_options)
return (model,)
class CLIPLoader:
@@ -1033,13 +1071,19 @@ class EmptyLatentImage:
@classmethod
def INPUT_TYPES(s):
return {"required": { "width": ("INT", {"default": 512, "min": 16, "max": MAX_RESOLUTION, "step": 8}),
"height": ("INT", {"default": 512, "min": 16, "max": MAX_RESOLUTION, "step": 8}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096})}}
return {
"required": {
"width": ("INT", {"default": 512, "min": 16, "max": MAX_RESOLUTION, "step": 8, "tooltip": "The width of the latent images in pixels."}),
"height": ("INT", {"default": 512, "min": 16, "max": MAX_RESOLUTION, "step": 8, "tooltip": "The height of the latent images in pixels."}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096, "tooltip": "The number of latent images in the batch."})
}
}
RETURN_TYPES = ("LATENT",)
OUTPUT_TOOLTIPS = ("The empty latent image batch.",)
FUNCTION = "generate"
CATEGORY = "latent"
DESCRIPTION = "Create a new batch of empty latent images to be denoised via sampling."
def generate(self, width, height, batch_size=1):
latent = torch.zeros([batch_size, 4, height // 8, width // 8], device=self.device)
@@ -1359,24 +1403,27 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive,
class KSampler:
@classmethod
def INPUT_TYPES(s):
return {"required":
{"model": ("MODEL",),
"seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
"steps": ("INT", {"default": 20, "min": 1, "max": 10000}),
"cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01}),
"sampler_name": (comfy.samplers.KSampler.SAMPLERS, ),
"scheduler": (comfy.samplers.KSampler.SCHEDULERS, ),
"positive": ("CONDITIONING", ),
"negative": ("CONDITIONING", ),
"latent_image": ("LATENT", ),
"denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
}
}
return {
"required": {
"model": ("MODEL", {"tooltip": "The model used for denoising the input latent."}),
"seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff, "tooltip": "The random seed used for creating the noise."}),
"steps": ("INT", {"default": 20, "min": 1, "max": 10000, "tooltip": "The number of steps used in the denoising process."}),
"cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01, "tooltip": "The Classifier-Free Guidance scale balances creativity and adherence to the prompt. Higher values result in images more closely matching the prompt however too high values will negatively impact quality."}),
"sampler_name": (comfy.samplers.KSampler.SAMPLERS, {"tooltip": "The algorithm used when sampling, this can affect the quality, speed, and style of the generated output."}),
"scheduler": (comfy.samplers.KSampler.SCHEDULERS, {"tooltip": "The scheduler controls how noise is gradually removed to form the image."}),
"positive": ("CONDITIONING", {"tooltip": "The conditioning describing the attributes you want to include in the image."}),
"negative": ("CONDITIONING", {"tooltip": "The conditioning describing the attributes you want to exclude from the image."}),
"latent_image": ("LATENT", {"tooltip": "The latent image to denoise."}),
"denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "The amount of denoising applied, lower values will maintain the structure of the initial image allowing for image to image sampling."}),
}
}
RETURN_TYPES = ("LATENT",)
OUTPUT_TOOLTIPS = ("The denoised latent.",)
FUNCTION = "sample"
CATEGORY = "sampling"
DESCRIPTION = "Uses the provided model, positive and negative conditioning to denoise the latent image."
def sample(self, model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=1.0):
return common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=denoise)
@@ -1424,11 +1471,15 @@ class SaveImage:
@classmethod
def INPUT_TYPES(s):
return {"required":
{"images": ("IMAGE", ),
"filename_prefix": ("STRING", {"default": "ComfyUI"})},
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
}
return {
"required": {
"images": ("IMAGE", {"tooltip": "The images to save."}),
"filename_prefix": ("STRING", {"default": "ComfyUI", "tooltip": "The prefix for the file to save. This may include formatting information such as %date:yyyy-MM-dd% or %Empty Latent Image.width% to include values from nodes."})
},
"hidden": {
"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"
},
}
RETURN_TYPES = ()
FUNCTION = "save_images"
@@ -1436,6 +1487,7 @@ class SaveImage:
OUTPUT_NODE = True
CATEGORY = "image"
DESCRIPTION = "Saves the input images to your ComfyUI output directory."
def save_images(self, images, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None):
filename_prefix += self.prefix_append
@@ -2077,3 +2129,5 @@ def init_extra_nodes(init_custom_nodes=True):
else:
logging.warning("Please do a: pip install -r requirements.txt")
logging.warning("")
return import_failed

View File

@@ -79,7 +79,7 @@
"#!wget -c https://huggingface.co/comfyanonymous/clip_vision_g/resolve/main/clip_vision_g.safetensors -P ./models/clip_vision/\n",
"\n",
"# SD1.5\n",
"!wget -c https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned-emaonly.ckpt -P ./models/checkpoints/\n",
"!wget -c https://huggingface.co/Comfy-Org/stable-diffusion-v1-5-archive/resolve/main/v1-5-pruned-emaonly-fp16.safetensors -P ./models/checkpoints/\n",
"\n",
"# SD2\n",
"#!wget -c https://huggingface.co/stabilityai/stable-diffusion-2-1-base/resolve/main/v2-1_512-ema-pruned.safetensors -P ./models/checkpoints/\n",

View File

@@ -1,6 +1,7 @@
[pytest]
markers =
inference: mark as inference test (deselect with '-m "not inference"')
execution: mark as execution test (deselect with '-m "not execution"')
testpaths =
tests
tests-unit

View File

@@ -43,7 +43,7 @@ prompt_text = """
"4": {
"class_type": "CheckpointLoaderSimple",
"inputs": {
"ckpt_name": "v1-5-pruned-emaonly.ckpt"
"ckpt_name": "v1-5-pruned-emaonly.safetensors"
}
},
"5": {

View File

@@ -41,15 +41,14 @@ def get_images(ws, prompt):
continue #previews are binary data
history = get_history(prompt_id)[prompt_id]
for o in history['outputs']:
for node_id in history['outputs']:
node_output = history['outputs'][node_id]
if 'images' in node_output:
images_output = []
for image in node_output['images']:
image_data = get_image(image['filename'], image['subfolder'], image['type'])
images_output.append(image_data)
output_images[node_id] = images_output
for node_id in history['outputs']:
node_output = history['outputs'][node_id]
images_output = []
if 'images' in node_output:
for image in node_output['images']:
image_data = get_image(image['filename'], image['subfolder'], image['type'])
images_output.append(image_data)
output_images[node_id] = images_output
return output_images
@@ -85,7 +84,7 @@ prompt_text = """
"4": {
"class_type": "CheckpointLoaderSimple",
"inputs": {
"ckpt_name": "v1-5-pruned-emaonly.ckpt"
"ckpt_name": "v1-5-pruned-emaonly.safetensors"
}
},
"5": {

View File

@@ -81,7 +81,7 @@ prompt_text = """
"4": {
"class_type": "CheckpointLoaderSimple",
"inputs": {
"ckpt_name": "v1-5-pruned-emaonly.ckpt"
"ckpt_name": "v1-5-pruned-emaonly.safetensors"
}
},
"5": {

View File

@@ -12,7 +12,6 @@ import json
import glob
import struct
import ssl
import hashlib
from PIL import Image, ImageOps
from PIL.PngImagePlugin import PngInfo
from io import BytesIO
@@ -28,7 +27,9 @@ import comfy.model_management
import node_helpers
from app.frontend_management import FrontendManager
from app.user_manager import UserManager
from model_filemanager import download_model, DownloadModelStatus
from typing import Optional
from api_server.routes.internal.internal_routes import InternalRoutes
class BinaryEventTypes:
PREVIEW_IMAGE = 1
@@ -40,6 +41,21 @@ async def send_socket_catch_exception(function, message):
except (aiohttp.ClientError, aiohttp.ClientPayloadError, ConnectionResetError) as err:
logging.warning("send error: {}".format(err))
def get_comfyui_version():
comfyui_version = "unknown"
repo_path = os.path.dirname(os.path.realpath(__file__))
try:
import pygit2
repo = pygit2.Repository(repo_path)
comfyui_version = repo.describe(describe_strategy=pygit2.GIT_DESCRIBE_TAGS)
except Exception:
try:
import subprocess
comfyui_version = subprocess.check_output(["git", "describe", "--tags"], cwd=repo_path).decode('utf-8')
except Exception as e:
logging.warning(f"Failed to get ComfyUI version: {e}")
return comfyui_version.strip()
@web.middleware
async def cache_control(request: web.Request, handler):
response: web.Response = await handler(request)
@@ -72,10 +88,12 @@ class PromptServer():
mimetypes.types_map['.js'] = 'application/javascript; charset=utf-8'
self.user_manager = UserManager()
self.internal_routes = InternalRoutes()
self.supports = ["custom_nodes_from_web"]
self.prompt_queue = None
self.loop = loop
self.messages = asyncio.Queue()
self.client_session:Optional[aiohttp.ClientSession] = None
self.number = 0
middlewares = [cache_control]
@@ -138,6 +156,14 @@ class PromptServer():
embeddings = folder_paths.get_filename_list("embeddings")
return web.json_response(list(map(lambda a: os.path.splitext(a)[0], embeddings)))
@routes.get("/models/{folder}")
async def get_models(request):
folder = request.match_info.get("folder", None)
if not folder in folder_paths.folder_names_and_paths:
return web.Response(status=404)
files = folder_paths.get_filename_list(folder)
return web.json_response(files)
@routes.get("/extensions")
async def get_extensions(request):
files = glob.glob(os.path.join(
@@ -389,16 +415,20 @@ class PromptServer():
return web.json_response(dt["__metadata__"])
@routes.get("/system_stats")
async def get_queue(request):
async def system_stats(request):
device = comfy.model_management.get_torch_device()
device_name = comfy.model_management.get_torch_device_name(device)
vram_total, torch_vram_total = comfy.model_management.get_total_memory(device, torch_total_too=True)
vram_free, torch_vram_free = comfy.model_management.get_free_memory(device, torch_free_too=True)
system_stats = {
"system": {
"os": os.name,
"comfyui_version": get_comfyui_version(),
"python_version": sys.version,
"embedded_python": os.path.split(os.path.split(sys.executable)[0])[1] == "python_embeded"
"pytorch_version": comfy.model_management.torch_version,
"embedded_python": os.path.split(os.path.split(sys.executable)[0])[1] == "python_embeded",
"argv": sys.argv
},
"devices": [
{
@@ -422,6 +452,7 @@ class PromptServer():
obj_class = nodes.NODE_CLASS_MAPPINGS[node_class]
info = {}
info['input'] = obj_class.INPUT_TYPES()
info['input_order'] = {key: list(value.keys()) for (key, value) in obj_class.INPUT_TYPES().items()}
info['output'] = obj_class.RETURN_TYPES
info['output_is_list'] = obj_class.OUTPUT_IS_LIST if hasattr(obj_class, 'OUTPUT_IS_LIST') else [False] * len(obj_class.RETURN_TYPES)
info['output_name'] = obj_class.RETURN_NAMES if hasattr(obj_class, 'RETURN_NAMES') else info['output']
@@ -437,6 +468,14 @@ class PromptServer():
if hasattr(obj_class, 'CATEGORY'):
info['category'] = obj_class.CATEGORY
if hasattr(obj_class, 'OUTPUT_TOOLTIPS'):
info['output_tooltips'] = obj_class.OUTPUT_TOOLTIPS
if getattr(obj_class, "DEPRECATED", False):
info['deprecated'] = True
if getattr(obj_class, "EXPERIMENTAL", False):
info['experimental'] = True
return info
@routes.get("/object_info")
@@ -559,9 +598,42 @@ class PromptServer():
self.prompt_queue.delete_history_item(id_to_delete)
return web.Response(status=200)
# Internal route. Should not be depended upon and is subject to change at any time.
# TODO(robinhuang): Move to internal route table class once we refactor PromptServer to pass around Websocket.
@routes.post("/internal/models/download")
async def download_handler(request):
async def report_progress(filename: str, status: DownloadModelStatus):
payload = status.to_dict()
payload['download_path'] = filename
await self.send_json("download_progress", payload)
data = await request.json()
url = data.get('url')
model_directory = data.get('model_directory')
model_filename = data.get('model_filename')
progress_interval = data.get('progress_interval', 1.0) # In seconds, how often to report download progress.
if not url or not model_directory or not model_filename:
return web.json_response({"status": "error", "message": "Missing URL or folder path or filename"}, status=400)
session = self.client_session
if session is None:
logging.error("Client session is not initialized")
return web.Response(status=500)
task = asyncio.create_task(download_model(lambda url: session.get(url), model_filename, url, model_directory, report_progress, progress_interval))
await task
return web.json_response(task.result().to_dict())
async def setup(self):
timeout = aiohttp.ClientTimeout(total=None) # no timeout
self.client_session = aiohttp.ClientSession(timeout=timeout)
def add_routes(self):
self.user_manager.add_routes(self.routes)
self.app.add_subapp('/internal', self.internal_routes.get_app())
# Prefix every route with /api for easier matching for delegation.
# This is very useful for frontend dev server, which need to forward
@@ -680,6 +752,9 @@ class PromptServer():
site = web.TCPSite(runner, address, port, ssl_context=ssl_ctx)
await site.start()
self.address = address
self.port = port
if verbose:
logging.info("Starting server\n")
logging.info("To see the GUI go to: {}://{}:{}".format(scheme, address, port))

1
tests-ui/.gitignore vendored
View File

@@ -1 +0,0 @@
node_modules

View File

@@ -1,9 +0,0 @@
const { start } = require("./utils");
const lg = require("./utils/litegraph");
// Load things once per test file before to ensure its all warmed up for the tests
beforeAll(async () => {
lg.setup(global);
await start({ resetEnv: true });
lg.teardown(global);
});

View File

@@ -1,4 +0,0 @@
{
"presets": ["@babel/preset-env"],
"plugins": ["babel-plugin-transform-import-meta"]
}

View File

@@ -1,14 +0,0 @@
module.exports = async function () {
global.ResizeObserver = class ResizeObserver {
observe() {}
unobserve() {}
disconnect() {}
};
const { nop } = require("./utils/nopProxy");
global.enableWebGLCanvas = nop;
HTMLCanvasElement.prototype.getContext = nop;
localStorage["Comfy.Settings.Comfy.Logging.Enabled"] = "false";
};

View File

@@ -1,11 +0,0 @@
/** @type {import('jest').Config} */
const config = {
testEnvironment: "jsdom",
setupFiles: ["./globalSetup.js"],
setupFilesAfterEnv: ["./afterSetup.js"],
clearMocks: true,
resetModules: true,
testTimeout: 10000
};
module.exports = config;

File diff suppressed because it is too large Load Diff

View File

@@ -1,31 +0,0 @@
{
"name": "comfui-tests",
"version": "1.0.0",
"description": "UI tests",
"main": "index.js",
"scripts": {
"test": "jest",
"test:generate": "node setup.js"
},
"repository": {
"type": "git",
"url": "git+https://github.com/comfyanonymous/ComfyUI.git"
},
"keywords": [
"comfyui",
"test"
],
"author": "comfyanonymous",
"license": "GPL-3.0",
"bugs": {
"url": "https://github.com/comfyanonymous/ComfyUI/issues"
},
"homepage": "https://github.com/comfyanonymous/ComfyUI#readme",
"devDependencies": {
"@babel/preset-env": "^7.22.20",
"@types/jest": "^29.5.5",
"babel-plugin-transform-import-meta": "^2.2.1",
"jest": "^29.7.0",
"jest-environment-jsdom": "^29.7.0"
}
}

View File

@@ -1,88 +0,0 @@
const { spawn } = require("child_process");
const { resolve } = require("path");
const { existsSync, mkdirSync, writeFileSync } = require("fs");
const http = require("http");
async function setup() {
// Wait up to 30s for it to start
let success = false;
let child;
for (let i = 0; i < 30; i++) {
try {
await new Promise((res, rej) => {
http
.get("http://127.0.0.1:8188/object_info", (resp) => {
let data = "";
resp.on("data", (chunk) => {
data += chunk;
});
resp.on("end", () => {
// Modify the response data to add some checkpoints
const objectInfo = JSON.parse(data);
objectInfo.CheckpointLoaderSimple.input.required.ckpt_name[0] = ["model1.safetensors", "model2.ckpt"];
objectInfo.VAELoader.input.required.vae_name[0] = ["vae1.safetensors", "vae2.ckpt"];
data = JSON.stringify(objectInfo, undefined, "\t");
const outDir = resolve("./data");
if (!existsSync(outDir)) {
mkdirSync(outDir);
}
const outPath = resolve(outDir, "object_info.json");
console.log(`Writing ${Object.keys(objectInfo).length} nodes to ${outPath}`);
writeFileSync(outPath, data, {
encoding: "utf8",
});
res();
});
})
.on("error", rej);
});
success = true;
break;
} catch (error) {
console.log(i + "/30", error);
if (i === 0) {
// Start the server on first iteration if it fails to connect
console.log("Starting ComfyUI server...");
let python = resolve("../../python_embeded/python.exe");
let args;
let cwd;
if (existsSync(python)) {
args = ["-s", "ComfyUI/main.py"];
cwd = "../..";
} else {
python = "python";
args = ["main.py"];
cwd = "..";
}
args.push("--cpu");
console.log(python, ...args);
child = spawn(python, args, { cwd });
child.on("error", (err) => {
console.log(`Server error (${err})`);
i = 30;
});
child.on("exit", (code) => {
if (!success) {
console.log(`Server exited (${code})`);
i = 30;
}
});
}
await new Promise((r) => {
setTimeout(r, 1000);
});
}
}
child?.kill();
if (!success) {
throw new Error("Waiting for server failed...");
}
}
setup();

View File

@@ -1,196 +0,0 @@
// @ts-check
/// <reference path="../node_modules/@types/jest/index.d.ts" />
const { start } = require("../utils");
const lg = require("../utils/litegraph");
describe("extensions", () => {
beforeEach(() => {
lg.setup(global);
});
afterEach(() => {
lg.teardown(global);
});
it("calls each extension hook", async () => {
const mockExtension = {
name: "TestExtension",
init: jest.fn(),
setup: jest.fn(),
addCustomNodeDefs: jest.fn(),
getCustomWidgets: jest.fn(),
beforeRegisterNodeDef: jest.fn(),
registerCustomNodes: jest.fn(),
loadedGraphNode: jest.fn(),
nodeCreated: jest.fn(),
beforeConfigureGraph: jest.fn(),
afterConfigureGraph: jest.fn(),
};
const { app, ez, graph } = await start({
async preSetup(app) {
app.registerExtension(mockExtension);
},
});
// Basic initialisation hooks should be called once, with app
expect(mockExtension.init).toHaveBeenCalledTimes(1);
expect(mockExtension.init).toHaveBeenCalledWith(app);
// Adding custom node defs should be passed the full list of nodes
expect(mockExtension.addCustomNodeDefs).toHaveBeenCalledTimes(1);
expect(mockExtension.addCustomNodeDefs.mock.calls[0][1]).toStrictEqual(app);
const defs = mockExtension.addCustomNodeDefs.mock.calls[0][0];
expect(defs).toHaveProperty("KSampler");
expect(defs).toHaveProperty("LoadImage");
// Get custom widgets is called once and should return new widget types
expect(mockExtension.getCustomWidgets).toHaveBeenCalledTimes(1);
expect(mockExtension.getCustomWidgets).toHaveBeenCalledWith(app);
// Before register node def will be called once per node type
const nodeNames = Object.keys(defs);
const nodeCount = nodeNames.length;
expect(mockExtension.beforeRegisterNodeDef).toHaveBeenCalledTimes(nodeCount);
for (let i = 0; i < 10; i++) {
// It should be send the JS class and the original JSON definition
const nodeClass = mockExtension.beforeRegisterNodeDef.mock.calls[i][0];
const nodeDef = mockExtension.beforeRegisterNodeDef.mock.calls[i][1];
expect(nodeClass.name).toBe("ComfyNode");
expect(nodeClass.comfyClass).toBe(nodeNames[i]);
expect(nodeDef.name).toBe(nodeNames[i]);
expect(nodeDef).toHaveProperty("input");
expect(nodeDef).toHaveProperty("output");
}
// Register custom nodes is called once after registerNode defs to allow adding other frontend nodes
expect(mockExtension.registerCustomNodes).toHaveBeenCalledTimes(1);
// Before configure graph will be called here as the default graph is being loaded
expect(mockExtension.beforeConfigureGraph).toHaveBeenCalledTimes(1);
// it gets sent the graph data that is going to be loaded
const graphData = mockExtension.beforeConfigureGraph.mock.calls[0][0];
// A node created is fired for each node constructor that is called
expect(mockExtension.nodeCreated).toHaveBeenCalledTimes(graphData.nodes.length);
for (let i = 0; i < graphData.nodes.length; i++) {
expect(mockExtension.nodeCreated.mock.calls[i][0].type).toBe(graphData.nodes[i].type);
}
// Each node then calls loadedGraphNode to allow them to be updated
expect(mockExtension.loadedGraphNode).toHaveBeenCalledTimes(graphData.nodes.length);
for (let i = 0; i < graphData.nodes.length; i++) {
expect(mockExtension.loadedGraphNode.mock.calls[i][0].type).toBe(graphData.nodes[i].type);
}
// After configure is then called once all the setup is done
expect(mockExtension.afterConfigureGraph).toHaveBeenCalledTimes(1);
expect(mockExtension.setup).toHaveBeenCalledTimes(1);
expect(mockExtension.setup).toHaveBeenCalledWith(app);
// Ensure hooks are called in the correct order
const callOrder = [
"init",
"addCustomNodeDefs",
"getCustomWidgets",
"beforeRegisterNodeDef",
"registerCustomNodes",
"beforeConfigureGraph",
"nodeCreated",
"loadedGraphNode",
"afterConfigureGraph",
"setup",
];
for (let i = 1; i < callOrder.length; i++) {
const fn1 = mockExtension[callOrder[i - 1]];
const fn2 = mockExtension[callOrder[i]];
expect(fn1.mock.invocationCallOrder[0]).toBeLessThan(fn2.mock.invocationCallOrder[0]);
}
graph.clear();
// Ensure adding a new node calls the correct callback
ez.LoadImage();
expect(mockExtension.loadedGraphNode).toHaveBeenCalledTimes(graphData.nodes.length);
expect(mockExtension.nodeCreated).toHaveBeenCalledTimes(graphData.nodes.length + 1);
expect(mockExtension.nodeCreated.mock.lastCall[0].type).toBe("LoadImage");
// Reload the graph to ensure correct hooks are fired
await graph.reload();
// These hooks should not be fired again
expect(mockExtension.init).toHaveBeenCalledTimes(1);
expect(mockExtension.addCustomNodeDefs).toHaveBeenCalledTimes(1);
expect(mockExtension.getCustomWidgets).toHaveBeenCalledTimes(1);
expect(mockExtension.registerCustomNodes).toHaveBeenCalledTimes(1);
expect(mockExtension.beforeRegisterNodeDef).toHaveBeenCalledTimes(nodeCount);
expect(mockExtension.setup).toHaveBeenCalledTimes(1);
// These should be called again
expect(mockExtension.beforeConfigureGraph).toHaveBeenCalledTimes(2);
expect(mockExtension.nodeCreated).toHaveBeenCalledTimes(graphData.nodes.length + 2);
expect(mockExtension.loadedGraphNode).toHaveBeenCalledTimes(graphData.nodes.length + 1);
expect(mockExtension.afterConfigureGraph).toHaveBeenCalledTimes(2);
}, 15000);
it("allows custom nodeDefs and widgets to be registered", async () => {
const widgetMock = jest.fn((node, inputName, inputData, app) => {
expect(node.constructor.comfyClass).toBe("TestNode");
expect(inputName).toBe("test_input");
expect(inputData[0]).toBe("CUSTOMWIDGET");
expect(inputData[1]?.hello).toBe("world");
expect(app).toStrictEqual(app);
return {
widget: node.addWidget("button", inputName, "hello", () => {}),
};
});
// Register our extension that adds a custom node + widget type
const mockExtension = {
name: "TestExtension",
addCustomNodeDefs: (nodeDefs) => {
nodeDefs["TestNode"] = {
output: [],
output_name: [],
output_is_list: [],
name: "TestNode",
display_name: "TestNode",
category: "Test",
input: {
required: {
test_input: ["CUSTOMWIDGET", { hello: "world" }],
},
},
};
},
getCustomWidgets: jest.fn(() => {
return {
CUSTOMWIDGET: widgetMock,
};
}),
};
const { graph, ez } = await start({
async preSetup(app) {
app.registerExtension(mockExtension);
},
});
expect(mockExtension.getCustomWidgets).toBeCalledTimes(1);
graph.clear();
expect(widgetMock).toBeCalledTimes(0);
const node = ez.TestNode();
expect(widgetMock).toBeCalledTimes(1);
// Ensure our custom widget is created
expect(node.inputs.length).toBe(0);
expect(node.widgets.length).toBe(1);
const w = node.widgets[0].widget;
expect(w.name).toBe("test_input");
expect(w.type).toBe("button");
});
});

File diff suppressed because it is too large Load Diff

View File

@@ -1,295 +0,0 @@
// @ts-check
/// <reference path="../node_modules/@types/jest/index.d.ts" />
const { start } = require("../utils");
const lg = require("../utils/litegraph");
describe("users", () => {
beforeEach(() => {
lg.setup(global);
});
afterEach(() => {
lg.teardown(global);
});
function expectNoUserScreen() {
// Ensure login isnt visible
const selection = document.querySelectorAll("#comfy-user-selection")?.[0];
expect(selection["style"].display).toBe("none");
const menu = document.querySelectorAll(".comfy-menu")?.[0];
expect(window.getComputedStyle(menu)?.display).not.toBe("none");
}
describe("multi-user", () => {
function mockAddStylesheet() {
const utils = require("../../web/scripts/utils");
utils.addStylesheet = jest.fn().mockReturnValue(Promise.resolve());
}
async function waitForUserScreenShow() {
mockAddStylesheet();
// Wait for "show" to be called
const { UserSelectionScreen } = require("../../web/scripts/ui/userSelection");
let resolve, reject;
const fn = UserSelectionScreen.prototype.show;
const p = new Promise((res, rej) => {
resolve = res;
reject = rej;
});
jest.spyOn(UserSelectionScreen.prototype, "show").mockImplementation(async (...args) => {
const res = fn(...args);
await new Promise(process.nextTick); // wait for promises to resolve
resolve();
return res;
});
// @ts-ignore
setTimeout(() => reject("timeout waiting for UserSelectionScreen to be shown."), 500);
await p;
await new Promise(process.nextTick); // wait for promises to resolve
}
async function testUserScreen(onShown, users) {
if (!users) {
users = {};
}
const starting = start({
resetEnv: true,
userConfig: { storage: "server", users },
});
// Ensure no current user
expect(localStorage["Comfy.userId"]).toBeFalsy();
expect(localStorage["Comfy.userName"]).toBeFalsy();
await waitForUserScreenShow();
const selection = document.querySelectorAll("#comfy-user-selection")?.[0];
expect(selection).toBeTruthy();
// Ensure login is visible
expect(window.getComputedStyle(selection)?.display).not.toBe("none");
// Ensure menu is hidden
const menu = document.querySelectorAll(".comfy-menu")?.[0];
expect(window.getComputedStyle(menu)?.display).toBe("none");
const isCreate = await onShown(selection);
// Submit form
selection.querySelectorAll("form")[0].submit();
await new Promise(process.nextTick); // wait for promises to resolve
// Wait for start
const s = await starting;
// Ensure login is removed
expect(document.querySelectorAll("#comfy-user-selection")).toHaveLength(0);
expect(window.getComputedStyle(menu)?.display).not.toBe("none");
// Ensure settings + templates are saved
const { api } = require("../../web/scripts/api");
expect(api.createUser).toHaveBeenCalledTimes(+isCreate);
expect(api.storeSettings).toHaveBeenCalledTimes(+isCreate);
expect(api.storeUserData).toHaveBeenCalledTimes(+isCreate);
if (isCreate) {
expect(api.storeUserData).toHaveBeenCalledWith("comfy.templates.json", null, { stringify: false });
expect(s.app.isNewUserSession).toBeTruthy();
} else {
expect(s.app.isNewUserSession).toBeFalsy();
}
return { users, selection, ...s };
}
it("allows user creation if no users", async () => {
const { users } = await testUserScreen((selection) => {
// Ensure we have no users flag added
expect(selection.classList.contains("no-users")).toBeTruthy();
// Enter a username
const input = selection.getElementsByTagName("input")[0];
input.focus();
input.value = "Test User";
return true;
});
expect(users).toStrictEqual({
"Test User!": "Test User",
});
expect(localStorage["Comfy.userId"]).toBe("Test User!");
expect(localStorage["Comfy.userName"]).toBe("Test User");
});
it("allows user creation if no current user but other users", async () => {
const users = {
"Test User 2!": "Test User 2",
};
await testUserScreen((selection) => {
expect(selection.classList.contains("no-users")).toBeFalsy();
// Enter a username
const input = selection.getElementsByTagName("input")[0];
input.focus();
input.value = "Test User 3";
return true;
}, users);
expect(users).toStrictEqual({
"Test User 2!": "Test User 2",
"Test User 3!": "Test User 3",
});
expect(localStorage["Comfy.userId"]).toBe("Test User 3!");
expect(localStorage["Comfy.userName"]).toBe("Test User 3");
});
it("allows user selection if no current user but other users", async () => {
const users = {
"A!": "A",
"B!": "B",
"C!": "C",
};
await testUserScreen((selection) => {
expect(selection.classList.contains("no-users")).toBeFalsy();
// Check user list
const select = selection.getElementsByTagName("select")[0];
const options = select.getElementsByTagName("option");
expect(
[...options]
.filter((o) => !o.disabled)
.reduce((p, n) => {
p[n.getAttribute("value")] = n.textContent;
return p;
}, {})
).toStrictEqual(users);
// Select an option
select.focus();
select.value = options[2].value;
return false;
}, users);
expect(users).toStrictEqual(users);
expect(localStorage["Comfy.userId"]).toBe("B!");
expect(localStorage["Comfy.userName"]).toBe("B");
});
it("doesnt show user screen if current user", async () => {
const starting = start({
resetEnv: true,
userConfig: {
storage: "server",
users: {
"User!": "User",
},
},
localStorage: {
"Comfy.userId": "User!",
"Comfy.userName": "User",
},
});
await new Promise(process.nextTick); // wait for promises to resolve
expectNoUserScreen();
await starting;
});
it("allows user switching", async () => {
const { app } = await start({
resetEnv: true,
userConfig: {
storage: "server",
users: {
"User!": "User",
},
},
localStorage: {
"Comfy.userId": "User!",
"Comfy.userName": "User",
},
});
// cant actually test switching user easily but can check the setting is present
expect(app.ui.settings.settingsLookup["Comfy.SwitchUser"]).toBeTruthy();
});
});
describe("single-user", () => {
it("doesnt show user creation if no default user", async () => {
const { app } = await start({
resetEnv: true,
userConfig: { migrated: false, storage: "server" },
});
expectNoUserScreen();
// It should store the settings
const { api } = require("../../web/scripts/api");
expect(api.storeSettings).toHaveBeenCalledTimes(1);
expect(api.storeUserData).toHaveBeenCalledTimes(1);
expect(api.storeUserData).toHaveBeenCalledWith("comfy.templates.json", null, { stringify: false });
expect(app.isNewUserSession).toBeTruthy();
});
it("doesnt show user creation if default user", async () => {
const { app } = await start({
resetEnv: true,
userConfig: { migrated: true, storage: "server" },
});
expectNoUserScreen();
// It should store the settings
const { api } = require("../../web/scripts/api");
expect(api.storeSettings).toHaveBeenCalledTimes(0);
expect(api.storeUserData).toHaveBeenCalledTimes(0);
expect(app.isNewUserSession).toBeFalsy();
});
it("doesnt allow user switching", async () => {
const { app } = await start({
resetEnv: true,
userConfig: { migrated: true, storage: "server" },
});
expectNoUserScreen();
expect(app.ui.settings.settingsLookup["Comfy.SwitchUser"]).toBeFalsy();
});
});
describe("browser-user", () => {
it("doesnt show user creation if no default user", async () => {
const { app } = await start({
resetEnv: true,
userConfig: { migrated: false, storage: "browser" },
});
expectNoUserScreen();
// It should store the settings
const { api } = require("../../web/scripts/api");
expect(api.storeSettings).toHaveBeenCalledTimes(0);
expect(api.storeUserData).toHaveBeenCalledTimes(0);
expect(app.isNewUserSession).toBeFalsy();
});
it("doesnt show user creation if default user", async () => {
const { app } = await start({
resetEnv: true,
userConfig: { migrated: true, storage: "server" },
});
expectNoUserScreen();
// It should store the settings
const { api } = require("../../web/scripts/api");
expect(api.storeSettings).toHaveBeenCalledTimes(0);
expect(api.storeUserData).toHaveBeenCalledTimes(0);
expect(app.isNewUserSession).toBeFalsy();
});
it("doesnt allow user switching", async () => {
const { app } = await start({
resetEnv: true,
userConfig: { migrated: true, storage: "browser" },
});
expectNoUserScreen();
expect(app.ui.settings.settingsLookup["Comfy.SwitchUser"]).toBeFalsy();
});
});
});

View File

@@ -1,557 +0,0 @@
// @ts-check
/// <reference path="../node_modules/@types/jest/index.d.ts" />
const {
start,
makeNodeDef,
checkBeforeAndAfterReload,
assertNotNullOrUndefined,
createDefaultWorkflow,
} = require("../utils");
const lg = require("../utils/litegraph");
/**
* @typedef { import("../utils/ezgraph") } Ez
* @typedef { ReturnType<Ez["Ez"]["graph"]>["ez"] } EzNodeFactory
*/
/**
* @param { EzNodeFactory } ez
* @param { InstanceType<Ez["EzGraph"]> } graph
* @param { InstanceType<Ez["EzInput"]> } input
* @param { string } widgetType
* @param { number } controlWidgetCount
* @returns
*/
async function connectPrimitiveAndReload(ez, graph, input, widgetType, controlWidgetCount = 0) {
// Connect to primitive and ensure its still connected after
let primitive = ez.PrimitiveNode();
primitive.outputs[0].connectTo(input);
await checkBeforeAndAfterReload(graph, async () => {
primitive = graph.find(primitive);
let { connections } = primitive.outputs[0];
expect(connections).toHaveLength(1);
expect(connections[0].targetNode.id).toBe(input.node.node.id);
// Ensure widget is correct type
const valueWidget = primitive.widgets.value;
expect(valueWidget.widget.type).toBe(widgetType);
// Check if control_after_generate should be added
if (controlWidgetCount) {
const controlWidget = primitive.widgets.control_after_generate;
expect(controlWidget.widget.type).toBe("combo");
if (widgetType === "combo") {
const filterWidget = primitive.widgets.control_filter_list;
expect(filterWidget.widget.type).toBe("string");
}
}
// Ensure we dont have other widgets
expect(primitive.node.widgets).toHaveLength(1 + controlWidgetCount);
});
return primitive;
}
describe("widget inputs", () => {
beforeEach(() => {
lg.setup(global);
});
afterEach(() => {
lg.teardown(global);
});
[
{ name: "int", type: "INT", widget: "number", control: 1 },
{ name: "float", type: "FLOAT", widget: "number", control: 1 },
{ name: "text", type: "STRING" },
{
name: "customtext",
type: "STRING",
opt: { multiline: true },
},
{ name: "toggle", type: "BOOLEAN" },
{ name: "combo", type: ["a", "b", "c"], control: 2 },
].forEach((c) => {
test(`widget conversion + primitive works on ${c.name}`, async () => {
const { ez, graph } = await start({
mockNodeDefs: makeNodeDef("TestNode", { [c.name]: [c.type, c.opt ?? {}] }),
});
// Create test node and convert to input
const n = ez.TestNode();
const w = n.widgets[c.name];
w.convertToInput();
expect(w.isConvertedToInput).toBeTruthy();
const input = w.getConvertedInput();
expect(input).toBeTruthy();
// @ts-ignore : input is valid here
await connectPrimitiveAndReload(ez, graph, input, c.widget ?? c.name, c.control);
});
});
test("converted widget works after reload", async () => {
const { ez, graph } = await start();
let n = ez.CheckpointLoaderSimple();
const inputCount = n.inputs.length;
// Convert ckpt name to an input
n.widgets.ckpt_name.convertToInput();
expect(n.widgets.ckpt_name.isConvertedToInput).toBeTruthy();
expect(n.inputs.ckpt_name).toBeTruthy();
expect(n.inputs.length).toEqual(inputCount + 1);
// Convert back to widget and ensure input is removed
n.widgets.ckpt_name.convertToWidget();
expect(n.widgets.ckpt_name.isConvertedToInput).toBeFalsy();
expect(n.inputs.ckpt_name).toBeFalsy();
expect(n.inputs.length).toEqual(inputCount);
// Convert again and reload the graph to ensure it maintains state
n.widgets.ckpt_name.convertToInput();
expect(n.inputs.length).toEqual(inputCount + 1);
const primitive = await connectPrimitiveAndReload(ez, graph, n.inputs.ckpt_name, "combo", 2);
// Disconnect & reconnect
primitive.outputs[0].connections[0].disconnect();
let { connections } = primitive.outputs[0];
expect(connections).toHaveLength(0);
primitive.outputs[0].connectTo(n.inputs.ckpt_name);
({ connections } = primitive.outputs[0]);
expect(connections).toHaveLength(1);
expect(connections[0].targetNode.id).toBe(n.node.id);
// Convert back to widget and ensure input is removed
n.widgets.ckpt_name.convertToWidget();
expect(n.widgets.ckpt_name.isConvertedToInput).toBeFalsy();
expect(n.inputs.ckpt_name).toBeFalsy();
expect(n.inputs.length).toEqual(inputCount);
});
test("converted widget works on clone", async () => {
const { graph, ez } = await start();
let n = ez.CheckpointLoaderSimple();
// Convert the widget to an input
n.widgets.ckpt_name.convertToInput();
expect(n.widgets.ckpt_name.isConvertedToInput).toBeTruthy();
// Clone the node
n.menu["Clone"].call();
expect(graph.nodes).toHaveLength(2);
const clone = graph.nodes[1];
expect(clone.id).not.toEqual(n.id);
// Ensure the clone has an input
expect(clone.widgets.ckpt_name.isConvertedToInput).toBeTruthy();
expect(clone.inputs.ckpt_name).toBeTruthy();
// Ensure primitive connects to both nodes
let primitive = ez.PrimitiveNode();
primitive.outputs[0].connectTo(n.inputs.ckpt_name);
primitive.outputs[0].connectTo(clone.inputs.ckpt_name);
expect(primitive.outputs[0].connections).toHaveLength(2);
// Convert back to widget and ensure input is removed
clone.widgets.ckpt_name.convertToWidget();
expect(clone.widgets.ckpt_name.isConvertedToInput).toBeFalsy();
expect(clone.inputs.ckpt_name).toBeFalsy();
});
test("shows missing node error on custom node with converted input", async () => {
const { graph } = await start();
const dialogShow = jest.spyOn(graph.app.ui.dialog, "show");
await graph.app.loadGraphData({
last_node_id: 3,
last_link_id: 4,
nodes: [
{
id: 1,
type: "TestNode",
pos: [41.87329101561909, 389.7381480823742],
size: { 0: 220, 1: 374 },
flags: {},
order: 1,
mode: 0,
inputs: [{ name: "test", type: "FLOAT", link: 4, widget: { name: "test" }, slot_index: 0 }],
outputs: [],
properties: { "Node name for S&R": "TestNode" },
widgets_values: [1],
},
{
id: 3,
type: "PrimitiveNode",
pos: [-312, 433],
size: { 0: 210, 1: 82 },
flags: {},
order: 0,
mode: 0,
outputs: [{ links: [4], widget: { name: "test" } }],
title: "test",
properties: {},
},
],
links: [[4, 3, 0, 1, 6, "FLOAT"]],
groups: [],
config: {},
extra: {},
version: 0.4,
});
expect(dialogShow).toBeCalledTimes(1);
expect(dialogShow.mock.calls[0][0].innerHTML).toContain("the following node types were not found");
expect(dialogShow.mock.calls[0][0].innerHTML).toContain("TestNode");
});
test("defaultInput widgets can be converted back to inputs", async () => {
const { graph, ez } = await start({
mockNodeDefs: makeNodeDef("TestNode", { example: ["INT", { defaultInput: true }] }),
});
// Create test node and ensure it starts as an input
let n = ez.TestNode();
let w = n.widgets.example;
expect(w.isConvertedToInput).toBeTruthy();
let input = w.getConvertedInput();
expect(input).toBeTruthy();
// Ensure it can be converted to
w.convertToWidget();
expect(w.isConvertedToInput).toBeFalsy();
expect(n.inputs.length).toEqual(0);
// and from
w.convertToInput();
expect(w.isConvertedToInput).toBeTruthy();
input = w.getConvertedInput();
// Reload and ensure it still only has 1 converted widget
if (!assertNotNullOrUndefined(input)) return;
await connectPrimitiveAndReload(ez, graph, input, "number", 1);
n = graph.find(n);
expect(n.widgets).toHaveLength(1);
w = n.widgets.example;
expect(w.isConvertedToInput).toBeTruthy();
// Convert back to widget and ensure it is still a widget after reload
w.convertToWidget();
await graph.reload();
n = graph.find(n);
expect(n.widgets).toHaveLength(1);
expect(n.widgets[0].isConvertedToInput).toBeFalsy();
expect(n.inputs.length).toEqual(0);
});
test("forceInput widgets can not be converted back to inputs", async () => {
const { graph, ez } = await start({
mockNodeDefs: makeNodeDef("TestNode", { example: ["INT", { forceInput: true }] }),
});
// Create test node and ensure it starts as an input
let n = ez.TestNode();
let w = n.widgets.example;
expect(w.isConvertedToInput).toBeTruthy();
const input = w.getConvertedInput();
expect(input).toBeTruthy();
// Convert to widget should error
expect(() => w.convertToWidget()).toThrow();
// Reload and ensure it still only has 1 converted widget
if (assertNotNullOrUndefined(input)) {
await connectPrimitiveAndReload(ez, graph, input, "number", 1);
n = graph.find(n);
expect(n.widgets).toHaveLength(1);
expect(n.widgets.example.isConvertedToInput).toBeTruthy();
}
});
test("primitive can connect to matching combos on converted widgets", async () => {
const { ez } = await start({
mockNodeDefs: {
...makeNodeDef("TestNode1", { example: [["A", "B", "C"], { forceInput: true }] }),
...makeNodeDef("TestNode2", { example: [["A", "B", "C"], { forceInput: true }] }),
},
});
const n1 = ez.TestNode1();
const n2 = ez.TestNode2();
const p = ez.PrimitiveNode();
p.outputs[0].connectTo(n1.inputs[0]);
p.outputs[0].connectTo(n2.inputs[0]);
expect(p.outputs[0].connections).toHaveLength(2);
const valueWidget = p.widgets.value;
expect(valueWidget.widget.type).toBe("combo");
expect(valueWidget.widget.options.values).toEqual(["A", "B", "C"]);
});
test("primitive can not connect to non matching combos on converted widgets", async () => {
const { ez } = await start({
mockNodeDefs: {
...makeNodeDef("TestNode1", { example: [["A", "B", "C"], { forceInput: true }] }),
...makeNodeDef("TestNode2", { example: [["A", "B"], { forceInput: true }] }),
},
});
const n1 = ez.TestNode1();
const n2 = ez.TestNode2();
const p = ez.PrimitiveNode();
p.outputs[0].connectTo(n1.inputs[0]);
expect(() => p.outputs[0].connectTo(n2.inputs[0])).toThrow();
expect(p.outputs[0].connections).toHaveLength(1);
});
test("combo output can not connect to non matching combos list input", async () => {
const { ez } = await start({
mockNodeDefs: {
...makeNodeDef("TestNode1", {}, [["A", "B"]]),
...makeNodeDef("TestNode2", { example: [["A", "B"], { forceInput: true }] }),
...makeNodeDef("TestNode3", { example: [["A", "B", "C"], { forceInput: true }] }),
},
});
const n1 = ez.TestNode1();
const n2 = ez.TestNode2();
const n3 = ez.TestNode3();
n1.outputs[0].connectTo(n2.inputs[0]);
expect(() => n1.outputs[0].connectTo(n3.inputs[0])).toThrow();
});
test("combo primitive can filter list when control_after_generate called", async () => {
const { ez } = await start({
mockNodeDefs: {
...makeNodeDef("TestNode1", { example: [["A", "B", "C", "D", "AA", "BB", "CC", "DD", "AAA", "BBB"], {}] }),
},
});
const n1 = ez.TestNode1();
n1.widgets.example.convertToInput();
const p = ez.PrimitiveNode();
p.outputs[0].connectTo(n1.inputs[0]);
const value = p.widgets.value;
const control = p.widgets.control_after_generate.widget;
const filter = p.widgets.control_filter_list;
expect(p.widgets.length).toBe(3);
control.value = "increment";
expect(value.value).toBe("A");
// Manually trigger after queue when set to increment
control["afterQueued"]();
expect(value.value).toBe("B");
// Filter to items containing D
filter.value = "D";
control["afterQueued"]();
expect(value.value).toBe("D");
control["afterQueued"]();
expect(value.value).toBe("DD");
// Check decrement
value.value = "BBB";
control.value = "decrement";
filter.value = "B";
control["afterQueued"]();
expect(value.value).toBe("BB");
control["afterQueued"]();
expect(value.value).toBe("B");
// Check regex works
value.value = "BBB";
filter.value = "/[AB]|^C$/";
control["afterQueued"]();
expect(value.value).toBe("AAA");
control["afterQueued"]();
expect(value.value).toBe("BB");
control["afterQueued"]();
expect(value.value).toBe("AA");
control["afterQueued"]();
expect(value.value).toBe("C");
control["afterQueued"]();
expect(value.value).toBe("B");
control["afterQueued"]();
expect(value.value).toBe("A");
// Check random
control.value = "randomize";
filter.value = "/D/";
for (let i = 0; i < 100; i++) {
control["afterQueued"]();
expect(value.value === "D" || value.value === "DD").toBeTruthy();
}
// Ensure it doesnt apply when fixed
control.value = "fixed";
value.value = "B";
filter.value = "C";
control["afterQueued"]();
expect(value.value).toBe("B");
});
describe("reroutes", () => {
async function checkOutput(graph, values) {
expect((await graph.toPrompt()).output).toStrictEqual({
1: { inputs: { ckpt_name: "model1.safetensors" }, class_type: "CheckpointLoaderSimple" },
2: { inputs: { text: "positive", clip: ["1", 1] }, class_type: "CLIPTextEncode" },
3: { inputs: { text: "negative", clip: ["1", 1] }, class_type: "CLIPTextEncode" },
4: {
inputs: { width: values.width ?? 512, height: values.height ?? 512, batch_size: values?.batch_size ?? 1 },
class_type: "EmptyLatentImage",
},
5: {
inputs: {
seed: 0,
steps: 20,
cfg: 8,
sampler_name: "euler",
scheduler: values?.scheduler ?? "normal",
denoise: 1,
model: ["1", 0],
positive: ["2", 0],
negative: ["3", 0],
latent_image: ["4", 0],
},
class_type: "KSampler",
},
6: { inputs: { samples: ["5", 0], vae: ["1", 2] }, class_type: "VAEDecode" },
7: {
inputs: { filename_prefix: values.filename_prefix ?? "ComfyUI", images: ["6", 0] },
class_type: "SaveImage",
},
});
}
async function waitForWidget(node) {
// widgets are created slightly after the graph is ready
// hard to find an exact hook to get these so just wait for them to be ready
for (let i = 0; i < 10; i++) {
await new Promise((r) => setTimeout(r, 10));
if (node.widgets?.value) {
return;
}
}
}
it("can connect primitive via a reroute path to a widget input", async () => {
const { ez, graph } = await start();
const nodes = createDefaultWorkflow(ez, graph);
nodes.empty.widgets.width.convertToInput();
nodes.sampler.widgets.scheduler.convertToInput();
nodes.save.widgets.filename_prefix.convertToInput();
let widthReroute = ez.Reroute();
let schedulerReroute = ez.Reroute();
let fileReroute = ez.Reroute();
let widthNext = widthReroute;
let schedulerNext = schedulerReroute;
let fileNext = fileReroute;
for (let i = 0; i < 5; i++) {
let next = ez.Reroute();
widthNext.outputs[0].connectTo(next.inputs[0]);
widthNext = next;
next = ez.Reroute();
schedulerNext.outputs[0].connectTo(next.inputs[0]);
schedulerNext = next;
next = ez.Reroute();
fileNext.outputs[0].connectTo(next.inputs[0]);
fileNext = next;
}
widthNext.outputs[0].connectTo(nodes.empty.inputs.width);
schedulerNext.outputs[0].connectTo(nodes.sampler.inputs.scheduler);
fileNext.outputs[0].connectTo(nodes.save.inputs.filename_prefix);
let widthPrimitive = ez.PrimitiveNode();
let schedulerPrimitive = ez.PrimitiveNode();
let filePrimitive = ez.PrimitiveNode();
widthPrimitive.outputs[0].connectTo(widthReroute.inputs[0]);
schedulerPrimitive.outputs[0].connectTo(schedulerReroute.inputs[0]);
filePrimitive.outputs[0].connectTo(fileReroute.inputs[0]);
expect(widthPrimitive.widgets.value.value).toBe(512);
widthPrimitive.widgets.value.value = 1024;
expect(schedulerPrimitive.widgets.value.value).toBe("normal");
schedulerPrimitive.widgets.value.value = "simple";
expect(filePrimitive.widgets.value.value).toBe("ComfyUI");
filePrimitive.widgets.value.value = "ComfyTest";
await checkBeforeAndAfterReload(graph, async () => {
widthPrimitive = graph.find(widthPrimitive);
schedulerPrimitive = graph.find(schedulerPrimitive);
filePrimitive = graph.find(filePrimitive);
await waitForWidget(filePrimitive);
expect(widthPrimitive.widgets.length).toBe(2);
expect(schedulerPrimitive.widgets.length).toBe(3);
expect(filePrimitive.widgets.length).toBe(1);
await checkOutput(graph, {
width: 1024,
scheduler: "simple",
filename_prefix: "ComfyTest",
});
});
});
it("can connect primitive via a reroute path to multiple widget inputs", async () => {
const { ez, graph } = await start();
const nodes = createDefaultWorkflow(ez, graph);
nodes.empty.widgets.width.convertToInput();
nodes.empty.widgets.height.convertToInput();
nodes.empty.widgets.batch_size.convertToInput();
let reroute = ez.Reroute();
let prevReroute = reroute;
for (let i = 0; i < 5; i++) {
const next = ez.Reroute();
prevReroute.outputs[0].connectTo(next.inputs[0]);
prevReroute = next;
}
const r1 = ez.Reroute(prevReroute.outputs[0]);
const r2 = ez.Reroute(prevReroute.outputs[0]);
const r3 = ez.Reroute(r2.outputs[0]);
const r4 = ez.Reroute(r2.outputs[0]);
r1.outputs[0].connectTo(nodes.empty.inputs.width);
r3.outputs[0].connectTo(nodes.empty.inputs.height);
r4.outputs[0].connectTo(nodes.empty.inputs.batch_size);
let primitive = ez.PrimitiveNode();
primitive.outputs[0].connectTo(reroute.inputs[0]);
expect(primitive.widgets.value.value).toBe(1);
primitive.widgets.value.value = 64;
await checkBeforeAndAfterReload(graph, async (r) => {
primitive = graph.find(primitive);
await waitForWidget(primitive);
// Ensure widget configs are merged
expect(primitive.widgets.value.widget.options?.min).toBe(16); // width/height min
expect(primitive.widgets.value.widget.options?.max).toBe(4096); // batch max
expect(primitive.widgets.value.widget.options?.step).toBe(80); // width/height step * 10
await checkOutput(graph, {
width: 64,
height: 64,
batch_size: 64,
});
});
});
});
});

View File

@@ -1,452 +0,0 @@
// @ts-check
/// <reference path="../../web/types/litegraph.d.ts" />
/**
* @typedef { import("../../web/scripts/app")["app"] } app
* @typedef { import("../../web/types/litegraph") } LG
* @typedef { import("../../web/types/litegraph").IWidget } IWidget
* @typedef { import("../../web/types/litegraph").ContextMenuItem } ContextMenuItem
* @typedef { import("../../web/types/litegraph").INodeInputSlot } INodeInputSlot
* @typedef { import("../../web/types/litegraph").INodeOutputSlot } INodeOutputSlot
* @typedef { InstanceType<LG["LGraphNode"]> & { widgets?: Array<IWidget> } } LGNode
* @typedef { (...args: EzOutput[] | [...EzOutput[], Record<string, unknown>]) => EzNode } EzNodeFactory
*/
export class EzConnection {
/** @type { app } */
app;
/** @type { InstanceType<LG["LLink"]> } */
link;
get originNode() {
return new EzNode(this.app, this.app.graph.getNodeById(this.link.origin_id));
}
get originOutput() {
return this.originNode.outputs[this.link.origin_slot];
}
get targetNode() {
return new EzNode(this.app, this.app.graph.getNodeById(this.link.target_id));
}
get targetInput() {
return this.targetNode.inputs[this.link.target_slot];
}
/**
* @param { app } app
* @param { InstanceType<LG["LLink"]> } link
*/
constructor(app, link) {
this.app = app;
this.link = link;
}
disconnect() {
this.targetInput.disconnect();
}
}
export class EzSlot {
/** @type { EzNode } */
node;
/** @type { number } */
index;
/**
* @param { EzNode } node
* @param { number } index
*/
constructor(node, index) {
this.node = node;
this.index = index;
}
}
export class EzInput extends EzSlot {
/** @type { INodeInputSlot } */
input;
/**
* @param { EzNode } node
* @param { number } index
* @param { INodeInputSlot } input
*/
constructor(node, index, input) {
super(node, index);
this.input = input;
}
get connection() {
const link = this.node.node.inputs?.[this.index]?.link;
if (link == null) {
return null;
}
return new EzConnection(this.node.app, this.node.app.graph.links[link]);
}
disconnect() {
this.node.node.disconnectInput(this.index);
}
}
export class EzOutput extends EzSlot {
/** @type { INodeOutputSlot } */
output;
/**
* @param { EzNode } node
* @param { number } index
* @param { INodeOutputSlot } output
*/
constructor(node, index, output) {
super(node, index);
this.output = output;
}
get connections() {
return (this.node.node.outputs?.[this.index]?.links ?? []).map(
(l) => new EzConnection(this.node.app, this.node.app.graph.links[l])
);
}
/**
* @param { EzInput } input
*/
connectTo(input) {
if (!input) throw new Error("Invalid input");
/**
* @type { LG["LLink"] | null }
*/
const link = this.node.node.connect(this.index, input.node.node, input.index);
if (!link) {
const inp = input.input;
const inName = inp.name || inp.label || inp.type;
throw new Error(
`Connecting from ${input.node.node.type}#${input.node.id}[${inName}#${input.index}] -> ${this.node.node.type}#${this.node.id}[${
this.output.name ?? this.output.type
}#${this.index}] failed.`
);
}
return link;
}
}
export class EzNodeMenuItem {
/** @type { EzNode } */
node;
/** @type { number } */
index;
/** @type { ContextMenuItem } */
item;
/**
* @param { EzNode } node
* @param { number } index
* @param { ContextMenuItem } item
*/
constructor(node, index, item) {
this.node = node;
this.index = index;
this.item = item;
}
call(selectNode = true) {
if (!this.item?.callback) throw new Error(`Menu Item ${this.item?.content ?? "[null]"} has no callback.`);
if (selectNode) {
this.node.select();
}
return this.item.callback.call(this.node.node, undefined, undefined, undefined, undefined, this.node.node);
}
}
export class EzWidget {
/** @type { EzNode } */
node;
/** @type { number } */
index;
/** @type { IWidget } */
widget;
/**
* @param { EzNode } node
* @param { number } index
* @param { IWidget } widget
*/
constructor(node, index, widget) {
this.node = node;
this.index = index;
this.widget = widget;
}
get value() {
return this.widget.value;
}
set value(v) {
this.widget.value = v;
this.widget.callback?.call?.(this.widget, v)
}
get isConvertedToInput() {
// @ts-ignore : this type is valid for converted widgets
return this.widget.type === "converted-widget";
}
getConvertedInput() {
if (!this.isConvertedToInput) throw new Error(`Widget ${this.widget.name} is not converted to input.`);
return this.node.inputs.find((inp) => inp.input["widget"]?.name === this.widget.name);
}
convertToWidget() {
if (!this.isConvertedToInput)
throw new Error(`Widget ${this.widget.name} cannot be converted as it is already a widget.`);
var menu = this.node.menu["Convert Input to Widget"].item.submenu.options;
var index = menu.findIndex(a => a.content == `Convert ${this.widget.name} to widget`);
menu[index].callback.call();
}
convertToInput() {
if (this.isConvertedToInput)
throw new Error(`Widget ${this.widget.name} cannot be converted as it is already an input.`);
var menu = this.node.menu["Convert Widget to Input"].item.submenu.options;
var index = menu.findIndex(a => a.content == `Convert ${this.widget.name} to input`);
menu[index].callback.call();
}
}
export class EzNode {
/** @type { app } */
app;
/** @type { LGNode } */
node;
/**
* @param { app } app
* @param { LGNode } node
*/
constructor(app, node) {
this.app = app;
this.node = node;
}
get id() {
return this.node.id;
}
get inputs() {
return this.#makeLookupArray("inputs", "name", EzInput);
}
get outputs() {
return this.#makeLookupArray("outputs", "name", EzOutput);
}
get widgets() {
return this.#makeLookupArray("widgets", "name", EzWidget);
}
get menu() {
return this.#makeLookupArray(() => this.app.canvas.getNodeMenuOptions(this.node), "content", EzNodeMenuItem);
}
get isRemoved() {
return !this.app.graph.getNodeById(this.id);
}
select(addToSelection = false) {
this.app.canvas.selectNode(this.node, addToSelection);
}
// /**
// * @template { "inputs" | "outputs" } T
// * @param { T } type
// * @returns { Record<string, type extends "inputs" ? EzInput : EzOutput> & (type extends "inputs" ? EzInput [] : EzOutput[]) }
// */
// #getSlotItems(type) {
// // @ts-ignore : these items are correct
// return (this.node[type] ?? []).reduce((p, s, i) => {
// if (s.name in p) {
// throw new Error(`Unable to store input ${s.name} on array as name conflicts.`);
// }
// // @ts-ignore
// p.push((p[s.name] = new (type === "inputs" ? EzInput : EzOutput)(this, i, s)));
// return p;
// }, Object.assign([], { $: this }));
// }
/**
* @template { { new(node: EzNode, index: number, obj: any): any } } T
* @param { "inputs" | "outputs" | "widgets" | (() => Array<unknown>) } nodeProperty
* @param { string } nameProperty
* @param { T } ctor
* @returns { Record<string, InstanceType<T>> & Array<InstanceType<T>> }
*/
#makeLookupArray(nodeProperty, nameProperty, ctor) {
const items = typeof nodeProperty === "function" ? nodeProperty() : this.node[nodeProperty];
// @ts-ignore
return (items ?? []).reduce((p, s, i) => {
if (!s) return p;
const name = s[nameProperty];
const item = new ctor(this, i, s);
// @ts-ignore
p.push(item);
if (name) {
// @ts-ignore
if (name in p) {
throw new Error(`Unable to store ${nodeProperty} ${name} on array as name conflicts.`);
}
}
// @ts-ignore
p[name] = item;
return p;
}, Object.assign([], { $: this }));
}
}
export class EzGraph {
/** @type { app } */
app;
/**
* @param { app } app
*/
constructor(app) {
this.app = app;
}
get nodes() {
return this.app.graph._nodes.map((n) => new EzNode(this.app, n));
}
clear() {
this.app.graph.clear();
}
arrange() {
this.app.graph.arrange();
}
stringify() {
return JSON.stringify(this.app.graph.serialize(), undefined);
}
/**
* @param { number | LGNode | EzNode } obj
* @returns { EzNode }
*/
find(obj) {
let match;
let id;
if (typeof obj === "number") {
id = obj;
} else {
id = obj.id;
}
match = this.app.graph.getNodeById(id);
if (!match) {
throw new Error(`Unable to find node with ID ${id}.`);
}
return new EzNode(this.app, match);
}
/**
* @returns { Promise<void> }
*/
reload() {
const graph = JSON.parse(JSON.stringify(this.app.graph.serialize()));
return new Promise((r) => {
this.app.graph.clear();
setTimeout(async () => {
await this.app.loadGraphData(graph);
r();
}, 10);
});
}
/**
* @returns { Promise<{
* workflow: {},
* output: Record<string, {
* class_name: string,
* inputs: Record<string, [string, number] | unknown>
* }>}> }
*/
toPrompt() {
// @ts-ignore
return this.app.graphToPrompt();
}
}
export const Ez = {
/**
* Quickly build and interact with a ComfyUI graph
* @example
* const { ez, graph } = Ez.graph(app);
* graph.clear();
* const [model, clip, vae] = ez.CheckpointLoaderSimple().outputs;
* const [pos] = ez.CLIPTextEncode(clip, { text: "positive" }).outputs;
* const [neg] = ez.CLIPTextEncode(clip, { text: "negative" }).outputs;
* const [latent] = ez.KSampler(model, pos, neg, ...ez.EmptyLatentImage().outputs).outputs;
* const [image] = ez.VAEDecode(latent, vae).outputs;
* const saveNode = ez.SaveImage(image);
* console.log(saveNode);
* graph.arrange();
* @param { app } app
* @param { LG["LiteGraph"] } LiteGraph
* @param { LG["LGraphCanvas"] } LGraphCanvas
* @param { boolean } clearGraph
* @returns { { graph: EzGraph, ez: Record<string, EzNodeFactory> } }
*/
graph(app, LiteGraph = window["LiteGraph"], LGraphCanvas = window["LGraphCanvas"], clearGraph = true) {
// Always set the active canvas so things work
LGraphCanvas.active_canvas = app.canvas;
if (clearGraph) {
app.graph.clear();
}
// @ts-ignore : this proxy handles utility methods & node creation
const factory = new Proxy(
{},
{
get(_, p) {
if (typeof p !== "string") throw new Error("Invalid node");
const node = LiteGraph.createNode(p);
if (!node) throw new Error(`Unknown node "${p}"`);
app.graph.add(node);
/**
* @param {Parameters<EzNodeFactory>} args
*/
return function (...args) {
const ezNode = new EzNode(app, node);
const inputs = ezNode.inputs;
let slot = 0;
for (const arg of args) {
if (arg instanceof EzOutput) {
arg.connectTo(inputs[slot++]);
} else {
for (const k in arg) {
ezNode.widgets[k].value = arg[k];
}
}
}
return ezNode;
};
},
}
);
return { graph: new EzGraph(app), ez: factory };
},
};

View File

@@ -1,129 +0,0 @@
const { mockApi } = require("./setup");
const { Ez } = require("./ezgraph");
const lg = require("./litegraph");
const fs = require("fs");
const path = require("path");
const html = fs.readFileSync(path.resolve(__dirname, "../../web/index.html"))
/**
*
* @param { Parameters<typeof mockApi>[0] & {
* resetEnv?: boolean,
* preSetup?(app): Promise<void>,
* localStorage?: Record<string, string>
* } } config
* @returns
*/
export async function start(config = {}) {
if(config.resetEnv) {
jest.resetModules();
jest.resetAllMocks();
lg.setup(global);
localStorage.clear();
sessionStorage.clear();
}
Object.assign(localStorage, config.localStorage ?? {});
document.body.innerHTML = html;
mockApi(config);
const { app } = require("../../web/scripts/app");
config.preSetup?.(app);
await app.setup();
return { ...Ez.graph(app, global["LiteGraph"], global["LGraphCanvas"]), app };
}
/**
* @param { ReturnType<Ez["graph"]>["graph"] } graph
* @param { (hasReloaded: boolean) => (Promise<void> | void) } cb
*/
export async function checkBeforeAndAfterReload(graph, cb) {
await cb(false);
await graph.reload();
await cb(true);
}
/**
* @param { string } name
* @param { Record<string, string | [string | string[], any]> } input
* @param { (string | string[])[] | Record<string, string | string[]> } output
* @returns { Record<string, import("../../web/types/comfy").ComfyObjectInfo> }
*/
export function makeNodeDef(name, input, output = {}) {
const nodeDef = {
name,
category: "test",
output: [],
output_name: [],
output_is_list: [],
input: {
required: {},
},
};
for (const k in input) {
nodeDef.input.required[k] = typeof input[k] === "string" ? [input[k], {}] : [...input[k]];
}
if (output instanceof Array) {
output = output.reduce((p, c) => {
p[c] = c;
return p;
}, {});
}
for (const k in output) {
nodeDef.output.push(output[k]);
nodeDef.output_name.push(k);
nodeDef.output_is_list.push(false);
}
return { [name]: nodeDef };
}
/**
/**
* @template { any } T
* @param { T } x
* @returns { x is Exclude<T, null | undefined> }
*/
export function assertNotNullOrUndefined(x) {
expect(x).not.toEqual(null);
expect(x).not.toEqual(undefined);
return true;
}
/**
*
* @param { ReturnType<Ez["graph"]>["ez"] } ez
* @param { ReturnType<Ez["graph"]>["graph"] } graph
*/
export function createDefaultWorkflow(ez, graph) {
graph.clear();
const ckpt = ez.CheckpointLoaderSimple();
const pos = ez.CLIPTextEncode(ckpt.outputs.CLIP, { text: "positive" });
const neg = ez.CLIPTextEncode(ckpt.outputs.CLIP, { text: "negative" });
const empty = ez.EmptyLatentImage();
const sampler = ez.KSampler(
ckpt.outputs.MODEL,
pos.outputs.CONDITIONING,
neg.outputs.CONDITIONING,
empty.outputs.LATENT
);
const decode = ez.VAEDecode(sampler.outputs.LATENT, ckpt.outputs.VAE);
const save = ez.SaveImage(decode.outputs.IMAGE);
graph.arrange();
return { ckpt, pos, neg, empty, sampler, decode, save };
}
export async function getNodeDefs() {
const { api } = require("../../web/scripts/api");
return api.getNodeDefs();
}
export async function getNodeDef(nodeId) {
return (await getNodeDefs())[nodeId];
}

View File

@@ -1,36 +0,0 @@
const fs = require("fs");
const path = require("path");
const { nop } = require("../utils/nopProxy");
function forEachKey(cb) {
for (const k of [
"LiteGraph",
"LGraph",
"LLink",
"LGraphNode",
"LGraphGroup",
"DragAndScale",
"LGraphCanvas",
"ContextMenu",
]) {
cb(k);
}
}
export function setup(ctx) {
const lg = fs.readFileSync(path.resolve("../web/lib/litegraph.core.js"), "utf-8");
const globalTemp = {};
(function (console) {
eval(lg);
}).call(globalTemp, nop);
forEachKey((k) => (ctx[k] = globalTemp[k]));
require(path.resolve("../web/lib/litegraph.extensions.js"));
}
export function teardown(ctx) {
forEachKey((k) => delete ctx[k]);
// Clear document after each run
document.getElementsByTagName("html")[0].innerHTML = "";
}

View File

@@ -1,6 +0,0 @@
export const nop = new Proxy(function () {}, {
get: () => nop,
set: () => true,
apply: () => nop,
construct: () => nop,
});

View File

@@ -1,82 +0,0 @@
require("../../web/scripts/api");
const fs = require("fs");
const path = require("path");
function* walkSync(dir) {
const files = fs.readdirSync(dir, { withFileTypes: true });
for (const file of files) {
if (file.isDirectory()) {
yield* walkSync(path.join(dir, file.name));
} else {
yield path.join(dir, file.name);
}
}
}
/**
* @typedef { import("../../web/types/comfy").ComfyObjectInfo } ComfyObjectInfo
*/
/**
* @param {{
* mockExtensions?: string[],
* mockNodeDefs?: Record<string, ComfyObjectInfo>,
* settings?: Record<string, string>
* userConfig?: {storage: "server" | "browser", users?: Record<string, any>, migrated?: boolean },
* userData?: Record<string, any>
* }} config
*/
export function mockApi(config = {}) {
let { mockExtensions, mockNodeDefs, userConfig, settings, userData } = {
userConfig,
settings: {},
userData: {},
...config,
};
if (!mockExtensions) {
mockExtensions = Array.from(walkSync(path.resolve("../web/extensions/core")))
.filter((x) => x.endsWith(".js"))
.map((x) => path.relative(path.resolve("../web"), x));
}
if (!mockNodeDefs) {
mockNodeDefs = JSON.parse(fs.readFileSync(path.resolve("./data/object_info.json")));
}
const events = new EventTarget();
const mockApi = {
addEventListener: events.addEventListener.bind(events),
removeEventListener: events.removeEventListener.bind(events),
dispatchEvent: events.dispatchEvent.bind(events),
getSystemStats: jest.fn(),
getExtensions: jest.fn(() => mockExtensions),
getNodeDefs: jest.fn(() => mockNodeDefs),
init: jest.fn(),
apiURL: jest.fn((x) => "../../web/" + x),
createUser: jest.fn((username) => {
if(username in userConfig.users) {
return { status: 400, json: () => "Duplicate" }
}
userConfig.users[username + "!"] = username;
return { status: 200, json: () => username + "!" }
}),
getUserConfig: jest.fn(() => userConfig ?? { storage: "browser", migrated: false }),
getSettings: jest.fn(() => settings),
storeSettings: jest.fn((v) => Object.assign(settings, v)),
getUserData: jest.fn((f) => {
if (f in userData) {
return { status: 200, json: () => userData[f] };
} else {
return { status: 404 };
}
}),
storeUserData: jest.fn((file, data) => {
userData[file] = data;
}),
listUserData: jest.fn(() => [])
};
jest.mock("../../web/scripts/api", () => ({
get api() {
return mockApi;
},
}));
}

View File

@@ -1,6 +1,7 @@
import argparse
import pytest
from requests.exceptions import HTTPError
from unittest.mock import patch
from app.frontend_management import (
FrontendManager,
@@ -83,6 +84,35 @@ def test_init_frontend_invalid_provider():
with pytest.raises(HTTPError):
FrontendManager.init_frontend_unsafe(version_string)
@pytest.fixture
def mock_os_functions():
with patch('app.frontend_management.os.makedirs') as mock_makedirs, \
patch('app.frontend_management.os.listdir') as mock_listdir, \
patch('app.frontend_management.os.rmdir') as mock_rmdir:
mock_listdir.return_value = [] # Simulate empty directory
yield mock_makedirs, mock_listdir, mock_rmdir
@pytest.fixture
def mock_download():
with patch('app.frontend_management.download_release_asset_zip') as mock:
mock.side_effect = Exception("Download failed") # Simulate download failure
yield mock
def test_finally_block(mock_os_functions, mock_download, mock_provider):
# Arrange
mock_makedirs, mock_listdir, mock_rmdir = mock_os_functions
version_string = 'test-owner/test-repo@1.0.0'
# Act & Assert
with pytest.raises(Exception):
FrontendManager.init_frontend_unsafe(version_string, mock_provider)
# Assert
mock_makedirs.assert_called_once()
mock_download.assert_called_once()
mock_listdir.assert_called_once()
mock_rmdir.assert_called_once()
def test_parse_version_string():
version_string = "owner/repo@1.0.0"

View File

@@ -0,0 +1,321 @@
import pytest
import aiohttp
from aiohttp import ClientResponse
import itertools
import os
from unittest.mock import AsyncMock, patch, MagicMock
from model_filemanager import download_model, validate_model_subdirectory, track_download_progress, create_model_path, check_file_exists, DownloadStatusType, DownloadModelStatus, validate_filename
class AsyncIteratorMock:
"""
A mock class that simulates an asynchronous iterator.
This is used to mimic the behavior of aiohttp's content iterator.
"""
def __init__(self, seq):
# Convert the input sequence into an iterator
self.iter = iter(seq)
def __aiter__(self):
# This method is called when 'async for' is used
return self
async def __anext__(self):
# This method is called for each iteration in an 'async for' loop
try:
return next(self.iter)
except StopIteration:
# This is the asynchronous equivalent of StopIteration
raise StopAsyncIteration
class ContentMock:
"""
A mock class that simulates the content attribute of an aiohttp ClientResponse.
This class provides the iter_chunked method which returns an async iterator of chunks.
"""
def __init__(self, chunks):
# Store the chunks that will be returned by the iterator
self.chunks = chunks
def iter_chunked(self, chunk_size):
# This method mimics aiohttp's content.iter_chunked()
# For simplicity in testing, we ignore chunk_size and just return our predefined chunks
return AsyncIteratorMock(self.chunks)
@pytest.mark.asyncio
async def test_download_model_success():
mock_response = AsyncMock(spec=aiohttp.ClientResponse)
mock_response.status = 200
mock_response.headers = {'Content-Length': '1000'}
# Create a mock for content that returns an async iterator directly
chunks = [b'a' * 500, b'b' * 300, b'c' * 200]
mock_response.content = ContentMock(chunks)
mock_make_request = AsyncMock(return_value=mock_response)
mock_progress_callback = AsyncMock()
# Mock file operations
mock_open = MagicMock()
mock_file = MagicMock()
mock_open.return_value.__enter__.return_value = mock_file
time_values = itertools.count(0, 0.1)
with patch('model_filemanager.create_model_path', return_value=('models/checkpoints/model.sft', 'checkpoints/model.sft')), \
patch('model_filemanager.check_file_exists', return_value=None), \
patch('builtins.open', mock_open), \
patch('time.time', side_effect=time_values): # Simulate time passing
result = await download_model(
mock_make_request,
'model.sft',
'http://example.com/model.sft',
'checkpoints',
mock_progress_callback
)
# Assert the result
assert isinstance(result, DownloadModelStatus)
assert result.message == 'Successfully downloaded model.sft'
assert result.status == 'completed'
assert result.already_existed is False
# Check progress callback calls
assert mock_progress_callback.call_count >= 3 # At least start, one progress update, and completion
# Check initial call
mock_progress_callback.assert_any_call(
'checkpoints/model.sft',
DownloadModelStatus(DownloadStatusType.PENDING, 0, "Starting download of model.sft", False)
)
# Check final call
mock_progress_callback.assert_any_call(
'checkpoints/model.sft',
DownloadModelStatus(DownloadStatusType.COMPLETED, 100, "Successfully downloaded model.sft", False)
)
# Verify file writing
mock_file.write.assert_any_call(b'a' * 500)
mock_file.write.assert_any_call(b'b' * 300)
mock_file.write.assert_any_call(b'c' * 200)
# Verify request was made
mock_make_request.assert_called_once_with('http://example.com/model.sft')
@pytest.mark.asyncio
async def test_download_model_url_request_failure():
# Mock dependencies
mock_response = AsyncMock(spec=ClientResponse)
mock_response.status = 404 # Simulate a "Not Found" error
mock_get = AsyncMock(return_value=mock_response)
mock_progress_callback = AsyncMock()
# Mock the create_model_path function
with patch('model_filemanager.create_model_path', return_value=('/mock/path/model.safetensors', 'mock/path/model.safetensors')):
# Mock the check_file_exists function to return None (file doesn't exist)
with patch('model_filemanager.check_file_exists', return_value=None):
# Call the function
result = await download_model(
mock_get,
'model.safetensors',
'http://example.com/model.safetensors',
'mock_directory',
mock_progress_callback
)
# Assert the expected behavior
assert isinstance(result, DownloadModelStatus)
assert result.status == 'error'
assert result.message == 'Failed to download model.safetensors. Status code: 404'
assert result.already_existed is False
# Check that progress_callback was called with the correct arguments
mock_progress_callback.assert_any_call(
'mock_directory/model.safetensors',
DownloadModelStatus(
status=DownloadStatusType.PENDING,
progress_percentage=0,
message='Starting download of model.safetensors',
already_existed=False
)
)
mock_progress_callback.assert_called_with(
'mock_directory/model.safetensors',
DownloadModelStatus(
status=DownloadStatusType.ERROR,
progress_percentage=0,
message='Failed to download model.safetensors. Status code: 404',
already_existed=False
)
)
# Verify that the get method was called with the correct URL
mock_get.assert_called_once_with('http://example.com/model.safetensors')
@pytest.mark.asyncio
async def test_download_model_invalid_model_subdirectory():
mock_make_request = AsyncMock()
mock_progress_callback = AsyncMock()
result = await download_model(
mock_make_request,
'model.sft',
'http://example.com/model.sft',
'../bad_path',
mock_progress_callback
)
# Assert the result
assert isinstance(result, DownloadModelStatus)
assert result.message == 'Invalid model subdirectory'
assert result.status == 'error'
assert result.already_existed is False
# For create_model_path function
def test_create_model_path(tmp_path, monkeypatch):
mock_models_dir = tmp_path / "models"
monkeypatch.setattr('folder_paths.models_dir', str(mock_models_dir))
model_name = "test_model.sft"
model_directory = "test_dir"
file_path, relative_path = create_model_path(model_name, model_directory, mock_models_dir)
assert file_path == str(mock_models_dir / model_directory / model_name)
assert relative_path == f"{model_directory}/{model_name}"
assert os.path.exists(os.path.dirname(file_path))
@pytest.mark.asyncio
async def test_check_file_exists_when_file_exists(tmp_path):
file_path = tmp_path / "existing_model.sft"
file_path.touch() # Create an empty file
mock_callback = AsyncMock()
result = await check_file_exists(str(file_path), "existing_model.sft", mock_callback, "test/existing_model.sft")
assert result is not None
assert result.status == "completed"
assert result.message == "existing_model.sft already exists"
assert result.already_existed is True
mock_callback.assert_called_once_with(
"test/existing_model.sft",
DownloadModelStatus(DownloadStatusType.COMPLETED, 100, "existing_model.sft already exists", already_existed=True)
)
@pytest.mark.asyncio
async def test_check_file_exists_when_file_does_not_exist(tmp_path):
file_path = tmp_path / "non_existing_model.sft"
mock_callback = AsyncMock()
result = await check_file_exists(str(file_path), "non_existing_model.sft", mock_callback, "test/non_existing_model.sft")
assert result is None
mock_callback.assert_not_called()
@pytest.mark.asyncio
async def test_track_download_progress_no_content_length():
mock_response = AsyncMock(spec=aiohttp.ClientResponse)
mock_response.headers = {} # No Content-Length header
mock_response.content.iter_chunked.return_value = AsyncIteratorMock([b'a' * 500, b'b' * 500])
mock_callback = AsyncMock()
mock_open = MagicMock(return_value=MagicMock())
with patch('builtins.open', mock_open):
result = await track_download_progress(
mock_response, '/mock/path/model.sft', 'model.sft',
mock_callback, 'models/model.sft', interval=0.1
)
assert result.status == "completed"
# Check that progress was reported even without knowing the total size
mock_callback.assert_any_call(
'models/model.sft',
DownloadModelStatus(DownloadStatusType.IN_PROGRESS, 0, "Downloading model.sft", already_existed=False)
)
@pytest.mark.asyncio
async def test_track_download_progress_interval():
mock_response = AsyncMock(spec=aiohttp.ClientResponse)
mock_response.headers = {'Content-Length': '1000'}
mock_response.content.iter_chunked.return_value = AsyncIteratorMock([b'a' * 100] * 10)
mock_callback = AsyncMock()
mock_open = MagicMock(return_value=MagicMock())
# Create a mock time function that returns incremental float values
mock_time = MagicMock()
mock_time.side_effect = [i * 0.5 for i in range(30)] # This should be enough for 10 chunks
with patch('builtins.open', mock_open), \
patch('time.time', mock_time):
await track_download_progress(
mock_response, '/mock/path/model.sft', 'model.sft',
mock_callback, 'models/model.sft', interval=1.0
)
# Print out the actual call count and the arguments of each call for debugging
print(f"mock_callback was called {mock_callback.call_count} times")
for i, call in enumerate(mock_callback.call_args_list):
args, kwargs = call
print(f"Call {i + 1}: {args[1].status}, Progress: {args[1].progress_percentage:.2f}%")
# Assert that progress was updated at least 3 times (start, at least one interval, and end)
assert mock_callback.call_count >= 3, f"Expected at least 3 calls, but got {mock_callback.call_count}"
# Verify the first and last calls
first_call = mock_callback.call_args_list[0]
assert first_call[0][1].status == "in_progress"
# Allow for some initial progress, but it should be less than 50%
assert 0 <= first_call[0][1].progress_percentage < 50, f"First call progress was {first_call[0][1].progress_percentage}%"
last_call = mock_callback.call_args_list[-1]
assert last_call[0][1].status == "completed"
assert last_call[0][1].progress_percentage == 100
def test_valid_subdirectory():
assert validate_model_subdirectory("valid-model123") is True
def test_subdirectory_too_long():
assert validate_model_subdirectory("a" * 51) is False
def test_subdirectory_with_double_dots():
assert validate_model_subdirectory("model/../unsafe") is False
def test_subdirectory_with_slash():
assert validate_model_subdirectory("model/unsafe") is False
def test_subdirectory_with_special_characters():
assert validate_model_subdirectory("model@unsafe") is False
def test_subdirectory_with_underscore_and_dash():
assert validate_model_subdirectory("valid_model-name") is True
def test_empty_subdirectory():
assert validate_model_subdirectory("") is False
@pytest.mark.parametrize("filename, expected", [
("valid_model.safetensors", True),
("valid_model.sft", True),
("valid model.safetensors", True), # Test with space
("UPPERCASE_MODEL.SAFETENSORS", True),
("model_with.multiple.dots.pt", False),
("", False), # Empty string
("../../../etc/passwd", False), # Path traversal attempt
("/etc/passwd", False), # Absolute path
("\\windows\\system32\\config\\sam", False), # Windows path
(".hidden_file.pt", False), # Hidden file
("invalid<char>.ckpt", False), # Invalid character
("invalid?.ckpt", False), # Another invalid character
("very" * 100 + ".safetensors", False), # Too long filename
("\nmodel_with_newline.pt", False), # Newline character
("model_with_emoji😊.pt", False), # Emoji in filename
])
def test_validate_filename(filename, expected):
assert validate_filename(filename) == expected

View File

@@ -1 +1,3 @@
pytest>=7.8.0
pytest-aiohttp
pytest-asyncio

View File

@@ -0,0 +1,115 @@
import pytest
from aiohttp import web
from unittest.mock import MagicMock, patch
from api_server.routes.internal.internal_routes import InternalRoutes
from api_server.services.file_service import FileService
from folder_paths import models_dir, user_directory, output_directory
@pytest.fixture
def internal_routes():
return InternalRoutes()
@pytest.fixture
def aiohttp_client_factory(aiohttp_client, internal_routes):
async def _get_client():
app = internal_routes.get_app()
return await aiohttp_client(app)
return _get_client
@pytest.mark.asyncio
async def test_list_files_valid_directory(aiohttp_client_factory, internal_routes):
mock_file_list = [
{"name": "file1.txt", "path": "file1.txt", "type": "file", "size": 100},
{"name": "dir1", "path": "dir1", "type": "directory"}
]
internal_routes.file_service.list_files = MagicMock(return_value=mock_file_list)
client = await aiohttp_client_factory()
resp = await client.get('/files?directory=models')
assert resp.status == 200
data = await resp.json()
assert 'files' in data
assert len(data['files']) == 2
assert data['files'] == mock_file_list
# Check other valid directories
resp = await client.get('/files?directory=user')
assert resp.status == 200
resp = await client.get('/files?directory=output')
assert resp.status == 200
@pytest.mark.asyncio
async def test_list_files_invalid_directory(aiohttp_client_factory, internal_routes):
internal_routes.file_service.list_files = MagicMock(side_effect=ValueError("Invalid directory key"))
client = await aiohttp_client_factory()
resp = await client.get('/files?directory=invalid')
assert resp.status == 400
data = await resp.json()
assert 'error' in data
assert data['error'] == "Invalid directory key"
@pytest.mark.asyncio
async def test_list_files_exception(aiohttp_client_factory, internal_routes):
internal_routes.file_service.list_files = MagicMock(side_effect=Exception("Unexpected error"))
client = await aiohttp_client_factory()
resp = await client.get('/files?directory=models')
assert resp.status == 500
data = await resp.json()
assert 'error' in data
assert data['error'] == "Unexpected error"
@pytest.mark.asyncio
async def test_list_files_no_directory_param(aiohttp_client_factory, internal_routes):
mock_file_list = []
internal_routes.file_service.list_files = MagicMock(return_value=mock_file_list)
client = await aiohttp_client_factory()
resp = await client.get('/files')
assert resp.status == 200
data = await resp.json()
assert 'files' in data
assert len(data['files']) == 0
def test_setup_routes(internal_routes):
internal_routes.setup_routes()
routes = internal_routes.routes
assert any(route.method == 'GET' and str(route.path) == '/files' for route in routes)
def test_get_app(internal_routes):
app = internal_routes.get_app()
assert isinstance(app, web.Application)
assert internal_routes._app is not None
def test_get_app_reuse(internal_routes):
app1 = internal_routes.get_app()
app2 = internal_routes.get_app()
assert app1 is app2
@pytest.mark.asyncio
async def test_routes_added_to_app(aiohttp_client_factory, internal_routes):
client = await aiohttp_client_factory()
try:
resp = await client.get('/files')
print(f"Response received: status {resp.status}")
except Exception as e:
print(f"Exception occurred during GET request: {e}")
raise
assert resp.status != 404, "Route /files does not exist"
@pytest.mark.asyncio
async def test_file_service_initialization():
with patch('api_server.routes.internal.internal_routes.FileService') as MockFileService:
# Create a mock instance
mock_file_service_instance = MagicMock(spec=FileService)
MockFileService.return_value = mock_file_service_instance
internal_routes = InternalRoutes()
# Check if FileService was initialized with the correct parameters
MockFileService.assert_called_once_with({
"models": models_dir,
"user": user_directory,
"output": output_directory
})
# Verify that the file_service attribute of InternalRoutes is set
assert internal_routes.file_service == mock_file_service_instance

View File

@@ -0,0 +1,54 @@
import pytest
from unittest.mock import MagicMock
from api_server.services.file_service import FileService
@pytest.fixture
def mock_file_system_ops():
return MagicMock()
@pytest.fixture
def file_service(mock_file_system_ops):
allowed_directories = {
"models": "/path/to/models",
"user": "/path/to/user",
"output": "/path/to/output"
}
return FileService(allowed_directories, file_system_ops=mock_file_system_ops)
def test_list_files_valid_directory(file_service, mock_file_system_ops):
mock_file_system_ops.walk_directory.return_value = [
{"name": "file1.txt", "path": "file1.txt", "type": "file", "size": 100},
{"name": "dir1", "path": "dir1", "type": "directory"}
]
result = file_service.list_files("models")
assert len(result) == 2
assert result[0]["name"] == "file1.txt"
assert result[1]["name"] == "dir1"
mock_file_system_ops.walk_directory.assert_called_once_with("/path/to/models")
def test_list_files_invalid_directory(file_service):
# Does not support walking directories outside of the allowed directories
with pytest.raises(ValueError, match="Invalid directory key"):
file_service.list_files("invalid_key")
def test_list_files_empty_directory(file_service, mock_file_system_ops):
mock_file_system_ops.walk_directory.return_value = []
result = file_service.list_files("models")
assert len(result) == 0
mock_file_system_ops.walk_directory.assert_called_once_with("/path/to/models")
@pytest.mark.parametrize("directory_key", ["models", "user", "output"])
def test_list_files_all_allowed_directories(file_service, mock_file_system_ops, directory_key):
mock_file_system_ops.walk_directory.return_value = [
{"name": f"file_{directory_key}.txt", "path": f"file_{directory_key}.txt", "type": "file", "size": 100}
]
result = file_service.list_files(directory_key)
assert len(result) == 1
assert result[0]["name"] == f"file_{directory_key}.txt"
mock_file_system_ops.walk_directory.assert_called_once_with(f"/path/to/{directory_key}")

View File

@@ -0,0 +1,42 @@
import pytest
from typing import List
from api_server.utils.file_operations import FileSystemOperations, FileSystemItem, is_file_info
@pytest.fixture
def temp_directory(tmp_path):
# Create a temporary directory structure
dir1 = tmp_path / "dir1"
dir2 = tmp_path / "dir2"
dir1.mkdir()
dir2.mkdir()
(dir1 / "file1.txt").write_text("content1")
(dir2 / "file2.txt").write_text("content2")
(tmp_path / "file3.txt").write_text("content3")
return tmp_path
def test_walk_directory(temp_directory):
result: List[FileSystemItem] = FileSystemOperations.walk_directory(str(temp_directory))
assert len(result) == 5 # 2 directories and 3 files
files = [item for item in result if item['type'] == 'file']
dirs = [item for item in result if item['type'] == 'directory']
assert len(files) == 3
assert len(dirs) == 2
file_names = {file['name'] for file in files}
assert file_names == {'file1.txt', 'file2.txt', 'file3.txt'}
dir_names = {dir['name'] for dir in dirs}
assert dir_names == {'dir1', 'dir2'}
def test_walk_directory_empty(tmp_path):
result = FileSystemOperations.walk_directory(str(tmp_path))
assert len(result) == 0
def test_walk_directory_file_size(temp_directory):
result: List[FileSystemItem] = FileSystemOperations.walk_directory(str(temp_directory))
files = [item for item in result if is_file_info(item)]
for file in files:
assert file['size'] > 0 # Assuming all files have some content

View File

@@ -0,0 +1,4 @@
# Config for testing nodes
testing:
custom_nodes: tests/inference/testing_nodes

Some files were not shown because too many files have changed in this diff Show More