summaryrefslogtreecommitdiffstats
path: root/training/util.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-01-06 11:14:24 +0100
committerVolpeon <git@volpeon.ink>2023-01-06 11:14:24 +0100
commit672a59abeaa60dc5ef78a33bd9b58e391b922016 (patch)
tree1afb3a943af3fa7c935d133cf2768a33f11f8235 /training/util.py
parentPackage update (diff)
downloadtextual-inversion-diff-672a59abeaa60dc5ef78a33bd9b58e391b922016.tar.gz
textual-inversion-diff-672a59abeaa60dc5ef78a33bd9b58e391b922016.tar.bz2
textual-inversion-diff-672a59abeaa60dc5ef78a33bd9b58e391b922016.zip
Use context manager for EMA, on_train/eval hooks
Diffstat (limited to 'training/util.py')
-rw-r--r--training/util.py2
1 files changed, 1 insertions, 1 deletions
diff --git a/training/util.py b/training/util.py
index 6f1e85a..bed7111 100644
--- a/training/util.py
+++ b/training/util.py
@@ -262,7 +262,7 @@ class EMAModel:
262 raise ValueError("collected_params and shadow_params must have the same length") 262 raise ValueError("collected_params and shadow_params must have the same length")
263 263
264 @contextmanager 264 @contextmanager
265 def apply_temporary(self, parameters): 265 def apply_temporary(self, parameters: Iterable[torch.nn.Parameter]):
266 try: 266 try:
267 parameters = list(parameters) 267 parameters = list(parameters)
268 original_params = [p.clone() for p in parameters] 268 original_params = [p.clone() for p in parameters]