diff options
| author | Volpeon <git@volpeon.ink> | 2022-12-30 13:48:26 +0100 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2022-12-30 13:48:26 +0100 |
| commit | dfcfd6bc1db6b9eb12c8321d18fc7a461710e7e0 (patch) | |
| tree | da07cbadfad6f54e55e43e2fda21cef80cded5ea /train_dreambooth.py | |
| parent | Update (diff) | |
| download | textual-inversion-diff-dfcfd6bc1db6b9eb12c8321d18fc7a461710e7e0.tar.gz textual-inversion-diff-dfcfd6bc1db6b9eb12c8321d18fc7a461710e7e0.tar.bz2 textual-inversion-diff-dfcfd6bc1db6b9eb12c8321d18fc7a461710e7e0.zip | |
Training script improvements
Diffstat (limited to 'train_dreambooth.py')
| -rw-r--r-- | train_dreambooth.py | 41 |
1 files changed, 38 insertions, 3 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py index 202d52c..072150b 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
| @@ -22,7 +22,7 @@ from slugify import slugify | |||
| 22 | 22 | ||
| 23 | from common import load_text_embeddings, load_config | 23 | from common import load_text_embeddings, load_config |
| 24 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 24 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion |
| 25 | from data.csv import CSVDataModule | 25 | from data.csv import CSVDataModule, CSVDataItem |
| 26 | from training.optimization import get_one_cycle_schedule | 26 | from training.optimization import get_one_cycle_schedule |
| 27 | from training.ti import patch_trainable_embeddings | 27 | from training.ti import patch_trainable_embeddings |
| 28 | from training.util import AverageMeter, CheckpointerBase, save_args | 28 | from training.util import AverageMeter, CheckpointerBase, save_args |
| @@ -83,6 +83,18 @@ def parse_args(): | |||
| 83 | help="A token to use as initializer word." | 83 | help="A token to use as initializer word." |
| 84 | ) | 84 | ) |
| 85 | parser.add_argument( | 85 | parser.add_argument( |
| 86 | "--exclude_keywords", | ||
| 87 | type=str, | ||
| 88 | nargs='*', | ||
| 89 | help="Skip dataset items containing a listed keyword.", | ||
| 90 | ) | ||
| 91 | parser.add_argument( | ||
| 92 | "--exclude_modes", | ||
| 93 | type=str, | ||
| 94 | nargs='*', | ||
| 95 | help="Exclude all items with a listed mode.", | ||
| 96 | ) | ||
| 97 | parser.add_argument( | ||
| 86 | "--train_text_encoder", | 98 | "--train_text_encoder", |
| 87 | action="store_true", | 99 | action="store_true", |
| 88 | default=True, | 100 | default=True, |
| @@ -379,6 +391,12 @@ def parse_args(): | |||
| 379 | if len(args.placeholder_token) != len(args.initializer_token): | 391 | if len(args.placeholder_token) != len(args.initializer_token): |
| 380 | raise ValueError("Number of items in --placeholder_token and --initializer_token must match") | 392 | raise ValueError("Number of items in --placeholder_token and --initializer_token must match") |
| 381 | 393 | ||
| 394 | if isinstance(args.exclude_keywords, str): | ||
| 395 | args.exclude_keywords = [args.exclude_keywords] | ||
| 396 | |||
| 397 | if isinstance(args.exclude_modes, str): | ||
| 398 | args.exclude_modes = [args.exclude_modes] | ||
| 399 | |||
| 382 | if args.output_dir is None: | 400 | if args.output_dir is None: |
| 383 | raise ValueError("You must specify --output_dir") | 401 | raise ValueError("You must specify --output_dir") |
| 384 | 402 | ||
| @@ -636,6 +654,19 @@ def main(): | |||
| 636 | elif args.mixed_precision == "bf16": | 654 | elif args.mixed_precision == "bf16": |
| 637 | weight_dtype = torch.bfloat16 | 655 | weight_dtype = torch.bfloat16 |
| 638 | 656 | ||
| 657 | def keyword_filter(item: CSVDataItem): | ||
| 658 | cond2 = args.exclude_keywords is None or not any( | ||
| 659 | keyword in part | ||
| 660 | for keyword in args.exclude_keywords | ||
| 661 | for part in item.prompt | ||
| 662 | ) | ||
| 663 | cond3 = args.mode is None or args.mode in item.mode | ||
| 664 | cond4 = args.exclude_modes is None or not any( | ||
| 665 | mode in item.mode | ||
| 666 | for mode in args.exclude_modes | ||
| 667 | ) | ||
| 668 | return cond2 and cond3 and cond4 | ||
| 669 | |||
| 639 | def collate_fn(examples): | 670 | def collate_fn(examples): |
| 640 | prompts = [example["prompts"] for example in examples] | 671 | prompts = [example["prompts"] for example in examples] |
| 641 | cprompts = [example["cprompts"] for example in examples] | 672 | cprompts = [example["cprompts"] for example in examples] |
| @@ -671,12 +702,12 @@ def main(): | |||
| 671 | num_class_images=args.num_class_images, | 702 | num_class_images=args.num_class_images, |
| 672 | size=args.resolution, | 703 | size=args.resolution, |
| 673 | repeats=args.repeats, | 704 | repeats=args.repeats, |
| 674 | mode=args.mode, | ||
| 675 | dropout=args.tag_dropout, | 705 | dropout=args.tag_dropout, |
| 676 | center_crop=args.center_crop, | 706 | center_crop=args.center_crop, |
| 677 | template_key=args.train_data_template, | 707 | template_key=args.train_data_template, |
| 678 | valid_set_size=args.valid_set_size, | 708 | valid_set_size=args.valid_set_size, |
| 679 | num_workers=args.dataloader_num_workers, | 709 | num_workers=args.dataloader_num_workers, |
| 710 | filter=keyword_filter, | ||
| 680 | collate_fn=collate_fn | 711 | collate_fn=collate_fn |
| 681 | ) | 712 | ) |
| 682 | 713 | ||
| @@ -782,6 +813,10 @@ def main(): | |||
| 782 | config = vars(args).copy() | 813 | config = vars(args).copy() |
| 783 | config["initializer_token"] = " ".join(config["initializer_token"]) | 814 | config["initializer_token"] = " ".join(config["initializer_token"]) |
| 784 | config["placeholder_token"] = " ".join(config["placeholder_token"]) | 815 | config["placeholder_token"] = " ".join(config["placeholder_token"]) |
| 816 | if config["exclude_modes"] is not None: | ||
| 817 | config["exclude_modes"] = " ".join(config["exclude_modes"]) | ||
| 818 | if config["exclude_keywords"] is not None: | ||
| 819 | config["exclude_keywords"] = " ".join(config["exclude_keywords"]) | ||
| 785 | accelerator.init_trackers("dreambooth", config=config) | 820 | accelerator.init_trackers("dreambooth", config=config) |
| 786 | 821 | ||
| 787 | # Train! | 822 | # Train! |
| @@ -879,7 +914,7 @@ def main(): | |||
| 879 | target, target_prior = torch.chunk(target, 2, dim=0) | 914 | target, target_prior = torch.chunk(target, 2, dim=0) |
| 880 | 915 | ||
| 881 | # Compute instance loss | 916 | # Compute instance loss |
| 882 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="none").mean([1, 2, 3]).mean() | 917 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") |
| 883 | 918 | ||
| 884 | # Compute prior loss | 919 | # Compute prior loss |
| 885 | prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") | 920 | prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") |
