1
mirror of https://github.com/comfyanonymous/ComfyUI.git synced 2025-08-03 07:26:31 +08:00

Some cleanups to how the text encoders are loaded.

This commit is contained in:
comfyanonymous
2024-02-19 10:29:18 -05:00
parent dbe0979b3f
commit d91f45ef28
3 changed files with 32 additions and 28 deletions

View File

@@ -22,6 +22,7 @@ class BASE:
sampling_settings = {}
latent_format = latent_formats.LatentFormat
vae_key_prefix = ["first_stage_model."]
text_encoder_key_prefix = ["cond_stage_model."]
supported_inference_dtypes = [torch.float16, torch.bfloat16, torch.float32]
manual_cast_dtype = None
@@ -55,6 +56,7 @@ class BASE:
return out
def process_clip_state_dict(self, state_dict):
state_dict = utils.state_dict_prefix_replace(state_dict, {k: "" for k in self.text_encoder_key_prefix}, filter_keys=True)
return state_dict
def process_unet_state_dict(self, state_dict):
@@ -64,7 +66,7 @@ class BASE:
return state_dict
def process_clip_state_dict_for_saving(self, state_dict):
replace_prefix = {"": "cond_stage_model."}
replace_prefix = {"": self.text_encoder_key_prefix[0]}
return utils.state_dict_prefix_replace(state_dict, replace_prefix)
def process_clip_vision_state_dict_for_saving(self, state_dict):
@@ -78,7 +80,7 @@ class BASE:
return utils.state_dict_prefix_replace(state_dict, replace_prefix)
def process_vae_state_dict_for_saving(self, state_dict):
replace_prefix = {"": "first_stage_model."}
replace_prefix = {"": self.vae_key_prefix[0]}
return utils.state_dict_prefix_replace(state_dict, replace_prefix)
def set_inference_dtype(self, dtype, manual_cast_dtype):