diff options
| author | Volpeon <git@volpeon.ink> | 2023-01-08 20:33:04 +0100 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2023-01-08 20:33:04 +0100 |
| commit | ecb12378da48fc3a17539d5cc33edc561cf8a426 (patch) | |
| tree | 30517efe41d557a4c1f2661e80e4c0b87e807048 /train_dreambooth.py | |
| parent | Fixed aspect ratio bucketing (diff) | |
| download | textual-inversion-diff-ecb12378da48fc3a17539d5cc33edc561cf8a426.tar.gz textual-inversion-diff-ecb12378da48fc3a17539d5cc33edc561cf8a426.tar.bz2 textual-inversion-diff-ecb12378da48fc3a17539d5cc33edc561cf8a426.zip | |
Improved aspect ratio bucketing
Diffstat (limited to 'train_dreambooth.py')
| -rw-r--r-- | train_dreambooth.py | 27 |
1 files changed, 27 insertions, 0 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py index 79eede6..d396249 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
| @@ -104,6 +104,29 @@ def parse_args(): | |||
| 104 | help="Number of epochs the text encoder will be trained." | 104 | help="Number of epochs the text encoder will be trained." |
| 105 | ) | 105 | ) |
| 106 | parser.add_argument( | 106 | parser.add_argument( |
| 107 | "--num_buckets", | ||
| 108 | type=int, | ||
| 109 | default=4, | ||
| 110 | help="Number of aspect ratio buckets in either direction.", | ||
| 111 | ) | ||
| 112 | parser.add_argument( | ||
| 113 | "--progressive_buckets", | ||
| 114 | action="store_true", | ||
| 115 | help="Include images in smaller buckets as well.", | ||
| 116 | ) | ||
| 117 | parser.add_argument( | ||
| 118 | "--bucket_step_size", | ||
| 119 | type=int, | ||
| 120 | default=64, | ||
| 121 | help="Step size between buckets.", | ||
| 122 | ) | ||
| 123 | parser.add_argument( | ||
| 124 | "--bucket_max_pixels", | ||
| 125 | type=int, | ||
| 126 | default=None, | ||
| 127 | help="Maximum pixels per bucket.", | ||
| 128 | ) | ||
| 129 | parser.add_argument( | ||
| 107 | "--tag_dropout", | 130 | "--tag_dropout", |
| 108 | type=float, | 131 | type=float, |
| 109 | default=0.1, | 132 | default=0.1, |
| @@ -734,6 +757,10 @@ def main(): | |||
| 734 | class_subdir=args.class_image_dir, | 757 | class_subdir=args.class_image_dir, |
| 735 | num_class_images=args.num_class_images, | 758 | num_class_images=args.num_class_images, |
| 736 | size=args.resolution, | 759 | size=args.resolution, |
| 760 | num_buckets=args.num_buckets, | ||
| 761 | progressive_buckets=args.progressive_buckets, | ||
| 762 | bucket_step_size=args.bucket_step_size, | ||
| 763 | bucket_max_pixels=args.bucket_max_pixels, | ||
| 737 | dropout=args.tag_dropout, | 764 | dropout=args.tag_dropout, |
| 738 | template_key=args.train_data_template, | 765 | template_key=args.train_data_template, |
| 739 | valid_set_size=args.valid_set_size, | 766 | valid_set_size=args.valid_set_size, |
