diff options
author | Volpeon <git@volpeon.ink> | 2023-01-05 22:05:25 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-01-05 22:05:25 +0100 |
commit | 5c115a212e40ff177c734351601f9babe29419ce (patch) | |
tree | a66c8c67d2811e126b52ac4d4cd30a1c3ea2c2b9 /training | |
parent | Fix LR finder (diff) | |
download | textual-inversion-diff-5c115a212e40ff177c734351601f9babe29419ce.tar.gz textual-inversion-diff-5c115a212e40ff177c734351601f9babe29419ce.tar.bz2 textual-inversion-diff-5c115a212e40ff177c734351601f9babe29419ce.zip |
Added EMA to TI
Diffstat (limited to 'training')
-rw-r--r-- | training/util.py | 100 |
1 files changed, 95 insertions, 5 deletions
diff --git a/training/util.py b/training/util.py index 43a55e1..93b6248 100644 --- a/training/util.py +++ b/training/util.py | |||
@@ -1,5 +1,6 @@ | |||
1 | from pathlib import Path | 1 | from pathlib import Path |
2 | import json | 2 | import json |
3 | import copy | ||
3 | from typing import Iterable | 4 | from typing import Iterable |
4 | 5 | ||
5 | import torch | 6 | import torch |
@@ -116,18 +117,58 @@ class CheckpointerBase: | |||
116 | del generator | 117 | del generator |
117 | 118 | ||
118 | 119 | ||
120 | # Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14 | ||
119 | class EMAModel: | 121 | class EMAModel: |
120 | """ | 122 | """ |
121 | Exponential Moving Average of models weights | 123 | Exponential Moving Average of models weights |
122 | """ | 124 | """ |
123 | 125 | ||
124 | def __init__(self, parameters: Iterable[torch.nn.Parameter], decay=0.9999): | 126 | def __init__( |
127 | self, | ||
128 | parameters: Iterable[torch.nn.Parameter], | ||
129 | update_after_step=0, | ||
130 | inv_gamma=1.0, | ||
131 | power=2 / 3, | ||
132 | min_value=0.0, | ||
133 | max_value=0.9999, | ||
134 | ): | ||
135 | """ | ||
136 | @crowsonkb's notes on EMA Warmup: | ||
137 | If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are good values for models you plan | ||
138 | to train for a million or more steps (reaches decay factor 0.999 at 31.6K steps, 0.9999 at 1M steps), | ||
139 | gamma=1, power=3/4 for models you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999 | ||
140 | at 215.4k steps). | ||
141 | Args: | ||
142 | inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1. | ||
143 | power (float): Exponential factor of EMA warmup. Default: 2/3. | ||
144 | min_value (float): The minimum EMA decay rate. Default: 0. | ||
145 | """ | ||
125 | parameters = list(parameters) | 146 | parameters = list(parameters) |
126 | self.shadow_params = [p.clone().detach() for p in parameters] | 147 | self.shadow_params = [p.clone().detach() for p in parameters] |
127 | 148 | ||
128 | self.decay = decay | 149 | self.collected_params = None |
150 | |||
151 | self.update_after_step = update_after_step | ||
152 | self.inv_gamma = inv_gamma | ||
153 | self.power = power | ||
154 | self.min_value = min_value | ||
155 | self.max_value = max_value | ||
156 | |||
157 | self.decay = 0.0 | ||
129 | self.optimization_step = 0 | 158 | self.optimization_step = 0 |
130 | 159 | ||
160 | def get_decay(self, optimization_step): | ||
161 | """ | ||
162 | Compute the decay factor for the exponential moving average. | ||
163 | """ | ||
164 | step = max(0, optimization_step - self.update_after_step - 1) | ||
165 | value = 1 - (1 + step / self.inv_gamma) ** -self.power | ||
166 | |||
167 | if step <= 0: | ||
168 | return 0.0 | ||
169 | |||
170 | return max(self.min_value, min(value, self.max_value)) | ||
171 | |||
131 | @torch.no_grad() | 172 | @torch.no_grad() |
132 | def step(self, parameters): | 173 | def step(self, parameters): |
133 | parameters = list(parameters) | 174 | parameters = list(parameters) |
@@ -135,12 +176,12 @@ class EMAModel: | |||
135 | self.optimization_step += 1 | 176 | self.optimization_step += 1 |
136 | 177 | ||
137 | # Compute the decay factor for the exponential moving average. | 178 | # Compute the decay factor for the exponential moving average. |
138 | value = (1 + self.optimization_step) / (10 + self.optimization_step) | 179 | self.decay = self.get_decay(self.optimization_step) |
139 | one_minus_decay = 1 - min(self.decay, value) | ||
140 | 180 | ||
141 | for s_param, param in zip(self.shadow_params, parameters): | 181 | for s_param, param in zip(self.shadow_params, parameters): |
142 | if param.requires_grad: | 182 | if param.requires_grad: |
143 | s_param.sub_(one_minus_decay * (s_param - param)) | 183 | s_param.mul_(self.decay) |
184 | s_param.add_(param.data, alpha=1 - self.decay) | ||
144 | else: | 185 | else: |
145 | s_param.copy_(param) | 186 | s_param.copy_(param) |
146 | 187 | ||
@@ -169,3 +210,52 @@ class EMAModel: | |||
169 | p.to(device=device, dtype=dtype) if p.is_floating_point() else p.to(device=device) | 210 | p.to(device=device, dtype=dtype) if p.is_floating_point() else p.to(device=device) |
170 | for p in self.shadow_params | 211 | for p in self.shadow_params |
171 | ] | 212 | ] |
213 | |||
214 | def state_dict(self) -> dict: | ||
215 | r""" | ||
216 | Returns the state of the ExponentialMovingAverage as a dict. | ||
217 | This method is used by accelerate during checkpointing to save the ema state dict. | ||
218 | """ | ||
219 | # Following PyTorch conventions, references to tensors are returned: | ||
220 | # "returns a reference to the state and not its copy!" - | ||
221 | # https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict | ||
222 | return { | ||
223 | "decay": self.decay, | ||
224 | "optimization_step": self.optimization_step, | ||
225 | "shadow_params": self.shadow_params, | ||
226 | "collected_params": self.collected_params, | ||
227 | } | ||
228 | |||
229 | def load_state_dict(self, state_dict: dict) -> None: | ||
230 | r""" | ||
231 | Loads the ExponentialMovingAverage state. | ||
232 | This method is used by accelerate during checkpointing to save the ema state dict. | ||
233 | Args: | ||
234 | state_dict (dict): EMA state. Should be an object returned | ||
235 | from a call to :meth:`state_dict`. | ||
236 | """ | ||
237 | # deepcopy, to be consistent with module API | ||
238 | state_dict = copy.deepcopy(state_dict) | ||
239 | |||
240 | self.decay = state_dict["decay"] | ||
241 | if self.decay < 0.0 or self.decay > 1.0: | ||
242 | raise ValueError("Decay must be between 0 and 1") | ||
243 | |||
244 | self.optimization_step = state_dict["optimization_step"] | ||
245 | if not isinstance(self.optimization_step, int): | ||
246 | raise ValueError("Invalid optimization_step") | ||
247 | |||
248 | self.shadow_params = state_dict["shadow_params"] | ||
249 | if not isinstance(self.shadow_params, list): | ||
250 | raise ValueError("shadow_params must be a list") | ||
251 | if not all(isinstance(p, torch.Tensor) for p in self.shadow_params): | ||
252 | raise ValueError("shadow_params must all be Tensors") | ||
253 | |||
254 | self.collected_params = state_dict["collected_params"] | ||
255 | if self.collected_params is not None: | ||
256 | if not isinstance(self.collected_params, list): | ||
257 | raise ValueError("collected_params must be a list") | ||
258 | if not all(isinstance(p, torch.Tensor) for p in self.collected_params): | ||
259 | raise ValueError("collected_params must all be Tensors") | ||
260 | if len(self.collected_params) != len(self.shadow_params): | ||
261 | raise ValueError("collected_params and shadow_params must have the same length") | ||