diff options
author | Volpeon <git@volpeon.ink> | 2023-06-25 08:40:05 +0200 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-06-25 08:40:05 +0200 |
commit | 4f6ecc2fd38dc894ac7fcb4f130d3ab8af49d132 (patch) | |
tree | c6ab59d6726b818638fe90a3ea8bb8403a8b0c30 /training | |
parent | Update (diff) | |
download | textual-inversion-diff-4f6ecc2fd38dc894ac7fcb4f130d3ab8af49d132.tar.gz textual-inversion-diff-4f6ecc2fd38dc894ac7fcb4f130d3ab8af49d132.tar.bz2 textual-inversion-diff-4f6ecc2fd38dc894ac7fcb4f130d3ab8af49d132.zip |
Update
Diffstat (limited to 'training')
-rw-r--r-- | training/functional.py | 3 | ||||
-rw-r--r-- | training/strategy/dreambooth.py | 3 | ||||
-rw-r--r-- | training/strategy/ti.py | 7 |
3 files changed, 10 insertions, 3 deletions
diff --git a/training/functional.py b/training/functional.py index 8917eb7..b60afe3 100644 --- a/training/functional.py +++ b/training/functional.py | |||
@@ -786,7 +786,4 @@ def train( | |||
786 | accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=False) | 786 | accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=False) |
787 | accelerator.unwrap_model(unet, keep_fp32_wrapper=False) | 787 | accelerator.unwrap_model(unet, keep_fp32_wrapper=False) |
788 | 788 | ||
789 | text_encoder.forward = MethodType(text_encoder.forward, text_encoder) | ||
790 | unet.forward = MethodType(unet.forward, unet) | ||
791 | |||
792 | accelerator.free_memory() | 789 | accelerator.free_memory() |
diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py index 3d1abf7..7e67589 100644 --- a/training/strategy/dreambooth.py +++ b/training/strategy/dreambooth.py | |||
@@ -154,6 +154,9 @@ def dreambooth_strategy_callbacks( | |||
154 | unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=False) | 154 | unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=False) |
155 | text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=False) | 155 | text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=False) |
156 | 156 | ||
157 | text_encoder_.forward = MethodType(text_encoder_.forward, text_encoder_) | ||
158 | unet_.forward = MethodType(unet_.forward, unet_) | ||
159 | |||
157 | text_encoder_.text_model.embeddings.persist(False) | 160 | text_encoder_.text_model.embeddings.persist(False) |
158 | 161 | ||
159 | with ema_context(): | 162 | with ema_context(): |
diff --git a/training/strategy/ti.py b/training/strategy/ti.py index 7373982..f37dfb4 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py | |||
@@ -1,4 +1,5 @@ | |||
1 | from typing import Optional | 1 | from typing import Optional |
2 | from types import MethodType | ||
2 | from functools import partial | 3 | from functools import partial |
3 | from contextlib import contextmanager, nullcontext | 4 | from contextlib import contextmanager, nullcontext |
4 | from pathlib import Path | 5 | from pathlib import Path |
@@ -139,6 +140,12 @@ def textual_inversion_strategy_callbacks( | |||
139 | def on_checkpoint(step, postfix): | 140 | def on_checkpoint(step, postfix): |
140 | print(f"Saving checkpoint for step {step}...") | 141 | print(f"Saving checkpoint for step {step}...") |
141 | 142 | ||
143 | if postfix == "end": | ||
144 | text_encoder_ = accelerator.unwrap_model( | ||
145 | text_encoder, keep_fp32_wrapper=False | ||
146 | ) | ||
147 | text_encoder_.forward = MethodType(text_encoder_.forward, text_encoder_) | ||
148 | |||
142 | with ema_context(): | 149 | with ema_context(): |
143 | for token, ids in zip(placeholder_tokens, placeholder_token_ids): | 150 | for token, ids in zip(placeholder_tokens, placeholder_token_ids): |
144 | text_encoder.text_model.embeddings.save_embed( | 151 | text_encoder.text_model.embeddings.save_embed( |