mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-08-03 07:26:31 +08:00
Add a node to load diff controlnets.
This commit is contained in:
17
comfy/sd.py
17
comfy/sd.py
@@ -400,7 +400,7 @@ class ControlNet:
|
||||
out.append(self.control_model)
|
||||
return out
|
||||
|
||||
def load_controlnet(ckpt_path):
|
||||
def load_controlnet(ckpt_path, model=None):
|
||||
controlnet_data = load_torch_file(ckpt_path)
|
||||
pth_key = 'control_model.input_blocks.1.1.transformer_blocks.0.attn2.to_k.weight'
|
||||
pth = False
|
||||
@@ -437,6 +437,21 @@ def load_controlnet(ckpt_path):
|
||||
use_fp16=use_fp16)
|
||||
|
||||
if pth:
|
||||
if 'difference' in controlnet_data:
|
||||
if model is not None:
|
||||
m = model.patch_model()
|
||||
model_sd = m.state_dict()
|
||||
for x in controlnet_data:
|
||||
c_m = "control_model."
|
||||
if x.startswith(c_m):
|
||||
sd_key = "model.diffusion_model.{}".format(x[len(c_m):])
|
||||
if sd_key in model_sd:
|
||||
cd = controlnet_data[x]
|
||||
cd += model_sd[sd_key].type(cd.dtype).to(cd.device)
|
||||
model.unpatch_model()
|
||||
else:
|
||||
print("WARNING: Loaded a diff controlnet without a model. It will very likely not work.")
|
||||
|
||||
class WeightsLoader(torch.nn.Module):
|
||||
pass
|
||||
w = WeightsLoader()
|
||||
|
Reference in New Issue
Block a user