1
mirror of https://github.com/comfyanonymous/ComfyUI.git synced 2025-08-02 23:14:49 +08:00

Compare commits

...

185 Commits

Author SHA1 Message Date
comfyanonymous
30c0c81351 Add a way to patch blocks in SD3. 2024-10-29 00:48:32 -04:00
comfyanonymous
13b0ff8a6f Update SD3 code. 2024-10-28 21:58:52 -04:00
comfyanonymous
c320801187 Remove useless line. 2024-10-28 17:41:12 -04:00
Chenlei Hu
c0b0cfaeec Update web content to release v1.3.21 (#5351)
* Update web content to release v1.3.21

* nit
2024-10-28 14:29:38 -04:00
comfyanonymous
669d9e4c67 Set default shift on mochi to 6.0 2024-10-27 22:21:04 -04:00
comfyanonymous
9ee0a6553a float16 inference is a bit broken on mochi. 2024-10-27 04:56:40 -04:00
comfyanonymous
5cbb01bc2f Basic Genmo Mochi video model support.
To use:
"Load CLIP" node with t5xxl + type mochi
"Load Diffusion Model" node with the mochi dit file.
"Load VAE" with the mochi vae file.

EmptyMochiLatentVideo node for the latent.
euler + linear_quadratic in the KSampler node.
2024-10-26 06:54:00 -04:00
comfyanonymous
c3ffbae067 Make LatentUpscale nodes work on 3d latents. 2024-10-26 01:50:51 -04:00
comfyanonymous
d605677b33 Make euler_ancestral work on flow models (credit: Ashen). 2024-10-25 19:53:44 -04:00
Chenlei Hu
ce759b7db6 Revert download to .tmp in frontend_management (#5369) 2024-10-25 19:26:13 -04:00
comfyanonymous
52810907e2 Add a model merge node for SD3.5 large. 2024-10-24 16:46:21 -04:00
PsychoLogicAu
af8cf79a2d support SimpleTuner lycoris lora for SD3 (#5340) 2024-10-24 01:18:32 -04:00
comfyanonymous
66b0961a46 Fix ControlLora issue with last commit. 2024-10-23 17:02:40 -04:00
comfyanonymous
754597c8a9 Clean up some controlnet code.
Remove self.device which was useless.
2024-10-23 14:19:05 -04:00
comfyanonymous
915fdb5745 Fix lowvram edge case. 2024-10-22 16:34:50 -04:00
contentis
5a8a48931a remove attention abstraction (#5324) 2024-10-22 14:02:38 -04:00
comfyanonymous
8ce2a1052c Optimizations to --fast and scaled fp8. 2024-10-22 02:12:28 -04:00
comfyanonymous
f82314fcfc Fix duplicate sigmas on beta scheduler. 2024-10-21 20:19:45 -04:00
comfyanonymous
0075c6d096 Mixed precision diffusion models with scaled fp8.
This change allows supports for diffusion models where all the linears are
scaled fp8 while the other weights are the original precision.
2024-10-21 18:12:51 -04:00
comfyanonymous
83ca891118 Support scaled fp8 t5xxl model. 2024-10-20 22:27:00 -04:00
comfyanonymous
f9f9faface Fixed model merging issue with scaled fp8. 2024-10-20 06:24:31 -04:00
comfyanonymous
471cd3eace fp8 casting is fast on GPUs that support fp8 compute. 2024-10-20 00:54:47 -04:00
comfyanonymous
a68bbafddb Support diffusion models with scaled fp8 weights. 2024-10-19 23:47:42 -04:00
comfyanonymous
73e3a9e676 Clamp output when rounding weight to prevent Nan. 2024-10-19 19:07:10 -04:00
comfyanonymous
518c0dc2fe Add tooltips to LoraSave node. 2024-10-18 06:01:09 -04:00
comfyanonymous
ce0542e10b Add a note that python 3.13 is not yet supported to the README. 2024-10-17 19:27:37 -04:00
comfyanonymous
8473019d40 Pytorch can be shipped with numpy 2 now. 2024-10-17 19:15:17 -04:00
Xiaodong Xie
89f15894dd Ignore more network related errors during websocket communication. (#5269)
Intermittent network issues during websocket communication should not crash ComfyUi process.

Co-authored-by: Xiaodong Xie <xie.xiaodong@frever.com>
2024-10-17 18:31:45 -04:00
comfyanonymous
67158994a4 Use the lowvram cast_to function for everything. 2024-10-17 17:25:56 -04:00
comfyanonymous
7390ff3b1e Add missing import. 2024-10-16 14:58:30 -04:00
comfyanonymous
0bedfb26af Revert "Fix Transformers FutureWarning (#5140)"
This reverts commit 95b7cf9bbe.
2024-10-16 12:36:19 -04:00
comfyanonymous
f71cfd2687 Add an experimental node to sharpen latents.
Can be used with LatentApplyOperationCFG for interesting results.
2024-10-16 05:25:31 -04:00
Alex "mcmonkey" Goodwin
c695c4af7f Frontend Manager: avoid redundant gh calls for static versions (#5152)
* Frontend Manager: avoid redundant gh calls for static versions

* actually, removing old tmpdir isn't needed

I tested - downloader code handles this case well already
(also rmdir was wrong func anyway, needed shutil.rmtree if it had content)

* add code comment
2024-10-16 03:35:37 -04:00
comfyanonymous
0dbba9f751 Add some latent operation nodes.
This is a port of the ModelSamplerTonemapNoiseTest from the experiments
repo.

To replicate that node use LatentOperationTonemapReinhard and
LatentApplyOperationCFG together.
2024-10-15 15:00:36 -04:00
comfyanonymous
f584758271 Cleanup some useless lines. 2024-10-14 21:02:39 -04:00
svdc
95b7cf9bbe Fix Transformers FutureWarning (#5140)
* Update sd1_clip.py

Fix Transformers FutureWarning

* Update sd1_clip.py

Fix comment
2024-10-14 20:12:20 -04:00
comfyanonymous
191a0d56b4 Switch default packaging workflows to python 3.12 2024-10-13 06:59:31 -04:00
comfyanonymous
3c60ecd7a8 Fix fp8 ops staying enabled. 2024-10-12 14:10:13 -04:00
comfyanonymous
7ae6626723 Remove useless argument. 2024-10-12 07:16:21 -04:00
comfyanonymous
6632365e16 model_options consistency between functions.
weight_dtype -> dtype
2024-10-11 20:51:19 -04:00
Kadir Nar
ad07796777 🐛 Add device to variable c (#5210) 2024-10-11 20:37:50 -04:00
comfyanonymous
1b80895285 Make clip loader nodes support loading sd3 t5xxl in lower precision.
Add attention mask support in the SD3 text encoder code.
2024-10-10 15:06:15 -04:00
Dr.Lt.Data
5f9d5a244b Hotfix for the div zero occurrence when memory_used_encode is 0 (#5121)
https://github.com/comfyanonymous/ComfyUI/issues/5069#issuecomment-2382656368
2024-10-09 23:34:34 -04:00
Chenlei Hu
14eba07acd Update web content to release v1.3.11 (#5189)
* Update web content to release v1.3.11

* nit
2024-10-09 22:37:04 -04:00
Jonathan Avila
4b2f0d9413 Increase maximum macOS version to 15.0.1 when forcing upcast attention (#5191) 2024-10-09 22:21:41 -04:00
Yoland Yan
25eac1d780 Change runner label for the new runners (#5197) 2024-10-09 20:08:57 -04:00
comfyanonymous
e38c94228b Add a weight_dtype fp8_e4m3fn_fast to the Diffusion Model Loader node.
This is used to load weights in fp8 and use fp8 matrix multiplication.
2024-10-09 19:43:17 -04:00
comfyanonymous
203942c8b2 Fix flux doras with diffusers keys. 2024-10-08 19:03:40 -04:00
Brendan Hoar
3c72c89a52 Update folder_paths.py - try/catch for special file_name values (#5187)
Somehow managed to drop a file called "nul" into a windows checkpoints subdirectory. This caused all sorts of havoc with many nodes that needed the list of checkpoints.
2024-10-08 15:04:32 -04:00
Chenlei Hu
614377abd6 Update web content to release v1.2.64 (#5124) 2024-10-07 17:15:29 -04:00
comfyanonymous
8dfa0cc552 Make SD3 fast previews a little better. 2024-10-07 09:19:59 -04:00
comfyanonymous
e5ecdfdd2d Make fast previews for SDXL a little better by adding a bias. 2024-10-06 19:27:04 -04:00
comfyanonymous
7d29fbf74b Slightly improve the fast previews for flux by adding a bias. 2024-10-06 17:55:46 -04:00
Lex
2c641e64ad IS_CHANGED should be a classmethod (#5159) 2024-10-06 05:47:51 -04:00
comfyanonymous
7d2467e830 Some minor cleanups. 2024-10-05 13:22:39 -04:00
comfyanonymous
6f021d8aa0 Let --verbose have an argument for the log level. 2024-10-04 10:05:34 -04:00
comfyanonymous
d854ed0bcf Allow using SD3 type te output on flux model. 2024-10-03 09:44:54 -04:00
comfyanonymous
abcd006b8c Allow more permutations of clip/t5 in dual clip loader. 2024-10-03 09:26:11 -04:00
comfyanonymous
d985d1d7dc CLIP Loader node now supports clip_l and clip_g only for SD3. 2024-10-02 04:25:17 -04:00
comfyanonymous
d1cdf51e1b Refactor some of the TE detection code. 2024-10-01 07:08:41 -04:00
comfyanonymous
b4626ab93e Add simpletuner lycoris format for SD unet. 2024-09-30 06:03:27 -04:00
comfyanonymous
a9e459c2a4 Use torch.nn.functional.linear in RGB preview code.
Add an optional bias to the latent RGB preview code.
2024-09-29 11:27:49 -04:00
comfyanonymous
3bb4dec720 Fix issue with loras, lowvram and --fast fp8. 2024-09-28 14:42:32 -04:00
City
8733191563 Flux torch.compile fix (#5082) 2024-09-27 22:07:51 -04:00
comfyanonymous
83b01f960a Add backend option to TorchCompileModel.
If you want to use the cudagraphs backend you need to: --disable-cuda-malloc

If you get other backends working feel free to make a PR to add them.
2024-09-27 02:12:37 -04:00
comfyanonymous
d72e871cfa Add a note that the experimental model downloader api will be removed. 2024-09-26 03:17:52 -04:00
comfyanonymous
037c3159b6 Move some nodes out of _for_testing. 2024-09-25 08:41:22 -04:00
comfyanonymous
bdd4a22a2e Fix flux TE not loading t5 embeddings. 2024-09-24 22:57:22 -04:00
comfyanonymous
fdf37566ef Add batch size to EmptyLatentAudio. 2024-09-24 04:32:55 -04:00
Alex "mcmonkey" Goodwin
08c8968482 Internal download API: Add proper validated directory input (#4981)
* add internal /folder_paths route

returns a json maps of folder paths

* (minor) format download_models.py

* initial folder path input on download api

* actually, require folder_path and clean up some code

* partial tests update

* fix & logging

* also download to a tmp file not the live file

to avoid compounding errors from network failure

* update tests again

* test tweaks

* workaround the first tests blocker

* fix file handling in tests

* rewrite test for create_model_path

* minor doc fix

* avoid 'mock_directory'

use temp dir to avoid accidental fs pollution from tests
2024-09-24 03:50:45 -04:00
chaObserv
479a427a48 Add dpmpp_2m_cfg_pp (#4992) 2024-09-24 02:42:56 -04:00
comfyanonymous
3a0eeee320 Make --listen listen on both ipv4 and ipv6 at the same time by default. 2024-09-23 04:38:19 -04:00
comfyanonymous
447da7ea86 Support listening on multiple addresses. 2024-09-23 04:36:59 -04:00
comfyanonymous
9c41bc8d10 Remove useless line. 2024-09-23 02:32:29 -04:00
Robin Huang
6ad0ddbae4 Run unit tests on Windows/MacOS as well. (#5018)
* Run unit tests on Windows as well.

* Test on mac.

* Continue running on error.

* Compared normalized paths to work cross platform.

* Only test common set of mimetypes across operating systems.
2024-09-22 05:01:39 -04:00
RandomGitUser321
a55142f904 Add ws.close() to the websocket examples (#5020)
* add ws.close() to websocket examples

* add and explain ws.close() in websocket examples
2024-09-22 04:59:10 -04:00
comfyanonymous
5718ef69bb Add total and free ram to /system_stats. 2024-09-22 03:42:11 -04:00
RandomGitUser321
13ecf10a92 Added to the websockets_api_example.py to show how to decode latent previews from the binary stream (#5016)
* Update websockets_api_example.py

* even more simplfied
2024-09-22 02:30:44 -04:00
comfyanonymous
7a415f47a9 Add an optional VAE input to the ControlNetApplyAdvanced node.
Deprecate the other controlnet nodes.
2024-09-22 01:24:52 -04:00
Chenlei Hu
89fa2fca24 Update web content to release v1.2.60 (#5017)
* Update web content to release v1.2.60

* Remove dist.zip
2024-09-21 23:28:54 -04:00
comfyanonymous
364b69e931 Make SD3 empty latent image zeros.
This shouldn't change anything. The reason it was not zeros is because it
did matter in early versions of the code.
2024-09-21 09:13:10 -04:00
comfyanonymous
dc96a1ae19 Load controlnet in fp8 if weights are in fp8. 2024-09-21 04:50:12 -04:00
comfyanonymous
2d810b081e Add load_controlnet_state_dict function. 2024-09-21 01:51:51 -04:00
comfyanonymous
9f7e9f0547 Add an error message when a controlnet needs a VAE but none is given. 2024-09-21 01:33:18 -04:00
comfyanonymous
a355f38ecc Make the SD3 controlnet node the default one. 2024-09-21 01:32:46 -04:00
huchenlei
38c69080c7 Add docstring 2024-09-20 03:16:23 -04:00
comfyanonymous
70a708d726 Fix model merging issue. 2024-09-20 02:31:44 -04:00
yoinked
e7d4782736 add laplace scheduler [2407.03297] (#4990)
* add laplace scheduler [2407.03297]

* should be here instead lol

* better settings
2024-09-19 23:23:09 -04:00
Alex "mcmonkey" Goodwin
3326bdfd4e add internal /folder_paths route (#4980)
returns a json maps of folder paths
2024-09-19 09:52:55 -04:00
Alex "mcmonkey" Goodwin
68bb885d22 add 'is_default' to model paths config (#4979)
* add 'is_default' to model paths config

including impl and doc in example file

* update weirdly overspecific test expectations

* oh there's two

* sigh
2024-09-19 08:59:55 -04:00
comfyanonymous
ad66f7c7d8 Add model_options to load_controlnet function. 2024-09-19 08:23:35 -04:00
Simon Lui
de8e8e3b0d Fix xpu Pytorch nightly build from calling optimize which doesn't exist. (#4978) 2024-09-19 05:11:42 -04:00
Alex "mcmonkey" Goodwin
a1e71cfad1 very simple strong-cache on model list (#4969)
* very simple strong-cache on model list

* store the cache after validation too

* only cache object_info for now

* use a 'with' context
2024-09-19 04:40:14 -04:00
comfyanonymous
0bfc7cc998 Create the temp directory on ComfyUI startup instead. 2024-09-18 09:55:57 -04:00
Tom
7183fd1665 Add route to list model types (#4846)
* Add list models route

* Better readable model types list
2024-09-17 04:22:05 -04:00
Alex "mcmonkey" Goodwin
254838f23c add simple error check to model loading (#4950) 2024-09-17 03:57:17 -04:00
pharmapsychotic
0b7dfa986d Improve tiling calculations to reduce number of tiles that need to be processed. (#4944) 2024-09-17 03:51:10 -04:00
comfyanonymous
d514bb38ee Add some option to model_options for the text encoder.
load_device, offload_device and the initial_device can now be set.
2024-09-17 03:49:54 -04:00
comfyanonymous
0849c80e2a get_key_patches now works without unloading the model. 2024-09-17 01:57:59 -04:00
comfyanonymous
56e8f5e4fd VAEDecodeAudio now does some normalization on the audio. 2024-09-16 00:30:36 -04:00
comfyanonymous
e813abbb2c Long CLIP L support for SDXL, SD3 and Flux.
Use the *CLIPLoader nodes.
2024-09-15 07:59:38 -04:00
JettHu
5e68a4ce67 Reduce repeated calls of INPUT_TYPES in cache (#4922) 2024-09-15 01:03:09 -04:00
comfyanonymous
ca08597670 Make the inpaint controlnet node work with non inpaint ones. 2024-09-14 09:17:13 -04:00
comfyanonymous
f48e390032 Support AliMama SD3 and Flux inpaint controlnets.
Use the ControlNetInpaintingAliMamaApply node.
2024-09-14 09:05:16 -04:00
Chenlei Hu
369a6dd2c4 Remove empty spaces in user_manager.py (#4917) 2024-09-13 23:30:44 -04:00
comfyanonymous
b3ce8fb9fd Revert "Reduce repeated calls of get_immediate_node_signature for ancestors in cache (#4871)"
This reverts commit f6b7194f64.
2024-09-13 23:24:47 -04:00
comfyanonymous
cf80d28689 Support loading controlnets with different input. 2024-09-13 09:54:37 -04:00
Acly
6fb44c4b7c Make adding links/nodes to ExecutionList non-recursive (#4886)
Graphs with 300+ chained nodes run into maximum recursion depth error (limit is 1000 in CPython)
2024-09-13 08:25:11 -04:00
Chenlei Hu
d2247c1e61 Normalize path returned by /userdata to always use / as separator (#4906) 2024-09-13 03:45:31 -04:00
Chenlei Hu
cb12ad7049 Add full_info flag in /userdata endpoint to list out file size and last modified timestamp (#4905)
* Add full_info flag in /userdata endpoint to list out file size and last modified timestamp

* nit
2024-09-13 02:40:59 -04:00
JettHu
f6b7194f64 Reduce repeated calls of get_immediate_node_signature for ancestors in cache (#4871) 2024-09-12 23:02:52 -04:00
comfyanonymous
7c6eb4fb29 Set some nodes as DEPRECATED. 2024-09-12 20:27:07 -04:00
Robin Huang
b962db9952 Add cli arg to override user directory (#4856)
* Override user directory.

* Use overridden user directory.

* Remove prints.

* Remove references to global user_files.

* Remove unused replace_folder function.

* Remove newline.

* Remove global during get_user_directory.

* Add validation.
2024-09-12 08:10:27 -04:00
comfyanonymous
d0b7ab88ba Add a simple experimental TorchCompileModel node.
It probably only works on Linux.

For maximum speed on Flux with Nvidia 40 series/ada and newer try using
this node with fp8_e4m3fn and the --fast argument.
2024-09-12 05:24:25 -04:00
Yoland Yan
405b529545 Minor: update tests-unit README.md (#4896) 2024-09-12 04:53:08 -04:00
comfyanonymous
9d720187f1 types -> comfy_types to fix import issue. 2024-09-12 03:57:46 -04:00
Robin Huang
d247bc5a9c Expand variables in base_path for extra_config_paths.yaml. (#4893)
* Expand variables in base_path for extra_config_paths.yaml.

* Fix comments.
2024-09-12 01:52:06 -04:00
comfyanonymous
9f4daca9d9 Doesn't really make sense for cfg_pp sampler to call regular one. 2024-09-11 02:51:36 -04:00
yoinked
b5d0f2a908 Add CFG++ to DPM++ 2S Ancestral (#3871)
* Update sampling.py

* Update samplers.py

* my bad

* "fix" the sampler

* Update samplers.py

* i named it wrong

* minor sampling improvements

mainly using a dynamic rho value (hey this sounds a lot like smea!!!)

* revert rho change

rho? r? its just 1/2
2024-09-11 02:49:44 -04:00
bymyself
e760bf5c40 Add content-type filter method to folder_paths (#4054)
* Add content-type filter method to folder_paths

* Add unit tests

* Hardcode webp content-type

* Annotate content_types as Literal["image", "video", "audio"]
2024-09-11 02:00:07 -04:00
comfyanonymous
36c83cdbba Limit origin check to when host is loopback.
This should still prevent the exploit without breaking things for people
who use reverse proxies.
2024-09-11 01:06:37 -04:00
Yoland Yan
81778a7feb [🗻 Mount Fuji Commit] Add unit tests for folder path utilities (#4869)
All past 30 min of comtts are done on the top of Mt Fuji
By Comfy, Robin, and Yoland
All other comfy org members died on the way

Introduced unit tests to verify the correctness of various folder path
utility functions such as `get_directory_by_type`, `annotated_filepath`,
and `recursive_search` among others. These tests cover scenarios
including directory retrieval, filepath annotation, recursive file
searches, and filtering files by extensions, enhancing the robustness
and reliability of the codebase.
2024-09-10 00:44:49 -04:00
comfyanonymous
bc94662b31 Cleanup. 2024-09-10 00:43:37 -04:00
Robin Huang
9fa8faa44a Expand user directory for basepath in extra_models_paths.yaml (#4857)
* Expand user path.

* Add test.

* Add unit test for expanding base path.

* Simplify unit test.

* Remove comment.

* Remove comment.

* Checkpoints.

* Refactor.
2024-09-10 00:33:44 -04:00
comfyanonymous
9a7444e39f Add diffusion_models to the extra_model_paths.yaml.example 2024-09-10 00:21:33 -04:00
comfyanonymous
54fca4a218 If host does not contain a port only compare the hostnames. 2024-09-09 16:28:23 -04:00
Chenlei Hu
cd4955367e Add back CI action for tests-ui (#4859) 2024-09-09 04:32:55 -04:00
david02871
8354203d95 Add .venv to gitignore (#4756) 2024-09-09 04:31:18 -04:00
comfyanonymous
e0b41243b4 Fix issue where sometimes origin doesn't contain the port. 2024-09-09 03:18:17 -04:00
Alex "mcmonkey" Goodwin
619263d4a6 allow current timestamp in save image prefix (#4030) 2024-09-09 02:55:51 -04:00
comfyanonymous
e3b0402bb7 Ignore origin domain when it's empty. 2024-09-09 01:04:56 -04:00
Darion
967867d48c fix: url decode filename from API (#4801) 2024-09-08 21:02:32 -04:00
comfyanonymous
cbaac71bf5 Fix issue with last commit. 2024-09-08 19:35:23 -04:00
comfyanonymous
3ab3516e46 By default only accept requests where origin header matches the host.
Browsers are dumb and let any website do requests to localhost this should
prevent this without breaking things. CORS prevents the javascript from
reading the response but they can still write it.

At the moment this is only enabled when the --enable-cors-header argument
is not used.
2024-09-08 18:17:29 -04:00
comfyanonymous
9c5fca75f4 Fix lora issue. 2024-09-08 10:10:47 -04:00
guill
a5da4d0b3e Fix error with ExecutionBlocker and OUTPUT_IS_LIST (#4836)
This change resolves an error when a node with OUTPUT_IS_LIST=(True,)
receives an ExecutionBlocker. I've also added a unit test for this case.
2024-09-08 09:48:47 -04:00
comfyanonymous
32a60a7bac Support onetrainer text encoder Flux lora. 2024-09-08 09:31:41 -04:00
Jim Winkens
bb52934ba4 Fix import issue (#4815) 2024-09-07 05:28:32 -04:00
comfyanonymous
8aabd7c8c0 SaveLora node can now save "full diff" lora format.
This isn't actually a lora format and is saving the full diff of the
weights in a format that can be used in the lora loader nodes.
2024-09-07 03:21:02 -04:00
comfyanonymous
a09b29ca11 Add an option to the SaveLora node to store the bias diff. 2024-09-07 03:03:30 -04:00
comfyanonymous
9bfee68773 LoraSave node now supports generating text encoder loras.
text_encoder_diff should be connected to a CLIPMergeSubtract node.

model_diff and text_encoder_diff are optional inputs so you can create
model only loras, text encoder only loras or a lora that contains both.
2024-09-07 02:30:12 -04:00
comfyanonymous
ea77750759 Support a generic Comfy format for text encoder loras.
This is a format with keys like:
text_encoders.clip_l.transformer.text_model.encoder.layers.9.self_attn.v_proj.lora_up.weight

Instead of waiting for me to add support for specific lora formats you can
convert your text encoder loras to this format instead.

If you want to see an example save a text encoder lora with the SaveLora
node with the commit right after this one.
2024-09-07 02:20:39 -04:00
comfyanonymous
c27ebeb1c2 Fix onnx export not working on flux. 2024-09-06 03:21:52 -04:00
guill
0c7c98a965 Nodes using UNIQUE_ID as input are NOT_IDEMPOTENT (#4793)
As suggested by @ltdrdata, we can automatically consider nodes that take
the UNIQUE_ID hidden input to be NOT_IDEMPOTENT.
2024-09-05 19:33:02 -04:00
comfyanonymous
dc2eb75b85 Update stable release workflow to latest pytorch with cuda 12.4. 2024-09-05 19:21:52 -04:00
Chenlei Hu
fa34efe3bd Update frontend to v1.2.47 (#4798)
* Update web content to release v1.2.47

* Update shortcut list
2024-09-05 18:56:01 -04:00
comfyanonymous
5cbaa9e07c Mistoline flux controlnet support. 2024-09-05 00:05:17 -04:00
comfyanonymous
c7427375ee Prioritize freeing partially offloaded models first. 2024-09-04 19:47:32 -04:00
comfyanonymous
22d1241a50 Add an experimental LoraSave node to extract model loras.
The model_diff input should be connected to the output of a
ModelMergeSubtract node.
2024-09-04 16:38:38 -04:00
Jedrzej Kosinski
f04229b84d Add emb_patch support to UNetModel forward (#4779) 2024-09-04 14:35:15 -04:00
Silver
f067ad15d1 Make live preview size a configurable launch argument (#4649)
* Make live preview size a configurable launch argument

* Remove import from testing phase

* Update cli_args.py
2024-09-03 19:16:38 -04:00
comfyanonymous
483004dd1d Support newer glora format. 2024-09-03 17:02:19 -04:00
comfyanonymous
00a5d08103 Lower fp8 lora memory usage. 2024-09-03 01:25:05 -04:00
comfyanonymous
d043997d30 Flux onetrainer lora. 2024-09-02 08:22:15 -04:00
Alex "mcmonkey" Goodwin
f1c2301697 fix typo in stale-issues (#4735) 2024-09-01 17:44:49 -04:00
comfyanonymous
8d31a6632f Speed up inference on nvidia 10 series on Linux. 2024-09-01 17:29:31 -04:00
comfyanonymous
b643eae08b Make minimum_inference_memory() depend on --reserve-vram 2024-09-01 01:18:34 -04:00
comfyanonymous
baa6b4dc36 Update manual install instructions. 2024-08-31 04:37:23 -04:00
Alex "mcmonkey" Goodwin
d4aeefc297 add github action to automatically handle stale user support issues (#4683)
* add github action to automatically handle stale user support issues

* improve stale message

* remove token part
2024-08-31 01:57:18 -04:00
comfyanonymous
587e7ca654 Remove github buttons. 2024-08-31 01:53:10 -04:00
Chenlei Hu
c90459eba0 Update ComfyUI_frontend to 1.2.40 (#4691)
* Update ComfyUI_frontend to 1.2.40

* Add files
2024-08-30 19:32:10 -04:00
Vedat Baday
04278afb10 feat: return import_failed from init_extra_nodes function (#4694) 2024-08-30 19:26:47 -04:00
comfyanonymous
935ae153e1 Cleanup. 2024-08-30 12:53:59 -04:00
Chenlei Hu
e91662e784 Get logs endpoint & system_stats additions (#4690)
* Add route for getting output logs

* Include ComfyUI version

* Move to own function

* Changed to memory logger

* Unify logger setup logic

* Fix get version git fallback

---------

Co-authored-by: pythongosssss <125205205+pythongosssss@users.noreply.github.com>
2024-08-30 12:46:37 -04:00
comfyanonymous
63fafaef45 Fix potential issue with hydit controlnets. 2024-08-30 04:58:41 -04:00
Alex "mcmonkey" Goodwin
ec28cd9136 swap legacy sdv15 link (#4682)
* swap legacy sdv15 link

* swap v15 ckpt examples to safetensors

* link the fp16 copy of the model by default
2024-08-29 19:48:48 -04:00
comfyanonymous
6eb5d64522 Fix glora lowvram issue. 2024-08-29 19:07:23 -04:00
comfyanonymous
10a79e9898 Implement model part of flux union controlnet. 2024-08-29 18:41:22 -04:00
comfyanonymous
ea3f39bd69 InstantX depth flux controlnet. 2024-08-29 02:14:19 -04:00
comfyanonymous
b33cd61070 InstantX canny controlnet. 2024-08-28 19:02:50 -04:00
Dr.Lt.Data
34eda0f853 fix: remove redundant useless loop (#4656)
fix: potential error of undefined variable

https://github.com/comfyanonymous/ComfyUI/discussions/4650
2024-08-28 17:46:30 -04:00
comfyanonymous
d31e226650 Unify RMSNorm code. 2024-08-28 16:56:38 -04:00
comfyanonymous
b79fd7d92c ComfyUI supports more than just stable diffusion. 2024-08-28 16:12:24 -04:00
comfyanonymous
38c22e631a Fix case where model was not properly unloaded in merging workflows. 2024-08-27 19:03:51 -04:00
Chenlei Hu
6bbdcd28ae Support weight padding on diff weight patch (#4576) 2024-08-27 13:55:37 -04:00
comfyanonymous
ab130001a8 Do RMSNorm in native type. 2024-08-27 02:41:56 -04:00
Chenlei Hu
ca4b8f30e0 Cleanup empty dir if frontend zip download failed (#4574) 2024-08-27 02:07:25 -04:00
Robin Huang
70b84058c1 Add relative file path to the progress report. (#4621) 2024-08-27 02:06:12 -04:00
comfyanonymous
2ca8f6e23d Make the stochastic fp8 rounding reproducible. 2024-08-26 15:12:06 -04:00
comfyanonymous
7985ff88b9 Use less memory in float8 lora patching by doing calculations in fp16. 2024-08-26 14:45:58 -04:00
comfyanonymous
c6812947e9 Fix potential memory leak. 2024-08-26 02:07:32 -04:00
comfyanonymous
9230f65823 Fix some controlnets OOMing when loading. 2024-08-25 05:54:29 -04:00
guill
6ab1e6fd4a [Bug #4529] Fix graph partial validation failure (#4588)
Currently, if a graph partially fails validation (i.e. some outputs are
valid while others have links from missing nodes), the execution loop
could get an exception resulting in server lockup.

This isn't actually possible to reproduce via the default UI, but is a
potential issue for people using the API to construct invalid graphs.
2024-08-24 15:34:58 -04:00
comfyanonymous
07dcbc3a3e Clarify how to use high quality previews. 2024-08-24 02:31:03 -04:00
comfyanonymous
8ae23d8e80 Fix onnx export. 2024-08-23 17:52:47 -04:00
173 changed files with 114124 additions and 83585 deletions

View File

@@ -14,7 +14,7 @@ run_cpu.bat
IF YOU GET A RED ERROR IN THE UI MAKE SURE YOU HAVE A MODEL/CHECKPOINT IN: ComfyUI\models\checkpoints
You can download the stable diffusion 1.5 one from: https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.ckpt
You can download the stable diffusion 1.5 one from: https://huggingface.co/Comfy-Org/stable-diffusion-v1-5-archive/blob/main/v1-5-pruned-emaonly-fp16.safetensors
RECOMMENDED WAY TO UPDATE:

View File

@@ -23,7 +23,7 @@ jobs:
runner_label: [self-hosted, Linux]
flags: ""
- os: windows
runner_label: [self-hosted, win]
runner_label: [self-hosted, Windows]
flags: ""
runs-on: ${{ matrix.runner_label }}
steps:

View File

@@ -12,17 +12,17 @@ on:
description: 'CUDA version'
required: true
type: string
default: "121"
default: "124"
python_minor:
description: 'Python minor version'
required: true
type: string
default: "11"
default: "12"
python_patch:
description: 'Python patch version'
required: true
type: string
default: "9"
default: "7"
jobs:

21
.github/workflows/stale-issues.yml vendored Normal file
View File

@@ -0,0 +1,21 @@
name: 'Close stale issues'
on:
schedule:
# Run daily at 430 am PT
- cron: '30 11 * * *'
permissions:
issues: write
jobs:
stale:
runs-on: ubuntu-latest
steps:
- uses: actions/stale@v9
with:
stale-issue-message: "This issue is being marked stale because it has not had any activity for 30 days. Reply below within 7 days if your issue still isn't solved, and it will be left open. Otherwise, the issue will be closed automatically."
days-before-stale: 30
days-before-close: 7
stale-issue-label: 'Stale'
only-labels: 'User Support'
exempt-all-assignees: true
exempt-all-milestones: true

View File

@@ -32,7 +32,7 @@ jobs:
runner_label: [self-hosted, Linux]
flags: ""
- os: windows
runner_label: [self-hosted, win]
runner_label: [self-hosted, Windows]
flags: ""
runs-on: ${{ matrix.runner_label }}
steps:
@@ -55,7 +55,7 @@ jobs:
torch_version: ["nightly"]
include:
- os: windows
runner_label: [self-hosted, win]
runner_label: [self-hosted, Windows]
flags: ""
runs-on: ${{ matrix.runner_label }}
steps:

30
.github/workflows/test-unit.yml vendored Normal file
View File

@@ -0,0 +1,30 @@
name: Unit Tests
on:
push:
branches: [ main, master ]
pull_request:
branches: [ main, master ]
jobs:
test:
strategy:
matrix:
os: [ubuntu-latest, windows-latest, macos-latest]
runs-on: ${{ matrix.os }}
continue-on-error: true
steps:
- uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: '3.10'
- name: Install requirements
run: |
python -m pip install --upgrade pip
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
pip install -r requirements.txt
- name: Run Unit Tests
run: |
pip install -r tests-unit/requirements.txt
python -m pytest tests-unit

View File

@@ -12,7 +12,7 @@ on:
description: 'extra dependencies'
required: false
type: string
default: "\"numpy<2\""
default: ""
cu:
description: 'cuda version'
required: true
@@ -23,13 +23,13 @@ on:
description: 'python minor version'
required: true
type: string
default: "11"
default: "12"
python_patch:
description: 'python patch version'
required: true
type: string
default: "9"
default: "7"
# push:
# branches:
# - master

View File

@@ -13,13 +13,13 @@ on:
description: 'python minor version'
required: true
type: string
default: "11"
default: "12"
python_patch:
description: 'python patch version'
required: true
type: string
default: "9"
default: "7"
# push:
# branches:
# - master

1
.gitignore vendored
View File

@@ -12,6 +12,7 @@ extra_model_paths.yaml
.vscode/
.idea/
venv/
.venv/
/web/extensions/*
!/web/extensions/logging.js.example
!/web/extensions/core/

View File

@@ -1,7 +1,7 @@
<div align="center">
# ComfyUI
**The most powerful and modular stable diffusion GUI and backend.**
**The most powerful and modular diffusion model GUI and backend.**
[![Website][website-shield]][website-url]
@@ -94,6 +94,8 @@ Workflow examples can be found on the [Examples page](https://comfyanonymous.git
| Alt + `+` | Canvas Zoom in |
| Alt + `-` | Canvas Zoom out |
| Ctrl + Shift + LMB + Vertical drag | Canvas Zoom in/out |
| P | Pin/Unpin selected nodes |
| Ctrl + G | Group selected nodes |
| Q | Toggle visibility of the queue |
| H | Toggle visibility of history |
| R | Refresh graph |
@@ -125,6 +127,8 @@ To run it on services like paperspace, kaggle or colab you can use my [Jupyter N
## Manual Install (Windows, Linux)
Note that some dependencies do not yet support python 3.13 so using 3.12 is recommended.
Git clone this repo.
Put your SD checkpoints (the huge ckpt/safetensors files) in: models/checkpoints
@@ -135,17 +139,17 @@ Put your VAE in: models/vae
### AMD GPUs (Linux only)
AMD users can install rocm and pytorch with pip if you don't have it already installed, this is the command to install the stable version:
```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.0```
```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.1```
This is the command to install the nightly with ROCm 6.0 which might have some performance improvements:
This is the command to install the nightly with ROCm 6.2 which might have some performance improvements:
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm6.1```
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm6.2```
### NVIDIA
Nvidia users should install stable pytorch using this command:
```pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu121```
```pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu124```
This is the command to install pytorch nightly instead which might have performance improvements:
@@ -230,7 +234,7 @@ To use a textual inversion concepts/embeddings in a text prompt put them in the
Use ```--preview-method auto``` to enable previews.
The default installation includes a fast latent preview method that's low-resolution. To enable higher-quality previews with [TAESD](https://github.com/madebyollin/taesd), download the [taesd_decoder.pth](https://github.com/madebyollin/taesd/raw/main/taesd_decoder.pth) (for SD1.x and SD2.x) and [taesdxl_decoder.pth](https://github.com/madebyollin/taesd/raw/main/taesdxl_decoder.pth) (for SDXL) models and place them in the `models/vae_approx` folder. Once they're installed, restart ComfyUI to enable high-quality previews.
The default installation includes a fast latent preview method that's low-resolution. To enable higher-quality previews with [TAESD](https://github.com/madebyollin/taesd), download the [taesd_decoder.pth, taesdxl_decoder.pth, taesd3_decoder.pth and taef1_decoder.pth](https://github.com/madebyollin/taesd/) and place them in the `models/vae_approx` folder. Once they're installed, restart ComfyUI and launch it with `--preview-method taesd` to enable high-quality previews.
## How to use TLS/SSL?
Generate a self-signed certificate (not appropriate for shared/production use) and key by running the command: `openssl req -x509 -newkey rsa:4096 -keyout key.pem -out cert.pem -sha256 -days 3650 -nodes -subj "/C=XX/ST=StateName/L=CityName/O=CompanyName/OU=CompanySectionName/CN=CommonNameOrHostname"`

View File

@@ -1,7 +1,8 @@
from aiohttp import web
from typing import Optional
from folder_paths import models_dir, user_directory, output_directory
from folder_paths import models_dir, user_directory, output_directory, folder_names_and_paths
from api_server.services.file_service import FileService
import app.logger
class InternalRoutes:
'''
@@ -31,6 +32,16 @@ class InternalRoutes:
except Exception as e:
return web.json_response({"error": str(e)}, status=500)
@self.routes.get('/logs')
async def get_logs(request):
return web.json_response(app.logger.get_logs())
@self.routes.get('/folder_paths')
async def get_folder_paths(request):
response = {}
for key in folder_names_and_paths:
response[key] = folder_names_and_paths[key][0]
return web.json_response(response)
def get_app(self):
if self._app is None:

View File

@@ -8,7 +8,7 @@ import zipfile
from dataclasses import dataclass
from functools import cached_property
from pathlib import Path
from typing import TypedDict
from typing import TypedDict, Optional
import requests
from typing_extensions import NotRequired
@@ -132,12 +132,13 @@ class FrontendManager:
return match_result.group(1), match_result.group(2), match_result.group(3)
@classmethod
def init_frontend_unsafe(cls, version_string: str) -> str:
def init_frontend_unsafe(cls, version_string: str, provider: Optional[FrontEndProvider] = None) -> str:
"""
Initializes the frontend for the specified version.
Args:
version_string (str): The version string.
provider (FrontEndProvider, optional): The provider to use. Defaults to None.
Returns:
str: The path to the initialized frontend.
@@ -150,7 +151,16 @@ class FrontendManager:
return cls.DEFAULT_FRONTEND_PATH
repo_owner, repo_name, version = cls.parse_version_string(version_string)
provider = FrontEndProvider(repo_owner, repo_name)
if version.startswith("v"):
expected_path = str(Path(cls.CUSTOM_FRONTENDS_ROOT) / f"{repo_owner}_{repo_name}" / version.lstrip("v"))
if os.path.exists(expected_path):
logging.info(f"Using existing copy of specific frontend version tag: {repo_owner}/{repo_name}@{version}")
return expected_path
logging.info(f"Initializing frontend: {repo_owner}/{repo_name}@{version}, requesting version details from GitHub...")
provider = provider or FrontEndProvider(repo_owner, repo_name)
release = provider.get_release(version)
semantic_version = release["tag_name"].lstrip("v")
@@ -158,15 +168,21 @@ class FrontendManager:
Path(cls.CUSTOM_FRONTENDS_ROOT) / provider.folder_name / semantic_version
)
if not os.path.exists(web_root):
os.makedirs(web_root, exist_ok=True)
logging.info(
"Downloading frontend(%s) version(%s) to (%s)",
provider.folder_name,
semantic_version,
web_root,
)
logging.debug(release)
download_release_asset_zip(release, destination_path=web_root)
try:
os.makedirs(web_root, exist_ok=True)
logging.info(
"Downloading frontend(%s) version(%s) to (%s)",
provider.folder_name,
semantic_version,
web_root,
)
logging.debug(release)
download_release_asset_zip(release, destination_path=web_root)
finally:
# Clean up the directory if it is empty, i.e. the download failed
if not os.listdir(web_root):
os.rmdir(web_root)
return web_root
@classmethod

31
app/logger.py Normal file
View File

@@ -0,0 +1,31 @@
import logging
from logging.handlers import MemoryHandler
from collections import deque
logs = None
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
def get_logs():
return "\n".join([formatter.format(x) for x in logs])
def setup_logger(log_level: str = 'INFO', capacity: int = 300):
global logs
if logs:
return
# Setup default global logger
logger = logging.getLogger()
logger.setLevel(log_level)
stream_handler = logging.StreamHandler()
stream_handler.setFormatter(logging.Formatter("%(message)s"))
logger.addHandler(stream_handler)
# Create a memory handler with a deque as its buffer
logs = deque(maxlen=capacity)
memory_handler = MemoryHandler(capacity, flushLevel=logging.INFO)
memory_handler.buffer = logs
memory_handler.setFormatter(formatter)
logger.addHandler(memory_handler)

View File

@@ -5,17 +5,17 @@ import uuid
import glob
import shutil
from aiohttp import web
from urllib import parse
from comfy.cli_args import args
from folder_paths import user_directory
import folder_paths
from .app_settings import AppSettings
default_user = "default"
users_file = os.path.join(user_directory, "users.json")
class UserManager():
def __init__(self):
global user_directory
user_directory = folder_paths.get_user_directory()
self.settings = AppSettings(self)
if not os.path.exists(user_directory):
@@ -25,14 +25,17 @@ class UserManager():
print("****** For multi-user setups add the --multi-user CLI argument to enable multiple user profiles. ******")
if args.multi_user:
if os.path.isfile(users_file):
with open(users_file) as f:
if os.path.isfile(self.get_users_file()):
with open(self.get_users_file()) as f:
self.users = json.load(f)
else:
self.users = {}
else:
self.users = {"default": "default"}
def get_users_file(self):
return os.path.join(folder_paths.get_user_directory(), "users.json")
def get_request_user_id(self, request):
user = "default"
if args.multi_user and "comfy-user" in request.headers:
@@ -44,7 +47,7 @@ class UserManager():
return user
def get_request_user_filepath(self, request, file, type="userdata", create_dir=True):
global user_directory
user_directory = folder_paths.get_user_directory()
if type == "userdata":
root_dir = user_directory
@@ -59,6 +62,10 @@ class UserManager():
return None
if file is not None:
# Check if filename is url encoded
if "%" in file:
file = parse.unquote(file)
# prevent leaving /{type}/{user}
path = os.path.abspath(os.path.join(user_root, file))
if os.path.commonpath((user_root, path)) != user_root:
@@ -80,8 +87,7 @@ class UserManager():
self.users[user_id] = name
global users_file
with open(users_file, "w") as f:
with open(self.get_users_file(), "w") as f:
json.dump(self.users, f)
return user_id
@@ -112,25 +118,69 @@ class UserManager():
@routes.get("/userdata")
async def listuserdata(request):
"""
List user data files in a specified directory.
This endpoint allows listing files in a user's data directory, with options for recursion,
full file information, and path splitting.
Query Parameters:
- dir (required): The directory to list files from.
- recurse (optional): If "true", recursively list files in subdirectories.
- full_info (optional): If "true", return detailed file information (path, size, modified time).
- split (optional): If "true", split file paths into components (only applies when full_info is false).
Returns:
- 400: If 'dir' parameter is missing.
- 403: If the requested path is not allowed.
- 404: If the requested directory does not exist.
- 200: JSON response with the list of files or file information.
The response format depends on the query parameters:
- Default: List of relative file paths.
- full_info=true: List of dictionaries with file details.
- split=true (and full_info=false): List of lists, each containing path components.
"""
directory = request.rel_url.query.get('dir', '')
if not directory:
return web.Response(status=400)
return web.Response(status=400, text="Directory not provided")
path = self.get_request_user_filepath(request, directory)
if not path:
return web.Response(status=403)
return web.Response(status=403, text="Invalid directory")
if not os.path.exists(path):
return web.Response(status=404)
return web.Response(status=404, text="Directory not found")
recurse = request.rel_url.query.get('recurse', '').lower() == "true"
results = glob.glob(os.path.join(
glob.escape(path), '**/*'), recursive=recurse)
results = [os.path.relpath(x, path) for x in results if os.path.isfile(x)]
full_info = request.rel_url.query.get('full_info', '').lower() == "true"
# Use different patterns based on whether we're recursing or not
if recurse:
pattern = os.path.join(glob.escape(path), '**', '*')
else:
pattern = os.path.join(glob.escape(path), '*')
results = glob.glob(pattern, recursive=recurse)
if full_info:
results = [
{
'path': os.path.relpath(x, path).replace(os.sep, '/'),
'size': os.path.getsize(x),
'modified': os.path.getmtime(x)
} for x in results if os.path.isfile(x)
]
else:
results = [
os.path.relpath(x, path).replace(os.sep, '/')
for x in results
if os.path.isfile(x)
]
split_path = request.rel_url.query.get('split', '').lower() == "true"
if split_path:
results = [[x] + x.split(os.sep) for x in results]
if split_path and not full_info:
results = [[x] + x.split('/') for x in results]
return web.json_response(results)
@@ -138,14 +188,14 @@ class UserManager():
file = request.match_info.get(param, None)
if not file:
return web.Response(status=400)
path = self.get_request_user_filepath(request, file)
if not path:
return web.Response(status=403)
if check_exists and not os.path.exists(path):
return web.Response(status=404)
return path
@routes.get("/userdata/{file}")
@@ -153,7 +203,7 @@ class UserManager():
path = get_user_data_path(request, check_exists=True)
if not isinstance(path, str):
return path
return web.FileResponse(path)
@routes.post("/userdata/{file}")
@@ -161,7 +211,7 @@ class UserManager():
path = get_user_data_path(request)
if not isinstance(path, str):
return path
overwrite = request.query["overwrite"] != "false"
if not overwrite and os.path.exists(path):
return web.Response(status=409)
@@ -170,7 +220,7 @@ class UserManager():
with open(path, "wb") as f:
f.write(body)
resp = os.path.relpath(path, self.get_request_user_filepath(request, None))
return web.json_response(resp)
@@ -181,7 +231,7 @@ class UserManager():
return path
os.remove(path)
return web.Response(status=204)
@routes.post("/userdata/{file}/move/{dest}")
@@ -189,17 +239,17 @@ class UserManager():
source = get_user_data_path(request, check_exists=True)
if not isinstance(source, str):
return source
dest = get_user_data_path(request, check_exists=False, param="dest")
if not isinstance(source, str):
return dest
overwrite = request.query["overwrite"] != "false"
if not overwrite and os.path.exists(dest):
return web.Response(status=409)
print(f"moving '{source}' -> '{dest}'")
shutil.move(source, dest)
resp = os.path.relpath(dest, self.get_request_user_filepath(request, None))
return web.json_response(resp)

View File

@@ -6,6 +6,7 @@ class ControlNet(comfy.ldm.modules.diffusionmodules.mmdit.MMDiT):
def __init__(
self,
num_blocks = None,
control_latent_channels = None,
dtype = None,
device = None,
operations = None,
@@ -17,10 +18,13 @@ class ControlNet(comfy.ldm.modules.diffusionmodules.mmdit.MMDiT):
for _ in range(len(self.joint_blocks)):
self.controlnet_blocks.append(operations.Linear(self.hidden_size, self.hidden_size, device=device, dtype=dtype))
if control_latent_channels is None:
control_latent_channels = self.in_channels
self.pos_embed_input = comfy.ldm.modules.diffusionmodules.mmdit.PatchEmbed(
None,
self.patch_size,
self.in_channels,
control_latent_channels,
self.hidden_size,
bias=True,
strict_img_size=False,

View File

@@ -36,7 +36,7 @@ class EnumAction(argparse.Action):
parser = argparse.ArgumentParser()
parser.add_argument("--listen", type=str, default="127.0.0.1", metavar="IP", nargs="?", const="0.0.0.0", help="Specify the IP address to listen on (default: 127.0.0.1). If --listen is provided without an argument, it defaults to 0.0.0.0. (listens on all)")
parser.add_argument("--listen", type=str, default="127.0.0.1", metavar="IP", nargs="?", const="0.0.0.0,::", help="Specify the IP address to listen on (default: 127.0.0.1). You can give a list of ip addresses by separating them with a comma like: 127.2.2.2,127.3.3.3 If --listen is provided without an argument, it defaults to 0.0.0.0,:: (listens on all ipv4 and ipv6)")
parser.add_argument("--port", type=int, default=8188, help="Set the listen port.")
parser.add_argument("--tls-keyfile", type=str, help="Path to TLS (SSL) key file. Enables TLS, makes app accessible at https://... requires --tls-certfile to function")
parser.add_argument("--tls-certfile", type=str, help="Path to TLS (SSL) certificate file. Enables TLS, makes app accessible at https://... requires --tls-keyfile to function")
@@ -92,6 +92,8 @@ class LatentPreviewMethod(enum.Enum):
parser.add_argument("--preview-method", type=LatentPreviewMethod, default=LatentPreviewMethod.NoPreviews, help="Default preview method for sampler nodes.", action=EnumAction)
parser.add_argument("--preview-size", type=int, default=512, help="Sets the maximum preview size for sampler nodes.")
cache_group = parser.add_mutually_exclusive_group()
cache_group.add_argument("--cache-classic", action="store_true", help="Use the old style (aggressive) caching.")
cache_group.add_argument("--cache-lru", type=int, default=0, help="Use LRU caching with a maximum of N node results cached. May use more RAM/VRAM.")
@@ -134,7 +136,7 @@ parser.add_argument("--disable-all-custom-nodes", action="store_true", help="Dis
parser.add_argument("--multi-user", action="store_true", help="Enables per-user storage.")
parser.add_argument("--verbose", action="store_true", help="Enables more debug prints.")
parser.add_argument("--verbose", default='INFO', const='DEBUG', nargs="?", choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], help='Set the logging level')
# The default built-in provider hosted under web/
DEFAULT_VERSION_STRING = "comfyanonymous/ComfyUI@latest"
@@ -169,6 +171,8 @@ parser.add_argument(
help="The local filesystem path to the directory where the frontend is located. Overrides --front-end-version.",
)
parser.add_argument("--user-directory", type=is_valid_directory, default=None, help="Set the ComfyUI user directory with an absolute path.")
if comfy.options.args_parsing:
args = parser.parse_args()
else:
@@ -179,10 +183,3 @@ if args.windows_standalone_build:
if args.disable_auto_launch:
args.auto_launch = False
import logging
logging_level = logging.INFO
if args.verbose:
logging_level = logging.DEBUG
logging.basicConfig(format="%(message)s", level=logging_level)

View File

@@ -109,8 +109,7 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False):
keys = list(sd.keys())
for k in keys:
if k not in u:
t = sd.pop(k)
del t
sd.pop(k)
return clip
def load(ckpt_path):

View File

@@ -34,7 +34,7 @@ import comfy.t2i_adapter.adapter
import comfy.ldm.cascade.controlnet
import comfy.cldm.mmdit
import comfy.ldm.hydit.controlnet
import comfy.ldm.flux.controlnet_xlabs
import comfy.ldm.flux.controlnet
def broadcast_image_to(tensor, target_batch_size, batched_number):
@@ -60,7 +60,7 @@ class StrengthType(Enum):
LINEAR_UP = 2
class ControlBase:
def __init__(self, device=None):
def __init__(self):
self.cond_hint_original = None
self.cond_hint = None
self.strength = 1.0
@@ -72,20 +72,24 @@ class ControlBase:
self.compression_ratio = 8
self.upscale_algorithm = 'nearest-exact'
self.extra_args = {}
if device is None:
device = comfy.model_management.get_torch_device()
self.device = device
self.previous_controlnet = None
self.extra_conds = []
self.strength_type = StrengthType.CONSTANT
self.concat_mask = False
self.extra_concat_orig = []
self.extra_concat = None
def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(0.0, 1.0), vae=None):
def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(0.0, 1.0), vae=None, extra_concat=[]):
self.cond_hint_original = cond_hint
self.strength = strength
self.timestep_percent_range = timestep_percent_range
if self.latent_format is not None:
if vae is None:
logging.warning("WARNING: no VAE provided to the controlnet apply node when this controlnet requires one.")
self.vae = vae
self.extra_concat_orig = extra_concat.copy()
if self.concat_mask and len(self.extra_concat_orig) == 0:
self.extra_concat_orig.append(torch.tensor([[[[1.0]]]]))
return self
def pre_run(self, model, percent_to_timestep_function):
@@ -100,9 +104,9 @@ class ControlBase:
def cleanup(self):
if self.previous_controlnet is not None:
self.previous_controlnet.cleanup()
if self.cond_hint is not None:
del self.cond_hint
self.cond_hint = None
self.cond_hint = None
self.extra_concat = None
self.timestep_range = None
def get_models(self):
@@ -123,6 +127,8 @@ class ControlBase:
c.vae = self.vae
c.extra_conds = self.extra_conds.copy()
c.strength_type = self.strength_type
c.concat_mask = self.concat_mask
c.extra_concat_orig = self.extra_concat_orig.copy()
def inference_memory_requirements(self, dtype):
if self.previous_controlnet is not None:
@@ -148,7 +154,7 @@ class ControlBase:
elif self.strength_type == StrengthType.LINEAR_UP:
x *= (self.strength ** float(len(control_output) - i))
if x.dtype != output_dtype:
if output_dtype is not None and x.dtype != output_dtype:
x = x.to(output_dtype)
out[key].append(x)
@@ -175,8 +181,8 @@ class ControlBase:
class ControlNet(ControlBase):
def __init__(self, control_model=None, global_average_pooling=False, compression_ratio=8, latent_format=None, device=None, load_device=None, manual_cast_dtype=None, extra_conds=["y"], strength_type=StrengthType.CONSTANT):
super().__init__(device)
def __init__(self, control_model=None, global_average_pooling=False, compression_ratio=8, latent_format=None, load_device=None, manual_cast_dtype=None, extra_conds=["y"], strength_type=StrengthType.CONSTANT, concat_mask=False):
super().__init__()
self.control_model = control_model
self.load_device = load_device
if control_model is not None:
@@ -189,6 +195,7 @@ class ControlNet(ControlBase):
self.latent_format = latent_format
self.extra_conds += extra_conds
self.strength_type = strength_type
self.concat_mask = concat_mask
def get_control(self, x_noisy, t, cond, batched_number):
control_prev = None
@@ -206,7 +213,6 @@ class ControlNet(ControlBase):
if self.manual_cast_dtype is not None:
dtype = self.manual_cast_dtype
output_dtype = x_noisy.dtype
if self.cond_hint is None or x_noisy.shape[2] * self.compression_ratio != self.cond_hint.shape[2] or x_noisy.shape[3] * self.compression_ratio != self.cond_hint.shape[3]:
if self.cond_hint is not None:
del self.cond_hint
@@ -214,6 +220,9 @@ class ControlNet(ControlBase):
compression_ratio = self.compression_ratio
if self.vae is not None:
compression_ratio *= self.vae.downscale_ratio
else:
if self.latent_format is not None:
raise ValueError("This Controlnet needs a VAE but none was provided, please use a ControlNetApply node with a VAE input and connect it.")
self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * compression_ratio, x_noisy.shape[2] * compression_ratio, self.upscale_algorithm, "center")
if self.vae is not None:
loaded_models = comfy.model_management.loaded_models(only_currently_used=True)
@@ -221,7 +230,15 @@ class ControlNet(ControlBase):
comfy.model_management.load_models_gpu(loaded_models)
if self.latent_format is not None:
self.cond_hint = self.latent_format.process_in(self.cond_hint)
self.cond_hint = self.cond_hint.to(device=self.device, dtype=dtype)
if len(self.extra_concat_orig) > 0:
to_concat = []
for c in self.extra_concat_orig:
c = c.to(self.cond_hint.device)
c = comfy.utils.common_upscale(c, self.cond_hint.shape[3], self.cond_hint.shape[2], self.upscale_algorithm, "center")
to_concat.append(comfy.utils.repeat_to_batch_size(c, self.cond_hint.shape[0]))
self.cond_hint = torch.cat([self.cond_hint] + to_concat, dim=1)
self.cond_hint = self.cond_hint.to(device=x_noisy.device, dtype=dtype)
if x_noisy.shape[0] != self.cond_hint.shape[0]:
self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number)
@@ -236,7 +253,7 @@ class ControlNet(ControlBase):
x_noisy = self.model_sampling_current.calculate_input(t, x_noisy)
control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.to(dtype), context=context.to(dtype), **extra)
return self.control_merge(control, control_prev, output_dtype)
return self.control_merge(control, control_prev, output_dtype=None)
def copy(self):
c = ControlNet(None, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype)
@@ -320,8 +337,8 @@ class ControlLoraOps:
class ControlLora(ControlNet):
def __init__(self, control_weights, global_average_pooling=False, device=None):
ControlBase.__init__(self, device)
def __init__(self, control_weights, global_average_pooling=False, model_options={}): #TODO? model_options
ControlBase.__init__(self)
self.control_weights = control_weights
self.global_average_pooling = global_average_pooling
self.extra_conds += ["y"]
@@ -377,21 +394,28 @@ class ControlLora(ControlNet):
def inference_memory_requirements(self, dtype):
return comfy.utils.calculate_parameters(self.control_weights) * comfy.model_management.dtype_size(dtype) + ControlBase.inference_memory_requirements(self, dtype)
def controlnet_config(sd):
def controlnet_config(sd, model_options={}):
model_config = comfy.model_detection.model_config_from_unet(sd, "", True)
supported_inference_dtypes = model_config.supported_inference_dtypes
unet_dtype = model_options.get("dtype", None)
if unet_dtype is None:
weight_dtype = comfy.utils.weight_dtype(sd)
supported_inference_dtypes = list(model_config.supported_inference_dtypes)
if weight_dtype is not None:
supported_inference_dtypes.append(weight_dtype)
unet_dtype = comfy.model_management.unet_dtype(model_params=-1, supported_dtypes=supported_inference_dtypes)
controlnet_config = model_config.unet_config
unet_dtype = comfy.model_management.unet_dtype(supported_dtypes=supported_inference_dtypes)
load_device = comfy.model_management.get_torch_device()
manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device)
if manual_cast_dtype is not None:
operations = comfy.ops.manual_cast
else:
operations = comfy.ops.disable_weight_init
return model_config, operations, load_device, unet_dtype, manual_cast_dtype
operations = model_options.get("custom_operations", None)
if operations is None:
operations = comfy.ops.pick_operations(unet_dtype, manual_cast_dtype, disable_fast_fp8=True)
offload_device = comfy.model_management.unet_offload_device()
return model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device
def controlnet_load_state_dict(control_model, sd):
missing, unexpected = control_model.load_state_dict(sd, strict=False)
@@ -403,26 +427,31 @@ def controlnet_load_state_dict(control_model, sd):
logging.debug("unexpected controlnet keys: {}".format(unexpected))
return control_model
def load_controlnet_mmdit(sd):
def load_controlnet_mmdit(sd, model_options={}):
new_sd = comfy.model_detection.convert_diffusers_mmdit(sd, "")
model_config, operations, load_device, unet_dtype, manual_cast_dtype = controlnet_config(new_sd)
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(new_sd, model_options=model_options)
num_blocks = comfy.model_detection.count_blocks(new_sd, 'joint_blocks.{}.')
for k in sd:
new_sd[k] = sd[k]
control_model = comfy.cldm.mmdit.ControlNet(num_blocks=num_blocks, operations=operations, device=load_device, dtype=unet_dtype, **model_config.unet_config)
concat_mask = False
control_latent_channels = new_sd.get("pos_embed_input.proj.weight").shape[1]
if control_latent_channels == 17: #inpaint controlnet
concat_mask = True
control_model = comfy.cldm.mmdit.ControlNet(num_blocks=num_blocks, control_latent_channels=control_latent_channels, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
control_model = controlnet_load_state_dict(control_model, new_sd)
latent_format = comfy.latent_formats.SD3()
latent_format.shift_factor = 0 #SD3 controlnet weirdness
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, concat_mask=concat_mask, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
return control
def load_controlnet_hunyuandit(controlnet_data):
model_config, operations, load_device, unet_dtype, manual_cast_dtype = controlnet_config(controlnet_data)
def load_controlnet_hunyuandit(controlnet_data, model_options={}):
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(controlnet_data, model_options=model_options)
control_model = comfy.ldm.hydit.controlnet.HunYuanControlNet(operations=operations, device=load_device, dtype=unet_dtype)
control_model = comfy.ldm.hydit.controlnet.HunYuanControlNet(operations=operations, device=offload_device, dtype=unet_dtype)
control_model = controlnet_load_state_dict(control_model, controlnet_data)
latent_format = comfy.latent_formats.SDXL()
@@ -430,22 +459,49 @@ def load_controlnet_hunyuandit(controlnet_data):
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds, strength_type=StrengthType.CONSTANT)
return control
def load_controlnet_flux_xlabs(sd):
model_config, operations, load_device, unet_dtype, manual_cast_dtype = controlnet_config(sd)
control_model = comfy.ldm.flux.controlnet_xlabs.ControlNetFlux(operations=operations, device=load_device, dtype=unet_dtype, **model_config.unet_config)
def load_controlnet_flux_xlabs_mistoline(sd, mistoline=False, model_options={}):
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(sd, model_options=model_options)
control_model = comfy.ldm.flux.controlnet.ControlNetFlux(mistoline=mistoline, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
control_model = controlnet_load_state_dict(control_model, sd)
extra_conds = ['y', 'guidance']
control = ControlNet(control_model, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
return control
def load_controlnet_flux_instantx(sd, model_options={}):
new_sd = comfy.model_detection.convert_diffusers_mmdit(sd, "")
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(new_sd, model_options=model_options)
for k in sd:
new_sd[k] = sd[k]
def load_controlnet(ckpt_path, model=None):
controlnet_data = comfy.utils.load_torch_file(ckpt_path, safe_load=True)
num_union_modes = 0
union_cnet = "controlnet_mode_embedder.weight"
if union_cnet in new_sd:
num_union_modes = new_sd[union_cnet].shape[0]
control_latent_channels = new_sd.get("pos_embed_input.weight").shape[1] // 4
concat_mask = False
if control_latent_channels == 17:
concat_mask = True
control_model = comfy.ldm.flux.controlnet.ControlNetFlux(latent_input=True, num_union_modes=num_union_modes, control_latent_channels=control_latent_channels, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
control_model = controlnet_load_state_dict(control_model, new_sd)
latent_format = comfy.latent_formats.Flux()
extra_conds = ['y', 'guidance']
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, concat_mask=concat_mask, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
return control
def convert_mistoline(sd):
return comfy.utils.state_dict_prefix_replace(sd, {"single_controlnet_blocks.": "controlnet_single_blocks."})
def load_controlnet_state_dict(state_dict, model=None, model_options={}):
controlnet_data = state_dict
if 'after_proj_list.18.bias' in controlnet_data.keys(): #Hunyuan DiT
return load_controlnet_hunyuandit(controlnet_data)
return load_controlnet_hunyuandit(controlnet_data, model_options=model_options)
if "lora_controlnet" in controlnet_data:
return ControlLora(controlnet_data)
return ControlLora(controlnet_data, model_options=model_options)
controlnet_config = None
supported_inference_dtypes = None
@@ -500,11 +556,15 @@ def load_controlnet(ckpt_path, model=None):
if len(leftover_keys) > 0:
logging.warning("leftover keys: {}".format(leftover_keys))
controlnet_data = new_sd
elif "controlnet_blocks.0.weight" in controlnet_data: #SD3 diffusers format
elif "controlnet_blocks.0.weight" in controlnet_data:
if "double_blocks.0.img_attn.norm.key_norm.scale" in controlnet_data:
return load_controlnet_flux_xlabs(controlnet_data)
else:
return load_controlnet_mmdit(controlnet_data)
return load_controlnet_flux_xlabs_mistoline(controlnet_data, model_options=model_options)
elif "pos_embed_input.proj.weight" in controlnet_data:
return load_controlnet_mmdit(controlnet_data, model_options=model_options) #SD3 diffusers controlnet
elif "controlnet_x_embedder.weight" in controlnet_data:
return load_controlnet_flux_instantx(controlnet_data, model_options=model_options)
elif "controlnet_blocks.0.linear.weight" in controlnet_data: #mistoline flux
return load_controlnet_flux_xlabs_mistoline(convert_mistoline(controlnet_data), mistoline=True, model_options=model_options)
pth_key = 'control_model.zero_convs.0.0.weight'
pth = False
@@ -516,26 +576,38 @@ def load_controlnet(ckpt_path, model=None):
elif key in controlnet_data:
prefix = ""
else:
net = load_t2i_adapter(controlnet_data)
net = load_t2i_adapter(controlnet_data, model_options=model_options)
if net is None:
logging.error("error checkpoint does not contain controlnet or t2i adapter data {}".format(ckpt_path))
logging.error("error could not detect control model type.")
return net
if controlnet_config is None:
model_config = comfy.model_detection.model_config_from_unet(controlnet_data, prefix, True)
supported_inference_dtypes = model_config.supported_inference_dtypes
supported_inference_dtypes = list(model_config.supported_inference_dtypes)
controlnet_config = model_config.unet_config
unet_dtype = model_options.get("dtype", None)
if unet_dtype is None:
weight_dtype = comfy.utils.weight_dtype(controlnet_data)
if supported_inference_dtypes is None:
supported_inference_dtypes = [comfy.model_management.unet_dtype()]
if weight_dtype is not None:
supported_inference_dtypes.append(weight_dtype)
unet_dtype = comfy.model_management.unet_dtype(model_params=-1, supported_dtypes=supported_inference_dtypes)
load_device = comfy.model_management.get_torch_device()
if supported_inference_dtypes is None:
unet_dtype = comfy.model_management.unet_dtype()
else:
unet_dtype = comfy.model_management.unet_dtype(supported_dtypes=supported_inference_dtypes)
manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device)
if manual_cast_dtype is not None:
controlnet_config["operations"] = comfy.ops.manual_cast
operations = model_options.get("custom_operations", None)
if operations is None:
operations = comfy.ops.pick_operations(unet_dtype, manual_cast_dtype)
controlnet_config["operations"] = operations
controlnet_config["dtype"] = unet_dtype
controlnet_config["device"] = comfy.model_management.unet_offload_device()
controlnet_config.pop("out_channels")
controlnet_config["hint_channels"] = controlnet_data["{}input_hint_block.0.weight".format(prefix)].shape[1]
control_model = comfy.cldm.cldm.ControlNet(**controlnet_config)
@@ -569,22 +641,32 @@ def load_controlnet(ckpt_path, model=None):
if len(unexpected) > 0:
logging.debug("unexpected controlnet keys: {}".format(unexpected))
global_average_pooling = False
filename = os.path.splitext(ckpt_path)[0]
if filename.endswith("_shuffle") or filename.endswith("_shuffle_fp16"): #TODO: smarter way of enabling global_average_pooling
global_average_pooling = True
global_average_pooling = model_options.get("global_average_pooling", False)
control = ControlNet(control_model, global_average_pooling=global_average_pooling, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
return control
def load_controlnet(ckpt_path, model=None, model_options={}):
if "global_average_pooling" not in model_options:
filename = os.path.splitext(ckpt_path)[0]
if filename.endswith("_shuffle") or filename.endswith("_shuffle_fp16"): #TODO: smarter way of enabling global_average_pooling
model_options["global_average_pooling"] = True
cnet = load_controlnet_state_dict(comfy.utils.load_torch_file(ckpt_path, safe_load=True), model=model, model_options=model_options)
if cnet is None:
logging.error("error checkpoint does not contain controlnet or t2i adapter data {}".format(ckpt_path))
return cnet
class T2IAdapter(ControlBase):
def __init__(self, t2i_model, channels_in, compression_ratio, upscale_algorithm, device=None):
super().__init__(device)
super().__init__()
self.t2i_model = t2i_model
self.channels_in = channels_in
self.control_input = None
self.compression_ratio = compression_ratio
self.upscale_algorithm = upscale_algorithm
if device is None:
device = comfy.model_management.get_torch_device()
self.device = device
def scale_image_to(self, width, height):
unshuffle_amount = self.t2i_model.unshuffle_amount
@@ -632,7 +714,7 @@ class T2IAdapter(ControlBase):
self.copy_to(c)
return c
def load_t2i_adapter(t2i_data):
def load_t2i_adapter(t2i_data, model_options={}): #TODO: model_options
compression_ratio = 8
upscale_algorithm = 'nearest-exact'

View File

@@ -1,7 +1,17 @@
import torch
def calc_mantissa(abs_x, exponent, normal_mask, MANTISSA_BITS, EXPONENT_BIAS, generator=None):
mantissa_scaled = torch.where(
normal_mask,
(abs_x / (2.0 ** (exponent - EXPONENT_BIAS)) - 1.0) * (2**MANTISSA_BITS),
(abs_x / (2.0 ** (-EXPONENT_BIAS + 1 - MANTISSA_BITS)))
)
mantissa_scaled += torch.rand(mantissa_scaled.size(), dtype=mantissa_scaled.dtype, layout=mantissa_scaled.layout, device=mantissa_scaled.device, generator=generator)
return mantissa_scaled.floor() / (2**MANTISSA_BITS)
#Not 100% sure about this
def manual_stochastic_round_to_float8(x, dtype):
def manual_stochastic_round_to_float8(x, dtype, generator=None):
if dtype == torch.float8_e4m3fn:
EXPONENT_BITS, MANTISSA_BITS, EXPONENT_BIAS = 4, 3, 7
elif dtype == torch.float8_e5m2:
@@ -9,44 +19,35 @@ def manual_stochastic_round_to_float8(x, dtype):
else:
raise ValueError("Unsupported dtype")
x = x.half()
sign = torch.sign(x)
abs_x = x.abs()
sign = torch.where(abs_x == 0, 0, sign)
# Combine exponent calculation and clamping
exponent = torch.clamp(
torch.floor(torch.log2(abs_x)).to(torch.int32) + EXPONENT_BIAS,
torch.floor(torch.log2(abs_x)) + EXPONENT_BIAS,
0, 2**EXPONENT_BITS - 1
)
# Combine mantissa calculation and rounding
# min_normal = 2.0 ** (-EXPONENT_BIAS + 1)
# zero_mask = (abs_x == 0)
# subnormal_mask = (exponent == 0) & (abs_x != 0)
normal_mask = ~(exponent == 0)
mantissa_scaled = torch.where(
abs_x[:] = calc_mantissa(abs_x, exponent, normal_mask, MANTISSA_BITS, EXPONENT_BIAS, generator=generator)
sign *= torch.where(
normal_mask,
(abs_x / (2.0 ** (exponent - EXPONENT_BIAS)) - 1.0) * (2**MANTISSA_BITS),
(abs_x / (2.0 ** (-EXPONENT_BIAS + 1 - MANTISSA_BITS)))
)
mantissa_floor = mantissa_scaled.floor()
mantissa = torch.where(
torch.rand_like(mantissa_scaled) < (mantissa_scaled - mantissa_floor),
(mantissa_floor + 1) / (2**MANTISSA_BITS),
mantissa_floor / (2**MANTISSA_BITS)
)
result = torch.where(
normal_mask,
sign * (2.0 ** (exponent - EXPONENT_BIAS)) * (1.0 + mantissa),
sign * (2.0 ** (-EXPONENT_BIAS + 1)) * mantissa
(2.0 ** (exponent - EXPONENT_BIAS)) * (1.0 + abs_x),
(2.0 ** (-EXPONENT_BIAS + 1)) * abs_x
)
result = torch.where(abs_x == 0, 0, result)
return result.to(dtype=dtype)
inf = torch.finfo(dtype)
torch.clamp(sign, min=inf.min, max=inf.max, out=sign)
return sign
def stochastic_rounding(value, dtype):
def stochastic_rounding(value, dtype, seed=0):
if dtype == torch.float32:
return value.to(dtype=torch.float32)
if dtype == torch.float16:
@@ -54,6 +55,13 @@ def stochastic_rounding(value, dtype):
if dtype == torch.bfloat16:
return value.to(dtype=torch.bfloat16)
if dtype == torch.float8_e4m3fn or dtype == torch.float8_e5m2:
return manual_stochastic_round_to_float8(value, dtype)
generator = torch.Generator(device=value.device)
generator.manual_seed(seed)
output = torch.empty_like(value, dtype=dtype)
num_slices = max(1, (value.numel() / (4096 * 4096)))
slice_size = max(1, round(value.shape[0] / num_slices))
for i in range(0, value.shape[0], slice_size):
output[i:i+slice_size].copy_(manual_stochastic_round_to_float8(value[i:i+slice_size], dtype, generator=generator))
return output
return value.to(dtype=dtype)

View File

@@ -44,6 +44,17 @@ def get_sigmas_vp(n, beta_d=19.9, beta_min=0.1, eps_s=1e-3, device='cpu'):
return append_zero(sigmas)
def get_sigmas_laplace(n, sigma_min, sigma_max, mu=0., beta=0.5, device='cpu'):
"""Constructs the noise schedule proposed by Tiankai et al. (2024). """
epsilon = 1e-5 # avoid log(0)
x = torch.linspace(0, 1, n, device=device)
clamp = lambda x: torch.clamp(x, min=sigma_min, max=sigma_max)
lmb = mu - beta * torch.sign(0.5-x) * torch.log(1 - 2 * torch.abs(0.5-x) + epsilon)
sigmas = clamp(torch.exp(lmb))
return sigmas
def to_d(x, sigma, denoised):
"""Converts a denoiser output to a Karras ODE derivative."""
return (x - denoised) / utils.append_dims(sigma, x.ndim)
@@ -153,6 +164,8 @@ def sample_euler(model, x, sigmas, extra_args=None, callback=None, disable=None,
@torch.no_grad()
def sample_euler_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
if isinstance(model.inner_model.inner_model.model_sampling, comfy.model_sampling.CONST):
return sample_euler_ancestral_RF(model, x, sigmas, extra_args, callback, disable, eta, s_noise, noise_sampler)
"""Ancestral sampling with Euler method steps."""
extra_args = {} if extra_args is None else extra_args
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
@@ -170,6 +183,29 @@ def sample_euler_ancestral(model, x, sigmas, extra_args=None, callback=None, dis
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
return x
@torch.no_grad()
def sample_euler_ancestral_RF(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1.0, s_noise=1., noise_sampler=None):
"""Ancestral sampling with Euler method steps."""
extra_args = {} if extra_args is None else extra_args
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
s_in = x.new_ones([x.shape[0]])
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
# sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
downstep_ratio = 1 + (sigmas[i+1]/sigmas[i] - 1) * eta
sigma_down = sigmas[i+1] * downstep_ratio
alpha_ip1 = 1 - sigmas[i+1]
alpha_down = 1 - sigma_down
renoise_coeff = (sigmas[i+1]**2 - sigma_down**2*alpha_ip1**2/alpha_down**2)**0.5
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
# Euler method
sigma_down_i_ratio = sigma_down / sigmas[i]
x = sigma_down_i_ratio * x + (1 - sigma_down_i_ratio) * denoised
if sigmas[i + 1] > 0 and eta > 0:
x = (alpha_ip1/alpha_down) * x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * renoise_coeff
return x
@torch.no_grad()
def sample_heun(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
@@ -1069,7 +1105,6 @@ def sample_euler_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disabl
d = to_d(x, sigma_hat, temp[0])
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
dt = sigmas[i + 1] - sigma_hat
# Euler method
x = denoised + d * sigmas[i + 1]
return x
@@ -1096,8 +1131,81 @@ def sample_euler_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=No
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
d = to_d(x, sigmas[i], temp[0])
# Euler method
dt = sigma_down - sigmas[i]
x = denoised + d * sigma_down
if sigmas[i + 1] > 0:
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
return x
@torch.no_grad()
def sample_dpmpp_2s_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
"""Ancestral sampling with DPM-Solver++(2S) second-order steps."""
extra_args = {} if extra_args is None else extra_args
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
temp = [0]
def post_cfg_function(args):
temp[0] = args["uncond_denoised"]
return args["denoised"]
model_options = extra_args.get("model_options", {}).copy()
extra_args["model_options"] = comfy.model_patcher.set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=True)
s_in = x.new_ones([x.shape[0]])
sigma_fn = lambda t: t.neg().exp()
t_fn = lambda sigma: sigma.log().neg()
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
if sigma_down == 0:
# Euler method
d = to_d(x, sigmas[i], temp[0])
x = denoised + d * sigma_down
else:
# DPM-Solver++(2S)
t, t_next = t_fn(sigmas[i]), t_fn(sigma_down)
# r = torch.sinh(1 + (2 - eta) * (t_next - t) / (t - t_fn(sigma_up))) works only on non-cfgpp, weird
r = 1 / 2
h = t_next - t
s = t + r * h
x_2 = (sigma_fn(s) / sigma_fn(t)) * (x + (denoised - temp[0])) - (-h * r).expm1() * denoised
denoised_2 = model(x_2, sigma_fn(s) * s_in, **extra_args)
x = (sigma_fn(t_next) / sigma_fn(t)) * (x + (denoised - temp[0])) - (-h).expm1() * denoised_2
# Noise addition
if sigmas[i + 1] > 0:
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
return x
@torch.no_grad()
def sample_dpmpp_2m_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None):
"""DPM-Solver++(2M)."""
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
t_fn = lambda sigma: sigma.log().neg()
old_uncond_denoised = None
uncond_denoised = None
def post_cfg_function(args):
nonlocal uncond_denoised
uncond_denoised = args["uncond_denoised"]
return args["denoised"]
model_options = extra_args.get("model_options", {}).copy()
extra_args["model_options"] = comfy.model_patcher.set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=True)
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1])
h = t_next - t
if old_uncond_denoised is None or sigmas[i + 1] == 0:
denoised_mix = -torch.exp(-h) * uncond_denoised
else:
h_last = t - t_fn(sigmas[i - 1])
r = h_last / h
denoised_mix = -torch.exp(-h) * uncond_denoised - torch.expm1(-h) * (1 / (2 * r)) * (denoised - old_uncond_denoised)
x = denoised + denoised_mix + torch.exp(-h) * x
old_uncond_denoised = uncond_denoised
return x

View File

@@ -4,6 +4,7 @@ class LatentFormat:
scale_factor = 1.0
latent_channels = 4
latent_rgb_factors = None
latent_rgb_factors_bias = None
taesd_decoder_name = None
def process_in(self, latent):
@@ -30,11 +31,13 @@ class SDXL(LatentFormat):
def __init__(self):
self.latent_rgb_factors = [
# R G B
[ 0.3920, 0.4054, 0.4549],
[-0.2634, -0.0196, 0.0653],
[ 0.0568, 0.1687, -0.0755],
[-0.3112, -0.2359, -0.2076]
[ 0.3651, 0.4232, 0.4341],
[-0.2533, -0.0042, 0.1068],
[ 0.1076, 0.1111, -0.0362],
[-0.3165, -0.2492, -0.2188]
]
self.latent_rgb_factors_bias = [ 0.1084, -0.0175, -0.0011]
self.taesd_decoder_name = "taesdxl_decoder"
class SDXL_Playground_2_5(LatentFormat):
@@ -112,23 +115,24 @@ class SD3(LatentFormat):
self.scale_factor = 1.5305
self.shift_factor = 0.0609
self.latent_rgb_factors = [
[-0.0645, 0.0177, 0.1052],
[ 0.0028, 0.0312, 0.0650],
[ 0.1848, 0.0762, 0.0360],
[ 0.0944, 0.0360, 0.0889],
[ 0.0897, 0.0506, -0.0364],
[-0.0020, 0.1203, 0.0284],
[ 0.0855, 0.0118, 0.0283],
[-0.0539, 0.0658, 0.1047],
[-0.0057, 0.0116, 0.0700],
[-0.0412, 0.0281, -0.0039],
[ 0.1106, 0.1171, 0.1220],
[-0.0248, 0.0682, -0.0481],
[ 0.0815, 0.0846, 0.1207],
[-0.0120, -0.0055, -0.0867],
[-0.0749, -0.0634, -0.0456],
[-0.1418, -0.1457, -0.1259]
[-0.0922, -0.0175, 0.0749],
[ 0.0311, 0.0633, 0.0954],
[ 0.1994, 0.0927, 0.0458],
[ 0.0856, 0.0339, 0.0902],
[ 0.0587, 0.0272, -0.0496],
[-0.0006, 0.1104, 0.0309],
[ 0.0978, 0.0306, 0.0427],
[-0.0042, 0.1038, 0.1358],
[-0.0194, 0.0020, 0.0669],
[-0.0488, 0.0130, -0.0268],
[ 0.0922, 0.0988, 0.0951],
[-0.0278, 0.0524, -0.0542],
[ 0.0332, 0.0456, 0.0895],
[-0.0069, -0.0030, -0.0810],
[-0.0596, -0.0465, -0.0293],
[-0.1448, -0.1463, -0.1189]
]
self.latent_rgb_factors_bias = [0.2394, 0.2135, 0.1925]
self.taesd_decoder_name = "taesd3_decoder"
def process_in(self, latent):
@@ -146,23 +150,24 @@ class Flux(SD3):
self.scale_factor = 0.3611
self.shift_factor = 0.1159
self.latent_rgb_factors =[
[-0.0404, 0.0159, 0.0609],
[ 0.0043, 0.0298, 0.0850],
[ 0.0328, -0.0749, -0.0503],
[-0.0245, 0.0085, 0.0549],
[ 0.0966, 0.0894, 0.0530],
[ 0.0035, 0.0399, 0.0123],
[ 0.0583, 0.1184, 0.1262],
[-0.0191, -0.0206, -0.0306],
[-0.0324, 0.0055, 0.1001],
[ 0.0955, 0.0659, -0.0545],
[-0.0504, 0.0231, -0.0013],
[ 0.0500, -0.0008, -0.0088],
[ 0.0982, 0.0941, 0.0976],
[-0.1233, -0.0280, -0.0897],
[-0.0005, -0.0530, -0.0020],
[-0.1273, -0.0932, -0.0680]
[-0.0346, 0.0244, 0.0681],
[ 0.0034, 0.0210, 0.0687],
[ 0.0275, -0.0668, -0.0433],
[-0.0174, 0.0160, 0.0617],
[ 0.0859, 0.0721, 0.0329],
[ 0.0004, 0.0383, 0.0115],
[ 0.0405, 0.0861, 0.0915],
[-0.0236, -0.0185, -0.0259],
[-0.0245, 0.0250, 0.1180],
[ 0.1008, 0.0755, -0.0421],
[-0.0515, 0.0201, 0.0011],
[ 0.0428, -0.0012, -0.0036],
[ 0.0817, 0.0765, 0.0749],
[-0.1264, -0.0522, -0.1103],
[-0.0280, -0.0881, -0.0499],
[-0.1262, -0.0982, -0.0778]
]
self.latent_rgb_factors_bias = [-0.0329, -0.0718, -0.0851]
self.taesd_decoder_name = "taef1_decoder"
def process_in(self, latent):
@@ -170,3 +175,30 @@ class Flux(SD3):
def process_out(self, latent):
return (latent / self.scale_factor) + self.shift_factor
class Mochi(LatentFormat):
latent_channels = 12
def __init__(self):
self.scale_factor = 1.0
self.latents_mean = torch.tensor([-0.06730895953510081, -0.038011381506090416, -0.07477820912866141,
-0.05565264470995561, 0.012767231469026969, -0.04703542746246419,
0.043896967884726704, -0.09346305707025976, -0.09918314763016893,
-0.008729793427399178, -0.011931556316503654, -0.0321993391887285]).view(1, self.latent_channels, 1, 1, 1)
self.latents_std = torch.tensor([0.9263795028493863, 0.9248894543193766, 0.9393059390890617,
0.959253732819592, 0.8244560132752793, 0.917259975397747,
0.9294154431013696, 1.3720942357788521, 0.881393668867029,
0.9168315692124348, 0.9185249279345552, 0.9274757570805041]).view(1, self.latent_channels, 1, 1, 1)
self.latent_rgb_factors = None #TODO
self.taesd_decoder_name = None #TODO
def process_in(self, latent):
latents_mean = self.latents_mean.to(latent.device, latent.dtype)
latents_std = self.latents_std.to(latent.device, latent.dtype)
return (latent - latents_mean) * self.scale_factor / latents_std
def process_out(self, latent):
latents_mean = self.latents_mean.to(latent.device, latent.dtype)
latents_std = self.latents_std.to(latent.device, latent.dtype)
return latent * latents_std / self.scale_factor + latents_mean

View File

@@ -1,4 +1,5 @@
import torch
import comfy.ops
def pad_to_patch_size(img, patch_size=(2, 2), padding_mode="circular"):
if padding_mode == "circular" and torch.jit.is_tracing() or torch.jit.is_scripting():
@@ -6,3 +7,21 @@ def pad_to_patch_size(img, patch_size=(2, 2), padding_mode="circular"):
pad_h = (patch_size[0] - img.shape[-2] % patch_size[0]) % patch_size[0]
pad_w = (patch_size[1] - img.shape[-1] % patch_size[1]) % patch_size[1]
return torch.nn.functional.pad(img, (0, pad_w, 0, pad_h), mode=padding_mode)
try:
rms_norm_torch = torch.nn.functional.rms_norm
except:
rms_norm_torch = None
def rms_norm(x, weight=None, eps=1e-6):
if rms_norm_torch is not None and not (torch.jit.is_tracing() or torch.jit.is_scripting()):
if weight is None:
return rms_norm_torch(x, (x.shape[-1],), eps=eps)
else:
return rms_norm_torch(x, weight.shape, weight=comfy.ops.cast_to(weight, dtype=x.dtype, device=x.device), eps=eps)
else:
r = x * torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + eps)
if weight is None:
return r
else:
return r * comfy.ops.cast_to(weight, dtype=x.dtype, device=x.device)

View File

@@ -0,0 +1,205 @@
#Original code can be found on: https://github.com/XLabs-AI/x-flux/blob/main/src/flux/controlnet.py
#modified to support different types of flux controlnets
import torch
import math
from torch import Tensor, nn
from einops import rearrange, repeat
from .layers import (DoubleStreamBlock, EmbedND, LastLayer,
MLPEmbedder, SingleStreamBlock,
timestep_embedding)
from .model import Flux
import comfy.ldm.common_dit
class MistolineCondDownsamplBlock(nn.Module):
def __init__(self, dtype=None, device=None, operations=None):
super().__init__()
self.encoder = nn.Sequential(
operations.Conv2d(3, 16, 3, padding=1, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(16, 16, 1, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(16, 16, 1, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device)
)
def forward(self, x):
return self.encoder(x)
class MistolineControlnetBlock(nn.Module):
def __init__(self, hidden_size, dtype=None, device=None, operations=None):
super().__init__()
self.linear = operations.Linear(hidden_size, hidden_size, dtype=dtype, device=device)
self.act = nn.SiLU()
def forward(self, x):
return self.act(self.linear(x))
class ControlNetFlux(Flux):
def __init__(self, latent_input=False, num_union_modes=0, mistoline=False, control_latent_channels=None, image_model=None, dtype=None, device=None, operations=None, **kwargs):
super().__init__(final_layer=False, dtype=dtype, device=device, operations=operations, **kwargs)
self.main_model_double = 19
self.main_model_single = 38
self.mistoline = mistoline
# add ControlNet blocks
if self.mistoline:
control_block = lambda : MistolineControlnetBlock(self.hidden_size, dtype=dtype, device=device, operations=operations)
else:
control_block = lambda : operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device)
self.controlnet_blocks = nn.ModuleList([])
for _ in range(self.params.depth):
self.controlnet_blocks.append(control_block())
self.controlnet_single_blocks = nn.ModuleList([])
for _ in range(self.params.depth_single_blocks):
self.controlnet_single_blocks.append(control_block())
self.num_union_modes = num_union_modes
self.controlnet_mode_embedder = None
if self.num_union_modes > 0:
self.controlnet_mode_embedder = operations.Embedding(self.num_union_modes, self.hidden_size, dtype=dtype, device=device)
self.gradient_checkpointing = False
self.latent_input = latent_input
if control_latent_channels is None:
control_latent_channels = self.in_channels
else:
control_latent_channels *= 2 * 2 #patch size
self.pos_embed_input = operations.Linear(control_latent_channels, self.hidden_size, bias=True, dtype=dtype, device=device)
if not self.latent_input:
if self.mistoline:
self.input_cond_block = MistolineCondDownsamplBlock(dtype=dtype, device=device, operations=operations)
else:
self.input_hint_block = nn.Sequential(
operations.Conv2d(3, 16, 3, padding=1, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device)
)
def forward_orig(
self,
img: Tensor,
img_ids: Tensor,
controlnet_cond: Tensor,
txt: Tensor,
txt_ids: Tensor,
timesteps: Tensor,
y: Tensor,
guidance: Tensor = None,
control_type: Tensor = None,
) -> Tensor:
if img.ndim != 3 or txt.ndim != 3:
raise ValueError("Input img and txt tensors must have 3 dimensions.")
# running on sequences img
img = self.img_in(img)
controlnet_cond = self.pos_embed_input(controlnet_cond)
img = img + controlnet_cond
vec = self.time_in(timestep_embedding(timesteps, 256))
if self.params.guidance_embed:
vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
vec = vec + self.vector_in(y)
txt = self.txt_in(txt)
if self.controlnet_mode_embedder is not None and len(control_type) > 0:
control_cond = self.controlnet_mode_embedder(torch.tensor(control_type, device=img.device), out_dtype=img.dtype).unsqueeze(0).repeat((txt.shape[0], 1, 1))
txt = torch.cat([control_cond, txt], dim=1)
txt_ids = torch.cat([txt_ids[:,:1], txt_ids], dim=1)
ids = torch.cat((txt_ids, img_ids), dim=1)
pe = self.pe_embedder(ids)
controlnet_double = ()
for i in range(len(self.double_blocks)):
img, txt = self.double_blocks[i](img=img, txt=txt, vec=vec, pe=pe)
controlnet_double = controlnet_double + (self.controlnet_blocks[i](img),)
img = torch.cat((txt, img), 1)
controlnet_single = ()
for i in range(len(self.single_blocks)):
img = self.single_blocks[i](img, vec=vec, pe=pe)
controlnet_single = controlnet_single + (self.controlnet_single_blocks[i](img[:, txt.shape[1] :, ...]),)
repeat = math.ceil(self.main_model_double / len(controlnet_double))
if self.latent_input:
out_input = ()
for x in controlnet_double:
out_input += (x,) * repeat
else:
out_input = (controlnet_double * repeat)
out = {"input": out_input[:self.main_model_double]}
if len(controlnet_single) > 0:
repeat = math.ceil(self.main_model_single / len(controlnet_single))
out_output = ()
if self.latent_input:
for x in controlnet_single:
out_output += (x,) * repeat
else:
out_output = (controlnet_single * repeat)
out["output"] = out_output[:self.main_model_single]
return out
def forward(self, x, timesteps, context, y, guidance=None, hint=None, **kwargs):
patch_size = 2
if self.latent_input:
hint = comfy.ldm.common_dit.pad_to_patch_size(hint, (patch_size, patch_size))
elif self.mistoline:
hint = hint * 2.0 - 1.0
hint = self.input_cond_block(hint)
else:
hint = hint * 2.0 - 1.0
hint = self.input_hint_block(hint)
hint = rearrange(hint, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
bs, c, h, w = x.shape
x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size))
img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
h_len = ((h + (patch_size // 2)) // patch_size)
w_len = ((w + (patch_size // 2)) // patch_size)
img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
img_ids[..., 1] = img_ids[..., 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype)[:, None]
img_ids[..., 2] = img_ids[..., 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype)[None, :]
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
return self.forward_orig(img, img_ids, hint, context, txt_ids, timesteps, y, guidance, control_type=kwargs.get("control_type", []))

View File

@@ -1,104 +0,0 @@
#Original code can be found on: https://github.com/XLabs-AI/x-flux/blob/main/src/flux/controlnet.py
import torch
from torch import Tensor, nn
from einops import rearrange, repeat
from .layers import (DoubleStreamBlock, EmbedND, LastLayer,
MLPEmbedder, SingleStreamBlock,
timestep_embedding)
from .model import Flux
import comfy.ldm.common_dit
class ControlNetFlux(Flux):
def __init__(self, image_model=None, dtype=None, device=None, operations=None, **kwargs):
super().__init__(final_layer=False, dtype=dtype, device=device, operations=operations, **kwargs)
# add ControlNet blocks
self.controlnet_blocks = nn.ModuleList([])
for _ in range(self.params.depth):
controlnet_block = operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device)
# controlnet_block = zero_module(controlnet_block)
self.controlnet_blocks.append(controlnet_block)
self.pos_embed_input = operations.Linear(self.in_channels, self.hidden_size, bias=True, dtype=dtype, device=device)
self.gradient_checkpointing = False
self.input_hint_block = nn.Sequential(
operations.Conv2d(3, 16, 3, padding=1, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device)
)
def forward_orig(
self,
img: Tensor,
img_ids: Tensor,
controlnet_cond: Tensor,
txt: Tensor,
txt_ids: Tensor,
timesteps: Tensor,
y: Tensor,
guidance: Tensor = None,
) -> Tensor:
if img.ndim != 3 or txt.ndim != 3:
raise ValueError("Input img and txt tensors must have 3 dimensions.")
# running on sequences img
img = self.img_in(img)
controlnet_cond = self.input_hint_block(controlnet_cond)
controlnet_cond = rearrange(controlnet_cond, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
controlnet_cond = self.pos_embed_input(controlnet_cond)
img = img + controlnet_cond
vec = self.time_in(timestep_embedding(timesteps, 256))
if self.params.guidance_embed:
vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
vec = vec + self.vector_in(y)
txt = self.txt_in(txt)
ids = torch.cat((txt_ids, img_ids), dim=1)
pe = self.pe_embedder(ids)
block_res_samples = ()
for block in self.double_blocks:
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
block_res_samples = block_res_samples + (img,)
controlnet_block_res_samples = ()
for block_res_sample, controlnet_block in zip(block_res_samples, self.controlnet_blocks):
block_res_sample = controlnet_block(block_res_sample)
controlnet_block_res_samples = controlnet_block_res_samples + (block_res_sample,)
return {"input": (controlnet_block_res_samples * 10)[:19]}
def forward(self, x, timesteps, context, y, guidance=None, hint=None, **kwargs):
hint = hint * 2.0 - 1.0
bs, c, h, w = x.shape
patch_size = 2
x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size))
img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
h_len = ((h + (patch_size // 2)) // patch_size)
w_len = ((w + (patch_size // 2)) // patch_size)
img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
img_ids[..., 1] = img_ids[..., 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype)[:, None]
img_ids[..., 2] = img_ids[..., 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype)[None, :]
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
return self.forward_orig(img, img_ids, hint, context, txt_ids, timesteps, y, guidance)

View File

@@ -6,6 +6,7 @@ from torch import Tensor, nn
from .math import attention, rope
import comfy.ops
import comfy.ldm.common_dit
class EmbedND(nn.Module):
@@ -63,10 +64,7 @@ class RMSNorm(torch.nn.Module):
self.scale = nn.Parameter(torch.empty((dim), dtype=dtype, device=device))
def forward(self, x: Tensor):
x_dtype = x.dtype
x = x.float()
rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6)
return (x * rrms).to(dtype=x_dtype) * comfy.ops.cast_to(self.scale, dtype=x_dtype, device=x.device)
return comfy.ldm.common_dit.rms_norm(x, self.scale, 1e-6)
class QKNorm(torch.nn.Module):

View File

@@ -108,7 +108,7 @@ class Flux(nn.Module):
raise ValueError("Didn't get guidance strength for guidance distilled model.")
vec = vec + self.guidance_in(timestep_embedding(guidance, 256).to(img.dtype))
vec = vec + self.vector_in(y)
vec = vec + self.vector_in(y[:,:self.params.vec_in_dim])
txt = self.txt_in(txt)
ids = torch.cat((txt_ids, img_ids), dim=1)
@@ -151,8 +151,8 @@ class Flux(nn.Module):
h_len = ((h + (patch_size // 2)) // patch_size)
w_len = ((w + (patch_size // 2)) // patch_size)
img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
img_ids[..., 1] = img_ids[..., 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype)[:, None]
img_ids[..., 2] = img_ids[..., 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype)[None, :]
img_ids[:, :, 1] = torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1)
img_ids[:, :, 2] = torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0)
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)

View File

@@ -0,0 +1,541 @@
#original code from https://github.com/genmoai/models under apache 2.0 license
#adapted to ComfyUI
from typing import Dict, List, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
# from flash_attn import flash_attn_varlen_qkvpacked_func
from comfy.ldm.modules.attention import optimized_attention
from .layers import (
FeedForward,
PatchEmbed,
RMSNorm,
TimestepEmbedder,
)
from .rope_mixed import (
compute_mixed_rotation,
create_position_matrix,
)
from .temporal_rope import apply_rotary_emb_qk_real
from .utils import (
AttentionPool,
modulate,
)
import comfy.ldm.common_dit
import comfy.ops
def modulated_rmsnorm(x, scale, eps=1e-6):
# Normalize and modulate
x_normed = comfy.ldm.common_dit.rms_norm(x, eps=eps)
x_modulated = x_normed * (1 + scale.unsqueeze(1))
return x_modulated
def residual_tanh_gated_rmsnorm(x, x_res, gate, eps=1e-6):
# Apply tanh to gate
tanh_gate = torch.tanh(gate).unsqueeze(1)
# Normalize and apply gated scaling
x_normed = comfy.ldm.common_dit.rms_norm(x_res, eps=eps) * tanh_gate
# Apply residual connection
output = x + x_normed
return output
class AsymmetricAttention(nn.Module):
def __init__(
self,
dim_x: int,
dim_y: int,
num_heads: int = 8,
qkv_bias: bool = True,
qk_norm: bool = False,
attn_drop: float = 0.0,
update_y: bool = True,
out_bias: bool = True,
attend_to_padding: bool = False,
softmax_scale: Optional[float] = None,
device: Optional[torch.device] = None,
dtype=None,
operations=None,
):
super().__init__()
self.dim_x = dim_x
self.dim_y = dim_y
self.num_heads = num_heads
self.head_dim = dim_x // num_heads
self.attn_drop = attn_drop
self.update_y = update_y
self.attend_to_padding = attend_to_padding
self.softmax_scale = softmax_scale
if dim_x % num_heads != 0:
raise ValueError(
f"dim_x={dim_x} should be divisible by num_heads={num_heads}"
)
# Input layers.
self.qkv_bias = qkv_bias
self.qkv_x = operations.Linear(dim_x, 3 * dim_x, bias=qkv_bias, device=device, dtype=dtype)
# Project text features to match visual features (dim_y -> dim_x)
self.qkv_y = operations.Linear(dim_y, 3 * dim_x, bias=qkv_bias, device=device, dtype=dtype)
# Query and key normalization for stability.
assert qk_norm
self.q_norm_x = RMSNorm(self.head_dim, device=device, dtype=dtype)
self.k_norm_x = RMSNorm(self.head_dim, device=device, dtype=dtype)
self.q_norm_y = RMSNorm(self.head_dim, device=device, dtype=dtype)
self.k_norm_y = RMSNorm(self.head_dim, device=device, dtype=dtype)
# Output layers. y features go back down from dim_x -> dim_y.
self.proj_x = operations.Linear(dim_x, dim_x, bias=out_bias, device=device, dtype=dtype)
self.proj_y = (
operations.Linear(dim_x, dim_y, bias=out_bias, device=device, dtype=dtype)
if update_y
else nn.Identity()
)
def forward(
self,
x: torch.Tensor, # (B, N, dim_x)
y: torch.Tensor, # (B, L, dim_y)
scale_x: torch.Tensor, # (B, dim_x), modulation for pre-RMSNorm.
scale_y: torch.Tensor, # (B, dim_y), modulation for pre-RMSNorm.
crop_y,
**rope_rotation,
) -> Tuple[torch.Tensor, torch.Tensor]:
rope_cos = rope_rotation.get("rope_cos")
rope_sin = rope_rotation.get("rope_sin")
# Pre-norm for visual features
x = modulated_rmsnorm(x, scale_x) # (B, M, dim_x) where M = N / cp_group_size
# Process visual features
# qkv_x = self.qkv_x(x) # (B, M, 3 * dim_x)
# assert qkv_x.dtype == torch.bfloat16
# qkv_x = all_to_all_collect_tokens(
# qkv_x, self.num_heads
# ) # (3, B, N, local_h, head_dim)
# Process text features
y = modulated_rmsnorm(y, scale_y) # (B, L, dim_y)
q_y, k_y, v_y = self.qkv_y(y).view(y.shape[0], y.shape[1], 3, self.num_heads, -1).unbind(2) # (B, N, local_h, head_dim)
q_y = self.q_norm_y(q_y)
k_y = self.k_norm_y(k_y)
# Split qkv_x into q, k, v
q_x, k_x, v_x = self.qkv_x(x).view(x.shape[0], x.shape[1], 3, self.num_heads, -1).unbind(2) # (B, N, local_h, head_dim)
q_x = self.q_norm_x(q_x)
q_x = apply_rotary_emb_qk_real(q_x, rope_cos, rope_sin)
k_x = self.k_norm_x(k_x)
k_x = apply_rotary_emb_qk_real(k_x, rope_cos, rope_sin)
q = torch.cat([q_x, q_y[:, :crop_y]], dim=1).transpose(1, 2)
k = torch.cat([k_x, k_y[:, :crop_y]], dim=1).transpose(1, 2)
v = torch.cat([v_x, v_y[:, :crop_y]], dim=1).transpose(1, 2)
xy = optimized_attention(q,
k,
v, self.num_heads, skip_reshape=True)
x, y = torch.tensor_split(xy, (q_x.shape[1],), dim=1)
x = self.proj_x(x)
o = torch.zeros(y.shape[0], q_y.shape[1], y.shape[-1], device=y.device, dtype=y.dtype)
o[:, :y.shape[1]] = y
y = self.proj_y(o)
# print("ox", x)
# print("oy", y)
return x, y
class AsymmetricJointBlock(nn.Module):
def __init__(
self,
hidden_size_x: int,
hidden_size_y: int,
num_heads: int,
*,
mlp_ratio_x: float = 8.0, # Ratio of hidden size to d_model for MLP for visual tokens.
mlp_ratio_y: float = 4.0, # Ratio of hidden size to d_model for MLP for text tokens.
update_y: bool = True, # Whether to update text tokens in this block.
device: Optional[torch.device] = None,
dtype=None,
operations=None,
**block_kwargs,
):
super().__init__()
self.update_y = update_y
self.hidden_size_x = hidden_size_x
self.hidden_size_y = hidden_size_y
self.mod_x = operations.Linear(hidden_size_x, 4 * hidden_size_x, device=device, dtype=dtype)
if self.update_y:
self.mod_y = operations.Linear(hidden_size_x, 4 * hidden_size_y, device=device, dtype=dtype)
else:
self.mod_y = operations.Linear(hidden_size_x, hidden_size_y, device=device, dtype=dtype)
# Self-attention:
self.attn = AsymmetricAttention(
hidden_size_x,
hidden_size_y,
num_heads=num_heads,
update_y=update_y,
device=device,
dtype=dtype,
operations=operations,
**block_kwargs,
)
# MLP.
mlp_hidden_dim_x = int(hidden_size_x * mlp_ratio_x)
assert mlp_hidden_dim_x == int(1536 * 8)
self.mlp_x = FeedForward(
in_features=hidden_size_x,
hidden_size=mlp_hidden_dim_x,
multiple_of=256,
ffn_dim_multiplier=None,
device=device,
dtype=dtype,
operations=operations,
)
# MLP for text not needed in last block.
if self.update_y:
mlp_hidden_dim_y = int(hidden_size_y * mlp_ratio_y)
self.mlp_y = FeedForward(
in_features=hidden_size_y,
hidden_size=mlp_hidden_dim_y,
multiple_of=256,
ffn_dim_multiplier=None,
device=device,
dtype=dtype,
operations=operations,
)
def forward(
self,
x: torch.Tensor,
c: torch.Tensor,
y: torch.Tensor,
**attn_kwargs,
):
"""Forward pass of a block.
Args:
x: (B, N, dim) tensor of visual tokens
c: (B, dim) tensor of conditioned features
y: (B, L, dim) tensor of text tokens
num_frames: Number of frames in the video. N = num_frames * num_spatial_tokens
Returns:
x: (B, N, dim) tensor of visual tokens after block
y: (B, L, dim) tensor of text tokens after block
"""
N = x.size(1)
c = F.silu(c)
mod_x = self.mod_x(c)
scale_msa_x, gate_msa_x, scale_mlp_x, gate_mlp_x = mod_x.chunk(4, dim=1)
mod_y = self.mod_y(c)
if self.update_y:
scale_msa_y, gate_msa_y, scale_mlp_y, gate_mlp_y = mod_y.chunk(4, dim=1)
else:
scale_msa_y = mod_y
# Self-attention block.
x_attn, y_attn = self.attn(
x,
y,
scale_x=scale_msa_x,
scale_y=scale_msa_y,
**attn_kwargs,
)
assert x_attn.size(1) == N
x = residual_tanh_gated_rmsnorm(x, x_attn, gate_msa_x)
if self.update_y:
y = residual_tanh_gated_rmsnorm(y, y_attn, gate_msa_y)
# MLP block.
x = self.ff_block_x(x, scale_mlp_x, gate_mlp_x)
if self.update_y:
y = self.ff_block_y(y, scale_mlp_y, gate_mlp_y)
return x, y
def ff_block_x(self, x, scale_x, gate_x):
x_mod = modulated_rmsnorm(x, scale_x)
x_res = self.mlp_x(x_mod)
x = residual_tanh_gated_rmsnorm(x, x_res, gate_x) # Sandwich norm
return x
def ff_block_y(self, y, scale_y, gate_y):
y_mod = modulated_rmsnorm(y, scale_y)
y_res = self.mlp_y(y_mod)
y = residual_tanh_gated_rmsnorm(y, y_res, gate_y) # Sandwich norm
return y
class FinalLayer(nn.Module):
"""
The final layer of DiT.
"""
def __init__(
self,
hidden_size,
patch_size,
out_channels,
device: Optional[torch.device] = None,
dtype=None,
operations=None,
):
super().__init__()
self.norm_final = operations.LayerNorm(
hidden_size, elementwise_affine=False, eps=1e-6, device=device, dtype=dtype
)
self.mod = operations.Linear(hidden_size, 2 * hidden_size, device=device, dtype=dtype)
self.linear = operations.Linear(
hidden_size, patch_size * patch_size * out_channels, device=device, dtype=dtype
)
def forward(self, x, c):
c = F.silu(c)
shift, scale = self.mod(c).chunk(2, dim=1)
x = modulate(self.norm_final(x), shift, scale)
x = self.linear(x)
return x
class AsymmDiTJoint(nn.Module):
"""
Diffusion model with a Transformer backbone.
Ingests text embeddings instead of a label.
"""
def __init__(
self,
*,
patch_size=2,
in_channels=4,
hidden_size_x=1152,
hidden_size_y=1152,
depth=48,
num_heads=16,
mlp_ratio_x=8.0,
mlp_ratio_y=4.0,
use_t5: bool = False,
t5_feat_dim: int = 4096,
t5_token_length: int = 256,
learn_sigma=True,
patch_embed_bias: bool = True,
timestep_mlp_bias: bool = True,
attend_to_padding: bool = False,
timestep_scale: Optional[float] = None,
use_extended_posenc: bool = False,
posenc_preserve_area: bool = False,
rope_theta: float = 10000.0,
image_model=None,
device: Optional[torch.device] = None,
dtype=None,
operations=None,
**block_kwargs,
):
super().__init__()
self.dtype = dtype
self.learn_sigma = learn_sigma
self.in_channels = in_channels
self.out_channels = in_channels * 2 if learn_sigma else in_channels
self.patch_size = patch_size
self.num_heads = num_heads
self.hidden_size_x = hidden_size_x
self.hidden_size_y = hidden_size_y
self.head_dim = (
hidden_size_x // num_heads
) # Head dimension and count is determined by visual.
self.attend_to_padding = attend_to_padding
self.use_extended_posenc = use_extended_posenc
self.posenc_preserve_area = posenc_preserve_area
self.use_t5 = use_t5
self.t5_token_length = t5_token_length
self.t5_feat_dim = t5_feat_dim
self.rope_theta = (
rope_theta # Scaling factor for frequency computation for temporal RoPE.
)
self.x_embedder = PatchEmbed(
patch_size=patch_size,
in_chans=in_channels,
embed_dim=hidden_size_x,
bias=patch_embed_bias,
dtype=dtype,
device=device,
operations=operations
)
# Conditionings
# Timestep
self.t_embedder = TimestepEmbedder(
hidden_size_x, bias=timestep_mlp_bias, timestep_scale=timestep_scale, dtype=dtype, device=device, operations=operations
)
if self.use_t5:
# Caption Pooling (T5)
self.t5_y_embedder = AttentionPool(
t5_feat_dim, num_heads=8, output_dim=hidden_size_x, dtype=dtype, device=device, operations=operations
)
# Dense Embedding Projection (T5)
self.t5_yproj = operations.Linear(
t5_feat_dim, hidden_size_y, bias=True, dtype=dtype, device=device
)
# Initialize pos_frequencies as an empty parameter.
self.pos_frequencies = nn.Parameter(
torch.empty(3, self.num_heads, self.head_dim // 2, dtype=dtype, device=device)
)
assert not self.attend_to_padding
# for depth 48:
# b = 0: AsymmetricJointBlock, update_y=True
# b = 1: AsymmetricJointBlock, update_y=True
# ...
# b = 46: AsymmetricJointBlock, update_y=True
# b = 47: AsymmetricJointBlock, update_y=False. No need to update text features.
blocks = []
for b in range(depth):
# Joint multi-modal block
update_y = b < depth - 1
block = AsymmetricJointBlock(
hidden_size_x,
hidden_size_y,
num_heads,
mlp_ratio_x=mlp_ratio_x,
mlp_ratio_y=mlp_ratio_y,
update_y=update_y,
attend_to_padding=attend_to_padding,
device=device,
dtype=dtype,
operations=operations,
**block_kwargs,
)
blocks.append(block)
self.blocks = nn.ModuleList(blocks)
self.final_layer = FinalLayer(
hidden_size_x, patch_size, self.out_channels, dtype=dtype, device=device, operations=operations
)
def embed_x(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x: (B, C=12, T, H, W) tensor of visual tokens
Returns:
x: (B, C=3072, N) tensor of visual tokens with positional embedding.
"""
return self.x_embedder(x) # Convert BcTHW to BCN
def prepare(
self,
x: torch.Tensor,
sigma: torch.Tensor,
t5_feat: torch.Tensor,
t5_mask: torch.Tensor,
):
"""Prepare input and conditioning embeddings."""
# Visual patch embeddings with positional encoding.
T, H, W = x.shape[-3:]
pH, pW = H // self.patch_size, W // self.patch_size
x = self.embed_x(x) # (B, N, D), where N = T * H * W / patch_size ** 2
assert x.ndim == 3
B = x.size(0)
pH, pW = H // self.patch_size, W // self.patch_size
N = T * pH * pW
assert x.size(1) == N
pos = create_position_matrix(
T, pH=pH, pW=pW, device=x.device, dtype=torch.float32
) # (N, 3)
rope_cos, rope_sin = compute_mixed_rotation(
freqs=comfy.ops.cast_to(self.pos_frequencies, dtype=x.dtype, device=x.device), pos=pos
) # Each are (N, num_heads, dim // 2)
c_t = self.t_embedder(1 - sigma, out_dtype=x.dtype) # (B, D)
t5_y_pool = self.t5_y_embedder(t5_feat, t5_mask) # (B, D)
c = c_t + t5_y_pool
y_feat = self.t5_yproj(t5_feat) # (B, L, t5_feat_dim) --> (B, L, D)
return x, c, y_feat, rope_cos, rope_sin
def forward(
self,
x: torch.Tensor,
timestep: torch.Tensor,
context: List[torch.Tensor],
attention_mask: List[torch.Tensor],
num_tokens=256,
packed_indices: Dict[str, torch.Tensor] = None,
rope_cos: torch.Tensor = None,
rope_sin: torch.Tensor = None,
control=None, **kwargs
):
y_feat = context
y_mask = attention_mask
sigma = timestep
"""Forward pass of DiT.
Args:
x: (B, C, T, H, W) tensor of spatial inputs (images or latent representations of images)
sigma: (B,) tensor of noise standard deviations
y_feat: List((B, L, y_feat_dim) tensor of caption token features. For SDXL text encoders: L=77, y_feat_dim=2048)
y_mask: List((B, L) boolean tensor indicating which tokens are not padding)
packed_indices: Dict with keys for Flash Attention. Result of compute_packed_indices.
"""
B, _, T, H, W = x.shape
x, c, y_feat, rope_cos, rope_sin = self.prepare(
x, sigma, y_feat, y_mask
)
del y_mask
for i, block in enumerate(self.blocks):
x, y_feat = block(
x,
c,
y_feat,
rope_cos=rope_cos,
rope_sin=rope_sin,
crop_y=num_tokens,
) # (B, M, D), (B, L, D)
del y_feat # Final layers don't use dense text features.
x = self.final_layer(x, c) # (B, M, patch_size ** 2 * out_channels)
x = rearrange(
x,
"B (T hp wp) (p1 p2 c) -> B c T (hp p1) (wp p2)",
T=T,
hp=H // self.patch_size,
wp=W // self.patch_size,
p1=self.patch_size,
p2=self.patch_size,
c=self.out_channels,
)
return -x

View File

@@ -0,0 +1,164 @@
#original code from https://github.com/genmoai/models under apache 2.0 license
#adapted to ComfyUI
import collections.abc
import math
from itertools import repeat
from typing import Callable, Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
import comfy.ldm.common_dit
# From PyTorch internals
def _ntuple(n):
def parse(x):
if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
return tuple(x)
return tuple(repeat(x, n))
return parse
to_2tuple = _ntuple(2)
class TimestepEmbedder(nn.Module):
def __init__(
self,
hidden_size: int,
frequency_embedding_size: int = 256,
*,
bias: bool = True,
timestep_scale: Optional[float] = None,
dtype=None,
device=None,
operations=None,
):
super().__init__()
self.mlp = nn.Sequential(
operations.Linear(frequency_embedding_size, hidden_size, bias=bias, dtype=dtype, device=device),
nn.SiLU(),
operations.Linear(hidden_size, hidden_size, bias=bias, dtype=dtype, device=device),
)
self.frequency_embedding_size = frequency_embedding_size
self.timestep_scale = timestep_scale
@staticmethod
def timestep_embedding(t, dim, max_period=10000):
half = dim // 2
freqs = torch.arange(start=0, end=half, dtype=torch.float32, device=t.device)
freqs.mul_(-math.log(max_period) / half).exp_()
args = t[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat(
[embedding, torch.zeros_like(embedding[:, :1])], dim=-1
)
return embedding
def forward(self, t, out_dtype):
if self.timestep_scale is not None:
t = t * self.timestep_scale
t_freq = self.timestep_embedding(t, self.frequency_embedding_size).to(dtype=out_dtype)
t_emb = self.mlp(t_freq)
return t_emb
class FeedForward(nn.Module):
def __init__(
self,
in_features: int,
hidden_size: int,
multiple_of: int,
ffn_dim_multiplier: Optional[float],
device: Optional[torch.device] = None,
dtype=None,
operations=None,
):
super().__init__()
# keep parameter count and computation constant compared to standard FFN
hidden_size = int(2 * hidden_size / 3)
# custom dim factor multiplier
if ffn_dim_multiplier is not None:
hidden_size = int(ffn_dim_multiplier * hidden_size)
hidden_size = multiple_of * ((hidden_size + multiple_of - 1) // multiple_of)
self.hidden_dim = hidden_size
self.w1 = operations.Linear(in_features, 2 * hidden_size, bias=False, device=device, dtype=dtype)
self.w2 = operations.Linear(hidden_size, in_features, bias=False, device=device, dtype=dtype)
def forward(self, x):
x, gate = self.w1(x).chunk(2, dim=-1)
x = self.w2(F.silu(x) * gate)
return x
class PatchEmbed(nn.Module):
def __init__(
self,
patch_size: int = 16,
in_chans: int = 3,
embed_dim: int = 768,
norm_layer: Optional[Callable] = None,
flatten: bool = True,
bias: bool = True,
dynamic_img_pad: bool = False,
dtype=None,
device=None,
operations=None,
):
super().__init__()
self.patch_size = to_2tuple(patch_size)
self.flatten = flatten
self.dynamic_img_pad = dynamic_img_pad
self.proj = operations.Conv2d(
in_chans,
embed_dim,
kernel_size=patch_size,
stride=patch_size,
bias=bias,
device=device,
dtype=dtype,
)
assert norm_layer is None
self.norm = (
norm_layer(embed_dim, device=device) if norm_layer else nn.Identity()
)
def forward(self, x):
B, _C, T, H, W = x.shape
if not self.dynamic_img_pad:
assert H % self.patch_size[0] == 0, f"Input height ({H}) should be divisible by patch size ({self.patch_size[0]})."
assert W % self.patch_size[1] == 0, f"Input width ({W}) should be divisible by patch size ({self.patch_size[1]})."
else:
pad_h = (self.patch_size[0] - H % self.patch_size[0]) % self.patch_size[0]
pad_w = (self.patch_size[1] - W % self.patch_size[1]) % self.patch_size[1]
x = F.pad(x, (0, pad_w, 0, pad_h))
x = rearrange(x, "B C T H W -> (B T) C H W", B=B, T=T)
x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size, padding_mode='circular')
x = self.proj(x)
# Flatten temporal and spatial dimensions.
if not self.flatten:
raise NotImplementedError("Must flatten output.")
x = rearrange(x, "(B T) C H W -> B (T H W) C", B=B, T=T)
x = self.norm(x)
return x
class RMSNorm(torch.nn.Module):
def __init__(self, hidden_size, eps=1e-5, device=None, dtype=None):
super().__init__()
self.eps = eps
self.weight = torch.nn.Parameter(torch.empty(hidden_size, device=device, dtype=dtype))
self.register_parameter("bias", None)
def forward(self, x):
return comfy.ldm.common_dit.rms_norm(x, self.weight, self.eps)

View File

@@ -0,0 +1,88 @@
#original code from https://github.com/genmoai/models under apache 2.0 license
# import functools
import math
import torch
def centers(start: float, stop, num, dtype=None, device=None):
"""linspace through bin centers.
Args:
start (float): Start of the range.
stop (float): End of the range.
num (int): Number of points.
dtype (torch.dtype): Data type of the points.
device (torch.device): Device of the points.
Returns:
centers (Tensor): Centers of the bins. Shape: (num,).
"""
edges = torch.linspace(start, stop, num + 1, dtype=dtype, device=device)
return (edges[:-1] + edges[1:]) / 2
# @functools.lru_cache(maxsize=1)
def create_position_matrix(
T: int,
pH: int,
pW: int,
device: torch.device,
dtype: torch.dtype,
*,
target_area: float = 36864,
):
"""
Args:
T: int - Temporal dimension
pH: int - Height dimension after patchify
pW: int - Width dimension after patchify
Returns:
pos: [T * pH * pW, 3] - position matrix
"""
# Create 1D tensors for each dimension
t = torch.arange(T, dtype=dtype)
# Positionally interpolate to area 36864.
# (3072x3072 frame with 16x16 patches = 192x192 latents).
# This automatically scales rope positions when the resolution changes.
# We use a large target area so the model is more sensitive
# to changes in the learned pos_frequencies matrix.
scale = math.sqrt(target_area / (pW * pH))
w = centers(-pW * scale / 2, pW * scale / 2, pW)
h = centers(-pH * scale / 2, pH * scale / 2, pH)
# Use meshgrid to create 3D grids
grid_t, grid_h, grid_w = torch.meshgrid(t, h, w, indexing="ij")
# Stack and reshape the grids.
pos = torch.stack([grid_t, grid_h, grid_w], dim=-1) # [T, pH, pW, 3]
pos = pos.view(-1, 3) # [T * pH * pW, 3]
pos = pos.to(dtype=dtype, device=device)
return pos
def compute_mixed_rotation(
freqs: torch.Tensor,
pos: torch.Tensor,
):
"""
Project each 3-dim position into per-head, per-head-dim 1D frequencies.
Args:
freqs: [3, num_heads, num_freqs] - learned rotation frequency (for t, row, col) for each head position
pos: [N, 3] - position of each token
num_heads: int
Returns:
freqs_cos: [N, num_heads, num_freqs] - cosine components
freqs_sin: [N, num_heads, num_freqs] - sine components
"""
assert freqs.ndim == 3
freqs_sum = torch.einsum("Nd,dhf->Nhf", pos.to(freqs), freqs)
freqs_cos = torch.cos(freqs_sum)
freqs_sin = torch.sin(freqs_sum)
return freqs_cos, freqs_sin

View File

@@ -0,0 +1,34 @@
#original code from https://github.com/genmoai/models under apache 2.0 license
# Based on Llama3 Implementation.
import torch
def apply_rotary_emb_qk_real(
xqk: torch.Tensor,
freqs_cos: torch.Tensor,
freqs_sin: torch.Tensor,
) -> torch.Tensor:
"""
Apply rotary embeddings to input tensors using the given frequency tensor without complex numbers.
Args:
xqk (torch.Tensor): Query and/or Key tensors to apply rotary embeddings. Shape: (B, S, *, num_heads, D)
Can be either just query or just key, or both stacked along some batch or * dim.
freqs_cos (torch.Tensor): Precomputed cosine frequency tensor.
freqs_sin (torch.Tensor): Precomputed sine frequency tensor.
Returns:
torch.Tensor: The input tensor with rotary embeddings applied.
"""
# Split the last dimension into even and odd parts
xqk_even = xqk[..., 0::2]
xqk_odd = xqk[..., 1::2]
# Apply rotation
cos_part = (xqk_even * freqs_cos - xqk_odd * freqs_sin).type_as(xqk)
sin_part = (xqk_even * freqs_sin + xqk_odd * freqs_cos).type_as(xqk)
# Interleave the results back into the original shape
out = torch.stack([cos_part, sin_part], dim=-1).flatten(-2)
return out

View File

@@ -0,0 +1,102 @@
#original code from https://github.com/genmoai/models under apache 2.0 license
#adapted to ComfyUI
from typing import Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
def modulate(x, shift, scale):
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
def pool_tokens(x: torch.Tensor, mask: torch.Tensor, *, keepdim=False) -> torch.Tensor:
"""
Pool tokens in x using mask.
NOTE: We assume x does not require gradients.
Args:
x: (B, L, D) tensor of tokens.
mask: (B, L) boolean tensor indicating which tokens are not padding.
Returns:
pooled: (B, D) tensor of pooled tokens.
"""
assert x.size(1) == mask.size(1) # Expected mask to have same length as tokens.
assert x.size(0) == mask.size(0) # Expected mask to have same batch size as tokens.
mask = mask[:, :, None].to(dtype=x.dtype)
mask = mask / mask.sum(dim=1, keepdim=True).clamp(min=1)
pooled = (x * mask).sum(dim=1, keepdim=keepdim)
return pooled
class AttentionPool(nn.Module):
def __init__(
self,
embed_dim: int,
num_heads: int,
output_dim: int = None,
device: Optional[torch.device] = None,
dtype=None,
operations=None,
):
"""
Args:
spatial_dim (int): Number of tokens in sequence length.
embed_dim (int): Dimensionality of input tokens.
num_heads (int): Number of attention heads.
output_dim (int): Dimensionality of output tokens. Defaults to embed_dim.
"""
super().__init__()
self.num_heads = num_heads
self.to_kv = operations.Linear(embed_dim, 2 * embed_dim, device=device, dtype=dtype)
self.to_q = operations.Linear(embed_dim, embed_dim, device=device, dtype=dtype)
self.to_out = operations.Linear(embed_dim, output_dim or embed_dim, device=device, dtype=dtype)
def forward(self, x, mask):
"""
Args:
x (torch.Tensor): (B, L, D) tensor of input tokens.
mask (torch.Tensor): (B, L) boolean tensor indicating which tokens are not padding.
NOTE: We assume x does not require gradients.
Returns:
x (torch.Tensor): (B, D) tensor of pooled tokens.
"""
D = x.size(2)
# Construct attention mask, shape: (B, 1, num_queries=1, num_keys=1+L).
attn_mask = mask[:, None, None, :].bool() # (B, 1, 1, L).
attn_mask = F.pad(attn_mask, (1, 0), value=True) # (B, 1, 1, 1+L).
# Average non-padding token features. These will be used as the query.
x_pool = pool_tokens(x, mask, keepdim=True) # (B, 1, D)
# Concat pooled features to input sequence.
x = torch.cat([x_pool, x], dim=1) # (B, L+1, D)
# Compute queries, keys, values. Only the mean token is used to create a query.
kv = self.to_kv(x) # (B, L+1, 2 * D)
q = self.to_q(x[:, 0]) # (B, D)
# Extract heads.
head_dim = D // self.num_heads
kv = kv.unflatten(2, (2, self.num_heads, head_dim)) # (B, 1+L, 2, H, head_dim)
kv = kv.transpose(1, 3) # (B, H, 2, 1+L, head_dim)
k, v = kv.unbind(2) # (B, H, 1+L, head_dim)
q = q.unflatten(1, (self.num_heads, head_dim)) # (B, H, head_dim)
q = q.unsqueeze(2) # (B, H, 1, head_dim)
# Compute attention.
x = F.scaled_dot_product_attention(
q, k, v, attn_mask=attn_mask, dropout_p=0.0
) # (B, H, 1, head_dim)
# Concatenate heads and run output.
x = x.squeeze(2).flatten(1, 2) # (B, D = H * head_dim)
x = self.to_out(x)
return x

View File

@@ -0,0 +1,480 @@
#original code from https://github.com/genmoai/models under apache 2.0 license
#adapted to ComfyUI
from typing import Callable, List, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
import comfy.ops
ops = comfy.ops.disable_weight_init
# import mochi_preview.dit.joint_model.context_parallel as cp
# from mochi_preview.vae.cp_conv import cp_pass_frames, gather_all_frames
def cast_tuple(t, length=1):
return t if isinstance(t, tuple) else ((t,) * length)
class GroupNormSpatial(ops.GroupNorm):
"""
GroupNorm applied per-frame.
"""
def forward(self, x: torch.Tensor, *, chunk_size: int = 8):
B, C, T, H, W = x.shape
x = rearrange(x, "B C T H W -> (B T) C H W")
# Run group norm in chunks.
output = torch.empty_like(x)
for b in range(0, B * T, chunk_size):
output[b : b + chunk_size] = super().forward(x[b : b + chunk_size])
return rearrange(output, "(B T) C H W -> B C T H W", B=B, T=T)
class PConv3d(ops.Conv3d):
def __init__(
self,
in_channels,
out_channels,
kernel_size: Union[int, Tuple[int, int, int]],
stride: Union[int, Tuple[int, int, int]],
causal: bool = True,
context_parallel: bool = True,
**kwargs,
):
self.causal = causal
self.context_parallel = context_parallel
kernel_size = cast_tuple(kernel_size, 3)
stride = cast_tuple(stride, 3)
height_pad = (kernel_size[1] - 1) // 2
width_pad = (kernel_size[2] - 1) // 2
super().__init__(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
dilation=(1, 1, 1),
padding=(0, height_pad, width_pad),
**kwargs,
)
def forward(self, x: torch.Tensor):
# Compute padding amounts.
context_size = self.kernel_size[0] - 1
if self.causal:
pad_front = context_size
pad_back = 0
else:
pad_front = context_size // 2
pad_back = context_size - pad_front
# Apply padding.
assert self.padding_mode == "replicate" # DEBUG
mode = "constant" if self.padding_mode == "zeros" else self.padding_mode
x = F.pad(x, (0, 0, 0, 0, pad_front, pad_back), mode=mode)
return super().forward(x)
class Conv1x1(ops.Linear):
"""*1x1 Conv implemented with a linear layer."""
def __init__(self, in_features: int, out_features: int, *args, **kwargs):
super().__init__(in_features, out_features, *args, **kwargs)
def forward(self, x: torch.Tensor):
"""Forward pass.
Args:
x: Input tensor. Shape: [B, C, *] or [B, *, C].
Returns:
x: Output tensor. Shape: [B, C', *] or [B, *, C'].
"""
x = x.movedim(1, -1)
x = super().forward(x)
x = x.movedim(-1, 1)
return x
class DepthToSpaceTime(nn.Module):
def __init__(
self,
temporal_expansion: int,
spatial_expansion: int,
):
super().__init__()
self.temporal_expansion = temporal_expansion
self.spatial_expansion = spatial_expansion
# When printed, this module should show the temporal and spatial expansion factors.
def extra_repr(self):
return f"texp={self.temporal_expansion}, sexp={self.spatial_expansion}"
def forward(self, x: torch.Tensor):
"""Forward pass.
Args:
x: Input tensor. Shape: [B, C, T, H, W].
Returns:
x: Rearranged tensor. Shape: [B, C/(st*s*s), T*st, H*s, W*s].
"""
x = rearrange(
x,
"B (C st sh sw) T H W -> B C (T st) (H sh) (W sw)",
st=self.temporal_expansion,
sh=self.spatial_expansion,
sw=self.spatial_expansion,
)
# cp_rank, _ = cp.get_cp_rank_size()
if self.temporal_expansion > 1: # and cp_rank == 0:
# Drop the first self.temporal_expansion - 1 frames.
# This is because we always want the 3x3x3 conv filter to only apply
# to the first frame, and the first frame doesn't need to be repeated.
assert all(x.shape)
x = x[:, :, self.temporal_expansion - 1 :]
assert all(x.shape)
return x
def norm_fn(
in_channels: int,
affine: bool = True,
):
return GroupNormSpatial(affine=affine, num_groups=32, num_channels=in_channels)
class ResBlock(nn.Module):
"""Residual block that preserves the spatial dimensions."""
def __init__(
self,
channels: int,
*,
affine: bool = True,
attn_block: Optional[nn.Module] = None,
padding_mode: str = "replicate",
causal: bool = True,
):
super().__init__()
self.channels = channels
assert causal
self.stack = nn.Sequential(
norm_fn(channels, affine=affine),
nn.SiLU(inplace=True),
PConv3d(
in_channels=channels,
out_channels=channels,
kernel_size=(3, 3, 3),
stride=(1, 1, 1),
padding_mode=padding_mode,
bias=True,
# causal=causal,
),
norm_fn(channels, affine=affine),
nn.SiLU(inplace=True),
PConv3d(
in_channels=channels,
out_channels=channels,
kernel_size=(3, 3, 3),
stride=(1, 1, 1),
padding_mode=padding_mode,
bias=True,
# causal=causal,
),
)
self.attn_block = attn_block if attn_block else nn.Identity()
def forward(self, x: torch.Tensor):
"""Forward pass.
Args:
x: Input tensor. Shape: [B, C, T, H, W].
"""
residual = x
x = self.stack(x)
x = x + residual
del residual
return self.attn_block(x)
class CausalUpsampleBlock(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
num_res_blocks: int,
*,
temporal_expansion: int = 2,
spatial_expansion: int = 2,
**block_kwargs,
):
super().__init__()
blocks = []
for _ in range(num_res_blocks):
blocks.append(block_fn(in_channels, **block_kwargs))
self.blocks = nn.Sequential(*blocks)
self.temporal_expansion = temporal_expansion
self.spatial_expansion = spatial_expansion
# Change channels in the final convolution layer.
self.proj = Conv1x1(
in_channels,
out_channels * temporal_expansion * (spatial_expansion**2),
)
self.d2st = DepthToSpaceTime(
temporal_expansion=temporal_expansion, spatial_expansion=spatial_expansion
)
def forward(self, x):
x = self.blocks(x)
x = self.proj(x)
x = self.d2st(x)
return x
def block_fn(channels, *, has_attention: bool = False, **block_kwargs):
assert has_attention is False #NOTE: if this is ever true add back the attention code.
attn_block = None #AttentionBlock(channels) if has_attention else None
return ResBlock(
channels, affine=True, attn_block=attn_block, **block_kwargs
)
class DownsampleBlock(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
num_res_blocks,
*,
temporal_reduction=2,
spatial_reduction=2,
**block_kwargs,
):
"""
Downsample block for the VAE encoder.
Args:
in_channels: Number of input channels.
out_channels: Number of output channels.
num_res_blocks: Number of residual blocks.
temporal_reduction: Temporal reduction factor.
spatial_reduction: Spatial reduction factor.
"""
super().__init__()
layers = []
# Change the channel count in the strided convolution.
# This lets the ResBlock have uniform channel count,
# as in ConvNeXt.
assert in_channels != out_channels
layers.append(
PConv3d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=(temporal_reduction, spatial_reduction, spatial_reduction),
stride=(temporal_reduction, spatial_reduction, spatial_reduction),
padding_mode="replicate",
bias=True,
)
)
for _ in range(num_res_blocks):
layers.append(block_fn(out_channels, **block_kwargs))
self.layers = nn.Sequential(*layers)
def forward(self, x):
return self.layers(x)
def add_fourier_features(inputs: torch.Tensor, start=6, stop=8, step=1):
num_freqs = (stop - start) // step
assert inputs.ndim == 5
C = inputs.size(1)
# Create Base 2 Fourier features.
freqs = torch.arange(start, stop, step, dtype=inputs.dtype, device=inputs.device)
assert num_freqs == len(freqs)
w = torch.pow(2.0, freqs) * (2 * torch.pi) # [num_freqs]
C = inputs.shape[1]
w = w.repeat(C)[None, :, None, None, None] # [1, C * num_freqs, 1, 1, 1]
# Interleaved repeat of input channels to match w.
h = inputs.repeat_interleave(num_freqs, dim=1) # [B, C * num_freqs, T, H, W]
# Scale channels by frequency.
h = w * h
return torch.cat(
[
inputs,
torch.sin(h),
torch.cos(h),
],
dim=1,
)
class FourierFeatures(nn.Module):
def __init__(self, start: int = 6, stop: int = 8, step: int = 1):
super().__init__()
self.start = start
self.stop = stop
self.step = step
def forward(self, inputs):
"""Add Fourier features to inputs.
Args:
inputs: Input tensor. Shape: [B, C, T, H, W]
Returns:
h: Output tensor. Shape: [B, (1 + 2 * num_freqs) * C, T, H, W]
"""
return add_fourier_features(inputs, self.start, self.stop, self.step)
class Decoder(nn.Module):
def __init__(
self,
*,
out_channels: int = 3,
latent_dim: int,
base_channels: int,
channel_multipliers: List[int],
num_res_blocks: List[int],
temporal_expansions: Optional[List[int]] = None,
spatial_expansions: Optional[List[int]] = None,
has_attention: List[bool],
output_norm: bool = True,
nonlinearity: str = "silu",
output_nonlinearity: str = "silu",
causal: bool = True,
**block_kwargs,
):
super().__init__()
self.input_channels = latent_dim
self.base_channels = base_channels
self.channel_multipliers = channel_multipliers
self.num_res_blocks = num_res_blocks
self.output_nonlinearity = output_nonlinearity
assert nonlinearity == "silu"
assert causal
ch = [mult * base_channels for mult in channel_multipliers]
self.num_up_blocks = len(ch) - 1
assert len(num_res_blocks) == self.num_up_blocks + 2
blocks = []
first_block = [
nn.Conv3d(latent_dim, ch[-1], kernel_size=(1, 1, 1))
] # Input layer.
# First set of blocks preserve channel count.
for _ in range(num_res_blocks[-1]):
first_block.append(
block_fn(
ch[-1],
has_attention=has_attention[-1],
causal=causal,
**block_kwargs,
)
)
blocks.append(nn.Sequential(*first_block))
assert len(temporal_expansions) == len(spatial_expansions) == self.num_up_blocks
assert len(num_res_blocks) == len(has_attention) == self.num_up_blocks + 2
upsample_block_fn = CausalUpsampleBlock
for i in range(self.num_up_blocks):
block = upsample_block_fn(
ch[-i - 1],
ch[-i - 2],
num_res_blocks=num_res_blocks[-i - 2],
has_attention=has_attention[-i - 2],
temporal_expansion=temporal_expansions[-i - 1],
spatial_expansion=spatial_expansions[-i - 1],
causal=causal,
**block_kwargs,
)
blocks.append(block)
assert not output_norm
# Last block. Preserve channel count.
last_block = []
for _ in range(num_res_blocks[0]):
last_block.append(
block_fn(
ch[0], has_attention=has_attention[0], causal=causal, **block_kwargs
)
)
blocks.append(nn.Sequential(*last_block))
self.blocks = nn.ModuleList(blocks)
self.output_proj = Conv1x1(ch[0], out_channels)
def forward(self, x):
"""Forward pass.
Args:
x: Latent tensor. Shape: [B, input_channels, t, h, w]. Scaled [-1, 1].
Returns:
x: Reconstructed video tensor. Shape: [B, C, T, H, W]. Scaled to [-1, 1].
T + 1 = (t - 1) * 4.
H = h * 16, W = w * 16.
"""
for block in self.blocks:
x = block(x)
if self.output_nonlinearity == "silu":
x = F.silu(x, inplace=not self.training)
else:
assert (
not self.output_nonlinearity
) # StyleGAN3 omits the to-RGB nonlinearity.
return self.output_proj(x).contiguous()
class VideoVAE(nn.Module):
def __init__(self):
super().__init__()
self.encoder = None #TODO once the model releases
self.decoder = Decoder(
out_channels=3,
base_channels=128,
channel_multipliers=[1, 2, 4, 6],
temporal_expansions=[1, 2, 3],
spatial_expansions=[2, 2, 2],
num_res_blocks=[3, 3, 4, 6, 3],
latent_dim=12,
has_attention=[False, False, False, False, False],
padding_mode="replicate",
output_norm=False,
nonlinearity="silu",
output_nonlinearity="silu",
causal=True,
)
def encode(self, x):
return self.encoder(x)
def decode(self, x):
return self.decoder(x)

View File

@@ -372,7 +372,7 @@ class HunYuanDiT(nn.Module):
for layer, block in enumerate(self.blocks):
if layer > self.depth // 2:
if controls is not None:
skip = skips.pop() + controls.pop()
skip = skips.pop() + controls.pop().to(dtype=x.dtype)
else:
skip = skips.pop()
x = block(x, c, text_states, freqs_cis_img, skip) # (N, L, D)

View File

@@ -1,11 +1,11 @@
import logging
import math
from typing import Dict, Optional
from typing import Dict, Optional, List
import numpy as np
import torch
import torch.nn as nn
from .. import attention
from ..attention import optimized_attention
from einops import rearrange, repeat
from .util import timestep_embedding
import comfy.ops
@@ -97,7 +97,7 @@ class PatchEmbed(nn.Module):
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
def forward(self, x):
B, C, H, W = x.shape
# B, C, H, W = x.shape
# if self.img_size is not None:
# if self.strict_img_size:
# _assert(H == self.img_size[0], f"Input height ({H}) doesn't match model ({self.img_size[0]}).")
@@ -266,8 +266,6 @@ def split_qkv(qkv, head_dim):
qkv = qkv.reshape(qkv.shape[0], qkv.shape[1], 3, -1, head_dim).movedim(2, 0)
return qkv[0], qkv[1], qkv[2]
def optimized_attention(qkv, num_heads):
return attention.optimized_attention(qkv[0], qkv[1], qkv[2], num_heads)
class SelfAttention(nn.Module):
ATTENTION_MODES = ("xformers", "torch", "torch-hb", "math", "debug")
@@ -326,9 +324,9 @@ class SelfAttention(nn.Module):
return x
def forward(self, x: torch.Tensor) -> torch.Tensor:
qkv = self.pre_attention(x)
q, k, v = self.pre_attention(x)
x = optimized_attention(
qkv, num_heads=self.num_heads
q, k, v, heads=self.num_heads
)
x = self.post_attention(x)
return x
@@ -355,29 +353,9 @@ class RMSNorm(torch.nn.Module):
else:
self.register_parameter("weight", None)
def _norm(self, x):
"""
Apply the RMSNorm normalization to the input tensor.
Args:
x (torch.Tensor): The input tensor.
Returns:
torch.Tensor: The normalized tensor.
"""
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
"""
Forward pass through the RMSNorm layer.
Args:
x (torch.Tensor): The input tensor.
Returns:
torch.Tensor: The output tensor after applying RMSNorm.
"""
x = self._norm(x)
if self.learnable_scale:
return x * self.weight.to(device=x.device, dtype=x.dtype)
else:
return x
return comfy.ldm.common_dit.rms_norm(x, self.weight, self.eps)
class SwiGLUFeedForward(nn.Module):
@@ -437,6 +415,7 @@ class DismantledBlock(nn.Module):
scale_mod_only: bool = False,
swiglu: bool = False,
qk_norm: Optional[str] = None,
x_block_self_attn: bool = False,
dtype=None,
device=None,
operations=None,
@@ -460,6 +439,24 @@ class DismantledBlock(nn.Module):
device=device,
operations=operations
)
if x_block_self_attn:
assert not pre_only
assert not scale_mod_only
self.x_block_self_attn = True
self.attn2 = SelfAttention(
dim=hidden_size,
num_heads=num_heads,
qkv_bias=qkv_bias,
attn_mode=attn_mode,
pre_only=False,
qk_norm=qk_norm,
rmsnorm=rmsnorm,
dtype=dtype,
device=device,
operations=operations
)
else:
self.x_block_self_attn = False
if not pre_only:
if not rmsnorm:
self.norm2 = operations.LayerNorm(
@@ -486,7 +483,11 @@ class DismantledBlock(nn.Module):
multiple_of=256,
)
self.scale_mod_only = scale_mod_only
if not scale_mod_only:
if x_block_self_attn:
assert not pre_only
assert not scale_mod_only
n_mods = 9
elif not scale_mod_only:
n_mods = 6 if not pre_only else 2
else:
n_mods = 4 if not pre_only else 1
@@ -547,14 +548,64 @@ class DismantledBlock(nn.Module):
)
return x
def pre_attention_x(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
assert self.x_block_self_attn
(
shift_msa,
scale_msa,
gate_msa,
shift_mlp,
scale_mlp,
gate_mlp,
shift_msa2,
scale_msa2,
gate_msa2,
) = self.adaLN_modulation(c).chunk(9, dim=1)
x_norm = self.norm1(x)
qkv = self.attn.pre_attention(modulate(x_norm, shift_msa, scale_msa))
qkv2 = self.attn2.pre_attention(modulate(x_norm, shift_msa2, scale_msa2))
return qkv, qkv2, (
x,
gate_msa,
shift_mlp,
scale_mlp,
gate_mlp,
gate_msa2,
)
def post_attention_x(self, attn, attn2, x, gate_msa, shift_mlp, scale_mlp, gate_mlp, gate_msa2):
assert not self.pre_only
attn1 = self.attn.post_attention(attn)
attn2 = self.attn2.post_attention(attn2)
out1 = gate_msa.unsqueeze(1) * attn1
out2 = gate_msa2.unsqueeze(1) * attn2
x = x + out1
x = x + out2
x = x + gate_mlp.unsqueeze(1) * self.mlp(
modulate(self.norm2(x), shift_mlp, scale_mlp)
)
return x
def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
assert not self.pre_only
qkv, intermediates = self.pre_attention(x, c)
attn = optimized_attention(
qkv,
num_heads=self.attn.num_heads,
)
return self.post_attention(attn, *intermediates)
if self.x_block_self_attn:
qkv, qkv2, intermediates = self.pre_attention_x(x, c)
attn, _ = optimized_attention(
qkv[0], qkv[1], qkv[2],
num_heads=self.attn.num_heads,
)
attn2, _ = optimized_attention(
qkv2[0], qkv2[1], qkv2[2],
num_heads=self.attn2.num_heads,
)
return self.post_attention_x(attn, attn2, *intermediates)
else:
qkv, intermediates = self.pre_attention(x, c)
attn = optimized_attention(
qkv[0], qkv[1], qkv[2],
heads=self.attn.num_heads,
)
return self.post_attention(attn, *intermediates)
def block_mixing(*args, use_checkpoint=True, **kwargs):
@@ -569,7 +620,10 @@ def block_mixing(*args, use_checkpoint=True, **kwargs):
def _block_mixing(context, x, context_block, x_block, c):
context_qkv, context_intermediates = context_block.pre_attention(context, c)
x_qkv, x_intermediates = x_block.pre_attention(x, c)
if x_block.x_block_self_attn:
x_qkv, x_qkv2, x_intermediates = x_block.pre_attention_x(x, c)
else:
x_qkv, x_intermediates = x_block.pre_attention(x, c)
o = []
for t in range(3):
@@ -577,8 +631,8 @@ def _block_mixing(context, x, context_block, x_block, c):
qkv = tuple(o)
attn = optimized_attention(
qkv,
num_heads=x_block.attn.num_heads,
qkv[0], qkv[1], qkv[2],
heads=x_block.attn.num_heads,
)
context_attn, x_attn = (
attn[:, : context_qkv[0].shape[1]],
@@ -590,7 +644,14 @@ def _block_mixing(context, x, context_block, x_block, c):
else:
context = None
x = x_block.post_attention(x_attn, *x_intermediates)
if x_block.x_block_self_attn:
attn2 = optimized_attention(
x_qkv2[0], x_qkv2[1], x_qkv2[2],
heads=x_block.attn2.num_heads,
)
x = x_block.post_attention_x(x_attn, attn2, *x_intermediates)
else:
x = x_block.post_attention(x_attn, *x_intermediates)
return context, x
@@ -605,8 +666,13 @@ class JointBlock(nn.Module):
super().__init__()
pre_only = kwargs.pop("pre_only")
qk_norm = kwargs.pop("qk_norm", None)
x_block_self_attn = kwargs.pop("x_block_self_attn", False)
self.context_block = DismantledBlock(*args, pre_only=pre_only, qk_norm=qk_norm, **kwargs)
self.x_block = DismantledBlock(*args, pre_only=False, qk_norm=qk_norm, **kwargs)
self.x_block = DismantledBlock(*args,
pre_only=False,
qk_norm=qk_norm,
x_block_self_attn=x_block_self_attn,
**kwargs)
def forward(self, *args, **kwargs):
return block_mixing(
@@ -662,7 +728,7 @@ class SelfAttentionContext(nn.Module):
def forward(self, x):
qkv = self.qkv(x)
q, k, v = split_qkv(qkv, self.dim_head)
x = optimized_attention((q.reshape(q.shape[0], q.shape[1], -1), k, v), self.heads)
x = optimized_attention(q.reshape(q.shape[0], q.shape[1], -1), k, v, heads=self.heads)
return self.proj(x)
class ContextProcessorBlock(nn.Module):
@@ -721,9 +787,12 @@ class MMDiT(nn.Module):
qk_norm: Optional[str] = None,
qkv_bias: bool = True,
context_processor_layers = None,
x_block_self_attn: bool = False,
x_block_self_attn_layers: Optional[List[int]] = [],
context_size = 4096,
num_blocks = None,
final_layer = True,
skip_blocks = False,
dtype = None, #TODO
device = None,
operations = None,
@@ -738,6 +807,7 @@ class MMDiT(nn.Module):
self.pos_embed_scaling_factor = pos_embed_scaling_factor
self.pos_embed_offset = pos_embed_offset
self.pos_embed_max_size = pos_embed_max_size
self.x_block_self_attn_layers = x_block_self_attn_layers
# hidden_size = default(hidden_size, 64 * depth)
# num_heads = default(num_heads, hidden_size // 64)
@@ -795,26 +865,28 @@ class MMDiT(nn.Module):
self.pos_embed = None
self.use_checkpoint = use_checkpoint
self.joint_blocks = nn.ModuleList(
[
JointBlock(
self.hidden_size,
num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
attn_mode=attn_mode,
pre_only=(i == num_blocks - 1) and final_layer,
rmsnorm=rmsnorm,
scale_mod_only=scale_mod_only,
swiglu=swiglu,
qk_norm=qk_norm,
dtype=dtype,
device=device,
operations=operations
)
for i in range(num_blocks)
]
)
if not skip_blocks:
self.joint_blocks = nn.ModuleList(
[
JointBlock(
self.hidden_size,
num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
attn_mode=attn_mode,
pre_only=(i == num_blocks - 1) and final_layer,
rmsnorm=rmsnorm,
scale_mod_only=scale_mod_only,
swiglu=swiglu,
qk_norm=qk_norm,
x_block_self_attn=(i in self.x_block_self_attn_layers) or x_block_self_attn,
dtype=dtype,
device=device,
operations=operations,
)
for i in range(num_blocks)
]
)
if final_layer:
self.final_layer = FinalLayer(self.hidden_size, patch_size, self.out_channels, dtype=dtype, device=device, operations=operations)
@@ -877,7 +949,9 @@ class MMDiT(nn.Module):
c_mod: torch.Tensor,
context: Optional[torch.Tensor] = None,
control = None,
transformer_options = {},
) -> torch.Tensor:
patches_replace = transformer_options.get("patches_replace", {})
if self.register_length > 0:
context = torch.cat(
(
@@ -889,14 +963,25 @@ class MMDiT(nn.Module):
# context is B, L', D
# x is B, L, D
blocks_replace = patches_replace.get("dit", {})
blocks = len(self.joint_blocks)
for i in range(blocks):
context, x = self.joint_blocks[i](
context,
x,
c=c_mod,
use_checkpoint=self.use_checkpoint,
)
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["txt"], out["img"] = self.joint_blocks[i](args["txt"], args["img"], c=args["vec"])
return out
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": c_mod}, {"original_block": block_wrap})
context = out["txt"]
x = out["img"]
else:
context, x = self.joint_blocks[i](
context,
x,
c=c_mod,
use_checkpoint=self.use_checkpoint,
)
if control is not None:
control_o = control.get("output")
if i < len(control_o):
@@ -914,6 +999,7 @@ class MMDiT(nn.Module):
y: Optional[torch.Tensor] = None,
context: Optional[torch.Tensor] = None,
control = None,
transformer_options = {},
) -> torch.Tensor:
"""
Forward pass of DiT.
@@ -935,7 +1021,7 @@ class MMDiT(nn.Module):
if context is not None:
context = self.context_embedder(context)
x = self.forward_core_with_concat(x, c, context, control)
x = self.forward_core_with_concat(x, c, context, control, transformer_options)
x = self.unpatchify(x, hw=hw) # (N, out_channels, H, W)
return x[:,:,:hw[-2],:hw[-1]]
@@ -949,7 +1035,8 @@ class OpenAISignatureMMDITWrapper(MMDiT):
context: Optional[torch.Tensor] = None,
y: Optional[torch.Tensor] = None,
control = None,
transformer_options = {},
**kwargs,
) -> torch.Tensor:
return super().forward(x, timesteps, context=context, y=y, control=control)
return super().forward(x, timesteps, context=context, y=y, control=control, transformer_options=transformer_options)

View File

@@ -842,6 +842,11 @@ class UNetModel(nn.Module):
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype)
emb = self.time_embed(t_emb)
if "emb_patch" in transformer_patches:
patch = transformer_patches["emb_patch"]
for p in patch:
emb = p(emb, self.model_channels, transformer_options)
if self.num_classes is not None:
assert y.shape[0] == x.shape[0]
emb = emb + self.label_emb(y)

View File

@@ -16,6 +16,7 @@
along with this program. If not, see <https://www.gnu.org/licenses/>.
"""
from __future__ import annotations
import comfy.utils
import comfy.model_management
import comfy.model_base
@@ -200,9 +201,13 @@ def load_lora(lora, to_load):
def model_lora_keys_clip(model, key_map={}):
sdk = model.state_dict().keys()
for k in sdk:
if k.endswith(".weight"):
key_map["text_encoders.{}".format(k[:-len(".weight")])] = k #generic lora format without any weird key names
text_model_lora_key = "lora_te_text_model_encoder_layers_{}_{}"
clip_l_present = False
clip_g_present = False
for b in range(32): #TODO: clean up
for c in LORA_CLIP_MAP:
k = "clip_h.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c)
@@ -226,6 +231,7 @@ def model_lora_keys_clip(model, key_map={}):
k = "clip_g.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c)
if k in sdk:
clip_g_present = True
if clip_l_present:
lora_key = "lora_te2_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) #SDXL base
key_map[lora_key] = k
@@ -241,10 +247,18 @@ def model_lora_keys_clip(model, key_map={}):
for k in sdk:
if k.endswith(".weight"):
if k.startswith("t5xxl.transformer."):#OneTrainer SD3 lora
if k.startswith("t5xxl.transformer."):#OneTrainer SD3 and Flux lora
l_key = k[len("t5xxl.transformer."):-len(".weight")]
lora_key = "lora_te3_{}".format(l_key.replace(".", "_"))
key_map[lora_key] = k
t5_index = 1
if clip_g_present:
t5_index += 1
if clip_l_present:
t5_index += 1
if t5_index == 2:
key_map["lora_te{}_{}".format(t5_index, l_key.replace(".", "_"))] = k #OneTrainer Flux
t5_index += 1
key_map["lora_te{}_{}".format(t5_index, l_key.replace(".", "_"))] = k
elif k.startswith("hydit_clip.transformer.bert."): #HunyuanDiT Lora
l_key = k[len("hydit_clip.transformer.bert."):-len(".weight")]
lora_key = "lora_te1_{}".format(l_key.replace(".", "_"))
@@ -280,6 +294,7 @@ def model_lora_keys_unet(model, key_map={}):
unet_key = "diffusion_model.{}".format(diffusers_keys[k])
key_lora = k[:-len(".weight")].replace(".", "_")
key_map["lora_unet_{}".format(key_lora)] = unet_key
key_map["lycoris_{}".format(key_lora)] = unet_key #simpletuner lycoris format
diffusers_lora_prefix = ["", "unet."]
for p in diffusers_lora_prefix:
@@ -302,6 +317,10 @@ def model_lora_keys_unet(model, key_map={}):
key_lora = "lora_transformer_{}".format(k[:-len(".weight")].replace(".", "_")) #OneTrainer lora
key_map[key_lora] = to
key_lora = "lycoris_{}".format(k[:-len(".weight")].replace(".", "_")) #simpletuner lycoris format
key_map[key_lora] = to
if isinstance(model, comfy.model_base.AuraFlow): #Diffusers lora AuraFlow
diffusers_keys = comfy.utils.auraflow_to_diffusers(model.model_config.unet_config, output_prefix="diffusion_model.")
for k in diffusers_keys:
@@ -323,14 +342,15 @@ def model_lora_keys_unet(model, key_map={}):
to = diffusers_keys[k]
key_map["transformer.{}".format(k[:-len(".weight")])] = to #simpletrainer and probably regular diffusers flux lora format
key_map["lycoris_{}".format(k[:-len(".weight")].replace(".", "_"))] = to #simpletrainer lycoris
key_map["lora_transformer_{}".format(k[:-len(".weight")].replace(".", "_"))] = to #onetrainer
return key_map
def weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype):
def weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function):
dora_scale = comfy.model_management.cast_to_device(dora_scale, weight.device, intermediate_dtype)
lora_diff *= alpha
weight_calc = weight + lora_diff.type(weight.dtype)
weight_calc = weight + function(lora_diff).type(weight.dtype)
weight_norm = (
weight_calc.transpose(0, 1)
.reshape(weight_calc.shape[1], -1)
@@ -347,6 +367,39 @@ def weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediat
weight[:] = weight_calc
return weight
def pad_tensor_to_shape(tensor: torch.Tensor, new_shape: list[int]) -> torch.Tensor:
"""
Pad a tensor to a new shape with zeros.
Args:
tensor (torch.Tensor): The original tensor to be padded.
new_shape (List[int]): The desired shape of the padded tensor.
Returns:
torch.Tensor: A new tensor padded with zeros to the specified shape.
Note:
If the new shape is smaller than the original tensor in any dimension,
the original tensor will be truncated in that dimension.
"""
if any([new_shape[i] < tensor.shape[i] for i in range(len(new_shape))]):
raise ValueError("The new shape must be larger than the original tensor in all dimensions")
if len(new_shape) != len(tensor.shape):
raise ValueError("The new shape must have the same number of dimensions as the original tensor")
# Create a new tensor filled with zeros
padded_tensor = torch.zeros(new_shape, dtype=tensor.dtype, device=tensor.device)
# Create slicing tuples for both tensors
orig_slices = tuple(slice(0, dim) for dim in tensor.shape)
new_slices = tuple(slice(0, dim) for dim in tensor.shape)
# Copy the original tensor into the new tensor
padded_tensor[new_slices] = tensor[orig_slices]
return padded_tensor
def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32):
for p in patches:
strength = p[0]
@@ -366,7 +419,7 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32):
weight *= strength_model
if isinstance(v, list):
v = (calculate_weight(v[1:], v[0].clone(), key, intermediate_dtype=intermediate_dtype), )
v = (calculate_weight(v[1:], v[0][1](comfy.model_management.cast_to_device(v[0][0], weight.device, intermediate_dtype, copy=True), inplace=True), key, intermediate_dtype=intermediate_dtype), )
if len(v) == 1:
patch_type = "diff"
@@ -375,12 +428,18 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32):
v = v[1]
if patch_type == "diff":
w1 = v[0]
diff: torch.Tensor = v[0]
# An extra flag to pad the weight if the diff's shape is larger than the weight
do_pad_weight = len(v) > 1 and v[1]['pad_weight']
if do_pad_weight and diff.shape != weight.shape:
logging.info("Pad weight {} from {} to shape: {}".format(key, weight.shape, diff.shape))
weight = pad_tensor_to_shape(weight, diff.shape)
if strength != 0.0:
if w1.shape != weight.shape:
logging.warning("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, w1.shape, weight.shape))
if diff.shape != weight.shape:
logging.warning("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, diff.shape, weight.shape))
else:
weight += function(strength * comfy.model_management.cast_to_device(w1, weight.device, weight.dtype))
weight += function(strength * comfy.model_management.cast_to_device(diff, weight.device, weight.dtype))
elif patch_type == "lora": #lora/locon
mat1 = comfy.model_management.cast_to_device(v[0], weight.device, intermediate_dtype)
mat2 = comfy.model_management.cast_to_device(v[1], weight.device, intermediate_dtype)
@@ -398,7 +457,7 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32):
try:
lora_diff = torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1)).reshape(weight.shape)
if dora_scale is not None:
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype))
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
else:
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
except Exception as e:
@@ -444,7 +503,7 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32):
try:
lora_diff = torch.kron(w1, w2).reshape(weight.shape)
if dora_scale is not None:
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype))
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
else:
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
except Exception as e:
@@ -481,28 +540,48 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32):
try:
lora_diff = (m1 * m2).reshape(weight.shape)
if dora_scale is not None:
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype))
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
else:
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
except Exception as e:
logging.error("ERROR {} {} {}".format(patch_type, key, e))
elif patch_type == "glora":
if v[4] is not None:
alpha = v[4] / v[0].shape[0]
else:
alpha = 1.0
dora_scale = v[5]
old_glora = False
if v[3].shape[1] == v[2].shape[0] == v[0].shape[0] == v[1].shape[1]:
rank = v[0].shape[0]
old_glora = True
if v[3].shape[0] == v[2].shape[1] == v[0].shape[1] == v[1].shape[0]:
if old_glora and v[1].shape[0] == weight.shape[0] and weight.shape[0] == weight.shape[1]:
pass
else:
old_glora = False
rank = v[1].shape[0]
a1 = comfy.model_management.cast_to_device(v[0].flatten(start_dim=1), weight.device, intermediate_dtype)
a2 = comfy.model_management.cast_to_device(v[1].flatten(start_dim=1), weight.device, intermediate_dtype)
b1 = comfy.model_management.cast_to_device(v[2].flatten(start_dim=1), weight.device, intermediate_dtype)
b2 = comfy.model_management.cast_to_device(v[3].flatten(start_dim=1), weight.device, intermediate_dtype)
if v[4] is not None:
alpha = v[4] / rank
else:
alpha = 1.0
try:
lora_diff = (torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1), a2), a1)).reshape(weight.shape)
if old_glora:
lora_diff = (torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1).to(dtype=intermediate_dtype), a2), a1)).reshape(weight.shape) #old lycoris glora
else:
if weight.dim() > 2:
lora_diff = torch.einsum("o i ..., i j -> o j ...", torch.einsum("o i ..., i j -> o j ...", weight.to(dtype=intermediate_dtype), a1), a2).reshape(weight.shape)
else:
lora_diff = torch.mm(torch.mm(weight.to(dtype=intermediate_dtype), a1), a2).reshape(weight.shape)
lora_diff += torch.mm(b1, b2).reshape(weight.shape)
if dora_scale is not None:
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype))
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
else:
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
except Exception as e:

View File

@@ -24,6 +24,7 @@ from comfy.ldm.cascade.stage_b import StageB
from comfy.ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugmentation
from comfy.ldm.modules.diffusionmodules.upscaling import ImageConcatWithNoiseAugmentation
from comfy.ldm.modules.diffusionmodules.mmdit import OpenAISignatureMMDITWrapper
import comfy.ldm.genmo.joint_model.asymm_models_joint
import comfy.ldm.aura.mmdit
import comfy.ldm.hydit.models
import comfy.ldm.audio.dit
@@ -96,7 +97,8 @@ class BaseModel(torch.nn.Module):
if not unet_config.get("disable_unet_model_creation", False):
if model_config.custom_operations is None:
operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype)
fp8 = model_config.optimizations.get("fp8", model_config.scaled_fp8 is not None)
operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype, fp8_optimizations=fp8, scaled_fp8=model_config.scaled_fp8)
else:
operations = model_config.custom_operations
self.diffusion_model = unet_model(**unet_config, device=device, operations=operations)
@@ -244,6 +246,10 @@ class BaseModel(torch.nn.Module):
extra_sds.append(self.model_config.process_clip_vision_state_dict_for_saving(clip_vision_state_dict))
unet_state_dict = self.diffusion_model.state_dict()
if self.model_config.scaled_fp8 is not None:
unet_state_dict["scaled_fp8"] = torch.tensor([], dtype=self.model_config.scaled_fp8)
unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict)
if self.model_type == ModelType.V_PREDICTION:
@@ -713,3 +719,18 @@ class Flux(BaseModel):
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([kwargs.get("guidance", 3.5)]))
return out
class GenmoMochi(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.genmo.joint_model.asymm_models_joint.AsymmDiTJoint)
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
attention_mask = kwargs.get("attention_mask", None)
if attention_mask is not None:
out['attention_mask'] = comfy.conds.CONDRegular(attention_mask)
out['num_tokens'] = comfy.conds.CONDConstant(max(1, torch.sum(attention_mask).item()))
cross_attn = kwargs.get("cross_attn", None)
if cross_attn is not None:
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
return out

View File

@@ -70,6 +70,11 @@ def detect_unet_config(state_dict, key_prefix):
context_processor = '{}context_processor.layers.0.attn.qkv.weight'.format(key_prefix)
if context_processor in state_dict_keys:
unet_config["context_processor_layers"] = count_blocks(state_dict_keys, '{}context_processor.layers.'.format(key_prefix) + '{}.')
unet_config["x_block_self_attn_layers"] = []
for key in state_dict_keys:
if key.startswith('{}joint_blocks.'.format(key_prefix)) and key.endswith('.x_block.attn2.qkv.weight'):
layer = key[len('{}joint_blocks.'.format(key_prefix)):-len('.x_block.attn2.qkv.weight')]
unet_config["x_block_self_attn_layers"].append(int(layer))
return unet_config
if '{}clf.1.weight'.format(key_prefix) in state_dict_keys: #stable cascade
@@ -145,6 +150,34 @@ def detect_unet_config(state_dict, key_prefix):
dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys
return dit_config
if '{}t5_yproj.weight'.format(key_prefix) in state_dict_keys: #Genmo mochi preview
dit_config = {}
dit_config["image_model"] = "mochi_preview"
dit_config["depth"] = 48
dit_config["patch_size"] = 2
dit_config["num_heads"] = 24
dit_config["hidden_size_x"] = 3072
dit_config["hidden_size_y"] = 1536
dit_config["mlp_ratio_x"] = 4.0
dit_config["mlp_ratio_y"] = 4.0
dit_config["learn_sigma"] = False
dit_config["in_channels"] = 12
dit_config["qk_norm"] = True
dit_config["qkv_bias"] = False
dit_config["out_bias"] = True
dit_config["attn_drop"] = 0.0
dit_config["patch_embed_bias"] = True
dit_config["posenc_preserve_area"] = True
dit_config["timestep_mlp_bias"] = True
dit_config["attend_to_padding"] = False
dit_config["timestep_scale"] = 1000.0
dit_config["use_t5"] = True
dit_config["t5_feat_dim"] = 4096
dit_config["t5_token_length"] = 256
dit_config["rope_theta"] = 10000.0
return dit_config
if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys:
return None
@@ -286,9 +319,15 @@ def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=Fal
return None
model_config = model_config_from_unet_config(unet_config, state_dict)
if model_config is None and use_base_if_no_match:
return comfy.supported_models_base.BASE(unet_config)
else:
return model_config
model_config = comfy.supported_models_base.BASE(unet_config)
scaled_fp8_weight = state_dict.get("{}scaled_fp8".format(unet_key_prefix), None)
if scaled_fp8_weight is not None:
model_config.scaled_fp8 = scaled_fp8_weight.dtype
if model_config.scaled_fp8 == torch.float32:
model_config.scaled_fp8 = torch.float8_e4m3fn
return model_config
def unet_prefix_from_state_dict(state_dict):
candidates = ["model.diffusion_model.", #ldm/sgm models

View File

@@ -45,6 +45,7 @@ cpu_state = CPUState.GPU
total_vram = 0
xpu_available = False
torch_version = ""
try:
torch_version = torch.version.__version__
xpu_available = (int(torch_version[0]) < 2 or (int(torch_version[0]) == 2 and int(torch_version[2]) <= 4)) and torch.xpu.is_available()
@@ -144,7 +145,7 @@ total_ram = psutil.virtual_memory().total / (1024 * 1024)
logging.info("Total VRAM {:0.0f} MB, total RAM {:0.0f} MB".format(total_vram, total_ram))
try:
logging.info("pytorch version: {}".format(torch.version.__version__))
logging.info("pytorch version: {}".format(torch_version))
except:
pass
@@ -325,7 +326,7 @@ class LoadedModel:
self.model_unload()
raise e
if is_intel_xpu() and not args.disable_ipex_optimize and self.real_model is not None:
if is_intel_xpu() and not args.disable_ipex_optimize and 'ipex' in globals() and self.real_model is not None:
with torch.no_grad():
self.real_model = ipex.optimize(self.real_model.eval(), inplace=True, graph_mode=True, concat_linear=True)
@@ -369,12 +370,11 @@ def offloaded_memory(loaded_models, device):
offloaded_mem += m.model_offloaded_memory()
return offloaded_mem
def minimum_inference_memory():
return (1024 * 1024 * 1024) * 1.2
WINDOWS = any(platform.win32_ver())
EXTRA_RESERVED_VRAM = 200 * 1024 * 1024
if any(platform.win32_ver()):
EXTRA_RESERVED_VRAM = 500 * 1024 * 1024 #Windows is higher because of the shared vram issue
EXTRA_RESERVED_VRAM = 400 * 1024 * 1024
if WINDOWS:
EXTRA_RESERVED_VRAM = 600 * 1024 * 1024 #Windows is higher because of the shared vram issue
if args.reserve_vram is not None:
EXTRA_RESERVED_VRAM = args.reserve_vram * 1024 * 1024 * 1024
@@ -383,6 +383,9 @@ if args.reserve_vram is not None:
def extra_reserved_memory():
return EXTRA_RESERVED_VRAM
def minimum_inference_memory():
return (1024 * 1024 * 1024) * 0.8 + extra_reserved_memory()
def unload_model_clones(model, unload_weights_only=True, force_unload=True):
to_unload = []
for i in range(len(current_loaded_models)):
@@ -405,6 +408,8 @@ def unload_model_clones(model, unload_weights_only=True, force_unload=True):
if not force_unload:
if unload_weights_only and unload_weight == False:
return None
else:
unload_weight = True
for i in to_unload:
logging.debug("unload clone {} {}".format(i, unload_weight))
@@ -421,7 +426,7 @@ def free_memory(memory_required, device, keep_loaded=[]):
shift_model = current_loaded_models[i]
if shift_model.device == device:
if shift_model not in keep_loaded:
can_unload.append((sys.getrefcount(shift_model.model), shift_model.model_memory(), i))
can_unload.append((-shift_model.model_offloaded_memory(), sys.getrefcount(shift_model.model), shift_model.model_memory(), i))
shift_model.currently_used = False
for x in sorted(can_unload):
@@ -621,6 +626,8 @@ def maximum_vram_for_weights(device=None):
return (get_total_memory(device) * 0.88 - minimum_inference_memory())
def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, torch.bfloat16, torch.float32]):
if model_params < 0:
model_params = 1000000000000000000000
if args.bf16_unet:
return torch.bfloat16
if args.fp16_unet:
@@ -640,6 +647,9 @@ def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, tor
pass
if fp8_dtype is not None:
if supports_fp8_compute(device): #if fp8 compute is supported the casting is most likely not expensive
return fp8_dtype
free_model_memory = maximum_vram_for_weights(device)
if model_params * 2 > free_model_memory:
return fp8_dtype
@@ -833,27 +843,21 @@ def force_channels_last():
#TODO
return False
def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False):
if device is None or weight.device == device:
if not copy:
if dtype is None or weight.dtype == dtype:
return weight
return weight.to(dtype=dtype, copy=copy)
r = torch.empty_like(weight, dtype=dtype, device=device)
r.copy_(weight, non_blocking=non_blocking)
return r
def cast_to_device(tensor, device, dtype, copy=False):
device_supports_cast = False
if tensor.dtype == torch.float32 or tensor.dtype == torch.float16:
device_supports_cast = True
elif tensor.dtype == torch.bfloat16:
if hasattr(device, 'type') and device.type.startswith("cuda"):
device_supports_cast = True
elif is_intel_xpu():
device_supports_cast = True
non_blocking = device_supports_non_blocking(device)
return cast_to(tensor, dtype=dtype, device=device, non_blocking=non_blocking, copy=copy)
non_blocking = device_should_use_non_blocking(device)
if device_supports_cast:
if copy:
if tensor.device == device:
return tensor.to(dtype, copy=copy, non_blocking=non_blocking)
return tensor.to(device, copy=copy, non_blocking=non_blocking).to(dtype, non_blocking=non_blocking)
else:
return tensor.to(device, non_blocking=non_blocking).to(dtype, non_blocking=non_blocking)
else:
return tensor.to(device, dtype, copy=copy, non_blocking=non_blocking)
def xformers_enabled():
global directml_enabled
@@ -892,7 +896,7 @@ def force_upcast_attention_dtype():
upcast = args.force_upcast_attention
try:
macos_version = tuple(int(n) for n in platform.mac_ver()[0].split("."))
if (14, 5) <= macos_version < (14, 7): # black image bug on recent versions of MacOS
if (14, 5) <= macos_version <= (15, 0, 1): # black image bug on recent versions of macOS
upcast = True
except:
pass
@@ -999,7 +1003,10 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma
nvidia_10_series = ["1080", "1070", "titan x", "p3000", "p3200", "p4000", "p4200", "p5000", "p5200", "p6000", "1060", "1050", "p40", "p100", "p6", "p4"]
for x in nvidia_10_series:
if x in props.name.lower():
return True
if WINDOWS or manual_cast:
return True
else:
return False #weird linux behavior where fp32 is faster
if manual_cast:
free_model_memory = maximum_vram_for_weights(device)
@@ -1055,6 +1062,9 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma
return False
def supports_fp8_compute(device=None):
if not is_nvidia():
return False
props = torch.cuda.get_device_properties(device)
if props.major >= 9:
return True
@@ -1062,6 +1072,14 @@ def supports_fp8_compute(device=None):
return False
if props.minor < 9:
return False
if int(torch_version[0]) < 2 or (int(torch_version[0]) == 2 and int(torch_version[2]) < 3):
return False
if WINDOWS:
if (int(torch_version[0]) == 2 and int(torch_version[2]) < 4):
return False
return True
def soft_empty_cache(force=False):

View File

@@ -28,8 +28,20 @@ import comfy.utils
import comfy.float
import comfy.model_management
import comfy.lora
from comfy.types import UnetWrapperFunction
from comfy.comfy_types import UnetWrapperFunction
def string_to_seed(data):
crc = 0xFFFFFFFF
for byte in data:
if isinstance(byte, str):
byte = ord(byte)
crc ^= byte
for _ in range(8):
if crc & 1:
crc = (crc >> 1) ^ 0xEDB88320
else:
crc >>= 1
return crc ^ 0xFFFFFFFF
def set_model_options_patch_replace(model_options, patch, name, block_name, number, transformer_index=None):
to = model_options["transformer_options"].copy()
@@ -76,7 +88,36 @@ class LowVramPatch:
self.key = key
self.patches = patches
def __call__(self, weight):
return comfy.lora.calculate_weight(self.patches[self.key], weight, self.key, intermediate_dtype=weight.dtype)
intermediate_dtype = weight.dtype
if intermediate_dtype not in [torch.float32, torch.float16, torch.bfloat16]: #intermediate_dtype has to be one that is supported in math ops
intermediate_dtype = torch.float32
return comfy.float.stochastic_rounding(comfy.lora.calculate_weight(self.patches[self.key], weight.to(intermediate_dtype), self.key, intermediate_dtype=intermediate_dtype), weight.dtype, seed=string_to_seed(self.key))
return comfy.lora.calculate_weight(self.patches[self.key], weight, self.key, intermediate_dtype=intermediate_dtype)
def get_key_weight(model, key):
set_func = None
convert_func = None
op_keys = key.rsplit('.', 1)
if len(op_keys) < 2:
weight = comfy.utils.get_attr(model, key)
else:
op = comfy.utils.get_attr(model, op_keys[0])
try:
set_func = getattr(op, "set_{}".format(op_keys[1]))
except AttributeError:
pass
try:
convert_func = getattr(op, "convert_{}".format(op_keys[1]))
except AttributeError:
pass
weight = getattr(op, op_keys[1])
if convert_func is not None:
weight = comfy.utils.get_attr(model, key)
return weight, set_func, convert_func
class ModelPatcher:
def __init__(self, model, load_device, offload_device, size=0, weight_inplace_update=False):
@@ -271,17 +312,23 @@ class ModelPatcher:
return list(p)
def get_key_patches(self, filter_prefix=None):
comfy.model_management.unload_model_clones(self)
model_sd = self.model_state_dict()
p = {}
for k in model_sd:
if filter_prefix is not None:
if not k.startswith(filter_prefix):
continue
bk = self.backup.get(k, None)
weight, set_func, convert_func = get_key_weight(self.model, k)
if bk is not None:
weight = bk.weight
if convert_func is None:
convert_func = lambda a, **kwargs: a
if k in self.patches:
p[k] = [model_sd[k]] + self.patches[k]
p[k] = [(weight, convert_func)] + self.patches[k]
else:
p[k] = (model_sd[k],)
p[k] = [(weight, convert_func)]
return p
def model_state_dict(self, filter_prefix=None):
@@ -297,8 +344,7 @@ class ModelPatcher:
if key not in self.patches:
return
weight = comfy.utils.get_attr(self.model, key)
weight, set_func, convert_func = get_key_weight(self.model, key)
inplace_update = self.weight_inplace_update or inplace_update
if key not in self.backup:
@@ -308,23 +354,38 @@ class ModelPatcher:
temp_weight = comfy.model_management.cast_to_device(weight, device_to, torch.float32, copy=True)
else:
temp_weight = weight.to(torch.float32, copy=True)
if convert_func is not None:
temp_weight = convert_func(temp_weight, inplace=True)
out_weight = comfy.lora.calculate_weight(self.patches[key], temp_weight, key)
out_weight = comfy.float.stochastic_rounding(out_weight, weight.dtype)
if inplace_update:
comfy.utils.copy_to_param(self.model, key, out_weight)
if set_func is None:
out_weight = comfy.float.stochastic_rounding(out_weight, weight.dtype, seed=string_to_seed(key))
if inplace_update:
comfy.utils.copy_to_param(self.model, key, out_weight)
else:
comfy.utils.set_attr_param(self.model, key, out_weight)
else:
comfy.utils.set_attr_param(self.model, key, out_weight)
set_func(out_weight, inplace_update=inplace_update, seed=string_to_seed(key))
def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False):
mem_counter = 0
patch_counter = 0
lowvram_counter = 0
load_completely = []
loading = []
for n, m in self.model.named_modules():
if hasattr(m, "comfy_cast_weights") or hasattr(m, "weight"):
loading.append((comfy.model_management.module_size(m), n, m))
load_completely = []
loading.sort(reverse=True)
for x in loading:
n = x[1]
m = x[2]
module_mem = x[0]
lowvram_weight = False
if not full_load and hasattr(m, "comfy_cast_weights"):
module_mem = comfy.model_management.module_size(m)
if mem_counter + module_mem >= lowvram_model_memory:
lowvram_weight = True
lowvram_counter += 1
@@ -356,9 +417,8 @@ class ModelPatcher:
wipe_lowvram_weight(m)
if hasattr(m, "weight"):
mem_used = comfy.model_management.module_size(m)
mem_counter += mem_used
load_completely.append((mem_used, n, m))
mem_counter += module_mem
load_completely.append((module_mem, n, m))
load_completely.sort(reverse=True)
for x in load_completely:

View File

@@ -19,16 +19,12 @@
import torch
import comfy.model_management
from comfy.cli_args import args
import comfy.float
def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=True):
if not copy and (dtype is None or weight.dtype == dtype) and (device is None or weight.device == device):
return weight
r = torch.empty_like(weight, dtype=dtype, device=device)
r.copy_(weight, non_blocking=non_blocking)
return r
cast_to = comfy.model_management.cast_to #TODO: remove once no more references
def cast_to_input(weight, input, non_blocking=False, copy=True):
return cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy)
return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy)
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None):
if input is not None:
@@ -43,12 +39,12 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None):
non_blocking = comfy.model_management.device_supports_non_blocking(device)
if s.bias is not None:
has_function = s.bias_function is not None
bias = cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=has_function)
bias = comfy.model_management.cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=has_function)
if has_function:
bias = s.bias_function(bias)
has_function = s.weight_function is not None
weight = cast_to(s.weight, dtype, device, non_blocking=non_blocking, copy=has_function)
weight = comfy.model_management.cast_to(s.weight, dtype, device, non_blocking=non_blocking, copy=has_function)
if has_function:
weight = s.weight_function(weight)
return weight, bias
@@ -254,20 +250,29 @@ def fp8_linear(self, input):
if dtype not in [torch.float8_e4m3fn]:
return None
tensor_2d = False
if len(input.shape) == 2:
tensor_2d = True
input = input.unsqueeze(1)
if len(input.shape) == 3:
inn = input.reshape(-1, input.shape[2]).to(dtype)
non_blocking = comfy.model_management.device_supports_non_blocking(input.device)
w, bias = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input.dtype)
w = w.t()
scale_weight = self.scale_weight
scale_input = self.scale_input
if scale_weight is None:
scale_weight = torch.ones((1), device=input.device, dtype=torch.float32)
if scale_input is None:
scale_input = scale_weight
scale_weight = torch.ones((), device=input.device, dtype=torch.float32)
else:
scale_weight = scale_weight.to(input.device)
if scale_input is None:
scale_input = torch.ones((1), device=input.device, dtype=torch.float32)
scale_input = torch.ones((), device=input.device, dtype=torch.float32)
inn = input.reshape(-1, input.shape[2]).to(dtype)
else:
scale_input = scale_input.to(input.device)
inn = (input * (1.0 / scale_input).to(input.dtype)).reshape(-1, input.shape[2]).to(dtype)
if bias is not None:
o = torch._scaled_mm(inn, w, out_dtype=input.dtype, bias=bias, scale_a=scale_input, scale_b=scale_weight)
@@ -277,7 +282,11 @@ def fp8_linear(self, input):
if isinstance(o, tuple):
o = o[0]
if tensor_2d:
return o.reshape(input.shape[0], -1)
return o.reshape((-1, input.shape[1], self.weight.shape[0]))
return None
class fp8_ops(manual_cast):
@@ -295,11 +304,63 @@ class fp8_ops(manual_cast):
weight, bias = cast_bias_weight(self, input)
return torch.nn.functional.linear(input, weight, bias)
def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None):
class scaled_fp8_op(manual_cast):
class Linear(manual_cast.Linear):
def __init__(self, *args, **kwargs):
if override_dtype is not None:
kwargs['dtype'] = override_dtype
super().__init__(*args, **kwargs)
def reset_parameters(self):
if not hasattr(self, 'scale_weight'):
self.scale_weight = torch.nn.parameter.Parameter(data=torch.ones((), device=self.weight.device, dtype=torch.float32), requires_grad=False)
if not scale_input:
self.scale_input = None
if not hasattr(self, 'scale_input'):
self.scale_input = torch.nn.parameter.Parameter(data=torch.ones((), device=self.weight.device, dtype=torch.float32), requires_grad=False)
return None
def forward_comfy_cast_weights(self, input):
if fp8_matrix_mult:
out = fp8_linear(self, input)
if out is not None:
return out
weight, bias = cast_bias_weight(self, input)
if weight.numel() < input.numel(): #TODO: optimize
return torch.nn.functional.linear(input, weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype), bias)
else:
return torch.nn.functional.linear(input * self.scale_weight.to(device=weight.device, dtype=weight.dtype), weight, bias)
def convert_weight(self, weight, inplace=False, **kwargs):
if inplace:
weight *= self.scale_weight.to(device=weight.device, dtype=weight.dtype)
return weight
else:
return weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype)
def set_weight(self, weight, inplace_update=False, seed=None, **kwargs):
weight = comfy.float.stochastic_rounding(weight / self.scale_weight.to(device=weight.device, dtype=weight.dtype), self.weight.dtype, seed=seed)
if inplace_update:
self.weight.data.copy_(weight)
else:
self.weight = torch.nn.Parameter(weight, requires_grad=False)
return scaled_fp8_op
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None):
fp8_compute = comfy.model_management.supports_fp8_compute(load_device)
if scaled_fp8 is not None:
return scaled_fp8_ops(fp8_matrix_mult=fp8_compute, scale_input=True, override_dtype=scaled_fp8)
if fp8_compute and (fp8_optimizations or args.fast) and not disable_fast_fp8:
return fp8_ops
def pick_operations(weight_dtype, compute_dtype, load_device=None):
if compute_dtype is None or weight_dtype == compute_dtype:
return disable_weight_init
if args.fast:
if comfy.model_management.supports_fp8_compute(load_device):
return fp8_ops
return manual_cast

View File

@@ -6,7 +6,7 @@ from comfy import model_management
import math
import logging
import comfy.sampler_helpers
import scipy
import scipy.stats
import numpy
def get_area_and_mult(conds, x_in, timestep_in):
@@ -358,11 +358,35 @@ def beta_scheduler(model_sampling, steps, alpha=0.6, beta=0.6):
ts = numpy.rint(scipy.stats.beta.ppf(ts, alpha, beta) * total_timesteps)
sigs = []
last_t = -1
for t in ts:
sigs += [float(model_sampling.sigmas[int(t)])]
if t != last_t:
sigs += [float(model_sampling.sigmas[int(t)])]
last_t = t
sigs += [0.0]
return torch.FloatTensor(sigs)
# from: https://github.com/genmoai/models/blob/main/src/mochi_preview/infer.py#L41
def linear_quadratic_schedule(model_sampling, steps, threshold_noise=0.025, linear_steps=None):
if steps == 1:
sigma_schedule = [1.0, 0.0]
else:
if linear_steps is None:
linear_steps = steps // 2
linear_sigma_schedule = [i * threshold_noise / linear_steps for i in range(linear_steps)]
threshold_noise_step_diff = linear_steps - threshold_noise * steps
quadratic_steps = steps - linear_steps
quadratic_coef = threshold_noise_step_diff / (linear_steps * quadratic_steps ** 2)
linear_coef = threshold_noise / linear_steps - 2 * threshold_noise_step_diff / (quadratic_steps ** 2)
const = quadratic_coef * (linear_steps ** 2)
quadratic_sigma_schedule = [
quadratic_coef * (i ** 2) + linear_coef * i + const
for i in range(linear_steps, steps)
]
sigma_schedule = linear_sigma_schedule + quadratic_sigma_schedule + [1.0]
sigma_schedule = [1.0 - x for x in sigma_schedule]
return torch.FloatTensor(sigma_schedule) * model_sampling.sigma_max.cpu()
def get_mask_aabb(masks):
if masks.numel() == 0:
return torch.zeros((0, 4), device=masks.device, dtype=torch.int)
@@ -570,8 +594,8 @@ class Sampler:
return math.isclose(max_sigma, sigma, rel_tol=1e-05) or sigma > max_sigma
KSAMPLER_NAMES = ["euler", "euler_cfg_pp", "euler_ancestral", "euler_ancestral_cfg_pp", "heun", "heunpp2","dpm_2", "dpm_2_ancestral",
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde", "dpmpp_sde_gpu",
"dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm",
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_2s_ancestral_cfg_pp", "dpmpp_sde", "dpmpp_sde_gpu",
"dpmpp_2m", "dpmpp_2m_cfg_pp", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm",
"ipndm", "ipndm_v", "deis"]
class KSAMPLER(Sampler):
@@ -729,7 +753,7 @@ def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model
return cfg_guider.sample(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)
SCHEDULER_NAMES = ["normal", "karras", "exponential", "sgm_uniform", "simple", "ddim_uniform", "beta"]
SCHEDULER_NAMES = ["normal", "karras", "exponential", "sgm_uniform", "simple", "ddim_uniform", "beta", "linear_quadratic"]
SAMPLER_NAMES = KSAMPLER_NAMES + ["ddim", "uni_pc", "uni_pc_bh2"]
def calculate_sigmas(model_sampling, scheduler_name, steps):
@@ -747,6 +771,8 @@ def calculate_sigmas(model_sampling, scheduler_name, steps):
sigmas = normal_scheduler(model_sampling, steps, sgm=True)
elif scheduler_name == "beta":
sigmas = beta_scheduler(model_sampling, steps)
elif scheduler_name == "linear_quadratic":
sigmas = linear_quadratic_schedule(model_sampling, steps)
else:
logging.error("error invalid scheduler {}".format(scheduler_name))
return sigmas

View File

@@ -7,6 +7,7 @@ from .ldm.models.autoencoder import AutoencoderKL, AutoencodingEngine
from .ldm.cascade.stage_a import StageA
from .ldm.cascade.stage_c_coder import StageC_coder
from .ldm.audio.autoencoder import AudioOobleckVAE
import comfy.ldm.genmo.vae.model
import yaml
import comfy.utils
@@ -25,11 +26,11 @@ import comfy.text_encoders.aura_t5
import comfy.text_encoders.hydit
import comfy.text_encoders.flux
import comfy.text_encoders.long_clipl
import comfy.text_encoders.genmo
import comfy.model_patcher
import comfy.lora
import comfy.t2i_adapter.adapter
import comfy.supported_models_base
import comfy.taesd.taesd
def load_lora_for_models(model, clip, lora, strength_model, strength_clip):
@@ -70,14 +71,14 @@ class CLIP:
clip = target.clip
tokenizer = target.tokenizer
load_device = model_management.text_encoder_device()
offload_device = model_management.text_encoder_offload_device()
load_device = model_options.get("load_device", model_management.text_encoder_device())
offload_device = model_options.get("offload_device", model_management.text_encoder_offload_device())
dtype = model_options.get("dtype", None)
if dtype is None:
dtype = model_management.text_encoder_dtype(load_device)
params['dtype'] = dtype
params['device'] = model_management.text_encoder_initial_device(load_device, offload_device, parameters * model_management.dtype_size(dtype))
params['device'] = model_options.get("initial_device", model_management.text_encoder_initial_device(load_device, offload_device, parameters * model_management.dtype_size(dtype)))
params['model_options'] = model_options
self.cond_stage_model = clip(**(params))
@@ -242,6 +243,13 @@ class VAE:
self.process_output = lambda audio: audio
self.process_input = lambda audio: audio
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
elif "blocks.2.blocks.3.stack.5.weight" in sd or "decoder.blocks.2.blocks.3.stack.5.weight" in sd: #genmo mochi vae
if "blocks.2.blocks.3.stack.5.weight" in sd:
sd = comfy.utils.state_dict_prefix_replace(sd, {"": "decoder."})
self.first_stage_model = comfy.ldm.genmo.vae.model.VideoVAE()
self.latent_channels = 12
self.memory_used_decode = lambda shape, dtype: (1000 * shape[2] * shape[3] * shape[4] * (6 * 8 * 8)) * model_management.dtype_size(dtype)
self.upscale_ratio = (lambda a: max(0, a * 6 - 5), 8, 8)
else:
logging.warning("WARNING: No VAE weights detected, VAE not initalized.")
self.first_stage_model = None
@@ -297,6 +305,10 @@ class VAE:
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float()
return comfy.utils.tiled_scale_multidim(samples, decode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, output_device=self.output_device)
def decode_tiled_3d(self, samples, tile_t=999, tile_x=32, tile_y=32, overlap=(1, 8, 8)):
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float()
return self.process_output(comfy.utils.tiled_scale_multidim(samples, decode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, output_device=self.output_device))
def encode_tiled_(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64):
steps = pixel_samples.shape[0] * comfy.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x, tile_y, overlap)
steps += pixel_samples.shape[0] * comfy.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x // 2, tile_y * 2, overlap)
@@ -315,6 +327,7 @@ class VAE:
return comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=(1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device)
def decode(self, samples_in):
pixel_samples = None
try:
memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype)
model_management.load_models_gpu([self.patcher], memory_required=memory_used)
@@ -322,16 +335,21 @@ class VAE:
batch_number = int(free_memory / memory_used)
batch_number = max(1, batch_number)
pixel_samples = torch.empty((samples_in.shape[0], self.output_channels) + tuple(map(lambda a: a * self.upscale_ratio, samples_in.shape[2:])), device=self.output_device)
for x in range(0, samples_in.shape[0], batch_number):
samples = samples_in[x:x+batch_number].to(self.vae_dtype).to(self.device)
pixel_samples[x:x+batch_number] = self.process_output(self.first_stage_model.decode(samples).to(self.output_device).float())
out = self.process_output(self.first_stage_model.decode(samples).to(self.output_device).float())
if pixel_samples is None:
pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device)
pixel_samples[x:x+batch_number] = out
except model_management.OOM_EXCEPTION as e:
logging.warning("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
if len(samples_in.shape) == 3:
dims = samples_in.ndim - 2
if dims == 1:
pixel_samples = self.decode_tiled_1d(samples_in)
else:
elif dims == 2:
pixel_samples = self.decode_tiled_(samples_in)
elif dims == 3:
pixel_samples = self.decode_tiled_3d(samples_in)
pixel_samples = pixel_samples.to(self.output_device).movedim(1,-1)
return pixel_samples
@@ -348,7 +366,7 @@ class VAE:
memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype)
model_management.load_models_gpu([self.patcher], memory_required=memory_used)
free_memory = model_management.get_free_memory(self.device)
batch_number = int(free_memory / memory_used)
batch_number = int(free_memory / max(1, memory_used))
batch_number = max(1, batch_number)
samples = torch.empty((pixel_samples.shape[0], self.latent_channels) + tuple(map(lambda a: a // self.downscale_ratio, pixel_samples.shape[2:])), device=self.output_device)
for x in range(0, pixel_samples.shape[0], batch_number):
@@ -399,6 +417,7 @@ class CLIPType(Enum):
STABLE_AUDIO = 4
HUNYUAN_DIT = 5
FLUX = 6
MOCHI = 7
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
clip_data = []
@@ -406,8 +425,46 @@ def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DI
clip_data.append(comfy.utils.load_torch_file(p, safe_load=True))
return load_text_encoder_state_dicts(clip_data, embedding_directory=embedding_directory, clip_type=clip_type, model_options=model_options)
class TEModel(Enum):
CLIP_L = 1
CLIP_H = 2
CLIP_G = 3
T5_XXL = 4
T5_XL = 5
T5_BASE = 6
def detect_te_model(sd):
if "text_model.encoder.layers.30.mlp.fc1.weight" in sd:
return TEModel.CLIP_G
if "text_model.encoder.layers.22.mlp.fc1.weight" in sd:
return TEModel.CLIP_H
if "text_model.encoder.layers.0.mlp.fc1.weight" in sd:
return TEModel.CLIP_L
if "encoder.block.23.layer.1.DenseReluDense.wi_1.weight" in sd:
weight = sd["encoder.block.23.layer.1.DenseReluDense.wi_1.weight"]
if weight.shape[-1] == 4096:
return TEModel.T5_XXL
elif weight.shape[-1] == 2048:
return TEModel.T5_XL
if "encoder.block.0.layer.0.SelfAttention.k.weight" in sd:
return TEModel.T5_BASE
return None
def t5xxl_detect(clip_data):
weight_name = "encoder.block.23.layer.1.DenseReluDense.wi_1.weight"
for sd in clip_data:
if weight_name in sd:
return comfy.text_encoders.sd3_clip.t5_xxl_detect(sd)
return {}
def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
clip_data = state_dicts
class EmptyClass:
pass
@@ -421,64 +478,65 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
clip_target = EmptyClass()
clip_target.params = {}
if len(clip_data) == 1:
if "text_model.encoder.layers.30.mlp.fc1.weight" in clip_data[0]:
te_model = detect_te_model(clip_data[0])
if te_model == TEModel.CLIP_G:
if clip_type == CLIPType.STABLE_CASCADE:
clip_target.clip = sdxl_clip.StableCascadeClipModel
clip_target.tokenizer = sdxl_clip.StableCascadeTokenizer
elif clip_type == CLIPType.SD3:
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=False, clip_g=True, t5=False)
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
else:
clip_target.clip = sdxl_clip.SDXLRefinerClipModel
clip_target.tokenizer = sdxl_clip.SDXLTokenizer
elif "text_model.encoder.layers.22.mlp.fc1.weight" in clip_data[0]:
elif te_model == TEModel.CLIP_H:
clip_target.clip = comfy.text_encoders.sd2_clip.SD2ClipModel
clip_target.tokenizer = comfy.text_encoders.sd2_clip.SD2Tokenizer
elif "encoder.block.23.layer.1.DenseReluDense.wi_1.weight" in clip_data[0]:
weight = clip_data[0]["encoder.block.23.layer.1.DenseReluDense.wi_1.weight"]
dtype_t5 = weight.dtype
if weight.shape[-1] == 4096:
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=False, clip_g=False, t5=True, dtype_t5=dtype_t5)
elif te_model == TEModel.T5_XXL:
if clip_type == CLIPType.SD3:
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=False, clip_g=False, t5=True, **t5xxl_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
elif weight.shape[-1] == 2048:
clip_target.clip = comfy.text_encoders.aura_t5.AuraT5Model
clip_target.tokenizer = comfy.text_encoders.aura_t5.AuraT5Tokenizer
elif "encoder.block.0.layer.0.SelfAttention.k.weight" in clip_data[0]:
else: #CLIPType.MOCHI
clip_target.clip = comfy.text_encoders.genmo.mochi_te(**t5xxl_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.genmo.MochiT5Tokenizer
elif te_model == TEModel.T5_XL:
clip_target.clip = comfy.text_encoders.aura_t5.AuraT5Model
clip_target.tokenizer = comfy.text_encoders.aura_t5.AuraT5Tokenizer
elif te_model == TEModel.T5_BASE:
clip_target.clip = comfy.text_encoders.sa_t5.SAT5Model
clip_target.tokenizer = comfy.text_encoders.sa_t5.SAT5Tokenizer
else:
w = clip_data[0].get("text_model.embeddings.position_embedding.weight", None)
if w is not None and w.shape[0] == 248:
clip_target.clip = comfy.text_encoders.long_clipl.LongClipModel
clip_target.tokenizer = comfy.text_encoders.long_clipl.LongClipTokenizer
if clip_type == CLIPType.SD3:
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=True, clip_g=False, t5=False)
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
else:
clip_target.clip = sd1_clip.SD1ClipModel
clip_target.tokenizer = sd1_clip.SD1Tokenizer
elif len(clip_data) == 2:
if clip_type == CLIPType.SD3:
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=True, clip_g=True, t5=False)
te_models = [detect_te_model(clip_data[0]), detect_te_model(clip_data[1])]
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=TEModel.CLIP_L in te_models, clip_g=TEModel.CLIP_G in te_models, t5=TEModel.T5_XXL in te_models, **t5xxl_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
elif clip_type == CLIPType.HUNYUAN_DIT:
clip_target.clip = comfy.text_encoders.hydit.HyditModel
clip_target.tokenizer = comfy.text_encoders.hydit.HyditTokenizer
elif clip_type == CLIPType.FLUX:
weight_name = "encoder.block.23.layer.1.DenseReluDense.wi_1.weight"
weight = clip_data[0].get(weight_name, clip_data[1].get(weight_name, None))
dtype_t5 = None
if weight is not None:
dtype_t5 = weight.dtype
clip_target.clip = comfy.text_encoders.flux.flux_clip(dtype_t5=dtype_t5)
clip_target.clip = comfy.text_encoders.flux.flux_clip(**t5xxl_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.flux.FluxTokenizer
else:
clip_target.clip = sdxl_clip.SDXLClipModel
clip_target.tokenizer = sdxl_clip.SDXLTokenizer
elif len(clip_data) == 3:
clip_target.clip = comfy.text_encoders.sd3_clip.SD3ClipModel
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(**t5xxl_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
parameters = 0
tokenizer_data = {}
for c in clip_data:
parameters += comfy.utils.calculate_parameters(c)
tokenizer_data, model_options = comfy.text_encoders.long_clipl.model_options_long_clip(c, tokenizer_data, model_options)
clip = CLIP(clip_target, embedding_directory=embedding_directory, parameters=parameters, model_options=model_options)
clip = CLIP(clip_target, embedding_directory=embedding_directory, parameters=parameters, tokenizer_data=tokenizer_data, model_options=model_options)
for c in clip_data:
m, u = clip.load_sd(c)
if len(m) > 0:
@@ -544,11 +602,11 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
return None
unet_weight_dtype = list(model_config.supported_inference_dtypes)
if weight_dtype is not None:
if weight_dtype is not None and model_config.scaled_fp8 is None:
unet_weight_dtype.append(weight_dtype)
model_config.custom_operations = model_options.get("custom_operations", None)
unet_dtype = model_options.get("weight_dtype", None)
unet_dtype = model_options.get("dtype", model_options.get("weight_dtype", None))
if unet_dtype is None:
unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=unet_weight_dtype)
@@ -562,7 +620,6 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
if output_model:
inital_load_device = model_management.unet_inital_load_device(parameters, unet_dtype)
offload_device = model_management.unet_offload_device()
model = model_config.get_model(sd, diffusion_model_prefix, device=inital_load_device)
model.load_model_weights(sd, diffusion_model_prefix)
@@ -614,6 +671,8 @@ def load_diffusion_model_state_dict(sd, model_options={}): #load unet in diffuse
sd = temp_sd
parameters = comfy.utils.calculate_parameters(sd)
weight_dtype = comfy.utils.weight_dtype(sd)
load_device = model_management.get_torch_device()
model_config = model_detection.model_config_from_unet(sd, "")
@@ -640,14 +699,21 @@ def load_diffusion_model_state_dict(sd, model_options={}): #load unet in diffuse
logging.warning("{} {}".format(diffusers_keys[k], k))
offload_device = model_management.unet_offload_device()
unet_weight_dtype = list(model_config.supported_inference_dtypes)
if weight_dtype is not None and model_config.scaled_fp8 is None:
unet_weight_dtype.append(weight_dtype)
if dtype is None:
unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=model_config.supported_inference_dtypes)
unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=unet_weight_dtype)
else:
unet_dtype = dtype
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
model_config.custom_operations = model_options.get("custom_operations", None)
model_config.custom_operations = model_options.get("custom_operations", model_config.custom_operations)
if model_options.get("fp8_optimizations", False):
model_config.optimizations["fp8"] = True
model = model_config.get_model(new_sd, "")
model = model.to(offload_device)
model.load_model_weights(new_sd, "")

View File

@@ -80,7 +80,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
"pooled",
"hidden"
]
def __init__(self, version="openai/clip-vit-large-patch14", device="cpu", max_length=77,
def __init__(self, device="cpu", max_length=77,
freeze=True, layer="last", layer_idx=None, textmodel_json_config=None, dtype=None, model_class=comfy.clip_model.CLIPTextModel,
special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=True, enable_attention_masks=False, zero_out_masked=False,
return_projected_pooled=True, return_attention_masks=False, model_options={}): # clip-vit-base-patch32
@@ -94,11 +94,20 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
config = json.load(f)
operations = model_options.get("custom_operations", None)
scaled_fp8 = None
if operations is None:
operations = comfy.ops.manual_cast
scaled_fp8 = model_options.get("scaled_fp8", None)
if scaled_fp8 is not None:
operations = comfy.ops.scaled_fp8_ops(fp8_matrix_mult=False, override_dtype=scaled_fp8)
else:
operations = comfy.ops.manual_cast
self.operations = operations
self.transformer = model_class(config, dtype, device, self.operations)
if scaled_fp8 is not None:
self.transformer.scaled_fp8 = torch.nn.Parameter(torch.tensor([], dtype=scaled_fp8))
self.num_layers = self.transformer.num_layers
self.max_length = max_length
@@ -542,6 +551,7 @@ class SD1Tokenizer:
def __init__(self, embedding_directory=None, tokenizer_data={}, clip_name="l", tokenizer=SDTokenizer):
self.clip_name = clip_name
self.clip = "clip_{}".format(self.clip_name)
tokenizer = tokenizer_data.get("{}_tokenizer_class".format(self.clip), tokenizer)
setattr(self, self.clip, tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data))
def tokenize_with_weights(self, text:str, return_word_ids=False):
@@ -570,6 +580,7 @@ class SD1ClipModel(torch.nn.Module):
self.clip_name = clip_name
self.clip = "clip_{}".format(self.clip_name)
clip_model = model_options.get("{}_class".format(self.clip), clip_model)
setattr(self, self.clip, clip_model(device=device, dtype=dtype, model_options=model_options, **kwargs))
self.dtypes = set()

View File

@@ -22,7 +22,8 @@ class SDXLClipGTokenizer(sd1_clip.SDTokenizer):
class SDXLTokenizer:
def __init__(self, embedding_directory=None, tokenizer_data={}):
self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory)
clip_l_tokenizer_class = tokenizer_data.get("clip_l_tokenizer_class", sd1_clip.SDTokenizer)
self.clip_l = clip_l_tokenizer_class(embedding_directory=embedding_directory)
self.clip_g = SDXLClipGTokenizer(embedding_directory=embedding_directory)
def tokenize_with_weights(self, text:str, return_word_ids=False):
@@ -40,7 +41,8 @@ class SDXLTokenizer:
class SDXLClipModel(torch.nn.Module):
def __init__(self, device="cpu", dtype=None, model_options={}):
super().__init__()
self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False, model_options=model_options)
clip_l_class = model_options.get("clip_l_class", sd1_clip.SDClipModel)
self.clip_l = clip_l_class(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False, model_options=model_options)
self.clip_g = SDXLClipG(device=device, dtype=dtype, model_options=model_options)
self.dtypes = set([dtype])
@@ -57,7 +59,8 @@ class SDXLClipModel(torch.nn.Module):
token_weight_pairs_l = token_weight_pairs["l"]
g_out, g_pooled = self.clip_g.encode_token_weights(token_weight_pairs_g)
l_out, l_pooled = self.clip_l.encode_token_weights(token_weight_pairs_l)
return torch.cat([l_out, g_out], dim=-1), g_pooled
cut_to = min(l_out.shape[1], g_out.shape[1])
return torch.cat([l_out[:,:cut_to], g_out[:,:cut_to]], dim=-1), g_pooled
def load_sd(self, sd):
if "text_model.encoder.layers.30.mlp.fc1.weight" in sd:

View File

@@ -10,6 +10,7 @@ import comfy.text_encoders.sa_t5
import comfy.text_encoders.aura_t5
import comfy.text_encoders.hydit
import comfy.text_encoders.flux
import comfy.text_encoders.genmo
from . import supported_models_base
from . import latent_formats
@@ -529,12 +530,11 @@ class SD3(supported_models_base.BASE):
clip_l = True
if "{}clip_g.transformer.text_model.final_layer_norm.weight".format(pref) in state_dict:
clip_g = True
t5_key = "{}t5xxl.transformer.encoder.final_layer_norm.weight".format(pref)
if t5_key in state_dict:
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
if "dtype_t5" in t5_detect:
t5 = True
dtype_t5 = state_dict[t5_key].dtype
return supported_models_base.ClipTarget(comfy.text_encoders.sd3_clip.SD3Tokenizer, comfy.text_encoders.sd3_clip.sd3_clip(clip_l=clip_l, clip_g=clip_g, t5=t5, dtype_t5=dtype_t5))
return supported_models_base.ClipTarget(comfy.text_encoders.sd3_clip.SD3Tokenizer, comfy.text_encoders.sd3_clip.sd3_clip(clip_l=clip_l, clip_g=clip_g, t5=t5, **t5_detect))
class StableAudio(supported_models_base.BASE):
unet_config = {
@@ -653,11 +653,8 @@ class Flux(supported_models_base.BASE):
def clip_target(self, state_dict={}):
pref = self.text_encoder_key_prefix[0]
t5_key = "{}t5xxl.transformer.encoder.final_layer_norm.weight".format(pref)
dtype_t5 = None
if t5_key in state_dict:
dtype_t5 = state_dict[t5_key].dtype
return supported_models_base.ClipTarget(comfy.text_encoders.flux.FluxTokenizer, comfy.text_encoders.flux.flux_clip(dtype_t5=dtype_t5))
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.flux.FluxTokenizer, comfy.text_encoders.flux.flux_clip(**t5_detect))
class FluxSchnell(Flux):
unet_config = {
@@ -674,7 +671,36 @@ class FluxSchnell(Flux):
out = model_base.Flux(self, model_type=model_base.ModelType.FLOW, device=device)
return out
class GenmoMochi(supported_models_base.BASE):
unet_config = {
"image_model": "mochi_preview",
}
models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, HunyuanDiT, HunyuanDiT1, Flux, FluxSchnell]
sampling_settings = {
"multiplier": 1.0,
"shift": 6.0,
}
unet_extra_config = {}
latent_format = latent_formats.Mochi
memory_usage_factor = 2.0 #TODO
supported_inference_dtypes = [torch.bfloat16, torch.float32]
vae_key_prefix = ["vae."]
text_encoder_key_prefix = ["text_encoders."]
def get_model(self, state_dict, prefix="", device=None):
out = model_base.GenmoMochi(self, device=device)
return out
def clip_target(self, state_dict={}):
pref = self.text_encoder_key_prefix[0]
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.genmo.MochiT5Tokenizer, comfy.text_encoders.genmo.mochi_te(**t5_detect))
models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, HunyuanDiT, HunyuanDiT1, Flux, FluxSchnell, GenmoMochi]
models += [SVD_img2vid]

View File

@@ -49,6 +49,8 @@ class BASE:
manual_cast_dtype = None
custom_operations = None
scaled_fp8 = None
optimizations = {"fp8": False}
@classmethod
def matches(s, unet_config, state_dict=None):
@@ -71,6 +73,7 @@ class BASE:
self.unet_config = unet_config.copy()
self.sampling_settings = self.sampling_settings.copy()
self.latent_format = self.latent_format()
self.optimizations = self.optimizations.copy()
for x in self.unet_extra_config:
self.unet_config[x] = self.unet_extra_config[x]

View File

@@ -1,24 +1,21 @@
from comfy import sd1_clip
import comfy.text_encoders.t5
import comfy.text_encoders.sd3_clip
import comfy.model_management
from transformers import T5TokenizerFast
import torch
import os
class T5XXLModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, model_options={}):
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_config_xxl.json")
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, model_options=model_options)
class T5XXLTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256)
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256)
class FluxTokenizer:
def __init__(self, embedding_directory=None, tokenizer_data={}):
self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory)
clip_l_tokenizer_class = tokenizer_data.get("clip_l_tokenizer_class", sd1_clip.SDTokenizer)
self.clip_l = clip_l_tokenizer_class(embedding_directory=embedding_directory)
self.t5xxl = T5XXLTokenizer(embedding_directory=embedding_directory)
def tokenize_with_weights(self, text:str, return_word_ids=False):
@@ -38,8 +35,9 @@ class FluxClipModel(torch.nn.Module):
def __init__(self, dtype_t5=None, device="cpu", dtype=None, model_options={}):
super().__init__()
dtype_t5 = comfy.model_management.pick_weight_dtype(dtype_t5, dtype, device)
self.clip_l = sd1_clip.SDClipModel(device=device, dtype=dtype, return_projected_pooled=False, model_options=model_options)
self.t5xxl = T5XXLModel(device=device, dtype=dtype_t5, model_options=model_options)
clip_l_class = model_options.get("clip_l_class", sd1_clip.SDClipModel)
self.clip_l = clip_l_class(device=device, dtype=dtype, return_projected_pooled=False, model_options=model_options)
self.t5xxl = comfy.text_encoders.sd3_clip.T5XXLModel(device=device, dtype=dtype_t5, model_options=model_options)
self.dtypes = set([dtype, dtype_t5])
def set_clip_options(self, options):
@@ -64,8 +62,11 @@ class FluxClipModel(torch.nn.Module):
else:
return self.t5xxl.load_sd(sd)
def flux_clip(dtype_t5=None):
def flux_clip(dtype_t5=None, t5xxl_scaled_fp8=None):
class FluxClipModel_(FluxClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options:
model_options = model_options.copy()
model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8
super().__init__(dtype_t5=dtype_t5, device=device, dtype=dtype, model_options=model_options)
return FluxClipModel_

View File

@@ -0,0 +1,38 @@
from comfy import sd1_clip
import comfy.text_encoders.sd3_clip
import os
from transformers import T5TokenizerFast
class T5XXLModel(comfy.text_encoders.sd3_clip.T5XXLModel):
def __init__(self, **kwargs):
kwargs["attention_mask"] = True
super().__init__(**kwargs)
class MochiT5XXL(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
super().__init__(device=device, dtype=dtype, clip_name="t5xxl", clip_model=T5XXLModel, model_options=model_options)
class T5XXLTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256)
class MochiT5Tokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="t5xxl", tokenizer=T5XXLTokenizer)
def mochi_te(dtype_t5=None, t5xxl_scaled_fp8=None):
class MochiTEModel_(MochiT5XXL):
def __init__(self, device="cpu", dtype=None, model_options={}):
if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options:
model_options = model_options.copy()
model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8
if dtype is None:
dtype = dtype_t5
super().__init__(device=device, dtype=dtype, model_options=model_options)
return MochiTEModel_

View File

@@ -6,9 +6,9 @@ class LongClipTokenizer_(sd1_clip.SDTokenizer):
super().__init__(max_length=248, embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
class LongClipModel_(sd1_clip.SDClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
def __init__(self, *args, **kwargs):
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "long_clipl.json")
super().__init__(device=device, textmodel_json_config=textmodel_json_config, return_projected_pooled=False, dtype=dtype, model_options=model_options)
super().__init__(*args, textmodel_json_config=textmodel_json_config, **kwargs)
class LongClipTokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
@@ -17,3 +17,14 @@ class LongClipTokenizer(sd1_clip.SD1Tokenizer):
class LongClipModel(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}, **kwargs):
super().__init__(device=device, dtype=dtype, model_options=model_options, clip_model=LongClipModel_, **kwargs)
def model_options_long_clip(sd, tokenizer_data, model_options):
w = sd.get("clip_l.text_model.embeddings.position_embedding.weight", None)
if w is None:
w = sd.get("text_model.embeddings.position_embedding.weight", None)
if w is not None and w.shape[0] == 248:
tokenizer_data = tokenizer_data.copy()
model_options = model_options.copy()
tokenizer_data["clip_l_tokenizer_class"] = LongClipTokenizer_
model_options["clip_l_class"] = LongClipModel_
return tokenizer_data, model_options

View File

@@ -8,9 +8,27 @@ import comfy.model_management
import logging
class T5XXLModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, model_options={}):
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=False, model_options={}):
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_config_xxl.json")
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, model_options=model_options)
t5xxl_scaled_fp8 = model_options.get("t5xxl_scaled_fp8", None)
if t5xxl_scaled_fp8 is not None:
model_options = model_options.copy()
model_options["scaled_fp8"] = t5xxl_scaled_fp8
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
def t5_xxl_detect(state_dict, prefix=""):
out = {}
t5_key = "{}encoder.final_layer_norm.weight".format(prefix)
if t5_key in state_dict:
out["dtype_t5"] = state_dict[t5_key].dtype
scaled_fp8_key = "{}scaled_fp8".format(prefix)
if scaled_fp8_key in state_dict:
out["t5xxl_scaled_fp8"] = state_dict[scaled_fp8_key].dtype
return out
class T5XXLTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
@@ -20,7 +38,8 @@ class T5XXLTokenizer(sd1_clip.SDTokenizer):
class SD3Tokenizer:
def __init__(self, embedding_directory=None, tokenizer_data={}):
self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory)
clip_l_tokenizer_class = tokenizer_data.get("clip_l_tokenizer_class", sd1_clip.SDTokenizer)
self.clip_l = clip_l_tokenizer_class(embedding_directory=embedding_directory)
self.clip_g = sdxl_clip.SDXLClipGTokenizer(embedding_directory=embedding_directory)
self.t5xxl = T5XXLTokenizer(embedding_directory=embedding_directory)
@@ -38,11 +57,12 @@ class SD3Tokenizer:
return {}
class SD3ClipModel(torch.nn.Module):
def __init__(self, clip_l=True, clip_g=True, t5=True, dtype_t5=None, device="cpu", dtype=None, model_options={}):
def __init__(self, clip_l=True, clip_g=True, t5=True, dtype_t5=None, t5_attention_mask=False, device="cpu", dtype=None, model_options={}):
super().__init__()
self.dtypes = set()
if clip_l:
self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False, return_projected_pooled=False, model_options=model_options)
clip_l_class = model_options.get("clip_l_class", sd1_clip.SDClipModel)
self.clip_l = clip_l_class(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False, return_projected_pooled=False, model_options=model_options)
self.dtypes.add(dtype)
else:
self.clip_l = None
@@ -55,7 +75,8 @@ class SD3ClipModel(torch.nn.Module):
if t5:
dtype_t5 = comfy.model_management.pick_weight_dtype(dtype_t5, dtype, device)
self.t5xxl = T5XXLModel(device=device, dtype=dtype_t5, model_options=model_options)
self.t5_attention_mask = t5_attention_mask
self.t5xxl = T5XXLModel(device=device, dtype=dtype_t5, model_options=model_options, attention_mask=self.t5_attention_mask)
self.dtypes.add(dtype_t5)
else:
self.t5xxl = None
@@ -85,6 +106,7 @@ class SD3ClipModel(torch.nn.Module):
lg_out = None
pooled = None
out = None
extra = {}
if len(token_weight_pairs_g) > 0 or len(token_weight_pairs_l) > 0:
if self.clip_l is not None:
@@ -95,7 +117,8 @@ class SD3ClipModel(torch.nn.Module):
if self.clip_g is not None:
g_out, g_pooled = self.clip_g.encode_token_weights(token_weight_pairs_g)
if lg_out is not None:
lg_out = torch.cat([lg_out, g_out], dim=-1)
cut_to = min(lg_out.shape[1], g_out.shape[1])
lg_out = torch.cat([lg_out[:,:cut_to], g_out[:,:cut_to]], dim=-1)
else:
lg_out = torch.nn.functional.pad(g_out, (768, 0))
else:
@@ -108,7 +131,11 @@ class SD3ClipModel(torch.nn.Module):
pooled = torch.cat((l_pooled, g_pooled), dim=-1)
if self.t5xxl is not None:
t5_out, t5_pooled = self.t5xxl.encode_token_weights(token_weight_pairs_t5)
t5_output = self.t5xxl.encode_token_weights(token_weight_pairs_t5)
t5_out, t5_pooled = t5_output[:2]
if self.t5_attention_mask:
extra["attention_mask"] = t5_output[2]["attention_mask"]
if lg_out is not None:
out = torch.cat([lg_out, t5_out], dim=-2)
else:
@@ -120,7 +147,7 @@ class SD3ClipModel(torch.nn.Module):
if pooled is None:
pooled = torch.zeros((1, 768 + 1280), device=comfy.model_management.intermediate_device())
return out, pooled
return out, pooled, extra
def load_sd(self, sd):
if "text_model.encoder.layers.30.mlp.fc1.weight" in sd:
@@ -130,8 +157,11 @@ class SD3ClipModel(torch.nn.Module):
else:
return self.t5xxl.load_sd(sd)
def sd3_clip(clip_l=True, clip_g=True, t5=True, dtype_t5=None):
def sd3_clip(clip_l=True, clip_g=True, t5=True, dtype_t5=None, t5xxl_scaled_fp8=None, t5_attention_mask=False):
class SD3ClipModel_(SD3ClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
super().__init__(clip_l=clip_l, clip_g=clip_g, t5=t5, dtype_t5=dtype_t5, device=device, dtype=dtype, model_options=model_options)
if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options:
model_options = model_options.copy()
model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8
super().__init__(clip_l=clip_l, clip_g=clip_g, t5=t5, dtype_t5=dtype_t5, t5_attention_mask=t5_attention_mask, device=device, dtype=dtype, model_options=model_options)
return SD3ClipModel_

View File

@@ -68,7 +68,7 @@ def weight_dtype(sd, prefix=""):
for k in sd.keys():
if k.startswith(prefix):
w = sd[k]
dtypes[w.dtype] = dtypes.get(w.dtype, 0) + 1
dtypes[w.dtype] = dtypes.get(w.dtype, 0) + w.numel()
if len(dtypes) == 0:
return None
@@ -528,6 +528,8 @@ def flux_to_diffusers(mmdit_config, output_prefix=""):
("guidance_in.out_layer.weight", "time_text_embed.guidance_embedder.linear_2.weight"),
("final_layer.adaLN_modulation.1.bias", "norm_out.linear.bias", swap_scale_shift),
("final_layer.adaLN_modulation.1.weight", "norm_out.linear.weight", swap_scale_shift),
("pos_embed_input.bias", "controlnet_x_embedder.bias"),
("pos_embed_input.weight", "controlnet_x_embedder.weight"),
}
for k in MAP_BASIC:
@@ -688,9 +690,14 @@ def lanczos(samples, width, height):
return result.to(samples.device, samples.dtype)
def common_upscale(samples, width, height, upscale_method, crop):
orig_shape = tuple(samples.shape)
if len(orig_shape) > 4:
samples = samples.reshape(samples.shape[0], samples.shape[1], -1, samples.shape[-2], samples.shape[-1])
samples = samples.movedim(2, 1)
samples = samples.reshape(-1, orig_shape[1], orig_shape[-2], orig_shape[-1])
if crop == "center":
old_width = samples.shape[3]
old_height = samples.shape[2]
old_width = samples.shape[-1]
old_height = samples.shape[-2]
old_aspect = old_width / old_height
new_aspect = width / height
x = 0
@@ -699,48 +706,87 @@ def common_upscale(samples, width, height, upscale_method, crop):
x = round((old_width - old_width * (new_aspect / old_aspect)) / 2)
elif old_aspect < new_aspect:
y = round((old_height - old_height * (old_aspect / new_aspect)) / 2)
s = samples[:,:,y:old_height-y,x:old_width-x]
s = samples.narrow(-2, y, old_height - y * 2).narrow(-1, x, old_width - x * 2)
else:
s = samples
if upscale_method == "bislerp":
return bislerp(s, width, height)
out = bislerp(s, width, height)
elif upscale_method == "lanczos":
return lanczos(s, width, height)
out = lanczos(s, width, height)
else:
return torch.nn.functional.interpolate(s, size=(height, width), mode=upscale_method)
out = torch.nn.functional.interpolate(s, size=(height, width), mode=upscale_method)
if len(orig_shape) == 4:
return out
out = out.reshape((orig_shape[0], -1, orig_shape[1]) + (height, width))
return out.movedim(2, 1).reshape(orig_shape[:-2] + (height, width))
def get_tiled_scale_steps(width, height, tile_x, tile_y, overlap):
return math.ceil((height / (tile_y - overlap))) * math.ceil((width / (tile_x - overlap)))
rows = 1 if height <= tile_y else math.ceil((height - overlap) / (tile_y - overlap))
cols = 1 if width <= tile_x else math.ceil((width - overlap) / (tile_x - overlap))
return rows * cols
@torch.inference_mode()
def tiled_scale_multidim(samples, function, tile=(64, 64), overlap = 8, upscale_amount = 4, out_channels = 3, output_device="cpu", pbar = None):
dims = len(tile)
output = torch.empty([samples.shape[0], out_channels] + list(map(lambda a: round(a * upscale_amount), samples.shape[2:])), device=output_device)
if not (isinstance(upscale_amount, (tuple, list))):
upscale_amount = [upscale_amount] * dims
if not (isinstance(overlap, (tuple, list))):
overlap = [overlap] * dims
def get_upscale(dim, val):
up = upscale_amount[dim]
if callable(up):
return up(val)
else:
return up * val
def mult_list_upscale(a):
out = []
for i in range(len(a)):
out.append(round(get_upscale(i, a[i])))
return out
output = torch.empty([samples.shape[0], out_channels] + mult_list_upscale(samples.shape[2:]), device=output_device)
for b in range(samples.shape[0]):
s = samples[b:b+1]
out = torch.zeros([s.shape[0], out_channels] + list(map(lambda a: round(a * upscale_amount), s.shape[2:])), device=output_device)
out_div = torch.zeros([s.shape[0], out_channels] + list(map(lambda a: round(a * upscale_amount), s.shape[2:])), device=output_device)
for it in itertools.product(*map(lambda a: range(0, a[0], a[1] - overlap), zip(s.shape[2:], tile))):
# handle entire input fitting in a single tile
if all(s.shape[d+2] <= tile[d] for d in range(dims)):
output[b:b+1] = function(s).to(output_device)
if pbar is not None:
pbar.update(1)
continue
out = torch.zeros([s.shape[0], out_channels] + mult_list_upscale(s.shape[2:]), device=output_device)
out_div = torch.zeros([s.shape[0], out_channels] + mult_list_upscale(s.shape[2:]), device=output_device)
positions = [range(0, s.shape[d+2], tile[d] - overlap[d]) if s.shape[d+2] > tile[d] else [0] for d in range(dims)]
for it in itertools.product(*positions):
s_in = s
upscaled = []
for d in range(dims):
pos = max(0, min(s.shape[d + 2] - overlap, it[d]))
pos = max(0, min(s.shape[d + 2] - (overlap[d] + 1), it[d]))
l = min(tile[d], s.shape[d + 2] - pos)
s_in = s_in.narrow(d + 2, pos, l)
upscaled.append(round(pos * upscale_amount))
upscaled.append(round(get_upscale(d, pos)))
ps = function(s_in).to(output_device)
mask = torch.ones_like(ps)
feather = round(overlap * upscale_amount)
for t in range(feather):
for d in range(2, dims + 2):
m = mask.narrow(d, t, 1)
m *= ((1.0/feather) * (t + 1))
m = mask.narrow(d, mask.shape[d] -1 -t, 1)
m *= ((1.0/feather) * (t + 1))
for d in range(2, dims + 2):
feather = round(get_upscale(d - 2, overlap[d - 2]))
for t in range(feather):
a = (t + 1) / feather
mask.narrow(d, t, 1).mul_(a)
mask.narrow(d, mask.shape[d] - 1 - t, 1).mul_(a)
o = out
o_d = out_div
@@ -748,8 +794,8 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap = 8, upscale_
o = o.narrow(d + 2, upscaled[d], mask.shape[d + 2])
o_d = o_d.narrow(d + 2, upscaled[d], mask.shape[d + 2])
o += ps * mask
o_d += mask
o.add_(ps * mask)
o_d.add_(mask)
if pbar is not None:
pbar.update(1)

View File

@@ -1,11 +1,21 @@
import itertools
from typing import Sequence, Mapping
from typing import Sequence, Mapping, Dict
from comfy_execution.graph import DynamicPrompt
import nodes
from comfy_execution.graph_utils import is_link
NODE_CLASS_CONTAINS_UNIQUE_ID: Dict[str, bool] = {}
def include_unique_id_in_input(class_type: str) -> bool:
if class_type in NODE_CLASS_CONTAINS_UNIQUE_ID:
return NODE_CLASS_CONTAINS_UNIQUE_ID[class_type]
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
NODE_CLASS_CONTAINS_UNIQUE_ID[class_type] = "UNIQUE_ID" in class_def.INPUT_TYPES().get("hidden", {}).values()
return NODE_CLASS_CONTAINS_UNIQUE_ID[class_type]
class CacheKeySet:
def __init__(self, dynprompt, node_ids, is_changed_cache):
self.keys = {}
@@ -56,6 +66,8 @@ class CacheKeySetID(CacheKeySet):
for node_id in node_ids:
if node_id in self.keys:
continue
if not self.dynprompt.has_node(node_id):
continue
node = self.dynprompt.get_node(node_id)
self.keys[node_id] = (node_id, node["class_type"])
self.subcache_keys[node_id] = (node_id, node["class_type"])
@@ -74,6 +86,8 @@ class CacheKeySetInputSignature(CacheKeySet):
for node_id in node_ids:
if node_id in self.keys:
continue
if not self.dynprompt.has_node(node_id):
continue
node = self.dynprompt.get_node(node_id)
self.keys[node_id] = self.get_node_signature(self.dynprompt, node_id)
self.subcache_keys[node_id] = (node_id, node["class_type"])
@@ -87,11 +101,14 @@ class CacheKeySetInputSignature(CacheKeySet):
return to_hashable(signature)
def get_immediate_node_signature(self, dynprompt, node_id, ancestor_order_mapping):
if not dynprompt.has_node(node_id):
# This node doesn't exist -- we can't cache it.
return [float("NaN")]
node = dynprompt.get_node(node_id)
class_type = node["class_type"]
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
signature = [class_type, self.is_changed_cache.get(node_id)]
if self.include_node_id_in_input() or (hasattr(class_def, "NOT_IDEMPOTENT") and class_def.NOT_IDEMPOTENT):
if self.include_node_id_in_input() or (hasattr(class_def, "NOT_IDEMPOTENT") and class_def.NOT_IDEMPOTENT) or include_unique_id_in_input(class_type):
signature.append(node_id)
inputs = node["inputs"]
for key in sorted(inputs.keys()):
@@ -112,6 +129,8 @@ class CacheKeySetInputSignature(CacheKeySet):
return ancestors, order_mapping
def get_ordered_ancestry_internal(self, dynprompt, node_id, ancestors, order_mapping):
if not dynprompt.has_node(node_id):
return
inputs = dynprompt.get_node(node_id)["inputs"]
input_keys = sorted(inputs.keys())
for key in input_keys:

View File

@@ -99,30 +99,44 @@ class TopologicalSort:
self.add_strong_link(from_node_id, from_socket, to_node_id)
def add_strong_link(self, from_node_id, from_socket, to_node_id):
self.add_node(from_node_id)
if to_node_id not in self.blocking[from_node_id]:
self.blocking[from_node_id][to_node_id] = {}
self.blockCount[to_node_id] += 1
self.blocking[from_node_id][to_node_id][from_socket] = True
if not self.is_cached(from_node_id):
self.add_node(from_node_id)
if to_node_id not in self.blocking[from_node_id]:
self.blocking[from_node_id][to_node_id] = {}
self.blockCount[to_node_id] += 1
self.blocking[from_node_id][to_node_id][from_socket] = True
def add_node(self, unique_id, include_lazy=False, subgraph_nodes=None):
if unique_id in self.pendingNodes:
return
self.pendingNodes[unique_id] = True
self.blockCount[unique_id] = 0
self.blocking[unique_id] = {}
def add_node(self, node_unique_id, include_lazy=False, subgraph_nodes=None):
node_ids = [node_unique_id]
links = []
inputs = self.dynprompt.get_node(unique_id)["inputs"]
for input_name in inputs:
value = inputs[input_name]
if is_link(value):
from_node_id, from_socket = value
if subgraph_nodes is not None and from_node_id not in subgraph_nodes:
continue
input_type, input_category, input_info = self.get_input_info(unique_id, input_name)
is_lazy = input_info is not None and "lazy" in input_info and input_info["lazy"]
if include_lazy or not is_lazy:
self.add_strong_link(from_node_id, from_socket, unique_id)
while len(node_ids) > 0:
unique_id = node_ids.pop()
if unique_id in self.pendingNodes:
continue
self.pendingNodes[unique_id] = True
self.blockCount[unique_id] = 0
self.blocking[unique_id] = {}
inputs = self.dynprompt.get_node(unique_id)["inputs"]
for input_name in inputs:
value = inputs[input_name]
if is_link(value):
from_node_id, from_socket = value
if subgraph_nodes is not None and from_node_id not in subgraph_nodes:
continue
input_type, input_category, input_info = self.get_input_info(unique_id, input_name)
is_lazy = input_info is not None and "lazy" in input_info and input_info["lazy"]
if (include_lazy or not is_lazy) and not self.is_cached(from_node_id):
node_ids.append(from_node_id)
links.append((from_node_id, from_socket, unique_id))
for link in links:
self.add_strong_link(*link)
def is_cached(self, node_id):
return False
def get_ready_nodes(self):
return [node_id for node_id in self.pendingNodes if self.blockCount[node_id] == 0]
@@ -146,11 +160,8 @@ class ExecutionList(TopologicalSort):
self.output_cache = output_cache
self.staged_node_id = None
def add_strong_link(self, from_node_id, from_socket, to_node_id):
if self.output_cache.get(from_node_id) is not None:
# Nothing to do
return
super().add_strong_link(from_node_id, from_socket, to_node_id)
def is_cached(self, node_id):
return self.output_cache.get(node_id) is not None
def stage_node_execution(self):
assert self.staged_node_id is None

View File

@@ -16,14 +16,15 @@ class EmptyLatentAudio:
@classmethod
def INPUT_TYPES(s):
return {"required": {"seconds": ("FLOAT", {"default": 47.6, "min": 1.0, "max": 1000.0, "step": 0.1})}}
return {"required": {"seconds": ("FLOAT", {"default": 47.6, "min": 1.0, "max": 1000.0, "step": 0.1}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096, "tooltip": "The number of latent images in the batch."}),
}}
RETURN_TYPES = ("LATENT",)
FUNCTION = "generate"
CATEGORY = "latent/audio"
def generate(self, seconds):
batch_size = 1
def generate(self, seconds, batch_size):
length = round((seconds * 44100 / 2048) / 2) * 2
latent = torch.zeros([batch_size, 64, length], device=self.device)
return ({"samples":latent, "type": "audio"}, )
@@ -58,6 +59,9 @@ class VAEDecodeAudio:
def decode(self, vae, samples):
audio = vae.decode(samples["samples"]).movedim(-1, 1)
std = torch.std(audio, dim=[1,2], keepdim=True) * 5.0
std[std < 1.0] = 1.0
audio /= std
return ({"waveform": audio, "sample_rate": 44100}, )
@@ -183,17 +187,10 @@ class PreviewAudio(SaveAudio):
}
class LoadAudio:
SUPPORTED_FORMATS = ('.wav', '.mp3', '.ogg', '.flac', '.aiff', '.aif')
@classmethod
def INPUT_TYPES(s):
input_dir = folder_paths.get_input_directory()
files = [
f for f in os.listdir(input_dir)
if (os.path.isfile(os.path.join(input_dir, f))
and f.endswith(LoadAudio.SUPPORTED_FORMATS)
)
]
files = folder_paths.filter_files_content_types(os.listdir(input_dir), ["audio", "video"])
return {"required": {"audio": (sorted(files), {"audio_upload": True})}}
CATEGORY = "audio"

View File

@@ -1,4 +1,6 @@
from comfy.cldm.control_types import UNION_CONTROLNET_TYPES
import nodes
import comfy.utils
class SetUnionControlNetType:
@classmethod
@@ -22,6 +24,37 @@ class SetUnionControlNetType:
return (control_net,)
class ControlNetInpaintingAliMamaApply(nodes.ControlNetApplyAdvanced):
@classmethod
def INPUT_TYPES(s):
return {"required": {"positive": ("CONDITIONING", ),
"negative": ("CONDITIONING", ),
"control_net": ("CONTROL_NET", ),
"vae": ("VAE", ),
"image": ("IMAGE", ),
"mask": ("MASK", ),
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001})
}}
FUNCTION = "apply_inpaint_controlnet"
CATEGORY = "conditioning/controlnet"
def apply_inpaint_controlnet(self, positive, negative, control_net, vae, image, mask, strength, start_percent, end_percent):
extra_concat = []
if control_net.concat_mask:
mask = 1.0 - mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1]))
mask_apply = comfy.utils.common_upscale(mask, image.shape[2], image.shape[1], "bilinear", "center").round()
image = image * mask_apply.movedim(1, -1).repeat(1, 1, 1, image.shape[3])
extra_concat = [mask]
return self.apply_controlnet(positive, negative, control_net, image, strength, start_percent, end_percent, vae=vae, extra_concat=extra_concat)
NODE_CLASS_MAPPINGS = {
"SetUnionControlNetType": SetUnionControlNetType,
"ControlNetInpaintingAliMamaApply": ControlNetInpaintingAliMamaApply,
}

View File

@@ -90,6 +90,27 @@ class PolyexponentialScheduler:
sigmas = k_diffusion_sampling.get_sigmas_polyexponential(n=steps, sigma_min=sigma_min, sigma_max=sigma_max, rho=rho)
return (sigmas, )
class LaplaceScheduler:
@classmethod
def INPUT_TYPES(s):
return {"required":
{"steps": ("INT", {"default": 20, "min": 1, "max": 10000}),
"sigma_max": ("FLOAT", {"default": 14.614642, "min": 0.0, "max": 5000.0, "step":0.01, "round": False}),
"sigma_min": ("FLOAT", {"default": 0.0291675, "min": 0.0, "max": 5000.0, "step":0.01, "round": False}),
"mu": ("FLOAT", {"default": 0.0, "min": -10.0, "max": 10.0, "step":0.1, "round": False}),
"beta": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 10.0, "step":0.1, "round": False}),
}
}
RETURN_TYPES = ("SIGMAS",)
CATEGORY = "sampling/custom_sampling/schedulers"
FUNCTION = "get_sigmas"
def get_sigmas(self, steps, sigma_max, sigma_min, mu, beta):
sigmas = k_diffusion_sampling.get_sigmas_laplace(n=steps, sigma_min=sigma_min, sigma_max=sigma_max, mu=mu, beta=beta)
return (sigmas, )
class SDTurboScheduler:
@classmethod
def INPUT_TYPES(s):
@@ -673,6 +694,7 @@ NODE_CLASS_MAPPINGS = {
"KarrasScheduler": KarrasScheduler,
"ExponentialScheduler": ExponentialScheduler,
"PolyexponentialScheduler": PolyexponentialScheduler,
"LaplaceScheduler": LaplaceScheduler,
"VPScheduler": VPScheduler,
"BetaSamplingScheduler": BetaSamplingScheduler,
"SDTurboScheduler": SDTurboScheduler,

View File

@@ -107,7 +107,7 @@ class HypernetworkLoader:
CATEGORY = "loaders"
def load_hypernetwork(self, model, hypernetwork_name, strength):
hypernetwork_path = folder_paths.get_full_path("hypernetworks", hypernetwork_name)
hypernetwork_path = folder_paths.get_full_path_or_raise("hypernetworks", hypernetwork_name)
model_hypernetwork = model.clone()
patch = load_hypernetwork_patch(hypernetwork_path, strength)
if patch is not None:

View File

@@ -1,4 +1,5 @@
import comfy.utils
import comfy_extras.nodes_post_processing
import torch
def reshape_latent_to(target_shape, latent):
@@ -145,6 +146,131 @@ class LatentBatchSeedBehavior:
return (samples_out,)
class LatentApplyOperation:
@classmethod
def INPUT_TYPES(s):
return {"required": { "samples": ("LATENT",),
"operation": ("LATENT_OPERATION",),
}}
RETURN_TYPES = ("LATENT",)
FUNCTION = "op"
CATEGORY = "latent/advanced/operations"
EXPERIMENTAL = True
def op(self, samples, operation):
samples_out = samples.copy()
s1 = samples["samples"]
samples_out["samples"] = operation(latent=s1)
return (samples_out,)
class LatentApplyOperationCFG:
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"operation": ("LATENT_OPERATION",),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
CATEGORY = "latent/advanced/operations"
EXPERIMENTAL = True
def patch(self, model, operation):
m = model.clone()
def pre_cfg_function(args):
conds_out = args["conds_out"]
if len(conds_out) == 2:
conds_out[0] = operation(latent=(conds_out[0] - conds_out[1])) + conds_out[1]
else:
conds_out[0] = operation(latent=conds_out[0])
return conds_out
m.set_model_sampler_pre_cfg_function(pre_cfg_function)
return (m, )
class LatentOperationTonemapReinhard:
@classmethod
def INPUT_TYPES(s):
return {"required": { "multiplier": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step": 0.01}),
}}
RETURN_TYPES = ("LATENT_OPERATION",)
FUNCTION = "op"
CATEGORY = "latent/advanced/operations"
EXPERIMENTAL = True
def op(self, multiplier):
def tonemap_reinhard(latent, **kwargs):
latent_vector_magnitude = (torch.linalg.vector_norm(latent, dim=(1)) + 0.0000000001)[:,None]
normalized_latent = latent / latent_vector_magnitude
mean = torch.mean(latent_vector_magnitude, dim=(1,2,3), keepdim=True)
std = torch.std(latent_vector_magnitude, dim=(1,2,3), keepdim=True)
top = (std * 5 + mean) * multiplier
#reinhard
latent_vector_magnitude *= (1.0 / top)
new_magnitude = latent_vector_magnitude / (latent_vector_magnitude + 1.0)
new_magnitude *= top
return normalized_latent * new_magnitude
return (tonemap_reinhard,)
class LatentOperationSharpen:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"sharpen_radius": ("INT", {
"default": 9,
"min": 1,
"max": 31,
"step": 1
}),
"sigma": ("FLOAT", {
"default": 1.0,
"min": 0.1,
"max": 10.0,
"step": 0.1
}),
"alpha": ("FLOAT", {
"default": 0.1,
"min": 0.0,
"max": 5.0,
"step": 0.01
}),
}}
RETURN_TYPES = ("LATENT_OPERATION",)
FUNCTION = "op"
CATEGORY = "latent/advanced/operations"
EXPERIMENTAL = True
def op(self, sharpen_radius, sigma, alpha):
def sharpen(latent, **kwargs):
luminance = (torch.linalg.vector_norm(latent, dim=(1)) + 1e-6)[:,None]
normalized_latent = latent / luminance
channels = latent.shape[1]
kernel_size = sharpen_radius * 2 + 1
kernel = comfy_extras.nodes_post_processing.gaussian_kernel(kernel_size, sigma, device=luminance.device)
center = kernel_size // 2
kernel *= alpha * -10
kernel[center, center] = kernel[center, center] - kernel.sum() + 1.0
padded_image = torch.nn.functional.pad(normalized_latent, (sharpen_radius,sharpen_radius,sharpen_radius,sharpen_radius), 'reflect')
sharpened = torch.nn.functional.conv2d(padded_image, kernel.repeat(channels, 1, 1).unsqueeze(1), padding=kernel_size // 2, groups=channels)[:,:,sharpen_radius:-sharpen_radius, sharpen_radius:-sharpen_radius]
return luminance * sharpened
return (sharpen,)
NODE_CLASS_MAPPINGS = {
"LatentAdd": LatentAdd,
"LatentSubtract": LatentSubtract,
@@ -152,4 +278,8 @@ NODE_CLASS_MAPPINGS = {
"LatentInterpolate": LatentInterpolate,
"LatentBatch": LatentBatch,
"LatentBatchSeedBehavior": LatentBatchSeedBehavior,
"LatentApplyOperation": LatentApplyOperation,
"LatentApplyOperationCFG": LatentApplyOperationCFG,
"LatentOperationTonemapReinhard": LatentOperationTonemapReinhard,
"LatentOperationSharpen": LatentOperationSharpen,
}

View File

@@ -0,0 +1,119 @@
import torch
import comfy.model_management
import comfy.utils
import folder_paths
import os
import logging
from enum import Enum
CLAMP_QUANTILE = 0.99
def extract_lora(diff, rank):
conv2d = (len(diff.shape) == 4)
kernel_size = None if not conv2d else diff.size()[2:4]
conv2d_3x3 = conv2d and kernel_size != (1, 1)
out_dim, in_dim = diff.size()[0:2]
rank = min(rank, in_dim, out_dim)
if conv2d:
if conv2d_3x3:
diff = diff.flatten(start_dim=1)
else:
diff = diff.squeeze()
U, S, Vh = torch.linalg.svd(diff.float())
U = U[:, :rank]
S = S[:rank]
U = U @ torch.diag(S)
Vh = Vh[:rank, :]
dist = torch.cat([U.flatten(), Vh.flatten()])
hi_val = torch.quantile(dist, CLAMP_QUANTILE)
low_val = -hi_val
U = U.clamp(low_val, hi_val)
Vh = Vh.clamp(low_val, hi_val)
if conv2d:
U = U.reshape(out_dim, rank, 1, 1)
Vh = Vh.reshape(rank, in_dim, kernel_size[0], kernel_size[1])
return (U, Vh)
class LORAType(Enum):
STANDARD = 0
FULL_DIFF = 1
LORA_TYPES = {"standard": LORAType.STANDARD,
"full_diff": LORAType.FULL_DIFF}
def calc_lora_model(model_diff, rank, prefix_model, prefix_lora, output_sd, lora_type, bias_diff=False):
comfy.model_management.load_models_gpu([model_diff], force_patch_weights=True)
sd = model_diff.model_state_dict(filter_prefix=prefix_model)
for k in sd:
if k.endswith(".weight"):
weight_diff = sd[k]
if lora_type == LORAType.STANDARD:
if weight_diff.ndim < 2:
if bias_diff:
output_sd["{}{}.diff".format(prefix_lora, k[len(prefix_model):-7])] = weight_diff.contiguous().half().cpu()
continue
try:
out = extract_lora(weight_diff, rank)
output_sd["{}{}.lora_up.weight".format(prefix_lora, k[len(prefix_model):-7])] = out[0].contiguous().half().cpu()
output_sd["{}{}.lora_down.weight".format(prefix_lora, k[len(prefix_model):-7])] = out[1].contiguous().half().cpu()
except:
logging.warning("Could not generate lora weights for key {}, is the weight difference a zero?".format(k))
elif lora_type == LORAType.FULL_DIFF:
output_sd["{}{}.diff".format(prefix_lora, k[len(prefix_model):-7])] = weight_diff.contiguous().half().cpu()
elif bias_diff and k.endswith(".bias"):
output_sd["{}{}.diff_b".format(prefix_lora, k[len(prefix_model):-5])] = sd[k].contiguous().half().cpu()
return output_sd
class LoraSave:
def __init__(self):
self.output_dir = folder_paths.get_output_directory()
@classmethod
def INPUT_TYPES(s):
return {"required": {"filename_prefix": ("STRING", {"default": "loras/ComfyUI_extracted_lora"}),
"rank": ("INT", {"default": 8, "min": 1, "max": 4096, "step": 1}),
"lora_type": (tuple(LORA_TYPES.keys()),),
"bias_diff": ("BOOLEAN", {"default": True}),
},
"optional": {"model_diff": ("MODEL", {"tooltip": "The ModelSubtract output to be converted to a lora."}),
"text_encoder_diff": ("CLIP", {"tooltip": "The CLIPSubtract output to be converted to a lora."})},
}
RETURN_TYPES = ()
FUNCTION = "save"
OUTPUT_NODE = True
CATEGORY = "_for_testing"
def save(self, filename_prefix, rank, lora_type, bias_diff, model_diff=None, text_encoder_diff=None):
if model_diff is None and text_encoder_diff is None:
return {}
lora_type = LORA_TYPES.get(lora_type)
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir)
output_sd = {}
if model_diff is not None:
output_sd = calc_lora_model(model_diff, rank, "diffusion_model.", "diffusion_model.", output_sd, lora_type, bias_diff=bias_diff)
if text_encoder_diff is not None:
output_sd = calc_lora_model(text_encoder_diff.patcher, rank, "", "text_encoders.", output_sd, lora_type, bias_diff=bias_diff)
output_checkpoint = f"{filename}_{counter:05}_.safetensors"
output_checkpoint = os.path.join(full_output_folder, output_checkpoint)
comfy.utils.save_torch_file(output_sd, output_checkpoint, metadata=None)
return {}
NODE_CLASS_MAPPINGS = {
"LoraSave": LoraSave
}
NODE_DISPLAY_NAME_MAPPINGS = {
"LoraSave": "Extract and Save Lora"
}

View File

@@ -0,0 +1,26 @@
import nodes
import torch
import comfy.model_management
class EmptyMochiLatentVideo:
def __init__(self):
self.device = comfy.model_management.intermediate_device()
@classmethod
def INPUT_TYPES(s):
return {"required": { "width": ("INT", {"default": 848, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
"height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
"length": ("INT", {"default": 25, "min": 7, "max": nodes.MAX_RESOLUTION, "step": 6}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096})}}
RETURN_TYPES = ("LATENT",)
FUNCTION = "generate"
CATEGORY = "latent/mochi"
def generate(self, width, height, length, batch_size=1):
latent = torch.zeros([batch_size, 12, ((length - 1) // 6) + 1, height // 8, width // 8], device=self.device)
return ({"samples":latent}, )
NODE_CLASS_MAPPINGS = {
"EmptyMochiLatentVideo": EmptyMochiLatentVideo,
}

View File

@@ -17,7 +17,7 @@ class PatchModelAddDownscale:
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
CATEGORY = "_for_testing"
CATEGORY = "model_patches/unet"
def patch(self, model, block_number, downscale_factor, start_percent, end_percent, downscale_after_skip, downscale_method, upscale_method):
model_sampling = model.get_model_object("model_sampling")

View File

@@ -101,10 +101,34 @@ class ModelMergeFlux1(comfy_extras.nodes_model_merging.ModelMergeBlocks):
return {"required": arg_dict}
class ModelMergeSD35_Large(comfy_extras.nodes_model_merging.ModelMergeBlocks):
CATEGORY = "advanced/model_merging/model_specific"
@classmethod
def INPUT_TYPES(s):
arg_dict = { "model1": ("MODEL",),
"model2": ("MODEL",)}
argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
arg_dict["pos_embed."] = argument
arg_dict["x_embedder."] = argument
arg_dict["context_embedder."] = argument
arg_dict["y_embedder."] = argument
arg_dict["t_embedder."] = argument
for i in range(38):
arg_dict["joint_blocks.{}.".format(i)] = argument
arg_dict["final_layer."] = argument
return {"required": arg_dict}
NODE_CLASS_MAPPINGS = {
"ModelMergeSD1": ModelMergeSD1,
"ModelMergeSD2": ModelMergeSD1, #SD1 and SD2 have the same blocks
"ModelMergeSDXL": ModelMergeSDXL,
"ModelMergeSD3_2B": ModelMergeSD3_2B,
"ModelMergeFlux1": ModelMergeFlux1,
"ModelMergeSD35_Large": ModelMergeSD35_Large,
}

View File

@@ -26,6 +26,7 @@ class PerpNeg:
FUNCTION = "patch"
CATEGORY = "_for_testing"
DEPRECATED = True
def patch(self, model, empty_conditioning, neg_scale):
m = model.clone()

View File

@@ -126,7 +126,7 @@ class PhotoMakerLoader:
CATEGORY = "_for_testing/photomaker"
def load_photomaker_model(self, photomaker_model_name):
photomaker_model_path = folder_paths.get_full_path("photomaker", photomaker_model_name)
photomaker_model_path = folder_paths.get_full_path_or_raise("photomaker", photomaker_model_name)
photomaker_model = PhotoMakerIDEncoder()
data = comfy.utils.load_torch_file(photomaker_model_path, safe_load=True)
if "id_encoder" in data:

View File

@@ -15,9 +15,9 @@ class TripleCLIPLoader:
CATEGORY = "advanced/loaders"
def load_clip(self, clip_name1, clip_name2, clip_name3):
clip_path1 = folder_paths.get_full_path("clip", clip_name1)
clip_path2 = folder_paths.get_full_path("clip", clip_name2)
clip_path3 = folder_paths.get_full_path("clip", clip_name3)
clip_path1 = folder_paths.get_full_path_or_raise("clip", clip_name1)
clip_path2 = folder_paths.get_full_path_or_raise("clip", clip_name2)
clip_path3 = folder_paths.get_full_path_or_raise("clip", clip_name3)
clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2, clip_path3], embedding_directory=folder_paths.get_folder_paths("embeddings"))
return (clip,)
@@ -36,7 +36,7 @@ class EmptySD3LatentImage:
CATEGORY = "latent/sd3"
def generate(self, width, height, batch_size=1):
latent = torch.ones([batch_size, 16, height // 8, width // 8], device=self.device) * 0.0609
latent = torch.zeros([batch_size, 16, height // 8, width // 8], device=self.device)
return ({"samples":latent}, )
class CLIPTextEncodeSD3:
@@ -93,6 +93,7 @@ class ControlNetApplySD3(nodes.ControlNetApplyAdvanced):
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001})
}}
CATEGORY = "conditioning/controlnet"
DEPRECATED = True
NODE_CLASS_MAPPINGS = {
"TripleCLIPLoader": TripleCLIPLoader,
@@ -103,5 +104,5 @@ NODE_CLASS_MAPPINGS = {
NODE_DISPLAY_NAME_MAPPINGS = {
# Sampling
"ControlNetApplySD3": "ControlNetApply SD3 and HunyuanDiT",
"ControlNetApplySD3": "Apply Controlnet with VAE",
}

View File

@@ -116,6 +116,7 @@ class StableCascade_SuperResolutionControlnet:
RETURN_NAMES = ("controlnet_input", "stage_c", "stage_b")
FUNCTION = "generate"
EXPERIMENTAL = True
CATEGORY = "_for_testing/stable_cascade"
def generate(self, image, vae):

View File

@@ -154,7 +154,7 @@ class TomePatchModel:
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
CATEGORY = "_for_testing"
CATEGORY = "model_patches/unet"
def patch(self, model, ratio):
self.u = None

View File

@@ -0,0 +1,22 @@
import torch
class TorchCompileModel:
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"backend": (["inductor", "cudagraphs"],),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
CATEGORY = "_for_testing"
EXPERIMENTAL = True
def patch(self, model, backend):
m = model.clone()
m.add_object_patch("diffusion_model", torch.compile(model=m.get_model_object("diffusion_model"), backend=backend))
return (m, )
NODE_CLASS_MAPPINGS = {
"TorchCompileModel": TorchCompileModel,
}

View File

@@ -25,7 +25,7 @@ class UpscaleModelLoader:
CATEGORY = "loaders"
def load_model(self, model_name):
model_path = folder_paths.get_full_path("upscale_models", model_name)
model_path = folder_paths.get_full_path_or_raise("upscale_models", model_name)
sd = comfy.utils.load_torch_file(model_path, safe_load=True)
if "module.layers.0.residual_group.blocks.0.norm1.weight" in sd:
sd = comfy.utils.state_dict_prefix_replace(sd, {"module.":""})

View File

@@ -17,7 +17,7 @@ class ImageOnlyCheckpointLoader:
CATEGORY = "loaders/video_models"
def load_checkpoint(self, ckpt_name, output_vae=True, output_clip=True):
ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name)
ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name)
out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=False, output_clipvision=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
return (out[0], out[3], out[2])
@@ -107,7 +107,7 @@ class VideoTriangleCFGGuidance:
return (m, )
class ImageOnlyCheckpointSave(comfy_extras.nodes_model_merging.CheckpointSave):
CATEGORY = "_for_testing"
CATEGORY = "advanced/model_merging"
@classmethod
def INPUT_TYPES(s):

View File

@@ -37,6 +37,7 @@ class SaveImageWebsocket:
return {}
@classmethod
def IS_CHANGED(s, images):
return time.time()

View File

@@ -179,7 +179,13 @@ def merge_result_data(results, obj):
# merge node execution results
for i, is_list in zip(range(len(results[0])), output_is_list):
if is_list:
output.append([x for o in results for x in o[i]])
value = []
for o in results:
if isinstance(o[i], ExecutionBlocker):
value.append(o[i])
else:
value.extend(o[i])
output.append(value)
else:
output.append([o[i] for o in results])
return output

View File

@@ -25,11 +25,16 @@ a111:
#comfyui:
# base_path: path/to/comfyui/
# # You can use is_default to mark that these folders should be listed first, and used as the default dirs for eg downloads
# #is_default: true
# checkpoints: models/checkpoints/
# clip: models/clip/
# clip_vision: models/clip_vision/
# configs: models/configs/
# controlnet: models/controlnet/
# diffusion_models: |
# models/diffusion_models
# models/unet
# embeddings: models/embeddings/
# loras: models/loras/
# upscale_models: models/upscale_models/

View File

@@ -2,7 +2,9 @@ from __future__ import annotations
import os
import time
import mimetypes
import logging
from typing import Set, List, Dict, Tuple, Literal
from collections.abc import Collection
supported_pt_extensions: set[str] = {'.ckpt', '.pt', '.bin', '.pth', '.safetensors', '.pkl', '.sft'}
@@ -44,6 +46,40 @@ user_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "user
filename_list_cache: dict[str, tuple[list[str], dict[str, float], float]] = {}
class CacheHelper:
"""
Helper class for managing file list cache data.
"""
def __init__(self):
self.cache: dict[str, tuple[list[str], dict[str, float], float]] = {}
self.active = False
def get(self, key: str, default=None) -> tuple[list[str], dict[str, float], float]:
if not self.active:
return default
return self.cache.get(key, default)
def set(self, key: str, value: tuple[list[str], dict[str, float], float]) -> None:
if self.active:
self.cache[key] = value
def clear(self):
self.cache.clear()
def __enter__(self):
self.active = True
return self
def __exit__(self, exc_type, exc_value, traceback):
self.active = False
self.clear()
cache_helper = CacheHelper()
extension_mimetypes_cache = {
"webp" : "image",
}
def map_legacy(folder_name: str) -> str:
legacy = {"unet": "diffusion_models"}
return legacy.get(folder_name, folder_name)
@@ -78,6 +114,13 @@ def get_input_directory() -> str:
global input_directory
return input_directory
def get_user_directory() -> str:
return user_directory
def set_user_directory(user_dir: str) -> None:
global user_directory
user_directory = user_dir
#NOTE: used in http server so don't put folders that should not be accessed remotely
def get_directory_by_type(type_name: str) -> str | None:
@@ -89,6 +132,28 @@ def get_directory_by_type(type_name: str) -> str | None:
return get_input_directory()
return None
def filter_files_content_types(files: List[str], content_types: Literal["image", "video", "audio"]) -> List[str]:
"""
Example:
files = os.listdir(folder_paths.get_input_directory())
filter_files_content_types(files, ["image", "audio", "video"])
"""
global extension_mimetypes_cache
result = []
for file in files:
extension = file.split('.')[-1]
if extension not in extension_mimetypes_cache:
mime_type, _ = mimetypes.guess_type(file, strict=False)
if not mime_type:
continue
content_type = mime_type.split('/')[0]
extension_mimetypes_cache[extension] = content_type
else:
content_type = extension_mimetypes_cache[extension]
if content_type in content_types:
result.append(file)
return result
# determine base_dir rely on annotation if name is 'filename.ext [annotation]' format
# otherwise use default_path as base_dir
@@ -130,11 +195,14 @@ def exists_annotated_filepath(name) -> bool:
return os.path.exists(filepath)
def add_model_folder_path(folder_name: str, full_folder_path: str) -> None:
def add_model_folder_path(folder_name: str, full_folder_path: str, is_default: bool = False) -> None:
global folder_names_and_paths
folder_name = map_legacy(folder_name)
if folder_name in folder_names_and_paths:
folder_names_and_paths[folder_name][0].append(full_folder_path)
if is_default:
folder_names_and_paths[folder_name][0].insert(0, full_folder_path)
else:
folder_names_and_paths[folder_name][0].append(full_folder_path)
else:
folder_names_and_paths[folder_name] = ([full_folder_path], set())
@@ -166,8 +234,12 @@ def recursive_search(directory: str, excluded_dir_names: list[str] | None=None)
for dirpath, subdirs, filenames in os.walk(directory, followlinks=True, topdown=True):
subdirs[:] = [d for d in subdirs if d not in excluded_dir_names]
for file_name in filenames:
relative_path = os.path.relpath(os.path.join(dirpath, file_name), directory)
result.append(relative_path)
try:
relative_path = os.path.relpath(os.path.join(dirpath, file_name), directory)
result.append(relative_path)
except:
logging.warning(f"Warning: Unable to access {file_name}. Skipping this file.")
continue
for d in subdirs:
path: str = os.path.join(dirpath, d)
@@ -200,6 +272,14 @@ def get_full_path(folder_name: str, filename: str) -> str | None:
return None
def get_full_path_or_raise(folder_name: str, filename: str) -> str:
full_path = get_full_path(folder_name, filename)
if full_path is None:
raise FileNotFoundError(f"Model in folder '{folder_name}' with filename '{filename}' not found.")
return full_path
def get_filename_list_(folder_name: str) -> tuple[list[str], dict[str, float], float]:
folder_name = map_legacy(folder_name)
global folder_names_and_paths
@@ -214,6 +294,10 @@ def get_filename_list_(folder_name: str) -> tuple[list[str], dict[str, float], f
return sorted(list(output_list)), output_folders, time.perf_counter()
def cached_filename_list_(folder_name: str) -> tuple[list[str], dict[str, float], float] | None:
strong_cache = cache_helper.get(folder_name)
if strong_cache is not None:
return strong_cache
global filename_list_cache
global folder_names_and_paths
folder_name = map_legacy(folder_name)
@@ -242,6 +326,7 @@ def get_filename_list(folder_name: str) -> list[str]:
out = get_filename_list_(folder_name)
global filename_list_cache
filename_list_cache[folder_name] = out
cache_helper.set(folder_name, out)
return list(out[0])
def get_save_image_path(filename_prefix: str, output_dir: str, image_width=0, image_height=0) -> tuple[str, str, int, str, str]:
@@ -257,9 +342,17 @@ def get_save_image_path(filename_prefix: str, output_dir: str, image_width=0, im
def compute_vars(input: str, image_width: int, image_height: int) -> str:
input = input.replace("%width%", str(image_width))
input = input.replace("%height%", str(image_height))
now = time.localtime()
input = input.replace("%year%", str(now.tm_year))
input = input.replace("%month%", str(now.tm_mon).zfill(2))
input = input.replace("%day%", str(now.tm_mday).zfill(2))
input = input.replace("%hour%", str(now.tm_hour).zfill(2))
input = input.replace("%minute%", str(now.tm_min).zfill(2))
input = input.replace("%second%", str(now.tm_sec).zfill(2))
return input
filename_prefix = compute_vars(filename_prefix, image_width, image_height)
if "%" in filename_prefix:
filename_prefix = compute_vars(filename_prefix, image_width, image_height)
subfolder = os.path.dirname(os.path.normpath(filename_prefix))
filename = os.path.basename(os.path.normpath(filename_prefix))

View File

@@ -9,7 +9,7 @@ import folder_paths
import comfy.utils
import logging
MAX_PREVIEW_RESOLUTION = 512
MAX_PREVIEW_RESOLUTION = args.preview_size
def preview_to_image(latent_image):
latents_ubyte = (((latent_image + 1.0) / 2.0).clamp(0, 1) # change scale from -1..1 to 0..1
@@ -36,12 +36,20 @@ class TAESDPreviewerImpl(LatentPreviewer):
class Latent2RGBPreviewer(LatentPreviewer):
def __init__(self, latent_rgb_factors):
self.latent_rgb_factors = torch.tensor(latent_rgb_factors, device="cpu")
def __init__(self, latent_rgb_factors, latent_rgb_factors_bias=None):
self.latent_rgb_factors = torch.tensor(latent_rgb_factors, device="cpu").transpose(0, 1)
self.latent_rgb_factors_bias = None
if latent_rgb_factors_bias is not None:
self.latent_rgb_factors_bias = torch.tensor(latent_rgb_factors_bias, device="cpu")
def decode_latent_to_preview(self, x0):
self.latent_rgb_factors = self.latent_rgb_factors.to(dtype=x0.dtype, device=x0.device)
latent_image = x0[0].permute(1, 2, 0) @ self.latent_rgb_factors
if self.latent_rgb_factors_bias is not None:
self.latent_rgb_factors_bias = self.latent_rgb_factors_bias.to(dtype=x0.dtype, device=x0.device)
latent_image = torch.nn.functional.linear(x0[0].permute(1, 2, 0), self.latent_rgb_factors, bias=self.latent_rgb_factors_bias)
# latent_image = x0[0].permute(1, 2, 0) @ self.latent_rgb_factors
return preview_to_image(latent_image)
@@ -71,7 +79,7 @@ def get_previewer(device, latent_format):
if previewer is None:
if latent_format.latent_rgb_factors is not None:
previewer = Latent2RGBPreviewer(latent_format.latent_rgb_factors)
previewer = Latent2RGBPreviewer(latent_format.latent_rgb_factors, latent_format.latent_rgb_factors_bias)
return previewer
def prepare_callback(model, steps, x0_output_dict=None):

45
main.py
View File

@@ -6,6 +6,10 @@ import importlib.util
import folder_paths
import time
from comfy.cli_args import args
from app.logger import setup_logger
setup_logger(log_level=args.verbose)
def execute_prestartup_script():
@@ -59,6 +63,7 @@ import threading
import gc
import logging
import utils.extra_config
if os.name == "nt":
logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())
@@ -81,7 +86,6 @@ if args.windows_standalone_build:
pass
import comfy.utils
import yaml
import execution
import server
@@ -156,7 +160,10 @@ def prompt_worker(q, server):
need_gc = False
async def run(server, address='', port=8188, verbose=True, call_on_start=None):
await asyncio.gather(server.start(address, port, verbose, call_on_start), server.publish_loop())
addresses = []
for addr in address.split(","):
addresses.append((addr, port))
await asyncio.gather(server.start_multi_address(addresses, call_on_start), server.publish_loop())
def hijack_progress(server):
@@ -176,27 +183,6 @@ def cleanup_temp():
shutil.rmtree(temp_dir, ignore_errors=True)
def load_extra_path_config(yaml_path):
with open(yaml_path, 'r') as stream:
config = yaml.safe_load(stream)
for c in config:
conf = config[c]
if conf is None:
continue
base_path = None
if "base_path" in conf:
base_path = conf.pop("base_path")
for x in conf:
for y in conf[x].split("\n"):
if len(y) == 0:
continue
full_path = y
if base_path is not None:
full_path = os.path.join(base_path, full_path)
logging.info("Adding extra search path {} {}".format(x, full_path))
folder_paths.add_model_folder_path(x, full_path)
if __name__ == "__main__":
if args.temp_directory:
temp_dir = os.path.join(os.path.abspath(args.temp_directory), "temp")
@@ -218,11 +204,11 @@ if __name__ == "__main__":
extra_model_paths_config_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "extra_model_paths.yaml")
if os.path.isfile(extra_model_paths_config_path):
load_extra_path_config(extra_model_paths_config_path)
utils.extra_config.load_extra_path_config(extra_model_paths_config_path)
if args.extra_model_paths_config:
for config_path in itertools.chain(*args.extra_model_paths_config):
load_extra_path_config(config_path)
utils.extra_config.load_extra_path_config(config_path)
nodes.init_extra_nodes(init_custom_nodes=not args.disable_all_custom_nodes)
@@ -243,21 +229,30 @@ if __name__ == "__main__":
folder_paths.add_model_folder_path("clip", os.path.join(folder_paths.get_output_directory(), "clip"))
folder_paths.add_model_folder_path("vae", os.path.join(folder_paths.get_output_directory(), "vae"))
folder_paths.add_model_folder_path("diffusion_models", os.path.join(folder_paths.get_output_directory(), "diffusion_models"))
folder_paths.add_model_folder_path("loras", os.path.join(folder_paths.get_output_directory(), "loras"))
if args.input_directory:
input_dir = os.path.abspath(args.input_directory)
logging.info(f"Setting input directory to: {input_dir}")
folder_paths.set_input_directory(input_dir)
if args.user_directory:
user_dir = os.path.abspath(args.user_directory)
logging.info(f"Setting user directory to: {user_dir}")
folder_paths.set_user_directory(user_dir)
if args.quick_test_for_ci:
exit(0)
os.makedirs(folder_paths.get_temp_directory(), exist_ok=True)
call_on_start = None
if args.auto_launch:
def startup_server(scheme, address, port):
import webbrowser
if os.name == 'nt' and address == '0.0.0.0':
address = '127.0.0.1'
if ':' in address:
address = "[{}]".format(address)
webbrowser.open(f"{scheme}://{address}:{port}")
call_on_start = startup_server

View File

@@ -1,2 +1,2 @@
# model_manager/__init__.py
from .download_models import download_model, DownloadModelStatus, DownloadStatusType, create_model_path, check_file_exists, track_download_progress, validate_model_subdirectory, validate_filename
from .download_models import download_model, DownloadModelStatus, DownloadStatusType, create_model_path, check_file_exists, track_download_progress, validate_filename

View File

@@ -1,9 +1,10 @@
#NOTE: This was an experiment and WILL BE REMOVED
from __future__ import annotations
import aiohttp
import os
import traceback
import logging
from folder_paths import models_dir
from folder_paths import folder_names_and_paths, get_folder_paths
import re
from typing import Callable, Any, Optional, Awaitable, Dict
from enum import Enum
@@ -17,6 +18,7 @@ class DownloadStatusType(Enum):
COMPLETED = "completed"
ERROR = "error"
@dataclass
class DownloadModelStatus():
status: str
@@ -29,7 +31,7 @@ class DownloadModelStatus():
self.progress_percentage = progress_percentage
self.message = message
self.already_existed = already_existed
def to_dict(self) -> Dict[str, Any]:
return {
"status": self.status,
@@ -38,102 +40,112 @@ class DownloadModelStatus():
"already_existed": self.already_existed
}
async def download_model(model_download_request: Callable[[str], Awaitable[aiohttp.ClientResponse]],
model_name: str,
model_url: str,
model_sub_directory: str,
model_name: str,
model_url: str,
model_directory: str,
folder_path: str,
progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]],
progress_interval: float = 1.0) -> DownloadModelStatus:
"""
Download a model file from a given URL into the models directory.
Args:
model_download_request (Callable[[str], Awaitable[aiohttp.ClientResponse]]):
model_download_request (Callable[[str], Awaitable[aiohttp.ClientResponse]]):
A function that makes an HTTP request. This makes it easier to mock in unit tests.
model_name (str):
model_name (str):
The name of the model file to be downloaded. This will be the filename on disk.
model_url (str):
model_url (str):
The URL from which to download the model.
model_sub_directory (str):
The subdirectory within the main models directory where the model
model_directory (str):
The subdirectory within the main models directory where the model
should be saved (e.g., 'checkpoints', 'loras', etc.).
progress_callback (Callable[[str, DownloadModelStatus], Awaitable[Any]]):
progress_callback (Callable[[str, DownloadModelStatus], Awaitable[Any]]):
An asynchronous function to call with progress updates.
folder_path (str);
Path to which model folder should be used as the root.
Returns:
DownloadModelStatus: The result of the download operation.
"""
if not validate_model_subdirectory(model_sub_directory):
return DownloadModelStatus(
DownloadStatusType.ERROR,
0,
"Invalid model subdirectory",
False
)
if not validate_filename(model_name):
return DownloadModelStatus(
DownloadStatusType.ERROR,
DownloadStatusType.ERROR,
0,
"Invalid model name",
"Invalid model name",
False
)
file_path, relative_path = create_model_path(model_name, model_sub_directory, models_dir)
existing_file = await check_file_exists(file_path, model_name, progress_callback, relative_path)
if not model_directory in folder_names_and_paths:
return DownloadModelStatus(
DownloadStatusType.ERROR,
0,
"Invalid or unrecognized model directory. model_directory must be a known model type (eg 'checkpoints'). If you are seeing this error for a custom model type, ensure the relevant custom nodes are installed and working.",
False
)
if not folder_path in get_folder_paths(model_directory):
return DownloadModelStatus(
DownloadStatusType.ERROR,
0,
f"Invalid folder path '{folder_path}', does not match the list of known directories ({get_folder_paths(model_directory)}). If you're seeing this in the downloader UI, you may need to refresh the page.",
False
)
file_path = create_model_path(model_name, folder_path)
existing_file = await check_file_exists(file_path, model_name, progress_callback)
if existing_file:
return existing_file
try:
logging.info(f"Downloading {model_name} from {model_url}")
status = DownloadModelStatus(DownloadStatusType.PENDING, 0, f"Starting download of {model_name}", False)
await progress_callback(relative_path, status)
await progress_callback(model_name, status)
response = await model_download_request(model_url)
if response.status != 200:
error_message = f"Failed to download {model_name}. Status code: {response.status}"
logging.error(error_message)
status = DownloadModelStatus(DownloadStatusType.ERROR, 0, error_message, False)
await progress_callback(relative_path, status)
await progress_callback(model_name, status)
return DownloadModelStatus(DownloadStatusType.ERROR, 0, error_message, False)
return await track_download_progress(response, file_path, model_name, progress_callback, relative_path, progress_interval)
return await track_download_progress(response, file_path, model_name, progress_callback, progress_interval)
except Exception as e:
logging.error(f"Error in downloading model: {e}")
return await handle_download_error(e, model_name, progress_callback, relative_path)
return await handle_download_error(e, model_name, progress_callback)
def create_model_path(model_name: str, model_directory: str, models_base_dir: str) -> tuple[str, str]:
full_model_dir = os.path.join(models_base_dir, model_directory)
os.makedirs(full_model_dir, exist_ok=True)
file_path = os.path.join(full_model_dir, model_name)
def create_model_path(model_name: str, folder_path: str) -> tuple[str, str]:
os.makedirs(folder_path, exist_ok=True)
file_path = os.path.join(folder_path, model_name)
# Ensure the resulting path is still within the base directory
abs_file_path = os.path.abspath(file_path)
abs_base_dir = os.path.abspath(str(models_base_dir))
abs_base_dir = os.path.abspath(folder_path)
if os.path.commonprefix([abs_file_path, abs_base_dir]) != abs_base_dir:
raise Exception(f"Invalid model directory: {model_directory}/{model_name}")
raise Exception(f"Invalid model directory: {folder_path}/{model_name}")
return file_path
relative_path = '/'.join([model_directory, model_name])
return file_path, relative_path
async def check_file_exists(file_path: str,
model_name: str,
progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]],
relative_path: str) -> Optional[DownloadModelStatus]:
async def check_file_exists(file_path: str,
model_name: str,
progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]]
) -> Optional[DownloadModelStatus]:
if os.path.exists(file_path):
status = DownloadModelStatus(DownloadStatusType.COMPLETED, 100, f"{model_name} already exists", True)
await progress_callback(relative_path, status)
await progress_callback(model_name, status)
return status
return None
async def track_download_progress(response: aiohttp.ClientResponse,
file_path: str,
model_name: str,
progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]],
relative_path: str,
async def track_download_progress(response: aiohttp.ClientResponse,
file_path: str,
model_name: str,
progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]],
interval: float = 1.0) -> DownloadModelStatus:
try:
total_size = int(response.headers.get('Content-Length', 0))
@@ -144,10 +156,11 @@ async def track_download_progress(response: aiohttp.ClientResponse,
nonlocal last_update_time
progress = (downloaded / total_size) * 100 if total_size > 0 else 0
status = DownloadModelStatus(DownloadStatusType.IN_PROGRESS, progress, f"Downloading {model_name}", False)
await progress_callback(relative_path, status)
await progress_callback(model_name, status)
last_update_time = time.time()
with open(file_path, 'wb') as f:
temp_file_path = file_path + '.tmp'
with open(temp_file_path, 'wb') as f:
chunk_iterator = response.content.iter_chunked(8192)
while True:
try:
@@ -156,58 +169,39 @@ async def track_download_progress(response: aiohttp.ClientResponse,
break
f.write(chunk)
downloaded += len(chunk)
if time.time() - last_update_time >= interval:
await update_progress()
os.rename(temp_file_path, file_path)
await update_progress()
logging.info(f"Successfully downloaded {model_name}. Total downloaded: {downloaded}")
status = DownloadModelStatus(DownloadStatusType.COMPLETED, 100, f"Successfully downloaded {model_name}", False)
await progress_callback(relative_path, status)
await progress_callback(model_name, status)
return status
except Exception as e:
logging.error(f"Error in track_download_progress: {e}")
logging.error(traceback.format_exc())
return await handle_download_error(e, model_name, progress_callback, relative_path)
return await handle_download_error(e, model_name, progress_callback)
async def handle_download_error(e: Exception,
model_name: str,
progress_callback: Callable[[str, DownloadModelStatus], Any],
relative_path: str) -> DownloadModelStatus:
async def handle_download_error(e: Exception,
model_name: str,
progress_callback: Callable[[str, DownloadModelStatus], Any]
) -> DownloadModelStatus:
error_message = f"Error downloading {model_name}: {str(e)}"
status = DownloadModelStatus(DownloadStatusType.ERROR, 0, error_message, False)
await progress_callback(relative_path, status)
await progress_callback(model_name, status)
return status
def validate_model_subdirectory(model_subdirectory: str) -> bool:
"""
Validate that the model subdirectory is safe to install into.
Must not contain relative paths, nested paths or special characters
other than underscores and hyphens.
Args:
model_subdirectory (str): The subdirectory for the specific model type.
Returns:
bool: True if the subdirectory is safe, False otherwise.
"""
if len(model_subdirectory) > 50:
return False
if '..' in model_subdirectory or '/' in model_subdirectory:
return False
if not re.match(r'^[a-zA-Z0-9_-]+$', model_subdirectory):
return False
return True
def validate_filename(filename: str)-> bool:
"""
Validate a filename to ensure it's safe and doesn't contain any path traversal attempts.
Args:
filename (str): The filename to validate

View File

@@ -281,7 +281,10 @@ class VAEDecode:
DESCRIPTION = "Decodes latent images back into pixel space images."
def decode(self, vae, samples):
return (vae.decode(samples["samples"]), )
images = vae.decode(samples["samples"])
if len(images.shape) == 5: #Combine batches
images = images.reshape(-1, images.shape[-3], images.shape[-2], images.shape[-1])
return (images, )
class VAEDecodeTiled:
@classmethod
@@ -511,10 +514,11 @@ class CheckpointLoader:
FUNCTION = "load_checkpoint"
CATEGORY = "advanced/loaders"
DEPRECATED = True
def load_checkpoint(self, config_name, ckpt_name):
config_path = folder_paths.get_full_path("configs", config_name)
ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name)
ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name)
return comfy.sd.load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
class CheckpointLoaderSimple:
@@ -535,7 +539,7 @@ class CheckpointLoaderSimple:
DESCRIPTION = "Loads a diffusion model checkpoint, diffusion models are used to denoise latents."
def load_checkpoint(self, ckpt_name):
ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name)
ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name)
out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
return out[:3]
@@ -577,7 +581,7 @@ class unCLIPCheckpointLoader:
CATEGORY = "loaders"
def load_checkpoint(self, ckpt_name, output_vae=True, output_clip=True):
ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name)
ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name)
out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
return out
@@ -624,7 +628,7 @@ class LoraLoader:
if strength_model == 0 and strength_clip == 0:
return (model, clip)
lora_path = folder_paths.get_full_path("loras", lora_name)
lora_path = folder_paths.get_full_path_or_raise("loras", lora_name)
lora = None
if self.loaded_lora is not None:
if self.loaded_lora[0] == lora_path:
@@ -703,11 +707,11 @@ class VAELoader:
encoder = next(filter(lambda a: a.startswith("{}_encoder.".format(name)), approx_vaes))
decoder = next(filter(lambda a: a.startswith("{}_decoder.".format(name)), approx_vaes))
enc = comfy.utils.load_torch_file(folder_paths.get_full_path("vae_approx", encoder))
enc = comfy.utils.load_torch_file(folder_paths.get_full_path_or_raise("vae_approx", encoder))
for k in enc:
sd["taesd_encoder.{}".format(k)] = enc[k]
dec = comfy.utils.load_torch_file(folder_paths.get_full_path("vae_approx", decoder))
dec = comfy.utils.load_torch_file(folder_paths.get_full_path_or_raise("vae_approx", decoder))
for k in dec:
sd["taesd_decoder.{}".format(k)] = dec[k]
@@ -738,7 +742,7 @@ class VAELoader:
if vae_name in ["taesd", "taesdxl", "taesd3", "taef1"]:
sd = self.load_taesd(vae_name)
else:
vae_path = folder_paths.get_full_path("vae", vae_name)
vae_path = folder_paths.get_full_path_or_raise("vae", vae_name)
sd = comfy.utils.load_torch_file(vae_path)
vae = comfy.sd.VAE(sd=sd)
return (vae,)
@@ -754,7 +758,7 @@ class ControlNetLoader:
CATEGORY = "loaders"
def load_controlnet(self, control_net_name):
controlnet_path = folder_paths.get_full_path("controlnet", control_net_name)
controlnet_path = folder_paths.get_full_path_or_raise("controlnet", control_net_name)
controlnet = comfy.controlnet.load_controlnet(controlnet_path)
return (controlnet,)
@@ -770,7 +774,7 @@ class DiffControlNetLoader:
CATEGORY = "loaders"
def load_controlnet(self, model, control_net_name):
controlnet_path = folder_paths.get_full_path("controlnet", control_net_name)
controlnet_path = folder_paths.get_full_path_or_raise("controlnet", control_net_name)
controlnet = comfy.controlnet.load_controlnet(controlnet_path, model)
return (controlnet,)
@@ -786,6 +790,7 @@ class ControlNetApply:
RETURN_TYPES = ("CONDITIONING",)
FUNCTION = "apply_controlnet"
DEPRECATED = True
CATEGORY = "conditioning/controlnet"
def apply_controlnet(self, conditioning, control_net, image, strength):
@@ -815,7 +820,10 @@ class ControlNetApplyAdvanced:
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001})
}}
},
"optional": {"vae": ("VAE", ),
}
}
RETURN_TYPES = ("CONDITIONING","CONDITIONING")
RETURN_NAMES = ("positive", "negative")
@@ -823,7 +831,7 @@ class ControlNetApplyAdvanced:
CATEGORY = "conditioning/controlnet"
def apply_controlnet(self, positive, negative, control_net, image, strength, start_percent, end_percent, vae=None):
def apply_controlnet(self, positive, negative, control_net, image, strength, start_percent, end_percent, vae=None, extra_concat=[]):
if strength == 0:
return (positive, negative)
@@ -840,7 +848,7 @@ class ControlNetApplyAdvanced:
if prev_cnet in cnets:
c_net = cnets[prev_cnet]
else:
c_net = control_net.copy().set_cond_hint(control_hint, strength, (start_percent, end_percent), vae)
c_net = control_net.copy().set_cond_hint(control_hint, strength, (start_percent, end_percent), vae=vae, extra_concat=extra_concat)
c_net.set_previous_controlnet(prev_cnet)
cnets[prev_cnet] = c_net
@@ -856,7 +864,7 @@ class UNETLoader:
@classmethod
def INPUT_TYPES(s):
return {"required": { "unet_name": (folder_paths.get_filename_list("diffusion_models"), ),
"weight_dtype": (["default", "fp8_e4m3fn", "fp8_e5m2"],)
"weight_dtype": (["default", "fp8_e4m3fn", "fp8_e4m3fn_fast", "fp8_e5m2"],)
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "load_unet"
@@ -867,10 +875,13 @@ class UNETLoader:
model_options = {}
if weight_dtype == "fp8_e4m3fn":
model_options["dtype"] = torch.float8_e4m3fn
elif weight_dtype == "fp8_e4m3fn_fast":
model_options["dtype"] = torch.float8_e4m3fn
model_options["fp8_optimizations"] = True
elif weight_dtype == "fp8_e5m2":
model_options["dtype"] = torch.float8_e5m2
unet_path = folder_paths.get_full_path("diffusion_models", unet_name)
unet_path = folder_paths.get_full_path_or_raise("diffusion_models", unet_name)
model = comfy.sd.load_diffusion_model(unet_path, model_options=model_options)
return (model,)
@@ -878,7 +889,7 @@ class CLIPLoader:
@classmethod
def INPUT_TYPES(s):
return {"required": { "clip_name": (folder_paths.get_filename_list("clip"), ),
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio"], ),
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi"], ),
}}
RETURN_TYPES = ("CLIP",)
FUNCTION = "load_clip"
@@ -892,10 +903,12 @@ class CLIPLoader:
clip_type = comfy.sd.CLIPType.SD3
elif type == "stable_audio":
clip_type = comfy.sd.CLIPType.STABLE_AUDIO
elif type == "mochi":
clip_type = comfy.sd.CLIPType.MOCHI
else:
clip_type = comfy.sd.CLIPType.STABLE_DIFFUSION
clip_path = folder_paths.get_full_path("clip", clip_name)
clip_path = folder_paths.get_full_path_or_raise("clip", clip_name)
clip = comfy.sd.load_clip(ckpt_paths=[clip_path], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type)
return (clip,)
@@ -912,8 +925,8 @@ class DualCLIPLoader:
CATEGORY = "advanced/loaders"
def load_clip(self, clip_name1, clip_name2, type):
clip_path1 = folder_paths.get_full_path("clip", clip_name1)
clip_path2 = folder_paths.get_full_path("clip", clip_name2)
clip_path1 = folder_paths.get_full_path_or_raise("clip", clip_name1)
clip_path2 = folder_paths.get_full_path_or_raise("clip", clip_name2)
if type == "sdxl":
clip_type = comfy.sd.CLIPType.STABLE_DIFFUSION
elif type == "sd3":
@@ -935,7 +948,7 @@ class CLIPVisionLoader:
CATEGORY = "loaders"
def load_clip(self, clip_name):
clip_path = folder_paths.get_full_path("clip_vision", clip_name)
clip_path = folder_paths.get_full_path_or_raise("clip_vision", clip_name)
clip_vision = comfy.clip_vision.load(clip_path)
return (clip_vision,)
@@ -965,7 +978,7 @@ class StyleModelLoader:
CATEGORY = "loaders"
def load_style_model(self, style_model_name):
style_model_path = folder_paths.get_full_path("style_models", style_model_name)
style_model_path = folder_paths.get_full_path_or_raise("style_models", style_model_name)
style_model = comfy.sd.load_style_model(style_model_path)
return (style_model,)
@@ -1030,7 +1043,7 @@ class GLIGENLoader:
CATEGORY = "loaders"
def load_gligen(self, gligen_name):
gligen_path = folder_paths.get_full_path("gligen", gligen_name)
gligen_path = folder_paths.get_full_path_or_raise("gligen", gligen_name)
gligen = comfy.sd.load_gligen(gligen_path)
return (gligen,)
@@ -1171,10 +1184,10 @@ class LatentUpscale:
if width == 0:
height = max(64, height)
width = max(64, round(samples["samples"].shape[3] * height / samples["samples"].shape[2]))
width = max(64, round(samples["samples"].shape[-1] * height / samples["samples"].shape[-2]))
elif height == 0:
width = max(64, width)
height = max(64, round(samples["samples"].shape[2] * width / samples["samples"].shape[3]))
height = max(64, round(samples["samples"].shape[-2] * width / samples["samples"].shape[-1]))
else:
width = max(64, width)
height = max(64, height)
@@ -1196,8 +1209,8 @@ class LatentUpscaleBy:
def upscale(self, samples, upscale_method, scale_by):
s = samples.copy()
width = round(samples["samples"].shape[3] * scale_by)
height = round(samples["samples"].shape[2] * scale_by)
width = round(samples["samples"].shape[-1] * scale_by)
height = round(samples["samples"].shape[-2] * scale_by)
s["samples"] = comfy.utils.common_upscale(samples["samples"], width, height, upscale_method, "disabled")
return (s,)
@@ -1916,8 +1929,8 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"ConditioningSetArea": "Conditioning (Set Area)",
"ConditioningSetAreaPercentage": "Conditioning (Set Area with Percentage)",
"ConditioningSetMask": "Conditioning (Set Mask)",
"ControlNetApply": "Apply ControlNet",
"ControlNetApplyAdvanced": "Apply ControlNet (Advanced)",
"ControlNetApply": "Apply ControlNet (OLD)",
"ControlNetApplyAdvanced": "Apply ControlNet",
# Latent
"VAEEncodeForInpaint": "VAE Encode (for Inpainting)",
"SetLatentNoiseMask": "Set Latent Noise Mask",
@@ -2101,6 +2114,9 @@ def init_builtin_extra_nodes():
"nodes_controlnet.py",
"nodes_hunyuan.py",
"nodes_flux.py",
"nodes_lora_extract.py",
"nodes_torch_compile.py",
"nodes_mochi.py",
]
import_failed = []
@@ -2129,3 +2145,5 @@ def init_extra_nodes(init_custom_nodes=True):
else:
logging.warning("Please do a: pip install -r requirements.txt")
logging.warning("")
return import_failed

View File

@@ -79,7 +79,7 @@
"#!wget -c https://huggingface.co/comfyanonymous/clip_vision_g/resolve/main/clip_vision_g.safetensors -P ./models/clip_vision/\n",
"\n",
"# SD1.5\n",
"!wget -c https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned-emaonly.ckpt -P ./models/checkpoints/\n",
"!wget -c https://huggingface.co/Comfy-Org/stable-diffusion-v1-5-archive/resolve/main/v1-5-pruned-emaonly-fp16.safetensors -P ./models/checkpoints/\n",
"\n",
"# SD2\n",
"#!wget -c https://huggingface.co/stabilityai/stable-diffusion-2-1-base/resolve/main/v2-1_512-ema-pruned.safetensors -P ./models/checkpoints/\n",

View File

@@ -43,7 +43,7 @@ prompt_text = """
"4": {
"class_type": "CheckpointLoaderSimple",
"inputs": {
"ckpt_name": "v1-5-pruned-emaonly.ckpt"
"ckpt_name": "v1-5-pruned-emaonly.safetensors"
}
},
"5": {

View File

@@ -38,18 +38,20 @@ def get_images(ws, prompt):
if data['node'] is None and data['prompt_id'] == prompt_id:
break #Execution is done
else:
# If you want to be able to decode the binary stream for latent previews, here is how you can do it:
# bytesIO = BytesIO(out[8:])
# preview_image = Image.open(bytesIO) # This is your preview in PIL image format, store it in a global
continue #previews are binary data
history = get_history(prompt_id)[prompt_id]
for o in history['outputs']:
for node_id in history['outputs']:
node_output = history['outputs'][node_id]
if 'images' in node_output:
images_output = []
for image in node_output['images']:
image_data = get_image(image['filename'], image['subfolder'], image['type'])
images_output.append(image_data)
output_images[node_id] = images_output
for node_id in history['outputs']:
node_output = history['outputs'][node_id]
images_output = []
if 'images' in node_output:
for image in node_output['images']:
image_data = get_image(image['filename'], image['subfolder'], image['type'])
images_output.append(image_data)
output_images[node_id] = images_output
return output_images
@@ -85,7 +87,7 @@ prompt_text = """
"4": {
"class_type": "CheckpointLoaderSimple",
"inputs": {
"ckpt_name": "v1-5-pruned-emaonly.ckpt"
"ckpt_name": "v1-5-pruned-emaonly.safetensors"
}
},
"5": {
@@ -152,7 +154,7 @@ prompt["3"]["inputs"]["seed"] = 5
ws = websocket.WebSocket()
ws.connect("ws://{}/ws?clientId={}".format(server_address, client_id))
images = get_images(ws, prompt)
ws.close() # for in case this example is used in an environment where it will be repeatedly called, like in a Gradio app. otherwise, you'll randomly receive connection timeouts
#Commented out code to display the output images:
# for node_id in images:

View File

@@ -81,7 +81,7 @@ prompt_text = """
"4": {
"class_type": "CheckpointLoaderSimple",
"inputs": {
"ckpt_name": "v1-5-pruned-emaonly.ckpt"
"ckpt_name": "v1-5-pruned-emaonly.safetensors"
}
},
"5": {
@@ -147,7 +147,7 @@ prompt["3"]["inputs"]["seed"] = 5
ws = websocket.WebSocket()
ws.connect("ws://{}/ws?clientId={}".format(server_address, client_id))
images = get_images(ws, prompt)
ws.close() # for in case this example is used in an environment where it will be repeatedly called, like in a Gradio app. otherwise, you'll randomly receive connection timeouts
#Commented out code to display the output images:
# for node_id in images:

158
server.py
View File

@@ -12,6 +12,8 @@ import json
import glob
import struct
import ssl
import socket
import ipaddress
from PIL import Image, ImageOps
from PIL.PngImagePlugin import PngInfo
from io import BytesIO
@@ -31,7 +33,6 @@ from model_filemanager import download_model, DownloadModelStatus
from typing import Optional
from api_server.routes.internal.internal_routes import InternalRoutes
class BinaryEventTypes:
PREVIEW_IMAGE = 1
UNENCODED_PREVIEW_IMAGE = 2
@@ -39,9 +40,24 @@ class BinaryEventTypes:
async def send_socket_catch_exception(function, message):
try:
await function(message)
except (aiohttp.ClientError, aiohttp.ClientPayloadError, ConnectionResetError) as err:
except (aiohttp.ClientError, aiohttp.ClientPayloadError, ConnectionResetError, BrokenPipeError, ConnectionError) as err:
logging.warning("send error: {}".format(err))
def get_comfyui_version():
comfyui_version = "unknown"
repo_path = os.path.dirname(os.path.realpath(__file__))
try:
import pygit2
repo = pygit2.Repository(repo_path)
comfyui_version = repo.describe(describe_strategy=pygit2.GIT_DESCRIBE_TAGS)
except Exception:
try:
import subprocess
comfyui_version = subprocess.check_output(["git", "describe", "--tags"], cwd=repo_path).decode('utf-8')
except Exception as e:
logging.warning(f"Failed to get ComfyUI version: {e}")
return comfyui_version.strip()
@web.middleware
async def cache_control(request: web.Request, handler):
response: web.Response = await handler(request)
@@ -66,6 +82,68 @@ def create_cors_middleware(allowed_origin: str):
return cors_middleware
def is_loopback(host):
if host is None:
return False
try:
if ipaddress.ip_address(host).is_loopback:
return True
else:
return False
except:
pass
loopback = False
for family in (socket.AF_INET, socket.AF_INET6):
try:
r = socket.getaddrinfo(host, None, family, socket.SOCK_STREAM)
for family, _, _, _, sockaddr in r:
if not ipaddress.ip_address(sockaddr[0]).is_loopback:
return loopback
else:
loopback = True
except socket.gaierror:
pass
return loopback
def create_origin_only_middleware():
@web.middleware
async def origin_only_middleware(request: web.Request, handler):
#this code is used to prevent the case where a random website can queue comfy workflows by making a POST to 127.0.0.1 which browsers don't prevent for some dumb reason.
#in that case the Host and Origin hostnames won't match
#I know the proper fix would be to add a cookie but this should take care of the problem in the meantime
if 'Host' in request.headers and 'Origin' in request.headers:
host = request.headers['Host']
origin = request.headers['Origin']
host_domain = host.lower()
parsed = urllib.parse.urlparse(origin)
origin_domain = parsed.netloc.lower()
host_domain_parsed = urllib.parse.urlsplit('//' + host_domain)
#limit the check to when the host domain is localhost, this makes it slightly less safe but should still prevent the exploit
loopback = is_loopback(host_domain_parsed.hostname)
if parsed.port is None: #if origin doesn't have a port strip it from the host to handle weird browsers, same for host
host_domain = host_domain_parsed.hostname
if host_domain_parsed.port is None:
origin_domain = parsed.hostname
if loopback and host_domain is not None and origin_domain is not None and len(host_domain) > 0 and len(origin_domain) > 0:
if host_domain != origin_domain:
logging.warning("WARNING: request with non matching host and origin {} != {}, returning 403".format(host_domain, origin_domain))
return web.Response(status=403)
if request.method == "OPTIONS":
response = web.Response()
else:
response = await handler(request)
return response
return origin_only_middleware
class PromptServer():
def __init__(self, loop):
PromptServer.instance = self
@@ -85,6 +163,8 @@ class PromptServer():
middlewares = [cache_control]
if args.enable_cors_header:
middlewares.append(create_cors_middleware(args.enable_cors_header))
else:
middlewares.append(create_origin_only_middleware())
max_upload_size = round(args.max_upload_size * 1024 * 1024)
self.app = web.Application(client_max_size=max_upload_size, middlewares=middlewares)
@@ -141,6 +221,12 @@ class PromptServer():
def get_embeddings(self):
embeddings = folder_paths.get_filename_list("embeddings")
return web.json_response(list(map(lambda a: os.path.splitext(a)[0], embeddings)))
@routes.get("/models")
def list_model_types(request):
model_types = list(folder_paths.folder_names_and_paths.keys())
return web.json_response(model_types)
@routes.get("/models/{folder}")
async def get_models(request):
@@ -401,16 +487,25 @@ class PromptServer():
return web.json_response(dt["__metadata__"])
@routes.get("/system_stats")
async def get_queue(request):
async def system_stats(request):
device = comfy.model_management.get_torch_device()
device_name = comfy.model_management.get_torch_device_name(device)
cpu_device = comfy.model_management.torch.device("cpu")
ram_total = comfy.model_management.get_total_memory(cpu_device)
ram_free = comfy.model_management.get_free_memory(cpu_device)
vram_total, torch_vram_total = comfy.model_management.get_total_memory(device, torch_total_too=True)
vram_free, torch_vram_free = comfy.model_management.get_free_memory(device, torch_free_too=True)
system_stats = {
"system": {
"os": os.name,
"ram_total": ram_total,
"ram_free": ram_free,
"comfyui_version": get_comfyui_version(),
"python_version": sys.version,
"embedded_python": os.path.split(os.path.split(sys.executable)[0])[1] == "python_embeded"
"pytorch_version": comfy.model_management.torch_version,
"embedded_python": os.path.split(os.path.split(sys.executable)[0])[1] == "python_embeded",
"argv": sys.argv
},
"devices": [
{
@@ -462,14 +557,15 @@ class PromptServer():
@routes.get("/object_info")
async def get_object_info(request):
out = {}
for x in nodes.NODE_CLASS_MAPPINGS:
try:
out[x] = node_info(x)
except Exception as e:
logging.error(f"[ERROR] An error occurred while retrieving information for the '{x}' node.")
logging.error(traceback.format_exc())
return web.json_response(out)
with folder_paths.cache_helper:
out = {}
for x in nodes.NODE_CLASS_MAPPINGS:
try:
out[x] = node_info(x)
except Exception as e:
logging.error(f"[ERROR] An error occurred while retrieving information for the '{x}' node.")
logging.error(traceback.format_exc())
return web.json_response(out)
@routes.get("/object_info/{node_class}")
async def get_object_info_node(request):
@@ -583,18 +679,22 @@ class PromptServer():
# Internal route. Should not be depended upon and is subject to change at any time.
# TODO(robinhuang): Move to internal route table class once we refactor PromptServer to pass around Websocket.
# NOTE: This was an experiment and WILL BE REMOVED
@routes.post("/internal/models/download")
async def download_handler(request):
async def report_progress(filename: str, status: DownloadModelStatus):
await self.send_json("download_progress", status.to_dict())
payload = status.to_dict()
payload['download_path'] = filename
await self.send_json("download_progress", payload)
data = await request.json()
url = data.get('url')
model_directory = data.get('model_directory')
folder_path = data.get('folder_path')
model_filename = data.get('model_filename')
progress_interval = data.get('progress_interval', 1.0) # In seconds, how often to report download progress.
if not url or not model_directory or not model_filename:
if not url or not model_directory or not model_filename or not folder_path:
return web.json_response({"status": "error", "message": "Missing URL or folder path or filename"}, status=400)
session = self.client_session
@@ -602,7 +702,7 @@ class PromptServer():
logging.error("Client session is not initialized")
return web.Response(status=500)
task = asyncio.create_task(download_model(lambda url: session.get(url), model_filename, url, model_directory, report_progress, progress_interval))
task = asyncio.create_task(download_model(lambda url: session.get(url), model_filename, url, model_directory, folder_path, report_progress, progress_interval))
await task
return web.json_response(task.result().to_dict())
@@ -719,6 +819,9 @@ class PromptServer():
await self.send(*msg)
async def start(self, address, port, verbose=True, call_on_start=None):
await self.start_multi_address([(address, port)], call_on_start=call_on_start)
async def start_multi_address(self, addresses, call_on_start=None):
runner = web.AppRunner(self.app, access_log=None)
await runner.setup()
ssl_ctx = None
@@ -729,17 +832,26 @@ class PromptServer():
keyfile=args.tls_keyfile)
scheme = "https"
site = web.TCPSite(runner, address, port, ssl_context=ssl_ctx)
await site.start()
logging.info("Starting server\n")
for addr in addresses:
address = addr[0]
port = addr[1]
site = web.TCPSite(runner, address, port, ssl_context=ssl_ctx)
await site.start()
self.address = address
self.port = port
if not hasattr(self, 'address'):
self.address = address #TODO: remove this
self.port = port
if ':' in address:
address_print = "[{}]".format(address)
else:
address_print = address
logging.info("To see the GUI go to: {}://{}:{}".format(scheme, address_print, port))
if verbose:
logging.info("Starting server\n")
logging.info("To see the GUI go to: {}://{}:{}".format(scheme, address, port))
if call_on_start is not None:
call_on_start(scheme, address, port)
call_on_start(scheme, self.address, self.port)
def add_on_prompt_handler(self, handler):
self.on_prompt_handlers.append(handler)

View File

@@ -2,7 +2,7 @@
## Install test dependencies
`pip install -r tests-units/requirements.txt`
`pip install -r tests-unit/requirements.txt`
## Run tests
`pytest tests-units/`
`pytest tests-unit/`

View File

@@ -1,6 +1,7 @@
import argparse
import pytest
from requests.exceptions import HTTPError
from unittest.mock import patch
from app.frontend_management import (
FrontendManager,
@@ -83,6 +84,35 @@ def test_init_frontend_invalid_provider():
with pytest.raises(HTTPError):
FrontendManager.init_frontend_unsafe(version_string)
@pytest.fixture
def mock_os_functions():
with patch('app.frontend_management.os.makedirs') as mock_makedirs, \
patch('app.frontend_management.os.listdir') as mock_listdir, \
patch('app.frontend_management.os.rmdir') as mock_rmdir:
mock_listdir.return_value = [] # Simulate empty directory
yield mock_makedirs, mock_listdir, mock_rmdir
@pytest.fixture
def mock_download():
with patch('app.frontend_management.download_release_asset_zip') as mock:
mock.side_effect = Exception("Download failed") # Simulate download failure
yield mock
def test_finally_block(mock_os_functions, mock_download, mock_provider):
# Arrange
mock_makedirs, mock_listdir, mock_rmdir = mock_os_functions
version_string = 'test-owner/test-repo@1.0.0'
# Act & Assert
with pytest.raises(Exception):
FrontendManager.init_frontend_unsafe(version_string, mock_provider)
# Assert
mock_makedirs.assert_called_once()
mock_download.assert_called_once()
mock_listdir.assert_called_once()
mock_rmdir.assert_called_once()
def test_parse_version_string():
version_string = "owner/repo@1.0.0"

View File

@@ -0,0 +1,66 @@
### 🗻 This file is created through the spirit of Mount Fuji at its peak
# TODO(yoland): clean up this after I get back down
import pytest
import os
import tempfile
from unittest.mock import patch
import folder_paths
@pytest.fixture
def temp_dir():
with tempfile.TemporaryDirectory() as tmpdirname:
yield tmpdirname
def test_get_directory_by_type():
test_dir = "/test/dir"
folder_paths.set_output_directory(test_dir)
assert folder_paths.get_directory_by_type("output") == test_dir
assert folder_paths.get_directory_by_type("invalid") is None
def test_annotated_filepath():
assert folder_paths.annotated_filepath("test.txt") == ("test.txt", None)
assert folder_paths.annotated_filepath("test.txt [output]") == ("test.txt", folder_paths.get_output_directory())
assert folder_paths.annotated_filepath("test.txt [input]") == ("test.txt", folder_paths.get_input_directory())
assert folder_paths.annotated_filepath("test.txt [temp]") == ("test.txt", folder_paths.get_temp_directory())
def test_get_annotated_filepath():
default_dir = "/default/dir"
assert folder_paths.get_annotated_filepath("test.txt", default_dir) == os.path.join(default_dir, "test.txt")
assert folder_paths.get_annotated_filepath("test.txt [output]") == os.path.join(folder_paths.get_output_directory(), "test.txt")
def test_add_model_folder_path():
folder_paths.add_model_folder_path("test_folder", "/test/path")
assert "/test/path" in folder_paths.get_folder_paths("test_folder")
def test_recursive_search(temp_dir):
os.makedirs(os.path.join(temp_dir, "subdir"))
open(os.path.join(temp_dir, "file1.txt"), "w").close()
open(os.path.join(temp_dir, "subdir", "file2.txt"), "w").close()
files, dirs = folder_paths.recursive_search(temp_dir)
assert set(files) == {"file1.txt", os.path.join("subdir", "file2.txt")}
assert len(dirs) == 2 # temp_dir and subdir
def test_filter_files_extensions():
files = ["file1.txt", "file2.jpg", "file3.png", "file4.txt"]
assert folder_paths.filter_files_extensions(files, [".txt"]) == ["file1.txt", "file4.txt"]
assert folder_paths.filter_files_extensions(files, [".jpg", ".png"]) == ["file2.jpg", "file3.png"]
assert folder_paths.filter_files_extensions(files, []) == files
@patch("folder_paths.recursive_search")
@patch("folder_paths.folder_names_and_paths")
def test_get_filename_list(mock_folder_names_and_paths, mock_recursive_search):
mock_folder_names_and_paths.__getitem__.return_value = (["/test/path"], {".txt"})
mock_recursive_search.return_value = (["file1.txt", "file2.jpg"], {})
assert folder_paths.get_filename_list("test_folder") == ["file1.txt"]
def test_get_save_image_path(temp_dir):
with patch("folder_paths.output_directory", temp_dir):
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path("test", temp_dir, 100, 100)
assert os.path.samefile(full_output_folder, temp_dir)
assert filename == "test"
assert counter == 1
assert subfolder == ""
assert filename_prefix == "test"

View File

View File

@@ -0,0 +1,52 @@
import pytest
import os
import tempfile
from folder_paths import filter_files_content_types
@pytest.fixture(scope="module")
def file_extensions():
return {
'image': ['gif', 'heif', 'ico', 'jpeg', 'jpg', 'png', 'pnm', 'ppm', 'svg', 'tiff', 'webp', 'xbm', 'xpm'],
'audio': ['aif', 'aifc', 'aiff', 'au', 'flac', 'm4a', 'mp2', 'mp3', 'ogg', 'snd', 'wav'],
'video': ['avi', 'm2v', 'm4v', 'mkv', 'mov', 'mp4', 'mpeg', 'mpg', 'ogv', 'qt', 'webm', 'wmv']
}
@pytest.fixture(scope="module")
def mock_dir(file_extensions):
with tempfile.TemporaryDirectory() as directory:
for content_type, extensions in file_extensions.items():
for extension in extensions:
with open(f"{directory}/sample_{content_type}.{extension}", "w") as f:
f.write(f"Sample {content_type} file in {extension} format")
yield directory
def test_categorizes_all_correctly(mock_dir, file_extensions):
files = os.listdir(mock_dir)
for content_type, extensions in file_extensions.items():
filtered_files = filter_files_content_types(files, [content_type])
for extension in extensions:
assert f"sample_{content_type}.{extension}" in filtered_files
def test_categorizes_all_uniquely(mock_dir, file_extensions):
files = os.listdir(mock_dir)
for content_type, extensions in file_extensions.items():
filtered_files = filter_files_content_types(files, [content_type])
assert len(filtered_files) == len(extensions)
def test_handles_bad_extensions():
files = ["file1.txt", "file2.py", "file3.example", "file4.pdf", "file5.ini", "file6.doc", "file7.md"]
assert filter_files_content_types(files, ["image", "audio", "video"]) == []
def test_handles_no_extension():
files = ["file1", "file2", "file3", "file4", "file5", "file6", "file7"]
assert filter_files_content_types(files, ["image", "audio", "video"]) == []
def test_handles_no_files():
files = []
assert filter_files_content_types(files, ["image", "audio", "video"]) == []

View File

@@ -1,10 +1,17 @@
import pytest
import tempfile
import aiohttp
from aiohttp import ClientResponse
import itertools
import os
import os
from unittest.mock import AsyncMock, patch, MagicMock
from model_filemanager import download_model, validate_model_subdirectory, track_download_progress, create_model_path, check_file_exists, DownloadStatusType, DownloadModelStatus, validate_filename
from model_filemanager import download_model, track_download_progress, create_model_path, check_file_exists, DownloadStatusType, DownloadModelStatus, validate_filename
import folder_paths
@pytest.fixture
def temp_dir():
with tempfile.TemporaryDirectory() as tmpdirname:
yield tmpdirname
class AsyncIteratorMock:
"""
@@ -42,7 +49,7 @@ class ContentMock:
return AsyncIteratorMock(self.chunks)
@pytest.mark.asyncio
async def test_download_model_success():
async def test_download_model_success(temp_dir):
mock_response = AsyncMock(spec=aiohttp.ClientResponse)
mock_response.status = 200
mock_response.headers = {'Content-Length': '1000'}
@@ -53,15 +60,13 @@ async def test_download_model_success():
mock_make_request = AsyncMock(return_value=mock_response)
mock_progress_callback = AsyncMock()
# Mock file operations
mock_open = MagicMock()
mock_file = MagicMock()
mock_open.return_value.__enter__.return_value = mock_file
time_values = itertools.count(0, 0.1)
with patch('model_filemanager.create_model_path', return_value=('models/checkpoints/model.sft', 'checkpoints/model.sft')), \
fake_paths = {'checkpoints': ([temp_dir], folder_paths.supported_pt_extensions)}
with patch('model_filemanager.create_model_path', return_value=('models/checkpoints/model.sft', 'model.sft')), \
patch('model_filemanager.check_file_exists', return_value=None), \
patch('builtins.open', mock_open), \
patch('folder_paths.folder_names_and_paths', fake_paths), \
patch('time.time', side_effect=time_values): # Simulate time passing
result = await download_model(
@@ -69,6 +74,7 @@ async def test_download_model_success():
'model.sft',
'http://example.com/model.sft',
'checkpoints',
temp_dir,
mock_progress_callback
)
@@ -83,44 +89,48 @@ async def test_download_model_success():
# Check initial call
mock_progress_callback.assert_any_call(
'checkpoints/model.sft',
'model.sft',
DownloadModelStatus(DownloadStatusType.PENDING, 0, "Starting download of model.sft", False)
)
# Check final call
mock_progress_callback.assert_any_call(
'checkpoints/model.sft',
'model.sft',
DownloadModelStatus(DownloadStatusType.COMPLETED, 100, "Successfully downloaded model.sft", False)
)
# Verify file writing
mock_file.write.assert_any_call(b'a' * 500)
mock_file.write.assert_any_call(b'b' * 300)
mock_file.write.assert_any_call(b'c' * 200)
mock_file_path = os.path.join(temp_dir, 'model.sft')
assert os.path.exists(mock_file_path)
with open(mock_file_path, 'rb') as mock_file:
assert mock_file.read() == b''.join(chunks)
os.remove(mock_file_path)
# Verify request was made
mock_make_request.assert_called_once_with('http://example.com/model.sft')
@pytest.mark.asyncio
async def test_download_model_url_request_failure():
async def test_download_model_url_request_failure(temp_dir):
# Mock dependencies
mock_response = AsyncMock(spec=ClientResponse)
mock_response.status = 404 # Simulate a "Not Found" error
mock_get = AsyncMock(return_value=mock_response)
mock_progress_callback = AsyncMock()
fake_paths = {'checkpoints': ([temp_dir], folder_paths.supported_pt_extensions)}
# Mock the create_model_path function
with patch('model_filemanager.create_model_path', return_value=('/mock/path/model.safetensors', 'mock/path/model.safetensors')):
# Mock the check_file_exists function to return None (file doesn't exist)
with patch('model_filemanager.check_file_exists', return_value=None):
# Call the function
result = await download_model(
mock_get,
'model.safetensors',
'http://example.com/model.safetensors',
'mock_directory',
mock_progress_callback
)
with patch('model_filemanager.create_model_path', return_value='/mock/path/model.safetensors'), \
patch('model_filemanager.check_file_exists', return_value=None), \
patch('folder_paths.folder_names_and_paths', fake_paths):
# Call the function
result = await download_model(
mock_get,
'model.safetensors',
'http://example.com/model.safetensors',
'checkpoints',
temp_dir,
mock_progress_callback
)
# Assert the expected behavior
assert isinstance(result, DownloadModelStatus)
@@ -130,7 +140,7 @@ async def test_download_model_url_request_failure():
# Check that progress_callback was called with the correct arguments
mock_progress_callback.assert_any_call(
'mock_directory/model.safetensors',
'model.safetensors',
DownloadModelStatus(
status=DownloadStatusType.PENDING,
progress_percentage=0,
@@ -139,7 +149,7 @@ async def test_download_model_url_request_failure():
)
)
mock_progress_callback.assert_called_with(
'mock_directory/model.safetensors',
'model.safetensors',
DownloadModelStatus(
status=DownloadStatusType.ERROR,
progress_percentage=0,
@@ -153,98 +163,125 @@ async def test_download_model_url_request_failure():
@pytest.mark.asyncio
async def test_download_model_invalid_model_subdirectory():
mock_make_request = AsyncMock()
mock_progress_callback = AsyncMock()
result = await download_model(
mock_make_request,
'model.sft',
'http://example.com/model.sft',
'../bad_path',
'../bad_path',
mock_progress_callback
)
# Assert the result
assert isinstance(result, DownloadModelStatus)
assert result.message == 'Invalid model subdirectory'
assert result.message.startswith('Invalid or unrecognized model directory')
assert result.status == 'error'
assert result.already_existed is False
@pytest.mark.asyncio
async def test_download_model_invalid_folder_path():
mock_make_request = AsyncMock()
mock_progress_callback = AsyncMock()
result = await download_model(
mock_make_request,
'model.sft',
'http://example.com/model.sft',
'checkpoints',
'invalid_path',
mock_progress_callback
)
# Assert the result
assert isinstance(result, DownloadModelStatus)
assert result.message.startswith("Invalid folder path")
assert result.status == 'error'
assert result.already_existed is False
# For create_model_path function
def test_create_model_path(tmp_path, monkeypatch):
mock_models_dir = tmp_path / "models"
monkeypatch.setattr('folder_paths.models_dir', str(mock_models_dir))
model_name = "test_model.sft"
model_directory = "test_dir"
file_path, relative_path = create_model_path(model_name, model_directory, mock_models_dir)
assert file_path == str(mock_models_dir / model_directory / model_name)
assert relative_path == f"{model_directory}/{model_name}"
model_name = "model.safetensors"
folder_path = os.path.join(tmp_path, "mock_dir")
file_path = create_model_path(model_name, folder_path)
assert file_path == os.path.join(folder_path, "model.safetensors")
assert os.path.exists(os.path.dirname(file_path))
with pytest.raises(Exception, match="Invalid model directory"):
create_model_path("../path_traversal.safetensors", folder_path)
with pytest.raises(Exception, match="Invalid model directory"):
create_model_path("/etc/some_root_path", folder_path)
@pytest.mark.asyncio
async def test_check_file_exists_when_file_exists(tmp_path):
file_path = tmp_path / "existing_model.sft"
file_path.touch() # Create an empty file
mock_callback = AsyncMock()
result = await check_file_exists(str(file_path), "existing_model.sft", mock_callback, "test/existing_model.sft")
result = await check_file_exists(str(file_path), "existing_model.sft", mock_callback)
assert result is not None
assert result.status == "completed"
assert result.message == "existing_model.sft already exists"
assert result.already_existed is True
mock_callback.assert_called_once_with(
"test/existing_model.sft",
"existing_model.sft",
DownloadModelStatus(DownloadStatusType.COMPLETED, 100, "existing_model.sft already exists", already_existed=True)
)
@pytest.mark.asyncio
async def test_check_file_exists_when_file_does_not_exist(tmp_path):
file_path = tmp_path / "non_existing_model.sft"
mock_callback = AsyncMock()
result = await check_file_exists(str(file_path), "non_existing_model.sft", mock_callback, "test/non_existing_model.sft")
result = await check_file_exists(str(file_path), "non_existing_model.sft", mock_callback)
assert result is None
mock_callback.assert_not_called()
@pytest.mark.asyncio
async def test_track_download_progress_no_content_length():
async def test_track_download_progress_no_content_length(temp_dir):
mock_response = AsyncMock(spec=aiohttp.ClientResponse)
mock_response.headers = {} # No Content-Length header
mock_response.content.iter_chunked.return_value = AsyncIteratorMock([b'a' * 500, b'b' * 500])
chunks = [b'a' * 500, b'b' * 500]
mock_response.content.iter_chunked.return_value = AsyncIteratorMock(chunks)
mock_callback = AsyncMock()
mock_open = MagicMock(return_value=MagicMock())
with patch('builtins.open', mock_open):
result = await track_download_progress(
mock_response, '/mock/path/model.sft', 'model.sft',
mock_callback, 'models/model.sft', interval=0.1
)
full_path = os.path.join(temp_dir, 'model.sft')
result = await track_download_progress(
mock_response, full_path, 'model.sft',
mock_callback, interval=0.1
)
assert result.status == "completed"
assert os.path.exists(full_path)
with open(full_path, 'rb') as f:
assert f.read() == b''.join(chunks)
os.remove(full_path)
# Check that progress was reported even without knowing the total size
mock_callback.assert_any_call(
'models/model.sft',
'model.sft',
DownloadModelStatus(DownloadStatusType.IN_PROGRESS, 0, "Downloading model.sft", already_existed=False)
)
@pytest.mark.asyncio
async def test_track_download_progress_interval():
async def test_track_download_progress_interval(temp_dir):
mock_response = AsyncMock(spec=aiohttp.ClientResponse)
mock_response.headers = {'Content-Length': '1000'}
mock_response.content.iter_chunked.return_value = AsyncIteratorMock([b'a' * 100] * 10)
chunks = [b'a' * 100] * 10
mock_response.content.iter_chunked.return_value = AsyncIteratorMock(chunks)
mock_callback = AsyncMock()
mock_open = MagicMock(return_value=MagicMock())
@@ -253,18 +290,18 @@ async def test_track_download_progress_interval():
mock_time = MagicMock()
mock_time.side_effect = [i * 0.5 for i in range(30)] # This should be enough for 10 chunks
with patch('builtins.open', mock_open), \
patch('time.time', mock_time):
await track_download_progress(
mock_response, '/mock/path/model.sft', 'model.sft',
mock_callback, 'models/model.sft', interval=1.0
)
full_path = os.path.join(temp_dir, 'model.sft')
# Print out the actual call count and the arguments of each call for debugging
print(f"mock_callback was called {mock_callback.call_count} times")
for i, call in enumerate(mock_callback.call_args_list):
args, kwargs = call
print(f"Call {i + 1}: {args[1].status}, Progress: {args[1].progress_percentage:.2f}%")
with patch('time.time', mock_time):
await track_download_progress(
mock_response, full_path, 'model.sft',
mock_callback, interval=1.0
)
assert os.path.exists(full_path)
with open(full_path, 'rb') as f:
assert f.read() == b''.join(chunks)
os.remove(full_path)
# Assert that progress was updated at least 3 times (start, at least one interval, and end)
assert mock_callback.call_count >= 3, f"Expected at least 3 calls, but got {mock_callback.call_count}"
@@ -279,27 +316,6 @@ async def test_track_download_progress_interval():
assert last_call[0][1].status == "completed"
assert last_call[0][1].progress_percentage == 100
def test_valid_subdirectory():
assert validate_model_subdirectory("valid-model123") is True
def test_subdirectory_too_long():
assert validate_model_subdirectory("a" * 51) is False
def test_subdirectory_with_double_dots():
assert validate_model_subdirectory("model/../unsafe") is False
def test_subdirectory_with_slash():
assert validate_model_subdirectory("model/unsafe") is False
def test_subdirectory_with_special_characters():
assert validate_model_subdirectory("model@unsafe") is False
def test_subdirectory_with_underscore_and_dash():
assert validate_model_subdirectory("valid_model-name") is True
def test_empty_subdirectory():
assert validate_model_subdirectory("") is False
@pytest.mark.parametrize("filename, expected", [
("valid_model.safetensors", True),
("valid_model.sft", True),

View File

@@ -0,0 +1,120 @@
import pytest
import os
from aiohttp import web
from app.user_manager import UserManager
from unittest.mock import patch
pytestmark = (
pytest.mark.asyncio
) # This applies the asyncio mark to all test functions in the module
@pytest.fixture
def user_manager(tmp_path):
um = UserManager()
um.get_request_user_filepath = lambda req, file, **kwargs: os.path.join(
tmp_path, file
)
return um
@pytest.fixture
def app(user_manager):
app = web.Application()
routes = web.RouteTableDef()
user_manager.add_routes(routes)
app.add_routes(routes)
return app
async def test_listuserdata_empty_directory(aiohttp_client, app, tmp_path):
client = await aiohttp_client(app)
resp = await client.get("/userdata?dir=test_dir")
assert resp.status == 404
async def test_listuserdata_with_files(aiohttp_client, app, tmp_path):
os.makedirs(tmp_path / "test_dir")
with open(tmp_path / "test_dir" / "file1.txt", "w") as f:
f.write("test content")
client = await aiohttp_client(app)
resp = await client.get("/userdata?dir=test_dir")
assert resp.status == 200
assert await resp.json() == ["file1.txt"]
async def test_listuserdata_recursive(aiohttp_client, app, tmp_path):
os.makedirs(tmp_path / "test_dir" / "subdir")
with open(tmp_path / "test_dir" / "file1.txt", "w") as f:
f.write("test content")
with open(tmp_path / "test_dir" / "subdir" / "file2.txt", "w") as f:
f.write("test content")
client = await aiohttp_client(app)
resp = await client.get("/userdata?dir=test_dir&recurse=true")
assert resp.status == 200
assert set(await resp.json()) == {"file1.txt", "subdir/file2.txt"}
async def test_listuserdata_full_info(aiohttp_client, app, tmp_path):
os.makedirs(tmp_path / "test_dir")
with open(tmp_path / "test_dir" / "file1.txt", "w") as f:
f.write("test content")
client = await aiohttp_client(app)
resp = await client.get("/userdata?dir=test_dir&full_info=true")
assert resp.status == 200
result = await resp.json()
assert len(result) == 1
assert result[0]["path"] == "file1.txt"
assert "size" in result[0]
assert "modified" in result[0]
async def test_listuserdata_split_path(aiohttp_client, app, tmp_path):
os.makedirs(tmp_path / "test_dir" / "subdir")
with open(tmp_path / "test_dir" / "subdir" / "file1.txt", "w") as f:
f.write("test content")
client = await aiohttp_client(app)
resp = await client.get("/userdata?dir=test_dir&recurse=true&split=true")
assert resp.status == 200
assert await resp.json() == [
["subdir/file1.txt", "subdir", "file1.txt"]
]
async def test_listuserdata_invalid_directory(aiohttp_client, app):
client = await aiohttp_client(app)
resp = await client.get("/userdata?dir=")
assert resp.status == 400
async def test_listuserdata_normalized_separator(aiohttp_client, app, tmp_path):
os_sep = "\\"
with patch("os.sep", os_sep):
with patch("os.path.sep", os_sep):
os.makedirs(tmp_path / "test_dir" / "subdir")
with open(tmp_path / "test_dir" / "subdir" / "file1.txt", "w") as f:
f.write("test content")
client = await aiohttp_client(app)
resp = await client.get("/userdata?dir=test_dir&recurse=true")
assert resp.status == 200
result = await resp.json()
assert len(result) == 1
assert "/" in result[0] # Ensure forward slash is used
assert "\\" not in result[0] # Ensure backslash is not present
assert result[0] == "subdir/file1.txt"
# Test with full_info
resp = await client.get(
"/userdata?dir=test_dir&recurse=true&full_info=true"
)
assert resp.status == 200
result = await resp.json()
assert len(result) == 1
assert "/" in result[0]["path"] # Ensure forward slash is used
assert "\\" not in result[0]["path"] # Ensure backslash is not present
assert result[0]["path"] == "subdir/file1.txt"

View File

@@ -0,0 +1,126 @@
import pytest
import yaml
import os
from unittest.mock import Mock, patch, mock_open
from utils.extra_config import load_extra_path_config
import folder_paths
@pytest.fixture
def mock_yaml_content():
return {
'test_config': {
'base_path': '~/App/',
'checkpoints': 'subfolder1',
}
}
@pytest.fixture
def mock_expanded_home():
return '/home/user'
@pytest.fixture
def yaml_config_with_appdata():
return """
test_config:
base_path: '%APPDATA%/ComfyUI'
checkpoints: 'models/checkpoints'
"""
@pytest.fixture
def mock_yaml_content_appdata(yaml_config_with_appdata):
return yaml.safe_load(yaml_config_with_appdata)
@pytest.fixture
def mock_expandvars_appdata():
mock = Mock()
mock.side_effect = lambda path: path.replace('%APPDATA%', 'C:/Users/TestUser/AppData/Roaming')
return mock
@pytest.fixture
def mock_add_model_folder_path():
return Mock()
@pytest.fixture
def mock_expanduser(mock_expanded_home):
def _expanduser(path):
if path.startswith('~/'):
return os.path.join(mock_expanded_home, path[2:])
return path
return _expanduser
@pytest.fixture
def mock_yaml_safe_load(mock_yaml_content):
return Mock(return_value=mock_yaml_content)
@patch('builtins.open', new_callable=mock_open, read_data="dummy file content")
def test_load_extra_model_paths_expands_userpath(
mock_file,
monkeypatch,
mock_add_model_folder_path,
mock_expanduser,
mock_yaml_safe_load,
mock_expanded_home
):
# Attach mocks used by load_extra_path_config
monkeypatch.setattr(folder_paths, 'add_model_folder_path', mock_add_model_folder_path)
monkeypatch.setattr(os.path, 'expanduser', mock_expanduser)
monkeypatch.setattr(yaml, 'safe_load', mock_yaml_safe_load)
dummy_yaml_file_name = 'dummy_path.yaml'
load_extra_path_config(dummy_yaml_file_name)
expected_calls = [
('checkpoints', os.path.join(mock_expanded_home, 'App', 'subfolder1'), False),
]
assert mock_add_model_folder_path.call_count == len(expected_calls)
# Check if add_model_folder_path was called with the correct arguments
for actual_call, expected_call in zip(mock_add_model_folder_path.call_args_list, expected_calls):
assert actual_call.args[0] == expected_call[0]
assert os.path.normpath(actual_call.args[1]) == os.path.normpath(expected_call[1]) # Normalize and check the path to check on multiple OS.
assert actual_call.args[2] == expected_call[2]
# Check if yaml.safe_load was called
mock_yaml_safe_load.assert_called_once()
# Check if open was called with the correct file path
mock_file.assert_called_once_with(dummy_yaml_file_name, 'r')
@patch('builtins.open', new_callable=mock_open)
def test_load_extra_model_paths_expands_appdata(
mock_file,
monkeypatch,
mock_add_model_folder_path,
mock_expandvars_appdata,
yaml_config_with_appdata,
mock_yaml_content_appdata
):
# Set the mock_file to return yaml with appdata as a variable
mock_file.return_value.read.return_value = yaml_config_with_appdata
# Attach mocks
monkeypatch.setattr(folder_paths, 'add_model_folder_path', mock_add_model_folder_path)
monkeypatch.setattr(os.path, 'expandvars', mock_expandvars_appdata)
monkeypatch.setattr(yaml, 'safe_load', Mock(return_value=mock_yaml_content_appdata))
# Mock expanduser to do nothing (since we're not testing it here)
monkeypatch.setattr(os.path, 'expanduser', lambda x: x)
dummy_yaml_file_name = 'dummy_path.yaml'
load_extra_path_config(dummy_yaml_file_name)
expected_base_path = 'C:/Users/TestUser/AppData/Roaming/ComfyUI'
expected_calls = [
('checkpoints', os.path.join(expected_base_path, 'models/checkpoints'), False),
]
assert mock_add_model_folder_path.call_count == len(expected_calls)
# Check the base path variable was expanded
for actual_call, expected_call in zip(mock_add_model_folder_path.call_args_list, expected_calls):
assert actual_call.args == expected_call
# Verify that expandvars was called
assert mock_expandvars_appdata.called

View File

@@ -95,17 +95,16 @@ class ComfyClient:
pass # Probably want to store this off for testing
history = self.get_history(prompt_id)[prompt_id]
for o in history['outputs']:
for node_id in history['outputs']:
node_output = history['outputs'][node_id]
result.outputs[node_id] = node_output
if 'images' in node_output:
images_output = []
for image in node_output['images']:
image_data = self.get_image(image['filename'], image['subfolder'], image['type'])
image_obj = Image.open(BytesIO(image_data))
images_output.append(image_obj)
node_output['image_objects'] = images_output
for node_id in history['outputs']:
node_output = history['outputs'][node_id]
result.outputs[node_id] = node_output
images_output = []
if 'images' in node_output:
for image in node_output['images']:
image_data = self.get_image(image['filename'], image['subfolder'], image['type'])
image_obj = Image.open(BytesIO(image_data))
images_output.append(image_obj)
node_output['image_objects'] = images_output
return result
@@ -357,6 +356,25 @@ class TestExecution:
assert 'prompt_id' in e.args[0], f"Did not get back a proper error message: {e}"
assert e.args[0]['node_id'] == generator.id, "Error should have been on the generator node"
def test_missing_node_error(self, client: ComfyClient, builder: GraphBuilder):
g = builder
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
input2 = g.node("StubImage", id="removeme", content="WHITE", height=512, width=512, batch_size=1)
input3 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
mask = g.node("StubMask", value=0.5, height=512, width=512, batch_size=1)
mix1 = g.node("TestLazyMixImages", image1=input1.out(0), image2=input2.out(0), mask=mask.out(0))
mix2 = g.node("TestLazyMixImages", image1=input1.out(0), image2=input3.out(0), mask=mask.out(0))
# We have multiple outputs. The first is invalid, but the second is valid
g.node("SaveImage", images=mix1.out(0))
g.node("SaveImage", images=mix2.out(0))
g.remove_node("removeme")
client.run(g)
# Add back in the missing node to make sure the error doesn't break the server
input2 = g.node("StubImage", id="removeme", content="WHITE", height=512, width=512, batch_size=1)
client.run(g)
def test_custom_is_changed(self, client: ComfyClient, builder: GraphBuilder):
g = builder
# Creating the nodes in this specific order previously caused a bug
@@ -450,8 +468,8 @@ class TestExecution:
g = builder
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
output1 = g.node("PreviewImage", images=input1.out(0))
output2 = g.node("PreviewImage", images=input1.out(0))
output1 = g.node("SaveImage", images=input1.out(0))
output2 = g.node("SaveImage", images=input1.out(0))
result = client.run(g)
images1 = result.get_images(output1)
@@ -478,3 +496,29 @@ class TestExecution:
assert len(images) == 1, "Should have 1 image"
assert numpy.array(images[0]).min() == 63 and numpy.array(images[0]).max() == 63, "Image should have value 0.25"
assert not result.did_run(test_node), "The execution should have been cached"
# This tests that nodes with OUTPUT_IS_LIST function correctly when they receive an ExecutionBlocker
# as input. We also test that when that list (containing an ExecutionBlocker) is passed to a node,
# only that one entry in the list is blocked.
def test_execution_block_list_output(self, client: ComfyClient, builder: GraphBuilder):
g = builder
image1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
image2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
image3 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
image_list = g.node("TestMakeListNode", value1=image1.out(0), value2=image2.out(0), value3=image3.out(0))
int1 = g.node("StubInt", value=1)
int2 = g.node("StubInt", value=2)
int3 = g.node("StubInt", value=3)
int_list = g.node("TestMakeListNode", value1=int1.out(0), value2=int2.out(0), value3=int3.out(0))
compare = g.node("TestIntConditions", a=int_list.out(0), b=2, operation="==")
blocker = g.node("TestExecutionBlocker", input=image_list.out(0), block=compare.out(0), verbose=False)
list_output = g.node("TestMakeListNode", value1=blocker.out(0))
output = g.node("PreviewImage", images=list_output.out(0))
result = client.run(g)
assert result.did_run(output), "The execution should have run"
images = result.get_images(output)
assert len(images) == 2, "Should have 2 images"
assert numpy.array(images[0]).min() == 0 and numpy.array(images[0]).max() == 0, "First image should be black"
assert numpy.array(images[1]).min() == 0 and numpy.array(images[1]).max() == 0, "Second image should also be black"

View File

@@ -109,15 +109,14 @@ class ComfyClient:
continue #previews are binary data
history = self.get_history(prompt_id)[prompt_id]
for o in history['outputs']:
for node_id in history['outputs']:
node_output = history['outputs'][node_id]
if 'images' in node_output:
images_output = []
for image in node_output['images']:
image_data = self.get_image(image['filename'], image['subfolder'], image['type'])
images_output.append(image_data)
output_images[node_id] = images_output
for node_id in history['outputs']:
node_output = history['outputs'][node_id]
images_output = []
if 'images' in node_output:
for image in node_output['images']:
image_data = self.get_image(image['filename'], image['subfolder'], image['type'])
images_output.append(image_data)
output_images[node_id] = images_output
return output_images

0
utils/__init__.py Normal file
View File

28
utils/extra_config.py Normal file
View File

@@ -0,0 +1,28 @@
import os
import yaml
import folder_paths
import logging
def load_extra_path_config(yaml_path):
with open(yaml_path, 'r') as stream:
config = yaml.safe_load(stream)
for c in config:
conf = config[c]
if conf is None:
continue
base_path = None
if "base_path" in conf:
base_path = conf.pop("base_path")
base_path = os.path.expandvars(os.path.expanduser(base_path))
is_default = False
if "is_default" in conf:
is_default = conf.pop("is_default")
for x in conf:
for y in conf[x].split("\n"):
if len(y) == 0:
continue
full_path = y
if base_path is not None:
full_path = os.path.join(base_path, full_path)
logging.info("Adding extra search path {} {}".format(x, full_path))
folder_paths.add_model_folder_path(x, full_path, is_default)

1
web/assets/CREDIT.txt generated vendored Normal file
View File

@@ -0,0 +1 @@
Thanks to OpenArt (https://openart.ai) for providing the sorted-custom-node-map data, captured in September 2024.

103
web/assets/ExtensionPanel-DZLYjWBj.js generated vendored Normal file
View File

@@ -0,0 +1,103 @@
var __defProp = Object.defineProperty;
var __name = (target, value) => __defProp(target, "name", { value, configurable: true });
import { d as defineComponent, bK as useExtensionStore, u as useSettingStore, r as ref, o as onMounted, q as computed, g as openBlock, h as createElementBlock, i as createVNode, y as withCtx, z as unref, bL as script$1, A as createBaseVNode, x as createBlock, M as Fragment, N as renderList, am as toDisplayString, ap as createTextVNode, j as createCommentVNode, D as script$4 } from "./index-CgU1oKZt.js";
import { s as script, a as script$2, b as script$3 } from "./index-DBWDcZsl.js";
import "./index-DYEEBf64.js";
const _hoisted_1 = { class: "extension-panel" };
const _hoisted_2 = { class: "mt-4" };
const _sfc_main = /* @__PURE__ */ defineComponent({
__name: "ExtensionPanel",
setup(__props) {
const extensionStore = useExtensionStore();
const settingStore = useSettingStore();
const editingEnabledExtensions = ref({});
onMounted(() => {
extensionStore.extensions.forEach((ext) => {
editingEnabledExtensions.value[ext.name] = extensionStore.isExtensionEnabled(ext.name);
});
});
const changedExtensions = computed(() => {
return extensionStore.extensions.filter(
(ext) => editingEnabledExtensions.value[ext.name] !== extensionStore.isExtensionEnabled(ext.name)
);
});
const hasChanges = computed(() => {
return changedExtensions.value.length > 0;
});
const updateExtensionStatus = /* @__PURE__ */ __name(() => {
const editingDisabledExtensionNames = Object.entries(
editingEnabledExtensions.value
).filter(([_, enabled]) => !enabled).map(([name]) => name);
settingStore.set("Comfy.Extension.Disabled", [
...extensionStore.inactiveDisabledExtensionNames,
...editingDisabledExtensionNames
]);
}, "updateExtensionStatus");
const applyChanges = /* @__PURE__ */ __name(() => {
window.location.reload();
}, "applyChanges");
return (_ctx, _cache) => {
return openBlock(), createElementBlock("div", _hoisted_1, [
createVNode(unref(script$2), {
value: unref(extensionStore).extensions,
stripedRows: "",
size: "small"
}, {
default: withCtx(() => [
createVNode(unref(script), {
field: "name",
header: _ctx.$t("extensionName"),
sortable: ""
}, null, 8, ["header"]),
createVNode(unref(script), { pt: {
bodyCell: "flex items-center justify-end"
} }, {
body: withCtx((slotProps) => [
createVNode(unref(script$1), {
modelValue: editingEnabledExtensions.value[slotProps.data.name],
"onUpdate:modelValue": /* @__PURE__ */ __name(($event) => editingEnabledExtensions.value[slotProps.data.name] = $event, "onUpdate:modelValue"),
onChange: updateExtensionStatus
}, null, 8, ["modelValue", "onUpdate:modelValue"])
]),
_: 1
})
]),
_: 1
}, 8, ["value"]),
createBaseVNode("div", _hoisted_2, [
hasChanges.value ? (openBlock(), createBlock(unref(script$3), {
key: 0,
severity: "info"
}, {
default: withCtx(() => [
createBaseVNode("ul", null, [
(openBlock(true), createElementBlock(Fragment, null, renderList(changedExtensions.value, (ext) => {
return openBlock(), createElementBlock("li", {
key: ext.name
}, [
createBaseVNode("span", null, toDisplayString(unref(extensionStore).isExtensionEnabled(ext.name) ? "[-]" : "[+]"), 1),
createTextVNode(" " + toDisplayString(ext.name), 1)
]);
}), 128))
])
]),
_: 1
})) : createCommentVNode("", true),
createVNode(unref(script$4), {
label: _ctx.$t("reloadToApplyChanges"),
icon: "pi pi-refresh",
onClick: applyChanges,
disabled: !hasChanges.value,
text: "",
fluid: "",
severity: "danger"
}, null, 8, ["label", "disabled"])
])
]);
};
}
});
export {
_sfc_main as default
};
//# sourceMappingURL=ExtensionPanel-DZLYjWBj.js.map

Some files were not shown because too many files have changed in this diff Show More