summaryrefslogtreecommitdiffstats
path: root/train_ti.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-01-12 08:51:17 +0100
committerVolpeon <git@volpeon.ink>2023-01-12 08:51:17 +0100
commite454b7e7df13cf6ce7b96b7dcc107533edf83f6f (patch)
tree235b442919758ab0cf3fecceaf587f855b1cbda3 /train_ti.py
parentDisable Adam weight decay (diff)
downloadtextual-inversion-diff-e454b7e7df13cf6ce7b96b7dcc107533edf83f6f.tar.gz
textual-inversion-diff-e454b7e7df13cf6ce7b96b7dcc107533edf83f6f.tar.bz2
textual-inversion-diff-e454b7e7df13cf6ce7b96b7dcc107533edf83f6f.zip
Fixed TI decay
Diffstat (limited to 'train_ti.py')
-rw-r--r--train_ti.py11
1 files changed, 3 insertions, 8 deletions
diff --git a/train_ti.py b/train_ti.py
index 2c5037f..890c465 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -7,7 +7,6 @@ from pathlib import Path
7from contextlib import contextmanager, nullcontext 7from contextlib import contextmanager, nullcontext
8 8
9import torch 9import torch
10import torch.nn.functional as F
11import torch.utils.checkpoint 10import torch.utils.checkpoint
12 11
13from accelerate import Accelerator 12from accelerate import Accelerator
@@ -166,7 +165,7 @@ def parse_args():
166 parser.add_argument( 165 parser.add_argument(
167 "--tag_dropout", 166 "--tag_dropout",
168 type=float, 167 type=float,
169 default=0.1, 168 default=0,
170 help="Tag dropout probability.", 169 help="Tag dropout probability.",
171 ) 170 )
172 parser.add_argument( 171 parser.add_argument(
@@ -177,7 +176,7 @@ def parse_args():
177 parser.add_argument( 176 parser.add_argument(
178 "--vector_dropout", 177 "--vector_dropout",
179 type=int, 178 type=int,
180 default=0.1, 179 default=0,
181 help="Vector dropout probability.", 180 help="Vector dropout probability.",
182 ) 181 )
183 parser.add_argument( 182 parser.add_argument(
@@ -869,11 +868,7 @@ def main():
869 868
870 @torch.no_grad() 869 @torch.no_grad()
871 def on_clip(lr): 870 def on_clip(lr):
872 embeddings = text_encoder.text_model.embeddings.temp_token_embedding 871 text_encoder.text_model.embeddings.normalize(min(1.0, 100 * lr))
873
874 pre_norm = embeddings.weight.norm(dim=-1, keepdim=True)
875 lambda_ = min(1.0, 100 * lr)
876 embeddings.weight[:] = F.normalize(embeddings.weight, dim=-1) * (pre_norm + lambda_ * (0.4 - pre_norm))
877 872
878 loop = partial( 873 loop = partial(
879 loss_step, 874 loss_step,