From a2c240c8c55dfe930657f66372975d6f26feb168 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 14 Jan 2023 10:02:30 +0100 Subject: TI: Prepare UNet with Accelerate as well --- training/util.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) (limited to 'training/util.py') diff --git a/training/util.py b/training/util.py index 1008021..781cf04 100644 --- a/training/util.py +++ b/training/util.py @@ -134,11 +134,11 @@ class EMAModel: def __init__( self, parameters: Iterable[torch.nn.Parameter], - update_after_step=0, - inv_gamma=1.0, - power=2 / 3, - min_value=0.0, - max_value=0.9999, + update_after_step: int = 0, + inv_gamma: float = 1.0, + power: float = 2 / 3, + min_value: float = 0.0, + max_value: float = 0.9999, ): """ @crowsonkb's notes on EMA Warmup: @@ -165,7 +165,7 @@ class EMAModel: self.decay = 0.0 self.optimization_step = 0 - def get_decay(self, optimization_step): + def get_decay(self, optimization_step: int): """ Compute the decay factor for the exponential moving average. """ @@ -276,5 +276,5 @@ class EMAModel: self.copy_to(parameters) yield finally: - for s_param, param in zip(original_params, parameters): - param.data.copy_(s_param.data) + for o_param, param in zip(original_params, parameters): + param.data.copy_(o_param.data) -- cgit v1.2.3-54-g00ecf