diff options
Diffstat (limited to 'dreambooth.py')
| -rw-r--r-- | dreambooth.py | 3 | 
1 files changed, 2 insertions, 1 deletions
| diff --git a/dreambooth.py b/dreambooth.py index 744d1bc..88cd0da 100644 --- a/dreambooth.py +++ b/dreambooth.py | |||
| @@ -638,6 +638,7 @@ def main(): | |||
| 638 | identifier=args.identifier, | 638 | identifier=args.identifier, | 
| 639 | repeats=args.repeats, | 639 | repeats=args.repeats, | 
| 640 | center_crop=args.center_crop, | 640 | center_crop=args.center_crop, | 
| 641 | valid_set_size=args.sample_batch_size*args.stable_sample_batches, | ||
| 641 | collate_fn=collate_fn) | 642 | collate_fn=collate_fn) | 
| 642 | 643 | ||
| 643 | datamodule.prepare_data() | 644 | datamodule.prepare_data() | 
| @@ -658,7 +659,7 @@ def main(): | |||
| 658 | sample_batch_size=args.sample_batch_size, | 659 | sample_batch_size=args.sample_batch_size, | 
| 659 | random_sample_batches=args.random_sample_batches, | 660 | random_sample_batches=args.random_sample_batches, | 
| 660 | stable_sample_batches=args.stable_sample_batches, | 661 | stable_sample_batches=args.stable_sample_batches, | 
| 661 | seed=args.seed | 662 | seed=args.seed or torch.random.seed() | 
| 662 | ) | 663 | ) | 
| 663 | 664 | ||
| 664 | # Scheduler and math around the number of training steps. | 665 | # Scheduler and math around the number of training steps. | 
