1
mirror of https://github.com/comfyanonymous/ComfyUI.git synced 2025-08-02 23:14:49 +08:00

Improved memory management. (#5450)

* Less fragile memory management.

* Fix issue.

* Remove useless function.

* Prevent and detect some types of memory leaks.

* Run garbage collector when switching workflow if needed.

* Fix issue.
This commit is contained in:
comfyanonymous
2024-12-02 14:39:34 -05:00
committed by GitHub
parent 2d5b3e0078
commit 79d5ceae6e
4 changed files with 119 additions and 120 deletions

View File

@@ -139,6 +139,7 @@ class ModelPatcher:
self.offload_device = offload_device
self.weight_inplace_update = weight_inplace_update
self.patches_uuid = uuid.uuid4()
self.parent = None
if not hasattr(self.model, 'model_loaded_weight_memory'):
self.model.model_loaded_weight_memory = 0
@@ -149,6 +150,9 @@ class ModelPatcher:
if not hasattr(self.model, 'model_lowvram'):
self.model.model_lowvram = False
if not hasattr(self.model, 'current_weight_patches_uuid'):
self.model.current_weight_patches_uuid = None
def model_size(self):
if self.size > 0:
return self.size
@@ -172,6 +176,7 @@ class ModelPatcher:
n.model_options = copy.deepcopy(self.model_options)
n.backup = self.backup
n.object_patches_backup = self.object_patches_backup
n.parent = self
return n
def is_clone(self, other):
@@ -464,6 +469,7 @@ class ModelPatcher:
self.model.lowvram_patch_counter += patch_counter
self.model.device = device_to
self.model.model_loaded_weight_memory = mem_counter
self.model.current_weight_patches_uuid = self.patches_uuid
def patch_model(self, device_to=None, lowvram_model_memory=0, load_weights=True, force_patch_weights=False):
for k in self.object_patches:
@@ -498,6 +504,7 @@ class ModelPatcher:
else:
comfy.utils.set_attr_param(self.model, k, bk.weight)
self.model.current_weight_patches_uuid = None
self.backup.clear()
if device_to is not None:
@@ -568,21 +575,42 @@ class ModelPatcher:
self.model.model_loaded_weight_memory -= memory_freed
return memory_freed
def partially_load(self, device_to, extra_memory=0):
self.unpatch_model(unpatch_weights=False)
def partially_load(self, device_to, extra_memory=0, force_patch_weights=False):
unpatch_weights = self.model.current_weight_patches_uuid is not None and (self.model.current_weight_patches_uuid != self.patches_uuid or force_patch_weights)
# TODO: force_patch_weights should not unload + reload full model
used = self.model.model_loaded_weight_memory
self.unpatch_model(self.offload_device, unpatch_weights=unpatch_weights)
if unpatch_weights:
extra_memory += (used - self.model.model_loaded_weight_memory)
self.patch_model(load_weights=False)
full_load = False
if self.model.model_lowvram == False:
if self.model.model_lowvram == False and self.model.model_loaded_weight_memory > 0:
return 0
if self.model.model_loaded_weight_memory + extra_memory > self.model_size():
full_load = True
current_used = self.model.model_loaded_weight_memory
self.load(device_to, lowvram_model_memory=current_used + extra_memory, full_load=full_load)
try:
self.load(device_to, lowvram_model_memory=current_used + extra_memory, force_patch_weights=force_patch_weights, full_load=full_load)
except Exception as e:
self.detach()
raise e
return self.model.model_loaded_weight_memory - current_used
def detach(self, unpatch_all=True):
self.model_patches_to(self.offload_device)
if unpatch_all:
self.unpatch_model(self.offload_device, unpatch_weights=unpatch_all)
return self.model
def current_loaded_device(self):
return self.model.device
def calculate_weight(self, patches, weight, key, intermediate_dtype=torch.float32):
print("WARNING the ModelPatcher.calculate_weight function is deprecated, please use: comfy.lora.calculate_weight instead")
return comfy.lora.calculate_weight(patches, weight, key, intermediate_dtype=intermediate_dtype)
def __del__(self):
self.detach(unpatch_all=False)