From 01f0b3bd5a7965776b420c97056f82601e2b7312 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Thu, 22 Jun 2023 18:34:53 +0200 Subject: Added prompt dropout --- data/csv.py | 27 +++++++++++++++++++-------- train_dreambooth.py | 15 ++++++++++++++- train_lora.py | 2 +- train_ti.py | 2 +- 4 files changed, 35 insertions(+), 11 deletions(-) diff --git a/data/csv.py b/data/csv.py index 43bf14c..c38db6d 100644 --- a/data/csv.py +++ b/data/csv.py @@ -156,12 +156,16 @@ class VlpnDataItem(NamedTuple): def full_prompt( self, - dropout: float = 0, + prompt_dropout: float = 0, + tag_dropout: float = 0, shuffle: bool = False, npgenerator: Optional[np.random.Generator] = None, ): + if prompt_dropout != 0 and np.random.random() <= prompt_dropout: + return "" + return keywords_to_str( - self.keywords, [self.prompt], dropout, shuffle, npgenerator + self.keywords, [self.prompt], tag_dropout, shuffle, npgenerator ) @@ -200,7 +204,8 @@ class VlpnDataModule: bucket_step_size: int = 64, bucket_max_pixels: Optional[int] = None, progressive_buckets: bool = False, - dropout: float = 0, + prompt_dropout: float = 0, + tag_dropout: float = 0, shuffle: bool = False, interpolation: str = "bicubic", color_jitter: bool = False, @@ -236,7 +241,8 @@ class VlpnDataModule: self.bucket_step_size = bucket_step_size self.bucket_max_pixels = bucket_max_pixels self.progressive_buckets = progressive_buckets - self.dropout = dropout + self.prompt_dropout = prompt_dropout + self.tag_dropout = tag_dropout self.shuffle = shuffle self.template_key = template_key self.interpolation = interpolation @@ -382,7 +388,8 @@ class VlpnDataModule: interpolation=self.interpolation, color_jitter=self.color_jitter, num_class_images=self.num_class_images, - dropout=self.dropout, + tag_dropout=self.tag_dropout, + prompt_dropout=self.prompt_dropout, shuffle=self.shuffle, ) @@ -433,7 +440,8 @@ class VlpnDataset(IterableDataset): fill_batch: bool = False, num_class_images: int = 0, size: int = 768, - dropout: float = 0, + tag_dropout: float = 0, + prompt_dropout: float = 0, shuffle: bool = False, interpolation: str = "bicubic", color_jitter: bool = False, @@ -447,7 +455,8 @@ class VlpnDataset(IterableDataset): self.tokenizer = tokenizer self.num_class_images = num_class_images self.size = size - self.dropout = dropout + self.tag_dropout = tag_dropout + self.prompt_dropout = prompt_dropout self.shuffle = shuffle self.interpolation = interpolations[interpolation] self.color_jitter = color_jitter @@ -558,7 +567,9 @@ class VlpnDataset(IterableDataset): example["nprompt_ids"] = self.get_input_ids(item.nprompt) example["instance_prompt_ids"] = self.get_input_ids( - item.full_prompt(self.dropout, True, self.npgenerator) + item.full_prompt( + self.prompt_dropout, self.tag_dropout, True, self.npgenerator + ) ) example["negative_prompt_ids"] = self.get_input_ids(item.nprompt) example["instance_images"] = image_transforms( diff --git a/train_dreambooth.py b/train_dreambooth.py index ab3ed16..7745d27 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py @@ -188,6 +188,12 @@ def parse_args(): default=0, help="Tag dropout probability.", ) + parser.add_argument( + "--prompt_dropout", + type=float, + default=0, + help="Prompt dropout probability.", + ) parser.add_argument( "--no_tag_shuffle", action="store_true", @@ -254,6 +260,11 @@ def parse_args(): default=999999, help="Number of epochs the text encoder will be trained.", ) + parser.add_argument( + "--text_encoder_unfreeze_last_n_layers", + default=2, + help="Number of text encoder layers to train.", + ) parser.add_argument( "--find_lr", action="store_true", @@ -908,7 +919,8 @@ def main(): dreambooth_datamodule = create_datamodule( valid_set_size=args.valid_set_size, batch_size=args.train_batch_size, - dropout=args.tag_dropout, + tag_dropout=args.tag_dropout, + prompt_dropout=args.prompt_dropout, filter=partial(keyword_filter, None, args.collection, args.exclude_collections), ) dreambooth_datamodule.setup() @@ -1051,6 +1063,7 @@ def main(): checkpoint_output_dir=dreambooth_checkpoint_output_dir, sample_frequency=dreambooth_sample_frequency, input_pertubation=args.input_pertubation, + text_encoder_unfreeze_last_n_layers=args.text_encoder_unfreeze_last_n_layers, no_val=args.valid_set_size == 0, avg_loss=avg_loss, avg_acc=avg_acc, diff --git a/train_lora.py b/train_lora.py index 51dc827..1ff25ff 100644 --- a/train_lora.py +++ b/train_lora.py @@ -1137,7 +1137,7 @@ def main(): lora_datamodule = create_datamodule( valid_set_size=args.valid_set_size, batch_size=args.train_batch_size, - dropout=args.tag_dropout, + tag_dropout=args.tag_dropout, filter=partial(keyword_filter, None, args.collection, args.exclude_collections), ) lora_datamodule.setup() diff --git a/train_ti.py b/train_ti.py index 7f93960..1dbd637 100644 --- a/train_ti.py +++ b/train_ti.py @@ -923,7 +923,7 @@ def main(): progressive_buckets=args.progressive_buckets, bucket_step_size=args.bucket_step_size, bucket_max_pixels=args.bucket_max_pixels, - dropout=args.tag_dropout, + tag_dropout=args.tag_dropout, shuffle=not args.no_tag_shuffle, template_key=data_template, placeholder_tokens=args.placeholder_tokens, -- cgit v1.2.3-70-g09d2