diff options
author | Volpeon <git@volpeon.ink> | 2022-12-30 14:04:59 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2022-12-30 14:04:59 +0100 |
commit | 799a2ed9c9735d11887600ee57ebb7471cdf6f43 (patch) | |
tree | 22a982d7348762f3cc55e91ba1e173f14c86cb99 /train_ti.py | |
parent | Training script improvements (diff) | |
download | textual-inversion-diff-799a2ed9c9735d11887600ee57ebb7471cdf6f43.tar.gz textual-inversion-diff-799a2ed9c9735d11887600ee57ebb7471cdf6f43.tar.bz2 textual-inversion-diff-799a2ed9c9735d11887600ee57ebb7471cdf6f43.zip |
Misc improvements
Diffstat (limited to 'train_ti.py')
-rw-r--r-- | train_ti.py | 47 |
1 files changed, 18 insertions, 29 deletions
diff --git a/train_ti.py b/train_ti.py index 6aa4007..088c1a6 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -93,16 +93,10 @@ def parse_args(): | |||
93 | help="The directory where class images will be saved.", | 93 | help="The directory where class images will be saved.", |
94 | ) | 94 | ) |
95 | parser.add_argument( | 95 | parser.add_argument( |
96 | "--exclude_keywords", | 96 | "--exclude_collections", |
97 | type=str, | 97 | type=str, |
98 | nargs='*', | 98 | nargs='*', |
99 | help="Skip dataset items containing a listed keyword.", | 99 | help="Exclude all items with a listed collection.", |
100 | ) | ||
101 | parser.add_argument( | ||
102 | "--exclude_modes", | ||
103 | type=str, | ||
104 | nargs='*', | ||
105 | help="Exclude all items with a listed mode.", | ||
106 | ) | 100 | ) |
107 | parser.add_argument( | 101 | parser.add_argument( |
108 | "--repeats", | 102 | "--repeats", |
@@ -123,10 +117,10 @@ def parse_args(): | |||
123 | help="The embeddings directory where Textual Inversion embeddings are stored.", | 117 | help="The embeddings directory where Textual Inversion embeddings are stored.", |
124 | ) | 118 | ) |
125 | parser.add_argument( | 119 | parser.add_argument( |
126 | "--mode", | 120 | "--collection", |
127 | type=str, | 121 | type=str, |
128 | default=None, | 122 | nargs='*', |
129 | help="A mode to filter the dataset.", | 123 | help="A collection to filter the dataset.", |
130 | ) | 124 | ) |
131 | parser.add_argument( | 125 | parser.add_argument( |
132 | "--seed", | 126 | "--seed", |
@@ -369,11 +363,11 @@ def parse_args(): | |||
369 | if len(args.placeholder_token) != len(args.initializer_token): | 363 | if len(args.placeholder_token) != len(args.initializer_token): |
370 | raise ValueError("You must specify --placeholder_token") | 364 | raise ValueError("You must specify --placeholder_token") |
371 | 365 | ||
372 | if isinstance(args.exclude_keywords, str): | 366 | if isinstance(args.collection, str): |
373 | args.exclude_keywords = [args.exclude_keywords] | 367 | args.collection = [args.collection] |
374 | 368 | ||
375 | if isinstance(args.exclude_modes, str): | 369 | if isinstance(args.exclude_collections, str): |
376 | args.exclude_modes = [args.exclude_modes] | 370 | args.exclude_collections = [args.exclude_collections] |
377 | 371 | ||
378 | if args.output_dir is None: | 372 | if args.output_dir is None: |
379 | raise ValueError("You must specify --output_dir") | 373 | raise ValueError("You must specify --output_dir") |
@@ -600,17 +594,12 @@ def main(): | |||
600 | for keyword in args.placeholder_token | 594 | for keyword in args.placeholder_token |
601 | for part in item.prompt | 595 | for part in item.prompt |
602 | ) | 596 | ) |
603 | cond2 = args.exclude_keywords is None or not any( | 597 | cond3 = args.collection is None or args.collection in item.collection |
604 | keyword in part | 598 | cond4 = args.exclude_collections is None or not any( |
605 | for keyword in args.exclude_keywords | 599 | collection in item.collection |
606 | for part in item.prompt | 600 | for collection in args.exclude_collections |
607 | ) | ||
608 | cond3 = args.mode is None or args.mode in item.mode | ||
609 | cond4 = args.exclude_modes is None or not any( | ||
610 | mode in item.mode | ||
611 | for mode in args.exclude_modes | ||
612 | ) | 601 | ) |
613 | return cond1 and cond2 and cond3 and cond4 | 602 | return cond1 and cond3 and cond4 |
614 | 603 | ||
615 | def collate_fn(examples): | 604 | def collate_fn(examples): |
616 | prompts = [example["prompts"] for example in examples] | 605 | prompts = [example["prompts"] for example in examples] |
@@ -827,10 +816,10 @@ def main(): | |||
827 | config = vars(args).copy() | 816 | config = vars(args).copy() |
828 | config["initializer_token"] = " ".join(config["initializer_token"]) | 817 | config["initializer_token"] = " ".join(config["initializer_token"]) |
829 | config["placeholder_token"] = " ".join(config["placeholder_token"]) | 818 | config["placeholder_token"] = " ".join(config["placeholder_token"]) |
830 | if config["exclude_modes"] is not None: | 819 | if config["collection"] is not None: |
831 | config["exclude_modes"] = " ".join(config["exclude_modes"]) | 820 | config["collection"] = " ".join(config["collection"]) |
832 | if config["exclude_keywords"] is not None: | 821 | if config["exclude_collections"] is not None: |
833 | config["exclude_keywords"] = " ".join(config["exclude_keywords"]) | 822 | config["exclude_collections"] = " ".join(config["exclude_collections"]) |
834 | accelerator.init_trackers("textual_inversion", config=config) | 823 | accelerator.init_trackers("textual_inversion", config=config) |
835 | 824 | ||
836 | # Train! | 825 | # Train! |