mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-08-02 23:14:49 +08:00
Add model_options for text encoder.
This commit is contained in:
21
comfy/sd.py
21
comfy/sd.py
@@ -62,7 +62,7 @@ def load_lora_for_models(model, clip, lora, strength_model, strength_clip):
|
||||
|
||||
|
||||
class CLIP:
|
||||
def __init__(self, target=None, embedding_directory=None, no_init=False, tokenizer_data={}, parameters=0):
|
||||
def __init__(self, target=None, embedding_directory=None, no_init=False, tokenizer_data={}, parameters=0, model_options={}):
|
||||
if no_init:
|
||||
return
|
||||
params = target.params.copy()
|
||||
@@ -71,9 +71,14 @@ class CLIP:
|
||||
|
||||
load_device = model_management.text_encoder_device()
|
||||
offload_device = model_management.text_encoder_offload_device()
|
||||
dtype = model_management.text_encoder_dtype(load_device)
|
||||
dtype = model_options.get("dtype", None)
|
||||
if dtype is None:
|
||||
dtype = model_management.text_encoder_dtype(load_device)
|
||||
|
||||
params['dtype'] = dtype
|
||||
params['device'] = model_management.text_encoder_initial_device(load_device, offload_device, parameters * model_management.dtype_size(dtype))
|
||||
params['model_options'] = model_options
|
||||
|
||||
self.cond_stage_model = clip(**(params))
|
||||
|
||||
for dt in self.cond_stage_model.dtypes:
|
||||
@@ -394,7 +399,7 @@ class CLIPType(Enum):
|
||||
HUNYUAN_DIT = 5
|
||||
FLUX = 6
|
||||
|
||||
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION):
|
||||
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
|
||||
clip_data = []
|
||||
for p in ckpt_paths:
|
||||
clip_data.append(comfy.utils.load_torch_file(p, safe_load=True))
|
||||
@@ -464,7 +469,7 @@ def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DI
|
||||
for c in clip_data:
|
||||
parameters += comfy.utils.calculate_parameters(c)
|
||||
|
||||
clip = CLIP(clip_target, embedding_directory=embedding_directory, parameters=parameters)
|
||||
clip = CLIP(clip_target, embedding_directory=embedding_directory, parameters=parameters, model_options=model_options)
|
||||
for c in clip_data:
|
||||
m, u = clip.load_sd(c)
|
||||
if len(m) > 0:
|
||||
@@ -506,14 +511,14 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl
|
||||
|
||||
return (model, clip, vae)
|
||||
|
||||
def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}):
|
||||
def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}, te_model_options={}):
|
||||
sd = comfy.utils.load_torch_file(ckpt_path)
|
||||
out = load_state_dict_guess_config(sd, output_vae, output_clip, output_clipvision, embedding_directory, output_model, model_options)
|
||||
out = load_state_dict_guess_config(sd, output_vae, output_clip, output_clipvision, embedding_directory, output_model, model_options, te_model_options=te_model_options)
|
||||
if out is None:
|
||||
raise RuntimeError("ERROR: Could not detect model type of: {}".format(ckpt_path))
|
||||
return out
|
||||
|
||||
def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}):
|
||||
def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}, te_model_options={}):
|
||||
clip = None
|
||||
clipvision = None
|
||||
vae = None
|
||||
@@ -563,7 +568,7 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
|
||||
clip_sd = model_config.process_clip_state_dict(sd)
|
||||
if len(clip_sd) > 0:
|
||||
parameters = comfy.utils.calculate_parameters(clip_sd)
|
||||
clip = CLIP(clip_target, embedding_directory=embedding_directory, tokenizer_data=clip_sd, parameters=parameters)
|
||||
clip = CLIP(clip_target, embedding_directory=embedding_directory, tokenizer_data=clip_sd, parameters=parameters, model_options=te_model_options)
|
||||
m, u = clip.load_sd(clip_sd, full_model=True)
|
||||
if len(m) > 0:
|
||||
m_filter = list(filter(lambda a: ".logit_scale" not in a and ".transformer.text_projection.weight" not in a, m))
|
||||
|
Reference in New Issue
Block a user