summaryrefslogtreecommitdiffstats
path: root/dreambooth.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-10-03 11:44:42 +0200
committerVolpeon <git@volpeon.ink>2022-10-03 11:44:42 +0200
commitd0ce16b542deac464e097c38adc5095802bd6763 (patch)
tree8dc3295ac2a03c46eb1add566691134335ee6657 /dreambooth.py
parentAdded script to convert Differs -> SD (diff)
downloadtextual-inversion-diff-d0ce16b542deac464e097c38adc5095802bd6763.tar.gz
textual-inversion-diff-d0ce16b542deac464e097c38adc5095802bd6763.tar.bz2
textual-inversion-diff-d0ce16b542deac464e097c38adc5095802bd6763.zip
Assign unused images in validation dataset to train dataset
Diffstat (limited to 'dreambooth.py')
-rw-r--r--dreambooth.py3
1 files changed, 2 insertions, 1 deletions
diff --git a/dreambooth.py b/dreambooth.py
index 744d1bc..88cd0da 100644
--- a/dreambooth.py
+++ b/dreambooth.py
@@ -638,6 +638,7 @@ def main():
638 identifier=args.identifier, 638 identifier=args.identifier,
639 repeats=args.repeats, 639 repeats=args.repeats,
640 center_crop=args.center_crop, 640 center_crop=args.center_crop,
641 valid_set_size=args.sample_batch_size*args.stable_sample_batches,
641 collate_fn=collate_fn) 642 collate_fn=collate_fn)
642 643
643 datamodule.prepare_data() 644 datamodule.prepare_data()
@@ -658,7 +659,7 @@ def main():
658 sample_batch_size=args.sample_batch_size, 659 sample_batch_size=args.sample_batch_size,
659 random_sample_batches=args.random_sample_batches, 660 random_sample_batches=args.random_sample_batches,
660 stable_sample_batches=args.stable_sample_batches, 661 stable_sample_batches=args.stable_sample_batches,
661 seed=args.seed 662 seed=args.seed or torch.random.seed()
662 ) 663 )
663 664
664 # Scheduler and math around the number of training steps. 665 # Scheduler and math around the number of training steps.