diff options
| author | Volpeon <git@volpeon.ink> | 2022-12-30 13:48:26 +0100 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2022-12-30 13:48:26 +0100 |
| commit | dfcfd6bc1db6b9eb12c8321d18fc7a461710e7e0 (patch) | |
| tree | da07cbadfad6f54e55e43e2fda21cef80cded5ea /train_ti.py | |
| parent | Update (diff) | |
| download | textual-inversion-diff-dfcfd6bc1db6b9eb12c8321d18fc7a461710e7e0.tar.gz textual-inversion-diff-dfcfd6bc1db6b9eb12c8321d18fc7a461710e7e0.tar.bz2 textual-inversion-diff-dfcfd6bc1db6b9eb12c8321d18fc7a461710e7e0.zip | |
Training script improvements
Diffstat (limited to 'train_ti.py')
| -rw-r--r-- | train_ti.py | 43 |
1 files changed, 38 insertions, 5 deletions
diff --git a/train_ti.py b/train_ti.py index b1f6a49..6aa4007 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -93,6 +93,18 @@ def parse_args(): | |||
| 93 | help="The directory where class images will be saved.", | 93 | help="The directory where class images will be saved.", |
| 94 | ) | 94 | ) |
| 95 | parser.add_argument( | 95 | parser.add_argument( |
| 96 | "--exclude_keywords", | ||
| 97 | type=str, | ||
| 98 | nargs='*', | ||
| 99 | help="Skip dataset items containing a listed keyword.", | ||
| 100 | ) | ||
| 101 | parser.add_argument( | ||
| 102 | "--exclude_modes", | ||
| 103 | type=str, | ||
| 104 | nargs='*', | ||
| 105 | help="Exclude all items with a listed mode.", | ||
| 106 | ) | ||
| 107 | parser.add_argument( | ||
| 96 | "--repeats", | 108 | "--repeats", |
| 97 | type=int, | 109 | type=int, |
| 98 | default=1, | 110 | default=1, |
| @@ -120,7 +132,8 @@ def parse_args(): | |||
| 120 | "--seed", | 132 | "--seed", |
| 121 | type=int, | 133 | type=int, |
| 122 | default=None, | 134 | default=None, |
| 123 | help="A seed for reproducible training.") | 135 | help="A seed for reproducible training." |
| 136 | ) | ||
| 124 | parser.add_argument( | 137 | parser.add_argument( |
| 125 | "--resolution", | 138 | "--resolution", |
| 126 | type=int, | 139 | type=int, |
| @@ -356,6 +369,12 @@ def parse_args(): | |||
| 356 | if len(args.placeholder_token) != len(args.initializer_token): | 369 | if len(args.placeholder_token) != len(args.initializer_token): |
| 357 | raise ValueError("You must specify --placeholder_token") | 370 | raise ValueError("You must specify --placeholder_token") |
| 358 | 371 | ||
| 372 | if isinstance(args.exclude_keywords, str): | ||
| 373 | args.exclude_keywords = [args.exclude_keywords] | ||
| 374 | |||
| 375 | if isinstance(args.exclude_modes, str): | ||
| 376 | args.exclude_modes = [args.exclude_modes] | ||
| 377 | |||
| 359 | if args.output_dir is None: | 378 | if args.output_dir is None: |
| 360 | raise ValueError("You must specify --output_dir") | 379 | raise ValueError("You must specify --output_dir") |
| 361 | 380 | ||
| @@ -576,11 +595,22 @@ def main(): | |||
| 576 | weight_dtype = torch.bfloat16 | 595 | weight_dtype = torch.bfloat16 |
| 577 | 596 | ||
| 578 | def keyword_filter(item: CSVDataItem): | 597 | def keyword_filter(item: CSVDataItem): |
| 579 | return any( | 598 | cond1 = any( |
| 580 | keyword in part | 599 | keyword in part |
| 581 | for keyword in args.placeholder_token | 600 | for keyword in args.placeholder_token |
| 582 | for part in item.prompt | 601 | for part in item.prompt |
| 583 | ) | 602 | ) |
| 603 | cond2 = args.exclude_keywords is None or not any( | ||
| 604 | keyword in part | ||
| 605 | for keyword in args.exclude_keywords | ||
| 606 | for part in item.prompt | ||
| 607 | ) | ||
| 608 | cond3 = args.mode is None or args.mode in item.mode | ||
| 609 | cond4 = args.exclude_modes is None or not any( | ||
| 610 | mode in item.mode | ||
| 611 | for mode in args.exclude_modes | ||
| 612 | ) | ||
| 613 | return cond1 and cond2 and cond3 and cond4 | ||
| 584 | 614 | ||
| 585 | def collate_fn(examples): | 615 | def collate_fn(examples): |
| 586 | prompts = [example["prompts"] for example in examples] | 616 | prompts = [example["prompts"] for example in examples] |
| @@ -617,7 +647,6 @@ def main(): | |||
| 617 | num_class_images=args.num_class_images, | 647 | num_class_images=args.num_class_images, |
| 618 | size=args.resolution, | 648 | size=args.resolution, |
| 619 | repeats=args.repeats, | 649 | repeats=args.repeats, |
| 620 | mode=args.mode, | ||
| 621 | dropout=args.tag_dropout, | 650 | dropout=args.tag_dropout, |
| 622 | center_crop=args.center_crop, | 651 | center_crop=args.center_crop, |
| 623 | template_key=args.train_data_template, | 652 | template_key=args.train_data_template, |
| @@ -769,7 +798,7 @@ def main(): | |||
| 769 | target, target_prior = torch.chunk(target, 2, dim=0) | 798 | target, target_prior = torch.chunk(target, 2, dim=0) |
| 770 | 799 | ||
| 771 | # Compute instance loss | 800 | # Compute instance loss |
| 772 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="none").mean([1, 2, 3]).mean() | 801 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") |
| 773 | 802 | ||
| 774 | # Compute prior loss | 803 | # Compute prior loss |
| 775 | prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") | 804 | prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") |
| @@ -785,7 +814,7 @@ def main(): | |||
| 785 | 814 | ||
| 786 | if args.find_lr: | 815 | if args.find_lr: |
| 787 | lr_finder = LRFinder(accelerator, text_encoder, optimizer, train_dataloader, val_dataloader, loop) | 816 | lr_finder = LRFinder(accelerator, text_encoder, optimizer, train_dataloader, val_dataloader, loop) |
| 788 | lr_finder.run(min_lr=1e-6, num_train_batches=4) | 817 | lr_finder.run(min_lr=1e-6, num_train_batches=1) |
| 789 | 818 | ||
| 790 | plt.savefig(basepath.joinpath("lr.png")) | 819 | plt.savefig(basepath.joinpath("lr.png")) |
| 791 | plt.close() | 820 | plt.close() |
| @@ -798,6 +827,10 @@ def main(): | |||
| 798 | config = vars(args).copy() | 827 | config = vars(args).copy() |
| 799 | config["initializer_token"] = " ".join(config["initializer_token"]) | 828 | config["initializer_token"] = " ".join(config["initializer_token"]) |
| 800 | config["placeholder_token"] = " ".join(config["placeholder_token"]) | 829 | config["placeholder_token"] = " ".join(config["placeholder_token"]) |
| 830 | if config["exclude_modes"] is not None: | ||
| 831 | config["exclude_modes"] = " ".join(config["exclude_modes"]) | ||
| 832 | if config["exclude_keywords"] is not None: | ||
| 833 | config["exclude_keywords"] = " ".join(config["exclude_keywords"]) | ||
| 801 | accelerator.init_trackers("textual_inversion", config=config) | 834 | accelerator.init_trackers("textual_inversion", config=config) |
| 802 | 835 | ||
| 803 | # Train! | 836 | # Train! |
