diff options
Diffstat (limited to 'textual_inversion.py')
-rw-r--r-- | textual_inversion.py | 13 |
1 files changed, 10 insertions, 3 deletions
diff --git a/textual_inversion.py b/textual_inversion.py index 19b8993..fd4a313 100644 --- a/textual_inversion.py +++ b/textual_inversion.py | |||
@@ -58,6 +58,11 @@ def parse_args(): | |||
58 | help="A CSV file containing the training data." | 58 | help="A CSV file containing the training data." |
59 | ) | 59 | ) |
60 | parser.add_argument( | 60 | parser.add_argument( |
61 | "--train_data_template", | ||
62 | type=str, | ||
63 | default="template", | ||
64 | ) | ||
65 | parser.add_argument( | ||
61 | "--instance_identifier", | 66 | "--instance_identifier", |
62 | type=str, | 67 | type=str, |
63 | default=None, | 68 | default=None, |
@@ -121,7 +126,7 @@ def parse_args(): | |||
121 | parser.add_argument( | 126 | parser.add_argument( |
122 | "--tag_dropout", | 127 | "--tag_dropout", |
123 | type=float, | 128 | type=float, |
124 | default=0.1, | 129 | default=0, |
125 | help="Tag dropout probability.", | 130 | help="Tag dropout probability.", |
126 | ) | 131 | ) |
127 | parser.add_argument( | 132 | parser.add_argument( |
@@ -170,7 +175,7 @@ def parse_args(): | |||
170 | parser.add_argument( | 175 | parser.add_argument( |
171 | "--lr_scheduler", | 176 | "--lr_scheduler", |
172 | type=str, | 177 | type=str, |
173 | default="constant_with_warmup", | 178 | default="one_cycle", |
174 | help=( | 179 | help=( |
175 | 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' | 180 | 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' |
176 | ' "constant", "constant_with_warmup", "one_cycle"]' | 181 | ' "constant", "constant_with_warmup", "one_cycle"]' |
@@ -670,8 +675,10 @@ def main(): | |||
670 | repeats=args.repeats, | 675 | repeats=args.repeats, |
671 | dropout=args.tag_dropout, | 676 | dropout=args.tag_dropout, |
672 | center_crop=args.center_crop, | 677 | center_crop=args.center_crop, |
678 | template_key=args.train_data_template, | ||
673 | valid_set_size=args.valid_set_size, | 679 | valid_set_size=args.valid_set_size, |
674 | num_workers=args.dataloader_num_workers, | 680 | num_workers=args.dataloader_num_workers, |
681 | keyword_filter=args.placeholder_token, | ||
675 | collate_fn=collate_fn | 682 | collate_fn=collate_fn |
676 | ) | 683 | ) |
677 | 684 | ||
@@ -740,7 +747,7 @@ def main(): | |||
740 | num_warmup_steps=warmup_steps, | 747 | num_warmup_steps=warmup_steps, |
741 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, | 748 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, |
742 | num_cycles=args.lr_cycles or math.ceil(math.sqrt( | 749 | num_cycles=args.lr_cycles or math.ceil(math.sqrt( |
743 | ((args.max_train_steps - args.lr_warmup_steps) / num_update_steps_per_epoch))), | 750 | ((args.max_train_steps - warmup_steps) / num_update_steps_per_epoch))), |
744 | ) | 751 | ) |
745 | else: | 752 | else: |
746 | lr_scheduler = get_scheduler( | 753 | lr_scheduler = get_scheduler( |