mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-08-02 23:14:49 +08:00
UNET weights can now be stored in fp8.
--fp8_e4m3fn-unet and --fp8_e5m2-unet are the two different formats supported by pytorch.
This commit is contained in:
@@ -841,14 +841,14 @@ class UNetModel(nn.Module):
|
||||
self.num_classes is not None
|
||||
), "must specify y if and only if the model is class-conditional"
|
||||
hs = []
|
||||
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(self.dtype)
|
||||
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype)
|
||||
emb = self.time_embed(t_emb)
|
||||
|
||||
if self.num_classes is not None:
|
||||
assert y.shape[0] == x.shape[0]
|
||||
emb = emb + self.label_emb(y)
|
||||
|
||||
h = x.type(self.dtype)
|
||||
h = x
|
||||
for id, module in enumerate(self.input_blocks):
|
||||
transformer_options["block"] = ("input", id)
|
||||
h = forward_timestep_embed(module, h, emb, context, transformer_options, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator)
|
||||
|
Reference in New Issue
Block a user