diff options
Diffstat (limited to 'training/strategy')
-rw-r--r-- | training/strategy/ti.py | 20 |
1 files changed, 10 insertions, 10 deletions
diff --git a/training/strategy/ti.py b/training/strategy/ti.py index 83dc566..6f8384f 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py | |||
@@ -15,7 +15,7 @@ from slugify import slugify | |||
15 | 15 | ||
16 | from models.clip.tokenizer import MultiCLIPTokenizer | 16 | from models.clip.tokenizer import MultiCLIPTokenizer |
17 | from training.util import EMAModel | 17 | from training.util import EMAModel |
18 | from training.functional import save_samples | 18 | from training.functional import TrainingCallbacks, save_samples |
19 | 19 | ||
20 | 20 | ||
21 | def textual_inversion_strategy( | 21 | def textual_inversion_strategy( |
@@ -153,12 +153,12 @@ def textual_inversion_strategy( | |||
153 | with ema_context: | 153 | with ema_context: |
154 | save_samples_(step=step) | 154 | save_samples_(step=step) |
155 | 155 | ||
156 | return { | 156 | return TrainingCallbacks( |
157 | "on_prepare": on_prepare, | 157 | on_prepare=on_prepare, |
158 | "on_train": on_train, | 158 | on_train=on_train, |
159 | "on_eval": on_eval, | 159 | on_eval=on_eval, |
160 | "on_after_optimize": on_after_optimize, | 160 | on_after_optimize=on_after_optimize, |
161 | "on_log": on_log, | 161 | on_log=on_log, |
162 | "on_checkpoint": on_checkpoint, | 162 | on_checkpoint=on_checkpoint, |
163 | "on_sample": on_sample, | 163 | on_sample=on_sample, |
164 | } | 164 | ) |