summaryrefslogtreecommitdiffstats
path: root/training/strategy
diff options
context:
space:
mode:
Diffstat (limited to 'training/strategy')
-rw-r--r--training/strategy/dreambooth.py5
-rw-r--r--training/strategy/ti.py14
2 files changed, 10 insertions, 9 deletions
diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py
index d813b49..f57e736 100644
--- a/training/strategy/dreambooth.py
+++ b/training/strategy/dreambooth.py
@@ -99,8 +99,7 @@ def dreambooth_strategy_callbacks(
99 def on_prepare(): 99 def on_prepare():
100 unet.requires_grad_(True) 100 unet.requires_grad_(True)
101 text_encoder.requires_grad_(True) 101 text_encoder.requires_grad_(True)
102 text_encoder.text_model.embeddings.persist() 102 text_encoder.text_model.embeddings.requires_grad_(False)
103 text_encoder.text_model.embeddings.temp_token_embedding.requires_grad_(False)
104 103
105 if ema_unet is not None: 104 if ema_unet is not None:
106 ema_unet.to(accelerator.device) 105 ema_unet.to(accelerator.device)
@@ -125,7 +124,7 @@ def dreambooth_strategy_callbacks(
125 with ema_context(): 124 with ema_context():
126 yield 125 yield
127 126
128 def on_before_optimize(epoch: int): 127 def on_before_optimize(lr: float, epoch: int):
129 if accelerator.sync_gradients: 128 if accelerator.sync_gradients:
130 params_to_clip = [unet.parameters()] 129 params_to_clip = [unet.parameters()]
131 if epoch < train_text_encoder_epochs: 130 if epoch < train_text_encoder_epochs:
diff --git a/training/strategy/ti.py b/training/strategy/ti.py
index ba78b98..e922954 100644
--- a/training/strategy/ti.py
+++ b/training/strategy/ti.py
@@ -117,14 +117,15 @@ def textual_inversion_strategy_callbacks(
117 with ema_context(): 117 with ema_context():
118 yield 118 yield
119 119
120 def on_after_optimize(lr: float): 120 @torch.no_grad()
121 def on_before_optimize(lr: float, epoch: int):
121 if use_emb_decay: 122 if use_emb_decay:
122 with torch.no_grad(): 123 text_encoder.text_model.embeddings.normalize(
123 text_encoder.text_model.embeddings.normalize( 124 emb_decay_target,
124 emb_decay_target, 125 min(1.0, emb_decay * lr)
125 min(1.0, emb_decay * lr) 126 )
126 )
127 127
128 def on_after_optimize(lr: float):
128 if ema_embeddings is not None: 129 if ema_embeddings is not None:
129 ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters()) 130 ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters())
130 131
@@ -154,6 +155,7 @@ def textual_inversion_strategy_callbacks(
154 on_model=on_model, 155 on_model=on_model,
155 on_train=on_train, 156 on_train=on_train,
156 on_eval=on_eval, 157 on_eval=on_eval,
158 on_before_optimize=on_before_optimize,
157 on_after_optimize=on_after_optimize, 159 on_after_optimize=on_after_optimize,
158 on_log=on_log, 160 on_log=on_log,
159 on_checkpoint=on_checkpoint, 161 on_checkpoint=on_checkpoint,