diff options
| author | Volpeon <git@volpeon.ink> | 2023-04-04 07:30:43 +0200 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2023-04-04 07:30:43 +0200 |
| commit | 30b557c8e1f03b4748ac3efca599ff51d66561cb (patch) | |
| tree | 59aaacde83a7a44dc267c64455f6dc2cfb90c01f /train_ti.py | |
| parent | Improved sparse embeddings (diff) | |
| download | textual-inversion-diff-30b557c8e1f03b4748ac3efca599ff51d66561cb.tar.gz textual-inversion-diff-30b557c8e1f03b4748ac3efca599ff51d66561cb.tar.bz2 textual-inversion-diff-30b557c8e1f03b4748ac3efca599ff51d66561cb.zip | |
TI: Bring back old embedding decay
Diffstat (limited to 'train_ti.py')
| -rw-r--r-- | train_ti.py | 24 |
1 files changed, 19 insertions, 5 deletions
diff --git a/train_ti.py b/train_ti.py index a9a2333..4366c9e 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -353,7 +353,7 @@ def parse_args(): | |||
| 353 | parser.add_argument( | 353 | parser.add_argument( |
| 354 | "--adam_weight_decay", | 354 | "--adam_weight_decay", |
| 355 | type=float, | 355 | type=float, |
| 356 | default=1e-2, | 356 | default=0, |
| 357 | help="Weight decay to use." | 357 | help="Weight decay to use." |
| 358 | ) | 358 | ) |
| 359 | parser.add_argument( | 359 | parser.add_argument( |
| @@ -451,10 +451,21 @@ def parse_args(): | |||
| 451 | help="The weight of prior preservation loss." | 451 | help="The weight of prior preservation loss." |
| 452 | ) | 452 | ) |
| 453 | parser.add_argument( | 453 | parser.add_argument( |
| 454 | "--emb_alpha", | 454 | "--use_emb_decay", |
| 455 | default=1.0, | 455 | action="store_true", |
| 456 | help="Whether to use embedding decay." | ||
| 457 | ) | ||
| 458 | parser.add_argument( | ||
| 459 | "--emb_decay_target", | ||
| 460 | default=0.4, | ||
| 461 | type=float, | ||
| 462 | help="Embedding decay target." | ||
| 463 | ) | ||
| 464 | parser.add_argument( | ||
| 465 | "--emb_decay", | ||
| 466 | default=1e+2, | ||
| 456 | type=float, | 467 | type=float, |
| 457 | help="Embedding alpha." | 468 | help="Embedding decay factor." |
| 458 | ) | 469 | ) |
| 459 | parser.add_argument( | 470 | parser.add_argument( |
| 460 | "--noise_timesteps", | 471 | "--noise_timesteps", |
| @@ -600,7 +611,7 @@ def main(): | |||
| 600 | save_args(output_dir, args) | 611 | save_args(output_dir, args) |
| 601 | 612 | ||
| 602 | tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( | 613 | tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( |
| 603 | args.pretrained_model_name_or_path, args.emb_alpha) | 614 | args.pretrained_model_name_or_path) |
| 604 | 615 | ||
| 605 | tokenizer.set_use_vector_shuffle(args.vector_shuffle) | 616 | tokenizer.set_use_vector_shuffle(args.vector_shuffle) |
| 606 | tokenizer.set_dropout(args.vector_dropout) | 617 | tokenizer.set_dropout(args.vector_dropout) |
| @@ -744,6 +755,9 @@ def main(): | |||
| 744 | tokenizer=tokenizer, | 755 | tokenizer=tokenizer, |
| 745 | sample_scheduler=sample_scheduler, | 756 | sample_scheduler=sample_scheduler, |
| 746 | checkpoint_output_dir=checkpoint_output_dir, | 757 | checkpoint_output_dir=checkpoint_output_dir, |
| 758 | use_emb_decay=args.use_emb_decay, | ||
| 759 | emb_decay_target=args.emb_decay_target, | ||
| 760 | emb_decay=args.emb_decay, | ||
| 747 | use_ema=args.use_ema, | 761 | use_ema=args.use_ema, |
| 748 | ema_inv_gamma=args.ema_inv_gamma, | 762 | ema_inv_gamma=args.ema_inv_gamma, |
| 749 | ema_power=args.ema_power, | 763 | ema_power=args.ema_power, |
