summaryrefslogtreecommitdiffstats
path: root/train_ti.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-01-11 22:29:28 +0100
committerVolpeon <git@volpeon.ink>2023-01-11 22:29:28 +0100
commitb89953bea7dfe6c92164888a66d05bc7d987ef71 (patch)
tree7cfb060de5cb981373572bc0c8dfd7152b9e9173 /train_ti.py
parentHeck (diff)
downloadtextual-inversion-diff-b89953bea7dfe6c92164888a66d05bc7d987ef71.tar.gz
textual-inversion-diff-b89953bea7dfe6c92164888a66d05bc7d987ef71.tar.bz2
textual-inversion-diff-b89953bea7dfe6c92164888a66d05bc7d987ef71.zip
Fix
Diffstat (limited to 'train_ti.py')
-rw-r--r--train_ti.py6
1 files changed, 3 insertions, 3 deletions
diff --git a/train_ti.py b/train_ti.py
index e6c437e..951f3dd 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -868,11 +868,11 @@ def main():
868 pass 868 pass
869 869
870 @torch.no_grad() 870 @torch.no_grad()
871 def on_clip(): 871 def on_clip(lr):
872 embeddings = text_encoder.text_model.embeddings.temp_token_embedding 872 embeddings = text_encoder.text_model.embeddings.temp_token_embedding
873 873
874 pre_norm = embeddings.weight.norm(dim=-1, keepdim=True) 874 pre_norm = embeddings.weight.norm(dim=-1, keepdim=True)
875 lambda_ = min(1.0, 100 * lr_scheduler.get_last_lr()[0]) 875 lambda_ = min(1.0, 100 * lr)
876 embeddings.weight[:] = F.normalize(embeddings.weight, dim=-1) * (pre_norm + lambda_ * (0.4 - pre_norm)) 876 embeddings.weight[:] = F.normalize(embeddings.weight, dim=-1) * (pre_norm + lambda_ * (0.4 - pre_norm))
877 877
878 loop = partial( 878 loop = partial(
@@ -991,7 +991,7 @@ def main():
991 accelerator.backward(loss) 991 accelerator.backward(loss)
992 992
993 if accelerator.sync_gradients: 993 if accelerator.sync_gradients:
994 on_clip() 994 on_clip(lr_scheduler.get_last_lr()[0])
995 995
996 optimizer.step() 996 optimizer.step()
997 if not accelerator.optimizer_step_was_skipped: 997 if not accelerator.optimizer_step_was_skipped: