mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-08-02 15:04:50 +08:00
Refactor comfy.ops
comfy.ops -> comfy.ops.disable_weight_init This should make it more clear what they actually do. Some unused code has also been removed.
This commit is contained in:
@@ -8,6 +8,7 @@ from typing import Optional, Any
|
||||
|
||||
from comfy import model_management
|
||||
import comfy.ops
|
||||
ops = comfy.ops.disable_weight_init
|
||||
|
||||
if model_management.xformers_enabled_vae():
|
||||
import xformers
|
||||
@@ -48,7 +49,7 @@ class Upsample(nn.Module):
|
||||
super().__init__()
|
||||
self.with_conv = with_conv
|
||||
if self.with_conv:
|
||||
self.conv = comfy.ops.Conv2d(in_channels,
|
||||
self.conv = ops.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
@@ -78,7 +79,7 @@ class Downsample(nn.Module):
|
||||
self.with_conv = with_conv
|
||||
if self.with_conv:
|
||||
# no asymmetric padding in torch conv, must do it ourselves
|
||||
self.conv = comfy.ops.Conv2d(in_channels,
|
||||
self.conv = ops.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
@@ -105,30 +106,30 @@ class ResnetBlock(nn.Module):
|
||||
|
||||
self.swish = torch.nn.SiLU(inplace=True)
|
||||
self.norm1 = Normalize(in_channels)
|
||||
self.conv1 = comfy.ops.Conv2d(in_channels,
|
||||
self.conv1 = ops.Conv2d(in_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1)
|
||||
if temb_channels > 0:
|
||||
self.temb_proj = comfy.ops.Linear(temb_channels,
|
||||
self.temb_proj = ops.Linear(temb_channels,
|
||||
out_channels)
|
||||
self.norm2 = Normalize(out_channels)
|
||||
self.dropout = torch.nn.Dropout(dropout, inplace=True)
|
||||
self.conv2 = comfy.ops.Conv2d(out_channels,
|
||||
self.conv2 = ops.Conv2d(out_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1)
|
||||
if self.in_channels != self.out_channels:
|
||||
if self.use_conv_shortcut:
|
||||
self.conv_shortcut = comfy.ops.Conv2d(in_channels,
|
||||
self.conv_shortcut = ops.Conv2d(in_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1)
|
||||
else:
|
||||
self.nin_shortcut = comfy.ops.Conv2d(in_channels,
|
||||
self.nin_shortcut = ops.Conv2d(in_channels,
|
||||
out_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
@@ -245,22 +246,22 @@ class AttnBlock(nn.Module):
|
||||
self.in_channels = in_channels
|
||||
|
||||
self.norm = Normalize(in_channels)
|
||||
self.q = comfy.ops.Conv2d(in_channels,
|
||||
self.q = ops.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
self.k = comfy.ops.Conv2d(in_channels,
|
||||
self.k = ops.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
self.v = comfy.ops.Conv2d(in_channels,
|
||||
self.v = ops.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
self.proj_out = comfy.ops.Conv2d(in_channels,
|
||||
self.proj_out = ops.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
@@ -312,14 +313,14 @@ class Model(nn.Module):
|
||||
# timestep embedding
|
||||
self.temb = nn.Module()
|
||||
self.temb.dense = nn.ModuleList([
|
||||
comfy.ops.Linear(self.ch,
|
||||
ops.Linear(self.ch,
|
||||
self.temb_ch),
|
||||
comfy.ops.Linear(self.temb_ch,
|
||||
ops.Linear(self.temb_ch,
|
||||
self.temb_ch),
|
||||
])
|
||||
|
||||
# downsampling
|
||||
self.conv_in = comfy.ops.Conv2d(in_channels,
|
||||
self.conv_in = ops.Conv2d(in_channels,
|
||||
self.ch,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
@@ -388,7 +389,7 @@ class Model(nn.Module):
|
||||
|
||||
# end
|
||||
self.norm_out = Normalize(block_in)
|
||||
self.conv_out = comfy.ops.Conv2d(block_in,
|
||||
self.conv_out = ops.Conv2d(block_in,
|
||||
out_ch,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
@@ -461,7 +462,7 @@ class Encoder(nn.Module):
|
||||
self.in_channels = in_channels
|
||||
|
||||
# downsampling
|
||||
self.conv_in = comfy.ops.Conv2d(in_channels,
|
||||
self.conv_in = ops.Conv2d(in_channels,
|
||||
self.ch,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
@@ -506,7 +507,7 @@ class Encoder(nn.Module):
|
||||
|
||||
# end
|
||||
self.norm_out = Normalize(block_in)
|
||||
self.conv_out = comfy.ops.Conv2d(block_in,
|
||||
self.conv_out = ops.Conv2d(block_in,
|
||||
2*z_channels if double_z else z_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
@@ -541,7 +542,7 @@ class Decoder(nn.Module):
|
||||
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
|
||||
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
|
||||
resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False,
|
||||
conv_out_op=comfy.ops.Conv2d,
|
||||
conv_out_op=ops.Conv2d,
|
||||
resnet_op=ResnetBlock,
|
||||
attn_op=AttnBlock,
|
||||
**ignorekwargs):
|
||||
@@ -565,7 +566,7 @@ class Decoder(nn.Module):
|
||||
self.z_shape, np.prod(self.z_shape)))
|
||||
|
||||
# z to block_in
|
||||
self.conv_in = comfy.ops.Conv2d(z_channels,
|
||||
self.conv_in = ops.Conv2d(z_channels,
|
||||
block_in,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
|
Reference in New Issue
Block a user