From 5c115a212e40ff177c734351601f9babe29419ce Mon Sep 17 00:00:00 2001 From: Volpeon Date: Thu, 5 Jan 2023 22:05:25 +0100 Subject: Added EMA to TI --- training/util.py | 100 ++++++++++++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 95 insertions(+), 5 deletions(-) (limited to 'training/util.py') diff --git a/training/util.py b/training/util.py index 43a55e1..93b6248 100644 --- a/training/util.py +++ b/training/util.py @@ -1,5 +1,6 @@ from pathlib import Path import json +import copy from typing import Iterable import torch @@ -116,18 +117,58 @@ class CheckpointerBase: del generator +# Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14 class EMAModel: """ Exponential Moving Average of models weights """ - def __init__(self, parameters: Iterable[torch.nn.Parameter], decay=0.9999): + def __init__( + self, + parameters: Iterable[torch.nn.Parameter], + update_after_step=0, + inv_gamma=1.0, + power=2 / 3, + min_value=0.0, + max_value=0.9999, + ): + """ + @crowsonkb's notes on EMA Warmup: + If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are good values for models you plan + to train for a million or more steps (reaches decay factor 0.999 at 31.6K steps, 0.9999 at 1M steps), + gamma=1, power=3/4 for models you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999 + at 215.4k steps). + Args: + inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1. + power (float): Exponential factor of EMA warmup. Default: 2/3. + min_value (float): The minimum EMA decay rate. Default: 0. + """ parameters = list(parameters) self.shadow_params = [p.clone().detach() for p in parameters] - self.decay = decay + self.collected_params = None + + self.update_after_step = update_after_step + self.inv_gamma = inv_gamma + self.power = power + self.min_value = min_value + self.max_value = max_value + + self.decay = 0.0 self.optimization_step = 0 + def get_decay(self, optimization_step): + """ + Compute the decay factor for the exponential moving average. + """ + step = max(0, optimization_step - self.update_after_step - 1) + value = 1 - (1 + step / self.inv_gamma) ** -self.power + + if step <= 0: + return 0.0 + + return max(self.min_value, min(value, self.max_value)) + @torch.no_grad() def step(self, parameters): parameters = list(parameters) @@ -135,12 +176,12 @@ class EMAModel: self.optimization_step += 1 # Compute the decay factor for the exponential moving average. - value = (1 + self.optimization_step) / (10 + self.optimization_step) - one_minus_decay = 1 - min(self.decay, value) + self.decay = self.get_decay(self.optimization_step) for s_param, param in zip(self.shadow_params, parameters): if param.requires_grad: - s_param.sub_(one_minus_decay * (s_param - param)) + s_param.mul_(self.decay) + s_param.add_(param.data, alpha=1 - self.decay) else: s_param.copy_(param) @@ -169,3 +210,52 @@ class EMAModel: p.to(device=device, dtype=dtype) if p.is_floating_point() else p.to(device=device) for p in self.shadow_params ] + + def state_dict(self) -> dict: + r""" + Returns the state of the ExponentialMovingAverage as a dict. + This method is used by accelerate during checkpointing to save the ema state dict. + """ + # Following PyTorch conventions, references to tensors are returned: + # "returns a reference to the state and not its copy!" - + # https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict + return { + "decay": self.decay, + "optimization_step": self.optimization_step, + "shadow_params": self.shadow_params, + "collected_params": self.collected_params, + } + + def load_state_dict(self, state_dict: dict) -> None: + r""" + Loads the ExponentialMovingAverage state. + This method is used by accelerate during checkpointing to save the ema state dict. + Args: + state_dict (dict): EMA state. Should be an object returned + from a call to :meth:`state_dict`. + """ + # deepcopy, to be consistent with module API + state_dict = copy.deepcopy(state_dict) + + self.decay = state_dict["decay"] + if self.decay < 0.0 or self.decay > 1.0: + raise ValueError("Decay must be between 0 and 1") + + self.optimization_step = state_dict["optimization_step"] + if not isinstance(self.optimization_step, int): + raise ValueError("Invalid optimization_step") + + self.shadow_params = state_dict["shadow_params"] + if not isinstance(self.shadow_params, list): + raise ValueError("shadow_params must be a list") + if not all(isinstance(p, torch.Tensor) for p in self.shadow_params): + raise ValueError("shadow_params must all be Tensors") + + self.collected_params = state_dict["collected_params"] + if self.collected_params is not None: + if not isinstance(self.collected_params, list): + raise ValueError("collected_params must be a list") + if not all(isinstance(p, torch.Tensor) for p in self.collected_params): + raise ValueError("collected_params must all be Tensors") + if len(self.collected_params) != len(self.shadow_params): + raise ValueError("collected_params and shadow_params must have the same length") -- cgit v1.2.3-54-g00ecf