1
mirror of https://github.com/comfyanonymous/ComfyUI.git synced 2025-08-04 07:52:46 +08:00

Lint and fix undefined names (1/N) (#6028)

This commit is contained in:
Chenlei Hu
2024-12-12 15:55:26 -08:00
committed by GitHub
parent 60749f345d
commit 2cddbf0821
5 changed files with 15 additions and 9 deletions

View File

@@ -1,3 +1,5 @@
import logging
import math
import torch
from contextlib import contextmanager
from typing import Any, Dict, Tuple, Union
@@ -52,7 +54,7 @@ class AbstractAutoencoder(torch.nn.Module):
if self.use_ema:
self.model_ema = LitEma(self, decay=ema_decay)
logpy.info(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
logging.info(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
def get_input(self, batch) -> Any:
raise NotImplementedError()
@@ -68,14 +70,14 @@ class AbstractAutoencoder(torch.nn.Module):
self.model_ema.store(self.parameters())
self.model_ema.copy_to(self)
if context is not None:
logpy.info(f"{context}: Switched to EMA weights")
logging.info(f"{context}: Switched to EMA weights")
try:
yield None
finally:
if self.use_ema:
self.model_ema.restore(self.parameters())
if context is not None:
logpy.info(f"{context}: Restored training weights")
logging.info(f"{context}: Restored training weights")
def encode(self, *args, **kwargs) -> torch.Tensor:
raise NotImplementedError("encode()-method of abstract base class called")
@@ -84,7 +86,7 @@ class AbstractAutoencoder(torch.nn.Module):
raise NotImplementedError("decode()-method of abstract base class called")
def instantiate_optimizer_from_config(self, params, lr, cfg):
logpy.info(f"loading >>> {cfg['target']} <<< optimizer from config")
logging.info(f"loading >>> {cfg['target']} <<< optimizer from config")
return get_obj_from_str(cfg["target"])(
params, lr=lr, **cfg.get("params", dict())
)
@@ -112,7 +114,7 @@ class AutoencodingEngine(AbstractAutoencoder):
self.encoder: torch.nn.Module = instantiate_from_config(encoder_config)
self.decoder: torch.nn.Module = instantiate_from_config(decoder_config)
self.regularization: AbstractRegularizer = instantiate_from_config(
self.regularization = instantiate_from_config(
regularizer_config
)