diff options
| author | Volpeon <git@volpeon.ink> | 2023-04-03 18:52:30 +0200 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2023-04-03 18:52:30 +0200 |
| commit | e68cb3542e08c9f22ce8a94fd88bebe0c121ca17 (patch) | |
| tree | 87fbb9d92233aa1bb7342e31aec64d6d375f41e1 /train_ti.py | |
| parent | TI: No tag dropout by default (diff) | |
| download | textual-inversion-diff-e68cb3542e08c9f22ce8a94fd88bebe0c121ca17.tar.gz textual-inversion-diff-e68cb3542e08c9f22ce8a94fd88bebe0c121ca17.tar.bz2 textual-inversion-diff-e68cb3542e08c9f22ce8a94fd88bebe0c121ca17.zip | |
TI: Delta learning
Diffstat (limited to 'train_ti.py')
| -rw-r--r-- | train_ti.py | 37 |
1 files changed, 11 insertions, 26 deletions
diff --git a/train_ti.py b/train_ti.py index 8dde1ba..0ad7574 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -353,7 +353,7 @@ def parse_args(): | |||
| 353 | parser.add_argument( | 353 | parser.add_argument( |
| 354 | "--adam_weight_decay", | 354 | "--adam_weight_decay", |
| 355 | type=float, | 355 | type=float, |
| 356 | default=0, | 356 | default=1e-2, |
| 357 | help="Weight decay to use." | 357 | help="Weight decay to use." |
| 358 | ) | 358 | ) |
| 359 | parser.add_argument( | 359 | parser.add_argument( |
| @@ -451,21 +451,10 @@ def parse_args(): | |||
| 451 | help="The weight of prior preservation loss." | 451 | help="The weight of prior preservation loss." |
| 452 | ) | 452 | ) |
| 453 | parser.add_argument( | 453 | parser.add_argument( |
| 454 | "--use_emb_decay", | 454 | "--emb_alpha", |
| 455 | action="store_true", | 455 | default=1.0, |
| 456 | help="Whether to use embedding decay." | ||
| 457 | ) | ||
| 458 | parser.add_argument( | ||
| 459 | "--emb_decay_target", | ||
| 460 | default=0.4, | ||
| 461 | type=float, | ||
| 462 | help="Embedding decay target." | ||
| 463 | ) | ||
| 464 | parser.add_argument( | ||
| 465 | "--emb_decay", | ||
| 466 | default=1e2, | ||
| 467 | type=float, | 456 | type=float, |
| 468 | help="Embedding decay factor." | 457 | help="Embedding alpha." |
| 469 | ) | 458 | ) |
| 470 | parser.add_argument( | 459 | parser.add_argument( |
| 471 | "--noise_timesteps", | 460 | "--noise_timesteps", |
| @@ -567,16 +556,16 @@ def parse_args(): | |||
| 567 | raise ValueError("You must specify --output_dir") | 556 | raise ValueError("You must specify --output_dir") |
| 568 | 557 | ||
| 569 | if args.adam_beta1 is None: | 558 | if args.adam_beta1 is None: |
| 570 | if args.optimizer in ('adam', 'adam8bit'): | 559 | if args.optimizer == 'lion': |
| 571 | args.adam_beta1 = 0.9 | ||
| 572 | elif args.optimizer == 'lion': | ||
| 573 | args.adam_beta1 = 0.95 | 560 | args.adam_beta1 = 0.95 |
| 561 | else: | ||
| 562 | args.adam_beta1 = 0.9 | ||
| 574 | 563 | ||
| 575 | if args.adam_beta2 is None: | 564 | if args.adam_beta2 is None: |
| 576 | if args.optimizer in ('adam', 'adam8bit'): | 565 | if args.optimizer == 'lion': |
| 577 | args.adam_beta2 = 0.999 | ||
| 578 | elif args.optimizer == 'lion': | ||
| 579 | args.adam_beta2 = 0.98 | 566 | args.adam_beta2 = 0.98 |
| 567 | else: | ||
| 568 | args.adam_beta2 = 0.999 | ||
| 580 | 569 | ||
| 581 | return args | 570 | return args |
| 582 | 571 | ||
| @@ -611,7 +600,7 @@ def main(): | |||
| 611 | save_args(output_dir, args) | 600 | save_args(output_dir, args) |
| 612 | 601 | ||
| 613 | tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( | 602 | tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( |
| 614 | args.pretrained_model_name_or_path) | 603 | args.pretrained_model_name_or_path, args.emb_alpha) |
| 615 | 604 | ||
| 616 | tokenizer.set_use_vector_shuffle(args.vector_shuffle) | 605 | tokenizer.set_use_vector_shuffle(args.vector_shuffle) |
| 617 | tokenizer.set_dropout(args.vector_dropout) | 606 | tokenizer.set_dropout(args.vector_dropout) |
| @@ -755,10 +744,6 @@ def main(): | |||
| 755 | tokenizer=tokenizer, | 744 | tokenizer=tokenizer, |
| 756 | sample_scheduler=sample_scheduler, | 745 | sample_scheduler=sample_scheduler, |
| 757 | checkpoint_output_dir=checkpoint_output_dir, | 746 | checkpoint_output_dir=checkpoint_output_dir, |
| 758 | gradient_checkpointing=args.gradient_checkpointing, | ||
| 759 | use_emb_decay=args.use_emb_decay, | ||
| 760 | emb_decay_target=args.emb_decay_target, | ||
| 761 | emb_decay=args.emb_decay, | ||
| 762 | use_ema=args.use_ema, | 747 | use_ema=args.use_ema, |
| 763 | ema_inv_gamma=args.ema_inv_gamma, | 748 | ema_inv_gamma=args.ema_inv_gamma, |
| 764 | ema_power=args.ema_power, | 749 | ema_power=args.ema_power, |
