From 3575d041f1507811b577fd2c653171fb51c0a386 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Fri, 20 Jan 2023 14:26:17 +0100 Subject: Restored LR finder --- training/util.py | 146 ++----------------------------------------------------- 1 file changed, 3 insertions(+), 143 deletions(-) (limited to 'training/util.py') diff --git a/training/util.py b/training/util.py index 237626f..c8524de 100644 --- a/training/util.py +++ b/training/util.py @@ -6,6 +6,8 @@ from contextlib import contextmanager import torch +from diffusers.training_utils import EMAModel as EMAModel_ + def save_args(basepath: Path, args, extra={}): info = {"args": vars(args)} @@ -30,149 +32,7 @@ class AverageMeter: self.avg = self.sum / self.count -# Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14 -class EMAModel: - """ - Exponential Moving Average of models weights - """ - - def __init__( - self, - parameters: Iterable[torch.nn.Parameter], - update_after_step: int = 0, - inv_gamma: float = 1.0, - power: float = 2 / 3, - min_value: float = 0.0, - max_value: float = 0.9999, - ): - """ - @crowsonkb's notes on EMA Warmup: - If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are good values for models you plan - to train for a million or more steps (reaches decay factor 0.999 at 31.6K steps, 0.9999 at 1M steps), - gamma=1, power=3/4 for models you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999 - at 215.4k steps). - Args: - inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1. - power (float): Exponential factor of EMA warmup. Default: 2/3. - min_value (float): The minimum EMA decay rate. Default: 0. - """ - parameters = list(parameters) - self.shadow_params = [p.clone().detach() for p in parameters] - - self.collected_params = None - - self.update_after_step = update_after_step - self.inv_gamma = inv_gamma - self.power = power - self.min_value = min_value - self.max_value = max_value - - self.decay = 0.0 - self.optimization_step = 0 - - def get_decay(self, optimization_step: int): - """ - Compute the decay factor for the exponential moving average. - """ - step = max(0, optimization_step - self.update_after_step - 1) - value = 1 - (1 + step / self.inv_gamma) ** -self.power - - if step <= 0: - return 0.0 - - return max(self.min_value, min(value, self.max_value)) - - @torch.no_grad() - def step(self, parameters): - parameters = list(parameters) - - self.optimization_step += 1 - - # Compute the decay factor for the exponential moving average. - self.decay = self.get_decay(self.optimization_step) - - for s_param, param in zip(self.shadow_params, parameters): - if param.requires_grad: - s_param.mul_(self.decay) - s_param.add_(param.data, alpha=1 - self.decay) - else: - s_param.copy_(param) - - torch.cuda.empty_cache() - - def copy_to(self, parameters: Iterable[torch.nn.Parameter]) -> None: - """ - Copy current averaged parameters into given collection of parameters. - Args: - parameters: Iterable of `torch.nn.Parameter`; the parameters to be - updated with the stored moving averages. If `None`, the - parameters with which this `ExponentialMovingAverage` was - initialized will be used. - """ - parameters = list(parameters) - for s_param, param in zip(self.shadow_params, parameters): - param.data.copy_(s_param.data) - - def to(self, device=None, dtype=None) -> None: - r"""Move internal buffers of the ExponentialMovingAverage to `device`. - Args: - device: like `device` argument to `torch.Tensor.to` - """ - # .to() on the tensors handles None correctly - self.shadow_params = [ - p.to(device=device, dtype=dtype) if p.is_floating_point() else p.to(device=device) - for p in self.shadow_params - ] - - def state_dict(self) -> dict: - r""" - Returns the state of the ExponentialMovingAverage as a dict. - This method is used by accelerate during checkpointing to save the ema state dict. - """ - # Following PyTorch conventions, references to tensors are returned: - # "returns a reference to the state and not its copy!" - - # https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict - return { - "decay": self.decay, - "optimization_step": self.optimization_step, - "shadow_params": self.shadow_params, - "collected_params": self.collected_params, - } - - def load_state_dict(self, state_dict: dict) -> None: - r""" - Loads the ExponentialMovingAverage state. - This method is used by accelerate during checkpointing to save the ema state dict. - Args: - state_dict (dict): EMA state. Should be an object returned - from a call to :meth:`state_dict`. - """ - # deepcopy, to be consistent with module API - state_dict = copy.deepcopy(state_dict) - - self.decay = state_dict["decay"] - if self.decay < 0.0 or self.decay > 1.0: - raise ValueError("Decay must be between 0 and 1") - - self.optimization_step = state_dict["optimization_step"] - if not isinstance(self.optimization_step, int): - raise ValueError("Invalid optimization_step") - - self.shadow_params = state_dict["shadow_params"] - if not isinstance(self.shadow_params, list): - raise ValueError("shadow_params must be a list") - if not all(isinstance(p, torch.Tensor) for p in self.shadow_params): - raise ValueError("shadow_params must all be Tensors") - - self.collected_params = state_dict["collected_params"] - if self.collected_params is not None: - if not isinstance(self.collected_params, list): - raise ValueError("collected_params must be a list") - if not all(isinstance(p, torch.Tensor) for p in self.collected_params): - raise ValueError("collected_params must all be Tensors") - if len(self.collected_params) != len(self.shadow_params): - raise ValueError("collected_params and shadow_params must have the same length") - +class EMAModel(EMAModel_): @contextmanager def apply_temporary(self, parameters: Iterable[torch.nn.Parameter]): parameters = list(parameters) -- cgit v1.2.3-54-g00ecf