1
mirror of https://github.com/comfyanonymous/ComfyUI.git synced 2025-08-02 15:04:50 +08:00

Add model_options for text encoder.

This commit is contained in:
comfyanonymous
2024-08-17 10:15:13 -04:00
parent 858d51f91a
commit fca42836f2
9 changed files with 66 additions and 57 deletions

View File

@@ -84,7 +84,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
def __init__(self, version="openai/clip-vit-large-patch14", device="cpu", max_length=77,
freeze=True, layer="last", layer_idx=None, textmodel_json_config=None, dtype=None, model_class=comfy.clip_model.CLIPTextModel,
special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=True, enable_attention_masks=False, zero_out_masked=False,
return_projected_pooled=True, return_attention_masks=False): # clip-vit-base-patch32
return_projected_pooled=True, return_attention_masks=False, model_options={}): # clip-vit-base-patch32
super().__init__()
assert layer in self.LAYERS
@@ -94,7 +94,11 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
with open(textmodel_json_config) as f:
config = json.load(f)
self.operations = comfy.ops.manual_cast
operations = model_options.get("custom_operations", None)
if operations is None:
operations = comfy.ops.manual_cast
self.operations = operations
self.transformer = model_class(config, dtype, device, self.operations)
self.num_layers = self.transformer.num_layers
@@ -553,7 +557,7 @@ class SD1Tokenizer:
return {}
class SD1ClipModel(torch.nn.Module):
def __init__(self, device="cpu", dtype=None, clip_name="l", clip_model=SDClipModel, name=None, **kwargs):
def __init__(self, device="cpu", dtype=None, model_options={}, clip_name="l", clip_model=SDClipModel, name=None, **kwargs):
super().__init__()
if name is not None:
@@ -563,7 +567,7 @@ class SD1ClipModel(torch.nn.Module):
self.clip_name = clip_name
self.clip = "clip_{}".format(self.clip_name)
setattr(self, self.clip, clip_model(device=device, dtype=dtype, **kwargs))
setattr(self, self.clip, clip_model(device=device, dtype=dtype, model_options=model_options, **kwargs))
self.dtypes = set()
if dtype is not None: