1
mirror of https://github.com/comfyanonymous/ComfyUI.git synced 2025-08-02 23:14:49 +08:00

Make lora code a bit cleaner.

This commit is contained in:
comfyanonymous
2023-12-09 14:15:09 -05:00
parent 9e411073e9
commit cb63e230b4
2 changed files with 18 additions and 10 deletions

View File

@@ -217,13 +217,19 @@ class ModelPatcher:
v = (self.calculate_weight(v[1:], v[0].clone(), key), )
if len(v) == 1:
patch_type = "diff"
elif len(v) == 2:
patch_type = v[0]
v = v[1]
if patch_type == "diff":
w1 = v[0]
if alpha != 0.0:
if w1.shape != weight.shape:
print("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, w1.shape, weight.shape))
else:
weight += alpha * comfy.model_management.cast_to_device(w1, weight.device, weight.dtype)
elif len(v) == 4: #lora/locon
elif patch_type == "lora": #lora/locon
mat1 = comfy.model_management.cast_to_device(v[0], weight.device, torch.float32)
mat2 = comfy.model_management.cast_to_device(v[1], weight.device, torch.float32)
if v[2] is not None:
@@ -237,7 +243,7 @@ class ModelPatcher:
weight += (alpha * torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1))).reshape(weight.shape).type(weight.dtype)
except Exception as e:
print("ERROR", key, e)
elif len(v) == 8: #lokr
elif patch_type == "lokr":
w1 = v[0]
w2 = v[1]
w1_a = v[3]
@@ -276,7 +282,7 @@ class ModelPatcher:
weight += alpha * torch.kron(w1, w2).reshape(weight.shape).type(weight.dtype)
except Exception as e:
print("ERROR", key, e)
else: #loha
elif patch_type == "loha":
w1a = v[0]
w1b = v[1]
if v[2] is not None:
@@ -305,6 +311,8 @@ class ModelPatcher:
weight += (alpha * m1 * m2).reshape(weight.shape).type(weight.dtype)
except Exception as e:
print("ERROR", key, e)
else:
print("patch type not recognized", patch_type, key)
return weight