mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-08-02 23:14:49 +08:00
Try to keep text encoders loaded and patched to increase speed.
load_model_gpu() is now used with the text encoder models instead of just the unet.
This commit is contained in:
@@ -216,11 +216,6 @@ current_gpu_controlnets = []
|
||||
|
||||
model_accelerated = False
|
||||
|
||||
def unet_offload_device():
|
||||
if vram_state == VRAMState.HIGH_VRAM or vram_state == VRAMState.SHARED:
|
||||
return get_torch_device()
|
||||
else:
|
||||
return torch.device("cpu")
|
||||
|
||||
def unload_model():
|
||||
global current_loaded_model
|
||||
@@ -234,8 +229,8 @@ def unload_model():
|
||||
model_accelerated = False
|
||||
|
||||
|
||||
current_loaded_model.model.to(unet_offload_device())
|
||||
current_loaded_model.model_patches_to(unet_offload_device())
|
||||
current_loaded_model.model.to(current_loaded_model.offload_device)
|
||||
current_loaded_model.model_patches_to(current_loaded_model.offload_device)
|
||||
current_loaded_model.unpatch_model()
|
||||
current_loaded_model = None
|
||||
|
||||
@@ -260,10 +255,14 @@ def load_model_gpu(model):
|
||||
model.unpatch_model()
|
||||
raise e
|
||||
|
||||
torch_dev = get_torch_device()
|
||||
torch_dev = model.load_device
|
||||
model.model_patches_to(torch_dev)
|
||||
|
||||
vram_set_state = vram_state
|
||||
if is_device_cpu(torch_dev):
|
||||
vram_set_state = VRAMState.DISABLED
|
||||
else:
|
||||
vram_set_state = vram_state
|
||||
|
||||
if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM):
|
||||
model_size = model.model_size()
|
||||
current_free_mem = get_free_memory(torch_dev)
|
||||
@@ -277,14 +276,14 @@ def load_model_gpu(model):
|
||||
pass
|
||||
elif vram_set_state == VRAMState.NORMAL_VRAM or vram_set_state == VRAMState.HIGH_VRAM or vram_set_state == VRAMState.SHARED:
|
||||
model_accelerated = False
|
||||
real_model.to(get_torch_device())
|
||||
real_model.to(torch_dev)
|
||||
else:
|
||||
if vram_set_state == VRAMState.NO_VRAM:
|
||||
device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "256MiB", "cpu": "16GiB"})
|
||||
elif vram_set_state == VRAMState.LOW_VRAM:
|
||||
device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "{}MiB".format(lowvram_model_memory // (1024 * 1024)), "cpu": "16GiB"})
|
||||
|
||||
accelerate.dispatch_model(real_model, device_map=device_map, main_device=get_torch_device())
|
||||
accelerate.dispatch_model(real_model, device_map=device_map, main_device=torch_dev)
|
||||
model_accelerated = True
|
||||
return current_loaded_model
|
||||
|
||||
@@ -327,6 +326,12 @@ def unload_if_low_vram(model):
|
||||
return model.cpu()
|
||||
return model
|
||||
|
||||
def unet_offload_device():
|
||||
if vram_state == VRAMState.HIGH_VRAM or vram_state == VRAMState.SHARED:
|
||||
return get_torch_device()
|
||||
else:
|
||||
return torch.device("cpu")
|
||||
|
||||
def text_encoder_offload_device():
|
||||
if args.gpu_only:
|
||||
return get_torch_device()
|
||||
@@ -428,14 +433,19 @@ def mps_mode():
|
||||
global cpu_state
|
||||
return cpu_state == CPUState.MPS
|
||||
|
||||
def is_device_cpu(device):
|
||||
if hasattr(device, 'type'):
|
||||
if (device.type == 'cpu' or device.type == 'mps'):
|
||||
return True
|
||||
return False
|
||||
|
||||
def should_use_fp16(device=None):
|
||||
global xpu_available
|
||||
global directml_enabled
|
||||
|
||||
if device is not None: #TODO
|
||||
if hasattr(device, 'type'):
|
||||
if (device.type == 'cpu' or device.type == 'mps'):
|
||||
return False
|
||||
if is_device_cpu(device):
|
||||
return False
|
||||
|
||||
if FORCE_FP32:
|
||||
return False
|
||||
|
Reference in New Issue
Block a user