mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-08-03 23:49:57 +08:00
Use common function for casting weights to input.
This commit is contained in:
25
comfy/ops.py
25
comfy/ops.py
@@ -19,14 +19,17 @@
|
||||
import torch
|
||||
import comfy.model_management
|
||||
|
||||
def cast_to_input(weight, input, non_blocking=False):
|
||||
return weight.to(device=input.device, dtype=input.dtype, non_blocking=non_blocking)
|
||||
|
||||
def cast_bias_weight(s, input):
|
||||
bias = None
|
||||
non_blocking = comfy.model_management.device_should_use_non_blocking(input.device)
|
||||
if s.bias is not None:
|
||||
bias = s.bias.to(device=input.device, dtype=input.dtype, non_blocking=non_blocking)
|
||||
bias = cast_to_input(s.bias, input, non_blocking=non_blocking)
|
||||
if s.bias_function is not None:
|
||||
bias = s.bias_function(bias)
|
||||
weight = s.weight.to(device=input.device, dtype=input.dtype, non_blocking=non_blocking)
|
||||
weight = cast_to_input(s.weight, input, non_blocking=non_blocking)
|
||||
if s.weight_function is not None:
|
||||
weight = s.weight_function(weight)
|
||||
return weight, bias
|
||||
@@ -168,6 +171,21 @@ class disable_weight_init:
|
||||
else:
|
||||
return super().forward(*args, **kwargs)
|
||||
|
||||
class Embedding(torch.nn.Embedding, CastWeightBiasOp):
|
||||
def reset_parameters(self):
|
||||
self.bias = None
|
||||
return None
|
||||
|
||||
def forward_comfy_cast_weights(self, input):
|
||||
weight, bias = cast_bias_weight(self, input)
|
||||
return torch.nn.functional.embedding(input, weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
if self.comfy_cast_weights:
|
||||
return self.forward_comfy_cast_weights(*args, **kwargs)
|
||||
else:
|
||||
return super().forward(*args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def conv_nd(s, dims, *args, **kwargs):
|
||||
if dims == 2:
|
||||
@@ -202,3 +220,6 @@ class manual_cast(disable_weight_init):
|
||||
|
||||
class ConvTranspose1d(disable_weight_init.ConvTranspose1d):
|
||||
comfy_cast_weights = True
|
||||
|
||||
class Embedding(disable_weight_init.Embedding):
|
||||
comfy_cast_weights = True
|
||||
|
Reference in New Issue
Block a user