diff options
author | Volpeon <git@volpeon.ink> | 2022-12-30 13:48:26 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2022-12-30 13:48:26 +0100 |
commit | dfcfd6bc1db6b9eb12c8321d18fc7a461710e7e0 (patch) | |
tree | da07cbadfad6f54e55e43e2fda21cef80cded5ea /train_dreambooth.py | |
parent | Update (diff) | |
download | textual-inversion-diff-dfcfd6bc1db6b9eb12c8321d18fc7a461710e7e0.tar.gz textual-inversion-diff-dfcfd6bc1db6b9eb12c8321d18fc7a461710e7e0.tar.bz2 textual-inversion-diff-dfcfd6bc1db6b9eb12c8321d18fc7a461710e7e0.zip |
Training script improvements
Diffstat (limited to 'train_dreambooth.py')
-rw-r--r-- | train_dreambooth.py | 41 |
1 files changed, 38 insertions, 3 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py index 202d52c..072150b 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
@@ -22,7 +22,7 @@ from slugify import slugify | |||
22 | 22 | ||
23 | from common import load_text_embeddings, load_config | 23 | from common import load_text_embeddings, load_config |
24 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 24 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion |
25 | from data.csv import CSVDataModule | 25 | from data.csv import CSVDataModule, CSVDataItem |
26 | from training.optimization import get_one_cycle_schedule | 26 | from training.optimization import get_one_cycle_schedule |
27 | from training.ti import patch_trainable_embeddings | 27 | from training.ti import patch_trainable_embeddings |
28 | from training.util import AverageMeter, CheckpointerBase, save_args | 28 | from training.util import AverageMeter, CheckpointerBase, save_args |
@@ -83,6 +83,18 @@ def parse_args(): | |||
83 | help="A token to use as initializer word." | 83 | help="A token to use as initializer word." |
84 | ) | 84 | ) |
85 | parser.add_argument( | 85 | parser.add_argument( |
86 | "--exclude_keywords", | ||
87 | type=str, | ||
88 | nargs='*', | ||
89 | help="Skip dataset items containing a listed keyword.", | ||
90 | ) | ||
91 | parser.add_argument( | ||
92 | "--exclude_modes", | ||
93 | type=str, | ||
94 | nargs='*', | ||
95 | help="Exclude all items with a listed mode.", | ||
96 | ) | ||
97 | parser.add_argument( | ||
86 | "--train_text_encoder", | 98 | "--train_text_encoder", |
87 | action="store_true", | 99 | action="store_true", |
88 | default=True, | 100 | default=True, |
@@ -379,6 +391,12 @@ def parse_args(): | |||
379 | if len(args.placeholder_token) != len(args.initializer_token): | 391 | if len(args.placeholder_token) != len(args.initializer_token): |
380 | raise ValueError("Number of items in --placeholder_token and --initializer_token must match") | 392 | raise ValueError("Number of items in --placeholder_token and --initializer_token must match") |
381 | 393 | ||
394 | if isinstance(args.exclude_keywords, str): | ||
395 | args.exclude_keywords = [args.exclude_keywords] | ||
396 | |||
397 | if isinstance(args.exclude_modes, str): | ||
398 | args.exclude_modes = [args.exclude_modes] | ||
399 | |||
382 | if args.output_dir is None: | 400 | if args.output_dir is None: |
383 | raise ValueError("You must specify --output_dir") | 401 | raise ValueError("You must specify --output_dir") |
384 | 402 | ||
@@ -636,6 +654,19 @@ def main(): | |||
636 | elif args.mixed_precision == "bf16": | 654 | elif args.mixed_precision == "bf16": |
637 | weight_dtype = torch.bfloat16 | 655 | weight_dtype = torch.bfloat16 |
638 | 656 | ||
657 | def keyword_filter(item: CSVDataItem): | ||
658 | cond2 = args.exclude_keywords is None or not any( | ||
659 | keyword in part | ||
660 | for keyword in args.exclude_keywords | ||
661 | for part in item.prompt | ||
662 | ) | ||
663 | cond3 = args.mode is None or args.mode in item.mode | ||
664 | cond4 = args.exclude_modes is None or not any( | ||
665 | mode in item.mode | ||
666 | for mode in args.exclude_modes | ||
667 | ) | ||
668 | return cond2 and cond3 and cond4 | ||
669 | |||
639 | def collate_fn(examples): | 670 | def collate_fn(examples): |
640 | prompts = [example["prompts"] for example in examples] | 671 | prompts = [example["prompts"] for example in examples] |
641 | cprompts = [example["cprompts"] for example in examples] | 672 | cprompts = [example["cprompts"] for example in examples] |
@@ -671,12 +702,12 @@ def main(): | |||
671 | num_class_images=args.num_class_images, | 702 | num_class_images=args.num_class_images, |
672 | size=args.resolution, | 703 | size=args.resolution, |
673 | repeats=args.repeats, | 704 | repeats=args.repeats, |
674 | mode=args.mode, | ||
675 | dropout=args.tag_dropout, | 705 | dropout=args.tag_dropout, |
676 | center_crop=args.center_crop, | 706 | center_crop=args.center_crop, |
677 | template_key=args.train_data_template, | 707 | template_key=args.train_data_template, |
678 | valid_set_size=args.valid_set_size, | 708 | valid_set_size=args.valid_set_size, |
679 | num_workers=args.dataloader_num_workers, | 709 | num_workers=args.dataloader_num_workers, |
710 | filter=keyword_filter, | ||
680 | collate_fn=collate_fn | 711 | collate_fn=collate_fn |
681 | ) | 712 | ) |
682 | 713 | ||
@@ -782,6 +813,10 @@ def main(): | |||
782 | config = vars(args).copy() | 813 | config = vars(args).copy() |
783 | config["initializer_token"] = " ".join(config["initializer_token"]) | 814 | config["initializer_token"] = " ".join(config["initializer_token"]) |
784 | config["placeholder_token"] = " ".join(config["placeholder_token"]) | 815 | config["placeholder_token"] = " ".join(config["placeholder_token"]) |
816 | if config["exclude_modes"] is not None: | ||
817 | config["exclude_modes"] = " ".join(config["exclude_modes"]) | ||
818 | if config["exclude_keywords"] is not None: | ||
819 | config["exclude_keywords"] = " ".join(config["exclude_keywords"]) | ||
785 | accelerator.init_trackers("dreambooth", config=config) | 820 | accelerator.init_trackers("dreambooth", config=config) |
786 | 821 | ||
787 | # Train! | 822 | # Train! |
@@ -879,7 +914,7 @@ def main(): | |||
879 | target, target_prior = torch.chunk(target, 2, dim=0) | 914 | target, target_prior = torch.chunk(target, 2, dim=0) |
880 | 915 | ||
881 | # Compute instance loss | 916 | # Compute instance loss |
882 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="none").mean([1, 2, 3]).mean() | 917 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") |
883 | 918 | ||
884 | # Compute prior loss | 919 | # Compute prior loss |
885 | prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") | 920 | prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") |