diff options
Diffstat (limited to 'train_dreambooth.py')
-rw-r--r-- | train_dreambooth.py | 6 |
1 files changed, 3 insertions, 3 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py index 280cf77..6d699f3 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
@@ -87,7 +87,7 @@ def parse_args(): | |||
87 | parser.add_argument( | 87 | parser.add_argument( |
88 | "--num_buckets", | 88 | "--num_buckets", |
89 | type=int, | 89 | type=int, |
90 | default=4, | 90 | default=0, |
91 | help="Number of aspect ratio buckets in either direction.", | 91 | help="Number of aspect ratio buckets in either direction.", |
92 | ) | 92 | ) |
93 | parser.add_argument( | 93 | parser.add_argument( |
@@ -305,7 +305,7 @@ def parse_args(): | |||
305 | parser.add_argument( | 305 | parser.add_argument( |
306 | "--adam_weight_decay", | 306 | "--adam_weight_decay", |
307 | type=float, | 307 | type=float, |
308 | default=0, | 308 | default=1e-2, |
309 | help="Weight decay to use." | 309 | help="Weight decay to use." |
310 | ) | 310 | ) |
311 | parser.add_argument( | 311 | parser.add_argument( |
@@ -526,6 +526,7 @@ def main(): | |||
526 | with_prior_preservation=args.num_class_images != 0, | 526 | with_prior_preservation=args.num_class_images != 0, |
527 | prior_loss_weight=args.prior_loss_weight, | 527 | prior_loss_weight=args.prior_loss_weight, |
528 | no_val=args.valid_set_size == 0, | 528 | no_val=args.valid_set_size == 0, |
529 | # low_freq_noise=0, | ||
529 | ) | 530 | ) |
530 | 531 | ||
531 | checkpoint_output_dir = output_dir / "model" | 532 | checkpoint_output_dir = output_dir / "model" |
@@ -587,7 +588,6 @@ def main(): | |||
587 | seed=args.seed, | 588 | seed=args.seed, |
588 | optimizer=optimizer, | 589 | optimizer=optimizer, |
589 | lr_scheduler=lr_scheduler, | 590 | lr_scheduler=lr_scheduler, |
590 | prepare_unet=True, | ||
591 | num_train_epochs=args.num_train_epochs, | 591 | num_train_epochs=args.num_train_epochs, |
592 | gradient_accumulation_steps=args.gradient_accumulation_steps, | 592 | gradient_accumulation_steps=args.gradient_accumulation_steps, |
593 | sample_frequency=args.sample_frequency, | 593 | sample_frequency=args.sample_frequency, |