mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-08-03 07:26:31 +08:00
ControlNetApply now stacks.
It can be used to apply multiple control nets at the same time.
This commit is contained in:
25
comfy/sd.py
25
comfy/sd.py
@@ -334,8 +334,13 @@ class ControlNet:
|
||||
self.cond_hint = None
|
||||
self.strength = 1.0
|
||||
self.device = device
|
||||
self.previous_controlnet = None
|
||||
|
||||
def get_control(self, x_noisy, t, cond_txt):
|
||||
control_prev = None
|
||||
if self.previous_controlnet is not None:
|
||||
control_prev = self.previous_controlnet.get_control(x_noisy, t, cond_txt)
|
||||
|
||||
output_dtype = x_noisy.dtype
|
||||
if self.cond_hint is None or x_noisy.shape[2] * 8 != self.cond_hint.shape[2] or x_noisy.shape[3] * 8 != self.cond_hint.shape[3]:
|
||||
if self.cond_hint is not None:
|
||||
@@ -354,10 +359,15 @@ class ControlNet:
|
||||
self.control_model = model_management.unload_if_low_vram(self.control_model)
|
||||
out = []
|
||||
autocast_enabled = torch.is_autocast_enabled()
|
||||
for x in control:
|
||||
|
||||
for i in range(len(control)):
|
||||
x = control[i]
|
||||
x *= self.strength
|
||||
if x.dtype != output_dtype and not autocast_enabled:
|
||||
x = x.to(output_dtype)
|
||||
|
||||
if control_prev is not None:
|
||||
x += control_prev[i]
|
||||
out.append(x)
|
||||
return out
|
||||
|
||||
@@ -366,7 +376,13 @@ class ControlNet:
|
||||
self.strength = strength
|
||||
return self
|
||||
|
||||
def set_previous_controlnet(self, controlnet):
|
||||
self.previous_controlnet = controlnet
|
||||
return self
|
||||
|
||||
def cleanup(self):
|
||||
if self.previous_controlnet is not None:
|
||||
self.previous_controlnet.cleanup()
|
||||
if self.cond_hint is not None:
|
||||
del self.cond_hint
|
||||
self.cond_hint = None
|
||||
@@ -377,6 +393,13 @@ class ControlNet:
|
||||
c.strength = self.strength
|
||||
return c
|
||||
|
||||
def get_control_models(self):
|
||||
out = []
|
||||
if self.previous_controlnet is not None:
|
||||
out += self.previous_controlnet.get_control_models()
|
||||
out.append(self.control_model)
|
||||
return out
|
||||
|
||||
def load_controlnet(ckpt_path):
|
||||
controlnet_data = load_torch_file(ckpt_path)
|
||||
pth_key = 'control_model.input_blocks.1.1.transformer_blocks.0.attn2.to_k.weight'
|
||||
|
Reference in New Issue
Block a user