summaryrefslogtreecommitdiffstats
path: root/train_dreambooth.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-01-08 20:33:04 +0100
committerVolpeon <git@volpeon.ink>2023-01-08 20:33:04 +0100
commitecb12378da48fc3a17539d5cc33edc561cf8a426 (patch)
tree30517efe41d557a4c1f2661e80e4c0b87e807048 /train_dreambooth.py
parentFixed aspect ratio bucketing (diff)
downloadtextual-inversion-diff-ecb12378da48fc3a17539d5cc33edc561cf8a426.tar.gz
textual-inversion-diff-ecb12378da48fc3a17539d5cc33edc561cf8a426.tar.bz2
textual-inversion-diff-ecb12378da48fc3a17539d5cc33edc561cf8a426.zip
Improved aspect ratio bucketing
Diffstat (limited to 'train_dreambooth.py')
-rw-r--r--train_dreambooth.py27
1 files changed, 27 insertions, 0 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py
index 79eede6..d396249 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -104,6 +104,29 @@ def parse_args():
104 help="Number of epochs the text encoder will be trained." 104 help="Number of epochs the text encoder will be trained."
105 ) 105 )
106 parser.add_argument( 106 parser.add_argument(
107 "--num_buckets",
108 type=int,
109 default=4,
110 help="Number of aspect ratio buckets in either direction.",
111 )
112 parser.add_argument(
113 "--progressive_buckets",
114 action="store_true",
115 help="Include images in smaller buckets as well.",
116 )
117 parser.add_argument(
118 "--bucket_step_size",
119 type=int,
120 default=64,
121 help="Step size between buckets.",
122 )
123 parser.add_argument(
124 "--bucket_max_pixels",
125 type=int,
126 default=None,
127 help="Maximum pixels per bucket.",
128 )
129 parser.add_argument(
107 "--tag_dropout", 130 "--tag_dropout",
108 type=float, 131 type=float,
109 default=0.1, 132 default=0.1,
@@ -734,6 +757,10 @@ def main():
734 class_subdir=args.class_image_dir, 757 class_subdir=args.class_image_dir,
735 num_class_images=args.num_class_images, 758 num_class_images=args.num_class_images,
736 size=args.resolution, 759 size=args.resolution,
760 num_buckets=args.num_buckets,
761 progressive_buckets=args.progressive_buckets,
762 bucket_step_size=args.bucket_step_size,
763 bucket_max_pixels=args.bucket_max_pixels,
737 dropout=args.tag_dropout, 764 dropout=args.tag_dropout,
738 template_key=args.train_data_template, 765 template_key=args.train_data_template,
739 valid_set_size=args.valid_set_size, 766 valid_set_size=args.valid_set_size,