mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-08-02 23:14:49 +08:00
LoRA Trainer: LoRA training node in weight adapter scheme (#8446)
This commit is contained in:
@@ -12,12 +12,20 @@ class WeightAdapterBase:
|
||||
weights: list[torch.Tensor]
|
||||
|
||||
@classmethod
|
||||
def load(cls, x: str, lora: dict[str, torch.Tensor]) -> Optional["WeightAdapterBase"]:
|
||||
def load(cls, x: str, lora: dict[str, torch.Tensor], alpha: float, dora_scale: torch.Tensor) -> Optional["WeightAdapterBase"]:
|
||||
raise NotImplementedError
|
||||
|
||||
def to_train(self) -> "WeightAdapterTrainBase":
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def create_train(cls, weight, *args) -> "WeightAdapterTrainBase":
|
||||
"""
|
||||
weight: The original weight tensor to be modified.
|
||||
*args: Additional arguments for configuration, such as rank, alpha etc.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def calculate_weight(
|
||||
self,
|
||||
weight,
|
||||
@@ -33,10 +41,22 @@ class WeightAdapterBase:
|
||||
|
||||
|
||||
class WeightAdapterTrainBase(nn.Module):
|
||||
# We follow the scheme of PR #7032
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
# [TODO] Collaborate with LoRA training PR #7032
|
||||
def __call__(self, w):
|
||||
"""
|
||||
w: The original weight tensor to be modified.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def passive_memory_usage(self):
|
||||
raise NotImplementedError("passive_memory_usage is not implemented")
|
||||
|
||||
def move_to(self, device):
|
||||
self.to(device)
|
||||
return self.passive_memory_usage()
|
||||
|
||||
|
||||
def weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function):
|
||||
@@ -102,3 +122,14 @@ def pad_tensor_to_shape(tensor: torch.Tensor, new_shape: list[int]) -> torch.Ten
|
||||
padded_tensor[new_slices] = tensor[orig_slices]
|
||||
|
||||
return padded_tensor
|
||||
|
||||
|
||||
def tucker_weight_from_conv(up, down, mid):
|
||||
up = up.reshape(up.size(0), up.size(1))
|
||||
down = down.reshape(down.size(0), down.size(1))
|
||||
return torch.einsum("m n ..., i m, n j -> i j ...", mid, up, down)
|
||||
|
||||
|
||||
def tucker_weight(wa, wb, t):
|
||||
temp = torch.einsum("i j ..., j r -> i r ...", t, wb)
|
||||
return torch.einsum("i j ..., i r -> r j ...", temp, wa)
|
||||
|
Reference in New Issue
Block a user