1
mirror of https://github.com/comfyanonymous/ComfyUI.git synced 2025-08-02 23:14:49 +08:00

Support AliMama SD3 and Flux inpaint controlnets.

Use the ControlNetInpaintingAliMamaApply node.
This commit is contained in:
comfyanonymous
2024-09-14 09:05:16 -04:00
parent 369a6dd2c4
commit f48e390032
4 changed files with 68 additions and 11 deletions

View File

@@ -79,13 +79,19 @@ class ControlBase:
self.previous_controlnet = None
self.extra_conds = []
self.strength_type = StrengthType.CONSTANT
self.concat_mask = False
self.extra_concat_orig = []
self.extra_concat = None
def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(0.0, 1.0), vae=None):
def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(0.0, 1.0), vae=None, extra_concat=[]):
self.cond_hint_original = cond_hint
self.strength = strength
self.timestep_percent_range = timestep_percent_range
if self.latent_format is not None:
self.vae = vae
self.extra_concat_orig = extra_concat.copy()
if self.concat_mask and len(self.extra_concat_orig) == 0:
self.extra_concat_orig.append(torch.tensor([[[[1.0]]]]))
return self
def pre_run(self, model, percent_to_timestep_function):
@@ -100,9 +106,9 @@ class ControlBase:
def cleanup(self):
if self.previous_controlnet is not None:
self.previous_controlnet.cleanup()
if self.cond_hint is not None:
del self.cond_hint
self.cond_hint = None
self.cond_hint = None
self.extra_concat = None
self.timestep_range = None
def get_models(self):
@@ -123,6 +129,8 @@ class ControlBase:
c.vae = self.vae
c.extra_conds = self.extra_conds.copy()
c.strength_type = self.strength_type
c.concat_mask = self.concat_mask
c.extra_concat_orig = self.extra_concat_orig.copy()
def inference_memory_requirements(self, dtype):
if self.previous_controlnet is not None:
@@ -175,7 +183,7 @@ class ControlBase:
class ControlNet(ControlBase):
def __init__(self, control_model=None, global_average_pooling=False, compression_ratio=8, latent_format=None, device=None, load_device=None, manual_cast_dtype=None, extra_conds=["y"], strength_type=StrengthType.CONSTANT):
def __init__(self, control_model=None, global_average_pooling=False, compression_ratio=8, latent_format=None, device=None, load_device=None, manual_cast_dtype=None, extra_conds=["y"], strength_type=StrengthType.CONSTANT, concat_mask=False):
super().__init__(device)
self.control_model = control_model
self.load_device = load_device
@@ -189,6 +197,7 @@ class ControlNet(ControlBase):
self.latent_format = latent_format
self.extra_conds += extra_conds
self.strength_type = strength_type
self.concat_mask = concat_mask
def get_control(self, x_noisy, t, cond, batched_number):
control_prev = None
@@ -220,6 +229,13 @@ class ControlNet(ControlBase):
comfy.model_management.load_models_gpu(loaded_models)
if self.latent_format is not None:
self.cond_hint = self.latent_format.process_in(self.cond_hint)
if len(self.extra_concat_orig) > 0:
to_concat = []
for c in self.extra_concat_orig:
c = comfy.utils.common_upscale(c, self.cond_hint.shape[3], self.cond_hint.shape[2], self.upscale_algorithm, "center")
to_concat.append(comfy.utils.repeat_to_batch_size(c, self.cond_hint.shape[0]))
self.cond_hint = torch.cat([self.cond_hint] + to_concat, dim=1)
self.cond_hint = self.cond_hint.to(device=self.device, dtype=dtype)
if x_noisy.shape[0] != self.cond_hint.shape[0]:
self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number)
@@ -410,12 +426,17 @@ def load_controlnet_mmdit(sd):
for k in sd:
new_sd[k] = sd[k]
control_model = comfy.cldm.mmdit.ControlNet(num_blocks=num_blocks, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
concat_mask = False
control_latent_channels = new_sd.get("pos_embed_input.proj.weight").shape[1]
if control_latent_channels == 17: #inpaint controlnet
concat_mask = True
control_model = comfy.cldm.mmdit.ControlNet(num_blocks=num_blocks, control_latent_channels=control_latent_channels, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
control_model = controlnet_load_state_dict(control_model, new_sd)
latent_format = comfy.latent_formats.SD3()
latent_format.shift_factor = 0 #SD3 controlnet weirdness
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, concat_mask=concat_mask, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
return control
@@ -450,13 +471,16 @@ def load_controlnet_flux_instantx(sd):
num_union_modes = new_sd[union_cnet].shape[0]
control_latent_channels = new_sd.get("pos_embed_input.weight").shape[1] // 4
concat_mask = False
if control_latent_channels == 17:
concat_mask = True
control_model = comfy.ldm.flux.controlnet.ControlNetFlux(latent_input=True, num_union_modes=num_union_modes, control_latent_channels=control_latent_channels, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
control_model = controlnet_load_state_dict(control_model, new_sd)
latent_format = comfy.latent_formats.Flux()
extra_conds = ['y', 'guidance']
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, concat_mask=concat_mask, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
return control
def convert_mistoline(sd):