mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-08-03 23:49:57 +08:00
Compare commits
4 Commits
Author | SHA1 | Date | |
---|---|---|---|
|
75b9b55b22 | ||
|
1765f1c60c | ||
|
1de69fe4d5 | ||
|
ae197f651b |
@@ -47,7 +47,7 @@ def reshape_for_broadcast(freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]], x
|
||||
|
||||
|
||||
def rotate_half(x):
|
||||
x_real, x_imag = x.float().reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
|
||||
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
|
||||
return torch.stack([-x_imag, x_real], dim=-1).flatten(3)
|
||||
|
||||
|
||||
@@ -78,10 +78,9 @@ def apply_rotary_emb(
|
||||
xk_out = None
|
||||
if isinstance(freqs_cis, tuple):
|
||||
cos, sin = reshape_for_broadcast(freqs_cis, xq, head_first) # [S, D]
|
||||
cos, sin = cos.to(xq.device), sin.to(xq.device)
|
||||
xq_out = (xq.float() * cos + rotate_half(xq.float()) * sin).type_as(xq)
|
||||
xq_out = (xq * cos + rotate_half(xq) * sin)
|
||||
if xk is not None:
|
||||
xk_out = (xk.float() * cos + rotate_half(xk.float()) * sin).type_as(xk)
|
||||
xk_out = (xk * cos + rotate_half(xk) * sin)
|
||||
else:
|
||||
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # [B, S, H, D//2]
|
||||
freqs_cis = reshape_for_broadcast(freqs_cis, xq_, head_first).to(xq.device) # [S, D//2] --> [1, S, 1, D//2]
|
||||
|
@@ -21,6 +21,7 @@ def calc_rope(x, patch_size, head_size):
|
||||
sub_args = [start, stop, (th, tw)]
|
||||
# head_size = HUNYUAN_DIT_CONFIG['DiT-g/2']['hidden_size'] // HUNYUAN_DIT_CONFIG['DiT-g/2']['num_heads']
|
||||
rope = get_2d_rotary_pos_embed(head_size, *sub_args)
|
||||
rope = (rope[0].to(x), rope[1].to(x))
|
||||
return rope
|
||||
|
||||
|
||||
|
@@ -495,7 +495,12 @@ def model_config_from_diffusers_unet(state_dict):
|
||||
def convert_diffusers_mmdit(state_dict, output_prefix=""):
|
||||
out_sd = {}
|
||||
|
||||
if 'transformer_blocks.0.attn.add_q_proj.weight' in state_dict: #SD3
|
||||
if 'transformer_blocks.0.attn.norm_added_k.weight' in state_dict: #Flux
|
||||
depth = count_blocks(state_dict, 'transformer_blocks.{}.')
|
||||
depth_single_blocks = count_blocks(state_dict, 'single_transformer_blocks.{}.')
|
||||
hidden_size = state_dict["x_embedder.bias"].shape[0]
|
||||
sd_map = comfy.utils.flux_to_diffusers({"depth": depth, "depth_single_blocks": depth_single_blocks, "hidden_size": hidden_size}, output_prefix=output_prefix)
|
||||
elif 'transformer_blocks.0.attn.add_q_proj.weight' in state_dict: #SD3
|
||||
num_blocks = count_blocks(state_dict, 'transformer_blocks.{}.')
|
||||
depth = state_dict["pos_embed.proj.weight"].shape[0] // 64
|
||||
sd_map = comfy.utils.mmdit_to_diffusers({"depth": depth, "num_blocks": num_blocks}, output_prefix=output_prefix)
|
||||
@@ -521,7 +526,12 @@ def convert_diffusers_mmdit(state_dict, output_prefix=""):
|
||||
old_weight = out_sd.get(t[0], None)
|
||||
if old_weight is None:
|
||||
old_weight = torch.empty_like(weight)
|
||||
old_weight = old_weight.repeat([3] + [1] * (len(old_weight.shape) - 1))
|
||||
if old_weight.shape[offset[0]] < offset[1] + offset[2]:
|
||||
exp = list(weight.shape)
|
||||
exp[offset[0]] = offset[1] + offset[2]
|
||||
new = torch.empty(exp, device=weight.device, dtype=weight.dtype)
|
||||
new[:old_weight.shape[0]] = old_weight
|
||||
old_weight = new
|
||||
|
||||
w = old_weight.narrow(offset[0], offset[1], offset[2])
|
||||
else:
|
||||
|
@@ -296,7 +296,7 @@ class LoadedModel:
|
||||
|
||||
def model_memory_required(self, device):
|
||||
if device == self.model.current_loaded_device():
|
||||
return 0
|
||||
return self.model_offloaded_memory()
|
||||
else:
|
||||
return self.model_memory()
|
||||
|
||||
@@ -308,6 +308,12 @@ class LoadedModel:
|
||||
|
||||
load_weights = not self.weights_loaded
|
||||
|
||||
if self.model.loaded_size() > 0:
|
||||
use_more_vram = lowvram_model_memory
|
||||
if use_more_vram == 0:
|
||||
use_more_vram = 1e32
|
||||
self.model_use_more_vram(use_more_vram)
|
||||
else:
|
||||
try:
|
||||
if lowvram_model_memory > 0 and load_weights:
|
||||
self.real_model = self.model.patch_model_lowvram(device_to=patch_model_to, lowvram_model_memory=lowvram_model_memory, force_patch_weights=force_patch_weights)
|
||||
@@ -484,18 +490,21 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
|
||||
|
||||
total_memory_required = {}
|
||||
for loaded_model in models_to_load:
|
||||
if unload_model_clones(loaded_model.model, unload_weights_only=True, force_unload=False) == True:#unload clones where the weights are different
|
||||
unload_model_clones(loaded_model.model, unload_weights_only=True, force_unload=False) #unload clones where the weights are different
|
||||
total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device)
|
||||
|
||||
for device in total_memory_required:
|
||||
if device != torch.device("cpu"):
|
||||
free_memory(total_memory_required[device] * 1.3 + extra_mem, device, models_already_loaded)
|
||||
for loaded_model in models_already_loaded:
|
||||
total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device)
|
||||
|
||||
for loaded_model in models_to_load:
|
||||
weights_unloaded = unload_model_clones(loaded_model.model, unload_weights_only=False, force_unload=False) #unload the rest of the clones where the weights can stay loaded
|
||||
if weights_unloaded is not None:
|
||||
loaded_model.weights_loaded = not weights_unloaded
|
||||
|
||||
for device in total_memory_required:
|
||||
if device != torch.device("cpu"):
|
||||
free_memory(total_memory_required[device] * 1.1 + extra_mem, device, models_already_loaded)
|
||||
|
||||
for loaded_model in models_to_load:
|
||||
model = loaded_model.model
|
||||
torch_dev = model.load_device
|
||||
|
@@ -102,7 +102,7 @@ class ModelPatcher:
|
||||
self.size = size
|
||||
self.model = model
|
||||
if not hasattr(self.model, 'device'):
|
||||
logging.info("Model doesn't have a device attribute.")
|
||||
logging.debug("Model doesn't have a device attribute.")
|
||||
self.model.device = offload_device
|
||||
elif self.model.device is None:
|
||||
self.model.device = offload_device
|
||||
|
@@ -457,8 +457,27 @@ def flux_to_diffusers(mmdit_config, output_prefix=""):
|
||||
key_map["{}add_k_proj.{}".format(k, end)] = (qkv, (0, hidden_size, hidden_size))
|
||||
key_map["{}add_v_proj.{}".format(k, end)] = (qkv, (0, hidden_size * 2, hidden_size))
|
||||
|
||||
block_map = {"attn.to_out.0.weight": "img_attn.proj.weight",
|
||||
block_map = {
|
||||
"attn.to_out.0.weight": "img_attn.proj.weight",
|
||||
"attn.to_out.0.bias": "img_attn.proj.bias",
|
||||
"norm1.linear.weight": "img_mod.lin.weight",
|
||||
"norm1.linear.bias": "img_mod.lin.bias",
|
||||
"norm1_context.linear.weight": "txt_mod.lin.weight",
|
||||
"norm1_context.linear.bias": "txt_mod.lin.bias",
|
||||
"attn.to_add_out.weight": "txt_attn.proj.weight",
|
||||
"attn.to_add_out.bias": "txt_attn.proj.bias",
|
||||
"ff.net.0.proj.weight": "img_mlp.0.weight",
|
||||
"ff.net.0.proj.bias": "img_mlp.0.bias",
|
||||
"ff.net.2.weight": "img_mlp.2.weight",
|
||||
"ff.net.2.bias": "img_mlp.2.bias",
|
||||
"ff_context.net.0.proj.weight": "txt_mlp.0.weight",
|
||||
"ff_context.net.0.proj.bias": "txt_mlp.0.bias",
|
||||
"ff_context.net.2.weight": "txt_mlp.2.weight",
|
||||
"ff_context.net.2.bias": "txt_mlp.2.bias",
|
||||
"attn.norm_q.weight": "img_attn.norm.query_norm.scale",
|
||||
"attn.norm_k.weight": "img_attn.norm.key_norm.scale",
|
||||
"attn.norm_added_q.weight": "txt_attn.norm.query_norm.scale",
|
||||
"attn.norm_added_k.weight": "txt_attn.norm.key_norm.scale",
|
||||
}
|
||||
|
||||
for k in block_map:
|
||||
@@ -474,15 +493,41 @@ def flux_to_diffusers(mmdit_config, output_prefix=""):
|
||||
key_map["{}to_q.{}".format(k, end)] = (qkv, (0, 0, hidden_size))
|
||||
key_map["{}to_k.{}".format(k, end)] = (qkv, (0, hidden_size, hidden_size))
|
||||
key_map["{}to_v.{}".format(k, end)] = (qkv, (0, hidden_size * 2, hidden_size))
|
||||
key_map["{}proj_mlp.{}".format(k, end)] = (qkv, (0, hidden_size * 3, hidden_size))
|
||||
key_map["{}.proj_mlp.{}".format(prefix_from, end)] = (qkv, (0, hidden_size * 3, hidden_size * 4))
|
||||
|
||||
block_map = {#TODO
|
||||
block_map = {
|
||||
"norm.linear.weight": "modulation.lin.weight",
|
||||
"norm.linear.bias": "modulation.lin.bias",
|
||||
"proj_out.weight": "linear2.weight",
|
||||
"proj_out.bias": "linear2.bias",
|
||||
"attn.norm_q.weight": "norm.query_norm.scale",
|
||||
"attn.norm_k.weight": "norm.key_norm.scale",
|
||||
}
|
||||
|
||||
for k in block_map:
|
||||
key_map["{}.{}".format(prefix_from, k)] = "{}.{}".format(prefix_to, block_map[k])
|
||||
|
||||
MAP_BASIC = { #TODO
|
||||
MAP_BASIC = {
|
||||
("final_layer.linear.bias", "proj_out.bias"),
|
||||
("final_layer.linear.weight", "proj_out.weight"),
|
||||
("img_in.bias", "x_embedder.bias"),
|
||||
("img_in.weight", "x_embedder.weight"),
|
||||
("time_in.in_layer.bias", "time_text_embed.timestep_embedder.linear_1.bias"),
|
||||
("time_in.in_layer.weight", "time_text_embed.timestep_embedder.linear_1.weight"),
|
||||
("time_in.out_layer.bias", "time_text_embed.timestep_embedder.linear_2.bias"),
|
||||
("time_in.out_layer.weight", "time_text_embed.timestep_embedder.linear_2.weight"),
|
||||
("txt_in.bias", "context_embedder.bias"),
|
||||
("txt_in.weight", "context_embedder.weight"),
|
||||
("vector_in.in_layer.bias", "time_text_embed.text_embedder.linear_1.bias"),
|
||||
("vector_in.in_layer.weight", "time_text_embed.text_embedder.linear_1.weight"),
|
||||
("vector_in.out_layer.bias", "time_text_embed.text_embedder.linear_2.bias"),
|
||||
("vector_in.out_layer.weight", "time_text_embed.text_embedder.linear_2.weight"),
|
||||
("guidance_in.in_layer.bias", "time_text_embed.guidance_embedder.linear_1.bias"),
|
||||
("guidance_in.in_layer.weight", "time_text_embed.guidance_embedder.linear_1.weight"),
|
||||
("guidance_in.out_layer.bias", "time_text_embed.guidance_embedder.linear_2.bias"),
|
||||
("guidance_in.out_layer.weight", "time_text_embed.guidance_embedder.linear_2.weight"),
|
||||
("final_layer.adaLN_modulation.1.bias", "norm_out.linear.bias", swap_scale_shift),
|
||||
("final_layer.adaLN_modulation.1.weight", "norm_out.linear.weight", swap_scale_shift),
|
||||
}
|
||||
|
||||
for k in MAP_BASIC:
|
||||
|
Reference in New Issue
Block a user