diff options
author | Volpeon <git@volpeon.ink> | 2023-03-01 12:34:42 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-03-01 12:34:42 +0100 |
commit | a1b8327085ddeab589be074d7e9df4291aba1210 (patch) | |
tree | 2f2016916d7a2f659268c3e375d55c59583c2b3b /train_dreambooth.py | |
parent | Fixed TI normalization order (diff) | |
download | textual-inversion-diff-a1b8327085ddeab589be074d7e9df4291aba1210.tar.gz textual-inversion-diff-a1b8327085ddeab589be074d7e9df4291aba1210.tar.bz2 textual-inversion-diff-a1b8327085ddeab589be074d7e9df4291aba1210.zip |
Update
Diffstat (limited to 'train_dreambooth.py')
-rw-r--r-- | train_dreambooth.py | 6 |
1 files changed, 3 insertions, 3 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py index 280cf77..6d699f3 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
@@ -87,7 +87,7 @@ def parse_args(): | |||
87 | parser.add_argument( | 87 | parser.add_argument( |
88 | "--num_buckets", | 88 | "--num_buckets", |
89 | type=int, | 89 | type=int, |
90 | default=4, | 90 | default=0, |
91 | help="Number of aspect ratio buckets in either direction.", | 91 | help="Number of aspect ratio buckets in either direction.", |
92 | ) | 92 | ) |
93 | parser.add_argument( | 93 | parser.add_argument( |
@@ -305,7 +305,7 @@ def parse_args(): | |||
305 | parser.add_argument( | 305 | parser.add_argument( |
306 | "--adam_weight_decay", | 306 | "--adam_weight_decay", |
307 | type=float, | 307 | type=float, |
308 | default=0, | 308 | default=1e-2, |
309 | help="Weight decay to use." | 309 | help="Weight decay to use." |
310 | ) | 310 | ) |
311 | parser.add_argument( | 311 | parser.add_argument( |
@@ -526,6 +526,7 @@ def main(): | |||
526 | with_prior_preservation=args.num_class_images != 0, | 526 | with_prior_preservation=args.num_class_images != 0, |
527 | prior_loss_weight=args.prior_loss_weight, | 527 | prior_loss_weight=args.prior_loss_weight, |
528 | no_val=args.valid_set_size == 0, | 528 | no_val=args.valid_set_size == 0, |
529 | # low_freq_noise=0, | ||
529 | ) | 530 | ) |
530 | 531 | ||
531 | checkpoint_output_dir = output_dir / "model" | 532 | checkpoint_output_dir = output_dir / "model" |
@@ -587,7 +588,6 @@ def main(): | |||
587 | seed=args.seed, | 588 | seed=args.seed, |
588 | optimizer=optimizer, | 589 | optimizer=optimizer, |
589 | lr_scheduler=lr_scheduler, | 590 | lr_scheduler=lr_scheduler, |
590 | prepare_unet=True, | ||
591 | num_train_epochs=args.num_train_epochs, | 591 | num_train_epochs=args.num_train_epochs, |
592 | gradient_accumulation_steps=args.gradient_accumulation_steps, | 592 | gradient_accumulation_steps=args.gradient_accumulation_steps, |
593 | sample_frequency=args.sample_frequency, | 593 | sample_frequency=args.sample_frequency, |