diff options
Diffstat (limited to 'data')
| -rw-r--r-- | data/csv.py | 27 |
1 files changed, 19 insertions, 8 deletions
diff --git a/data/csv.py b/data/csv.py index 43bf14c..c38db6d 100644 --- a/data/csv.py +++ b/data/csv.py | |||
| @@ -156,12 +156,16 @@ class VlpnDataItem(NamedTuple): | |||
| 156 | 156 | ||
| 157 | def full_prompt( | 157 | def full_prompt( |
| 158 | self, | 158 | self, |
| 159 | dropout: float = 0, | 159 | prompt_dropout: float = 0, |
| 160 | tag_dropout: float = 0, | ||
| 160 | shuffle: bool = False, | 161 | shuffle: bool = False, |
| 161 | npgenerator: Optional[np.random.Generator] = None, | 162 | npgenerator: Optional[np.random.Generator] = None, |
| 162 | ): | 163 | ): |
| 164 | if prompt_dropout != 0 and np.random.random() <= prompt_dropout: | ||
| 165 | return "" | ||
| 166 | |||
| 163 | return keywords_to_str( | 167 | return keywords_to_str( |
| 164 | self.keywords, [self.prompt], dropout, shuffle, npgenerator | 168 | self.keywords, [self.prompt], tag_dropout, shuffle, npgenerator |
| 165 | ) | 169 | ) |
| 166 | 170 | ||
| 167 | 171 | ||
| @@ -200,7 +204,8 @@ class VlpnDataModule: | |||
| 200 | bucket_step_size: int = 64, | 204 | bucket_step_size: int = 64, |
| 201 | bucket_max_pixels: Optional[int] = None, | 205 | bucket_max_pixels: Optional[int] = None, |
| 202 | progressive_buckets: bool = False, | 206 | progressive_buckets: bool = False, |
| 203 | dropout: float = 0, | 207 | prompt_dropout: float = 0, |
| 208 | tag_dropout: float = 0, | ||
| 204 | shuffle: bool = False, | 209 | shuffle: bool = False, |
| 205 | interpolation: str = "bicubic", | 210 | interpolation: str = "bicubic", |
| 206 | color_jitter: bool = False, | 211 | color_jitter: bool = False, |
| @@ -236,7 +241,8 @@ class VlpnDataModule: | |||
| 236 | self.bucket_step_size = bucket_step_size | 241 | self.bucket_step_size = bucket_step_size |
| 237 | self.bucket_max_pixels = bucket_max_pixels | 242 | self.bucket_max_pixels = bucket_max_pixels |
| 238 | self.progressive_buckets = progressive_buckets | 243 | self.progressive_buckets = progressive_buckets |
| 239 | self.dropout = dropout | 244 | self.prompt_dropout = prompt_dropout |
| 245 | self.tag_dropout = tag_dropout | ||
| 240 | self.shuffle = shuffle | 246 | self.shuffle = shuffle |
| 241 | self.template_key = template_key | 247 | self.template_key = template_key |
| 242 | self.interpolation = interpolation | 248 | self.interpolation = interpolation |
| @@ -382,7 +388,8 @@ class VlpnDataModule: | |||
| 382 | interpolation=self.interpolation, | 388 | interpolation=self.interpolation, |
| 383 | color_jitter=self.color_jitter, | 389 | color_jitter=self.color_jitter, |
| 384 | num_class_images=self.num_class_images, | 390 | num_class_images=self.num_class_images, |
| 385 | dropout=self.dropout, | 391 | tag_dropout=self.tag_dropout, |
| 392 | prompt_dropout=self.prompt_dropout, | ||
| 386 | shuffle=self.shuffle, | 393 | shuffle=self.shuffle, |
| 387 | ) | 394 | ) |
| 388 | 395 | ||
| @@ -433,7 +440,8 @@ class VlpnDataset(IterableDataset): | |||
| 433 | fill_batch: bool = False, | 440 | fill_batch: bool = False, |
| 434 | num_class_images: int = 0, | 441 | num_class_images: int = 0, |
| 435 | size: int = 768, | 442 | size: int = 768, |
| 436 | dropout: float = 0, | 443 | tag_dropout: float = 0, |
| 444 | prompt_dropout: float = 0, | ||
| 437 | shuffle: bool = False, | 445 | shuffle: bool = False, |
| 438 | interpolation: str = "bicubic", | 446 | interpolation: str = "bicubic", |
| 439 | color_jitter: bool = False, | 447 | color_jitter: bool = False, |
| @@ -447,7 +455,8 @@ class VlpnDataset(IterableDataset): | |||
| 447 | self.tokenizer = tokenizer | 455 | self.tokenizer = tokenizer |
| 448 | self.num_class_images = num_class_images | 456 | self.num_class_images = num_class_images |
| 449 | self.size = size | 457 | self.size = size |
| 450 | self.dropout = dropout | 458 | self.tag_dropout = tag_dropout |
| 459 | self.prompt_dropout = prompt_dropout | ||
| 451 | self.shuffle = shuffle | 460 | self.shuffle = shuffle |
| 452 | self.interpolation = interpolations[interpolation] | 461 | self.interpolation = interpolations[interpolation] |
| 453 | self.color_jitter = color_jitter | 462 | self.color_jitter = color_jitter |
| @@ -558,7 +567,9 @@ class VlpnDataset(IterableDataset): | |||
| 558 | example["nprompt_ids"] = self.get_input_ids(item.nprompt) | 567 | example["nprompt_ids"] = self.get_input_ids(item.nprompt) |
| 559 | 568 | ||
| 560 | example["instance_prompt_ids"] = self.get_input_ids( | 569 | example["instance_prompt_ids"] = self.get_input_ids( |
| 561 | item.full_prompt(self.dropout, True, self.npgenerator) | 570 | item.full_prompt( |
| 571 | self.prompt_dropout, self.tag_dropout, True, self.npgenerator | ||
| 572 | ) | ||
| 562 | ) | 573 | ) |
| 563 | example["negative_prompt_ids"] = self.get_input_ids(item.nprompt) | 574 | example["negative_prompt_ids"] = self.get_input_ids(item.nprompt) |
| 564 | example["instance_images"] = image_transforms( | 575 | example["instance_images"] = image_transforms( |
