summaryrefslogtreecommitdiffstats
path: root/textual_inversion.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-12-13 20:49:57 +0100
committerVolpeon <git@volpeon.ink>2022-12-13 20:49:57 +0100
commitb73469706091c8aaf3f028de96ab017f5a845639 (patch)
tree892208ff6c19a11b9870e0ba298d88fb0d4bd5ba /textual_inversion.py
parentFixed sample/checkpoint frequency (diff)
downloadtextual-inversion-diff-b73469706091c8aaf3f028de96ab017f5a845639.tar.gz
textual-inversion-diff-b73469706091c8aaf3f028de96ab017f5a845639.tar.bz2
textual-inversion-diff-b73469706091c8aaf3f028de96ab017f5a845639.zip
Optimized Textual Inversion training by filtering dataset by existence of added tokens
Diffstat (limited to 'textual_inversion.py')
-rw-r--r--textual_inversion.py13
1 files changed, 10 insertions, 3 deletions
diff --git a/textual_inversion.py b/textual_inversion.py
index 19b8993..fd4a313 100644
--- a/textual_inversion.py
+++ b/textual_inversion.py
@@ -58,6 +58,11 @@ def parse_args():
58 help="A CSV file containing the training data." 58 help="A CSV file containing the training data."
59 ) 59 )
60 parser.add_argument( 60 parser.add_argument(
61 "--train_data_template",
62 type=str,
63 default="template",
64 )
65 parser.add_argument(
61 "--instance_identifier", 66 "--instance_identifier",
62 type=str, 67 type=str,
63 default=None, 68 default=None,
@@ -121,7 +126,7 @@ def parse_args():
121 parser.add_argument( 126 parser.add_argument(
122 "--tag_dropout", 127 "--tag_dropout",
123 type=float, 128 type=float,
124 default=0.1, 129 default=0,
125 help="Tag dropout probability.", 130 help="Tag dropout probability.",
126 ) 131 )
127 parser.add_argument( 132 parser.add_argument(
@@ -170,7 +175,7 @@ def parse_args():
170 parser.add_argument( 175 parser.add_argument(
171 "--lr_scheduler", 176 "--lr_scheduler",
172 type=str, 177 type=str,
173 default="constant_with_warmup", 178 default="one_cycle",
174 help=( 179 help=(
175 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' 180 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
176 ' "constant", "constant_with_warmup", "one_cycle"]' 181 ' "constant", "constant_with_warmup", "one_cycle"]'
@@ -670,8 +675,10 @@ def main():
670 repeats=args.repeats, 675 repeats=args.repeats,
671 dropout=args.tag_dropout, 676 dropout=args.tag_dropout,
672 center_crop=args.center_crop, 677 center_crop=args.center_crop,
678 template_key=args.train_data_template,
673 valid_set_size=args.valid_set_size, 679 valid_set_size=args.valid_set_size,
674 num_workers=args.dataloader_num_workers, 680 num_workers=args.dataloader_num_workers,
681 keyword_filter=args.placeholder_token,
675 collate_fn=collate_fn 682 collate_fn=collate_fn
676 ) 683 )
677 684
@@ -740,7 +747,7 @@ def main():
740 num_warmup_steps=warmup_steps, 747 num_warmup_steps=warmup_steps,
741 num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, 748 num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
742 num_cycles=args.lr_cycles or math.ceil(math.sqrt( 749 num_cycles=args.lr_cycles or math.ceil(math.sqrt(
743 ((args.max_train_steps - args.lr_warmup_steps) / num_update_steps_per_epoch))), 750 ((args.max_train_steps - warmup_steps) / num_update_steps_per_epoch))),
744 ) 751 )
745 else: 752 else:
746 lr_scheduler = get_scheduler( 753 lr_scheduler = get_scheduler(