mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-08-02 23:14:49 +08:00
Basic SD3 controlnet implementation.
Still missing the node to properly use it.
This commit is contained in:
@@ -11,6 +11,7 @@ import comfy.ops
|
||||
import comfy.cldm.cldm
|
||||
import comfy.t2i_adapter.adapter
|
||||
import comfy.ldm.cascade.controlnet
|
||||
import comfy.cldm.mmdit
|
||||
|
||||
|
||||
def broadcast_image_to(tensor, target_batch_size, batched_number):
|
||||
@@ -94,13 +95,17 @@ class ControlBase:
|
||||
|
||||
for key in control:
|
||||
control_output = control[key]
|
||||
applied_to = set()
|
||||
for i in range(len(control_output)):
|
||||
x = control_output[i]
|
||||
if x is not None:
|
||||
if self.global_average_pooling:
|
||||
x = torch.mean(x, dim=(2, 3), keepdim=True).repeat(1, 1, x.shape[2], x.shape[3])
|
||||
|
||||
x *= self.strength
|
||||
if x not in applied_to: #memory saving strategy, allow shared tensors and only apply strength to shared tensors once
|
||||
applied_to.add(x)
|
||||
x *= self.strength
|
||||
|
||||
if x.dtype != output_dtype:
|
||||
x = x.to(output_dtype)
|
||||
|
||||
@@ -120,17 +125,18 @@ class ControlBase:
|
||||
if o[i].shape[0] < prev_val.shape[0]:
|
||||
o[i] = prev_val + o[i]
|
||||
else:
|
||||
o[i] += prev_val
|
||||
o[i] = prev_val + o[i] #TODO: change back to inplace add if shared tensors stop being an issue
|
||||
return out
|
||||
|
||||
class ControlNet(ControlBase):
|
||||
def __init__(self, control_model=None, global_average_pooling=False, device=None, load_device=None, manual_cast_dtype=None):
|
||||
def __init__(self, control_model=None, global_average_pooling=False, compression_ratio=8, device=None, load_device=None, manual_cast_dtype=None):
|
||||
super().__init__(device)
|
||||
self.control_model = control_model
|
||||
self.load_device = load_device
|
||||
if control_model is not None:
|
||||
self.control_model_wrapped = comfy.model_patcher.ModelPatcher(self.control_model, load_device=load_device, offload_device=comfy.model_management.unet_offload_device())
|
||||
|
||||
self.compression_ratio = compression_ratio
|
||||
self.global_average_pooling = global_average_pooling
|
||||
self.model_sampling_current = None
|
||||
self.manual_cast_dtype = manual_cast_dtype
|
||||
@@ -308,6 +314,37 @@ class ControlLora(ControlNet):
|
||||
def inference_memory_requirements(self, dtype):
|
||||
return comfy.utils.calculate_parameters(self.control_weights) * comfy.model_management.dtype_size(dtype) + ControlBase.inference_memory_requirements(self, dtype)
|
||||
|
||||
def load_controlnet_mmdit(sd):
|
||||
new_sd = comfy.model_detection.convert_diffusers_mmdit(sd, "")
|
||||
model_config = comfy.model_detection.model_config_from_unet(new_sd, "", True)
|
||||
num_blocks = comfy.model_detection.count_blocks(new_sd, 'joint_blocks.{}.')
|
||||
for k in sd:
|
||||
new_sd[k] = sd[k]
|
||||
|
||||
supported_inference_dtypes = model_config.supported_inference_dtypes
|
||||
|
||||
controlnet_config = model_config.unet_config
|
||||
unet_dtype = comfy.model_management.unet_dtype(supported_dtypes=supported_inference_dtypes)
|
||||
load_device = comfy.model_management.get_torch_device()
|
||||
manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device)
|
||||
if manual_cast_dtype is not None:
|
||||
operations = comfy.ops.manual_cast
|
||||
else:
|
||||
operations = comfy.ops.disable_weight_init
|
||||
|
||||
control_model = comfy.cldm.mmdit.ControlNet(num_blocks=num_blocks, operations=operations, device=load_device, dtype=unet_dtype, **controlnet_config)
|
||||
missing, unexpected = control_model.load_state_dict(new_sd, strict=False)
|
||||
|
||||
if len(missing) > 0:
|
||||
logging.warning("missing controlnet keys: {}".format(missing))
|
||||
|
||||
if len(unexpected) > 0:
|
||||
logging.debug("unexpected controlnet keys: {}".format(unexpected))
|
||||
|
||||
control = ControlNet(control_model, compression_ratio=1, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
|
||||
return control
|
||||
|
||||
|
||||
def load_controlnet(ckpt_path, model=None):
|
||||
controlnet_data = comfy.utils.load_torch_file(ckpt_path, safe_load=True)
|
||||
if "lora_controlnet" in controlnet_data:
|
||||
@@ -360,6 +397,8 @@ def load_controlnet(ckpt_path, model=None):
|
||||
if len(leftover_keys) > 0:
|
||||
logging.warning("leftover keys: {}".format(leftover_keys))
|
||||
controlnet_data = new_sd
|
||||
elif "controlnet_blocks.0.weight" in controlnet_data: #SD3 diffusers format
|
||||
return load_controlnet_mmdit(controlnet_data)
|
||||
|
||||
pth_key = 'control_model.zero_convs.0.0.weight'
|
||||
pth = False
|
||||
|
Reference in New Issue
Block a user