diff options
author | Volpeon <git@volpeon.ink> | 2023-01-06 11:14:24 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-01-06 11:14:24 +0100 |
commit | 672a59abeaa60dc5ef78a33bd9b58e391b922016 (patch) | |
tree | 1afb3a943af3fa7c935d133cf2768a33f11f8235 /training/util.py | |
parent | Package update (diff) | |
download | textual-inversion-diff-672a59abeaa60dc5ef78a33bd9b58e391b922016.tar.gz textual-inversion-diff-672a59abeaa60dc5ef78a33bd9b58e391b922016.tar.bz2 textual-inversion-diff-672a59abeaa60dc5ef78a33bd9b58e391b922016.zip |
Use context manager for EMA, on_train/eval hooks
Diffstat (limited to 'training/util.py')
-rw-r--r-- | training/util.py | 2 |
1 files changed, 1 insertions, 1 deletions
diff --git a/training/util.py b/training/util.py index 6f1e85a..bed7111 100644 --- a/training/util.py +++ b/training/util.py | |||
@@ -262,7 +262,7 @@ class EMAModel: | |||
262 | raise ValueError("collected_params and shadow_params must have the same length") | 262 | raise ValueError("collected_params and shadow_params must have the same length") |
263 | 263 | ||
264 | @contextmanager | 264 | @contextmanager |
265 | def apply_temporary(self, parameters): | 265 | def apply_temporary(self, parameters: Iterable[torch.nn.Parameter]): |
266 | try: | 266 | try: |
267 | parameters = list(parameters) | 267 | parameters = list(parameters) |
268 | original_params = [p.clone() for p in parameters] | 268 | original_params = [p.clone() for p in parameters] |