diff options
Diffstat (limited to 'train_dreambooth.py')
-rw-r--r-- | train_dreambooth.py | 15 |
1 files changed, 14 insertions, 1 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py index ab3ed16..7745d27 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
@@ -189,6 +189,12 @@ def parse_args(): | |||
189 | help="Tag dropout probability.", | 189 | help="Tag dropout probability.", |
190 | ) | 190 | ) |
191 | parser.add_argument( | 191 | parser.add_argument( |
192 | "--prompt_dropout", | ||
193 | type=float, | ||
194 | default=0, | ||
195 | help="Prompt dropout probability.", | ||
196 | ) | ||
197 | parser.add_argument( | ||
192 | "--no_tag_shuffle", | 198 | "--no_tag_shuffle", |
193 | action="store_true", | 199 | action="store_true", |
194 | help="Shuffle tags.", | 200 | help="Shuffle tags.", |
@@ -255,6 +261,11 @@ def parse_args(): | |||
255 | help="Number of epochs the text encoder will be trained.", | 261 | help="Number of epochs the text encoder will be trained.", |
256 | ) | 262 | ) |
257 | parser.add_argument( | 263 | parser.add_argument( |
264 | "--text_encoder_unfreeze_last_n_layers", | ||
265 | default=2, | ||
266 | help="Number of text encoder layers to train.", | ||
267 | ) | ||
268 | parser.add_argument( | ||
258 | "--find_lr", | 269 | "--find_lr", |
259 | action="store_true", | 270 | action="store_true", |
260 | help="Automatically find a learning rate (no training).", | 271 | help="Automatically find a learning rate (no training).", |
@@ -908,7 +919,8 @@ def main(): | |||
908 | dreambooth_datamodule = create_datamodule( | 919 | dreambooth_datamodule = create_datamodule( |
909 | valid_set_size=args.valid_set_size, | 920 | valid_set_size=args.valid_set_size, |
910 | batch_size=args.train_batch_size, | 921 | batch_size=args.train_batch_size, |
911 | dropout=args.tag_dropout, | 922 | tag_dropout=args.tag_dropout, |
923 | prompt_dropout=args.prompt_dropout, | ||
912 | filter=partial(keyword_filter, None, args.collection, args.exclude_collections), | 924 | filter=partial(keyword_filter, None, args.collection, args.exclude_collections), |
913 | ) | 925 | ) |
914 | dreambooth_datamodule.setup() | 926 | dreambooth_datamodule.setup() |
@@ -1051,6 +1063,7 @@ def main(): | |||
1051 | checkpoint_output_dir=dreambooth_checkpoint_output_dir, | 1063 | checkpoint_output_dir=dreambooth_checkpoint_output_dir, |
1052 | sample_frequency=dreambooth_sample_frequency, | 1064 | sample_frequency=dreambooth_sample_frequency, |
1053 | input_pertubation=args.input_pertubation, | 1065 | input_pertubation=args.input_pertubation, |
1066 | text_encoder_unfreeze_last_n_layers=args.text_encoder_unfreeze_last_n_layers, | ||
1054 | no_val=args.valid_set_size == 0, | 1067 | no_val=args.valid_set_size == 0, |
1055 | avg_loss=avg_loss, | 1068 | avg_loss=avg_loss, |
1056 | avg_acc=avg_acc, | 1069 | avg_acc=avg_acc, |