diff options
| author | Volpeon <git@volpeon.ink> | 2023-06-23 06:48:38 +0200 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2023-06-23 06:48:38 +0200 |
| commit | 950f1f6bcbb1a767170cea590b828d8e3cdae882 (patch) | |
| tree | 019a1d3463b363a3f335e16c50cb6890efc3470f /train_dreambooth.py | |
| parent | Update (diff) | |
| download | textual-inversion-diff-950f1f6bcbb1a767170cea590b828d8e3cdae882.tar.gz textual-inversion-diff-950f1f6bcbb1a767170cea590b828d8e3cdae882.tar.bz2 textual-inversion-diff-950f1f6bcbb1a767170cea590b828d8e3cdae882.zip | |
Update
Diffstat (limited to 'train_dreambooth.py')
| -rw-r--r-- | train_dreambooth.py | 32 |
1 files changed, 14 insertions, 18 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py index d284346..c8f03ea 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
| @@ -145,12 +145,6 @@ def parse_args(): | |||
| 145 | help="Tokens to create an alias for.", | 145 | help="Tokens to create an alias for.", |
| 146 | ) | 146 | ) |
| 147 | parser.add_argument( | 147 | parser.add_argument( |
| 148 | "--inverted_initializer_tokens", | ||
| 149 | type=str, | ||
| 150 | nargs="*", | ||
| 151 | help="A token to use as initializer word.", | ||
| 152 | ) | ||
| 153 | parser.add_argument( | ||
| 154 | "--num_vectors", type=int, nargs="*", help="Number of vectors per embedding." | 148 | "--num_vectors", type=int, nargs="*", help="Number of vectors per embedding." |
| 155 | ) | 149 | ) |
| 156 | parser.add_argument( | 150 | parser.add_argument( |
| @@ -499,6 +493,15 @@ def parse_args(): | |||
| 499 | help="Embedding dropout probability.", | 493 | help="Embedding dropout probability.", |
| 500 | ) | 494 | ) |
| 501 | parser.add_argument( | 495 | parser.add_argument( |
| 496 | "--use_emb_decay", action="store_true", help="Whether to use embedding decay." | ||
| 497 | ) | ||
| 498 | parser.add_argument( | ||
| 499 | "--emb_decay_target", default=0.4, type=float, help="Embedding decay target." | ||
| 500 | ) | ||
| 501 | parser.add_argument( | ||
| 502 | "--emb_decay", default=1e2, type=float, help="Embedding decay factor." | ||
| 503 | ) | ||
| 504 | parser.add_argument( | ||
| 502 | "--max_grad_norm", default=1.0, type=float, help="Max gradient norm." | 505 | "--max_grad_norm", default=1.0, type=float, help="Max gradient norm." |
| 503 | ) | 506 | ) |
| 504 | parser.add_argument( | 507 | parser.add_argument( |
| @@ -554,18 +557,6 @@ def parse_args(): | |||
| 554 | "--placeholder_tokens and --initializer_tokens must have the same number of items" | 557 | "--placeholder_tokens and --initializer_tokens must have the same number of items" |
| 555 | ) | 558 | ) |
| 556 | 559 | ||
| 557 | if isinstance(args.inverted_initializer_tokens, str): | ||
| 558 | args.inverted_initializer_tokens = [args.inverted_initializer_tokens] * len( | ||
| 559 | args.placeholder_tokens | ||
| 560 | ) | ||
| 561 | |||
| 562 | if ( | ||
| 563 | isinstance(args.inverted_initializer_tokens, list) | ||
| 564 | and len(args.inverted_initializer_tokens) != 0 | ||
| 565 | ): | ||
| 566 | args.placeholder_tokens += [f"inv_{t}" for t in args.placeholder_tokens] | ||
| 567 | args.initializer_tokens += args.inverted_initializer_tokens | ||
| 568 | |||
| 569 | if isinstance(args.num_vectors, int): | 560 | if isinstance(args.num_vectors, int): |
| 570 | args.num_vectors = [args.num_vectors] * len(args.placeholder_tokens) | 561 | args.num_vectors = [args.num_vectors] * len(args.placeholder_tokens) |
| 571 | 562 | ||
| @@ -875,6 +866,11 @@ def main(): | |||
| 875 | sample_num_batches=args.sample_batches, | 866 | sample_num_batches=args.sample_batches, |
| 876 | sample_num_steps=args.sample_steps, | 867 | sample_num_steps=args.sample_steps, |
| 877 | sample_image_size=args.sample_image_size, | 868 | sample_image_size=args.sample_image_size, |
| 869 | placeholder_tokens=placeholder_tokens, | ||
| 870 | placeholder_token_ids=placeholder_token_ids, | ||
| 871 | use_emb_decay=args.use_emb_decay, | ||
| 872 | emb_decay_target=args.emb_decay_target, | ||
| 873 | emb_decay=args.emb_decay, | ||
| 878 | max_grad_norm=args.max_grad_norm, | 874 | max_grad_norm=args.max_grad_norm, |
| 879 | ) | 875 | ) |
| 880 | 876 | ||
