1
mirror of https://github.com/comfyanonymous/ComfyUI.git synced 2025-08-02 23:14:49 +08:00

All the unet weights should now be initialized with the right dtype.

This commit is contained in:
comfyanonymous
2023-06-15 18:42:30 -04:00
parent cf3974c829
commit ae43f09ef7
3 changed files with 29 additions and 23 deletions

View File

@@ -206,13 +206,13 @@ def mean_flat(tensor):
return tensor.mean(dim=list(range(1, len(tensor.shape))))
def normalization(channels):
def normalization(channels, dtype=None):
"""
Make a standard normalization layer.
:param channels: number of input channels.
:return: an nn.Module for normalization.
"""
return GroupNorm32(32, channels)
return GroupNorm32(32, channels, dtype=dtype)
# PyTorch 1.7 has SiLU, but we support PyTorch 1.5.