diff options
author | Volpeon <git@volpeon.ink> | 2022-12-30 13:48:26 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2022-12-30 13:48:26 +0100 |
commit | dfcfd6bc1db6b9eb12c8321d18fc7a461710e7e0 (patch) | |
tree | da07cbadfad6f54e55e43e2fda21cef80cded5ea /train_ti.py | |
parent | Update (diff) | |
download | textual-inversion-diff-dfcfd6bc1db6b9eb12c8321d18fc7a461710e7e0.tar.gz textual-inversion-diff-dfcfd6bc1db6b9eb12c8321d18fc7a461710e7e0.tar.bz2 textual-inversion-diff-dfcfd6bc1db6b9eb12c8321d18fc7a461710e7e0.zip |
Training script improvements
Diffstat (limited to 'train_ti.py')
-rw-r--r-- | train_ti.py | 43 |
1 files changed, 38 insertions, 5 deletions
diff --git a/train_ti.py b/train_ti.py index b1f6a49..6aa4007 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -93,6 +93,18 @@ def parse_args(): | |||
93 | help="The directory where class images will be saved.", | 93 | help="The directory where class images will be saved.", |
94 | ) | 94 | ) |
95 | parser.add_argument( | 95 | parser.add_argument( |
96 | "--exclude_keywords", | ||
97 | type=str, | ||
98 | nargs='*', | ||
99 | help="Skip dataset items containing a listed keyword.", | ||
100 | ) | ||
101 | parser.add_argument( | ||
102 | "--exclude_modes", | ||
103 | type=str, | ||
104 | nargs='*', | ||
105 | help="Exclude all items with a listed mode.", | ||
106 | ) | ||
107 | parser.add_argument( | ||
96 | "--repeats", | 108 | "--repeats", |
97 | type=int, | 109 | type=int, |
98 | default=1, | 110 | default=1, |
@@ -120,7 +132,8 @@ def parse_args(): | |||
120 | "--seed", | 132 | "--seed", |
121 | type=int, | 133 | type=int, |
122 | default=None, | 134 | default=None, |
123 | help="A seed for reproducible training.") | 135 | help="A seed for reproducible training." |
136 | ) | ||
124 | parser.add_argument( | 137 | parser.add_argument( |
125 | "--resolution", | 138 | "--resolution", |
126 | type=int, | 139 | type=int, |
@@ -356,6 +369,12 @@ def parse_args(): | |||
356 | if len(args.placeholder_token) != len(args.initializer_token): | 369 | if len(args.placeholder_token) != len(args.initializer_token): |
357 | raise ValueError("You must specify --placeholder_token") | 370 | raise ValueError("You must specify --placeholder_token") |
358 | 371 | ||
372 | if isinstance(args.exclude_keywords, str): | ||
373 | args.exclude_keywords = [args.exclude_keywords] | ||
374 | |||
375 | if isinstance(args.exclude_modes, str): | ||
376 | args.exclude_modes = [args.exclude_modes] | ||
377 | |||
359 | if args.output_dir is None: | 378 | if args.output_dir is None: |
360 | raise ValueError("You must specify --output_dir") | 379 | raise ValueError("You must specify --output_dir") |
361 | 380 | ||
@@ -576,11 +595,22 @@ def main(): | |||
576 | weight_dtype = torch.bfloat16 | 595 | weight_dtype = torch.bfloat16 |
577 | 596 | ||
578 | def keyword_filter(item: CSVDataItem): | 597 | def keyword_filter(item: CSVDataItem): |
579 | return any( | 598 | cond1 = any( |
580 | keyword in part | 599 | keyword in part |
581 | for keyword in args.placeholder_token | 600 | for keyword in args.placeholder_token |
582 | for part in item.prompt | 601 | for part in item.prompt |
583 | ) | 602 | ) |
603 | cond2 = args.exclude_keywords is None or not any( | ||
604 | keyword in part | ||
605 | for keyword in args.exclude_keywords | ||
606 | for part in item.prompt | ||
607 | ) | ||
608 | cond3 = args.mode is None or args.mode in item.mode | ||
609 | cond4 = args.exclude_modes is None or not any( | ||
610 | mode in item.mode | ||
611 | for mode in args.exclude_modes | ||
612 | ) | ||
613 | return cond1 and cond2 and cond3 and cond4 | ||
584 | 614 | ||
585 | def collate_fn(examples): | 615 | def collate_fn(examples): |
586 | prompts = [example["prompts"] for example in examples] | 616 | prompts = [example["prompts"] for example in examples] |
@@ -617,7 +647,6 @@ def main(): | |||
617 | num_class_images=args.num_class_images, | 647 | num_class_images=args.num_class_images, |
618 | size=args.resolution, | 648 | size=args.resolution, |
619 | repeats=args.repeats, | 649 | repeats=args.repeats, |
620 | mode=args.mode, | ||
621 | dropout=args.tag_dropout, | 650 | dropout=args.tag_dropout, |
622 | center_crop=args.center_crop, | 651 | center_crop=args.center_crop, |
623 | template_key=args.train_data_template, | 652 | template_key=args.train_data_template, |
@@ -769,7 +798,7 @@ def main(): | |||
769 | target, target_prior = torch.chunk(target, 2, dim=0) | 798 | target, target_prior = torch.chunk(target, 2, dim=0) |
770 | 799 | ||
771 | # Compute instance loss | 800 | # Compute instance loss |
772 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="none").mean([1, 2, 3]).mean() | 801 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") |
773 | 802 | ||
774 | # Compute prior loss | 803 | # Compute prior loss |
775 | prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") | 804 | prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") |
@@ -785,7 +814,7 @@ def main(): | |||
785 | 814 | ||
786 | if args.find_lr: | 815 | if args.find_lr: |
787 | lr_finder = LRFinder(accelerator, text_encoder, optimizer, train_dataloader, val_dataloader, loop) | 816 | lr_finder = LRFinder(accelerator, text_encoder, optimizer, train_dataloader, val_dataloader, loop) |
788 | lr_finder.run(min_lr=1e-6, num_train_batches=4) | 817 | lr_finder.run(min_lr=1e-6, num_train_batches=1) |
789 | 818 | ||
790 | plt.savefig(basepath.joinpath("lr.png")) | 819 | plt.savefig(basepath.joinpath("lr.png")) |
791 | plt.close() | 820 | plt.close() |
@@ -798,6 +827,10 @@ def main(): | |||
798 | config = vars(args).copy() | 827 | config = vars(args).copy() |
799 | config["initializer_token"] = " ".join(config["initializer_token"]) | 828 | config["initializer_token"] = " ".join(config["initializer_token"]) |
800 | config["placeholder_token"] = " ".join(config["placeholder_token"]) | 829 | config["placeholder_token"] = " ".join(config["placeholder_token"]) |
830 | if config["exclude_modes"] is not None: | ||
831 | config["exclude_modes"] = " ".join(config["exclude_modes"]) | ||
832 | if config["exclude_keywords"] is not None: | ||
833 | config["exclude_keywords"] = " ".join(config["exclude_keywords"]) | ||
801 | accelerator.init_trackers("textual_inversion", config=config) | 834 | accelerator.init_trackers("textual_inversion", config=config) |
802 | 835 | ||
803 | # Train! | 836 | # Train! |