summaryrefslogtreecommitdiffstats
path: root/training/util.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-01-05 22:05:25 +0100
committerVolpeon <git@volpeon.ink>2023-01-05 22:05:25 +0100
commit5c115a212e40ff177c734351601f9babe29419ce (patch)
treea66c8c67d2811e126b52ac4d4cd30a1c3ea2c2b9 /training/util.py
parentFix LR finder (diff)
downloadtextual-inversion-diff-5c115a212e40ff177c734351601f9babe29419ce.tar.gz
textual-inversion-diff-5c115a212e40ff177c734351601f9babe29419ce.tar.bz2
textual-inversion-diff-5c115a212e40ff177c734351601f9babe29419ce.zip
Added EMA to TI
Diffstat (limited to 'training/util.py')
-rw-r--r--training/util.py100
1 files changed, 95 insertions, 5 deletions
diff --git a/training/util.py b/training/util.py
index 43a55e1..93b6248 100644
--- a/training/util.py
+++ b/training/util.py
@@ -1,5 +1,6 @@
1from pathlib import Path 1from pathlib import Path
2import json 2import json
3import copy
3from typing import Iterable 4from typing import Iterable
4 5
5import torch 6import torch
@@ -116,18 +117,58 @@ class CheckpointerBase:
116 del generator 117 del generator
117 118
118 119
120# Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14
119class EMAModel: 121class EMAModel:
120 """ 122 """
121 Exponential Moving Average of models weights 123 Exponential Moving Average of models weights
122 """ 124 """
123 125
124 def __init__(self, parameters: Iterable[torch.nn.Parameter], decay=0.9999): 126 def __init__(
127 self,
128 parameters: Iterable[torch.nn.Parameter],
129 update_after_step=0,
130 inv_gamma=1.0,
131 power=2 / 3,
132 min_value=0.0,
133 max_value=0.9999,
134 ):
135 """
136 @crowsonkb's notes on EMA Warmup:
137 If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are good values for models you plan
138 to train for a million or more steps (reaches decay factor 0.999 at 31.6K steps, 0.9999 at 1M steps),
139 gamma=1, power=3/4 for models you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999
140 at 215.4k steps).
141 Args:
142 inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1.
143 power (float): Exponential factor of EMA warmup. Default: 2/3.
144 min_value (float): The minimum EMA decay rate. Default: 0.
145 """
125 parameters = list(parameters) 146 parameters = list(parameters)
126 self.shadow_params = [p.clone().detach() for p in parameters] 147 self.shadow_params = [p.clone().detach() for p in parameters]
127 148
128 self.decay = decay 149 self.collected_params = None
150
151 self.update_after_step = update_after_step
152 self.inv_gamma = inv_gamma
153 self.power = power
154 self.min_value = min_value
155 self.max_value = max_value
156
157 self.decay = 0.0
129 self.optimization_step = 0 158 self.optimization_step = 0
130 159
160 def get_decay(self, optimization_step):
161 """
162 Compute the decay factor for the exponential moving average.
163 """
164 step = max(0, optimization_step - self.update_after_step - 1)
165 value = 1 - (1 + step / self.inv_gamma) ** -self.power
166
167 if step <= 0:
168 return 0.0
169
170 return max(self.min_value, min(value, self.max_value))
171
131 @torch.no_grad() 172 @torch.no_grad()
132 def step(self, parameters): 173 def step(self, parameters):
133 parameters = list(parameters) 174 parameters = list(parameters)
@@ -135,12 +176,12 @@ class EMAModel:
135 self.optimization_step += 1 176 self.optimization_step += 1
136 177
137 # Compute the decay factor for the exponential moving average. 178 # Compute the decay factor for the exponential moving average.
138 value = (1 + self.optimization_step) / (10 + self.optimization_step) 179 self.decay = self.get_decay(self.optimization_step)
139 one_minus_decay = 1 - min(self.decay, value)
140 180
141 for s_param, param in zip(self.shadow_params, parameters): 181 for s_param, param in zip(self.shadow_params, parameters):
142 if param.requires_grad: 182 if param.requires_grad:
143 s_param.sub_(one_minus_decay * (s_param - param)) 183 s_param.mul_(self.decay)
184 s_param.add_(param.data, alpha=1 - self.decay)
144 else: 185 else:
145 s_param.copy_(param) 186 s_param.copy_(param)
146 187
@@ -169,3 +210,52 @@ class EMAModel:
169 p.to(device=device, dtype=dtype) if p.is_floating_point() else p.to(device=device) 210 p.to(device=device, dtype=dtype) if p.is_floating_point() else p.to(device=device)
170 for p in self.shadow_params 211 for p in self.shadow_params
171 ] 212 ]
213
214 def state_dict(self) -> dict:
215 r"""
216 Returns the state of the ExponentialMovingAverage as a dict.
217 This method is used by accelerate during checkpointing to save the ema state dict.
218 """
219 # Following PyTorch conventions, references to tensors are returned:
220 # "returns a reference to the state and not its copy!" -
221 # https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict
222 return {
223 "decay": self.decay,
224 "optimization_step": self.optimization_step,
225 "shadow_params": self.shadow_params,
226 "collected_params": self.collected_params,
227 }
228
229 def load_state_dict(self, state_dict: dict) -> None:
230 r"""
231 Loads the ExponentialMovingAverage state.
232 This method is used by accelerate during checkpointing to save the ema state dict.
233 Args:
234 state_dict (dict): EMA state. Should be an object returned
235 from a call to :meth:`state_dict`.
236 """
237 # deepcopy, to be consistent with module API
238 state_dict = copy.deepcopy(state_dict)
239
240 self.decay = state_dict["decay"]
241 if self.decay < 0.0 or self.decay > 1.0:
242 raise ValueError("Decay must be between 0 and 1")
243
244 self.optimization_step = state_dict["optimization_step"]
245 if not isinstance(self.optimization_step, int):
246 raise ValueError("Invalid optimization_step")
247
248 self.shadow_params = state_dict["shadow_params"]
249 if not isinstance(self.shadow_params, list):
250 raise ValueError("shadow_params must be a list")
251 if not all(isinstance(p, torch.Tensor) for p in self.shadow_params):
252 raise ValueError("shadow_params must all be Tensors")
253
254 self.collected_params = state_dict["collected_params"]
255 if self.collected_params is not None:
256 if not isinstance(self.collected_params, list):
257 raise ValueError("collected_params must be a list")
258 if not all(isinstance(p, torch.Tensor) for p in self.collected_params):
259 raise ValueError("collected_params must all be Tensors")
260 if len(self.collected_params) != len(self.shadow_params):
261 raise ValueError("collected_params and shadow_params must have the same length")