diff options
| author | Volpeon <git@volpeon.ink> | 2023-01-06 09:07:18 +0100 | 
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2023-01-06 09:07:18 +0100 | 
| commit | f4f90c487cbc247952689e906519d8e2eb21da99 (patch) | |
| tree | fc308cdcf02c36437e8017fab5961294f86930fe /training/util.py | |
| parent | Log EMA decay (diff) | |
| download | textual-inversion-diff-f4f90c487cbc247952689e906519d8e2eb21da99.tar.gz textual-inversion-diff-f4f90c487cbc247952689e906519d8e2eb21da99.tar.bz2 textual-inversion-diff-f4f90c487cbc247952689e906519d8e2eb21da99.zip | |
Add contextmanager to EMAModel to apply weights temporarily
Diffstat (limited to 'training/util.py')
| -rw-r--r-- | training/util.py | 12 | 
1 files changed, 12 insertions, 0 deletions
| diff --git a/training/util.py b/training/util.py index 93b6248..6f1e85a 100644 --- a/training/util.py +++ b/training/util.py | |||
| @@ -2,6 +2,7 @@ from pathlib import Path | |||
| 2 | import json | 2 | import json | 
| 3 | import copy | 3 | import copy | 
| 4 | from typing import Iterable | 4 | from typing import Iterable | 
| 5 | from contextlib import contextmanager | ||
| 5 | 6 | ||
| 6 | import torch | 7 | import torch | 
| 7 | from PIL import Image | 8 | from PIL import Image | 
| @@ -259,3 +260,14 @@ class EMAModel: | |||
| 259 | raise ValueError("collected_params must all be Tensors") | 260 | raise ValueError("collected_params must all be Tensors") | 
| 260 | if len(self.collected_params) != len(self.shadow_params): | 261 | if len(self.collected_params) != len(self.shadow_params): | 
| 261 | raise ValueError("collected_params and shadow_params must have the same length") | 262 | raise ValueError("collected_params and shadow_params must have the same length") | 
| 263 | |||
| 264 | @contextmanager | ||
| 265 | def apply_temporary(self, parameters): | ||
| 266 | try: | ||
| 267 | parameters = list(parameters) | ||
| 268 | original_params = [p.clone() for p in parameters] | ||
| 269 | self.copy_to(parameters) | ||
| 270 | yield | ||
| 271 | finally: | ||
| 272 | for s_param, param in zip(original_params, parameters): | ||
| 273 | param.data.copy_(s_param.data) | ||
