diff options
Diffstat (limited to 'training')
| -rw-r--r-- | training/util.py | 12 |
1 files changed, 12 insertions, 0 deletions
diff --git a/training/util.py b/training/util.py index 93b6248..6f1e85a 100644 --- a/training/util.py +++ b/training/util.py | |||
| @@ -2,6 +2,7 @@ from pathlib import Path | |||
| 2 | import json | 2 | import json |
| 3 | import copy | 3 | import copy |
| 4 | from typing import Iterable | 4 | from typing import Iterable |
| 5 | from contextlib import contextmanager | ||
| 5 | 6 | ||
| 6 | import torch | 7 | import torch |
| 7 | from PIL import Image | 8 | from PIL import Image |
| @@ -259,3 +260,14 @@ class EMAModel: | |||
| 259 | raise ValueError("collected_params must all be Tensors") | 260 | raise ValueError("collected_params must all be Tensors") |
| 260 | if len(self.collected_params) != len(self.shadow_params): | 261 | if len(self.collected_params) != len(self.shadow_params): |
| 261 | raise ValueError("collected_params and shadow_params must have the same length") | 262 | raise ValueError("collected_params and shadow_params must have the same length") |
| 263 | |||
| 264 | @contextmanager | ||
| 265 | def apply_temporary(self, parameters): | ||
| 266 | try: | ||
| 267 | parameters = list(parameters) | ||
| 268 | original_params = [p.clone() for p in parameters] | ||
| 269 | self.copy_to(parameters) | ||
| 270 | yield | ||
| 271 | finally: | ||
| 272 | for s_param, param in zip(original_params, parameters): | ||
| 273 | param.data.copy_(s_param.data) | ||
