1
mirror of https://github.com/comfyanonymous/ComfyUI.git synced 2025-08-02 23:14:49 +08:00

InstantX depth flux controlnet.

This commit is contained in:
comfyanonymous
2024-08-29 02:14:19 -04:00
parent b33cd61070
commit ea3f39bd69
2 changed files with 32 additions and 16 deletions

View File

@@ -23,8 +23,12 @@ class ControlNetFlux(Flux):
self.controlnet_blocks = nn.ModuleList([])
for _ in range(self.params.depth):
controlnet_block = operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device)
# controlnet_block = zero_module(controlnet_block)
self.controlnet_blocks.append(controlnet_block)
self.controlnet_single_blocks = nn.ModuleList([])
for _ in range(self.params.depth_single_blocks):
self.controlnet_single_blocks.append(operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device))
self.gradient_checkpointing = False
self.latent_input = latent_input
self.pos_embed_input = operations.Linear(self.in_channels, self.hidden_size, bias=True, dtype=dtype, device=device)
@@ -78,26 +82,39 @@ class ControlNetFlux(Flux):
ids = torch.cat((txt_ids, img_ids), dim=1)
pe = self.pe_embedder(ids)
block_res_samples = ()
controlnet_double = ()
for block in self.double_blocks:
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
block_res_samples = block_res_samples + (img,)
for i in range(len(self.double_blocks)):
img, txt = self.double_blocks[i](img=img, txt=txt, vec=vec, pe=pe)
controlnet_double = controlnet_double + (self.controlnet_blocks[i](img),)
controlnet_block_res_samples = ()
for block_res_sample, controlnet_block in zip(block_res_samples, self.controlnet_blocks):
block_res_sample = controlnet_block(block_res_sample)
controlnet_block_res_samples = controlnet_block_res_samples + (block_res_sample,)
img = torch.cat((txt, img), 1)
controlnet_single = ()
repeat = math.ceil(self.main_model_double / len(controlnet_block_res_samples))
for i in range(len(self.single_blocks)):
img = self.single_blocks[i](img, vec=vec, pe=pe)
controlnet_single = controlnet_single + (self.controlnet_single_blocks[i](img[:, txt.shape[1] :, ...]),)
repeat = math.ceil(self.main_model_double / len(controlnet_double))
if self.latent_input:
out_input = ()
for x in controlnet_block_res_samples:
for x in controlnet_double:
out_input += (x,) * repeat
else:
out_input = (controlnet_block_res_samples * repeat)
return {"input": out_input[:self.main_model_double]}
out_input = (controlnet_double * repeat)
out = {"input": out_input[:self.main_model_double]}
if len(controlnet_single) > 0:
repeat = math.ceil(self.main_model_single / len(controlnet_single))
out_output = ()
if self.latent_input:
for x in controlnet_single:
out_output += (x,) * repeat
else:
out_output = (controlnet_single * repeat)
out["output"] = out_output[:self.main_model_single]
return out
def forward(self, x, timesteps, context, y, guidance=None, hint=None, **kwargs):
patch_size = 2