1
mirror of https://github.com/comfyanonymous/ComfyUI.git synced 2025-08-03 07:26:31 +08:00

Load flux t5 in fp8 if weights are in fp8.

This commit is contained in:
comfyanonymous
2024-08-01 11:05:56 -04:00
parent 8d34211a7a
commit 5f98de7697
4 changed files with 29 additions and 12 deletions

View File

@@ -1,5 +1,6 @@
from comfy import sd1_clip
import comfy.text_encoders.t5
import comfy.model_management
from transformers import T5TokenizerFast
import torch
import os
@@ -34,11 +35,12 @@ class FluxTokenizer:
class FluxClipModel(torch.nn.Module):
def __init__(self, device="cpu", dtype=None):
def __init__(self, dtype_t5=None, device="cpu", dtype=None):
super().__init__()
dtype_t5 = comfy.model_management.pick_weight_dtype(dtype_t5, dtype, device)
self.clip_l = sd1_clip.SDClipModel(device=device, dtype=dtype, return_projected_pooled=False)
self.t5xxl = T5XXLModel(device=device, dtype=dtype)
self.dtypes = set([dtype])
self.t5xxl = T5XXLModel(device=device, dtype=dtype_t5)
self.dtypes = set([dtype, dtype_t5])
def set_clip_options(self, options):
self.clip_l.set_clip_options(options)
@@ -62,3 +64,8 @@ class FluxClipModel(torch.nn.Module):
else:
return self.t5xxl.load_sd(sd)
def flux_clip(dtype_t5=None):
class FluxClipModel_(FluxClipModel):
def __init__(self, device="cpu", dtype=None):
super().__init__(dtype_t5=dtype_t5, device=device, dtype=dtype)
return FluxClipModel_