1
mirror of https://github.com/comfyanonymous/ComfyUI.git synced 2025-08-02 23:14:49 +08:00

UNET weights can now be stored in fp8.

--fp8_e4m3fn-unet and --fp8_e5m2-unet are the two different formats
supported by pytorch.
This commit is contained in:
comfyanonymous
2023-12-04 11:10:00 -05:00
parent af365e4dd1
commit 31b0f6f3d8
6 changed files with 47 additions and 10 deletions

View File

@@ -841,14 +841,14 @@ class UNetModel(nn.Module):
self.num_classes is not None
), "must specify y if and only if the model is class-conditional"
hs = []
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(self.dtype)
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype)
emb = self.time_embed(t_emb)
if self.num_classes is not None:
assert y.shape[0] == x.shape[0]
emb = emb + self.label_emb(y)
h = x.type(self.dtype)
h = x
for id, module in enumerate(self.input_blocks):
transformer_options["block"] = ("input", id)
h = forward_timestep_embed(module, h, emb, context, transformer_options, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator)