summaryrefslogtreecommitdiffstats
path: root/train_dreambooth.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-12-30 14:04:59 +0100
committerVolpeon <git@volpeon.ink>2022-12-30 14:04:59 +0100
commit799a2ed9c9735d11887600ee57ebb7471cdf6f43 (patch)
tree22a982d7348762f3cc55e91ba1e173f14c86cb99 /train_dreambooth.py
parentTraining script improvements (diff)
downloadtextual-inversion-diff-799a2ed9c9735d11887600ee57ebb7471cdf6f43.tar.gz
textual-inversion-diff-799a2ed9c9735d11887600ee57ebb7471cdf6f43.tar.bz2
textual-inversion-diff-799a2ed9c9735d11887600ee57ebb7471cdf6f43.zip
Misc improvements
Diffstat (limited to 'train_dreambooth.py')
-rw-r--r--train_dreambooth.py47
1 files changed, 18 insertions, 29 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py
index 072150b..8fd78f1 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -83,16 +83,10 @@ def parse_args():
83 help="A token to use as initializer word." 83 help="A token to use as initializer word."
84 ) 84 )
85 parser.add_argument( 85 parser.add_argument(
86 "--exclude_keywords", 86 "--exclude_collections",
87 type=str, 87 type=str,
88 nargs='*', 88 nargs='*',
89 help="Skip dataset items containing a listed keyword.", 89 help="Exclude all items with a listed collection.",
90 )
91 parser.add_argument(
92 "--exclude_modes",
93 type=str,
94 nargs='*',
95 help="Exclude all items with a listed mode.",
96 ) 90 )
97 parser.add_argument( 91 parser.add_argument(
98 "--train_text_encoder", 92 "--train_text_encoder",
@@ -142,10 +136,10 @@ def parse_args():
142 help="The embeddings directory where Textual Inversion embeddings are stored.", 136 help="The embeddings directory where Textual Inversion embeddings are stored.",
143 ) 137 )
144 parser.add_argument( 138 parser.add_argument(
145 "--mode", 139 "--collection",
146 type=str, 140 type=str,
147 default=None, 141 nargs='*',
148 help="A mode to filter the dataset.", 142 help="A collection to filter the dataset.",
149 ) 143 )
150 parser.add_argument( 144 parser.add_argument(
151 "--seed", 145 "--seed",
@@ -391,11 +385,11 @@ def parse_args():
391 if len(args.placeholder_token) != len(args.initializer_token): 385 if len(args.placeholder_token) != len(args.initializer_token):
392 raise ValueError("Number of items in --placeholder_token and --initializer_token must match") 386 raise ValueError("Number of items in --placeholder_token and --initializer_token must match")
393 387
394 if isinstance(args.exclude_keywords, str): 388 if isinstance(args.collection, str):
395 args.exclude_keywords = [args.exclude_keywords] 389 args.collection = [args.collection]
396 390
397 if isinstance(args.exclude_modes, str): 391 if isinstance(args.exclude_collections, str):
398 args.exclude_modes = [args.exclude_modes] 392 args.exclude_collections = [args.exclude_collections]
399 393
400 if args.output_dir is None: 394 if args.output_dir is None:
401 raise ValueError("You must specify --output_dir") 395 raise ValueError("You must specify --output_dir")
@@ -655,17 +649,12 @@ def main():
655 weight_dtype = torch.bfloat16 649 weight_dtype = torch.bfloat16
656 650
657 def keyword_filter(item: CSVDataItem): 651 def keyword_filter(item: CSVDataItem):
658 cond2 = args.exclude_keywords is None or not any( 652 cond3 = args.collection is None or args.collection in item.collection
659 keyword in part 653 cond4 = args.exclude_collections is None or not any(
660 for keyword in args.exclude_keywords 654 collection in item.collection
661 for part in item.prompt 655 for collection in args.exclude_collections
662 )
663 cond3 = args.mode is None or args.mode in item.mode
664 cond4 = args.exclude_modes is None or not any(
665 mode in item.mode
666 for mode in args.exclude_modes
667 ) 656 )
668 return cond2 and cond3 and cond4 657 return cond3 and cond4
669 658
670 def collate_fn(examples): 659 def collate_fn(examples):
671 prompts = [example["prompts"] for example in examples] 660 prompts = [example["prompts"] for example in examples]
@@ -813,10 +802,10 @@ def main():
813 config = vars(args).copy() 802 config = vars(args).copy()
814 config["initializer_token"] = " ".join(config["initializer_token"]) 803 config["initializer_token"] = " ".join(config["initializer_token"])
815 config["placeholder_token"] = " ".join(config["placeholder_token"]) 804 config["placeholder_token"] = " ".join(config["placeholder_token"])
816 if config["exclude_modes"] is not None: 805 if config["collection"] is not None:
817 config["exclude_modes"] = " ".join(config["exclude_modes"]) 806 config["collection"] = " ".join(config["collection"])
818 if config["exclude_keywords"] is not None: 807 if config["exclude_collections"] is not None:
819 config["exclude_keywords"] = " ".join(config["exclude_keywords"]) 808 config["exclude_collections"] = " ".join(config["exclude_collections"])
820 accelerator.init_trackers("dreambooth", config=config) 809 accelerator.init_trackers("dreambooth", config=config)
821 810
822 # Train! 811 # Train!