summaryrefslogtreecommitdiffstats
path: root/training
diff options
context:
space:
mode:
Diffstat (limited to 'training')
-rw-r--r--training/util.py60
1 files changed, 56 insertions, 4 deletions
diff --git a/training/util.py b/training/util.py
index d0f7fcd..43a55e1 100644
--- a/training/util.py
+++ b/training/util.py
@@ -1,5 +1,6 @@
1from pathlib import Path 1from pathlib import Path
2import json 2import json
3from typing import Iterable
3 4
4import torch 5import torch
5from PIL import Image 6from PIL import Image
@@ -39,8 +40,6 @@ class CheckpointerBase:
39 self, 40 self,
40 datamodule, 41 datamodule,
41 output_dir: Path, 42 output_dir: Path,
42 placeholder_token,
43 placeholder_token_id,
44 sample_image_size, 43 sample_image_size,
45 sample_batches, 44 sample_batches,
46 sample_batch_size, 45 sample_batch_size,
@@ -48,8 +47,6 @@ class CheckpointerBase:
48 ): 47 ):
49 self.datamodule = datamodule 48 self.datamodule = datamodule
50 self.output_dir = output_dir 49 self.output_dir = output_dir
51 self.placeholder_token = placeholder_token
52 self.placeholder_token_id = placeholder_token_id
53 self.sample_image_size = sample_image_size 50 self.sample_image_size = sample_image_size
54 self.seed = seed or torch.random.seed() 51 self.seed = seed or torch.random.seed()
55 self.sample_batches = sample_batches 52 self.sample_batches = sample_batches
@@ -117,3 +114,58 @@ class CheckpointerBase:
117 del image_grid 114 del image_grid
118 115
119 del generator 116 del generator
117
118
119class EMAModel:
120 """
121 Exponential Moving Average of models weights
122 """
123
124 def __init__(self, parameters: Iterable[torch.nn.Parameter], decay=0.9999):
125 parameters = list(parameters)
126 self.shadow_params = [p.clone().detach() for p in parameters]
127
128 self.decay = decay
129 self.optimization_step = 0
130
131 @torch.no_grad()
132 def step(self, parameters):
133 parameters = list(parameters)
134
135 self.optimization_step += 1
136
137 # Compute the decay factor for the exponential moving average.
138 value = (1 + self.optimization_step) / (10 + self.optimization_step)
139 one_minus_decay = 1 - min(self.decay, value)
140
141 for s_param, param in zip(self.shadow_params, parameters):
142 if param.requires_grad:
143 s_param.sub_(one_minus_decay * (s_param - param))
144 else:
145 s_param.copy_(param)
146
147 torch.cuda.empty_cache()
148
149 def copy_to(self, parameters: Iterable[torch.nn.Parameter]) -> None:
150 """
151 Copy current averaged parameters into given collection of parameters.
152 Args:
153 parameters: Iterable of `torch.nn.Parameter`; the parameters to be
154 updated with the stored moving averages. If `None`, the
155 parameters with which this `ExponentialMovingAverage` was
156 initialized will be used.
157 """
158 parameters = list(parameters)
159 for s_param, param in zip(self.shadow_params, parameters):
160 param.data.copy_(s_param.data)
161
162 def to(self, device=None, dtype=None) -> None:
163 r"""Move internal buffers of the ExponentialMovingAverage to `device`.
164 Args:
165 device: like `device` argument to `torch.Tensor.to`
166 """
167 # .to() on the tensors handles None correctly
168 self.shadow_params = [
169 p.to(device=device, dtype=dtype) if p.is_floating_point() else p.to(device=device)
170 for p in self.shadow_params
171 ]