diff options
author | Volpeon <git@volpeon.ink> | 2023-01-06 09:07:18 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-01-06 09:07:18 +0100 |
commit | f4f90c487cbc247952689e906519d8e2eb21da99 (patch) | |
tree | fc308cdcf02c36437e8017fab5961294f86930fe /training | |
parent | Log EMA decay (diff) | |
download | textual-inversion-diff-f4f90c487cbc247952689e906519d8e2eb21da99.tar.gz textual-inversion-diff-f4f90c487cbc247952689e906519d8e2eb21da99.tar.bz2 textual-inversion-diff-f4f90c487cbc247952689e906519d8e2eb21da99.zip |
Add contextmanager to EMAModel to apply weights temporarily
Diffstat (limited to 'training')
-rw-r--r-- | training/util.py | 12 |
1 files changed, 12 insertions, 0 deletions
diff --git a/training/util.py b/training/util.py index 93b6248..6f1e85a 100644 --- a/training/util.py +++ b/training/util.py | |||
@@ -2,6 +2,7 @@ from pathlib import Path | |||
2 | import json | 2 | import json |
3 | import copy | 3 | import copy |
4 | from typing import Iterable | 4 | from typing import Iterable |
5 | from contextlib import contextmanager | ||
5 | 6 | ||
6 | import torch | 7 | import torch |
7 | from PIL import Image | 8 | from PIL import Image |
@@ -259,3 +260,14 @@ class EMAModel: | |||
259 | raise ValueError("collected_params must all be Tensors") | 260 | raise ValueError("collected_params must all be Tensors") |
260 | if len(self.collected_params) != len(self.shadow_params): | 261 | if len(self.collected_params) != len(self.shadow_params): |
261 | raise ValueError("collected_params and shadow_params must have the same length") | 262 | raise ValueError("collected_params and shadow_params must have the same length") |
263 | |||
264 | @contextmanager | ||
265 | def apply_temporary(self, parameters): | ||
266 | try: | ||
267 | parameters = list(parameters) | ||
268 | original_params = [p.clone() for p in parameters] | ||
269 | self.copy_to(parameters) | ||
270 | yield | ||
271 | finally: | ||
272 | for s_param, param in zip(original_params, parameters): | ||
273 | param.data.copy_(s_param.data) | ||