From f4f90c487cbc247952689e906519d8e2eb21da99 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Fri, 6 Jan 2023 09:07:18 +0100 Subject: Add contextmanager to EMAModel to apply weights temporarily --- training/util.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) (limited to 'training/util.py') diff --git a/training/util.py b/training/util.py index 93b6248..6f1e85a 100644 --- a/training/util.py +++ b/training/util.py @@ -2,6 +2,7 @@ from pathlib import Path import json import copy from typing import Iterable +from contextlib import contextmanager import torch from PIL import Image @@ -259,3 +260,14 @@ class EMAModel: raise ValueError("collected_params must all be Tensors") if len(self.collected_params) != len(self.shadow_params): raise ValueError("collected_params and shadow_params must have the same length") + + @contextmanager + def apply_temporary(self, parameters): + try: + parameters = list(parameters) + original_params = [p.clone() for p in parameters] + self.copy_to(parameters) + yield + finally: + for s_param, param in zip(original_params, parameters): + param.data.copy_(s_param.data) -- cgit v1.2.3-54-g00ecf