summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--data/csv.py27
-rw-r--r--train_dreambooth.py15
-rw-r--r--train_lora.py2
-rw-r--r--train_ti.py2
4 files changed, 35 insertions, 11 deletions
diff --git a/data/csv.py b/data/csv.py
index 43bf14c..c38db6d 100644
--- a/data/csv.py
+++ b/data/csv.py
@@ -156,12 +156,16 @@ class VlpnDataItem(NamedTuple):
156 156
157 def full_prompt( 157 def full_prompt(
158 self, 158 self,
159 dropout: float = 0, 159 prompt_dropout: float = 0,
160 tag_dropout: float = 0,
160 shuffle: bool = False, 161 shuffle: bool = False,
161 npgenerator: Optional[np.random.Generator] = None, 162 npgenerator: Optional[np.random.Generator] = None,
162 ): 163 ):
164 if prompt_dropout != 0 and np.random.random() <= prompt_dropout:
165 return ""
166
163 return keywords_to_str( 167 return keywords_to_str(
164 self.keywords, [self.prompt], dropout, shuffle, npgenerator 168 self.keywords, [self.prompt], tag_dropout, shuffle, npgenerator
165 ) 169 )
166 170
167 171
@@ -200,7 +204,8 @@ class VlpnDataModule:
200 bucket_step_size: int = 64, 204 bucket_step_size: int = 64,
201 bucket_max_pixels: Optional[int] = None, 205 bucket_max_pixels: Optional[int] = None,
202 progressive_buckets: bool = False, 206 progressive_buckets: bool = False,
203 dropout: float = 0, 207 prompt_dropout: float = 0,
208 tag_dropout: float = 0,
204 shuffle: bool = False, 209 shuffle: bool = False,
205 interpolation: str = "bicubic", 210 interpolation: str = "bicubic",
206 color_jitter: bool = False, 211 color_jitter: bool = False,
@@ -236,7 +241,8 @@ class VlpnDataModule:
236 self.bucket_step_size = bucket_step_size 241 self.bucket_step_size = bucket_step_size
237 self.bucket_max_pixels = bucket_max_pixels 242 self.bucket_max_pixels = bucket_max_pixels
238 self.progressive_buckets = progressive_buckets 243 self.progressive_buckets = progressive_buckets
239 self.dropout = dropout 244 self.prompt_dropout = prompt_dropout
245 self.tag_dropout = tag_dropout
240 self.shuffle = shuffle 246 self.shuffle = shuffle
241 self.template_key = template_key 247 self.template_key = template_key
242 self.interpolation = interpolation 248 self.interpolation = interpolation
@@ -382,7 +388,8 @@ class VlpnDataModule:
382 interpolation=self.interpolation, 388 interpolation=self.interpolation,
383 color_jitter=self.color_jitter, 389 color_jitter=self.color_jitter,
384 num_class_images=self.num_class_images, 390 num_class_images=self.num_class_images,
385 dropout=self.dropout, 391 tag_dropout=self.tag_dropout,
392 prompt_dropout=self.prompt_dropout,
386 shuffle=self.shuffle, 393 shuffle=self.shuffle,
387 ) 394 )
388 395
@@ -433,7 +440,8 @@ class VlpnDataset(IterableDataset):
433 fill_batch: bool = False, 440 fill_batch: bool = False,
434 num_class_images: int = 0, 441 num_class_images: int = 0,
435 size: int = 768, 442 size: int = 768,
436 dropout: float = 0, 443 tag_dropout: float = 0,
444 prompt_dropout: float = 0,
437 shuffle: bool = False, 445 shuffle: bool = False,
438 interpolation: str = "bicubic", 446 interpolation: str = "bicubic",
439 color_jitter: bool = False, 447 color_jitter: bool = False,
@@ -447,7 +455,8 @@ class VlpnDataset(IterableDataset):
447 self.tokenizer = tokenizer 455 self.tokenizer = tokenizer
448 self.num_class_images = num_class_images 456 self.num_class_images = num_class_images
449 self.size = size 457 self.size = size
450 self.dropout = dropout 458 self.tag_dropout = tag_dropout
459 self.prompt_dropout = prompt_dropout
451 self.shuffle = shuffle 460 self.shuffle = shuffle
452 self.interpolation = interpolations[interpolation] 461 self.interpolation = interpolations[interpolation]
453 self.color_jitter = color_jitter 462 self.color_jitter = color_jitter
@@ -558,7 +567,9 @@ class VlpnDataset(IterableDataset):
558 example["nprompt_ids"] = self.get_input_ids(item.nprompt) 567 example["nprompt_ids"] = self.get_input_ids(item.nprompt)
559 568
560 example["instance_prompt_ids"] = self.get_input_ids( 569 example["instance_prompt_ids"] = self.get_input_ids(
561 item.full_prompt(self.dropout, True, self.npgenerator) 570 item.full_prompt(
571 self.prompt_dropout, self.tag_dropout, True, self.npgenerator
572 )
562 ) 573 )
563 example["negative_prompt_ids"] = self.get_input_ids(item.nprompt) 574 example["negative_prompt_ids"] = self.get_input_ids(item.nprompt)
564 example["instance_images"] = image_transforms( 575 example["instance_images"] = image_transforms(
diff --git a/train_dreambooth.py b/train_dreambooth.py
index ab3ed16..7745d27 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -189,6 +189,12 @@ def parse_args():
189 help="Tag dropout probability.", 189 help="Tag dropout probability.",
190 ) 190 )
191 parser.add_argument( 191 parser.add_argument(
192 "--prompt_dropout",
193 type=float,
194 default=0,
195 help="Prompt dropout probability.",
196 )
197 parser.add_argument(
192 "--no_tag_shuffle", 198 "--no_tag_shuffle",
193 action="store_true", 199 action="store_true",
194 help="Shuffle tags.", 200 help="Shuffle tags.",
@@ -255,6 +261,11 @@ def parse_args():
255 help="Number of epochs the text encoder will be trained.", 261 help="Number of epochs the text encoder will be trained.",
256 ) 262 )
257 parser.add_argument( 263 parser.add_argument(
264 "--text_encoder_unfreeze_last_n_layers",
265 default=2,
266 help="Number of text encoder layers to train.",
267 )
268 parser.add_argument(
258 "--find_lr", 269 "--find_lr",
259 action="store_true", 270 action="store_true",
260 help="Automatically find a learning rate (no training).", 271 help="Automatically find a learning rate (no training).",
@@ -908,7 +919,8 @@ def main():
908 dreambooth_datamodule = create_datamodule( 919 dreambooth_datamodule = create_datamodule(
909 valid_set_size=args.valid_set_size, 920 valid_set_size=args.valid_set_size,
910 batch_size=args.train_batch_size, 921 batch_size=args.train_batch_size,
911 dropout=args.tag_dropout, 922 tag_dropout=args.tag_dropout,
923 prompt_dropout=args.prompt_dropout,
912 filter=partial(keyword_filter, None, args.collection, args.exclude_collections), 924 filter=partial(keyword_filter, None, args.collection, args.exclude_collections),
913 ) 925 )
914 dreambooth_datamodule.setup() 926 dreambooth_datamodule.setup()
@@ -1051,6 +1063,7 @@ def main():
1051 checkpoint_output_dir=dreambooth_checkpoint_output_dir, 1063 checkpoint_output_dir=dreambooth_checkpoint_output_dir,
1052 sample_frequency=dreambooth_sample_frequency, 1064 sample_frequency=dreambooth_sample_frequency,
1053 input_pertubation=args.input_pertubation, 1065 input_pertubation=args.input_pertubation,
1066 text_encoder_unfreeze_last_n_layers=args.text_encoder_unfreeze_last_n_layers,
1054 no_val=args.valid_set_size == 0, 1067 no_val=args.valid_set_size == 0,
1055 avg_loss=avg_loss, 1068 avg_loss=avg_loss,
1056 avg_acc=avg_acc, 1069 avg_acc=avg_acc,
diff --git a/train_lora.py b/train_lora.py
index 51dc827..1ff25ff 100644
--- a/train_lora.py
+++ b/train_lora.py
@@ -1137,7 +1137,7 @@ def main():
1137 lora_datamodule = create_datamodule( 1137 lora_datamodule = create_datamodule(
1138 valid_set_size=args.valid_set_size, 1138 valid_set_size=args.valid_set_size,
1139 batch_size=args.train_batch_size, 1139 batch_size=args.train_batch_size,
1140 dropout=args.tag_dropout, 1140 tag_dropout=args.tag_dropout,
1141 filter=partial(keyword_filter, None, args.collection, args.exclude_collections), 1141 filter=partial(keyword_filter, None, args.collection, args.exclude_collections),
1142 ) 1142 )
1143 lora_datamodule.setup() 1143 lora_datamodule.setup()
diff --git a/train_ti.py b/train_ti.py
index 7f93960..1dbd637 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -923,7 +923,7 @@ def main():
923 progressive_buckets=args.progressive_buckets, 923 progressive_buckets=args.progressive_buckets,
924 bucket_step_size=args.bucket_step_size, 924 bucket_step_size=args.bucket_step_size,
925 bucket_max_pixels=args.bucket_max_pixels, 925 bucket_max_pixels=args.bucket_max_pixels,
926 dropout=args.tag_dropout, 926 tag_dropout=args.tag_dropout,
927 shuffle=not args.no_tag_shuffle, 927 shuffle=not args.no_tag_shuffle,
928 template_key=data_template, 928 template_key=data_template,
929 placeholder_tokens=args.placeholder_tokens, 929 placeholder_tokens=args.placeholder_tokens,