diff options
author | Volpeon <git@volpeon.ink> | 2023-01-15 21:06:16 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-01-15 21:06:16 +0100 |
commit | 632ce00b54ffeacfc18f44f10827f167ab3ac37c (patch) | |
tree | ecf58df2b176d3c7d1583136bf453ed24de8d7f3 /training/util.py | |
parent | Fixed Conda env (diff) | |
download | textual-inversion-diff-632ce00b54ffeacfc18f44f10827f167ab3ac37c.tar.gz textual-inversion-diff-632ce00b54ffeacfc18f44f10827f167ab3ac37c.tar.bz2 textual-inversion-diff-632ce00b54ffeacfc18f44f10827f167ab3ac37c.zip |
Restored functional trainer
Diffstat (limited to 'training/util.py')
-rw-r--r-- | training/util.py | 8 |
1 files changed, 5 insertions, 3 deletions
diff --git a/training/util.py b/training/util.py index f46cc61..557b196 100644 --- a/training/util.py +++ b/training/util.py | |||
@@ -180,11 +180,13 @@ class EMAModel: | |||
180 | 180 | ||
181 | @contextmanager | 181 | @contextmanager |
182 | def apply_temporary(self, parameters: Iterable[torch.nn.Parameter]): | 182 | def apply_temporary(self, parameters: Iterable[torch.nn.Parameter]): |
183 | parameters = list(parameters) | ||
184 | original_params = [p.clone() for p in parameters] | ||
185 | self.copy_to(parameters) | ||
186 | |||
183 | try: | 187 | try: |
184 | parameters = list(parameters) | ||
185 | original_params = [p.clone() for p in parameters] | ||
186 | self.copy_to(parameters) | ||
187 | yield | 188 | yield |
188 | finally: | 189 | finally: |
189 | for o_param, param in zip(original_params, parameters): | 190 | for o_param, param in zip(original_params, parameters): |
190 | param.data.copy_(o_param.data) | 191 | param.data.copy_(o_param.data) |
192 | del original_params | ||