mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-08-02 23:14:49 +08:00
Compare commits
2 Commits
Author | SHA1 | Date | |
---|---|---|---|
|
9aac21f894 | ||
|
528d1b3563 |
@@ -159,20 +159,20 @@ class DoubleStreamBlock(nn.Module):
|
|||||||
)
|
)
|
||||||
self.flipped_img_txt = flipped_img_txt
|
self.flipped_img_txt = flipped_img_txt
|
||||||
|
|
||||||
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims=None):
|
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims_img=None, modulation_dims_txt=None):
|
||||||
img_mod1, img_mod2 = self.img_mod(vec)
|
img_mod1, img_mod2 = self.img_mod(vec)
|
||||||
txt_mod1, txt_mod2 = self.txt_mod(vec)
|
txt_mod1, txt_mod2 = self.txt_mod(vec)
|
||||||
|
|
||||||
# prepare image for attention
|
# prepare image for attention
|
||||||
img_modulated = self.img_norm1(img)
|
img_modulated = self.img_norm1(img)
|
||||||
img_modulated = apply_mod(img_modulated, (1 + img_mod1.scale), img_mod1.shift, modulation_dims)
|
img_modulated = apply_mod(img_modulated, (1 + img_mod1.scale), img_mod1.shift, modulation_dims_img)
|
||||||
img_qkv = self.img_attn.qkv(img_modulated)
|
img_qkv = self.img_attn.qkv(img_modulated)
|
||||||
img_q, img_k, img_v = img_qkv.view(img_qkv.shape[0], img_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
img_q, img_k, img_v = img_qkv.view(img_qkv.shape[0], img_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||||
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
|
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
|
||||||
|
|
||||||
# prepare txt for attention
|
# prepare txt for attention
|
||||||
txt_modulated = self.txt_norm1(txt)
|
txt_modulated = self.txt_norm1(txt)
|
||||||
txt_modulated = apply_mod(txt_modulated, (1 + txt_mod1.scale), txt_mod1.shift, modulation_dims)
|
txt_modulated = apply_mod(txt_modulated, (1 + txt_mod1.scale), txt_mod1.shift, modulation_dims_txt)
|
||||||
txt_qkv = self.txt_attn.qkv(txt_modulated)
|
txt_qkv = self.txt_attn.qkv(txt_modulated)
|
||||||
txt_q, txt_k, txt_v = txt_qkv.view(txt_qkv.shape[0], txt_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
txt_q, txt_k, txt_v = txt_qkv.view(txt_qkv.shape[0], txt_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||||
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
|
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
|
||||||
@@ -195,12 +195,12 @@ class DoubleStreamBlock(nn.Module):
|
|||||||
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:]
|
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:]
|
||||||
|
|
||||||
# calculate the img bloks
|
# calculate the img bloks
|
||||||
img = img + apply_mod(self.img_attn.proj(img_attn), img_mod1.gate, None, modulation_dims)
|
img = img + apply_mod(self.img_attn.proj(img_attn), img_mod1.gate, None, modulation_dims_img)
|
||||||
img = img + apply_mod(self.img_mlp(apply_mod(self.img_norm2(img), (1 + img_mod2.scale), img_mod2.shift, modulation_dims)), img_mod2.gate, None, modulation_dims)
|
img = img + apply_mod(self.img_mlp(apply_mod(self.img_norm2(img), (1 + img_mod2.scale), img_mod2.shift, modulation_dims_img)), img_mod2.gate, None, modulation_dims_img)
|
||||||
|
|
||||||
# calculate the txt bloks
|
# calculate the txt bloks
|
||||||
txt += apply_mod(self.txt_attn.proj(txt_attn), txt_mod1.gate, None, modulation_dims)
|
txt += apply_mod(self.txt_attn.proj(txt_attn), txt_mod1.gate, None, modulation_dims_txt)
|
||||||
txt += apply_mod(self.txt_mlp(apply_mod(self.txt_norm2(txt), (1 + txt_mod2.scale), txt_mod2.shift, modulation_dims)), txt_mod2.gate, None, modulation_dims)
|
txt += apply_mod(self.txt_mlp(apply_mod(self.txt_norm2(txt), (1 + txt_mod2.scale), txt_mod2.shift, modulation_dims_txt)), txt_mod2.gate, None, modulation_dims_txt)
|
||||||
|
|
||||||
if txt.dtype == torch.float16:
|
if txt.dtype == torch.float16:
|
||||||
txt = torch.nan_to_num(txt, nan=0.0, posinf=65504, neginf=-65504)
|
txt = torch.nan_to_num(txt, nan=0.0, posinf=65504, neginf=-65504)
|
||||||
|
@@ -244,9 +244,11 @@ class HunyuanVideo(nn.Module):
|
|||||||
vec = torch.cat([(vec_ + token_replace_vec).unsqueeze(1), (vec_ + vec).unsqueeze(1)], dim=1)
|
vec = torch.cat([(vec_ + token_replace_vec).unsqueeze(1), (vec_ + vec).unsqueeze(1)], dim=1)
|
||||||
frame_tokens = (initial_shape[-1] // self.patch_size[-1]) * (initial_shape[-2] // self.patch_size[-2])
|
frame_tokens = (initial_shape[-1] // self.patch_size[-1]) * (initial_shape[-2] // self.patch_size[-2])
|
||||||
modulation_dims = [(0, frame_tokens, 0), (frame_tokens, None, 1)]
|
modulation_dims = [(0, frame_tokens, 0), (frame_tokens, None, 1)]
|
||||||
|
modulation_dims_txt = [(0, None, 1)]
|
||||||
else:
|
else:
|
||||||
vec = vec + self.vector_in(y[:, :self.params.vec_in_dim])
|
vec = vec + self.vector_in(y[:, :self.params.vec_in_dim])
|
||||||
modulation_dims = None
|
modulation_dims = None
|
||||||
|
modulation_dims_txt = None
|
||||||
|
|
||||||
if self.params.guidance_embed:
|
if self.params.guidance_embed:
|
||||||
if guidance is not None:
|
if guidance is not None:
|
||||||
@@ -273,14 +275,14 @@ class HunyuanVideo(nn.Module):
|
|||||||
if ("double_block", i) in blocks_replace:
|
if ("double_block", i) in blocks_replace:
|
||||||
def block_wrap(args):
|
def block_wrap(args):
|
||||||
out = {}
|
out = {}
|
||||||
out["img"], out["txt"] = block(img=args["img"], txt=args["txt"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"])
|
out["img"], out["txt"] = block(img=args["img"], txt=args["txt"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"], modulation_dims_img=args["modulation_dims_img"], modulation_dims_txt=args["modulation_dims_txt"])
|
||||||
return out
|
return out
|
||||||
|
|
||||||
out = blocks_replace[("double_block", i)]({"img": img, "txt": txt, "vec": vec, "pe": pe, "attention_mask": attn_mask}, {"original_block": block_wrap})
|
out = blocks_replace[("double_block", i)]({"img": img, "txt": txt, "vec": vec, "pe": pe, "attention_mask": attn_mask, 'modulation_dims_img': modulation_dims, 'modulation_dims_txt': modulation_dims_txt}, {"original_block": block_wrap})
|
||||||
txt = out["txt"]
|
txt = out["txt"]
|
||||||
img = out["img"]
|
img = out["img"]
|
||||||
else:
|
else:
|
||||||
img, txt = block(img=img, txt=txt, vec=vec, pe=pe, attn_mask=attn_mask, modulation_dims=modulation_dims)
|
img, txt = block(img=img, txt=txt, vec=vec, pe=pe, attn_mask=attn_mask, modulation_dims_img=modulation_dims, modulation_dims_txt=modulation_dims_txt)
|
||||||
|
|
||||||
if control is not None: # Controlnet
|
if control is not None: # Controlnet
|
||||||
control_i = control.get("input")
|
control_i = control.get("input")
|
||||||
@@ -295,10 +297,10 @@ class HunyuanVideo(nn.Module):
|
|||||||
if ("single_block", i) in blocks_replace:
|
if ("single_block", i) in blocks_replace:
|
||||||
def block_wrap(args):
|
def block_wrap(args):
|
||||||
out = {}
|
out = {}
|
||||||
out["img"] = block(args["img"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"])
|
out["img"] = block(args["img"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"], modulation_dims=args["modulation_dims"])
|
||||||
return out
|
return out
|
||||||
|
|
||||||
out = blocks_replace[("single_block", i)]({"img": img, "vec": vec, "pe": pe, "attention_mask": attn_mask}, {"original_block": block_wrap})
|
out = blocks_replace[("single_block", i)]({"img": img, "vec": vec, "pe": pe, "attention_mask": attn_mask, 'modulation_dims': modulation_dims}, {"original_block": block_wrap})
|
||||||
img = out["img"]
|
img = out["img"]
|
||||||
else:
|
else:
|
||||||
img = block(img, vec=vec, pe=pe, attn_mask=attn_mask, modulation_dims=modulation_dims)
|
img = block(img, vec=vec, pe=pe, attn_mask=attn_mask, modulation_dims=modulation_dims)
|
||||||
|
@@ -1089,7 +1089,6 @@ class ModelPatcher:
|
|||||||
|
|
||||||
def patch_hooks(self, hooks: comfy.hooks.HookGroup):
|
def patch_hooks(self, hooks: comfy.hooks.HookGroup):
|
||||||
with self.use_ejected():
|
with self.use_ejected():
|
||||||
self.unpatch_hooks()
|
|
||||||
if hooks is not None:
|
if hooks is not None:
|
||||||
model_sd_keys = list(self.model_state_dict().keys())
|
model_sd_keys = list(self.model_state_dict().keys())
|
||||||
memory_counter = None
|
memory_counter = None
|
||||||
@@ -1100,12 +1099,16 @@ class ModelPatcher:
|
|||||||
# if have cached weights for hooks, use it
|
# if have cached weights for hooks, use it
|
||||||
cached_weights = self.cached_hook_patches.get(hooks, None)
|
cached_weights = self.cached_hook_patches.get(hooks, None)
|
||||||
if cached_weights is not None:
|
if cached_weights is not None:
|
||||||
|
model_sd_keys_set = set(model_sd_keys)
|
||||||
for key in cached_weights:
|
for key in cached_weights:
|
||||||
if key not in model_sd_keys:
|
if key not in model_sd_keys:
|
||||||
logging.warning(f"Cached hook could not patch. Key does not exist in model: {key}")
|
logging.warning(f"Cached hook could not patch. Key does not exist in model: {key}")
|
||||||
continue
|
continue
|
||||||
self.patch_cached_hook_weights(cached_weights=cached_weights, key=key, memory_counter=memory_counter)
|
self.patch_cached_hook_weights(cached_weights=cached_weights, key=key, memory_counter=memory_counter)
|
||||||
|
model_sd_keys_set.remove(key)
|
||||||
|
self.unpatch_hooks(model_sd_keys_set)
|
||||||
else:
|
else:
|
||||||
|
self.unpatch_hooks()
|
||||||
relevant_patches = self.get_combined_hook_patches(hooks=hooks)
|
relevant_patches = self.get_combined_hook_patches(hooks=hooks)
|
||||||
original_weights = None
|
original_weights = None
|
||||||
if len(relevant_patches) > 0:
|
if len(relevant_patches) > 0:
|
||||||
@@ -1116,6 +1119,8 @@ class ModelPatcher:
|
|||||||
continue
|
continue
|
||||||
self.patch_hook_weight_to_device(hooks=hooks, combined_patches=relevant_patches, key=key, original_weights=original_weights,
|
self.patch_hook_weight_to_device(hooks=hooks, combined_patches=relevant_patches, key=key, original_weights=original_weights,
|
||||||
memory_counter=memory_counter)
|
memory_counter=memory_counter)
|
||||||
|
else:
|
||||||
|
self.unpatch_hooks()
|
||||||
self.current_hooks = hooks
|
self.current_hooks = hooks
|
||||||
|
|
||||||
def patch_cached_hook_weights(self, cached_weights: dict, key: str, memory_counter: MemoryCounter):
|
def patch_cached_hook_weights(self, cached_weights: dict, key: str, memory_counter: MemoryCounter):
|
||||||
@@ -1172,17 +1177,23 @@ class ModelPatcher:
|
|||||||
del out_weight
|
del out_weight
|
||||||
del weight
|
del weight
|
||||||
|
|
||||||
def unpatch_hooks(self) -> None:
|
def unpatch_hooks(self, whitelist_keys_set: set[str]=None) -> None:
|
||||||
with self.use_ejected():
|
with self.use_ejected():
|
||||||
if len(self.hook_backup) == 0:
|
if len(self.hook_backup) == 0:
|
||||||
self.current_hooks = None
|
self.current_hooks = None
|
||||||
return
|
return
|
||||||
keys = list(self.hook_backup.keys())
|
keys = list(self.hook_backup.keys())
|
||||||
for k in keys:
|
if whitelist_keys_set:
|
||||||
comfy.utils.copy_to_param(self.model, k, self.hook_backup[k][0].to(device=self.hook_backup[k][1]))
|
for k in keys:
|
||||||
|
if k in whitelist_keys_set:
|
||||||
|
comfy.utils.copy_to_param(self.model, k, self.hook_backup[k][0].to(device=self.hook_backup[k][1]))
|
||||||
|
self.hook_backup.pop(k)
|
||||||
|
else:
|
||||||
|
for k in keys:
|
||||||
|
comfy.utils.copy_to_param(self.model, k, self.hook_backup[k][0].to(device=self.hook_backup[k][1]))
|
||||||
|
|
||||||
self.hook_backup.clear()
|
self.hook_backup.clear()
|
||||||
self.current_hooks = None
|
self.current_hooks = None
|
||||||
|
|
||||||
def clean_hooks(self):
|
def clean_hooks(self):
|
||||||
self.unpatch_hooks()
|
self.unpatch_hooks()
|
||||||
|
@@ -1,3 +1,3 @@
|
|||||||
# This file is automatically generated by the build process when version is
|
# This file is automatically generated by the build process when version is
|
||||||
# updated in pyproject.toml.
|
# updated in pyproject.toml.
|
||||||
__version__ = "0.3.25"
|
__version__ = "0.3.26"
|
||||||
|
@@ -1,6 +1,6 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "ComfyUI"
|
name = "ComfyUI"
|
||||||
version = "0.3.25"
|
version = "0.3.26"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
license = { file = "LICENSE" }
|
license = { file = "LICENSE" }
|
||||||
requires-python = ">=3.9"
|
requires-python = ">=3.9"
|
||||||
|
Reference in New Issue
Block a user