1
mirror of https://github.com/comfyanonymous/ComfyUI.git synced 2025-08-02 23:14:49 +08:00

Initialize the unet directly on the target device.

This commit is contained in:
comfyanonymous
2023-07-29 14:51:56 -04:00
parent ad5866b02b
commit 4b957a0010
6 changed files with 110 additions and 103 deletions

View File

@@ -53,13 +53,13 @@ class BASE:
for x in self.unet_extra_config:
self.unet_config[x] = self.unet_extra_config[x]
def get_model(self, state_dict, prefix=""):
def get_model(self, state_dict, prefix="", device=None):
if self.inpaint_model():
return model_base.SDInpaint(self, model_type=self.model_type(state_dict, prefix))
return model_base.SDInpaint(self, model_type=self.model_type(state_dict, prefix), device=device)
elif self.noise_aug_config is not None:
return model_base.SD21UNCLIP(self, self.noise_aug_config, model_type=self.model_type(state_dict, prefix))
return model_base.SD21UNCLIP(self, self.noise_aug_config, model_type=self.model_type(state_dict, prefix), device=device)
else:
return model_base.BaseModel(self, model_type=self.model_type(state_dict, prefix))
return model_base.BaseModel(self, model_type=self.model_type(state_dict, prefix), device=device)
def process_clip_state_dict(self, state_dict):
return state_dict