summaryrefslogtreecommitdiffstats
path: root/training/strategy/ti.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-01-15 10:38:49 +0100
committerVolpeon <git@volpeon.ink>2023-01-15 10:38:49 +0100
commit37baa3aa254af721728aa33befdc383858cb8ea2 (patch)
treeebf64291e052280eea661f8a8d96c486dd5c1cf6 /training/strategy/ti.py
parentAdded functional TI strategy (diff)
downloadtextual-inversion-diff-37baa3aa254af721728aa33befdc383858cb8ea2.tar.gz
textual-inversion-diff-37baa3aa254af721728aa33befdc383858cb8ea2.tar.bz2
textual-inversion-diff-37baa3aa254af721728aa33befdc383858cb8ea2.zip
Removed unused code, put training callbacks in dataclass
Diffstat (limited to 'training/strategy/ti.py')
-rw-r--r--training/strategy/ti.py20
1 files changed, 10 insertions, 10 deletions
diff --git a/training/strategy/ti.py b/training/strategy/ti.py
index 83dc566..6f8384f 100644
--- a/training/strategy/ti.py
+++ b/training/strategy/ti.py
@@ -15,7 +15,7 @@ from slugify import slugify
15 15
16from models.clip.tokenizer import MultiCLIPTokenizer 16from models.clip.tokenizer import MultiCLIPTokenizer
17from training.util import EMAModel 17from training.util import EMAModel
18from training.functional import save_samples 18from training.functional import TrainingCallbacks, save_samples
19 19
20 20
21def textual_inversion_strategy( 21def textual_inversion_strategy(
@@ -153,12 +153,12 @@ def textual_inversion_strategy(
153 with ema_context: 153 with ema_context:
154 save_samples_(step=step) 154 save_samples_(step=step)
155 155
156 return { 156 return TrainingCallbacks(
157 "on_prepare": on_prepare, 157 on_prepare=on_prepare,
158 "on_train": on_train, 158 on_train=on_train,
159 "on_eval": on_eval, 159 on_eval=on_eval,
160 "on_after_optimize": on_after_optimize, 160 on_after_optimize=on_after_optimize,
161 "on_log": on_log, 161 on_log=on_log,
162 "on_checkpoint": on_checkpoint, 162 on_checkpoint=on_checkpoint,
163 "on_sample": on_sample, 163 on_sample=on_sample,
164 } 164 )