diff options
| author | Volpeon <git@volpeon.ink> | 2023-03-01 12:34:42 +0100 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2023-03-01 12:34:42 +0100 |
| commit | a1b8327085ddeab589be074d7e9df4291aba1210 (patch) | |
| tree | 2f2016916d7a2f659268c3e375d55c59583c2b3b /train_dreambooth.py | |
| parent | Fixed TI normalization order (diff) | |
| download | textual-inversion-diff-a1b8327085ddeab589be074d7e9df4291aba1210.tar.gz textual-inversion-diff-a1b8327085ddeab589be074d7e9df4291aba1210.tar.bz2 textual-inversion-diff-a1b8327085ddeab589be074d7e9df4291aba1210.zip | |
Update
Diffstat (limited to 'train_dreambooth.py')
| -rw-r--r-- | train_dreambooth.py | 6 |
1 files changed, 3 insertions, 3 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py index 280cf77..6d699f3 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
| @@ -87,7 +87,7 @@ def parse_args(): | |||
| 87 | parser.add_argument( | 87 | parser.add_argument( |
| 88 | "--num_buckets", | 88 | "--num_buckets", |
| 89 | type=int, | 89 | type=int, |
| 90 | default=4, | 90 | default=0, |
| 91 | help="Number of aspect ratio buckets in either direction.", | 91 | help="Number of aspect ratio buckets in either direction.", |
| 92 | ) | 92 | ) |
| 93 | parser.add_argument( | 93 | parser.add_argument( |
| @@ -305,7 +305,7 @@ def parse_args(): | |||
| 305 | parser.add_argument( | 305 | parser.add_argument( |
| 306 | "--adam_weight_decay", | 306 | "--adam_weight_decay", |
| 307 | type=float, | 307 | type=float, |
| 308 | default=0, | 308 | default=1e-2, |
| 309 | help="Weight decay to use." | 309 | help="Weight decay to use." |
| 310 | ) | 310 | ) |
| 311 | parser.add_argument( | 311 | parser.add_argument( |
| @@ -526,6 +526,7 @@ def main(): | |||
| 526 | with_prior_preservation=args.num_class_images != 0, | 526 | with_prior_preservation=args.num_class_images != 0, |
| 527 | prior_loss_weight=args.prior_loss_weight, | 527 | prior_loss_weight=args.prior_loss_weight, |
| 528 | no_val=args.valid_set_size == 0, | 528 | no_val=args.valid_set_size == 0, |
| 529 | # low_freq_noise=0, | ||
| 529 | ) | 530 | ) |
| 530 | 531 | ||
| 531 | checkpoint_output_dir = output_dir / "model" | 532 | checkpoint_output_dir = output_dir / "model" |
| @@ -587,7 +588,6 @@ def main(): | |||
| 587 | seed=args.seed, | 588 | seed=args.seed, |
| 588 | optimizer=optimizer, | 589 | optimizer=optimizer, |
| 589 | lr_scheduler=lr_scheduler, | 590 | lr_scheduler=lr_scheduler, |
| 590 | prepare_unet=True, | ||
| 591 | num_train_epochs=args.num_train_epochs, | 591 | num_train_epochs=args.num_train_epochs, |
| 592 | gradient_accumulation_steps=args.gradient_accumulation_steps, | 592 | gradient_accumulation_steps=args.gradient_accumulation_steps, |
| 593 | sample_frequency=args.sample_frequency, | 593 | sample_frequency=args.sample_frequency, |
