diff options
Diffstat (limited to 'training/strategy')
-rw-r--r-- | training/strategy/dreambooth.py | 3 | ||||
-rw-r--r-- | training/strategy/ti.py | 7 |
2 files changed, 10 insertions, 0 deletions
diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py index 3d1abf7..7e67589 100644 --- a/training/strategy/dreambooth.py +++ b/training/strategy/dreambooth.py | |||
@@ -154,6 +154,9 @@ def dreambooth_strategy_callbacks( | |||
154 | unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=False) | 154 | unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=False) |
155 | text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=False) | 155 | text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=False) |
156 | 156 | ||
157 | text_encoder_.forward = MethodType(text_encoder_.forward, text_encoder_) | ||
158 | unet_.forward = MethodType(unet_.forward, unet_) | ||
159 | |||
157 | text_encoder_.text_model.embeddings.persist(False) | 160 | text_encoder_.text_model.embeddings.persist(False) |
158 | 161 | ||
159 | with ema_context(): | 162 | with ema_context(): |
diff --git a/training/strategy/ti.py b/training/strategy/ti.py index 7373982..f37dfb4 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py | |||
@@ -1,4 +1,5 @@ | |||
1 | from typing import Optional | 1 | from typing import Optional |
2 | from types import MethodType | ||
2 | from functools import partial | 3 | from functools import partial |
3 | from contextlib import contextmanager, nullcontext | 4 | from contextlib import contextmanager, nullcontext |
4 | from pathlib import Path | 5 | from pathlib import Path |
@@ -139,6 +140,12 @@ def textual_inversion_strategy_callbacks( | |||
139 | def on_checkpoint(step, postfix): | 140 | def on_checkpoint(step, postfix): |
140 | print(f"Saving checkpoint for step {step}...") | 141 | print(f"Saving checkpoint for step {step}...") |
141 | 142 | ||
143 | if postfix == "end": | ||
144 | text_encoder_ = accelerator.unwrap_model( | ||
145 | text_encoder, keep_fp32_wrapper=False | ||
146 | ) | ||
147 | text_encoder_.forward = MethodType(text_encoder_.forward, text_encoder_) | ||
148 | |||
142 | with ema_context(): | 149 | with ema_context(): |
143 | for token, ids in zip(placeholder_tokens, placeholder_token_ids): | 150 | for token, ids in zip(placeholder_tokens, placeholder_token_ids): |
144 | text_encoder.text_model.embeddings.save_embed( | 151 | text_encoder.text_model.embeddings.save_embed( |