diff options
Diffstat (limited to 'training')
| -rw-r--r-- | training/util.py | 100 |
1 files changed, 95 insertions, 5 deletions
diff --git a/training/util.py b/training/util.py index 43a55e1..93b6248 100644 --- a/training/util.py +++ b/training/util.py | |||
| @@ -1,5 +1,6 @@ | |||
| 1 | from pathlib import Path | 1 | from pathlib import Path |
| 2 | import json | 2 | import json |
| 3 | import copy | ||
| 3 | from typing import Iterable | 4 | from typing import Iterable |
| 4 | 5 | ||
| 5 | import torch | 6 | import torch |
| @@ -116,18 +117,58 @@ class CheckpointerBase: | |||
| 116 | del generator | 117 | del generator |
| 117 | 118 | ||
| 118 | 119 | ||
| 120 | # Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14 | ||
| 119 | class EMAModel: | 121 | class EMAModel: |
| 120 | """ | 122 | """ |
| 121 | Exponential Moving Average of models weights | 123 | Exponential Moving Average of models weights |
| 122 | """ | 124 | """ |
| 123 | 125 | ||
| 124 | def __init__(self, parameters: Iterable[torch.nn.Parameter], decay=0.9999): | 126 | def __init__( |
| 127 | self, | ||
| 128 | parameters: Iterable[torch.nn.Parameter], | ||
| 129 | update_after_step=0, | ||
| 130 | inv_gamma=1.0, | ||
| 131 | power=2 / 3, | ||
| 132 | min_value=0.0, | ||
| 133 | max_value=0.9999, | ||
| 134 | ): | ||
| 135 | """ | ||
| 136 | @crowsonkb's notes on EMA Warmup: | ||
| 137 | If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are good values for models you plan | ||
| 138 | to train for a million or more steps (reaches decay factor 0.999 at 31.6K steps, 0.9999 at 1M steps), | ||
| 139 | gamma=1, power=3/4 for models you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999 | ||
| 140 | at 215.4k steps). | ||
| 141 | Args: | ||
| 142 | inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1. | ||
| 143 | power (float): Exponential factor of EMA warmup. Default: 2/3. | ||
| 144 | min_value (float): The minimum EMA decay rate. Default: 0. | ||
| 145 | """ | ||
| 125 | parameters = list(parameters) | 146 | parameters = list(parameters) |
| 126 | self.shadow_params = [p.clone().detach() for p in parameters] | 147 | self.shadow_params = [p.clone().detach() for p in parameters] |
| 127 | 148 | ||
| 128 | self.decay = decay | 149 | self.collected_params = None |
| 150 | |||
| 151 | self.update_after_step = update_after_step | ||
| 152 | self.inv_gamma = inv_gamma | ||
| 153 | self.power = power | ||
| 154 | self.min_value = min_value | ||
| 155 | self.max_value = max_value | ||
| 156 | |||
| 157 | self.decay = 0.0 | ||
| 129 | self.optimization_step = 0 | 158 | self.optimization_step = 0 |
| 130 | 159 | ||
| 160 | def get_decay(self, optimization_step): | ||
| 161 | """ | ||
| 162 | Compute the decay factor for the exponential moving average. | ||
| 163 | """ | ||
| 164 | step = max(0, optimization_step - self.update_after_step - 1) | ||
| 165 | value = 1 - (1 + step / self.inv_gamma) ** -self.power | ||
| 166 | |||
| 167 | if step <= 0: | ||
| 168 | return 0.0 | ||
| 169 | |||
| 170 | return max(self.min_value, min(value, self.max_value)) | ||
| 171 | |||
| 131 | @torch.no_grad() | 172 | @torch.no_grad() |
| 132 | def step(self, parameters): | 173 | def step(self, parameters): |
| 133 | parameters = list(parameters) | 174 | parameters = list(parameters) |
| @@ -135,12 +176,12 @@ class EMAModel: | |||
| 135 | self.optimization_step += 1 | 176 | self.optimization_step += 1 |
| 136 | 177 | ||
| 137 | # Compute the decay factor for the exponential moving average. | 178 | # Compute the decay factor for the exponential moving average. |
| 138 | value = (1 + self.optimization_step) / (10 + self.optimization_step) | 179 | self.decay = self.get_decay(self.optimization_step) |
| 139 | one_minus_decay = 1 - min(self.decay, value) | ||
| 140 | 180 | ||
| 141 | for s_param, param in zip(self.shadow_params, parameters): | 181 | for s_param, param in zip(self.shadow_params, parameters): |
| 142 | if param.requires_grad: | 182 | if param.requires_grad: |
| 143 | s_param.sub_(one_minus_decay * (s_param - param)) | 183 | s_param.mul_(self.decay) |
| 184 | s_param.add_(param.data, alpha=1 - self.decay) | ||
| 144 | else: | 185 | else: |
| 145 | s_param.copy_(param) | 186 | s_param.copy_(param) |
| 146 | 187 | ||
| @@ -169,3 +210,52 @@ class EMAModel: | |||
| 169 | p.to(device=device, dtype=dtype) if p.is_floating_point() else p.to(device=device) | 210 | p.to(device=device, dtype=dtype) if p.is_floating_point() else p.to(device=device) |
| 170 | for p in self.shadow_params | 211 | for p in self.shadow_params |
| 171 | ] | 212 | ] |
| 213 | |||
| 214 | def state_dict(self) -> dict: | ||
| 215 | r""" | ||
| 216 | Returns the state of the ExponentialMovingAverage as a dict. | ||
| 217 | This method is used by accelerate during checkpointing to save the ema state dict. | ||
| 218 | """ | ||
| 219 | # Following PyTorch conventions, references to tensors are returned: | ||
| 220 | # "returns a reference to the state and not its copy!" - | ||
| 221 | # https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict | ||
| 222 | return { | ||
| 223 | "decay": self.decay, | ||
| 224 | "optimization_step": self.optimization_step, | ||
| 225 | "shadow_params": self.shadow_params, | ||
| 226 | "collected_params": self.collected_params, | ||
| 227 | } | ||
| 228 | |||
| 229 | def load_state_dict(self, state_dict: dict) -> None: | ||
| 230 | r""" | ||
| 231 | Loads the ExponentialMovingAverage state. | ||
| 232 | This method is used by accelerate during checkpointing to save the ema state dict. | ||
| 233 | Args: | ||
| 234 | state_dict (dict): EMA state. Should be an object returned | ||
| 235 | from a call to :meth:`state_dict`. | ||
| 236 | """ | ||
| 237 | # deepcopy, to be consistent with module API | ||
| 238 | state_dict = copy.deepcopy(state_dict) | ||
| 239 | |||
| 240 | self.decay = state_dict["decay"] | ||
| 241 | if self.decay < 0.0 or self.decay > 1.0: | ||
| 242 | raise ValueError("Decay must be between 0 and 1") | ||
| 243 | |||
| 244 | self.optimization_step = state_dict["optimization_step"] | ||
| 245 | if not isinstance(self.optimization_step, int): | ||
| 246 | raise ValueError("Invalid optimization_step") | ||
| 247 | |||
| 248 | self.shadow_params = state_dict["shadow_params"] | ||
| 249 | if not isinstance(self.shadow_params, list): | ||
| 250 | raise ValueError("shadow_params must be a list") | ||
| 251 | if not all(isinstance(p, torch.Tensor) for p in self.shadow_params): | ||
| 252 | raise ValueError("shadow_params must all be Tensors") | ||
| 253 | |||
| 254 | self.collected_params = state_dict["collected_params"] | ||
| 255 | if self.collected_params is not None: | ||
| 256 | if not isinstance(self.collected_params, list): | ||
| 257 | raise ValueError("collected_params must be a list") | ||
| 258 | if not all(isinstance(p, torch.Tensor) for p in self.collected_params): | ||
| 259 | raise ValueError("collected_params must all be Tensors") | ||
| 260 | if len(self.collected_params) != len(self.shadow_params): | ||
| 261 | raise ValueError("collected_params and shadow_params must have the same length") | ||
