diff options
author | Volpeon <git@volpeon.ink> | 2022-12-30 14:04:59 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2022-12-30 14:04:59 +0100 |
commit | 799a2ed9c9735d11887600ee57ebb7471cdf6f43 (patch) | |
tree | 22a982d7348762f3cc55e91ba1e173f14c86cb99 /train_dreambooth.py | |
parent | Training script improvements (diff) | |
download | textual-inversion-diff-799a2ed9c9735d11887600ee57ebb7471cdf6f43.tar.gz textual-inversion-diff-799a2ed9c9735d11887600ee57ebb7471cdf6f43.tar.bz2 textual-inversion-diff-799a2ed9c9735d11887600ee57ebb7471cdf6f43.zip |
Misc improvements
Diffstat (limited to 'train_dreambooth.py')
-rw-r--r-- | train_dreambooth.py | 47 |
1 files changed, 18 insertions, 29 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py index 072150b..8fd78f1 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
@@ -83,16 +83,10 @@ def parse_args(): | |||
83 | help="A token to use as initializer word." | 83 | help="A token to use as initializer word." |
84 | ) | 84 | ) |
85 | parser.add_argument( | 85 | parser.add_argument( |
86 | "--exclude_keywords", | 86 | "--exclude_collections", |
87 | type=str, | 87 | type=str, |
88 | nargs='*', | 88 | nargs='*', |
89 | help="Skip dataset items containing a listed keyword.", | 89 | help="Exclude all items with a listed collection.", |
90 | ) | ||
91 | parser.add_argument( | ||
92 | "--exclude_modes", | ||
93 | type=str, | ||
94 | nargs='*', | ||
95 | help="Exclude all items with a listed mode.", | ||
96 | ) | 90 | ) |
97 | parser.add_argument( | 91 | parser.add_argument( |
98 | "--train_text_encoder", | 92 | "--train_text_encoder", |
@@ -142,10 +136,10 @@ def parse_args(): | |||
142 | help="The embeddings directory where Textual Inversion embeddings are stored.", | 136 | help="The embeddings directory where Textual Inversion embeddings are stored.", |
143 | ) | 137 | ) |
144 | parser.add_argument( | 138 | parser.add_argument( |
145 | "--mode", | 139 | "--collection", |
146 | type=str, | 140 | type=str, |
147 | default=None, | 141 | nargs='*', |
148 | help="A mode to filter the dataset.", | 142 | help="A collection to filter the dataset.", |
149 | ) | 143 | ) |
150 | parser.add_argument( | 144 | parser.add_argument( |
151 | "--seed", | 145 | "--seed", |
@@ -391,11 +385,11 @@ def parse_args(): | |||
391 | if len(args.placeholder_token) != len(args.initializer_token): | 385 | if len(args.placeholder_token) != len(args.initializer_token): |
392 | raise ValueError("Number of items in --placeholder_token and --initializer_token must match") | 386 | raise ValueError("Number of items in --placeholder_token and --initializer_token must match") |
393 | 387 | ||
394 | if isinstance(args.exclude_keywords, str): | 388 | if isinstance(args.collection, str): |
395 | args.exclude_keywords = [args.exclude_keywords] | 389 | args.collection = [args.collection] |
396 | 390 | ||
397 | if isinstance(args.exclude_modes, str): | 391 | if isinstance(args.exclude_collections, str): |
398 | args.exclude_modes = [args.exclude_modes] | 392 | args.exclude_collections = [args.exclude_collections] |
399 | 393 | ||
400 | if args.output_dir is None: | 394 | if args.output_dir is None: |
401 | raise ValueError("You must specify --output_dir") | 395 | raise ValueError("You must specify --output_dir") |
@@ -655,17 +649,12 @@ def main(): | |||
655 | weight_dtype = torch.bfloat16 | 649 | weight_dtype = torch.bfloat16 |
656 | 650 | ||
657 | def keyword_filter(item: CSVDataItem): | 651 | def keyword_filter(item: CSVDataItem): |
658 | cond2 = args.exclude_keywords is None or not any( | 652 | cond3 = args.collection is None or args.collection in item.collection |
659 | keyword in part | 653 | cond4 = args.exclude_collections is None or not any( |
660 | for keyword in args.exclude_keywords | 654 | collection in item.collection |
661 | for part in item.prompt | 655 | for collection in args.exclude_collections |
662 | ) | ||
663 | cond3 = args.mode is None or args.mode in item.mode | ||
664 | cond4 = args.exclude_modes is None or not any( | ||
665 | mode in item.mode | ||
666 | for mode in args.exclude_modes | ||
667 | ) | 656 | ) |
668 | return cond2 and cond3 and cond4 | 657 | return cond3 and cond4 |
669 | 658 | ||
670 | def collate_fn(examples): | 659 | def collate_fn(examples): |
671 | prompts = [example["prompts"] for example in examples] | 660 | prompts = [example["prompts"] for example in examples] |
@@ -813,10 +802,10 @@ def main(): | |||
813 | config = vars(args).copy() | 802 | config = vars(args).copy() |
814 | config["initializer_token"] = " ".join(config["initializer_token"]) | 803 | config["initializer_token"] = " ".join(config["initializer_token"]) |
815 | config["placeholder_token"] = " ".join(config["placeholder_token"]) | 804 | config["placeholder_token"] = " ".join(config["placeholder_token"]) |
816 | if config["exclude_modes"] is not None: | 805 | if config["collection"] is not None: |
817 | config["exclude_modes"] = " ".join(config["exclude_modes"]) | 806 | config["collection"] = " ".join(config["collection"]) |
818 | if config["exclude_keywords"] is not None: | 807 | if config["exclude_collections"] is not None: |
819 | config["exclude_keywords"] = " ".join(config["exclude_keywords"]) | 808 | config["exclude_collections"] = " ".join(config["exclude_collections"]) |
820 | accelerator.init_trackers("dreambooth", config=config) | 809 | accelerator.init_trackers("dreambooth", config=config) |
821 | 810 | ||
822 | # Train! | 811 | # Train! |