diff options
author | Volpeon <git@volpeon.ink> | 2023-03-31 12:48:18 +0200 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-03-31 12:48:18 +0200 |
commit | e8d17379efa491a019c3a5fc2633e55ab87e3432 (patch) | |
tree | 0c65fefd5a8fe91d516fda67655b16eae2fa6d91 | |
parent | Fix (diff) | |
download | textual-inversion-diff-e8d17379efa491a019c3a5fc2633e55ab87e3432.tar.gz textual-inversion-diff-e8d17379efa491a019c3a5fc2633e55ab87e3432.tar.bz2 textual-inversion-diff-e8d17379efa491a019c3a5fc2633e55ab87e3432.zip |
Support Dadaptation d0, adjust sample freq when steps instead of epochs are used
-rw-r--r-- | train_dreambooth.py | 15 | ||||
-rw-r--r-- | train_lora.py | 15 | ||||
-rw-r--r-- | train_ti.py | 15 |
3 files changed, 33 insertions, 12 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py index f1dca7f..d2e60ec 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
@@ -302,6 +302,12 @@ def parse_args(): | |||
302 | help='Optimizer to use ["adam", "adam8bit", "dadam", "dadan"]' | 302 | help='Optimizer to use ["adam", "adam8bit", "dadam", "dadan"]' |
303 | ) | 303 | ) |
304 | parser.add_argument( | 304 | parser.add_argument( |
305 | "--dadaptation_d0", | ||
306 | type=float, | ||
307 | default=1e-6, | ||
308 | help="The d0 parameter for Dadaptation optimizers." | ||
309 | ) | ||
310 | parser.add_argument( | ||
305 | "--adam_beta1", | 311 | "--adam_beta1", |
306 | type=float, | 312 | type=float, |
307 | default=0.9, | 313 | default=0.9, |
@@ -535,6 +541,7 @@ def main(): | |||
535 | weight_decay=args.adam_weight_decay, | 541 | weight_decay=args.adam_weight_decay, |
536 | eps=args.adam_epsilon, | 542 | eps=args.adam_epsilon, |
537 | decouple=True, | 543 | decouple=True, |
544 | d0=args.dadaptation_d0, | ||
538 | ) | 545 | ) |
539 | 546 | ||
540 | args.learning_rate = 1.0 | 547 | args.learning_rate = 1.0 |
@@ -548,6 +555,7 @@ def main(): | |||
548 | dadaptation.DAdaptAdan, | 555 | dadaptation.DAdaptAdan, |
549 | weight_decay=args.adam_weight_decay, | 556 | weight_decay=args.adam_weight_decay, |
550 | eps=args.adam_epsilon, | 557 | eps=args.adam_epsilon, |
558 | d0=args.dadaptation_d0, | ||
551 | ) | 559 | ) |
552 | 560 | ||
553 | args.learning_rate = 1.0 | 561 | args.learning_rate = 1.0 |
@@ -596,10 +604,9 @@ def main(): | |||
596 | datamodule.setup() | 604 | datamodule.setup() |
597 | 605 | ||
598 | num_train_epochs = args.num_train_epochs | 606 | num_train_epochs = args.num_train_epochs |
599 | |||
600 | if num_train_epochs is None: | 607 | if num_train_epochs is None: |
601 | num_images = math.ceil(len(datamodule.train_dataset) / args.train_batch_size) * args.train_batch_size | 608 | num_train_epochs = math.ceil(args.num_train_steps / len(datamodule.train_dataset)) |
602 | num_train_epochs = math.ceil(args.num_train_steps / num_images) | 609 | sample_frequency = math.ceil(num_train_epochs * (args.sample_frequency / args.num_train_steps)) |
603 | 610 | ||
604 | params_to_optimize = (unet.parameters(), ) | 611 | params_to_optimize = (unet.parameters(), ) |
605 | if args.train_text_encoder_epochs != 0: | 612 | if args.train_text_encoder_epochs != 0: |
@@ -639,7 +646,7 @@ def main(): | |||
639 | lr_scheduler=lr_scheduler, | 646 | lr_scheduler=lr_scheduler, |
640 | num_train_epochs=num_train_epochs, | 647 | num_train_epochs=num_train_epochs, |
641 | gradient_accumulation_steps=args.gradient_accumulation_steps, | 648 | gradient_accumulation_steps=args.gradient_accumulation_steps, |
642 | sample_frequency=args.sample_frequency, | 649 | sample_frequency=sample_frequency, |
643 | offset_noise_strength=args.offset_noise_strength, | 650 | offset_noise_strength=args.offset_noise_strength, |
644 | # -- | 651 | # -- |
645 | tokenizer=tokenizer, | 652 | tokenizer=tokenizer, |
diff --git a/train_lora.py b/train_lora.py index 9975462..7b54ef8 100644 --- a/train_lora.py +++ b/train_lora.py | |||
@@ -314,6 +314,12 @@ def parse_args(): | |||
314 | help='Optimizer to use ["adam", "adam8bit", "dadam", "dadan"]' | 314 | help='Optimizer to use ["adam", "adam8bit", "dadam", "dadan"]' |
315 | ) | 315 | ) |
316 | parser.add_argument( | 316 | parser.add_argument( |
317 | "--dadaptation_d0", | ||
318 | type=float, | ||
319 | default=1e-6, | ||
320 | help="The d0 parameter for Dadaptation optimizers." | ||
321 | ) | ||
322 | parser.add_argument( | ||
317 | "--adam_beta1", | 323 | "--adam_beta1", |
318 | type=float, | 324 | type=float, |
319 | default=0.9, | 325 | default=0.9, |
@@ -567,6 +573,7 @@ def main(): | |||
567 | weight_decay=args.adam_weight_decay, | 573 | weight_decay=args.adam_weight_decay, |
568 | eps=args.adam_epsilon, | 574 | eps=args.adam_epsilon, |
569 | decouple=True, | 575 | decouple=True, |
576 | d0=args.dadaptation_d0, | ||
570 | ) | 577 | ) |
571 | 578 | ||
572 | args.learning_rate = 1.0 | 579 | args.learning_rate = 1.0 |
@@ -580,6 +587,7 @@ def main(): | |||
580 | dadaptation.DAdaptAdan, | 587 | dadaptation.DAdaptAdan, |
581 | weight_decay=args.adam_weight_decay, | 588 | weight_decay=args.adam_weight_decay, |
582 | eps=args.adam_epsilon, | 589 | eps=args.adam_epsilon, |
590 | d0=args.dadaptation_d0, | ||
583 | ) | 591 | ) |
584 | 592 | ||
585 | args.learning_rate = 1.0 | 593 | args.learning_rate = 1.0 |
@@ -628,10 +636,9 @@ def main(): | |||
628 | datamodule.setup() | 636 | datamodule.setup() |
629 | 637 | ||
630 | num_train_epochs = args.num_train_epochs | 638 | num_train_epochs = args.num_train_epochs |
631 | |||
632 | if num_train_epochs is None: | 639 | if num_train_epochs is None: |
633 | num_images = math.ceil(len(datamodule.train_dataset) / args.train_batch_size) * args.train_batch_size | 640 | num_train_epochs = math.ceil(args.num_train_steps / len(datamodule.train_dataset)) |
634 | num_train_epochs = math.ceil(args.num_train_steps / num_images) | 641 | sample_frequency = math.ceil(num_train_epochs * (args.sample_frequency / args.num_train_steps)) |
635 | 642 | ||
636 | optimizer = create_optimizer( | 643 | optimizer = create_optimizer( |
637 | itertools.chain( | 644 | itertools.chain( |
@@ -667,7 +674,7 @@ def main(): | |||
667 | lr_scheduler=lr_scheduler, | 674 | lr_scheduler=lr_scheduler, |
668 | num_train_epochs=num_train_epochs, | 675 | num_train_epochs=num_train_epochs, |
669 | gradient_accumulation_steps=args.gradient_accumulation_steps, | 676 | gradient_accumulation_steps=args.gradient_accumulation_steps, |
670 | sample_frequency=args.sample_frequency, | 677 | sample_frequency=sample_frequency, |
671 | offset_noise_strength=args.offset_noise_strength, | 678 | offset_noise_strength=args.offset_noise_strength, |
672 | # -- | 679 | # -- |
673 | tokenizer=tokenizer, | 680 | tokenizer=tokenizer, |
diff --git a/train_ti.py b/train_ti.py index b7ea5f3..902f508 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -320,6 +320,12 @@ def parse_args(): | |||
320 | help='Optimizer to use ["adam", "adam8bit", "dadam", "dadan"]' | 320 | help='Optimizer to use ["adam", "adam8bit", "dadam", "dadan"]' |
321 | ) | 321 | ) |
322 | parser.add_argument( | 322 | parser.add_argument( |
323 | "--dadaptation_d0", | ||
324 | type=float, | ||
325 | default=1e-6, | ||
326 | help="The d0 parameter for Dadaptation optimizers." | ||
327 | ) | ||
328 | parser.add_argument( | ||
323 | "--adam_beta1", | 329 | "--adam_beta1", |
324 | type=float, | 330 | type=float, |
325 | default=0.9, | 331 | default=0.9, |
@@ -659,6 +665,7 @@ def main(): | |||
659 | weight_decay=args.adam_weight_decay, | 665 | weight_decay=args.adam_weight_decay, |
660 | eps=args.adam_epsilon, | 666 | eps=args.adam_epsilon, |
661 | decouple=True, | 667 | decouple=True, |
668 | d0=args.dadaptation_d0, | ||
662 | ) | 669 | ) |
663 | elif args.optimizer == 'dadan': | 670 | elif args.optimizer == 'dadan': |
664 | try: | 671 | try: |
@@ -670,6 +677,7 @@ def main(): | |||
670 | dadaptation.DAdaptAdan, | 677 | dadaptation.DAdaptAdan, |
671 | weight_decay=args.adam_weight_decay, | 678 | weight_decay=args.adam_weight_decay, |
672 | eps=args.adam_epsilon, | 679 | eps=args.adam_epsilon, |
680 | d0=args.dadaptation_d0, | ||
673 | ) | 681 | ) |
674 | else: | 682 | else: |
675 | raise ValueError(f"Unknown --optimizer \"{args.optimizer}\"") | 683 | raise ValueError(f"Unknown --optimizer \"{args.optimizer}\"") |
@@ -690,7 +698,6 @@ def main(): | |||
690 | no_val=args.valid_set_size == 0, | 698 | no_val=args.valid_set_size == 0, |
691 | strategy=textual_inversion_strategy, | 699 | strategy=textual_inversion_strategy, |
692 | gradient_accumulation_steps=args.gradient_accumulation_steps, | 700 | gradient_accumulation_steps=args.gradient_accumulation_steps, |
693 | sample_frequency=args.sample_frequency, | ||
694 | checkpoint_frequency=args.checkpoint_frequency, | 701 | checkpoint_frequency=args.checkpoint_frequency, |
695 | milestone_checkpoints=not args.no_milestone_checkpoints, | 702 | milestone_checkpoints=not args.no_milestone_checkpoints, |
696 | global_step_offset=global_step_offset, | 703 | global_step_offset=global_step_offset, |
@@ -759,10 +766,9 @@ def main(): | |||
759 | datamodule.setup() | 766 | datamodule.setup() |
760 | 767 | ||
761 | num_train_epochs = args.num_train_epochs | 768 | num_train_epochs = args.num_train_epochs |
762 | |||
763 | if num_train_epochs is None: | 769 | if num_train_epochs is None: |
764 | num_images = math.ceil(len(datamodule.train_dataset) / args.train_batch_size) * args.train_batch_size | 770 | num_train_epochs = math.ceil(args.num_train_steps / len(datamodule.train_dataset)) |
765 | num_train_epochs = math.ceil(args.num_train_steps / num_images) | 771 | sample_frequency = math.ceil(num_train_epochs * (args.sample_frequency / args.num_train_steps)) |
766 | 772 | ||
767 | optimizer = create_optimizer( | 773 | optimizer = create_optimizer( |
768 | text_encoder.text_model.embeddings.temp_token_embedding.parameters(), | 774 | text_encoder.text_model.embeddings.temp_token_embedding.parameters(), |
@@ -792,6 +798,7 @@ def main(): | |||
792 | optimizer=optimizer, | 798 | optimizer=optimizer, |
793 | lr_scheduler=lr_scheduler, | 799 | lr_scheduler=lr_scheduler, |
794 | num_train_epochs=num_train_epochs, | 800 | num_train_epochs=num_train_epochs, |
801 | sample_frequency=sample_frequency, | ||
795 | # -- | 802 | # -- |
796 | sample_output_dir=sample_output_dir, | 803 | sample_output_dir=sample_output_dir, |
797 | placeholder_tokens=placeholder_tokens, | 804 | placeholder_tokens=placeholder_tokens, |