diff options
author | Volpeon <git@volpeon.ink> | 2023-01-17 16:39:33 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-01-17 16:39:33 +0100 |
commit | 8e9d62225db11913bf7ef67221fc3508d7fe1149 (patch) | |
tree | 4c17e8491a77bc92deb276dedba7949a8bb7297a /training/strategy | |
parent | Optimized embedding normalization (diff) | |
download | textual-inversion-diff-8e9d62225db11913bf7ef67221fc3508d7fe1149.tar.gz textual-inversion-diff-8e9d62225db11913bf7ef67221fc3508d7fe1149.tar.bz2 textual-inversion-diff-8e9d62225db11913bf7ef67221fc3508d7fe1149.zip |
Update
Diffstat (limited to 'training/strategy')
-rw-r--r-- | training/strategy/dreambooth.py | 5 | ||||
-rw-r--r-- | training/strategy/ti.py | 14 |
2 files changed, 10 insertions, 9 deletions
diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py index d813b49..f57e736 100644 --- a/training/strategy/dreambooth.py +++ b/training/strategy/dreambooth.py | |||
@@ -99,8 +99,7 @@ def dreambooth_strategy_callbacks( | |||
99 | def on_prepare(): | 99 | def on_prepare(): |
100 | unet.requires_grad_(True) | 100 | unet.requires_grad_(True) |
101 | text_encoder.requires_grad_(True) | 101 | text_encoder.requires_grad_(True) |
102 | text_encoder.text_model.embeddings.persist() | 102 | text_encoder.text_model.embeddings.requires_grad_(False) |
103 | text_encoder.text_model.embeddings.temp_token_embedding.requires_grad_(False) | ||
104 | 103 | ||
105 | if ema_unet is not None: | 104 | if ema_unet is not None: |
106 | ema_unet.to(accelerator.device) | 105 | ema_unet.to(accelerator.device) |
@@ -125,7 +124,7 @@ def dreambooth_strategy_callbacks( | |||
125 | with ema_context(): | 124 | with ema_context(): |
126 | yield | 125 | yield |
127 | 126 | ||
128 | def on_before_optimize(epoch: int): | 127 | def on_before_optimize(lr: float, epoch: int): |
129 | if accelerator.sync_gradients: | 128 | if accelerator.sync_gradients: |
130 | params_to_clip = [unet.parameters()] | 129 | params_to_clip = [unet.parameters()] |
131 | if epoch < train_text_encoder_epochs: | 130 | if epoch < train_text_encoder_epochs: |
diff --git a/training/strategy/ti.py b/training/strategy/ti.py index ba78b98..e922954 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py | |||
@@ -117,14 +117,15 @@ def textual_inversion_strategy_callbacks( | |||
117 | with ema_context(): | 117 | with ema_context(): |
118 | yield | 118 | yield |
119 | 119 | ||
120 | def on_after_optimize(lr: float): | 120 | @torch.no_grad() |
121 | def on_before_optimize(lr: float, epoch: int): | ||
121 | if use_emb_decay: | 122 | if use_emb_decay: |
122 | with torch.no_grad(): | 123 | text_encoder.text_model.embeddings.normalize( |
123 | text_encoder.text_model.embeddings.normalize( | 124 | emb_decay_target, |
124 | emb_decay_target, | 125 | min(1.0, emb_decay * lr) |
125 | min(1.0, emb_decay * lr) | 126 | ) |
126 | ) | ||
127 | 127 | ||
128 | def on_after_optimize(lr: float): | ||
128 | if ema_embeddings is not None: | 129 | if ema_embeddings is not None: |
129 | ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters()) | 130 | ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters()) |
130 | 131 | ||
@@ -154,6 +155,7 @@ def textual_inversion_strategy_callbacks( | |||
154 | on_model=on_model, | 155 | on_model=on_model, |
155 | on_train=on_train, | 156 | on_train=on_train, |
156 | on_eval=on_eval, | 157 | on_eval=on_eval, |
158 | on_before_optimize=on_before_optimize, | ||
157 | on_after_optimize=on_after_optimize, | 159 | on_after_optimize=on_after_optimize, |
158 | on_log=on_log, | 160 | on_log=on_log, |
159 | on_checkpoint=on_checkpoint, | 161 | on_checkpoint=on_checkpoint, |