1
mirror of https://github.com/comfyanonymous/ComfyUI.git synced 2025-08-02 23:14:49 +08:00

Add support for unCLIP SD2.x models.

See _for_testing/unclip in the UI for the new nodes.

unCLIPCheckpointLoader is used to load them.

unCLIPConditioning is used to add the image cond and takes as input a
CLIPVisionEncode output which has been moved to the conditioning section.
This commit is contained in:
comfyanonymous
2023-04-01 23:19:15 -04:00
parent 0d972b85e6
commit 809bcc8ceb
17 changed files with 593 additions and 113 deletions

View File

@@ -35,6 +35,10 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
if 'strength' in cond[1]:
strength = cond[1]['strength']
adm_cond = None
if 'adm' in cond[1]:
adm_cond = cond[1]['adm']
input_x = x_in[:,:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]]
mult = torch.ones_like(input_x) * strength
@@ -60,6 +64,9 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
cropped.append(cr)
conditionning['c_concat'] = torch.cat(cropped, dim=1)
if adm_cond is not None:
conditionning['c_adm'] = adm_cond
control = None
if 'control' in cond[1]:
control = cond[1]['control']
@@ -76,6 +83,9 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
if 'c_concat' in c1:
if c1['c_concat'].shape != c2['c_concat'].shape:
return False
if 'c_adm' in c1:
if c1['c_adm'].shape != c2['c_adm'].shape:
return False
return True
def can_concat_cond(c1, c2):
@@ -92,16 +102,21 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
def cond_cat(c_list):
c_crossattn = []
c_concat = []
c_adm = []
for x in c_list:
if 'c_crossattn' in x:
c_crossattn.append(x['c_crossattn'])
if 'c_concat' in x:
c_concat.append(x['c_concat'])
if 'c_adm' in x:
c_adm.append(x['c_adm'])
out = {}
if len(c_crossattn) > 0:
out['c_crossattn'] = [torch.cat(c_crossattn)]
if len(c_concat) > 0:
out['c_concat'] = [torch.cat(c_concat)]
if len(c_adm) > 0:
out['c_adm'] = torch.cat(c_adm)
return out
def calc_cond_uncond_batch(model_function, cond, uncond, x_in, timestep, max_total_area, cond_concat_in, model_options):
@@ -327,6 +342,30 @@ def apply_control_net_to_equal_area(conds, uncond):
n['control'] = cond_cnets[x]
uncond[temp[1]] = [o[0], n]
def encode_adm(noise_augmentor, conds, batch_size, device):
for t in range(len(conds)):
x = conds[t]
if 'adm' in x[1]:
adm_inputs = []
weights = []
adm_in = x[1]["adm"]
for adm_c in adm_in:
adm_cond = adm_c[0].image_embeds
weight = adm_c[1]
c_adm, noise_level_emb = noise_augmentor(adm_cond.to(device), noise_level=torch.tensor([0], device=device))
adm_out = torch.cat((c_adm, noise_level_emb), 1) * weight
weights.append(weight)
adm_inputs.append(adm_out)
adm_out = torch.stack(adm_inputs).sum(0)
#TODO: Apply Noise to Embedding Mix
else:
adm_out = torch.zeros((1, noise_augmentor.time_embed.dim * 2), device=device)
x[1] = x[1].copy()
x[1]["adm"] = torch.cat([adm_out] * batch_size)
return conds
class KSampler:
SCHEDULERS = ["karras", "normal", "simple", "ddim_uniform"]
SAMPLERS = ["euler", "euler_ancestral", "heun", "dpm_2", "dpm_2_ancestral",
@@ -422,10 +461,14 @@ class KSampler:
else:
precision_scope = contextlib.nullcontext
if hasattr(self.model, 'noise_augmentor'): #unclip
positive = encode_adm(self.model.noise_augmentor, positive, noise.shape[0], self.device)
negative = encode_adm(self.model.noise_augmentor, negative, noise.shape[0], self.device)
extra_args = {"cond":positive, "uncond":negative, "cond_scale": cfg, "model_options": self.model_options}
cond_concat = None
if hasattr(self.model, 'concat_keys'):
if hasattr(self.model, 'concat_keys'): #inpaint
cond_concat = []
for ck in self.model.concat_keys:
if denoise_mask is not None: