summaryrefslogtreecommitdiffstats
path: root/training/strategy/ti.py
diff options
context:
space:
mode:
Diffstat (limited to 'training/strategy/ti.py')
-rw-r--r--training/strategy/ti.py19
1 files changed, 13 insertions, 6 deletions
diff --git a/training/strategy/ti.py b/training/strategy/ti.py
index 00f3529..597abd0 100644
--- a/training/strategy/ti.py
+++ b/training/strategy/ti.py
@@ -15,10 +15,10 @@ from slugify import slugify
15 15
16from models.clip.tokenizer import MultiCLIPTokenizer 16from models.clip.tokenizer import MultiCLIPTokenizer
17from training.util import EMAModel 17from training.util import EMAModel
18from training.functional import TrainingCallbacks, save_samples 18from training.functional import TrainingStrategy, TrainingCallbacks, save_samples
19 19
20 20
21def textual_inversion_strategy( 21def textual_inversion_strategy_callbacks(
22 accelerator: Accelerator, 22 accelerator: Accelerator,
23 unet: UNet2DConditionModel, 23 unet: UNet2DConditionModel,
24 text_encoder: CLIPTextModel, 24 text_encoder: CLIPTextModel,
@@ -119,17 +119,18 @@ def textual_inversion_strategy(
119 with ema_context(): 119 with ema_context():
120 yield 120 yield
121 121
122 @torch.no_grad()
123 def on_after_optimize(lr: float): 122 def on_after_optimize(lr: float):
123 if use_ema:
124 ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters())
125
126 @torch.no_grad()
127 def on_after_epoch(lr: float):
124 if use_emb_decay: 128 if use_emb_decay:
125 text_encoder.text_model.embeddings.normalize( 129 text_encoder.text_model.embeddings.normalize(
126 emb_decay_target, 130 emb_decay_target,
127 min(1.0, max(0.0, emb_decay_factor * ((lr - emb_decay_start) / (learning_rate - emb_decay_start)))) 131 min(1.0, max(0.0, emb_decay_factor * ((lr - emb_decay_start) / (learning_rate - emb_decay_start))))
128 ) 132 )
129 133
130 if use_ema:
131 ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters())
132
133 def on_log(): 134 def on_log():
134 if use_ema: 135 if use_ema:
135 return {"ema_decay": ema_embeddings.decay} 136 return {"ema_decay": ema_embeddings.decay}
@@ -157,7 +158,13 @@ def textual_inversion_strategy(
157 on_train=on_train, 158 on_train=on_train,
158 on_eval=on_eval, 159 on_eval=on_eval,
159 on_after_optimize=on_after_optimize, 160 on_after_optimize=on_after_optimize,
161 on_after_epoch=on_after_epoch,
160 on_log=on_log, 162 on_log=on_log,
161 on_checkpoint=on_checkpoint, 163 on_checkpoint=on_checkpoint,
162 on_sample=on_sample, 164 on_sample=on_sample,
163 ) 165 )
166
167
168textual_inversion_strategy = TrainingStrategy(
169 callbacks=textual_inversion_strategy_callbacks,
170)