diff options
Diffstat (limited to 'training/strategy')
| -rw-r--r-- | training/strategy/ti.py | 20 |
1 files changed, 10 insertions, 10 deletions
diff --git a/training/strategy/ti.py b/training/strategy/ti.py index 83dc566..6f8384f 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py | |||
| @@ -15,7 +15,7 @@ from slugify import slugify | |||
| 15 | 15 | ||
| 16 | from models.clip.tokenizer import MultiCLIPTokenizer | 16 | from models.clip.tokenizer import MultiCLIPTokenizer |
| 17 | from training.util import EMAModel | 17 | from training.util import EMAModel |
| 18 | from training.functional import save_samples | 18 | from training.functional import TrainingCallbacks, save_samples |
| 19 | 19 | ||
| 20 | 20 | ||
| 21 | def textual_inversion_strategy( | 21 | def textual_inversion_strategy( |
| @@ -153,12 +153,12 @@ def textual_inversion_strategy( | |||
| 153 | with ema_context: | 153 | with ema_context: |
| 154 | save_samples_(step=step) | 154 | save_samples_(step=step) |
| 155 | 155 | ||
| 156 | return { | 156 | return TrainingCallbacks( |
| 157 | "on_prepare": on_prepare, | 157 | on_prepare=on_prepare, |
| 158 | "on_train": on_train, | 158 | on_train=on_train, |
| 159 | "on_eval": on_eval, | 159 | on_eval=on_eval, |
| 160 | "on_after_optimize": on_after_optimize, | 160 | on_after_optimize=on_after_optimize, |
| 161 | "on_log": on_log, | 161 | on_log=on_log, |
| 162 | "on_checkpoint": on_checkpoint, | 162 | on_checkpoint=on_checkpoint, |
| 163 | "on_sample": on_sample, | 163 | on_sample=on_sample, |
| 164 | } | 164 | ) |
