1
mirror of https://github.com/comfyanonymous/ComfyUI.git synced 2025-08-02 15:04:50 +08:00

Refactor comfy.ops

comfy.ops -> comfy.ops.disable_weight_init

This should make it more clear what they actually do.

Some unused code has also been removed.
This commit is contained in:
comfyanonymous
2023-12-11 23:27:13 -05:00
parent b0aab1e4ea
commit 77755ab8db
10 changed files with 94 additions and 170 deletions

View File

@@ -12,13 +12,13 @@ from .util import (
checkpoint,
avg_pool_nd,
zero_module,
normalization,
timestep_embedding,
AlphaBlender,
)
from ..attention import SpatialTransformer, SpatialVideoTransformer, default
from comfy.ldm.util import exists
import comfy.ops
ops = comfy.ops.disable_weight_init
class TimestepBlock(nn.Module):
"""
@@ -70,7 +70,7 @@ class Upsample(nn.Module):
upsampling occurs in the inner-two dimensions.
"""
def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1, dtype=None, device=None, operations=comfy.ops):
def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1, dtype=None, device=None, operations=ops):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
@@ -106,7 +106,7 @@ class Downsample(nn.Module):
downsampling occurs in the inner-two dimensions.
"""
def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1, dtype=None, device=None, operations=comfy.ops):
def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1, dtype=None, device=None, operations=ops):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
@@ -159,7 +159,7 @@ class ResBlock(TimestepBlock):
skip_t_emb=False,
dtype=None,
device=None,
operations=comfy.ops
operations=ops
):
super().__init__()
self.channels = channels
@@ -284,7 +284,7 @@ class VideoResBlock(ResBlock):
down: bool = False,
dtype=None,
device=None,
operations=comfy.ops
operations=ops
):
super().__init__(
channels,
@@ -434,7 +434,7 @@ class UNetModel(nn.Module):
disable_temporal_crossattention=False,
max_ddpm_temb_period=10000,
device=None,
operations=comfy.ops,
operations=ops,
):
super().__init__()
assert use_spatial_transformer == True, "use_spatial_transformer has to be true"
@@ -581,7 +581,7 @@ class UNetModel(nn.Module):
up=False,
dtype=None,
device=None,
operations=comfy.ops
operations=ops
):
if self.use_temporal_resblocks:
return VideoResBlock(