1
mirror of https://github.com/comfyanonymous/ComfyUI.git synced 2025-08-02 23:14:49 +08:00

Move some functions to utils.py

This commit is contained in:
comfyanonymous
2023-09-02 22:33:37 -04:00
parent 766c7b3815
commit a74c5dbf37
3 changed files with 23 additions and 24 deletions

View File

@@ -3,21 +3,6 @@ from . import model_base
from . import utils
from . import latent_formats
def state_dict_key_replace(state_dict, keys_to_replace):
for x in keys_to_replace:
if x in state_dict:
state_dict[keys_to_replace[x]] = state_dict.pop(x)
return state_dict
def state_dict_prefix_replace(state_dict, replace_prefix):
for rp in replace_prefix:
replace = list(map(lambda a: (a, "{}{}".format(replace_prefix[rp], a[len(rp):])), filter(lambda a: a.startswith(rp), state_dict.keys())))
for x in replace:
state_dict[x[1]] = state_dict.pop(x[0])
return state_dict
class ClipTarget:
def __init__(self, tokenizer, clip):
self.clip = clip
@@ -70,13 +55,13 @@ class BASE:
def process_clip_state_dict_for_saving(self, state_dict):
replace_prefix = {"": "cond_stage_model."}
return state_dict_prefix_replace(state_dict, replace_prefix)
return utils.state_dict_prefix_replace(state_dict, replace_prefix)
def process_unet_state_dict_for_saving(self, state_dict):
replace_prefix = {"": "model.diffusion_model."}
return state_dict_prefix_replace(state_dict, replace_prefix)
return utils.state_dict_prefix_replace(state_dict, replace_prefix)
def process_vae_state_dict_for_saving(self, state_dict):
replace_prefix = {"": "first_stage_model."}
return state_dict_prefix_replace(state_dict, replace_prefix)
return utils.state_dict_prefix_replace(state_dict, replace_prefix)