1
mirror of https://github.com/comfyanonymous/ComfyUI.git synced 2025-08-02 23:14:49 +08:00

Support llava clip vision model.

This commit is contained in:
comfyanonymous
2025-03-06 00:24:43 -05:00
parent 85ef295069
commit 0bef826a98
4 changed files with 61 additions and 3 deletions

View File

@@ -211,6 +211,15 @@ class CLIPVision(torch.nn.Module):
pooled_output = self.post_layernorm(x[:, 0, :])
return x, i, pooled_output
class LlavaProjector(torch.nn.Module):
def __init__(self, in_dim, out_dim, dtype, device, operations):
super().__init__()
self.linear_1 = operations.Linear(in_dim, out_dim, bias=True, device=device, dtype=dtype)
self.linear_2 = operations.Linear(out_dim, out_dim, bias=True, device=device, dtype=dtype)
def forward(self, x):
return self.linear_2(torch.nn.functional.gelu(self.linear_1(x[:, 1:])))
class CLIPVisionModelProjection(torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
@@ -220,7 +229,16 @@ class CLIPVisionModelProjection(torch.nn.Module):
else:
self.visual_projection = lambda a: a
if "llava3" == config_dict.get("projector_type", None):
self.multi_modal_projector = LlavaProjector(config_dict["hidden_size"], 4096, dtype, device, operations)
else:
self.multi_modal_projector = None
def forward(self, *args, **kwargs):
x = self.vision_model(*args, **kwargs)
out = self.visual_projection(x[2])
return (x[0], x[1], out)
projected = None
if self.multi_modal_projector is not None:
projected = self.multi_modal_projector(x[1])
return (x[0], x[1], out, projected)