summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-03-31 12:48:18 +0200
committerVolpeon <git@volpeon.ink>2023-03-31 12:48:18 +0200
commite8d17379efa491a019c3a5fc2633e55ab87e3432 (patch)
tree0c65fefd5a8fe91d516fda67655b16eae2fa6d91
parentFix (diff)
downloadtextual-inversion-diff-e8d17379efa491a019c3a5fc2633e55ab87e3432.tar.gz
textual-inversion-diff-e8d17379efa491a019c3a5fc2633e55ab87e3432.tar.bz2
textual-inversion-diff-e8d17379efa491a019c3a5fc2633e55ab87e3432.zip
Support Dadaptation d0, adjust sample freq when steps instead of epochs are used
-rw-r--r--train_dreambooth.py15
-rw-r--r--train_lora.py15
-rw-r--r--train_ti.py15
3 files changed, 33 insertions, 12 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py
index f1dca7f..d2e60ec 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -302,6 +302,12 @@ def parse_args():
302 help='Optimizer to use ["adam", "adam8bit", "dadam", "dadan"]' 302 help='Optimizer to use ["adam", "adam8bit", "dadam", "dadan"]'
303 ) 303 )
304 parser.add_argument( 304 parser.add_argument(
305 "--dadaptation_d0",
306 type=float,
307 default=1e-6,
308 help="The d0 parameter for Dadaptation optimizers."
309 )
310 parser.add_argument(
305 "--adam_beta1", 311 "--adam_beta1",
306 type=float, 312 type=float,
307 default=0.9, 313 default=0.9,
@@ -535,6 +541,7 @@ def main():
535 weight_decay=args.adam_weight_decay, 541 weight_decay=args.adam_weight_decay,
536 eps=args.adam_epsilon, 542 eps=args.adam_epsilon,
537 decouple=True, 543 decouple=True,
544 d0=args.dadaptation_d0,
538 ) 545 )
539 546
540 args.learning_rate = 1.0 547 args.learning_rate = 1.0
@@ -548,6 +555,7 @@ def main():
548 dadaptation.DAdaptAdan, 555 dadaptation.DAdaptAdan,
549 weight_decay=args.adam_weight_decay, 556 weight_decay=args.adam_weight_decay,
550 eps=args.adam_epsilon, 557 eps=args.adam_epsilon,
558 d0=args.dadaptation_d0,
551 ) 559 )
552 560
553 args.learning_rate = 1.0 561 args.learning_rate = 1.0
@@ -596,10 +604,9 @@ def main():
596 datamodule.setup() 604 datamodule.setup()
597 605
598 num_train_epochs = args.num_train_epochs 606 num_train_epochs = args.num_train_epochs
599
600 if num_train_epochs is None: 607 if num_train_epochs is None:
601 num_images = math.ceil(len(datamodule.train_dataset) / args.train_batch_size) * args.train_batch_size 608 num_train_epochs = math.ceil(args.num_train_steps / len(datamodule.train_dataset))
602 num_train_epochs = math.ceil(args.num_train_steps / num_images) 609 sample_frequency = math.ceil(num_train_epochs * (args.sample_frequency / args.num_train_steps))
603 610
604 params_to_optimize = (unet.parameters(), ) 611 params_to_optimize = (unet.parameters(), )
605 if args.train_text_encoder_epochs != 0: 612 if args.train_text_encoder_epochs != 0:
@@ -639,7 +646,7 @@ def main():
639 lr_scheduler=lr_scheduler, 646 lr_scheduler=lr_scheduler,
640 num_train_epochs=num_train_epochs, 647 num_train_epochs=num_train_epochs,
641 gradient_accumulation_steps=args.gradient_accumulation_steps, 648 gradient_accumulation_steps=args.gradient_accumulation_steps,
642 sample_frequency=args.sample_frequency, 649 sample_frequency=sample_frequency,
643 offset_noise_strength=args.offset_noise_strength, 650 offset_noise_strength=args.offset_noise_strength,
644 # -- 651 # --
645 tokenizer=tokenizer, 652 tokenizer=tokenizer,
diff --git a/train_lora.py b/train_lora.py
index 9975462..7b54ef8 100644
--- a/train_lora.py
+++ b/train_lora.py
@@ -314,6 +314,12 @@ def parse_args():
314 help='Optimizer to use ["adam", "adam8bit", "dadam", "dadan"]' 314 help='Optimizer to use ["adam", "adam8bit", "dadam", "dadan"]'
315 ) 315 )
316 parser.add_argument( 316 parser.add_argument(
317 "--dadaptation_d0",
318 type=float,
319 default=1e-6,
320 help="The d0 parameter for Dadaptation optimizers."
321 )
322 parser.add_argument(
317 "--adam_beta1", 323 "--adam_beta1",
318 type=float, 324 type=float,
319 default=0.9, 325 default=0.9,
@@ -567,6 +573,7 @@ def main():
567 weight_decay=args.adam_weight_decay, 573 weight_decay=args.adam_weight_decay,
568 eps=args.adam_epsilon, 574 eps=args.adam_epsilon,
569 decouple=True, 575 decouple=True,
576 d0=args.dadaptation_d0,
570 ) 577 )
571 578
572 args.learning_rate = 1.0 579 args.learning_rate = 1.0
@@ -580,6 +587,7 @@ def main():
580 dadaptation.DAdaptAdan, 587 dadaptation.DAdaptAdan,
581 weight_decay=args.adam_weight_decay, 588 weight_decay=args.adam_weight_decay,
582 eps=args.adam_epsilon, 589 eps=args.adam_epsilon,
590 d0=args.dadaptation_d0,
583 ) 591 )
584 592
585 args.learning_rate = 1.0 593 args.learning_rate = 1.0
@@ -628,10 +636,9 @@ def main():
628 datamodule.setup() 636 datamodule.setup()
629 637
630 num_train_epochs = args.num_train_epochs 638 num_train_epochs = args.num_train_epochs
631
632 if num_train_epochs is None: 639 if num_train_epochs is None:
633 num_images = math.ceil(len(datamodule.train_dataset) / args.train_batch_size) * args.train_batch_size 640 num_train_epochs = math.ceil(args.num_train_steps / len(datamodule.train_dataset))
634 num_train_epochs = math.ceil(args.num_train_steps / num_images) 641 sample_frequency = math.ceil(num_train_epochs * (args.sample_frequency / args.num_train_steps))
635 642
636 optimizer = create_optimizer( 643 optimizer = create_optimizer(
637 itertools.chain( 644 itertools.chain(
@@ -667,7 +674,7 @@ def main():
667 lr_scheduler=lr_scheduler, 674 lr_scheduler=lr_scheduler,
668 num_train_epochs=num_train_epochs, 675 num_train_epochs=num_train_epochs,
669 gradient_accumulation_steps=args.gradient_accumulation_steps, 676 gradient_accumulation_steps=args.gradient_accumulation_steps,
670 sample_frequency=args.sample_frequency, 677 sample_frequency=sample_frequency,
671 offset_noise_strength=args.offset_noise_strength, 678 offset_noise_strength=args.offset_noise_strength,
672 # -- 679 # --
673 tokenizer=tokenizer, 680 tokenizer=tokenizer,
diff --git a/train_ti.py b/train_ti.py
index b7ea5f3..902f508 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -320,6 +320,12 @@ def parse_args():
320 help='Optimizer to use ["adam", "adam8bit", "dadam", "dadan"]' 320 help='Optimizer to use ["adam", "adam8bit", "dadam", "dadan"]'
321 ) 321 )
322 parser.add_argument( 322 parser.add_argument(
323 "--dadaptation_d0",
324 type=float,
325 default=1e-6,
326 help="The d0 parameter for Dadaptation optimizers."
327 )
328 parser.add_argument(
323 "--adam_beta1", 329 "--adam_beta1",
324 type=float, 330 type=float,
325 default=0.9, 331 default=0.9,
@@ -659,6 +665,7 @@ def main():
659 weight_decay=args.adam_weight_decay, 665 weight_decay=args.adam_weight_decay,
660 eps=args.adam_epsilon, 666 eps=args.adam_epsilon,
661 decouple=True, 667 decouple=True,
668 d0=args.dadaptation_d0,
662 ) 669 )
663 elif args.optimizer == 'dadan': 670 elif args.optimizer == 'dadan':
664 try: 671 try:
@@ -670,6 +677,7 @@ def main():
670 dadaptation.DAdaptAdan, 677 dadaptation.DAdaptAdan,
671 weight_decay=args.adam_weight_decay, 678 weight_decay=args.adam_weight_decay,
672 eps=args.adam_epsilon, 679 eps=args.adam_epsilon,
680 d0=args.dadaptation_d0,
673 ) 681 )
674 else: 682 else:
675 raise ValueError(f"Unknown --optimizer \"{args.optimizer}\"") 683 raise ValueError(f"Unknown --optimizer \"{args.optimizer}\"")
@@ -690,7 +698,6 @@ def main():
690 no_val=args.valid_set_size == 0, 698 no_val=args.valid_set_size == 0,
691 strategy=textual_inversion_strategy, 699 strategy=textual_inversion_strategy,
692 gradient_accumulation_steps=args.gradient_accumulation_steps, 700 gradient_accumulation_steps=args.gradient_accumulation_steps,
693 sample_frequency=args.sample_frequency,
694 checkpoint_frequency=args.checkpoint_frequency, 701 checkpoint_frequency=args.checkpoint_frequency,
695 milestone_checkpoints=not args.no_milestone_checkpoints, 702 milestone_checkpoints=not args.no_milestone_checkpoints,
696 global_step_offset=global_step_offset, 703 global_step_offset=global_step_offset,
@@ -759,10 +766,9 @@ def main():
759 datamodule.setup() 766 datamodule.setup()
760 767
761 num_train_epochs = args.num_train_epochs 768 num_train_epochs = args.num_train_epochs
762
763 if num_train_epochs is None: 769 if num_train_epochs is None:
764 num_images = math.ceil(len(datamodule.train_dataset) / args.train_batch_size) * args.train_batch_size 770 num_train_epochs = math.ceil(args.num_train_steps / len(datamodule.train_dataset))
765 num_train_epochs = math.ceil(args.num_train_steps / num_images) 771 sample_frequency = math.ceil(num_train_epochs * (args.sample_frequency / args.num_train_steps))
766 772
767 optimizer = create_optimizer( 773 optimizer = create_optimizer(
768 text_encoder.text_model.embeddings.temp_token_embedding.parameters(), 774 text_encoder.text_model.embeddings.temp_token_embedding.parameters(),
@@ -792,6 +798,7 @@ def main():
792 optimizer=optimizer, 798 optimizer=optimizer,
793 lr_scheduler=lr_scheduler, 799 lr_scheduler=lr_scheduler,
794 num_train_epochs=num_train_epochs, 800 num_train_epochs=num_train_epochs,
801 sample_frequency=sample_frequency,
795 # -- 802 # --
796 sample_output_dir=sample_output_dir, 803 sample_output_dir=sample_output_dir,
797 placeholder_tokens=placeholder_tokens, 804 placeholder_tokens=placeholder_tokens,