mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-08-02 23:14:49 +08:00
Add support for GLIGEN textbox model.
This commit is contained in:
71
nodes.py
71
nodes.py
@@ -490,6 +490,51 @@ class unCLIPConditioning:
|
||||
c.append(n)
|
||||
return (c, )
|
||||
|
||||
class GLIGENLoader:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "gligen_name": (folder_paths.get_filename_list("gligen"), )}}
|
||||
|
||||
RETURN_TYPES = ("GLIGEN",)
|
||||
FUNCTION = "load_gligen"
|
||||
|
||||
CATEGORY = "_for_testing/gligen"
|
||||
|
||||
def load_gligen(self, gligen_name):
|
||||
gligen_path = folder_paths.get_full_path("gligen", gligen_name)
|
||||
gligen = comfy.sd.load_gligen(gligen_path)
|
||||
return (gligen,)
|
||||
|
||||
class GLIGENTextBoxApply:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"conditioning_to": ("CONDITIONING", ),
|
||||
"clip": ("CLIP", ),
|
||||
"gligen_textbox_model": ("GLIGEN", ),
|
||||
"text": ("STRING", {"multiline": True}),
|
||||
"width": ("INT", {"default": 64, "min": 8, "max": MAX_RESOLUTION, "step": 8}),
|
||||
"height": ("INT", {"default": 64, "min": 8, "max": MAX_RESOLUTION, "step": 8}),
|
||||
"x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
|
||||
"y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
|
||||
}}
|
||||
RETURN_TYPES = ("CONDITIONING",)
|
||||
FUNCTION = "append"
|
||||
|
||||
CATEGORY = "_for_testing/gligen"
|
||||
|
||||
def append(self, conditioning_to, clip, gligen_textbox_model, text, width, height, x, y):
|
||||
c = []
|
||||
cond, cond_pooled = clip.encode_from_tokens(clip.tokenize(text), return_pooled=True)
|
||||
for t in conditioning_to:
|
||||
n = [t[0], t[1].copy()]
|
||||
position_params = [(cond_pooled, height // 8, width // 8, y // 8, x // 8)]
|
||||
prev = []
|
||||
if "gligen" in n[1]:
|
||||
prev = n[1]['gligen'][2]
|
||||
|
||||
n[1]['gligen'] = ("position", gligen_textbox_model, prev + position_params)
|
||||
c.append(n)
|
||||
return (c, )
|
||||
|
||||
class EmptyLatentImage:
|
||||
def __init__(self, device="cpu"):
|
||||
@@ -731,27 +776,30 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive,
|
||||
negative_copy = []
|
||||
|
||||
control_nets = []
|
||||
def get_models(cond):
|
||||
models = []
|
||||
for c in cond:
|
||||
if 'control' in c[1]:
|
||||
models += [c[1]['control']]
|
||||
if 'gligen' in c[1]:
|
||||
models += [c[1]['gligen'][1]]
|
||||
return models
|
||||
|
||||
for p in positive:
|
||||
t = p[0]
|
||||
if t.shape[0] < noise.shape[0]:
|
||||
t = torch.cat([t] * noise.shape[0])
|
||||
t = t.to(device)
|
||||
if 'control' in p[1]:
|
||||
control_nets += [p[1]['control']]
|
||||
positive_copy += [[t] + p[1:]]
|
||||
for n in negative:
|
||||
t = n[0]
|
||||
if t.shape[0] < noise.shape[0]:
|
||||
t = torch.cat([t] * noise.shape[0])
|
||||
t = t.to(device)
|
||||
if 'control' in n[1]:
|
||||
control_nets += [n[1]['control']]
|
||||
negative_copy += [[t] + n[1:]]
|
||||
|
||||
control_net_models = []
|
||||
for x in control_nets:
|
||||
control_net_models += x.get_control_models()
|
||||
comfy.model_management.load_controlnet_gpu(control_net_models)
|
||||
models = get_models(positive) + get_models(negative)
|
||||
comfy.model_management.load_controlnet_gpu(models)
|
||||
|
||||
if sampler_name in comfy.samplers.KSampler.SAMPLERS:
|
||||
sampler = comfy.samplers.KSampler(real_model, steps=steps, device=device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options)
|
||||
@@ -761,8 +809,8 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive,
|
||||
|
||||
samples = sampler.sample(noise, positive_copy, negative_copy, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask)
|
||||
samples = samples.cpu()
|
||||
for c in control_nets:
|
||||
c.cleanup()
|
||||
for m in models:
|
||||
m.cleanup()
|
||||
|
||||
out = latent.copy()
|
||||
out["samples"] = samples
|
||||
@@ -1128,6 +1176,9 @@ NODE_CLASS_MAPPINGS = {
|
||||
"VAEEncodeTiled": VAEEncodeTiled,
|
||||
"TomePatchModel": TomePatchModel,
|
||||
"unCLIPCheckpointLoader": unCLIPCheckpointLoader,
|
||||
"GLIGENLoader": GLIGENLoader,
|
||||
"GLIGENTextBoxApply": GLIGENTextBoxApply,
|
||||
|
||||
"CheckpointLoader": CheckpointLoader,
|
||||
"DiffusersLoader": DiffusersLoader,
|
||||
}
|
||||
|
Reference in New Issue
Block a user