diff options
| author | Volpeon <git@volpeon.ink> | 2022-12-30 14:04:59 +0100 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2022-12-30 14:04:59 +0100 |
| commit | 799a2ed9c9735d11887600ee57ebb7471cdf6f43 (patch) | |
| tree | 22a982d7348762f3cc55e91ba1e173f14c86cb99 /train_dreambooth.py | |
| parent | Training script improvements (diff) | |
| download | textual-inversion-diff-799a2ed9c9735d11887600ee57ebb7471cdf6f43.tar.gz textual-inversion-diff-799a2ed9c9735d11887600ee57ebb7471cdf6f43.tar.bz2 textual-inversion-diff-799a2ed9c9735d11887600ee57ebb7471cdf6f43.zip | |
Misc improvements
Diffstat (limited to 'train_dreambooth.py')
| -rw-r--r-- | train_dreambooth.py | 47 |
1 files changed, 18 insertions, 29 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py index 072150b..8fd78f1 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
| @@ -83,16 +83,10 @@ def parse_args(): | |||
| 83 | help="A token to use as initializer word." | 83 | help="A token to use as initializer word." |
| 84 | ) | 84 | ) |
| 85 | parser.add_argument( | 85 | parser.add_argument( |
| 86 | "--exclude_keywords", | 86 | "--exclude_collections", |
| 87 | type=str, | 87 | type=str, |
| 88 | nargs='*', | 88 | nargs='*', |
| 89 | help="Skip dataset items containing a listed keyword.", | 89 | help="Exclude all items with a listed collection.", |
| 90 | ) | ||
| 91 | parser.add_argument( | ||
| 92 | "--exclude_modes", | ||
| 93 | type=str, | ||
| 94 | nargs='*', | ||
| 95 | help="Exclude all items with a listed mode.", | ||
| 96 | ) | 90 | ) |
| 97 | parser.add_argument( | 91 | parser.add_argument( |
| 98 | "--train_text_encoder", | 92 | "--train_text_encoder", |
| @@ -142,10 +136,10 @@ def parse_args(): | |||
| 142 | help="The embeddings directory where Textual Inversion embeddings are stored.", | 136 | help="The embeddings directory where Textual Inversion embeddings are stored.", |
| 143 | ) | 137 | ) |
| 144 | parser.add_argument( | 138 | parser.add_argument( |
| 145 | "--mode", | 139 | "--collection", |
| 146 | type=str, | 140 | type=str, |
| 147 | default=None, | 141 | nargs='*', |
| 148 | help="A mode to filter the dataset.", | 142 | help="A collection to filter the dataset.", |
| 149 | ) | 143 | ) |
| 150 | parser.add_argument( | 144 | parser.add_argument( |
| 151 | "--seed", | 145 | "--seed", |
| @@ -391,11 +385,11 @@ def parse_args(): | |||
| 391 | if len(args.placeholder_token) != len(args.initializer_token): | 385 | if len(args.placeholder_token) != len(args.initializer_token): |
| 392 | raise ValueError("Number of items in --placeholder_token and --initializer_token must match") | 386 | raise ValueError("Number of items in --placeholder_token and --initializer_token must match") |
| 393 | 387 | ||
| 394 | if isinstance(args.exclude_keywords, str): | 388 | if isinstance(args.collection, str): |
| 395 | args.exclude_keywords = [args.exclude_keywords] | 389 | args.collection = [args.collection] |
| 396 | 390 | ||
| 397 | if isinstance(args.exclude_modes, str): | 391 | if isinstance(args.exclude_collections, str): |
| 398 | args.exclude_modes = [args.exclude_modes] | 392 | args.exclude_collections = [args.exclude_collections] |
| 399 | 393 | ||
| 400 | if args.output_dir is None: | 394 | if args.output_dir is None: |
| 401 | raise ValueError("You must specify --output_dir") | 395 | raise ValueError("You must specify --output_dir") |
| @@ -655,17 +649,12 @@ def main(): | |||
| 655 | weight_dtype = torch.bfloat16 | 649 | weight_dtype = torch.bfloat16 |
| 656 | 650 | ||
| 657 | def keyword_filter(item: CSVDataItem): | 651 | def keyword_filter(item: CSVDataItem): |
| 658 | cond2 = args.exclude_keywords is None or not any( | 652 | cond3 = args.collection is None or args.collection in item.collection |
| 659 | keyword in part | 653 | cond4 = args.exclude_collections is None or not any( |
| 660 | for keyword in args.exclude_keywords | 654 | collection in item.collection |
| 661 | for part in item.prompt | 655 | for collection in args.exclude_collections |
| 662 | ) | ||
| 663 | cond3 = args.mode is None or args.mode in item.mode | ||
| 664 | cond4 = args.exclude_modes is None or not any( | ||
| 665 | mode in item.mode | ||
| 666 | for mode in args.exclude_modes | ||
| 667 | ) | 656 | ) |
| 668 | return cond2 and cond3 and cond4 | 657 | return cond3 and cond4 |
| 669 | 658 | ||
| 670 | def collate_fn(examples): | 659 | def collate_fn(examples): |
| 671 | prompts = [example["prompts"] for example in examples] | 660 | prompts = [example["prompts"] for example in examples] |
| @@ -813,10 +802,10 @@ def main(): | |||
| 813 | config = vars(args).copy() | 802 | config = vars(args).copy() |
| 814 | config["initializer_token"] = " ".join(config["initializer_token"]) | 803 | config["initializer_token"] = " ".join(config["initializer_token"]) |
| 815 | config["placeholder_token"] = " ".join(config["placeholder_token"]) | 804 | config["placeholder_token"] = " ".join(config["placeholder_token"]) |
| 816 | if config["exclude_modes"] is not None: | 805 | if config["collection"] is not None: |
| 817 | config["exclude_modes"] = " ".join(config["exclude_modes"]) | 806 | config["collection"] = " ".join(config["collection"]) |
| 818 | if config["exclude_keywords"] is not None: | 807 | if config["exclude_collections"] is not None: |
| 819 | config["exclude_keywords"] = " ".join(config["exclude_keywords"]) | 808 | config["exclude_collections"] = " ".join(config["exclude_collections"]) |
| 820 | accelerator.init_trackers("dreambooth", config=config) | 809 | accelerator.init_trackers("dreambooth", config=config) |
| 821 | 810 | ||
| 822 | # Train! | 811 | # Train! |
