diff options
| author | Volpeon <git@volpeon.ink> | 2023-01-08 20:33:04 +0100 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2023-01-08 20:33:04 +0100 |
| commit | ecb12378da48fc3a17539d5cc33edc561cf8a426 (patch) | |
| tree | 30517efe41d557a4c1f2661e80e4c0b87e807048 /train_ti.py | |
| parent | Fixed aspect ratio bucketing (diff) | |
| download | textual-inversion-diff-ecb12378da48fc3a17539d5cc33edc561cf8a426.tar.gz textual-inversion-diff-ecb12378da48fc3a17539d5cc33edc561cf8a426.tar.bz2 textual-inversion-diff-ecb12378da48fc3a17539d5cc33edc561cf8a426.zip | |
Improved aspect ratio bucketing
Diffstat (limited to 'train_ti.py')
| -rw-r--r-- | train_ti.py | 16 |
1 files changed, 15 insertions, 1 deletions
diff --git a/train_ti.py b/train_ti.py index 323ef10..eb0b8b6 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -143,7 +143,7 @@ def parse_args(): | |||
| 143 | "--num_buckets", | 143 | "--num_buckets", |
| 144 | type=int, | 144 | type=int, |
| 145 | default=4, | 145 | default=4, |
| 146 | help="Number of aspect ratio buckets in either direction (adds 64 pixels per step).", | 146 | help="Number of aspect ratio buckets in either direction.", |
| 147 | ) | 147 | ) |
| 148 | parser.add_argument( | 148 | parser.add_argument( |
| 149 | "--progressive_buckets", | 149 | "--progressive_buckets", |
| @@ -151,6 +151,18 @@ def parse_args(): | |||
| 151 | help="Include images in smaller buckets as well.", | 151 | help="Include images in smaller buckets as well.", |
| 152 | ) | 152 | ) |
| 153 | parser.add_argument( | 153 | parser.add_argument( |
| 154 | "--bucket_step_size", | ||
| 155 | type=int, | ||
| 156 | default=64, | ||
| 157 | help="Step size between buckets.", | ||
| 158 | ) | ||
| 159 | parser.add_argument( | ||
| 160 | "--bucket_max_pixels", | ||
| 161 | type=int, | ||
| 162 | default=None, | ||
| 163 | help="Maximum pixels per bucket.", | ||
| 164 | ) | ||
| 165 | parser.add_argument( | ||
| 154 | "--tag_dropout", | 166 | "--tag_dropout", |
| 155 | type=float, | 167 | type=float, |
| 156 | default=0.1, | 168 | default=0.1, |
| @@ -718,6 +730,8 @@ def main(): | |||
| 718 | size=args.resolution, | 730 | size=args.resolution, |
| 719 | num_buckets=args.num_buckets, | 731 | num_buckets=args.num_buckets, |
| 720 | progressive_buckets=args.progressive_buckets, | 732 | progressive_buckets=args.progressive_buckets, |
| 733 | bucket_step_size=args.bucket_step_size, | ||
| 734 | bucket_max_pixels=args.bucket_max_pixels, | ||
| 721 | dropout=args.tag_dropout, | 735 | dropout=args.tag_dropout, |
| 722 | template_key=args.train_data_template, | 736 | template_key=args.train_data_template, |
| 723 | valid_set_size=args.valid_set_size, | 737 | valid_set_size=args.valid_set_size, |
