From 672a59abeaa60dc5ef78a33bd9b58e391b922016 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Fri, 6 Jan 2023 11:14:24 +0100 Subject: Use context manager for EMA, on_train/eval hooks --- training/util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'training/util.py') diff --git a/training/util.py b/training/util.py index 6f1e85a..bed7111 100644 --- a/training/util.py +++ b/training/util.py @@ -262,7 +262,7 @@ class EMAModel: raise ValueError("collected_params and shadow_params must have the same length") @contextmanager - def apply_temporary(self, parameters): + def apply_temporary(self, parameters: Iterable[torch.nn.Parameter]): try: parameters = list(parameters) original_params = [p.clone() for p in parameters] -- cgit v1.2.3-54-g00ecf