From 01f0b3bd5a7965776b420c97056f82601e2b7312 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Thu, 22 Jun 2023 18:34:53 +0200 Subject: Added prompt dropout --- data/csv.py | 27 +++++++++++++++++++-------- 1 file changed, 19 insertions(+), 8 deletions(-) (limited to 'data/csv.py') diff --git a/data/csv.py b/data/csv.py index 43bf14c..c38db6d 100644 --- a/data/csv.py +++ b/data/csv.py @@ -156,12 +156,16 @@ class VlpnDataItem(NamedTuple): def full_prompt( self, - dropout: float = 0, + prompt_dropout: float = 0, + tag_dropout: float = 0, shuffle: bool = False, npgenerator: Optional[np.random.Generator] = None, ): + if prompt_dropout != 0 and np.random.random() <= prompt_dropout: + return "" + return keywords_to_str( - self.keywords, [self.prompt], dropout, shuffle, npgenerator + self.keywords, [self.prompt], tag_dropout, shuffle, npgenerator ) @@ -200,7 +204,8 @@ class VlpnDataModule: bucket_step_size: int = 64, bucket_max_pixels: Optional[int] = None, progressive_buckets: bool = False, - dropout: float = 0, + prompt_dropout: float = 0, + tag_dropout: float = 0, shuffle: bool = False, interpolation: str = "bicubic", color_jitter: bool = False, @@ -236,7 +241,8 @@ class VlpnDataModule: self.bucket_step_size = bucket_step_size self.bucket_max_pixels = bucket_max_pixels self.progressive_buckets = progressive_buckets - self.dropout = dropout + self.prompt_dropout = prompt_dropout + self.tag_dropout = tag_dropout self.shuffle = shuffle self.template_key = template_key self.interpolation = interpolation @@ -382,7 +388,8 @@ class VlpnDataModule: interpolation=self.interpolation, color_jitter=self.color_jitter, num_class_images=self.num_class_images, - dropout=self.dropout, + tag_dropout=self.tag_dropout, + prompt_dropout=self.prompt_dropout, shuffle=self.shuffle, ) @@ -433,7 +440,8 @@ class VlpnDataset(IterableDataset): fill_batch: bool = False, num_class_images: int = 0, size: int = 768, - dropout: float = 0, + tag_dropout: float = 0, + prompt_dropout: float = 0, shuffle: bool = False, interpolation: str = "bicubic", color_jitter: bool = False, @@ -447,7 +455,8 @@ class VlpnDataset(IterableDataset): self.tokenizer = tokenizer self.num_class_images = num_class_images self.size = size - self.dropout = dropout + self.tag_dropout = tag_dropout + self.prompt_dropout = prompt_dropout self.shuffle = shuffle self.interpolation = interpolations[interpolation] self.color_jitter = color_jitter @@ -558,7 +567,9 @@ class VlpnDataset(IterableDataset): example["nprompt_ids"] = self.get_input_ids(item.nprompt) example["instance_prompt_ids"] = self.get_input_ids( - item.full_prompt(self.dropout, True, self.npgenerator) + item.full_prompt( + self.prompt_dropout, self.tag_dropout, True, self.npgenerator + ) ) example["negative_prompt_ids"] = self.get_input_ids(item.nprompt) example["instance_images"] = image_transforms( -- cgit v1.2.3-54-g00ecf