diff options
Diffstat (limited to 'training')
-rw-r--r-- | training/util.py | 12 |
1 files changed, 12 insertions, 0 deletions
diff --git a/training/util.py b/training/util.py index 93b6248..6f1e85a 100644 --- a/training/util.py +++ b/training/util.py | |||
@@ -2,6 +2,7 @@ from pathlib import Path | |||
2 | import json | 2 | import json |
3 | import copy | 3 | import copy |
4 | from typing import Iterable | 4 | from typing import Iterable |
5 | from contextlib import contextmanager | ||
5 | 6 | ||
6 | import torch | 7 | import torch |
7 | from PIL import Image | 8 | from PIL import Image |
@@ -259,3 +260,14 @@ class EMAModel: | |||
259 | raise ValueError("collected_params must all be Tensors") | 260 | raise ValueError("collected_params must all be Tensors") |
260 | if len(self.collected_params) != len(self.shadow_params): | 261 | if len(self.collected_params) != len(self.shadow_params): |
261 | raise ValueError("collected_params and shadow_params must have the same length") | 262 | raise ValueError("collected_params and shadow_params must have the same length") |
263 | |||
264 | @contextmanager | ||
265 | def apply_temporary(self, parameters): | ||
266 | try: | ||
267 | parameters = list(parameters) | ||
268 | original_params = [p.clone() for p in parameters] | ||
269 | self.copy_to(parameters) | ||
270 | yield | ||
271 | finally: | ||
272 | for s_param, param in zip(original_params, parameters): | ||
273 | param.data.copy_(s_param.data) | ||