summaryrefslogtreecommitdiffstats
path: root/train_dreambooth.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-03-01 12:34:42 +0100
committerVolpeon <git@volpeon.ink>2023-03-01 12:34:42 +0100
commita1b8327085ddeab589be074d7e9df4291aba1210 (patch)
tree2f2016916d7a2f659268c3e375d55c59583c2b3b /train_dreambooth.py
parentFixed TI normalization order (diff)
downloadtextual-inversion-diff-a1b8327085ddeab589be074d7e9df4291aba1210.tar.gz
textual-inversion-diff-a1b8327085ddeab589be074d7e9df4291aba1210.tar.bz2
textual-inversion-diff-a1b8327085ddeab589be074d7e9df4291aba1210.zip
Update
Diffstat (limited to 'train_dreambooth.py')
-rw-r--r--train_dreambooth.py6
1 files changed, 3 insertions, 3 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py
index 280cf77..6d699f3 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -87,7 +87,7 @@ def parse_args():
87 parser.add_argument( 87 parser.add_argument(
88 "--num_buckets", 88 "--num_buckets",
89 type=int, 89 type=int,
90 default=4, 90 default=0,
91 help="Number of aspect ratio buckets in either direction.", 91 help="Number of aspect ratio buckets in either direction.",
92 ) 92 )
93 parser.add_argument( 93 parser.add_argument(
@@ -305,7 +305,7 @@ def parse_args():
305 parser.add_argument( 305 parser.add_argument(
306 "--adam_weight_decay", 306 "--adam_weight_decay",
307 type=float, 307 type=float,
308 default=0, 308 default=1e-2,
309 help="Weight decay to use." 309 help="Weight decay to use."
310 ) 310 )
311 parser.add_argument( 311 parser.add_argument(
@@ -526,6 +526,7 @@ def main():
526 with_prior_preservation=args.num_class_images != 0, 526 with_prior_preservation=args.num_class_images != 0,
527 prior_loss_weight=args.prior_loss_weight, 527 prior_loss_weight=args.prior_loss_weight,
528 no_val=args.valid_set_size == 0, 528 no_val=args.valid_set_size == 0,
529 # low_freq_noise=0,
529 ) 530 )
530 531
531 checkpoint_output_dir = output_dir / "model" 532 checkpoint_output_dir = output_dir / "model"
@@ -587,7 +588,6 @@ def main():
587 seed=args.seed, 588 seed=args.seed,
588 optimizer=optimizer, 589 optimizer=optimizer,
589 lr_scheduler=lr_scheduler, 590 lr_scheduler=lr_scheduler,
590 prepare_unet=True,
591 num_train_epochs=args.num_train_epochs, 591 num_train_epochs=args.num_train_epochs,
592 gradient_accumulation_steps=args.gradient_accumulation_steps, 592 gradient_accumulation_steps=args.gradient_accumulation_steps,
593 sample_frequency=args.sample_frequency, 593 sample_frequency=args.sample_frequency,