diff options
| author | Volpeon <git@volpeon.ink> | 2022-12-13 20:49:57 +0100 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2022-12-13 20:49:57 +0100 |
| commit | b73469706091c8aaf3f028de96ab017f5a845639 (patch) | |
| tree | 892208ff6c19a11b9870e0ba298d88fb0d4bd5ba /textual_inversion.py | |
| parent | Fixed sample/checkpoint frequency (diff) | |
| download | textual-inversion-diff-b73469706091c8aaf3f028de96ab017f5a845639.tar.gz textual-inversion-diff-b73469706091c8aaf3f028de96ab017f5a845639.tar.bz2 textual-inversion-diff-b73469706091c8aaf3f028de96ab017f5a845639.zip | |
Optimized Textual Inversion training by filtering dataset by existence of added tokens
Diffstat (limited to 'textual_inversion.py')
| -rw-r--r-- | textual_inversion.py | 13 |
1 files changed, 10 insertions, 3 deletions
diff --git a/textual_inversion.py b/textual_inversion.py index 19b8993..fd4a313 100644 --- a/textual_inversion.py +++ b/textual_inversion.py | |||
| @@ -58,6 +58,11 @@ def parse_args(): | |||
| 58 | help="A CSV file containing the training data." | 58 | help="A CSV file containing the training data." |
| 59 | ) | 59 | ) |
| 60 | parser.add_argument( | 60 | parser.add_argument( |
| 61 | "--train_data_template", | ||
| 62 | type=str, | ||
| 63 | default="template", | ||
| 64 | ) | ||
| 65 | parser.add_argument( | ||
| 61 | "--instance_identifier", | 66 | "--instance_identifier", |
| 62 | type=str, | 67 | type=str, |
| 63 | default=None, | 68 | default=None, |
| @@ -121,7 +126,7 @@ def parse_args(): | |||
| 121 | parser.add_argument( | 126 | parser.add_argument( |
| 122 | "--tag_dropout", | 127 | "--tag_dropout", |
| 123 | type=float, | 128 | type=float, |
| 124 | default=0.1, | 129 | default=0, |
| 125 | help="Tag dropout probability.", | 130 | help="Tag dropout probability.", |
| 126 | ) | 131 | ) |
| 127 | parser.add_argument( | 132 | parser.add_argument( |
| @@ -170,7 +175,7 @@ def parse_args(): | |||
| 170 | parser.add_argument( | 175 | parser.add_argument( |
| 171 | "--lr_scheduler", | 176 | "--lr_scheduler", |
| 172 | type=str, | 177 | type=str, |
| 173 | default="constant_with_warmup", | 178 | default="one_cycle", |
| 174 | help=( | 179 | help=( |
| 175 | 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' | 180 | 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' |
| 176 | ' "constant", "constant_with_warmup", "one_cycle"]' | 181 | ' "constant", "constant_with_warmup", "one_cycle"]' |
| @@ -670,8 +675,10 @@ def main(): | |||
| 670 | repeats=args.repeats, | 675 | repeats=args.repeats, |
| 671 | dropout=args.tag_dropout, | 676 | dropout=args.tag_dropout, |
| 672 | center_crop=args.center_crop, | 677 | center_crop=args.center_crop, |
| 678 | template_key=args.train_data_template, | ||
| 673 | valid_set_size=args.valid_set_size, | 679 | valid_set_size=args.valid_set_size, |
| 674 | num_workers=args.dataloader_num_workers, | 680 | num_workers=args.dataloader_num_workers, |
| 681 | keyword_filter=args.placeholder_token, | ||
| 675 | collate_fn=collate_fn | 682 | collate_fn=collate_fn |
| 676 | ) | 683 | ) |
| 677 | 684 | ||
| @@ -740,7 +747,7 @@ def main(): | |||
| 740 | num_warmup_steps=warmup_steps, | 747 | num_warmup_steps=warmup_steps, |
| 741 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, | 748 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, |
| 742 | num_cycles=args.lr_cycles or math.ceil(math.sqrt( | 749 | num_cycles=args.lr_cycles or math.ceil(math.sqrt( |
| 743 | ((args.max_train_steps - args.lr_warmup_steps) / num_update_steps_per_epoch))), | 750 | ((args.max_train_steps - warmup_steps) / num_update_steps_per_epoch))), |
| 744 | ) | 751 | ) |
| 745 | else: | 752 | else: |
| 746 | lr_scheduler = get_scheduler( | 753 | lr_scheduler = get_scheduler( |
