From e8d17379efa491a019c3a5fc2633e55ab87e3432 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Fri, 31 Mar 2023 12:48:18 +0200 Subject: Support Dadaptation d0, adjust sample freq when steps instead of epochs are used --- train_dreambooth.py | 15 +++++++++++---- train_lora.py | 15 +++++++++++---- train_ti.py | 15 +++++++++++---- 3 files changed, 33 insertions(+), 12 deletions(-) diff --git a/train_dreambooth.py b/train_dreambooth.py index f1dca7f..d2e60ec 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py @@ -301,6 +301,12 @@ def parse_args(): default="dadan", help='Optimizer to use ["adam", "adam8bit", "dadam", "dadan"]' ) + parser.add_argument( + "--dadaptation_d0", + type=float, + default=1e-6, + help="The d0 parameter for Dadaptation optimizers." + ) parser.add_argument( "--adam_beta1", type=float, @@ -535,6 +541,7 @@ def main(): weight_decay=args.adam_weight_decay, eps=args.adam_epsilon, decouple=True, + d0=args.dadaptation_d0, ) args.learning_rate = 1.0 @@ -548,6 +555,7 @@ def main(): dadaptation.DAdaptAdan, weight_decay=args.adam_weight_decay, eps=args.adam_epsilon, + d0=args.dadaptation_d0, ) args.learning_rate = 1.0 @@ -596,10 +604,9 @@ def main(): datamodule.setup() num_train_epochs = args.num_train_epochs - if num_train_epochs is None: - num_images = math.ceil(len(datamodule.train_dataset) / args.train_batch_size) * args.train_batch_size - num_train_epochs = math.ceil(args.num_train_steps / num_images) + num_train_epochs = math.ceil(args.num_train_steps / len(datamodule.train_dataset)) + sample_frequency = math.ceil(num_train_epochs * (args.sample_frequency / args.num_train_steps)) params_to_optimize = (unet.parameters(), ) if args.train_text_encoder_epochs != 0: @@ -639,7 +646,7 @@ def main(): lr_scheduler=lr_scheduler, num_train_epochs=num_train_epochs, gradient_accumulation_steps=args.gradient_accumulation_steps, - sample_frequency=args.sample_frequency, + sample_frequency=sample_frequency, offset_noise_strength=args.offset_noise_strength, # -- tokenizer=tokenizer, diff --git a/train_lora.py b/train_lora.py index 9975462..7b54ef8 100644 --- a/train_lora.py +++ b/train_lora.py @@ -313,6 +313,12 @@ def parse_args(): default="dadan", help='Optimizer to use ["adam", "adam8bit", "dadam", "dadan"]' ) + parser.add_argument( + "--dadaptation_d0", + type=float, + default=1e-6, + help="The d0 parameter for Dadaptation optimizers." + ) parser.add_argument( "--adam_beta1", type=float, @@ -567,6 +573,7 @@ def main(): weight_decay=args.adam_weight_decay, eps=args.adam_epsilon, decouple=True, + d0=args.dadaptation_d0, ) args.learning_rate = 1.0 @@ -580,6 +587,7 @@ def main(): dadaptation.DAdaptAdan, weight_decay=args.adam_weight_decay, eps=args.adam_epsilon, + d0=args.dadaptation_d0, ) args.learning_rate = 1.0 @@ -628,10 +636,9 @@ def main(): datamodule.setup() num_train_epochs = args.num_train_epochs - if num_train_epochs is None: - num_images = math.ceil(len(datamodule.train_dataset) / args.train_batch_size) * args.train_batch_size - num_train_epochs = math.ceil(args.num_train_steps / num_images) + num_train_epochs = math.ceil(args.num_train_steps / len(datamodule.train_dataset)) + sample_frequency = math.ceil(num_train_epochs * (args.sample_frequency / args.num_train_steps)) optimizer = create_optimizer( itertools.chain( @@ -667,7 +674,7 @@ def main(): lr_scheduler=lr_scheduler, num_train_epochs=num_train_epochs, gradient_accumulation_steps=args.gradient_accumulation_steps, - sample_frequency=args.sample_frequency, + sample_frequency=sample_frequency, offset_noise_strength=args.offset_noise_strength, # -- tokenizer=tokenizer, diff --git a/train_ti.py b/train_ti.py index b7ea5f3..902f508 100644 --- a/train_ti.py +++ b/train_ti.py @@ -319,6 +319,12 @@ def parse_args(): default="dadan", help='Optimizer to use ["adam", "adam8bit", "dadam", "dadan"]' ) + parser.add_argument( + "--dadaptation_d0", + type=float, + default=1e-6, + help="The d0 parameter for Dadaptation optimizers." + ) parser.add_argument( "--adam_beta1", type=float, @@ -659,6 +665,7 @@ def main(): weight_decay=args.adam_weight_decay, eps=args.adam_epsilon, decouple=True, + d0=args.dadaptation_d0, ) elif args.optimizer == 'dadan': try: @@ -670,6 +677,7 @@ def main(): dadaptation.DAdaptAdan, weight_decay=args.adam_weight_decay, eps=args.adam_epsilon, + d0=args.dadaptation_d0, ) else: raise ValueError(f"Unknown --optimizer \"{args.optimizer}\"") @@ -690,7 +698,6 @@ def main(): no_val=args.valid_set_size == 0, strategy=textual_inversion_strategy, gradient_accumulation_steps=args.gradient_accumulation_steps, - sample_frequency=args.sample_frequency, checkpoint_frequency=args.checkpoint_frequency, milestone_checkpoints=not args.no_milestone_checkpoints, global_step_offset=global_step_offset, @@ -759,10 +766,9 @@ def main(): datamodule.setup() num_train_epochs = args.num_train_epochs - if num_train_epochs is None: - num_images = math.ceil(len(datamodule.train_dataset) / args.train_batch_size) * args.train_batch_size - num_train_epochs = math.ceil(args.num_train_steps / num_images) + num_train_epochs = math.ceil(args.num_train_steps / len(datamodule.train_dataset)) + sample_frequency = math.ceil(num_train_epochs * (args.sample_frequency / args.num_train_steps)) optimizer = create_optimizer( text_encoder.text_model.embeddings.temp_token_embedding.parameters(), @@ -792,6 +798,7 @@ def main(): optimizer=optimizer, lr_scheduler=lr_scheduler, num_train_epochs=num_train_epochs, + sample_frequency=sample_frequency, # -- sample_output_dir=sample_output_dir, placeholder_tokens=placeholder_tokens, -- cgit v1.2.3-70-g09d2