summaryrefslogtreecommitdiffstats
path: root/training/util.py
diff options
context:
space:
mode:
Diffstat (limited to 'training/util.py')
-rw-r--r--training/util.py8
1 files changed, 5 insertions, 3 deletions
diff --git a/training/util.py b/training/util.py
index f46cc61..557b196 100644
--- a/training/util.py
+++ b/training/util.py
@@ -180,11 +180,13 @@ class EMAModel:
180 180
181 @contextmanager 181 @contextmanager
182 def apply_temporary(self, parameters: Iterable[torch.nn.Parameter]): 182 def apply_temporary(self, parameters: Iterable[torch.nn.Parameter]):
183 parameters = list(parameters)
184 original_params = [p.clone() for p in parameters]
185 self.copy_to(parameters)
186
183 try: 187 try:
184 parameters = list(parameters)
185 original_params = [p.clone() for p in parameters]
186 self.copy_to(parameters)
187 yield 188 yield
188 finally: 189 finally:
189 for o_param, param in zip(original_params, parameters): 190 for o_param, param in zip(original_params, parameters):
190 param.data.copy_(o_param.data) 191 param.data.copy_(o_param.data)
192 del original_params