summaryrefslogtreecommitdiffstats
path: root/train_ti.py
diff options
context:
space:
mode:
Diffstat (limited to 'train_ti.py')
-rw-r--r--train_ti.py15
1 files changed, 11 insertions, 4 deletions
diff --git a/train_ti.py b/train_ti.py
index b7ea5f3..902f508 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -320,6 +320,12 @@ def parse_args():
320 help='Optimizer to use ["adam", "adam8bit", "dadam", "dadan"]' 320 help='Optimizer to use ["adam", "adam8bit", "dadam", "dadan"]'
321 ) 321 )
322 parser.add_argument( 322 parser.add_argument(
323 "--dadaptation_d0",
324 type=float,
325 default=1e-6,
326 help="The d0 parameter for Dadaptation optimizers."
327 )
328 parser.add_argument(
323 "--adam_beta1", 329 "--adam_beta1",
324 type=float, 330 type=float,
325 default=0.9, 331 default=0.9,
@@ -659,6 +665,7 @@ def main():
659 weight_decay=args.adam_weight_decay, 665 weight_decay=args.adam_weight_decay,
660 eps=args.adam_epsilon, 666 eps=args.adam_epsilon,
661 decouple=True, 667 decouple=True,
668 d0=args.dadaptation_d0,
662 ) 669 )
663 elif args.optimizer == 'dadan': 670 elif args.optimizer == 'dadan':
664 try: 671 try:
@@ -670,6 +677,7 @@ def main():
670 dadaptation.DAdaptAdan, 677 dadaptation.DAdaptAdan,
671 weight_decay=args.adam_weight_decay, 678 weight_decay=args.adam_weight_decay,
672 eps=args.adam_epsilon, 679 eps=args.adam_epsilon,
680 d0=args.dadaptation_d0,
673 ) 681 )
674 else: 682 else:
675 raise ValueError(f"Unknown --optimizer \"{args.optimizer}\"") 683 raise ValueError(f"Unknown --optimizer \"{args.optimizer}\"")
@@ -690,7 +698,6 @@ def main():
690 no_val=args.valid_set_size == 0, 698 no_val=args.valid_set_size == 0,
691 strategy=textual_inversion_strategy, 699 strategy=textual_inversion_strategy,
692 gradient_accumulation_steps=args.gradient_accumulation_steps, 700 gradient_accumulation_steps=args.gradient_accumulation_steps,
693 sample_frequency=args.sample_frequency,
694 checkpoint_frequency=args.checkpoint_frequency, 701 checkpoint_frequency=args.checkpoint_frequency,
695 milestone_checkpoints=not args.no_milestone_checkpoints, 702 milestone_checkpoints=not args.no_milestone_checkpoints,
696 global_step_offset=global_step_offset, 703 global_step_offset=global_step_offset,
@@ -759,10 +766,9 @@ def main():
759 datamodule.setup() 766 datamodule.setup()
760 767
761 num_train_epochs = args.num_train_epochs 768 num_train_epochs = args.num_train_epochs
762
763 if num_train_epochs is None: 769 if num_train_epochs is None:
764 num_images = math.ceil(len(datamodule.train_dataset) / args.train_batch_size) * args.train_batch_size 770 num_train_epochs = math.ceil(args.num_train_steps / len(datamodule.train_dataset))
765 num_train_epochs = math.ceil(args.num_train_steps / num_images) 771 sample_frequency = math.ceil(num_train_epochs * (args.sample_frequency / args.num_train_steps))
766 772
767 optimizer = create_optimizer( 773 optimizer = create_optimizer(
768 text_encoder.text_model.embeddings.temp_token_embedding.parameters(), 774 text_encoder.text_model.embeddings.temp_token_embedding.parameters(),
@@ -792,6 +798,7 @@ def main():
792 optimizer=optimizer, 798 optimizer=optimizer,
793 lr_scheduler=lr_scheduler, 799 lr_scheduler=lr_scheduler,
794 num_train_epochs=num_train_epochs, 800 num_train_epochs=num_train_epochs,
801 sample_frequency=sample_frequency,
795 # -- 802 # --
796 sample_output_dir=sample_output_dir, 803 sample_output_dir=sample_output_dir,
797 placeholder_tokens=placeholder_tokens, 804 placeholder_tokens=placeholder_tokens,