diff options
Diffstat (limited to 'training')
-rw-r--r-- | training/util.py | 60 |
1 files changed, 56 insertions, 4 deletions
diff --git a/training/util.py b/training/util.py index d0f7fcd..43a55e1 100644 --- a/training/util.py +++ b/training/util.py | |||
@@ -1,5 +1,6 @@ | |||
1 | from pathlib import Path | 1 | from pathlib import Path |
2 | import json | 2 | import json |
3 | from typing import Iterable | ||
3 | 4 | ||
4 | import torch | 5 | import torch |
5 | from PIL import Image | 6 | from PIL import Image |
@@ -39,8 +40,6 @@ class CheckpointerBase: | |||
39 | self, | 40 | self, |
40 | datamodule, | 41 | datamodule, |
41 | output_dir: Path, | 42 | output_dir: Path, |
42 | placeholder_token, | ||
43 | placeholder_token_id, | ||
44 | sample_image_size, | 43 | sample_image_size, |
45 | sample_batches, | 44 | sample_batches, |
46 | sample_batch_size, | 45 | sample_batch_size, |
@@ -48,8 +47,6 @@ class CheckpointerBase: | |||
48 | ): | 47 | ): |
49 | self.datamodule = datamodule | 48 | self.datamodule = datamodule |
50 | self.output_dir = output_dir | 49 | self.output_dir = output_dir |
51 | self.placeholder_token = placeholder_token | ||
52 | self.placeholder_token_id = placeholder_token_id | ||
53 | self.sample_image_size = sample_image_size | 50 | self.sample_image_size = sample_image_size |
54 | self.seed = seed or torch.random.seed() | 51 | self.seed = seed or torch.random.seed() |
55 | self.sample_batches = sample_batches | 52 | self.sample_batches = sample_batches |
@@ -117,3 +114,58 @@ class CheckpointerBase: | |||
117 | del image_grid | 114 | del image_grid |
118 | 115 | ||
119 | del generator | 116 | del generator |
117 | |||
118 | |||
119 | class EMAModel: | ||
120 | """ | ||
121 | Exponential Moving Average of models weights | ||
122 | """ | ||
123 | |||
124 | def __init__(self, parameters: Iterable[torch.nn.Parameter], decay=0.9999): | ||
125 | parameters = list(parameters) | ||
126 | self.shadow_params = [p.clone().detach() for p in parameters] | ||
127 | |||
128 | self.decay = decay | ||
129 | self.optimization_step = 0 | ||
130 | |||
131 | @torch.no_grad() | ||
132 | def step(self, parameters): | ||
133 | parameters = list(parameters) | ||
134 | |||
135 | self.optimization_step += 1 | ||
136 | |||
137 | # Compute the decay factor for the exponential moving average. | ||
138 | value = (1 + self.optimization_step) / (10 + self.optimization_step) | ||
139 | one_minus_decay = 1 - min(self.decay, value) | ||
140 | |||
141 | for s_param, param in zip(self.shadow_params, parameters): | ||
142 | if param.requires_grad: | ||
143 | s_param.sub_(one_minus_decay * (s_param - param)) | ||
144 | else: | ||
145 | s_param.copy_(param) | ||
146 | |||
147 | torch.cuda.empty_cache() | ||
148 | |||
149 | def copy_to(self, parameters: Iterable[torch.nn.Parameter]) -> None: | ||
150 | """ | ||
151 | Copy current averaged parameters into given collection of parameters. | ||
152 | Args: | ||
153 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be | ||
154 | updated with the stored moving averages. If `None`, the | ||
155 | parameters with which this `ExponentialMovingAverage` was | ||
156 | initialized will be used. | ||
157 | """ | ||
158 | parameters = list(parameters) | ||
159 | for s_param, param in zip(self.shadow_params, parameters): | ||
160 | param.data.copy_(s_param.data) | ||
161 | |||
162 | def to(self, device=None, dtype=None) -> None: | ||
163 | r"""Move internal buffers of the ExponentialMovingAverage to `device`. | ||
164 | Args: | ||
165 | device: like `device` argument to `torch.Tensor.to` | ||
166 | """ | ||
167 | # .to() on the tensors handles None correctly | ||
168 | self.shadow_params = [ | ||
169 | p.to(device=device, dtype=dtype) if p.is_floating_point() else p.to(device=device) | ||
170 | for p in self.shadow_params | ||
171 | ] | ||