diff options
| author | Volpeon <git@volpeon.ink> | 2023-01-13 09:17:44 +0100 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2023-01-13 09:17:44 +0100 |
| commit | 07e8111a98afbf912ba99a28e2d7ea647b2d29fc (patch) | |
| tree | 853aaae6f5ff477b73d5649f3f89303ff8654b70 /train_ti.py | |
| parent | Code deduplication (diff) | |
| download | textual-inversion-diff-07e8111a98afbf912ba99a28e2d7ea647b2d29fc.tar.gz textual-inversion-diff-07e8111a98afbf912ba99a28e2d7ea647b2d29fc.tar.bz2 textual-inversion-diff-07e8111a98afbf912ba99a28e2d7ea647b2d29fc.zip | |
Added TI decay start offset
Diffstat (limited to 'train_ti.py')
| -rw-r--r-- | train_ti.py | 10 |
1 files changed, 8 insertions, 2 deletions
diff --git a/train_ti.py b/train_ti.py index 3b7e3b1..e18ee38 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -415,11 +415,17 @@ def parse_args(): | |||
| 415 | ) | 415 | ) |
| 416 | parser.add_argument( | 416 | parser.add_argument( |
| 417 | "--decay_factor", | 417 | "--decay_factor", |
| 418 | default=100, | 418 | default=1, |
| 419 | type=float, | 419 | type=float, |
| 420 | help="Embedding decay factor." | 420 | help="Embedding decay factor." |
| 421 | ) | 421 | ) |
| 422 | parser.add_argument( | 422 | parser.add_argument( |
| 423 | "--decay_start", | ||
| 424 | default=1e-4, | ||
| 425 | type=float, | ||
| 426 | help="Embedding decay start offset." | ||
| 427 | ) | ||
| 428 | parser.add_argument( | ||
| 423 | "--noise_timesteps", | 429 | "--noise_timesteps", |
| 424 | type=int, | 430 | type=int, |
| 425 | default=1000, | 431 | default=1000, |
| @@ -833,7 +839,7 @@ def main(): | |||
| 833 | def on_after_optimize(lr: float): | 839 | def on_after_optimize(lr: float): |
| 834 | text_encoder.text_model.embeddings.normalize( | 840 | text_encoder.text_model.embeddings.normalize( |
| 835 | args.decay_target, | 841 | args.decay_target, |
| 836 | min(1.0, args.decay_factor * lr) | 842 | min(1.0, max(0.0, args.decay_factor * ((lr - args.decay_start) / (args.learning_rate - args.decay_start)))) |
| 837 | ) | 843 | ) |
| 838 | 844 | ||
| 839 | loop = partial( | 845 | loop = partial( |
