diff options
-rw-r--r-- | data/csv.py | 22 | ||||
-rw-r--r-- | train_dreambooth.py | 27 | ||||
-rw-r--r-- | train_ti.py | 16 |
3 files changed, 61 insertions, 4 deletions
diff --git a/data/csv.py b/data/csv.py index 7527b7d..55a1988 100644 --- a/data/csv.py +++ b/data/csv.py | |||
@@ -44,18 +44,25 @@ def generate_buckets( | |||
44 | items: list[str], | 44 | items: list[str], |
45 | base_size: int, | 45 | base_size: int, |
46 | step_size: int = 64, | 46 | step_size: int = 64, |
47 | max_pixels: Optional[int] = None, | ||
47 | num_buckets: int = 4, | 48 | num_buckets: int = 4, |
48 | progressive_buckets: bool = False, | 49 | progressive_buckets: bool = False, |
49 | return_tensor: bool = True | 50 | return_tensor: bool = True |
50 | ): | 51 | ): |
52 | if max_pixels is None: | ||
53 | max_pixels = (base_size + step_size) ** 2 | ||
54 | |||
55 | max_pixels = max(max_pixels, base_size * base_size) | ||
56 | |||
51 | bucket_items: list[int] = [] | 57 | bucket_items: list[int] = [] |
52 | bucket_assignments: list[int] = [] | 58 | bucket_assignments: list[int] = [] |
53 | buckets = [1.0] | 59 | buckets = [1.0] |
54 | 60 | ||
55 | for i in range(1, num_buckets + 1): | 61 | for i in range(1, num_buckets + 1): |
56 | s = base_size + i * step_size | 62 | long_side = base_size + i * step_size |
57 | buckets.append(s / base_size) | 63 | short_side = min(base_size - math.ceil((base_size - max_pixels / long_side) / step_size) * step_size, base_size) |
58 | buckets.append(base_size / s) | 64 | buckets.append(long_side / short_side) |
65 | buckets.append(short_side / long_side) | ||
59 | 66 | ||
60 | buckets = torch.tensor(buckets) | 67 | buckets = torch.tensor(buckets) |
61 | bucket_indices = torch.arange(len(buckets)) | 68 | bucket_indices = torch.arange(len(buckets)) |
@@ -110,6 +117,8 @@ class VlpnDataModule(): | |||
110 | num_class_images: int = 1, | 117 | num_class_images: int = 1, |
111 | size: int = 768, | 118 | size: int = 768, |
112 | num_buckets: int = 0, | 119 | num_buckets: int = 0, |
120 | bucket_step_size: int = 64, | ||
121 | max_pixels_per_bucket: Optional[int] = None, | ||
113 | progressive_buckets: bool = False, | 122 | progressive_buckets: bool = False, |
114 | dropout: float = 0, | 123 | dropout: float = 0, |
115 | interpolation: str = "bicubic", | 124 | interpolation: str = "bicubic", |
@@ -135,6 +144,8 @@ class VlpnDataModule(): | |||
135 | self.prompt_processor = prompt_processor | 144 | self.prompt_processor = prompt_processor |
136 | self.size = size | 145 | self.size = size |
137 | self.num_buckets = num_buckets | 146 | self.num_buckets = num_buckets |
147 | self.bucket_step_size = bucket_step_size | ||
148 | self.max_pixels_per_bucket = max_pixels_per_bucket | ||
138 | self.progressive_buckets = progressive_buckets | 149 | self.progressive_buckets = progressive_buckets |
139 | self.dropout = dropout | 150 | self.dropout = dropout |
140 | self.template_key = template_key | 151 | self.template_key = template_key |
@@ -223,6 +234,7 @@ class VlpnDataModule(): | |||
223 | train_dataset = VlpnDataset( | 234 | train_dataset = VlpnDataset( |
224 | self.data_train, self.prompt_processor, | 235 | self.data_train, self.prompt_processor, |
225 | num_buckets=self.num_buckets, progressive_buckets=self.progressive_buckets, | 236 | num_buckets=self.num_buckets, progressive_buckets=self.progressive_buckets, |
237 | bucket_step_size=self.bucket_step_size, max_pixels_per_bucket=self.max_pixels_per_bucket, | ||
226 | batch_size=self.batch_size, generator=generator, | 238 | batch_size=self.batch_size, generator=generator, |
227 | size=self.size, interpolation=self.interpolation, | 239 | size=self.size, interpolation=self.interpolation, |
228 | num_class_images=self.num_class_images, dropout=self.dropout, shuffle=True, | 240 | num_class_images=self.num_class_images, dropout=self.dropout, shuffle=True, |
@@ -251,6 +263,8 @@ class VlpnDataset(IterableDataset): | |||
251 | items: list[VlpnDataItem], | 263 | items: list[VlpnDataItem], |
252 | prompt_processor: PromptProcessor, | 264 | prompt_processor: PromptProcessor, |
253 | num_buckets: int = 1, | 265 | num_buckets: int = 1, |
266 | bucket_step_size: int = 64, | ||
267 | max_pixels_per_bucket: Optional[int] = None, | ||
254 | progressive_buckets: bool = False, | 268 | progressive_buckets: bool = False, |
255 | batch_size: int = 1, | 269 | batch_size: int = 1, |
256 | num_class_images: int = 0, | 270 | num_class_images: int = 0, |
@@ -274,7 +288,9 @@ class VlpnDataset(IterableDataset): | |||
274 | self.buckets, self.bucket_items, self.bucket_assignments = generate_buckets( | 288 | self.buckets, self.bucket_items, self.bucket_assignments = generate_buckets( |
275 | [item.instance_image_path for item in items], | 289 | [item.instance_image_path for item in items], |
276 | base_size=size, | 290 | base_size=size, |
291 | step_size=bucket_step_size, | ||
277 | num_buckets=num_buckets, | 292 | num_buckets=num_buckets, |
293 | max_pixels=max_pixels_per_bucket, | ||
278 | progressive_buckets=progressive_buckets, | 294 | progressive_buckets=progressive_buckets, |
279 | ) | 295 | ) |
280 | 296 | ||
diff --git a/train_dreambooth.py b/train_dreambooth.py index 79eede6..d396249 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
@@ -104,6 +104,29 @@ def parse_args(): | |||
104 | help="Number of epochs the text encoder will be trained." | 104 | help="Number of epochs the text encoder will be trained." |
105 | ) | 105 | ) |
106 | parser.add_argument( | 106 | parser.add_argument( |
107 | "--num_buckets", | ||
108 | type=int, | ||
109 | default=4, | ||
110 | help="Number of aspect ratio buckets in either direction.", | ||
111 | ) | ||
112 | parser.add_argument( | ||
113 | "--progressive_buckets", | ||
114 | action="store_true", | ||
115 | help="Include images in smaller buckets as well.", | ||
116 | ) | ||
117 | parser.add_argument( | ||
118 | "--bucket_step_size", | ||
119 | type=int, | ||
120 | default=64, | ||
121 | help="Step size between buckets.", | ||
122 | ) | ||
123 | parser.add_argument( | ||
124 | "--bucket_max_pixels", | ||
125 | type=int, | ||
126 | default=None, | ||
127 | help="Maximum pixels per bucket.", | ||
128 | ) | ||
129 | parser.add_argument( | ||
107 | "--tag_dropout", | 130 | "--tag_dropout", |
108 | type=float, | 131 | type=float, |
109 | default=0.1, | 132 | default=0.1, |
@@ -734,6 +757,10 @@ def main(): | |||
734 | class_subdir=args.class_image_dir, | 757 | class_subdir=args.class_image_dir, |
735 | num_class_images=args.num_class_images, | 758 | num_class_images=args.num_class_images, |
736 | size=args.resolution, | 759 | size=args.resolution, |
760 | num_buckets=args.num_buckets, | ||
761 | progressive_buckets=args.progressive_buckets, | ||
762 | bucket_step_size=args.bucket_step_size, | ||
763 | bucket_max_pixels=args.bucket_max_pixels, | ||
737 | dropout=args.tag_dropout, | 764 | dropout=args.tag_dropout, |
738 | template_key=args.train_data_template, | 765 | template_key=args.train_data_template, |
739 | valid_set_size=args.valid_set_size, | 766 | valid_set_size=args.valid_set_size, |
diff --git a/train_ti.py b/train_ti.py index 323ef10..eb0b8b6 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -143,7 +143,7 @@ def parse_args(): | |||
143 | "--num_buckets", | 143 | "--num_buckets", |
144 | type=int, | 144 | type=int, |
145 | default=4, | 145 | default=4, |
146 | help="Number of aspect ratio buckets in either direction (adds 64 pixels per step).", | 146 | help="Number of aspect ratio buckets in either direction.", |
147 | ) | 147 | ) |
148 | parser.add_argument( | 148 | parser.add_argument( |
149 | "--progressive_buckets", | 149 | "--progressive_buckets", |
@@ -151,6 +151,18 @@ def parse_args(): | |||
151 | help="Include images in smaller buckets as well.", | 151 | help="Include images in smaller buckets as well.", |
152 | ) | 152 | ) |
153 | parser.add_argument( | 153 | parser.add_argument( |
154 | "--bucket_step_size", | ||
155 | type=int, | ||
156 | default=64, | ||
157 | help="Step size between buckets.", | ||
158 | ) | ||
159 | parser.add_argument( | ||
160 | "--bucket_max_pixels", | ||
161 | type=int, | ||
162 | default=None, | ||
163 | help="Maximum pixels per bucket.", | ||
164 | ) | ||
165 | parser.add_argument( | ||
154 | "--tag_dropout", | 166 | "--tag_dropout", |
155 | type=float, | 167 | type=float, |
156 | default=0.1, | 168 | default=0.1, |
@@ -718,6 +730,8 @@ def main(): | |||
718 | size=args.resolution, | 730 | size=args.resolution, |
719 | num_buckets=args.num_buckets, | 731 | num_buckets=args.num_buckets, |
720 | progressive_buckets=args.progressive_buckets, | 732 | progressive_buckets=args.progressive_buckets, |
733 | bucket_step_size=args.bucket_step_size, | ||
734 | bucket_max_pixels=args.bucket_max_pixels, | ||
721 | dropout=args.tag_dropout, | 735 | dropout=args.tag_dropout, |
722 | template_key=args.train_data_template, | 736 | template_key=args.train_data_template, |
723 | valid_set_size=args.valid_set_size, | 737 | valid_set_size=args.valid_set_size, |