summaryrefslogtreecommitdiffstats
path: root/data
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-06-22 18:34:53 +0200
committerVolpeon <git@volpeon.ink>2023-06-22 18:34:53 +0200
commit01f0b3bd5a7965776b420c97056f82601e2b7312 (patch)
tree236bde99b6116abdc7ca75fb57828ce21b74ba32 /data
parentUpdate (diff)
downloadtextual-inversion-diff-01f0b3bd5a7965776b420c97056f82601e2b7312.tar.gz
textual-inversion-diff-01f0b3bd5a7965776b420c97056f82601e2b7312.tar.bz2
textual-inversion-diff-01f0b3bd5a7965776b420c97056f82601e2b7312.zip
Added prompt dropout
Diffstat (limited to 'data')
-rw-r--r--data/csv.py27
1 files changed, 19 insertions, 8 deletions
diff --git a/data/csv.py b/data/csv.py
index 43bf14c..c38db6d 100644
--- a/data/csv.py
+++ b/data/csv.py
@@ -156,12 +156,16 @@ class VlpnDataItem(NamedTuple):
156 156
157 def full_prompt( 157 def full_prompt(
158 self, 158 self,
159 dropout: float = 0, 159 prompt_dropout: float = 0,
160 tag_dropout: float = 0,
160 shuffle: bool = False, 161 shuffle: bool = False,
161 npgenerator: Optional[np.random.Generator] = None, 162 npgenerator: Optional[np.random.Generator] = None,
162 ): 163 ):
164 if prompt_dropout != 0 and np.random.random() <= prompt_dropout:
165 return ""
166
163 return keywords_to_str( 167 return keywords_to_str(
164 self.keywords, [self.prompt], dropout, shuffle, npgenerator 168 self.keywords, [self.prompt], tag_dropout, shuffle, npgenerator
165 ) 169 )
166 170
167 171
@@ -200,7 +204,8 @@ class VlpnDataModule:
200 bucket_step_size: int = 64, 204 bucket_step_size: int = 64,
201 bucket_max_pixels: Optional[int] = None, 205 bucket_max_pixels: Optional[int] = None,
202 progressive_buckets: bool = False, 206 progressive_buckets: bool = False,
203 dropout: float = 0, 207 prompt_dropout: float = 0,
208 tag_dropout: float = 0,
204 shuffle: bool = False, 209 shuffle: bool = False,
205 interpolation: str = "bicubic", 210 interpolation: str = "bicubic",
206 color_jitter: bool = False, 211 color_jitter: bool = False,
@@ -236,7 +241,8 @@ class VlpnDataModule:
236 self.bucket_step_size = bucket_step_size 241 self.bucket_step_size = bucket_step_size
237 self.bucket_max_pixels = bucket_max_pixels 242 self.bucket_max_pixels = bucket_max_pixels
238 self.progressive_buckets = progressive_buckets 243 self.progressive_buckets = progressive_buckets
239 self.dropout = dropout 244 self.prompt_dropout = prompt_dropout
245 self.tag_dropout = tag_dropout
240 self.shuffle = shuffle 246 self.shuffle = shuffle
241 self.template_key = template_key 247 self.template_key = template_key
242 self.interpolation = interpolation 248 self.interpolation = interpolation
@@ -382,7 +388,8 @@ class VlpnDataModule:
382 interpolation=self.interpolation, 388 interpolation=self.interpolation,
383 color_jitter=self.color_jitter, 389 color_jitter=self.color_jitter,
384 num_class_images=self.num_class_images, 390 num_class_images=self.num_class_images,
385 dropout=self.dropout, 391 tag_dropout=self.tag_dropout,
392 prompt_dropout=self.prompt_dropout,
386 shuffle=self.shuffle, 393 shuffle=self.shuffle,
387 ) 394 )
388 395
@@ -433,7 +440,8 @@ class VlpnDataset(IterableDataset):
433 fill_batch: bool = False, 440 fill_batch: bool = False,
434 num_class_images: int = 0, 441 num_class_images: int = 0,
435 size: int = 768, 442 size: int = 768,
436 dropout: float = 0, 443 tag_dropout: float = 0,
444 prompt_dropout: float = 0,
437 shuffle: bool = False, 445 shuffle: bool = False,
438 interpolation: str = "bicubic", 446 interpolation: str = "bicubic",
439 color_jitter: bool = False, 447 color_jitter: bool = False,
@@ -447,7 +455,8 @@ class VlpnDataset(IterableDataset):
447 self.tokenizer = tokenizer 455 self.tokenizer = tokenizer
448 self.num_class_images = num_class_images 456 self.num_class_images = num_class_images
449 self.size = size 457 self.size = size
450 self.dropout = dropout 458 self.tag_dropout = tag_dropout
459 self.prompt_dropout = prompt_dropout
451 self.shuffle = shuffle 460 self.shuffle = shuffle
452 self.interpolation = interpolations[interpolation] 461 self.interpolation = interpolations[interpolation]
453 self.color_jitter = color_jitter 462 self.color_jitter = color_jitter
@@ -558,7 +567,9 @@ class VlpnDataset(IterableDataset):
558 example["nprompt_ids"] = self.get_input_ids(item.nprompt) 567 example["nprompt_ids"] = self.get_input_ids(item.nprompt)
559 568
560 example["instance_prompt_ids"] = self.get_input_ids( 569 example["instance_prompt_ids"] = self.get_input_ids(
561 item.full_prompt(self.dropout, True, self.npgenerator) 570 item.full_prompt(
571 self.prompt_dropout, self.tag_dropout, True, self.npgenerator
572 )
562 ) 573 )
563 example["negative_prompt_ids"] = self.get_input_ids(item.nprompt) 574 example["negative_prompt_ids"] = self.get_input_ids(item.nprompt)
564 example["instance_images"] = image_transforms( 575 example["instance_images"] = image_transforms(