From b73469706091c8aaf3f028de96ab017f5a845639 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Tue, 13 Dec 2022 20:49:57 +0100 Subject: Optimized Textual Inversion training by filtering dataset by existence of added tokens --- textual_inversion.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) (limited to 'textual_inversion.py') diff --git a/textual_inversion.py b/textual_inversion.py index 19b8993..fd4a313 100644 --- a/textual_inversion.py +++ b/textual_inversion.py @@ -57,6 +57,11 @@ def parse_args(): default=None, help="A CSV file containing the training data." ) + parser.add_argument( + "--train_data_template", + type=str, + default="template", + ) parser.add_argument( "--instance_identifier", type=str, @@ -121,7 +126,7 @@ def parse_args(): parser.add_argument( "--tag_dropout", type=float, - default=0.1, + default=0, help="Tag dropout probability.", ) parser.add_argument( @@ -170,7 +175,7 @@ def parse_args(): parser.add_argument( "--lr_scheduler", type=str, - default="constant_with_warmup", + default="one_cycle", help=( 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' ' "constant", "constant_with_warmup", "one_cycle"]' @@ -670,8 +675,10 @@ def main(): repeats=args.repeats, dropout=args.tag_dropout, center_crop=args.center_crop, + template_key=args.train_data_template, valid_set_size=args.valid_set_size, num_workers=args.dataloader_num_workers, + keyword_filter=args.placeholder_token, collate_fn=collate_fn ) @@ -740,7 +747,7 @@ def main(): num_warmup_steps=warmup_steps, num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, num_cycles=args.lr_cycles or math.ceil(math.sqrt( - ((args.max_train_steps - args.lr_warmup_steps) / num_update_steps_per_epoch))), + ((args.max_train_steps - warmup_steps) / num_update_steps_per_epoch))), ) else: lr_scheduler = get_scheduler( -- cgit v1.2.3-54-g00ecf