diff options
| -rw-r--r-- | data/csv.py | 27 | ||||
| -rw-r--r-- | train_dreambooth.py | 15 | ||||
| -rw-r--r-- | train_lora.py | 2 | ||||
| -rw-r--r-- | train_ti.py | 2 | 
4 files changed, 35 insertions, 11 deletions
| diff --git a/data/csv.py b/data/csv.py index 43bf14c..c38db6d 100644 --- a/data/csv.py +++ b/data/csv.py | |||
| @@ -156,12 +156,16 @@ class VlpnDataItem(NamedTuple): | |||
| 156 | 156 | ||
| 157 | def full_prompt( | 157 | def full_prompt( | 
| 158 | self, | 158 | self, | 
| 159 | dropout: float = 0, | 159 | prompt_dropout: float = 0, | 
| 160 | tag_dropout: float = 0, | ||
| 160 | shuffle: bool = False, | 161 | shuffle: bool = False, | 
| 161 | npgenerator: Optional[np.random.Generator] = None, | 162 | npgenerator: Optional[np.random.Generator] = None, | 
| 162 | ): | 163 | ): | 
| 164 | if prompt_dropout != 0 and np.random.random() <= prompt_dropout: | ||
| 165 | return "" | ||
| 166 | |||
| 163 | return keywords_to_str( | 167 | return keywords_to_str( | 
| 164 | self.keywords, [self.prompt], dropout, shuffle, npgenerator | 168 | self.keywords, [self.prompt], tag_dropout, shuffle, npgenerator | 
| 165 | ) | 169 | ) | 
| 166 | 170 | ||
| 167 | 171 | ||
| @@ -200,7 +204,8 @@ class VlpnDataModule: | |||
| 200 | bucket_step_size: int = 64, | 204 | bucket_step_size: int = 64, | 
| 201 | bucket_max_pixels: Optional[int] = None, | 205 | bucket_max_pixels: Optional[int] = None, | 
| 202 | progressive_buckets: bool = False, | 206 | progressive_buckets: bool = False, | 
| 203 | dropout: float = 0, | 207 | prompt_dropout: float = 0, | 
| 208 | tag_dropout: float = 0, | ||
| 204 | shuffle: bool = False, | 209 | shuffle: bool = False, | 
| 205 | interpolation: str = "bicubic", | 210 | interpolation: str = "bicubic", | 
| 206 | color_jitter: bool = False, | 211 | color_jitter: bool = False, | 
| @@ -236,7 +241,8 @@ class VlpnDataModule: | |||
| 236 | self.bucket_step_size = bucket_step_size | 241 | self.bucket_step_size = bucket_step_size | 
| 237 | self.bucket_max_pixels = bucket_max_pixels | 242 | self.bucket_max_pixels = bucket_max_pixels | 
| 238 | self.progressive_buckets = progressive_buckets | 243 | self.progressive_buckets = progressive_buckets | 
| 239 | self.dropout = dropout | 244 | self.prompt_dropout = prompt_dropout | 
| 245 | self.tag_dropout = tag_dropout | ||
| 240 | self.shuffle = shuffle | 246 | self.shuffle = shuffle | 
| 241 | self.template_key = template_key | 247 | self.template_key = template_key | 
| 242 | self.interpolation = interpolation | 248 | self.interpolation = interpolation | 
| @@ -382,7 +388,8 @@ class VlpnDataModule: | |||
| 382 | interpolation=self.interpolation, | 388 | interpolation=self.interpolation, | 
| 383 | color_jitter=self.color_jitter, | 389 | color_jitter=self.color_jitter, | 
| 384 | num_class_images=self.num_class_images, | 390 | num_class_images=self.num_class_images, | 
| 385 | dropout=self.dropout, | 391 | tag_dropout=self.tag_dropout, | 
| 392 | prompt_dropout=self.prompt_dropout, | ||
| 386 | shuffle=self.shuffle, | 393 | shuffle=self.shuffle, | 
| 387 | ) | 394 | ) | 
| 388 | 395 | ||
| @@ -433,7 +440,8 @@ class VlpnDataset(IterableDataset): | |||
| 433 | fill_batch: bool = False, | 440 | fill_batch: bool = False, | 
| 434 | num_class_images: int = 0, | 441 | num_class_images: int = 0, | 
| 435 | size: int = 768, | 442 | size: int = 768, | 
| 436 | dropout: float = 0, | 443 | tag_dropout: float = 0, | 
| 444 | prompt_dropout: float = 0, | ||
| 437 | shuffle: bool = False, | 445 | shuffle: bool = False, | 
| 438 | interpolation: str = "bicubic", | 446 | interpolation: str = "bicubic", | 
| 439 | color_jitter: bool = False, | 447 | color_jitter: bool = False, | 
| @@ -447,7 +455,8 @@ class VlpnDataset(IterableDataset): | |||
| 447 | self.tokenizer = tokenizer | 455 | self.tokenizer = tokenizer | 
| 448 | self.num_class_images = num_class_images | 456 | self.num_class_images = num_class_images | 
| 449 | self.size = size | 457 | self.size = size | 
| 450 | self.dropout = dropout | 458 | self.tag_dropout = tag_dropout | 
| 459 | self.prompt_dropout = prompt_dropout | ||
| 451 | self.shuffle = shuffle | 460 | self.shuffle = shuffle | 
| 452 | self.interpolation = interpolations[interpolation] | 461 | self.interpolation = interpolations[interpolation] | 
| 453 | self.color_jitter = color_jitter | 462 | self.color_jitter = color_jitter | 
| @@ -558,7 +567,9 @@ class VlpnDataset(IterableDataset): | |||
| 558 | example["nprompt_ids"] = self.get_input_ids(item.nprompt) | 567 | example["nprompt_ids"] = self.get_input_ids(item.nprompt) | 
| 559 | 568 | ||
| 560 | example["instance_prompt_ids"] = self.get_input_ids( | 569 | example["instance_prompt_ids"] = self.get_input_ids( | 
| 561 | item.full_prompt(self.dropout, True, self.npgenerator) | 570 | item.full_prompt( | 
| 571 | self.prompt_dropout, self.tag_dropout, True, self.npgenerator | ||
| 572 | ) | ||
| 562 | ) | 573 | ) | 
| 563 | example["negative_prompt_ids"] = self.get_input_ids(item.nprompt) | 574 | example["negative_prompt_ids"] = self.get_input_ids(item.nprompt) | 
| 564 | example["instance_images"] = image_transforms( | 575 | example["instance_images"] = image_transforms( | 
| diff --git a/train_dreambooth.py b/train_dreambooth.py index ab3ed16..7745d27 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
| @@ -189,6 +189,12 @@ def parse_args(): | |||
| 189 | help="Tag dropout probability.", | 189 | help="Tag dropout probability.", | 
| 190 | ) | 190 | ) | 
| 191 | parser.add_argument( | 191 | parser.add_argument( | 
| 192 | "--prompt_dropout", | ||
| 193 | type=float, | ||
| 194 | default=0, | ||
| 195 | help="Prompt dropout probability.", | ||
| 196 | ) | ||
| 197 | parser.add_argument( | ||
| 192 | "--no_tag_shuffle", | 198 | "--no_tag_shuffle", | 
| 193 | action="store_true", | 199 | action="store_true", | 
| 194 | help="Shuffle tags.", | 200 | help="Shuffle tags.", | 
| @@ -255,6 +261,11 @@ def parse_args(): | |||
| 255 | help="Number of epochs the text encoder will be trained.", | 261 | help="Number of epochs the text encoder will be trained.", | 
| 256 | ) | 262 | ) | 
| 257 | parser.add_argument( | 263 | parser.add_argument( | 
| 264 | "--text_encoder_unfreeze_last_n_layers", | ||
| 265 | default=2, | ||
| 266 | help="Number of text encoder layers to train.", | ||
| 267 | ) | ||
| 268 | parser.add_argument( | ||
| 258 | "--find_lr", | 269 | "--find_lr", | 
| 259 | action="store_true", | 270 | action="store_true", | 
| 260 | help="Automatically find a learning rate (no training).", | 271 | help="Automatically find a learning rate (no training).", | 
| @@ -908,7 +919,8 @@ def main(): | |||
| 908 | dreambooth_datamodule = create_datamodule( | 919 | dreambooth_datamodule = create_datamodule( | 
| 909 | valid_set_size=args.valid_set_size, | 920 | valid_set_size=args.valid_set_size, | 
| 910 | batch_size=args.train_batch_size, | 921 | batch_size=args.train_batch_size, | 
| 911 | dropout=args.tag_dropout, | 922 | tag_dropout=args.tag_dropout, | 
| 923 | prompt_dropout=args.prompt_dropout, | ||
| 912 | filter=partial(keyword_filter, None, args.collection, args.exclude_collections), | 924 | filter=partial(keyword_filter, None, args.collection, args.exclude_collections), | 
| 913 | ) | 925 | ) | 
| 914 | dreambooth_datamodule.setup() | 926 | dreambooth_datamodule.setup() | 
| @@ -1051,6 +1063,7 @@ def main(): | |||
| 1051 | checkpoint_output_dir=dreambooth_checkpoint_output_dir, | 1063 | checkpoint_output_dir=dreambooth_checkpoint_output_dir, | 
| 1052 | sample_frequency=dreambooth_sample_frequency, | 1064 | sample_frequency=dreambooth_sample_frequency, | 
| 1053 | input_pertubation=args.input_pertubation, | 1065 | input_pertubation=args.input_pertubation, | 
| 1066 | text_encoder_unfreeze_last_n_layers=args.text_encoder_unfreeze_last_n_layers, | ||
| 1054 | no_val=args.valid_set_size == 0, | 1067 | no_val=args.valid_set_size == 0, | 
| 1055 | avg_loss=avg_loss, | 1068 | avg_loss=avg_loss, | 
| 1056 | avg_acc=avg_acc, | 1069 | avg_acc=avg_acc, | 
| diff --git a/train_lora.py b/train_lora.py index 51dc827..1ff25ff 100644 --- a/train_lora.py +++ b/train_lora.py | |||
| @@ -1137,7 +1137,7 @@ def main(): | |||
| 1137 | lora_datamodule = create_datamodule( | 1137 | lora_datamodule = create_datamodule( | 
| 1138 | valid_set_size=args.valid_set_size, | 1138 | valid_set_size=args.valid_set_size, | 
| 1139 | batch_size=args.train_batch_size, | 1139 | batch_size=args.train_batch_size, | 
| 1140 | dropout=args.tag_dropout, | 1140 | tag_dropout=args.tag_dropout, | 
| 1141 | filter=partial(keyword_filter, None, args.collection, args.exclude_collections), | 1141 | filter=partial(keyword_filter, None, args.collection, args.exclude_collections), | 
| 1142 | ) | 1142 | ) | 
| 1143 | lora_datamodule.setup() | 1143 | lora_datamodule.setup() | 
| diff --git a/train_ti.py b/train_ti.py index 7f93960..1dbd637 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -923,7 +923,7 @@ def main(): | |||
| 923 | progressive_buckets=args.progressive_buckets, | 923 | progressive_buckets=args.progressive_buckets, | 
| 924 | bucket_step_size=args.bucket_step_size, | 924 | bucket_step_size=args.bucket_step_size, | 
| 925 | bucket_max_pixels=args.bucket_max_pixels, | 925 | bucket_max_pixels=args.bucket_max_pixels, | 
| 926 | dropout=args.tag_dropout, | 926 | tag_dropout=args.tag_dropout, | 
| 927 | shuffle=not args.no_tag_shuffle, | 927 | shuffle=not args.no_tag_shuffle, | 
| 928 | template_key=data_template, | 928 | template_key=data_template, | 
| 929 | placeholder_tokens=args.placeholder_tokens, | 929 | placeholder_tokens=args.placeholder_tokens, | 
