diff options
Diffstat (limited to 'training/util.py')
-rw-r--r-- | training/util.py | 8 |
1 files changed, 5 insertions, 3 deletions
diff --git a/training/util.py b/training/util.py index f46cc61..557b196 100644 --- a/training/util.py +++ b/training/util.py | |||
@@ -180,11 +180,13 @@ class EMAModel: | |||
180 | 180 | ||
181 | @contextmanager | 181 | @contextmanager |
182 | def apply_temporary(self, parameters: Iterable[torch.nn.Parameter]): | 182 | def apply_temporary(self, parameters: Iterable[torch.nn.Parameter]): |
183 | parameters = list(parameters) | ||
184 | original_params = [p.clone() for p in parameters] | ||
185 | self.copy_to(parameters) | ||
186 | |||
183 | try: | 187 | try: |
184 | parameters = list(parameters) | ||
185 | original_params = [p.clone() for p in parameters] | ||
186 | self.copy_to(parameters) | ||
187 | yield | 188 | yield |
188 | finally: | 189 | finally: |
189 | for o_param, param in zip(original_params, parameters): | 190 | for o_param, param in zip(original_params, parameters): |
190 | param.data.copy_(o_param.data) | 191 | param.data.copy_(o_param.data) |
192 | del original_params | ||