1
mirror of https://github.com/comfyanonymous/ComfyUI.git synced 2025-08-02 15:04:50 +08:00

Fixed model merging issue with scaled fp8.

This commit is contained in:
comfyanonymous
2024-10-20 06:24:31 -04:00
parent 471cd3eace
commit f9f9faface
3 changed files with 40 additions and 29 deletions

View File

@@ -94,6 +94,31 @@ class LowVramPatch:
return comfy.float.stochastic_rounding(comfy.lora.calculate_weight(self.patches[self.key], weight.to(intermediate_dtype), self.key, intermediate_dtype=intermediate_dtype), weight.dtype, seed=string_to_seed(self.key))
return comfy.lora.calculate_weight(self.patches[self.key], weight, self.key, intermediate_dtype=intermediate_dtype)
def get_key_weight(model, key):
set_func = None
convert_func = None
op_keys = key.rsplit('.', 1)
if len(op_keys) < 2:
weight = comfy.utils.get_attr(model, key)
else:
op = comfy.utils.get_attr(model, op_keys[0])
try:
set_func = getattr(op, "set_{}".format(op_keys[1]))
except AttributeError:
pass
try:
convert_func = getattr(op, "convert_{}".format(op_keys[1]))
except AttributeError:
pass
weight = getattr(op, op_keys[1])
if convert_func is not None:
weight = comfy.utils.get_attr(model, key)
return weight, set_func, convert_func
class ModelPatcher:
def __init__(self, model, load_device, offload_device, size=0, weight_inplace_update=False):
self.size = size
@@ -294,14 +319,16 @@ class ModelPatcher:
if not k.startswith(filter_prefix):
continue
bk = self.backup.get(k, None)
weight, set_func, convert_func = get_key_weight(self.model, k)
if bk is not None:
weight = bk.weight
else:
weight = model_sd[k]
if convert_func is None:
convert_func = lambda a, **kwargs: a
if k in self.patches:
p[k] = [weight] + self.patches[k]
p[k] = [(weight, convert_func)] + self.patches[k]
else:
p[k] = (weight,)
p[k] = [(weight, convert_func)]
return p
def model_state_dict(self, filter_prefix=None):
@@ -317,27 +344,7 @@ class ModelPatcher:
if key not in self.patches:
return
set_func = None
convert_func = None
op_keys = key.rsplit('.', 1)
if len(op_keys) < 2:
weight = comfy.utils.get_attr(self.model, key)
else:
op = comfy.utils.get_attr(self.model, op_keys[0])
try:
set_func = getattr(op, "set_{}".format(op_keys[1]))
except AttributeError:
pass
try:
convert_func = getattr(op, "convert_{}".format(op_keys[1]))
except AttributeError:
pass
weight = getattr(op, op_keys[1])
if convert_func is not None:
weight = comfy.utils.get_attr(self.model, key)
weight, set_func, convert_func = get_key_weight(self.model, key)
inplace_update = self.weight_inplace_update or inplace_update
if key not in self.backup:
@@ -348,7 +355,7 @@ class ModelPatcher:
else:
temp_weight = weight.to(torch.float32, copy=True)
if convert_func is not None:
temp_weight = convert_func(temp_weight)
temp_weight = convert_func(temp_weight, inplace=True)
out_weight = comfy.lora.calculate_weight(self.patches[key], temp_weight, key)
if set_func is None: