mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-08-02 23:14:49 +08:00
LoRA Trainer: LoRA training node in weight adapter scheme (#8446)
This commit is contained in:
@@ -3,7 +3,56 @@ from typing import Optional
|
||||
|
||||
import torch
|
||||
import comfy.model_management
|
||||
from .base import WeightAdapterBase, weight_decompose, pad_tensor_to_shape
|
||||
from .base import (
|
||||
WeightAdapterBase,
|
||||
WeightAdapterTrainBase,
|
||||
weight_decompose,
|
||||
pad_tensor_to_shape,
|
||||
tucker_weight_from_conv,
|
||||
)
|
||||
|
||||
|
||||
class LoraDiff(WeightAdapterTrainBase):
|
||||
def __init__(self, weights):
|
||||
super().__init__()
|
||||
mat1, mat2, alpha, mid, dora_scale, reshape = weights
|
||||
out_dim, rank = mat1.shape[0], mat1.shape[1]
|
||||
rank, in_dim = mat2.shape[0], mat2.shape[1]
|
||||
if mid is not None:
|
||||
convdim = mid.ndim - 2
|
||||
layer = (
|
||||
torch.nn.Conv1d,
|
||||
torch.nn.Conv2d,
|
||||
torch.nn.Conv3d
|
||||
)[convdim]
|
||||
else:
|
||||
layer = torch.nn.Linear
|
||||
self.lora_up = layer(rank, out_dim, bias=False)
|
||||
self.lora_down = layer(in_dim, rank, bias=False)
|
||||
self.lora_up.weight.data.copy_(mat1)
|
||||
self.lora_down.weight.data.copy_(mat2)
|
||||
if mid is not None:
|
||||
self.lora_mid = layer(mid, rank, bias=False)
|
||||
self.lora_mid.weight.data.copy_(mid)
|
||||
else:
|
||||
self.lora_mid = None
|
||||
self.rank = rank
|
||||
self.alpha = torch.nn.Parameter(torch.tensor(alpha), requires_grad=False)
|
||||
|
||||
def __call__(self, w):
|
||||
org_dtype = w.dtype
|
||||
if self.lora_mid is None:
|
||||
diff = self.lora_up.weight @ self.lora_down.weight
|
||||
else:
|
||||
diff = tucker_weight_from_conv(
|
||||
self.lora_up.weight, self.lora_down.weight, self.lora_mid.weight
|
||||
)
|
||||
scale = self.alpha / self.rank
|
||||
weight = w + scale * diff.reshape(w.shape)
|
||||
return weight.to(org_dtype)
|
||||
|
||||
def passive_memory_usage(self):
|
||||
return sum(param.numel() * param.element_size() for param in self.parameters())
|
||||
|
||||
|
||||
class LoRAAdapter(WeightAdapterBase):
|
||||
@@ -13,6 +62,21 @@ class LoRAAdapter(WeightAdapterBase):
|
||||
self.loaded_keys = loaded_keys
|
||||
self.weights = weights
|
||||
|
||||
@classmethod
|
||||
def create_train(cls, weight, rank=1, alpha=1.0):
|
||||
out_dim = weight.shape[0]
|
||||
in_dim = weight.shape[1:].numel()
|
||||
mat1 = torch.empty(out_dim, rank, device=weight.device, dtype=weight.dtype)
|
||||
mat2 = torch.empty(rank, in_dim, device=weight.device, dtype=weight.dtype)
|
||||
torch.nn.init.kaiming_uniform_(mat1, a=5**0.5)
|
||||
torch.nn.init.constant_(mat2, 0.0)
|
||||
return LoraDiff(
|
||||
(mat1, mat2, alpha, None, None, None)
|
||||
)
|
||||
|
||||
def to_train(self):
|
||||
return LoraDiff(self.weights)
|
||||
|
||||
@classmethod
|
||||
def load(
|
||||
cls,
|
||||
|
Reference in New Issue
Block a user