1
mirror of https://github.com/comfyanonymous/ComfyUI.git synced 2025-08-02 23:14:49 +08:00

Compare commits

...

5 Commits

Author SHA1 Message Date
Jedrzej Kosinski
c3f48337ae Create venv_management.py, add get_bootstrap_requirements_string() to help in bootstrapping a new venv's torch dependencies based on existing venv 2025-05-21 16:27:27 -07:00
ComfyUI Wiki
ded60c33a0 Update templates to 0.1.18 (#8224) 2025-05-21 11:40:08 -07:00
Michael Abrahams
8bb858e4d3 Improve performance with large number of queued prompts (#8176)
* get_current_queue_volatile

* restore get_current_queue method

* remove extra import
2025-05-21 05:14:17 -04:00
编程界的小学生
57893c843f Code Optimization and Issues Fixes in ComfyUI server (#8196)
* Update server.py

* Update server.py
2025-05-21 04:59:42 -04:00
Jedrzej Kosinski
65da29aaa9 Make torch.compile LoRA/key-compatible (#8213)
* Make torch compile node use wrapper instead of object_patch for the entire diffusion_models object, allowing key assotiations on diffusion_models to not break (loras, getting attributes, etc.)

* Moved torch compile code into comfy_api so it can be used by custom nodes with a degree of confidence

* Refactor set_torch_compile_wrapper to support a list of keys instead of just diffusion_model, as well as additional torch.compile args

* remove unused import

* Moved torch compile kwargs to be stored in model_options instead of attachments; attachments are more intended for things to be 'persisted', AKA not deepcopied

* Add some comments

* Remove random line of code, not sure how it got there
2025-05-21 04:56:56 -04:00
8 changed files with 216 additions and 10 deletions

125
app/venv_management.py Normal file
View File

@@ -0,0 +1,125 @@
import torch
import torchvision
import torchaudio
from dataclasses import dataclass
import importlib
if importlib.util.find_spec("torch_directml"):
from pip._vendor import pkg_resources
class VEnvException(Exception):
pass
@dataclass
class TorchVersionInfo:
name: str = None
version: str = None
extension: str = None
is_nightly: bool = False
is_cpu: bool = False
is_cuda: bool = False
is_xpu: bool = False
is_rocm: bool = False
is_directml: bool = False
def get_bootstrap_requirements_string():
'''
Get string to insert into a 'pip install' command to get the same torch dependencies as current venv.
'''
torch_info = get_torch_info(torch)
packages = [torchvision, torchaudio]
infos = [torch_info] + [get_torch_info(x) for x in packages]
# directml should be first dependency, if exists
directml_info = get_torch_directml_info()
if directml_info is not None:
infos = [directml_info] + infos
# create list of strings to combine into install string
install_str_list = []
for info in infos:
info_string = f"{info.name}=={info.version}"
if not info.is_cpu and not info.is_directml:
info_string = f"{info_string}+{info.extension}"
install_str_list.append(info_string)
# handle extra_index_url, if needed
extra_index_url = get_index_url(torch_info)
if extra_index_url:
install_str_list.append(extra_index_url)
# format nightly install properly
if torch_info.is_nightly:
install_str_list = ["--pre"] + install_str_list
install_str = " ".join(install_str_list)
return install_str
def get_index_url(info: TorchVersionInfo=None):
'''
Get --extra-index-url (or --index-url) for torch install.
'''
if info is None:
info = get_torch_info()
# for cpu, don't need any index_url
if info.is_cpu and not info.is_nightly:
return None
# otherwise, format index_url
base_url = "https://download.pytorch.org/whl/"
if info.is_nightly:
base_url = f"--index-url {base_url}nightly/"
else:
base_url = f"--extra-index-url {base_url}"
base_url = f"{base_url}{info.extension}"
return base_url
def get_torch_info(package=None):
'''
Get info about an installed torch-related package.
'''
if package is None:
package = torch
info = TorchVersionInfo(name=package.__name__)
info.version = package.__version__
info.extension = None
info.is_nightly = False
# get extension, separate from version
info.version, info.extension = info.version.split('+', 1)
if info.extension.startswith('cpu'):
info.is_cpu = True
elif info.extension.startswith('cu'):
info.is_cuda = True
elif info.extension.startswith('rocm'):
info.is_rocm = True
elif info.extension.startswith('xpu'):
info.is_xpu = True
# TODO: add checks for some odd pytorch versions, if possible
# check if nightly install
if 'dev' in info.version:
info.is_nightly = True
return info
def get_torch_directml_info():
'''
Get info specifically about torch-directml package.
Returns None if torch-directml is not installed.
'''
# the import string and the pip string are different
pip_name = "torch-directml"
# if no torch_directml, do nothing
if not importlib.util.find_spec("torch_directml"):
return None
info = TorchVersionInfo(name=pip_name)
info.is_directml = True
for p in pkg_resources.working_set:
if p.project_name.lower() == pip_name:
info.version = p.version
if p.version is None:
return None
return info
if __name__ == '__main__':
print(get_bootstrap_requirements_string())

View File

@@ -0,0 +1,5 @@
from .torch_compile import set_torch_compile_wrapper
__all__ = [
"set_torch_compile_wrapper",
]

View File

@@ -0,0 +1,69 @@
from __future__ import annotations
import torch
import comfy.utils
from comfy.patcher_extension import WrappersMP
from typing import TYPE_CHECKING, Callable, Optional
if TYPE_CHECKING:
from comfy.model_patcher import ModelPatcher
from comfy.patcher_extension import WrapperExecutor
COMPILE_KEY = "torch.compile"
TORCH_COMPILE_KWARGS = "torch_compile_kwargs"
def apply_torch_compile_factory(compiled_module_dict: dict[str, Callable]) -> Callable:
'''
Create a wrapper that will refer to the compiled_diffusion_model.
'''
def apply_torch_compile_wrapper(executor: WrapperExecutor, *args, **kwargs):
try:
orig_modules = {}
for key, value in compiled_module_dict.items():
orig_modules[key] = comfy.utils.get_attr(executor.class_obj, key)
comfy.utils.set_attr(executor.class_obj, key, value)
return executor(*args, **kwargs)
finally:
for key, value in orig_modules.items():
comfy.utils.set_attr(executor.class_obj, key, value)
return apply_torch_compile_wrapper
def set_torch_compile_wrapper(model: ModelPatcher, backend: str, options: Optional[dict[str,str]]=None,
mode: Optional[str]=None, fullgraph=False, dynamic: Optional[bool]=None,
keys: list[str]=["diffusion_model"], *args, **kwargs):
'''
Perform torch.compile that will be applied at sample time for either the whole model or specific params of the BaseModel instance.
When keys is None, it will default to using ["diffusion_model"], compiling the whole diffusion_model.
When a list of keys is provided, it will perform torch.compile on only the selected modules.
'''
# clear out any other torch.compile wrappers
model.remove_wrappers_with_key(WrappersMP.APPLY_MODEL, COMPILE_KEY)
# if no keys, default to 'diffusion_model'
if not keys:
keys = ["diffusion_model"]
# create kwargs dict that can be referenced later
compile_kwargs = {
"backend": backend,
"options": options,
"mode": mode,
"fullgraph": fullgraph,
"dynamic": dynamic,
}
# get a dict of compiled keys
compiled_modules = {}
for key in keys:
compiled_modules[key] = torch.compile(
model=model.get_model_object(key),
**compile_kwargs,
)
# add torch.compile wrapper
wrapper_func = apply_torch_compile_factory(
compiled_module_dict=compiled_modules,
)
# store wrapper to run on BaseModel's apply_model function
model.add_wrapper_with_key(WrappersMP.APPLY_MODEL, COMPILE_KEY, wrapper_func)
# keep compile kwargs for reference
model.model_options[TORCH_COMPILE_KWARGS] = compile_kwargs

View File

@@ -1,4 +1,5 @@
import torch
from comfy_api.torch_helpers import set_torch_compile_wrapper
class TorchCompileModel:
@classmethod
@@ -14,7 +15,7 @@ class TorchCompileModel:
def patch(self, model, backend):
m = model.clone()
m.add_object_patch("diffusion_model", torch.compile(model=m.get_model_object("diffusion_model"), backend=backend))
set_torch_compile_wrapper(model=m, backend=backend)
return (m, )
NODE_CLASS_MAPPINGS = {

View File

@@ -909,7 +909,6 @@ class PromptQueue:
self.currently_running = {}
self.history = {}
self.flags = {}
server.prompt_queue = self
def put(self, item):
with self.mutex:
@@ -954,6 +953,7 @@ class PromptQueue:
self.history[prompt[1]].update(history_result)
self.server.queue_updated()
# Note: slow
def get_current_queue(self):
with self.mutex:
out = []
@@ -961,6 +961,13 @@ class PromptQueue:
out += [x]
return (out, copy.deepcopy(self.queue))
# read-safe as long as queue items are immutable
def get_current_queue_volatile(self):
with self.mutex:
running = [x for x in self.currently_running.values()]
queued = copy.copy(self.queue)
return (running, queued)
def get_tasks_remaining(self):
with self.mutex:
return len(self.queue) + len(self.currently_running)

View File

@@ -260,7 +260,6 @@ def start_comfyui(asyncio_loop=None):
asyncio_loop = asyncio.new_event_loop()
asyncio.set_event_loop(asyncio_loop)
prompt_server = server.PromptServer(asyncio_loop)
q = execution.PromptQueue(prompt_server)
hook_breaker_ac10a0.save_functions()
nodes.init_extra_nodes(init_custom_nodes=not args.disable_all_custom_nodes, init_api_nodes=not args.disable_api_nodes)
@@ -271,7 +270,7 @@ def start_comfyui(asyncio_loop=None):
prompt_server.add_routes()
hijack_progress(prompt_server)
threading.Thread(target=prompt_worker, daemon=True, args=(q, prompt_server,)).start()
threading.Thread(target=prompt_worker, daemon=True, args=(prompt_server.prompt_queue, prompt_server,)).start()
if args.quick_test_for_ci:
exit(0)

View File

@@ -1,5 +1,5 @@
comfyui-frontend-package==1.19.9
comfyui-workflow-templates==0.1.14
comfyui-workflow-templates==0.1.18
torch
torchsde
torchvision

View File

@@ -29,6 +29,7 @@ import comfy.model_management
import node_helpers
from comfyui_version import __version__
from app.frontend_management import FrontendManager
from app.user_manager import UserManager
from app.model_manager import ModelFileManager
from app.custom_node_manager import CustomNodeManager
@@ -159,7 +160,7 @@ class PromptServer():
self.custom_node_manager = CustomNodeManager()
self.internal_routes = InternalRoutes(self)
self.supports = ["custom_nodes_from_web"]
self.prompt_queue = None
self.prompt_queue = execution.PromptQueue(self)
self.loop = loop
self.messages = asyncio.Queue()
self.client_session:Optional[aiohttp.ClientSession] = None
@@ -226,7 +227,7 @@ class PromptServer():
return response
@routes.get("/embeddings")
def get_embeddings(self):
def get_embeddings(request):
embeddings = folder_paths.get_filename_list("embeddings")
return web.json_response(list(map(lambda a: os.path.splitext(a)[0], embeddings)))
@@ -282,7 +283,6 @@ class PromptServer():
a.update(f.read())
b.update(image.file.read())
image.file.seek(0)
f.close()
return a.hexdigest() == b.hexdigest()
return False
@@ -621,7 +621,7 @@ class PromptServer():
@routes.get("/queue")
async def get_queue(request):
queue_info = {}
current_queue = self.prompt_queue.get_current_queue()
current_queue = self.prompt_queue.get_current_queue_volatile()
queue_info['queue_running'] = current_queue[0]
queue_info['queue_pending'] = current_queue[1]
return web.json_response(queue_info)