summaryrefslogtreecommitdiffstats
path: root/train_ti.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-12-30 13:48:26 +0100
committerVolpeon <git@volpeon.ink>2022-12-30 13:48:26 +0100
commitdfcfd6bc1db6b9eb12c8321d18fc7a461710e7e0 (patch)
treeda07cbadfad6f54e55e43e2fda21cef80cded5ea /train_ti.py
parentUpdate (diff)
downloadtextual-inversion-diff-dfcfd6bc1db6b9eb12c8321d18fc7a461710e7e0.tar.gz
textual-inversion-diff-dfcfd6bc1db6b9eb12c8321d18fc7a461710e7e0.tar.bz2
textual-inversion-diff-dfcfd6bc1db6b9eb12c8321d18fc7a461710e7e0.zip
Training script improvements
Diffstat (limited to 'train_ti.py')
-rw-r--r--train_ti.py43
1 files changed, 38 insertions, 5 deletions
diff --git a/train_ti.py b/train_ti.py
index b1f6a49..6aa4007 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -93,6 +93,18 @@ def parse_args():
93 help="The directory where class images will be saved.", 93 help="The directory where class images will be saved.",
94 ) 94 )
95 parser.add_argument( 95 parser.add_argument(
96 "--exclude_keywords",
97 type=str,
98 nargs='*',
99 help="Skip dataset items containing a listed keyword.",
100 )
101 parser.add_argument(
102 "--exclude_modes",
103 type=str,
104 nargs='*',
105 help="Exclude all items with a listed mode.",
106 )
107 parser.add_argument(
96 "--repeats", 108 "--repeats",
97 type=int, 109 type=int,
98 default=1, 110 default=1,
@@ -120,7 +132,8 @@ def parse_args():
120 "--seed", 132 "--seed",
121 type=int, 133 type=int,
122 default=None, 134 default=None,
123 help="A seed for reproducible training.") 135 help="A seed for reproducible training."
136 )
124 parser.add_argument( 137 parser.add_argument(
125 "--resolution", 138 "--resolution",
126 type=int, 139 type=int,
@@ -356,6 +369,12 @@ def parse_args():
356 if len(args.placeholder_token) != len(args.initializer_token): 369 if len(args.placeholder_token) != len(args.initializer_token):
357 raise ValueError("You must specify --placeholder_token") 370 raise ValueError("You must specify --placeholder_token")
358 371
372 if isinstance(args.exclude_keywords, str):
373 args.exclude_keywords = [args.exclude_keywords]
374
375 if isinstance(args.exclude_modes, str):
376 args.exclude_modes = [args.exclude_modes]
377
359 if args.output_dir is None: 378 if args.output_dir is None:
360 raise ValueError("You must specify --output_dir") 379 raise ValueError("You must specify --output_dir")
361 380
@@ -576,11 +595,22 @@ def main():
576 weight_dtype = torch.bfloat16 595 weight_dtype = torch.bfloat16
577 596
578 def keyword_filter(item: CSVDataItem): 597 def keyword_filter(item: CSVDataItem):
579 return any( 598 cond1 = any(
580 keyword in part 599 keyword in part
581 for keyword in args.placeholder_token 600 for keyword in args.placeholder_token
582 for part in item.prompt 601 for part in item.prompt
583 ) 602 )
603 cond2 = args.exclude_keywords is None or not any(
604 keyword in part
605 for keyword in args.exclude_keywords
606 for part in item.prompt
607 )
608 cond3 = args.mode is None or args.mode in item.mode
609 cond4 = args.exclude_modes is None or not any(
610 mode in item.mode
611 for mode in args.exclude_modes
612 )
613 return cond1 and cond2 and cond3 and cond4
584 614
585 def collate_fn(examples): 615 def collate_fn(examples):
586 prompts = [example["prompts"] for example in examples] 616 prompts = [example["prompts"] for example in examples]
@@ -617,7 +647,6 @@ def main():
617 num_class_images=args.num_class_images, 647 num_class_images=args.num_class_images,
618 size=args.resolution, 648 size=args.resolution,
619 repeats=args.repeats, 649 repeats=args.repeats,
620 mode=args.mode,
621 dropout=args.tag_dropout, 650 dropout=args.tag_dropout,
622 center_crop=args.center_crop, 651 center_crop=args.center_crop,
623 template_key=args.train_data_template, 652 template_key=args.train_data_template,
@@ -769,7 +798,7 @@ def main():
769 target, target_prior = torch.chunk(target, 2, dim=0) 798 target, target_prior = torch.chunk(target, 2, dim=0)
770 799
771 # Compute instance loss 800 # Compute instance loss
772 loss = F.mse_loss(model_pred.float(), target.float(), reduction="none").mean([1, 2, 3]).mean() 801 loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
773 802
774 # Compute prior loss 803 # Compute prior loss
775 prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") 804 prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean")
@@ -785,7 +814,7 @@ def main():
785 814
786 if args.find_lr: 815 if args.find_lr:
787 lr_finder = LRFinder(accelerator, text_encoder, optimizer, train_dataloader, val_dataloader, loop) 816 lr_finder = LRFinder(accelerator, text_encoder, optimizer, train_dataloader, val_dataloader, loop)
788 lr_finder.run(min_lr=1e-6, num_train_batches=4) 817 lr_finder.run(min_lr=1e-6, num_train_batches=1)
789 818
790 plt.savefig(basepath.joinpath("lr.png")) 819 plt.savefig(basepath.joinpath("lr.png"))
791 plt.close() 820 plt.close()
@@ -798,6 +827,10 @@ def main():
798 config = vars(args).copy() 827 config = vars(args).copy()
799 config["initializer_token"] = " ".join(config["initializer_token"]) 828 config["initializer_token"] = " ".join(config["initializer_token"])
800 config["placeholder_token"] = " ".join(config["placeholder_token"]) 829 config["placeholder_token"] = " ".join(config["placeholder_token"])
830 if config["exclude_modes"] is not None:
831 config["exclude_modes"] = " ".join(config["exclude_modes"])
832 if config["exclude_keywords"] is not None:
833 config["exclude_keywords"] = " ".join(config["exclude_keywords"])
801 accelerator.init_trackers("textual_inversion", config=config) 834 accelerator.init_trackers("textual_inversion", config=config)
802 835
803 # Train! 836 # Train!