1
mirror of https://github.com/comfyanonymous/ComfyUI.git synced 2025-08-02 23:14:49 +08:00

Try to keep text encoders loaded and patched to increase speed.

load_model_gpu() is now used with the text encoder models instead of just
the unet.
This commit is contained in:
comfyanonymous
2023-07-01 13:22:51 -04:00
parent 97ee230682
commit b6a60fa696
4 changed files with 48 additions and 40 deletions

View File

@@ -216,11 +216,6 @@ current_gpu_controlnets = []
model_accelerated = False
def unet_offload_device():
if vram_state == VRAMState.HIGH_VRAM or vram_state == VRAMState.SHARED:
return get_torch_device()
else:
return torch.device("cpu")
def unload_model():
global current_loaded_model
@@ -234,8 +229,8 @@ def unload_model():
model_accelerated = False
current_loaded_model.model.to(unet_offload_device())
current_loaded_model.model_patches_to(unet_offload_device())
current_loaded_model.model.to(current_loaded_model.offload_device)
current_loaded_model.model_patches_to(current_loaded_model.offload_device)
current_loaded_model.unpatch_model()
current_loaded_model = None
@@ -260,10 +255,14 @@ def load_model_gpu(model):
model.unpatch_model()
raise e
torch_dev = get_torch_device()
torch_dev = model.load_device
model.model_patches_to(torch_dev)
vram_set_state = vram_state
if is_device_cpu(torch_dev):
vram_set_state = VRAMState.DISABLED
else:
vram_set_state = vram_state
if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM):
model_size = model.model_size()
current_free_mem = get_free_memory(torch_dev)
@@ -277,14 +276,14 @@ def load_model_gpu(model):
pass
elif vram_set_state == VRAMState.NORMAL_VRAM or vram_set_state == VRAMState.HIGH_VRAM or vram_set_state == VRAMState.SHARED:
model_accelerated = False
real_model.to(get_torch_device())
real_model.to(torch_dev)
else:
if vram_set_state == VRAMState.NO_VRAM:
device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "256MiB", "cpu": "16GiB"})
elif vram_set_state == VRAMState.LOW_VRAM:
device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "{}MiB".format(lowvram_model_memory // (1024 * 1024)), "cpu": "16GiB"})
accelerate.dispatch_model(real_model, device_map=device_map, main_device=get_torch_device())
accelerate.dispatch_model(real_model, device_map=device_map, main_device=torch_dev)
model_accelerated = True
return current_loaded_model
@@ -327,6 +326,12 @@ def unload_if_low_vram(model):
return model.cpu()
return model
def unet_offload_device():
if vram_state == VRAMState.HIGH_VRAM or vram_state == VRAMState.SHARED:
return get_torch_device()
else:
return torch.device("cpu")
def text_encoder_offload_device():
if args.gpu_only:
return get_torch_device()
@@ -428,14 +433,19 @@ def mps_mode():
global cpu_state
return cpu_state == CPUState.MPS
def is_device_cpu(device):
if hasattr(device, 'type'):
if (device.type == 'cpu' or device.type == 'mps'):
return True
return False
def should_use_fp16(device=None):
global xpu_available
global directml_enabled
if device is not None: #TODO
if hasattr(device, 'type'):
if (device.type == 'cpu' or device.type == 'mps'):
return False
if is_device_cpu(device):
return False
if FORCE_FP32:
return False