diff options
Diffstat (limited to 'training')
| -rw-r--r-- | training/util.py | 60 |
1 files changed, 56 insertions, 4 deletions
diff --git a/training/util.py b/training/util.py index d0f7fcd..43a55e1 100644 --- a/training/util.py +++ b/training/util.py | |||
| @@ -1,5 +1,6 @@ | |||
| 1 | from pathlib import Path | 1 | from pathlib import Path |
| 2 | import json | 2 | import json |
| 3 | from typing import Iterable | ||
| 3 | 4 | ||
| 4 | import torch | 5 | import torch |
| 5 | from PIL import Image | 6 | from PIL import Image |
| @@ -39,8 +40,6 @@ class CheckpointerBase: | |||
| 39 | self, | 40 | self, |
| 40 | datamodule, | 41 | datamodule, |
| 41 | output_dir: Path, | 42 | output_dir: Path, |
| 42 | placeholder_token, | ||
| 43 | placeholder_token_id, | ||
| 44 | sample_image_size, | 43 | sample_image_size, |
| 45 | sample_batches, | 44 | sample_batches, |
| 46 | sample_batch_size, | 45 | sample_batch_size, |
| @@ -48,8 +47,6 @@ class CheckpointerBase: | |||
| 48 | ): | 47 | ): |
| 49 | self.datamodule = datamodule | 48 | self.datamodule = datamodule |
| 50 | self.output_dir = output_dir | 49 | self.output_dir = output_dir |
| 51 | self.placeholder_token = placeholder_token | ||
| 52 | self.placeholder_token_id = placeholder_token_id | ||
| 53 | self.sample_image_size = sample_image_size | 50 | self.sample_image_size = sample_image_size |
| 54 | self.seed = seed or torch.random.seed() | 51 | self.seed = seed or torch.random.seed() |
| 55 | self.sample_batches = sample_batches | 52 | self.sample_batches = sample_batches |
| @@ -117,3 +114,58 @@ class CheckpointerBase: | |||
| 117 | del image_grid | 114 | del image_grid |
| 118 | 115 | ||
| 119 | del generator | 116 | del generator |
| 117 | |||
| 118 | |||
| 119 | class EMAModel: | ||
| 120 | """ | ||
| 121 | Exponential Moving Average of models weights | ||
| 122 | """ | ||
| 123 | |||
| 124 | def __init__(self, parameters: Iterable[torch.nn.Parameter], decay=0.9999): | ||
| 125 | parameters = list(parameters) | ||
| 126 | self.shadow_params = [p.clone().detach() for p in parameters] | ||
| 127 | |||
| 128 | self.decay = decay | ||
| 129 | self.optimization_step = 0 | ||
| 130 | |||
| 131 | @torch.no_grad() | ||
| 132 | def step(self, parameters): | ||
| 133 | parameters = list(parameters) | ||
| 134 | |||
| 135 | self.optimization_step += 1 | ||
| 136 | |||
| 137 | # Compute the decay factor for the exponential moving average. | ||
| 138 | value = (1 + self.optimization_step) / (10 + self.optimization_step) | ||
| 139 | one_minus_decay = 1 - min(self.decay, value) | ||
| 140 | |||
| 141 | for s_param, param in zip(self.shadow_params, parameters): | ||
| 142 | if param.requires_grad: | ||
| 143 | s_param.sub_(one_minus_decay * (s_param - param)) | ||
| 144 | else: | ||
| 145 | s_param.copy_(param) | ||
| 146 | |||
| 147 | torch.cuda.empty_cache() | ||
| 148 | |||
| 149 | def copy_to(self, parameters: Iterable[torch.nn.Parameter]) -> None: | ||
| 150 | """ | ||
| 151 | Copy current averaged parameters into given collection of parameters. | ||
| 152 | Args: | ||
| 153 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be | ||
| 154 | updated with the stored moving averages. If `None`, the | ||
| 155 | parameters with which this `ExponentialMovingAverage` was | ||
| 156 | initialized will be used. | ||
| 157 | """ | ||
| 158 | parameters = list(parameters) | ||
| 159 | for s_param, param in zip(self.shadow_params, parameters): | ||
| 160 | param.data.copy_(s_param.data) | ||
| 161 | |||
| 162 | def to(self, device=None, dtype=None) -> None: | ||
| 163 | r"""Move internal buffers of the ExponentialMovingAverage to `device`. | ||
| 164 | Args: | ||
| 165 | device: like `device` argument to `torch.Tensor.to` | ||
| 166 | """ | ||
| 167 | # .to() on the tensors handles None correctly | ||
| 168 | self.shadow_params = [ | ||
| 169 | p.to(device=device, dtype=dtype) if p.is_floating_point() else p.to(device=device) | ||
| 170 | for p in self.shadow_params | ||
| 171 | ] | ||
