diff options
Diffstat (limited to 'train_dreambooth.py')
-rw-r--r-- | train_dreambooth.py | 15 |
1 files changed, 11 insertions, 4 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py index f1dca7f..d2e60ec 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
@@ -302,6 +302,12 @@ def parse_args(): | |||
302 | help='Optimizer to use ["adam", "adam8bit", "dadam", "dadan"]' | 302 | help='Optimizer to use ["adam", "adam8bit", "dadam", "dadan"]' |
303 | ) | 303 | ) |
304 | parser.add_argument( | 304 | parser.add_argument( |
305 | "--dadaptation_d0", | ||
306 | type=float, | ||
307 | default=1e-6, | ||
308 | help="The d0 parameter for Dadaptation optimizers." | ||
309 | ) | ||
310 | parser.add_argument( | ||
305 | "--adam_beta1", | 311 | "--adam_beta1", |
306 | type=float, | 312 | type=float, |
307 | default=0.9, | 313 | default=0.9, |
@@ -535,6 +541,7 @@ def main(): | |||
535 | weight_decay=args.adam_weight_decay, | 541 | weight_decay=args.adam_weight_decay, |
536 | eps=args.adam_epsilon, | 542 | eps=args.adam_epsilon, |
537 | decouple=True, | 543 | decouple=True, |
544 | d0=args.dadaptation_d0, | ||
538 | ) | 545 | ) |
539 | 546 | ||
540 | args.learning_rate = 1.0 | 547 | args.learning_rate = 1.0 |
@@ -548,6 +555,7 @@ def main(): | |||
548 | dadaptation.DAdaptAdan, | 555 | dadaptation.DAdaptAdan, |
549 | weight_decay=args.adam_weight_decay, | 556 | weight_decay=args.adam_weight_decay, |
550 | eps=args.adam_epsilon, | 557 | eps=args.adam_epsilon, |
558 | d0=args.dadaptation_d0, | ||
551 | ) | 559 | ) |
552 | 560 | ||
553 | args.learning_rate = 1.0 | 561 | args.learning_rate = 1.0 |
@@ -596,10 +604,9 @@ def main(): | |||
596 | datamodule.setup() | 604 | datamodule.setup() |
597 | 605 | ||
598 | num_train_epochs = args.num_train_epochs | 606 | num_train_epochs = args.num_train_epochs |
599 | |||
600 | if num_train_epochs is None: | 607 | if num_train_epochs is None: |
601 | num_images = math.ceil(len(datamodule.train_dataset) / args.train_batch_size) * args.train_batch_size | 608 | num_train_epochs = math.ceil(args.num_train_steps / len(datamodule.train_dataset)) |
602 | num_train_epochs = math.ceil(args.num_train_steps / num_images) | 609 | sample_frequency = math.ceil(num_train_epochs * (args.sample_frequency / args.num_train_steps)) |
603 | 610 | ||
604 | params_to_optimize = (unet.parameters(), ) | 611 | params_to_optimize = (unet.parameters(), ) |
605 | if args.train_text_encoder_epochs != 0: | 612 | if args.train_text_encoder_epochs != 0: |
@@ -639,7 +646,7 @@ def main(): | |||
639 | lr_scheduler=lr_scheduler, | 646 | lr_scheduler=lr_scheduler, |
640 | num_train_epochs=num_train_epochs, | 647 | num_train_epochs=num_train_epochs, |
641 | gradient_accumulation_steps=args.gradient_accumulation_steps, | 648 | gradient_accumulation_steps=args.gradient_accumulation_steps, |
642 | sample_frequency=args.sample_frequency, | 649 | sample_frequency=sample_frequency, |
643 | offset_noise_strength=args.offset_noise_strength, | 650 | offset_noise_strength=args.offset_noise_strength, |
644 | # -- | 651 | # -- |
645 | tokenizer=tokenizer, | 652 | tokenizer=tokenizer, |