summaryrefslogtreecommitdiffstats
path: root/training
diff options
context:
space:
mode:
Diffstat (limited to 'training')
-rw-r--r--training/util.py12
1 files changed, 12 insertions, 0 deletions
diff --git a/training/util.py b/training/util.py
index 93b6248..6f1e85a 100644
--- a/training/util.py
+++ b/training/util.py
@@ -2,6 +2,7 @@ from pathlib import Path
2import json 2import json
3import copy 3import copy
4from typing import Iterable 4from typing import Iterable
5from contextlib import contextmanager
5 6
6import torch 7import torch
7from PIL import Image 8from PIL import Image
@@ -259,3 +260,14 @@ class EMAModel:
259 raise ValueError("collected_params must all be Tensors") 260 raise ValueError("collected_params must all be Tensors")
260 if len(self.collected_params) != len(self.shadow_params): 261 if len(self.collected_params) != len(self.shadow_params):
261 raise ValueError("collected_params and shadow_params must have the same length") 262 raise ValueError("collected_params and shadow_params must have the same length")
263
264 @contextmanager
265 def apply_temporary(self, parameters):
266 try:
267 parameters = list(parameters)
268 original_params = [p.clone() for p in parameters]
269 self.copy_to(parameters)
270 yield
271 finally:
272 for s_param, param in zip(original_params, parameters):
273 param.data.copy_(s_param.data)