diff options
Diffstat (limited to 'training/strategy')
| -rw-r--r-- | training/strategy/dreambooth.py | 3 | ||||
| -rw-r--r-- | training/strategy/ti.py | 7 |
2 files changed, 10 insertions, 0 deletions
diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py index 3d1abf7..7e67589 100644 --- a/training/strategy/dreambooth.py +++ b/training/strategy/dreambooth.py | |||
| @@ -154,6 +154,9 @@ def dreambooth_strategy_callbacks( | |||
| 154 | unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=False) | 154 | unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=False) |
| 155 | text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=False) | 155 | text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=False) |
| 156 | 156 | ||
| 157 | text_encoder_.forward = MethodType(text_encoder_.forward, text_encoder_) | ||
| 158 | unet_.forward = MethodType(unet_.forward, unet_) | ||
| 159 | |||
| 157 | text_encoder_.text_model.embeddings.persist(False) | 160 | text_encoder_.text_model.embeddings.persist(False) |
| 158 | 161 | ||
| 159 | with ema_context(): | 162 | with ema_context(): |
diff --git a/training/strategy/ti.py b/training/strategy/ti.py index 7373982..f37dfb4 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py | |||
| @@ -1,4 +1,5 @@ | |||
| 1 | from typing import Optional | 1 | from typing import Optional |
| 2 | from types import MethodType | ||
| 2 | from functools import partial | 3 | from functools import partial |
| 3 | from contextlib import contextmanager, nullcontext | 4 | from contextlib import contextmanager, nullcontext |
| 4 | from pathlib import Path | 5 | from pathlib import Path |
| @@ -139,6 +140,12 @@ def textual_inversion_strategy_callbacks( | |||
| 139 | def on_checkpoint(step, postfix): | 140 | def on_checkpoint(step, postfix): |
| 140 | print(f"Saving checkpoint for step {step}...") | 141 | print(f"Saving checkpoint for step {step}...") |
| 141 | 142 | ||
| 143 | if postfix == "end": | ||
| 144 | text_encoder_ = accelerator.unwrap_model( | ||
| 145 | text_encoder, keep_fp32_wrapper=False | ||
| 146 | ) | ||
| 147 | text_encoder_.forward = MethodType(text_encoder_.forward, text_encoder_) | ||
| 148 | |||
| 142 | with ema_context(): | 149 | with ema_context(): |
| 143 | for token, ids in zip(placeholder_tokens, placeholder_token_ids): | 150 | for token, ids in zip(placeholder_tokens, placeholder_token_ids): |
| 144 | text_encoder.text_model.embeddings.save_embed( | 151 | text_encoder.text_model.embeddings.save_embed( |
