diff options
| author | Volpeon <git@volpeon.ink> | 2023-06-25 08:40:05 +0200 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2023-06-25 08:40:05 +0200 |
| commit | 4f6ecc2fd38dc894ac7fcb4f130d3ab8af49d132 (patch) | |
| tree | c6ab59d6726b818638fe90a3ea8bb8403a8b0c30 /training/strategy/ti.py | |
| parent | Update (diff) | |
| download | textual-inversion-diff-4f6ecc2fd38dc894ac7fcb4f130d3ab8af49d132.tar.gz textual-inversion-diff-4f6ecc2fd38dc894ac7fcb4f130d3ab8af49d132.tar.bz2 textual-inversion-diff-4f6ecc2fd38dc894ac7fcb4f130d3ab8af49d132.zip | |
Update
Diffstat (limited to 'training/strategy/ti.py')
| -rw-r--r-- | training/strategy/ti.py | 7 |
1 files changed, 7 insertions, 0 deletions
diff --git a/training/strategy/ti.py b/training/strategy/ti.py index 7373982..f37dfb4 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py | |||
| @@ -1,4 +1,5 @@ | |||
| 1 | from typing import Optional | 1 | from typing import Optional |
| 2 | from types import MethodType | ||
| 2 | from functools import partial | 3 | from functools import partial |
| 3 | from contextlib import contextmanager, nullcontext | 4 | from contextlib import contextmanager, nullcontext |
| 4 | from pathlib import Path | 5 | from pathlib import Path |
| @@ -139,6 +140,12 @@ def textual_inversion_strategy_callbacks( | |||
| 139 | def on_checkpoint(step, postfix): | 140 | def on_checkpoint(step, postfix): |
| 140 | print(f"Saving checkpoint for step {step}...") | 141 | print(f"Saving checkpoint for step {step}...") |
| 141 | 142 | ||
| 143 | if postfix == "end": | ||
| 144 | text_encoder_ = accelerator.unwrap_model( | ||
| 145 | text_encoder, keep_fp32_wrapper=False | ||
| 146 | ) | ||
| 147 | text_encoder_.forward = MethodType(text_encoder_.forward, text_encoder_) | ||
| 148 | |||
| 142 | with ema_context(): | 149 | with ema_context(): |
| 143 | for token, ids in zip(placeholder_tokens, placeholder_token_ids): | 150 | for token, ids in zip(placeholder_tokens, placeholder_token_ids): |
| 144 | text_encoder.text_model.embeddings.save_embed( | 151 | text_encoder.text_model.embeddings.save_embed( |
