diff options
Diffstat (limited to 'train_dreambooth.py')
| -rw-r--r-- | train_dreambooth.py | 15 |
1 files changed, 14 insertions, 1 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py index ab3ed16..7745d27 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
| @@ -189,6 +189,12 @@ def parse_args(): | |||
| 189 | help="Tag dropout probability.", | 189 | help="Tag dropout probability.", |
| 190 | ) | 190 | ) |
| 191 | parser.add_argument( | 191 | parser.add_argument( |
| 192 | "--prompt_dropout", | ||
| 193 | type=float, | ||
| 194 | default=0, | ||
| 195 | help="Prompt dropout probability.", | ||
| 196 | ) | ||
| 197 | parser.add_argument( | ||
| 192 | "--no_tag_shuffle", | 198 | "--no_tag_shuffle", |
| 193 | action="store_true", | 199 | action="store_true", |
| 194 | help="Shuffle tags.", | 200 | help="Shuffle tags.", |
| @@ -255,6 +261,11 @@ def parse_args(): | |||
| 255 | help="Number of epochs the text encoder will be trained.", | 261 | help="Number of epochs the text encoder will be trained.", |
| 256 | ) | 262 | ) |
| 257 | parser.add_argument( | 263 | parser.add_argument( |
| 264 | "--text_encoder_unfreeze_last_n_layers", | ||
| 265 | default=2, | ||
| 266 | help="Number of text encoder layers to train.", | ||
| 267 | ) | ||
| 268 | parser.add_argument( | ||
| 258 | "--find_lr", | 269 | "--find_lr", |
| 259 | action="store_true", | 270 | action="store_true", |
| 260 | help="Automatically find a learning rate (no training).", | 271 | help="Automatically find a learning rate (no training).", |
| @@ -908,7 +919,8 @@ def main(): | |||
| 908 | dreambooth_datamodule = create_datamodule( | 919 | dreambooth_datamodule = create_datamodule( |
| 909 | valid_set_size=args.valid_set_size, | 920 | valid_set_size=args.valid_set_size, |
| 910 | batch_size=args.train_batch_size, | 921 | batch_size=args.train_batch_size, |
| 911 | dropout=args.tag_dropout, | 922 | tag_dropout=args.tag_dropout, |
| 923 | prompt_dropout=args.prompt_dropout, | ||
| 912 | filter=partial(keyword_filter, None, args.collection, args.exclude_collections), | 924 | filter=partial(keyword_filter, None, args.collection, args.exclude_collections), |
| 913 | ) | 925 | ) |
| 914 | dreambooth_datamodule.setup() | 926 | dreambooth_datamodule.setup() |
| @@ -1051,6 +1063,7 @@ def main(): | |||
| 1051 | checkpoint_output_dir=dreambooth_checkpoint_output_dir, | 1063 | checkpoint_output_dir=dreambooth_checkpoint_output_dir, |
| 1052 | sample_frequency=dreambooth_sample_frequency, | 1064 | sample_frequency=dreambooth_sample_frequency, |
| 1053 | input_pertubation=args.input_pertubation, | 1065 | input_pertubation=args.input_pertubation, |
| 1066 | text_encoder_unfreeze_last_n_layers=args.text_encoder_unfreeze_last_n_layers, | ||
| 1054 | no_val=args.valid_set_size == 0, | 1067 | no_val=args.valid_set_size == 0, |
| 1055 | avg_loss=avg_loss, | 1068 | avg_loss=avg_loss, |
| 1056 | avg_acc=avg_acc, | 1069 | avg_acc=avg_acc, |
