mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-08-03 07:26:31 +08:00
Add support for unCLIP SD2.x models.
See _for_testing/unclip in the UI for the new nodes. unCLIPCheckpointLoader is used to load them. unCLIPConditioning is used to add the image cond and takes as input a CLIPVisionEncode output which has been moved to the conditioning section.
This commit is contained in:
@@ -1,5 +1,47 @@
|
||||
import torch
|
||||
|
||||
def load_torch_file(ckpt):
|
||||
if ckpt.lower().endswith(".safetensors"):
|
||||
import safetensors.torch
|
||||
sd = safetensors.torch.load_file(ckpt, device="cpu")
|
||||
else:
|
||||
pl_sd = torch.load(ckpt, map_location="cpu")
|
||||
if "global_step" in pl_sd:
|
||||
print(f"Global Step: {pl_sd['global_step']}")
|
||||
if "state_dict" in pl_sd:
|
||||
sd = pl_sd["state_dict"]
|
||||
else:
|
||||
sd = pl_sd
|
||||
return sd
|
||||
|
||||
def transformers_convert(sd, prefix_from, prefix_to, number):
|
||||
resblock_to_replace = {
|
||||
"ln_1": "layer_norm1",
|
||||
"ln_2": "layer_norm2",
|
||||
"mlp.c_fc": "mlp.fc1",
|
||||
"mlp.c_proj": "mlp.fc2",
|
||||
"attn.out_proj": "self_attn.out_proj",
|
||||
}
|
||||
|
||||
for resblock in range(number):
|
||||
for x in resblock_to_replace:
|
||||
for y in ["weight", "bias"]:
|
||||
k = "{}.transformer.resblocks.{}.{}.{}".format(prefix_from, resblock, x, y)
|
||||
k_to = "{}.encoder.layers.{}.{}.{}".format(prefix_to, resblock, resblock_to_replace[x], y)
|
||||
if k in sd:
|
||||
sd[k_to] = sd.pop(k)
|
||||
|
||||
for y in ["weight", "bias"]:
|
||||
k_from = "{}.transformer.resblocks.{}.attn.in_proj_{}".format(prefix_from, resblock, y)
|
||||
if k_from in sd:
|
||||
weights = sd.pop(k_from)
|
||||
shape_from = weights.shape[0] // 3
|
||||
for x in range(3):
|
||||
p = ["self_attn.q_proj", "self_attn.k_proj", "self_attn.v_proj"]
|
||||
k_to = "{}.encoder.layers.{}.{}.{}".format(prefix_to, resblock, p[x], y)
|
||||
sd[k_to] = weights[shape_from*x:shape_from*(x + 1)]
|
||||
return sd
|
||||
|
||||
def common_upscale(samples, width, height, upscale_method, crop):
|
||||
if crop == "center":
|
||||
old_width = samples.shape[3]
|
||||
|
Reference in New Issue
Block a user