mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-08-02 23:14:49 +08:00
Initialize text encoder to target dtype.
This commit is contained in:
13
comfy/ops.py
13
comfy/ops.py
@@ -28,9 +28,18 @@ def conv_nd(dims, *args, **kwargs):
|
||||
raise ValueError(f"unsupported dimensions: {dims}")
|
||||
|
||||
@contextmanager
|
||||
def use_comfy_ops(): # Kind of an ugly hack but I can't think of a better way
|
||||
def use_comfy_ops(device=None, dtype=None): # Kind of an ugly hack but I can't think of a better way
|
||||
old_torch_nn_linear = torch.nn.Linear
|
||||
torch.nn.Linear = Linear
|
||||
force_device = device
|
||||
force_dtype = dtype
|
||||
def linear_with_dtype(in_features: int, out_features: int, bias: bool = True, device=None, dtype=None):
|
||||
if force_device is not None:
|
||||
device = force_device
|
||||
if force_dtype is not None:
|
||||
dtype = force_dtype
|
||||
return Linear(in_features, out_features, bias=bias, device=device, dtype=dtype)
|
||||
|
||||
torch.nn.Linear = linear_with_dtype
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
|
Reference in New Issue
Block a user