mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-08-02 23:14:49 +08:00
Add support for unCLIP SD2.x models.
See _for_testing/unclip in the UI for the new nodes. unCLIPCheckpointLoader is used to load them. unCLIPConditioning is used to add the image cond and takes as input a CLIPVisionEncode output which has been moved to the conditioning section.
This commit is contained in:
@@ -35,6 +35,10 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
|
||||
if 'strength' in cond[1]:
|
||||
strength = cond[1]['strength']
|
||||
|
||||
adm_cond = None
|
||||
if 'adm' in cond[1]:
|
||||
adm_cond = cond[1]['adm']
|
||||
|
||||
input_x = x_in[:,:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]]
|
||||
mult = torch.ones_like(input_x) * strength
|
||||
|
||||
@@ -60,6 +64,9 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
|
||||
cropped.append(cr)
|
||||
conditionning['c_concat'] = torch.cat(cropped, dim=1)
|
||||
|
||||
if adm_cond is not None:
|
||||
conditionning['c_adm'] = adm_cond
|
||||
|
||||
control = None
|
||||
if 'control' in cond[1]:
|
||||
control = cond[1]['control']
|
||||
@@ -76,6 +83,9 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
|
||||
if 'c_concat' in c1:
|
||||
if c1['c_concat'].shape != c2['c_concat'].shape:
|
||||
return False
|
||||
if 'c_adm' in c1:
|
||||
if c1['c_adm'].shape != c2['c_adm'].shape:
|
||||
return False
|
||||
return True
|
||||
|
||||
def can_concat_cond(c1, c2):
|
||||
@@ -92,16 +102,21 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
|
||||
def cond_cat(c_list):
|
||||
c_crossattn = []
|
||||
c_concat = []
|
||||
c_adm = []
|
||||
for x in c_list:
|
||||
if 'c_crossattn' in x:
|
||||
c_crossattn.append(x['c_crossattn'])
|
||||
if 'c_concat' in x:
|
||||
c_concat.append(x['c_concat'])
|
||||
if 'c_adm' in x:
|
||||
c_adm.append(x['c_adm'])
|
||||
out = {}
|
||||
if len(c_crossattn) > 0:
|
||||
out['c_crossattn'] = [torch.cat(c_crossattn)]
|
||||
if len(c_concat) > 0:
|
||||
out['c_concat'] = [torch.cat(c_concat)]
|
||||
if len(c_adm) > 0:
|
||||
out['c_adm'] = torch.cat(c_adm)
|
||||
return out
|
||||
|
||||
def calc_cond_uncond_batch(model_function, cond, uncond, x_in, timestep, max_total_area, cond_concat_in, model_options):
|
||||
@@ -327,6 +342,30 @@ def apply_control_net_to_equal_area(conds, uncond):
|
||||
n['control'] = cond_cnets[x]
|
||||
uncond[temp[1]] = [o[0], n]
|
||||
|
||||
def encode_adm(noise_augmentor, conds, batch_size, device):
|
||||
for t in range(len(conds)):
|
||||
x = conds[t]
|
||||
if 'adm' in x[1]:
|
||||
adm_inputs = []
|
||||
weights = []
|
||||
adm_in = x[1]["adm"]
|
||||
for adm_c in adm_in:
|
||||
adm_cond = adm_c[0].image_embeds
|
||||
weight = adm_c[1]
|
||||
c_adm, noise_level_emb = noise_augmentor(adm_cond.to(device), noise_level=torch.tensor([0], device=device))
|
||||
adm_out = torch.cat((c_adm, noise_level_emb), 1) * weight
|
||||
weights.append(weight)
|
||||
adm_inputs.append(adm_out)
|
||||
|
||||
adm_out = torch.stack(adm_inputs).sum(0)
|
||||
#TODO: Apply Noise to Embedding Mix
|
||||
else:
|
||||
adm_out = torch.zeros((1, noise_augmentor.time_embed.dim * 2), device=device)
|
||||
x[1] = x[1].copy()
|
||||
x[1]["adm"] = torch.cat([adm_out] * batch_size)
|
||||
|
||||
return conds
|
||||
|
||||
class KSampler:
|
||||
SCHEDULERS = ["karras", "normal", "simple", "ddim_uniform"]
|
||||
SAMPLERS = ["euler", "euler_ancestral", "heun", "dpm_2", "dpm_2_ancestral",
|
||||
@@ -422,10 +461,14 @@ class KSampler:
|
||||
else:
|
||||
precision_scope = contextlib.nullcontext
|
||||
|
||||
if hasattr(self.model, 'noise_augmentor'): #unclip
|
||||
positive = encode_adm(self.model.noise_augmentor, positive, noise.shape[0], self.device)
|
||||
negative = encode_adm(self.model.noise_augmentor, negative, noise.shape[0], self.device)
|
||||
|
||||
extra_args = {"cond":positive, "uncond":negative, "cond_scale": cfg, "model_options": self.model_options}
|
||||
|
||||
cond_concat = None
|
||||
if hasattr(self.model, 'concat_keys'):
|
||||
if hasattr(self.model, 'concat_keys'): #inpaint
|
||||
cond_concat = []
|
||||
for ck in self.model.concat_keys:
|
||||
if denoise_mask is not None:
|
||||
|
Reference in New Issue
Block a user