summaryrefslogtreecommitdiffstats
path: root/train_dreambooth.py
diff options
context:
space:
mode:
Diffstat (limited to 'train_dreambooth.py')
-rw-r--r--train_dreambooth.py15
1 files changed, 14 insertions, 1 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py
index ab3ed16..7745d27 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -189,6 +189,12 @@ def parse_args():
189 help="Tag dropout probability.", 189 help="Tag dropout probability.",
190 ) 190 )
191 parser.add_argument( 191 parser.add_argument(
192 "--prompt_dropout",
193 type=float,
194 default=0,
195 help="Prompt dropout probability.",
196 )
197 parser.add_argument(
192 "--no_tag_shuffle", 198 "--no_tag_shuffle",
193 action="store_true", 199 action="store_true",
194 help="Shuffle tags.", 200 help="Shuffle tags.",
@@ -255,6 +261,11 @@ def parse_args():
255 help="Number of epochs the text encoder will be trained.", 261 help="Number of epochs the text encoder will be trained.",
256 ) 262 )
257 parser.add_argument( 263 parser.add_argument(
264 "--text_encoder_unfreeze_last_n_layers",
265 default=2,
266 help="Number of text encoder layers to train.",
267 )
268 parser.add_argument(
258 "--find_lr", 269 "--find_lr",
259 action="store_true", 270 action="store_true",
260 help="Automatically find a learning rate (no training).", 271 help="Automatically find a learning rate (no training).",
@@ -908,7 +919,8 @@ def main():
908 dreambooth_datamodule = create_datamodule( 919 dreambooth_datamodule = create_datamodule(
909 valid_set_size=args.valid_set_size, 920 valid_set_size=args.valid_set_size,
910 batch_size=args.train_batch_size, 921 batch_size=args.train_batch_size,
911 dropout=args.tag_dropout, 922 tag_dropout=args.tag_dropout,
923 prompt_dropout=args.prompt_dropout,
912 filter=partial(keyword_filter, None, args.collection, args.exclude_collections), 924 filter=partial(keyword_filter, None, args.collection, args.exclude_collections),
913 ) 925 )
914 dreambooth_datamodule.setup() 926 dreambooth_datamodule.setup()
@@ -1051,6 +1063,7 @@ def main():
1051 checkpoint_output_dir=dreambooth_checkpoint_output_dir, 1063 checkpoint_output_dir=dreambooth_checkpoint_output_dir,
1052 sample_frequency=dreambooth_sample_frequency, 1064 sample_frequency=dreambooth_sample_frequency,
1053 input_pertubation=args.input_pertubation, 1065 input_pertubation=args.input_pertubation,
1066 text_encoder_unfreeze_last_n_layers=args.text_encoder_unfreeze_last_n_layers,
1054 no_val=args.valid_set_size == 0, 1067 no_val=args.valid_set_size == 0,
1055 avg_loss=avg_loss, 1068 avg_loss=avg_loss,
1056 avg_acc=avg_acc, 1069 avg_acc=avg_acc,