summaryrefslogtreecommitdiffstats
path: root/training/util.py
diff options
context:
space:
mode:
Diffstat (limited to 'training/util.py')
-rw-r--r--training/util.py146
1 files changed, 3 insertions, 143 deletions
diff --git a/training/util.py b/training/util.py
index 237626f..c8524de 100644
--- a/training/util.py
+++ b/training/util.py
@@ -6,6 +6,8 @@ from contextlib import contextmanager
6 6
7import torch 7import torch
8 8
9from diffusers.training_utils import EMAModel as EMAModel_
10
9 11
10def save_args(basepath: Path, args, extra={}): 12def save_args(basepath: Path, args, extra={}):
11 info = {"args": vars(args)} 13 info = {"args": vars(args)}
@@ -30,149 +32,7 @@ class AverageMeter:
30 self.avg = self.sum / self.count 32 self.avg = self.sum / self.count
31 33
32 34
33# Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14 35class EMAModel(EMAModel_):
34class EMAModel:
35 """
36 Exponential Moving Average of models weights
37 """
38
39 def __init__(
40 self,
41 parameters: Iterable[torch.nn.Parameter],
42 update_after_step: int = 0,
43 inv_gamma: float = 1.0,
44 power: float = 2 / 3,
45 min_value: float = 0.0,
46 max_value: float = 0.9999,
47 ):
48 """
49 @crowsonkb's notes on EMA Warmup:
50 If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are good values for models you plan
51 to train for a million or more steps (reaches decay factor 0.999 at 31.6K steps, 0.9999 at 1M steps),
52 gamma=1, power=3/4 for models you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999
53 at 215.4k steps).
54 Args:
55 inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1.
56 power (float): Exponential factor of EMA warmup. Default: 2/3.
57 min_value (float): The minimum EMA decay rate. Default: 0.
58 """
59 parameters = list(parameters)
60 self.shadow_params = [p.clone().detach() for p in parameters]
61
62 self.collected_params = None
63
64 self.update_after_step = update_after_step
65 self.inv_gamma = inv_gamma
66 self.power = power
67 self.min_value = min_value
68 self.max_value = max_value
69
70 self.decay = 0.0
71 self.optimization_step = 0
72
73 def get_decay(self, optimization_step: int):
74 """
75 Compute the decay factor for the exponential moving average.
76 """
77 step = max(0, optimization_step - self.update_after_step - 1)
78 value = 1 - (1 + step / self.inv_gamma) ** -self.power
79
80 if step <= 0:
81 return 0.0
82
83 return max(self.min_value, min(value, self.max_value))
84
85 @torch.no_grad()
86 def step(self, parameters):
87 parameters = list(parameters)
88
89 self.optimization_step += 1
90
91 # Compute the decay factor for the exponential moving average.
92 self.decay = self.get_decay(self.optimization_step)
93
94 for s_param, param in zip(self.shadow_params, parameters):
95 if param.requires_grad:
96 s_param.mul_(self.decay)
97 s_param.add_(param.data, alpha=1 - self.decay)
98 else:
99 s_param.copy_(param)
100
101 torch.cuda.empty_cache()
102
103 def copy_to(self, parameters: Iterable[torch.nn.Parameter]) -> None:
104 """
105 Copy current averaged parameters into given collection of parameters.
106 Args:
107 parameters: Iterable of `torch.nn.Parameter`; the parameters to be
108 updated with the stored moving averages. If `None`, the
109 parameters with which this `ExponentialMovingAverage` was
110 initialized will be used.
111 """
112 parameters = list(parameters)
113 for s_param, param in zip(self.shadow_params, parameters):
114 param.data.copy_(s_param.data)
115
116 def to(self, device=None, dtype=None) -> None:
117 r"""Move internal buffers of the ExponentialMovingAverage to `device`.
118 Args:
119 device: like `device` argument to `torch.Tensor.to`
120 """
121 # .to() on the tensors handles None correctly
122 self.shadow_params = [
123 p.to(device=device, dtype=dtype) if p.is_floating_point() else p.to(device=device)
124 for p in self.shadow_params
125 ]
126
127 def state_dict(self) -> dict:
128 r"""
129 Returns the state of the ExponentialMovingAverage as a dict.
130 This method is used by accelerate during checkpointing to save the ema state dict.
131 """
132 # Following PyTorch conventions, references to tensors are returned:
133 # "returns a reference to the state and not its copy!" -
134 # https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict
135 return {
136 "decay": self.decay,
137 "optimization_step": self.optimization_step,
138 "shadow_params": self.shadow_params,
139 "collected_params": self.collected_params,
140 }
141
142 def load_state_dict(self, state_dict: dict) -> None:
143 r"""
144 Loads the ExponentialMovingAverage state.
145 This method is used by accelerate during checkpointing to save the ema state dict.
146 Args:
147 state_dict (dict): EMA state. Should be an object returned
148 from a call to :meth:`state_dict`.
149 """
150 # deepcopy, to be consistent with module API
151 state_dict = copy.deepcopy(state_dict)
152
153 self.decay = state_dict["decay"]
154 if self.decay < 0.0 or self.decay > 1.0:
155 raise ValueError("Decay must be between 0 and 1")
156
157 self.optimization_step = state_dict["optimization_step"]
158 if not isinstance(self.optimization_step, int):
159 raise ValueError("Invalid optimization_step")
160
161 self.shadow_params = state_dict["shadow_params"]
162 if not isinstance(self.shadow_params, list):
163 raise ValueError("shadow_params must be a list")
164 if not all(isinstance(p, torch.Tensor) for p in self.shadow_params):
165 raise ValueError("shadow_params must all be Tensors")
166
167 self.collected_params = state_dict["collected_params"]
168 if self.collected_params is not None:
169 if not isinstance(self.collected_params, list):
170 raise ValueError("collected_params must be a list")
171 if not all(isinstance(p, torch.Tensor) for p in self.collected_params):
172 raise ValueError("collected_params must all be Tensors")
173 if len(self.collected_params) != len(self.shadow_params):
174 raise ValueError("collected_params and shadow_params must have the same length")
175
176 @contextmanager 36 @contextmanager
177 def apply_temporary(self, parameters: Iterable[torch.nn.Parameter]): 37 def apply_temporary(self, parameters: Iterable[torch.nn.Parameter]):
178 parameters = list(parameters) 38 parameters = list(parameters)