summaryrefslogtreecommitdiffstats
path: root/train_ti.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-12-30 14:04:59 +0100
committerVolpeon <git@volpeon.ink>2022-12-30 14:04:59 +0100
commit799a2ed9c9735d11887600ee57ebb7471cdf6f43 (patch)
tree22a982d7348762f3cc55e91ba1e173f14c86cb99 /train_ti.py
parentTraining script improvements (diff)
downloadtextual-inversion-diff-799a2ed9c9735d11887600ee57ebb7471cdf6f43.tar.gz
textual-inversion-diff-799a2ed9c9735d11887600ee57ebb7471cdf6f43.tar.bz2
textual-inversion-diff-799a2ed9c9735d11887600ee57ebb7471cdf6f43.zip
Misc improvements
Diffstat (limited to 'train_ti.py')
-rw-r--r--train_ti.py47
1 files changed, 18 insertions, 29 deletions
diff --git a/train_ti.py b/train_ti.py
index 6aa4007..088c1a6 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -93,16 +93,10 @@ def parse_args():
93 help="The directory where class images will be saved.", 93 help="The directory where class images will be saved.",
94 ) 94 )
95 parser.add_argument( 95 parser.add_argument(
96 "--exclude_keywords", 96 "--exclude_collections",
97 type=str, 97 type=str,
98 nargs='*', 98 nargs='*',
99 help="Skip dataset items containing a listed keyword.", 99 help="Exclude all items with a listed collection.",
100 )
101 parser.add_argument(
102 "--exclude_modes",
103 type=str,
104 nargs='*',
105 help="Exclude all items with a listed mode.",
106 ) 100 )
107 parser.add_argument( 101 parser.add_argument(
108 "--repeats", 102 "--repeats",
@@ -123,10 +117,10 @@ def parse_args():
123 help="The embeddings directory where Textual Inversion embeddings are stored.", 117 help="The embeddings directory where Textual Inversion embeddings are stored.",
124 ) 118 )
125 parser.add_argument( 119 parser.add_argument(
126 "--mode", 120 "--collection",
127 type=str, 121 type=str,
128 default=None, 122 nargs='*',
129 help="A mode to filter the dataset.", 123 help="A collection to filter the dataset.",
130 ) 124 )
131 parser.add_argument( 125 parser.add_argument(
132 "--seed", 126 "--seed",
@@ -369,11 +363,11 @@ def parse_args():
369 if len(args.placeholder_token) != len(args.initializer_token): 363 if len(args.placeholder_token) != len(args.initializer_token):
370 raise ValueError("You must specify --placeholder_token") 364 raise ValueError("You must specify --placeholder_token")
371 365
372 if isinstance(args.exclude_keywords, str): 366 if isinstance(args.collection, str):
373 args.exclude_keywords = [args.exclude_keywords] 367 args.collection = [args.collection]
374 368
375 if isinstance(args.exclude_modes, str): 369 if isinstance(args.exclude_collections, str):
376 args.exclude_modes = [args.exclude_modes] 370 args.exclude_collections = [args.exclude_collections]
377 371
378 if args.output_dir is None: 372 if args.output_dir is None:
379 raise ValueError("You must specify --output_dir") 373 raise ValueError("You must specify --output_dir")
@@ -600,17 +594,12 @@ def main():
600 for keyword in args.placeholder_token 594 for keyword in args.placeholder_token
601 for part in item.prompt 595 for part in item.prompt
602 ) 596 )
603 cond2 = args.exclude_keywords is None or not any( 597 cond3 = args.collection is None or args.collection in item.collection
604 keyword in part 598 cond4 = args.exclude_collections is None or not any(
605 for keyword in args.exclude_keywords 599 collection in item.collection
606 for part in item.prompt 600 for collection in args.exclude_collections
607 )
608 cond3 = args.mode is None or args.mode in item.mode
609 cond4 = args.exclude_modes is None or not any(
610 mode in item.mode
611 for mode in args.exclude_modes
612 ) 601 )
613 return cond1 and cond2 and cond3 and cond4 602 return cond1 and cond3 and cond4
614 603
615 def collate_fn(examples): 604 def collate_fn(examples):
616 prompts = [example["prompts"] for example in examples] 605 prompts = [example["prompts"] for example in examples]
@@ -827,10 +816,10 @@ def main():
827 config = vars(args).copy() 816 config = vars(args).copy()
828 config["initializer_token"] = " ".join(config["initializer_token"]) 817 config["initializer_token"] = " ".join(config["initializer_token"])
829 config["placeholder_token"] = " ".join(config["placeholder_token"]) 818 config["placeholder_token"] = " ".join(config["placeholder_token"])
830 if config["exclude_modes"] is not None: 819 if config["collection"] is not None:
831 config["exclude_modes"] = " ".join(config["exclude_modes"]) 820 config["collection"] = " ".join(config["collection"])
832 if config["exclude_keywords"] is not None: 821 if config["exclude_collections"] is not None:
833 config["exclude_keywords"] = " ".join(config["exclude_keywords"]) 822 config["exclude_collections"] = " ".join(config["exclude_collections"])
834 accelerator.init_trackers("textual_inversion", config=config) 823 accelerator.init_trackers("textual_inversion", config=config)
835 824
836 # Train! 825 # Train!