summaryrefslogtreecommitdiffstats
path: root/training/util.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-01-15 21:06:16 +0100
committerVolpeon <git@volpeon.ink>2023-01-15 21:06:16 +0100
commit632ce00b54ffeacfc18f44f10827f167ab3ac37c (patch)
treeecf58df2b176d3c7d1583136bf453ed24de8d7f3 /training/util.py
parentFixed Conda env (diff)
downloadtextual-inversion-diff-632ce00b54ffeacfc18f44f10827f167ab3ac37c.tar.gz
textual-inversion-diff-632ce00b54ffeacfc18f44f10827f167ab3ac37c.tar.bz2
textual-inversion-diff-632ce00b54ffeacfc18f44f10827f167ab3ac37c.zip
Restored functional trainer
Diffstat (limited to 'training/util.py')
-rw-r--r--training/util.py8
1 files changed, 5 insertions, 3 deletions
diff --git a/training/util.py b/training/util.py
index f46cc61..557b196 100644
--- a/training/util.py
+++ b/training/util.py
@@ -180,11 +180,13 @@ class EMAModel:
180 180
181 @contextmanager 181 @contextmanager
182 def apply_temporary(self, parameters: Iterable[torch.nn.Parameter]): 182 def apply_temporary(self, parameters: Iterable[torch.nn.Parameter]):
183 parameters = list(parameters)
184 original_params = [p.clone() for p in parameters]
185 self.copy_to(parameters)
186
183 try: 187 try:
184 parameters = list(parameters)
185 original_params = [p.clone() for p in parameters]
186 self.copy_to(parameters)
187 yield 188 yield
188 finally: 189 finally:
189 for o_param, param in zip(original_params, parameters): 190 for o_param, param in zip(original_params, parameters):
190 param.data.copy_(o_param.data) 191 param.data.copy_(o_param.data)
192 del original_params