summaryrefslogtreecommitdiffstats
path: root/training/strategy/ti.py
diff options
context:
space:
mode:
Diffstat (limited to 'training/strategy/ti.py')
-rw-r--r--training/strategy/ti.py20
1 files changed, 10 insertions, 10 deletions
diff --git a/training/strategy/ti.py b/training/strategy/ti.py
index 83dc566..6f8384f 100644
--- a/training/strategy/ti.py
+++ b/training/strategy/ti.py
@@ -15,7 +15,7 @@ from slugify import slugify
15 15
16from models.clip.tokenizer import MultiCLIPTokenizer 16from models.clip.tokenizer import MultiCLIPTokenizer
17from training.util import EMAModel 17from training.util import EMAModel
18from training.functional import save_samples 18from training.functional import TrainingCallbacks, save_samples
19 19
20 20
21def textual_inversion_strategy( 21def textual_inversion_strategy(
@@ -153,12 +153,12 @@ def textual_inversion_strategy(
153 with ema_context: 153 with ema_context:
154 save_samples_(step=step) 154 save_samples_(step=step)
155 155
156 return { 156 return TrainingCallbacks(
157 "on_prepare": on_prepare, 157 on_prepare=on_prepare,
158 "on_train": on_train, 158 on_train=on_train,
159 "on_eval": on_eval, 159 on_eval=on_eval,
160 "on_after_optimize": on_after_optimize, 160 on_after_optimize=on_after_optimize,
161 "on_log": on_log, 161 on_log=on_log,
162 "on_checkpoint": on_checkpoint, 162 on_checkpoint=on_checkpoint,
163 "on_sample": on_sample, 163 on_sample=on_sample,
164 } 164 )