diff options
Diffstat (limited to 'data')
-rw-r--r-- | data/csv.py | 27 |
1 files changed, 19 insertions, 8 deletions
diff --git a/data/csv.py b/data/csv.py index 43bf14c..c38db6d 100644 --- a/data/csv.py +++ b/data/csv.py | |||
@@ -156,12 +156,16 @@ class VlpnDataItem(NamedTuple): | |||
156 | 156 | ||
157 | def full_prompt( | 157 | def full_prompt( |
158 | self, | 158 | self, |
159 | dropout: float = 0, | 159 | prompt_dropout: float = 0, |
160 | tag_dropout: float = 0, | ||
160 | shuffle: bool = False, | 161 | shuffle: bool = False, |
161 | npgenerator: Optional[np.random.Generator] = None, | 162 | npgenerator: Optional[np.random.Generator] = None, |
162 | ): | 163 | ): |
164 | if prompt_dropout != 0 and np.random.random() <= prompt_dropout: | ||
165 | return "" | ||
166 | |||
163 | return keywords_to_str( | 167 | return keywords_to_str( |
164 | self.keywords, [self.prompt], dropout, shuffle, npgenerator | 168 | self.keywords, [self.prompt], tag_dropout, shuffle, npgenerator |
165 | ) | 169 | ) |
166 | 170 | ||
167 | 171 | ||
@@ -200,7 +204,8 @@ class VlpnDataModule: | |||
200 | bucket_step_size: int = 64, | 204 | bucket_step_size: int = 64, |
201 | bucket_max_pixels: Optional[int] = None, | 205 | bucket_max_pixels: Optional[int] = None, |
202 | progressive_buckets: bool = False, | 206 | progressive_buckets: bool = False, |
203 | dropout: float = 0, | 207 | prompt_dropout: float = 0, |
208 | tag_dropout: float = 0, | ||
204 | shuffle: bool = False, | 209 | shuffle: bool = False, |
205 | interpolation: str = "bicubic", | 210 | interpolation: str = "bicubic", |
206 | color_jitter: bool = False, | 211 | color_jitter: bool = False, |
@@ -236,7 +241,8 @@ class VlpnDataModule: | |||
236 | self.bucket_step_size = bucket_step_size | 241 | self.bucket_step_size = bucket_step_size |
237 | self.bucket_max_pixels = bucket_max_pixels | 242 | self.bucket_max_pixels = bucket_max_pixels |
238 | self.progressive_buckets = progressive_buckets | 243 | self.progressive_buckets = progressive_buckets |
239 | self.dropout = dropout | 244 | self.prompt_dropout = prompt_dropout |
245 | self.tag_dropout = tag_dropout | ||
240 | self.shuffle = shuffle | 246 | self.shuffle = shuffle |
241 | self.template_key = template_key | 247 | self.template_key = template_key |
242 | self.interpolation = interpolation | 248 | self.interpolation = interpolation |
@@ -382,7 +388,8 @@ class VlpnDataModule: | |||
382 | interpolation=self.interpolation, | 388 | interpolation=self.interpolation, |
383 | color_jitter=self.color_jitter, | 389 | color_jitter=self.color_jitter, |
384 | num_class_images=self.num_class_images, | 390 | num_class_images=self.num_class_images, |
385 | dropout=self.dropout, | 391 | tag_dropout=self.tag_dropout, |
392 | prompt_dropout=self.prompt_dropout, | ||
386 | shuffle=self.shuffle, | 393 | shuffle=self.shuffle, |
387 | ) | 394 | ) |
388 | 395 | ||
@@ -433,7 +440,8 @@ class VlpnDataset(IterableDataset): | |||
433 | fill_batch: bool = False, | 440 | fill_batch: bool = False, |
434 | num_class_images: int = 0, | 441 | num_class_images: int = 0, |
435 | size: int = 768, | 442 | size: int = 768, |
436 | dropout: float = 0, | 443 | tag_dropout: float = 0, |
444 | prompt_dropout: float = 0, | ||
437 | shuffle: bool = False, | 445 | shuffle: bool = False, |
438 | interpolation: str = "bicubic", | 446 | interpolation: str = "bicubic", |
439 | color_jitter: bool = False, | 447 | color_jitter: bool = False, |
@@ -447,7 +455,8 @@ class VlpnDataset(IterableDataset): | |||
447 | self.tokenizer = tokenizer | 455 | self.tokenizer = tokenizer |
448 | self.num_class_images = num_class_images | 456 | self.num_class_images = num_class_images |
449 | self.size = size | 457 | self.size = size |
450 | self.dropout = dropout | 458 | self.tag_dropout = tag_dropout |
459 | self.prompt_dropout = prompt_dropout | ||
451 | self.shuffle = shuffle | 460 | self.shuffle = shuffle |
452 | self.interpolation = interpolations[interpolation] | 461 | self.interpolation = interpolations[interpolation] |
453 | self.color_jitter = color_jitter | 462 | self.color_jitter = color_jitter |
@@ -558,7 +567,9 @@ class VlpnDataset(IterableDataset): | |||
558 | example["nprompt_ids"] = self.get_input_ids(item.nprompt) | 567 | example["nprompt_ids"] = self.get_input_ids(item.nprompt) |
559 | 568 | ||
560 | example["instance_prompt_ids"] = self.get_input_ids( | 569 | example["instance_prompt_ids"] = self.get_input_ids( |
561 | item.full_prompt(self.dropout, True, self.npgenerator) | 570 | item.full_prompt( |
571 | self.prompt_dropout, self.tag_dropout, True, self.npgenerator | ||
572 | ) | ||
562 | ) | 573 | ) |
563 | example["negative_prompt_ids"] = self.get_input_ids(item.nprompt) | 574 | example["negative_prompt_ids"] = self.get_input_ids(item.nprompt) |
564 | example["instance_images"] = image_transforms( | 575 | example["instance_images"] = image_transforms( |