summaryrefslogtreecommitdiffstats
path: root/train_dreambooth.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-12-30 13:48:26 +0100
committerVolpeon <git@volpeon.ink>2022-12-30 13:48:26 +0100
commitdfcfd6bc1db6b9eb12c8321d18fc7a461710e7e0 (patch)
treeda07cbadfad6f54e55e43e2fda21cef80cded5ea /train_dreambooth.py
parentUpdate (diff)
downloadtextual-inversion-diff-dfcfd6bc1db6b9eb12c8321d18fc7a461710e7e0.tar.gz
textual-inversion-diff-dfcfd6bc1db6b9eb12c8321d18fc7a461710e7e0.tar.bz2
textual-inversion-diff-dfcfd6bc1db6b9eb12c8321d18fc7a461710e7e0.zip
Training script improvements
Diffstat (limited to 'train_dreambooth.py')
-rw-r--r--train_dreambooth.py41
1 files changed, 38 insertions, 3 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py
index 202d52c..072150b 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -22,7 +22,7 @@ from slugify import slugify
22 22
23from common import load_text_embeddings, load_config 23from common import load_text_embeddings, load_config
24from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion 24from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
25from data.csv import CSVDataModule 25from data.csv import CSVDataModule, CSVDataItem
26from training.optimization import get_one_cycle_schedule 26from training.optimization import get_one_cycle_schedule
27from training.ti import patch_trainable_embeddings 27from training.ti import patch_trainable_embeddings
28from training.util import AverageMeter, CheckpointerBase, save_args 28from training.util import AverageMeter, CheckpointerBase, save_args
@@ -83,6 +83,18 @@ def parse_args():
83 help="A token to use as initializer word." 83 help="A token to use as initializer word."
84 ) 84 )
85 parser.add_argument( 85 parser.add_argument(
86 "--exclude_keywords",
87 type=str,
88 nargs='*',
89 help="Skip dataset items containing a listed keyword.",
90 )
91 parser.add_argument(
92 "--exclude_modes",
93 type=str,
94 nargs='*',
95 help="Exclude all items with a listed mode.",
96 )
97 parser.add_argument(
86 "--train_text_encoder", 98 "--train_text_encoder",
87 action="store_true", 99 action="store_true",
88 default=True, 100 default=True,
@@ -379,6 +391,12 @@ def parse_args():
379 if len(args.placeholder_token) != len(args.initializer_token): 391 if len(args.placeholder_token) != len(args.initializer_token):
380 raise ValueError("Number of items in --placeholder_token and --initializer_token must match") 392 raise ValueError("Number of items in --placeholder_token and --initializer_token must match")
381 393
394 if isinstance(args.exclude_keywords, str):
395 args.exclude_keywords = [args.exclude_keywords]
396
397 if isinstance(args.exclude_modes, str):
398 args.exclude_modes = [args.exclude_modes]
399
382 if args.output_dir is None: 400 if args.output_dir is None:
383 raise ValueError("You must specify --output_dir") 401 raise ValueError("You must specify --output_dir")
384 402
@@ -636,6 +654,19 @@ def main():
636 elif args.mixed_precision == "bf16": 654 elif args.mixed_precision == "bf16":
637 weight_dtype = torch.bfloat16 655 weight_dtype = torch.bfloat16
638 656
657 def keyword_filter(item: CSVDataItem):
658 cond2 = args.exclude_keywords is None or not any(
659 keyword in part
660 for keyword in args.exclude_keywords
661 for part in item.prompt
662 )
663 cond3 = args.mode is None or args.mode in item.mode
664 cond4 = args.exclude_modes is None or not any(
665 mode in item.mode
666 for mode in args.exclude_modes
667 )
668 return cond2 and cond3 and cond4
669
639 def collate_fn(examples): 670 def collate_fn(examples):
640 prompts = [example["prompts"] for example in examples] 671 prompts = [example["prompts"] for example in examples]
641 cprompts = [example["cprompts"] for example in examples] 672 cprompts = [example["cprompts"] for example in examples]
@@ -671,12 +702,12 @@ def main():
671 num_class_images=args.num_class_images, 702 num_class_images=args.num_class_images,
672 size=args.resolution, 703 size=args.resolution,
673 repeats=args.repeats, 704 repeats=args.repeats,
674 mode=args.mode,
675 dropout=args.tag_dropout, 705 dropout=args.tag_dropout,
676 center_crop=args.center_crop, 706 center_crop=args.center_crop,
677 template_key=args.train_data_template, 707 template_key=args.train_data_template,
678 valid_set_size=args.valid_set_size, 708 valid_set_size=args.valid_set_size,
679 num_workers=args.dataloader_num_workers, 709 num_workers=args.dataloader_num_workers,
710 filter=keyword_filter,
680 collate_fn=collate_fn 711 collate_fn=collate_fn
681 ) 712 )
682 713
@@ -782,6 +813,10 @@ def main():
782 config = vars(args).copy() 813 config = vars(args).copy()
783 config["initializer_token"] = " ".join(config["initializer_token"]) 814 config["initializer_token"] = " ".join(config["initializer_token"])
784 config["placeholder_token"] = " ".join(config["placeholder_token"]) 815 config["placeholder_token"] = " ".join(config["placeholder_token"])
816 if config["exclude_modes"] is not None:
817 config["exclude_modes"] = " ".join(config["exclude_modes"])
818 if config["exclude_keywords"] is not None:
819 config["exclude_keywords"] = " ".join(config["exclude_keywords"])
785 accelerator.init_trackers("dreambooth", config=config) 820 accelerator.init_trackers("dreambooth", config=config)
786 821
787 # Train! 822 # Train!
@@ -879,7 +914,7 @@ def main():
879 target, target_prior = torch.chunk(target, 2, dim=0) 914 target, target_prior = torch.chunk(target, 2, dim=0)
880 915
881 # Compute instance loss 916 # Compute instance loss
882 loss = F.mse_loss(model_pred.float(), target.float(), reduction="none").mean([1, 2, 3]).mean() 917 loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
883 918
884 # Compute prior loss 919 # Compute prior loss
885 prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") 920 prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean")