summaryrefslogtreecommitdiffstats
path: root/training/util.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-01-06 09:07:18 +0100
committerVolpeon <git@volpeon.ink>2023-01-06 09:07:18 +0100
commitf4f90c487cbc247952689e906519d8e2eb21da99 (patch)
treefc308cdcf02c36437e8017fab5961294f86930fe /training/util.py
parentLog EMA decay (diff)
downloadtextual-inversion-diff-f4f90c487cbc247952689e906519d8e2eb21da99.tar.gz
textual-inversion-diff-f4f90c487cbc247952689e906519d8e2eb21da99.tar.bz2
textual-inversion-diff-f4f90c487cbc247952689e906519d8e2eb21da99.zip
Add contextmanager to EMAModel to apply weights temporarily
Diffstat (limited to 'training/util.py')
-rw-r--r--training/util.py12
1 files changed, 12 insertions, 0 deletions
diff --git a/training/util.py b/training/util.py
index 93b6248..6f1e85a 100644
--- a/training/util.py
+++ b/training/util.py
@@ -2,6 +2,7 @@ from pathlib import Path
2import json 2import json
3import copy 3import copy
4from typing import Iterable 4from typing import Iterable
5from contextlib import contextmanager
5 6
6import torch 7import torch
7from PIL import Image 8from PIL import Image
@@ -259,3 +260,14 @@ class EMAModel:
259 raise ValueError("collected_params must all be Tensors") 260 raise ValueError("collected_params must all be Tensors")
260 if len(self.collected_params) != len(self.shadow_params): 261 if len(self.collected_params) != len(self.shadow_params):
261 raise ValueError("collected_params and shadow_params must have the same length") 262 raise ValueError("collected_params and shadow_params must have the same length")
263
264 @contextmanager
265 def apply_temporary(self, parameters):
266 try:
267 parameters = list(parameters)
268 original_params = [p.clone() for p in parameters]
269 self.copy_to(parameters)
270 yield
271 finally:
272 for s_param, param in zip(original_params, parameters):
273 param.data.copy_(s_param.data)