1
mirror of https://github.com/comfyanonymous/ComfyUI.git synced 2025-08-02 23:14:49 +08:00

Lowvram mode for gligen and fix some lowvram issues.

This commit is contained in:
comfyanonymous
2023-05-05 18:01:21 -04:00
parent 9bd33b6bd4
commit cb1551b819
4 changed files with 41 additions and 11 deletions

View File

@@ -88,6 +88,19 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
x = layer(x)
return x
#This is needed because accelerate makes a copy of transformer_options which breaks "current_index"
def forward_timestep_embed(ts, x, emb, context=None, transformer_options={}, output_shape=None):
for layer in ts:
if isinstance(layer, TimestepBlock):
x = layer(x, emb)
elif isinstance(layer, SpatialTransformer):
x = layer(x, context, transformer_options)
transformer_options["current_index"] += 1
elif isinstance(layer, Upsample):
x = layer(x, output_shape=output_shape)
else:
x = layer(x)
return x
class Upsample(nn.Module):
"""
@@ -805,13 +818,13 @@ class UNetModel(nn.Module):
h = x.type(self.dtype)
for id, module in enumerate(self.input_blocks):
h = module(h, emb, context, transformer_options)
h = forward_timestep_embed(module, h, emb, context, transformer_options)
if control is not None and 'input' in control and len(control['input']) > 0:
ctrl = control['input'].pop()
if ctrl is not None:
h += ctrl
hs.append(h)
h = self.middle_block(h, emb, context, transformer_options)
h = forward_timestep_embed(self.middle_block, h, emb, context, transformer_options)
if control is not None and 'middle' in control and len(control['middle']) > 0:
h += control['middle'].pop()
@@ -828,7 +841,7 @@ class UNetModel(nn.Module):
output_shape = hs[-1].shape
else:
output_shape = None
h = module(h, emb, context, transformer_options, output_shape)
h = forward_timestep_embed(module, h, emb, context, transformer_options, output_shape)
h = h.type(x.dtype)
if self.predict_codebook_ids:
return self.id_predictor(h)