diff options
author | Volpeon <git@volpeon.ink> | 2023-01-12 08:51:17 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-01-12 08:51:17 +0100 |
commit | e454b7e7df13cf6ce7b96b7dcc107533edf83f6f (patch) | |
tree | 235b442919758ab0cf3fecceaf587f855b1cbda3 /train_ti.py | |
parent | Disable Adam weight decay (diff) | |
download | textual-inversion-diff-e454b7e7df13cf6ce7b96b7dcc107533edf83f6f.tar.gz textual-inversion-diff-e454b7e7df13cf6ce7b96b7dcc107533edf83f6f.tar.bz2 textual-inversion-diff-e454b7e7df13cf6ce7b96b7dcc107533edf83f6f.zip |
Fixed TI decay
Diffstat (limited to 'train_ti.py')
-rw-r--r-- | train_ti.py | 11 |
1 files changed, 3 insertions, 8 deletions
diff --git a/train_ti.py b/train_ti.py index 2c5037f..890c465 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -7,7 +7,6 @@ from pathlib import Path | |||
7 | from contextlib import contextmanager, nullcontext | 7 | from contextlib import contextmanager, nullcontext |
8 | 8 | ||
9 | import torch | 9 | import torch |
10 | import torch.nn.functional as F | ||
11 | import torch.utils.checkpoint | 10 | import torch.utils.checkpoint |
12 | 11 | ||
13 | from accelerate import Accelerator | 12 | from accelerate import Accelerator |
@@ -166,7 +165,7 @@ def parse_args(): | |||
166 | parser.add_argument( | 165 | parser.add_argument( |
167 | "--tag_dropout", | 166 | "--tag_dropout", |
168 | type=float, | 167 | type=float, |
169 | default=0.1, | 168 | default=0, |
170 | help="Tag dropout probability.", | 169 | help="Tag dropout probability.", |
171 | ) | 170 | ) |
172 | parser.add_argument( | 171 | parser.add_argument( |
@@ -177,7 +176,7 @@ def parse_args(): | |||
177 | parser.add_argument( | 176 | parser.add_argument( |
178 | "--vector_dropout", | 177 | "--vector_dropout", |
179 | type=int, | 178 | type=int, |
180 | default=0.1, | 179 | default=0, |
181 | help="Vector dropout probability.", | 180 | help="Vector dropout probability.", |
182 | ) | 181 | ) |
183 | parser.add_argument( | 182 | parser.add_argument( |
@@ -869,11 +868,7 @@ def main(): | |||
869 | 868 | ||
870 | @torch.no_grad() | 869 | @torch.no_grad() |
871 | def on_clip(lr): | 870 | def on_clip(lr): |
872 | embeddings = text_encoder.text_model.embeddings.temp_token_embedding | 871 | text_encoder.text_model.embeddings.normalize(min(1.0, 100 * lr)) |
873 | |||
874 | pre_norm = embeddings.weight.norm(dim=-1, keepdim=True) | ||
875 | lambda_ = min(1.0, 100 * lr) | ||
876 | embeddings.weight[:] = F.normalize(embeddings.weight, dim=-1) * (pre_norm + lambda_ * (0.4 - pre_norm)) | ||
877 | 872 | ||
878 | loop = partial( | 873 | loop = partial( |
879 | loss_step, | 874 | loss_step, |