summaryrefslogtreecommitdiffstats
path: root/train_dreambooth.py
diff options
context:
space:
mode:
Diffstat (limited to 'train_dreambooth.py')
-rw-r--r--train_dreambooth.py15
1 files changed, 11 insertions, 4 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py
index f1dca7f..d2e60ec 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -302,6 +302,12 @@ def parse_args():
302 help='Optimizer to use ["adam", "adam8bit", "dadam", "dadan"]' 302 help='Optimizer to use ["adam", "adam8bit", "dadam", "dadan"]'
303 ) 303 )
304 parser.add_argument( 304 parser.add_argument(
305 "--dadaptation_d0",
306 type=float,
307 default=1e-6,
308 help="The d0 parameter for Dadaptation optimizers."
309 )
310 parser.add_argument(
305 "--adam_beta1", 311 "--adam_beta1",
306 type=float, 312 type=float,
307 default=0.9, 313 default=0.9,
@@ -535,6 +541,7 @@ def main():
535 weight_decay=args.adam_weight_decay, 541 weight_decay=args.adam_weight_decay,
536 eps=args.adam_epsilon, 542 eps=args.adam_epsilon,
537 decouple=True, 543 decouple=True,
544 d0=args.dadaptation_d0,
538 ) 545 )
539 546
540 args.learning_rate = 1.0 547 args.learning_rate = 1.0
@@ -548,6 +555,7 @@ def main():
548 dadaptation.DAdaptAdan, 555 dadaptation.DAdaptAdan,
549 weight_decay=args.adam_weight_decay, 556 weight_decay=args.adam_weight_decay,
550 eps=args.adam_epsilon, 557 eps=args.adam_epsilon,
558 d0=args.dadaptation_d0,
551 ) 559 )
552 560
553 args.learning_rate = 1.0 561 args.learning_rate = 1.0
@@ -596,10 +604,9 @@ def main():
596 datamodule.setup() 604 datamodule.setup()
597 605
598 num_train_epochs = args.num_train_epochs 606 num_train_epochs = args.num_train_epochs
599
600 if num_train_epochs is None: 607 if num_train_epochs is None:
601 num_images = math.ceil(len(datamodule.train_dataset) / args.train_batch_size) * args.train_batch_size 608 num_train_epochs = math.ceil(args.num_train_steps / len(datamodule.train_dataset))
602 num_train_epochs = math.ceil(args.num_train_steps / num_images) 609 sample_frequency = math.ceil(num_train_epochs * (args.sample_frequency / args.num_train_steps))
603 610
604 params_to_optimize = (unet.parameters(), ) 611 params_to_optimize = (unet.parameters(), )
605 if args.train_text_encoder_epochs != 0: 612 if args.train_text_encoder_epochs != 0:
@@ -639,7 +646,7 @@ def main():
639 lr_scheduler=lr_scheduler, 646 lr_scheduler=lr_scheduler,
640 num_train_epochs=num_train_epochs, 647 num_train_epochs=num_train_epochs,
641 gradient_accumulation_steps=args.gradient_accumulation_steps, 648 gradient_accumulation_steps=args.gradient_accumulation_steps,
642 sample_frequency=args.sample_frequency, 649 sample_frequency=sample_frequency,
643 offset_noise_strength=args.offset_noise_strength, 650 offset_noise_strength=args.offset_noise_strength,
644 # -- 651 # --
645 tokenizer=tokenizer, 652 tokenizer=tokenizer,