summaryrefslogtreecommitdiffstats
path: root/dreambooth.py
diff options
context:
space:
mode:
Diffstat (limited to 'dreambooth.py')
-rw-r--r--dreambooth.py3
1 files changed, 2 insertions, 1 deletions
diff --git a/dreambooth.py b/dreambooth.py
index 744d1bc..88cd0da 100644
--- a/dreambooth.py
+++ b/dreambooth.py
@@ -638,6 +638,7 @@ def main():
638 identifier=args.identifier, 638 identifier=args.identifier,
639 repeats=args.repeats, 639 repeats=args.repeats,
640 center_crop=args.center_crop, 640 center_crop=args.center_crop,
641 valid_set_size=args.sample_batch_size*args.stable_sample_batches,
641 collate_fn=collate_fn) 642 collate_fn=collate_fn)
642 643
643 datamodule.prepare_data() 644 datamodule.prepare_data()
@@ -658,7 +659,7 @@ def main():
658 sample_batch_size=args.sample_batch_size, 659 sample_batch_size=args.sample_batch_size,
659 random_sample_batches=args.random_sample_batches, 660 random_sample_batches=args.random_sample_batches,
660 stable_sample_batches=args.stable_sample_batches, 661 stable_sample_batches=args.stable_sample_batches,
661 seed=args.seed 662 seed=args.seed or torch.random.seed()
662 ) 663 )
663 664
664 # Scheduler and math around the number of training steps. 665 # Scheduler and math around the number of training steps.