diff options
-rw-r--r-- | data/csv.py | 27 | ||||
-rw-r--r-- | train_dreambooth.py | 15 | ||||
-rw-r--r-- | train_lora.py | 2 | ||||
-rw-r--r-- | train_ti.py | 2 |
4 files changed, 35 insertions, 11 deletions
diff --git a/data/csv.py b/data/csv.py index 43bf14c..c38db6d 100644 --- a/data/csv.py +++ b/data/csv.py | |||
@@ -156,12 +156,16 @@ class VlpnDataItem(NamedTuple): | |||
156 | 156 | ||
157 | def full_prompt( | 157 | def full_prompt( |
158 | self, | 158 | self, |
159 | dropout: float = 0, | 159 | prompt_dropout: float = 0, |
160 | tag_dropout: float = 0, | ||
160 | shuffle: bool = False, | 161 | shuffle: bool = False, |
161 | npgenerator: Optional[np.random.Generator] = None, | 162 | npgenerator: Optional[np.random.Generator] = None, |
162 | ): | 163 | ): |
164 | if prompt_dropout != 0 and np.random.random() <= prompt_dropout: | ||
165 | return "" | ||
166 | |||
163 | return keywords_to_str( | 167 | return keywords_to_str( |
164 | self.keywords, [self.prompt], dropout, shuffle, npgenerator | 168 | self.keywords, [self.prompt], tag_dropout, shuffle, npgenerator |
165 | ) | 169 | ) |
166 | 170 | ||
167 | 171 | ||
@@ -200,7 +204,8 @@ class VlpnDataModule: | |||
200 | bucket_step_size: int = 64, | 204 | bucket_step_size: int = 64, |
201 | bucket_max_pixels: Optional[int] = None, | 205 | bucket_max_pixels: Optional[int] = None, |
202 | progressive_buckets: bool = False, | 206 | progressive_buckets: bool = False, |
203 | dropout: float = 0, | 207 | prompt_dropout: float = 0, |
208 | tag_dropout: float = 0, | ||
204 | shuffle: bool = False, | 209 | shuffle: bool = False, |
205 | interpolation: str = "bicubic", | 210 | interpolation: str = "bicubic", |
206 | color_jitter: bool = False, | 211 | color_jitter: bool = False, |
@@ -236,7 +241,8 @@ class VlpnDataModule: | |||
236 | self.bucket_step_size = bucket_step_size | 241 | self.bucket_step_size = bucket_step_size |
237 | self.bucket_max_pixels = bucket_max_pixels | 242 | self.bucket_max_pixels = bucket_max_pixels |
238 | self.progressive_buckets = progressive_buckets | 243 | self.progressive_buckets = progressive_buckets |
239 | self.dropout = dropout | 244 | self.prompt_dropout = prompt_dropout |
245 | self.tag_dropout = tag_dropout | ||
240 | self.shuffle = shuffle | 246 | self.shuffle = shuffle |
241 | self.template_key = template_key | 247 | self.template_key = template_key |
242 | self.interpolation = interpolation | 248 | self.interpolation = interpolation |
@@ -382,7 +388,8 @@ class VlpnDataModule: | |||
382 | interpolation=self.interpolation, | 388 | interpolation=self.interpolation, |
383 | color_jitter=self.color_jitter, | 389 | color_jitter=self.color_jitter, |
384 | num_class_images=self.num_class_images, | 390 | num_class_images=self.num_class_images, |
385 | dropout=self.dropout, | 391 | tag_dropout=self.tag_dropout, |
392 | prompt_dropout=self.prompt_dropout, | ||
386 | shuffle=self.shuffle, | 393 | shuffle=self.shuffle, |
387 | ) | 394 | ) |
388 | 395 | ||
@@ -433,7 +440,8 @@ class VlpnDataset(IterableDataset): | |||
433 | fill_batch: bool = False, | 440 | fill_batch: bool = False, |
434 | num_class_images: int = 0, | 441 | num_class_images: int = 0, |
435 | size: int = 768, | 442 | size: int = 768, |
436 | dropout: float = 0, | 443 | tag_dropout: float = 0, |
444 | prompt_dropout: float = 0, | ||
437 | shuffle: bool = False, | 445 | shuffle: bool = False, |
438 | interpolation: str = "bicubic", | 446 | interpolation: str = "bicubic", |
439 | color_jitter: bool = False, | 447 | color_jitter: bool = False, |
@@ -447,7 +455,8 @@ class VlpnDataset(IterableDataset): | |||
447 | self.tokenizer = tokenizer | 455 | self.tokenizer = tokenizer |
448 | self.num_class_images = num_class_images | 456 | self.num_class_images = num_class_images |
449 | self.size = size | 457 | self.size = size |
450 | self.dropout = dropout | 458 | self.tag_dropout = tag_dropout |
459 | self.prompt_dropout = prompt_dropout | ||
451 | self.shuffle = shuffle | 460 | self.shuffle = shuffle |
452 | self.interpolation = interpolations[interpolation] | 461 | self.interpolation = interpolations[interpolation] |
453 | self.color_jitter = color_jitter | 462 | self.color_jitter = color_jitter |
@@ -558,7 +567,9 @@ class VlpnDataset(IterableDataset): | |||
558 | example["nprompt_ids"] = self.get_input_ids(item.nprompt) | 567 | example["nprompt_ids"] = self.get_input_ids(item.nprompt) |
559 | 568 | ||
560 | example["instance_prompt_ids"] = self.get_input_ids( | 569 | example["instance_prompt_ids"] = self.get_input_ids( |
561 | item.full_prompt(self.dropout, True, self.npgenerator) | 570 | item.full_prompt( |
571 | self.prompt_dropout, self.tag_dropout, True, self.npgenerator | ||
572 | ) | ||
562 | ) | 573 | ) |
563 | example["negative_prompt_ids"] = self.get_input_ids(item.nprompt) | 574 | example["negative_prompt_ids"] = self.get_input_ids(item.nprompt) |
564 | example["instance_images"] = image_transforms( | 575 | example["instance_images"] = image_transforms( |
diff --git a/train_dreambooth.py b/train_dreambooth.py index ab3ed16..7745d27 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
@@ -189,6 +189,12 @@ def parse_args(): | |||
189 | help="Tag dropout probability.", | 189 | help="Tag dropout probability.", |
190 | ) | 190 | ) |
191 | parser.add_argument( | 191 | parser.add_argument( |
192 | "--prompt_dropout", | ||
193 | type=float, | ||
194 | default=0, | ||
195 | help="Prompt dropout probability.", | ||
196 | ) | ||
197 | parser.add_argument( | ||
192 | "--no_tag_shuffle", | 198 | "--no_tag_shuffle", |
193 | action="store_true", | 199 | action="store_true", |
194 | help="Shuffle tags.", | 200 | help="Shuffle tags.", |
@@ -255,6 +261,11 @@ def parse_args(): | |||
255 | help="Number of epochs the text encoder will be trained.", | 261 | help="Number of epochs the text encoder will be trained.", |
256 | ) | 262 | ) |
257 | parser.add_argument( | 263 | parser.add_argument( |
264 | "--text_encoder_unfreeze_last_n_layers", | ||
265 | default=2, | ||
266 | help="Number of text encoder layers to train.", | ||
267 | ) | ||
268 | parser.add_argument( | ||
258 | "--find_lr", | 269 | "--find_lr", |
259 | action="store_true", | 270 | action="store_true", |
260 | help="Automatically find a learning rate (no training).", | 271 | help="Automatically find a learning rate (no training).", |
@@ -908,7 +919,8 @@ def main(): | |||
908 | dreambooth_datamodule = create_datamodule( | 919 | dreambooth_datamodule = create_datamodule( |
909 | valid_set_size=args.valid_set_size, | 920 | valid_set_size=args.valid_set_size, |
910 | batch_size=args.train_batch_size, | 921 | batch_size=args.train_batch_size, |
911 | dropout=args.tag_dropout, | 922 | tag_dropout=args.tag_dropout, |
923 | prompt_dropout=args.prompt_dropout, | ||
912 | filter=partial(keyword_filter, None, args.collection, args.exclude_collections), | 924 | filter=partial(keyword_filter, None, args.collection, args.exclude_collections), |
913 | ) | 925 | ) |
914 | dreambooth_datamodule.setup() | 926 | dreambooth_datamodule.setup() |
@@ -1051,6 +1063,7 @@ def main(): | |||
1051 | checkpoint_output_dir=dreambooth_checkpoint_output_dir, | 1063 | checkpoint_output_dir=dreambooth_checkpoint_output_dir, |
1052 | sample_frequency=dreambooth_sample_frequency, | 1064 | sample_frequency=dreambooth_sample_frequency, |
1053 | input_pertubation=args.input_pertubation, | 1065 | input_pertubation=args.input_pertubation, |
1066 | text_encoder_unfreeze_last_n_layers=args.text_encoder_unfreeze_last_n_layers, | ||
1054 | no_val=args.valid_set_size == 0, | 1067 | no_val=args.valid_set_size == 0, |
1055 | avg_loss=avg_loss, | 1068 | avg_loss=avg_loss, |
1056 | avg_acc=avg_acc, | 1069 | avg_acc=avg_acc, |
diff --git a/train_lora.py b/train_lora.py index 51dc827..1ff25ff 100644 --- a/train_lora.py +++ b/train_lora.py | |||
@@ -1137,7 +1137,7 @@ def main(): | |||
1137 | lora_datamodule = create_datamodule( | 1137 | lora_datamodule = create_datamodule( |
1138 | valid_set_size=args.valid_set_size, | 1138 | valid_set_size=args.valid_set_size, |
1139 | batch_size=args.train_batch_size, | 1139 | batch_size=args.train_batch_size, |
1140 | dropout=args.tag_dropout, | 1140 | tag_dropout=args.tag_dropout, |
1141 | filter=partial(keyword_filter, None, args.collection, args.exclude_collections), | 1141 | filter=partial(keyword_filter, None, args.collection, args.exclude_collections), |
1142 | ) | 1142 | ) |
1143 | lora_datamodule.setup() | 1143 | lora_datamodule.setup() |
diff --git a/train_ti.py b/train_ti.py index 7f93960..1dbd637 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -923,7 +923,7 @@ def main(): | |||
923 | progressive_buckets=args.progressive_buckets, | 923 | progressive_buckets=args.progressive_buckets, |
924 | bucket_step_size=args.bucket_step_size, | 924 | bucket_step_size=args.bucket_step_size, |
925 | bucket_max_pixels=args.bucket_max_pixels, | 925 | bucket_max_pixels=args.bucket_max_pixels, |
926 | dropout=args.tag_dropout, | 926 | tag_dropout=args.tag_dropout, |
927 | shuffle=not args.no_tag_shuffle, | 927 | shuffle=not args.no_tag_shuffle, |
928 | template_key=data_template, | 928 | template_key=data_template, |
929 | placeholder_tokens=args.placeholder_tokens, | 929 | placeholder_tokens=args.placeholder_tokens, |