diff options
Diffstat (limited to 'train_ti.py')
| -rw-r--r-- | train_ti.py | 15 |
1 files changed, 11 insertions, 4 deletions
diff --git a/train_ti.py b/train_ti.py index b7ea5f3..902f508 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -320,6 +320,12 @@ def parse_args(): | |||
| 320 | help='Optimizer to use ["adam", "adam8bit", "dadam", "dadan"]' | 320 | help='Optimizer to use ["adam", "adam8bit", "dadam", "dadan"]' |
| 321 | ) | 321 | ) |
| 322 | parser.add_argument( | 322 | parser.add_argument( |
| 323 | "--dadaptation_d0", | ||
| 324 | type=float, | ||
| 325 | default=1e-6, | ||
| 326 | help="The d0 parameter for Dadaptation optimizers." | ||
| 327 | ) | ||
| 328 | parser.add_argument( | ||
| 323 | "--adam_beta1", | 329 | "--adam_beta1", |
| 324 | type=float, | 330 | type=float, |
| 325 | default=0.9, | 331 | default=0.9, |
| @@ -659,6 +665,7 @@ def main(): | |||
| 659 | weight_decay=args.adam_weight_decay, | 665 | weight_decay=args.adam_weight_decay, |
| 660 | eps=args.adam_epsilon, | 666 | eps=args.adam_epsilon, |
| 661 | decouple=True, | 667 | decouple=True, |
| 668 | d0=args.dadaptation_d0, | ||
| 662 | ) | 669 | ) |
| 663 | elif args.optimizer == 'dadan': | 670 | elif args.optimizer == 'dadan': |
| 664 | try: | 671 | try: |
| @@ -670,6 +677,7 @@ def main(): | |||
| 670 | dadaptation.DAdaptAdan, | 677 | dadaptation.DAdaptAdan, |
| 671 | weight_decay=args.adam_weight_decay, | 678 | weight_decay=args.adam_weight_decay, |
| 672 | eps=args.adam_epsilon, | 679 | eps=args.adam_epsilon, |
| 680 | d0=args.dadaptation_d0, | ||
| 673 | ) | 681 | ) |
| 674 | else: | 682 | else: |
| 675 | raise ValueError(f"Unknown --optimizer \"{args.optimizer}\"") | 683 | raise ValueError(f"Unknown --optimizer \"{args.optimizer}\"") |
| @@ -690,7 +698,6 @@ def main(): | |||
| 690 | no_val=args.valid_set_size == 0, | 698 | no_val=args.valid_set_size == 0, |
| 691 | strategy=textual_inversion_strategy, | 699 | strategy=textual_inversion_strategy, |
| 692 | gradient_accumulation_steps=args.gradient_accumulation_steps, | 700 | gradient_accumulation_steps=args.gradient_accumulation_steps, |
| 693 | sample_frequency=args.sample_frequency, | ||
| 694 | checkpoint_frequency=args.checkpoint_frequency, | 701 | checkpoint_frequency=args.checkpoint_frequency, |
| 695 | milestone_checkpoints=not args.no_milestone_checkpoints, | 702 | milestone_checkpoints=not args.no_milestone_checkpoints, |
| 696 | global_step_offset=global_step_offset, | 703 | global_step_offset=global_step_offset, |
| @@ -759,10 +766,9 @@ def main(): | |||
| 759 | datamodule.setup() | 766 | datamodule.setup() |
| 760 | 767 | ||
| 761 | num_train_epochs = args.num_train_epochs | 768 | num_train_epochs = args.num_train_epochs |
| 762 | |||
| 763 | if num_train_epochs is None: | 769 | if num_train_epochs is None: |
| 764 | num_images = math.ceil(len(datamodule.train_dataset) / args.train_batch_size) * args.train_batch_size | 770 | num_train_epochs = math.ceil(args.num_train_steps / len(datamodule.train_dataset)) |
| 765 | num_train_epochs = math.ceil(args.num_train_steps / num_images) | 771 | sample_frequency = math.ceil(num_train_epochs * (args.sample_frequency / args.num_train_steps)) |
| 766 | 772 | ||
| 767 | optimizer = create_optimizer( | 773 | optimizer = create_optimizer( |
| 768 | text_encoder.text_model.embeddings.temp_token_embedding.parameters(), | 774 | text_encoder.text_model.embeddings.temp_token_embedding.parameters(), |
| @@ -792,6 +798,7 @@ def main(): | |||
| 792 | optimizer=optimizer, | 798 | optimizer=optimizer, |
| 793 | lr_scheduler=lr_scheduler, | 799 | lr_scheduler=lr_scheduler, |
| 794 | num_train_epochs=num_train_epochs, | 800 | num_train_epochs=num_train_epochs, |
| 801 | sample_frequency=sample_frequency, | ||
| 795 | # -- | 802 | # -- |
| 796 | sample_output_dir=sample_output_dir, | 803 | sample_output_dir=sample_output_dir, |
| 797 | placeholder_tokens=placeholder_tokens, | 804 | placeholder_tokens=placeholder_tokens, |
