diff options
Diffstat (limited to 'train_dreambooth.py')
-rw-r--r-- | train_dreambooth.py | 32 |
1 files changed, 14 insertions, 18 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py index d284346..c8f03ea 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
@@ -145,12 +145,6 @@ def parse_args(): | |||
145 | help="Tokens to create an alias for.", | 145 | help="Tokens to create an alias for.", |
146 | ) | 146 | ) |
147 | parser.add_argument( | 147 | parser.add_argument( |
148 | "--inverted_initializer_tokens", | ||
149 | type=str, | ||
150 | nargs="*", | ||
151 | help="A token to use as initializer word.", | ||
152 | ) | ||
153 | parser.add_argument( | ||
154 | "--num_vectors", type=int, nargs="*", help="Number of vectors per embedding." | 148 | "--num_vectors", type=int, nargs="*", help="Number of vectors per embedding." |
155 | ) | 149 | ) |
156 | parser.add_argument( | 150 | parser.add_argument( |
@@ -499,6 +493,15 @@ def parse_args(): | |||
499 | help="Embedding dropout probability.", | 493 | help="Embedding dropout probability.", |
500 | ) | 494 | ) |
501 | parser.add_argument( | 495 | parser.add_argument( |
496 | "--use_emb_decay", action="store_true", help="Whether to use embedding decay." | ||
497 | ) | ||
498 | parser.add_argument( | ||
499 | "--emb_decay_target", default=0.4, type=float, help="Embedding decay target." | ||
500 | ) | ||
501 | parser.add_argument( | ||
502 | "--emb_decay", default=1e2, type=float, help="Embedding decay factor." | ||
503 | ) | ||
504 | parser.add_argument( | ||
502 | "--max_grad_norm", default=1.0, type=float, help="Max gradient norm." | 505 | "--max_grad_norm", default=1.0, type=float, help="Max gradient norm." |
503 | ) | 506 | ) |
504 | parser.add_argument( | 507 | parser.add_argument( |
@@ -554,18 +557,6 @@ def parse_args(): | |||
554 | "--placeholder_tokens and --initializer_tokens must have the same number of items" | 557 | "--placeholder_tokens and --initializer_tokens must have the same number of items" |
555 | ) | 558 | ) |
556 | 559 | ||
557 | if isinstance(args.inverted_initializer_tokens, str): | ||
558 | args.inverted_initializer_tokens = [args.inverted_initializer_tokens] * len( | ||
559 | args.placeholder_tokens | ||
560 | ) | ||
561 | |||
562 | if ( | ||
563 | isinstance(args.inverted_initializer_tokens, list) | ||
564 | and len(args.inverted_initializer_tokens) != 0 | ||
565 | ): | ||
566 | args.placeholder_tokens += [f"inv_{t}" for t in args.placeholder_tokens] | ||
567 | args.initializer_tokens += args.inverted_initializer_tokens | ||
568 | |||
569 | if isinstance(args.num_vectors, int): | 560 | if isinstance(args.num_vectors, int): |
570 | args.num_vectors = [args.num_vectors] * len(args.placeholder_tokens) | 561 | args.num_vectors = [args.num_vectors] * len(args.placeholder_tokens) |
571 | 562 | ||
@@ -875,6 +866,11 @@ def main(): | |||
875 | sample_num_batches=args.sample_batches, | 866 | sample_num_batches=args.sample_batches, |
876 | sample_num_steps=args.sample_steps, | 867 | sample_num_steps=args.sample_steps, |
877 | sample_image_size=args.sample_image_size, | 868 | sample_image_size=args.sample_image_size, |
869 | placeholder_tokens=placeholder_tokens, | ||
870 | placeholder_token_ids=placeholder_token_ids, | ||
871 | use_emb_decay=args.use_emb_decay, | ||
872 | emb_decay_target=args.emb_decay_target, | ||
873 | emb_decay=args.emb_decay, | ||
878 | max_grad_norm=args.max_grad_norm, | 874 | max_grad_norm=args.max_grad_norm, |
879 | ) | 875 | ) |
880 | 876 | ||