summaryrefslogtreecommitdiffstats
path: root/train_dreambooth.py
diff options
context:
space:
mode:
Diffstat (limited to 'train_dreambooth.py')
-rw-r--r--train_dreambooth.py32
1 files changed, 14 insertions, 18 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py
index d284346..c8f03ea 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -145,12 +145,6 @@ def parse_args():
145 help="Tokens to create an alias for.", 145 help="Tokens to create an alias for.",
146 ) 146 )
147 parser.add_argument( 147 parser.add_argument(
148 "--inverted_initializer_tokens",
149 type=str,
150 nargs="*",
151 help="A token to use as initializer word.",
152 )
153 parser.add_argument(
154 "--num_vectors", type=int, nargs="*", help="Number of vectors per embedding." 148 "--num_vectors", type=int, nargs="*", help="Number of vectors per embedding."
155 ) 149 )
156 parser.add_argument( 150 parser.add_argument(
@@ -499,6 +493,15 @@ def parse_args():
499 help="Embedding dropout probability.", 493 help="Embedding dropout probability.",
500 ) 494 )
501 parser.add_argument( 495 parser.add_argument(
496 "--use_emb_decay", action="store_true", help="Whether to use embedding decay."
497 )
498 parser.add_argument(
499 "--emb_decay_target", default=0.4, type=float, help="Embedding decay target."
500 )
501 parser.add_argument(
502 "--emb_decay", default=1e2, type=float, help="Embedding decay factor."
503 )
504 parser.add_argument(
502 "--max_grad_norm", default=1.0, type=float, help="Max gradient norm." 505 "--max_grad_norm", default=1.0, type=float, help="Max gradient norm."
503 ) 506 )
504 parser.add_argument( 507 parser.add_argument(
@@ -554,18 +557,6 @@ def parse_args():
554 "--placeholder_tokens and --initializer_tokens must have the same number of items" 557 "--placeholder_tokens and --initializer_tokens must have the same number of items"
555 ) 558 )
556 559
557 if isinstance(args.inverted_initializer_tokens, str):
558 args.inverted_initializer_tokens = [args.inverted_initializer_tokens] * len(
559 args.placeholder_tokens
560 )
561
562 if (
563 isinstance(args.inverted_initializer_tokens, list)
564 and len(args.inverted_initializer_tokens) != 0
565 ):
566 args.placeholder_tokens += [f"inv_{t}" for t in args.placeholder_tokens]
567 args.initializer_tokens += args.inverted_initializer_tokens
568
569 if isinstance(args.num_vectors, int): 560 if isinstance(args.num_vectors, int):
570 args.num_vectors = [args.num_vectors] * len(args.placeholder_tokens) 561 args.num_vectors = [args.num_vectors] * len(args.placeholder_tokens)
571 562
@@ -875,6 +866,11 @@ def main():
875 sample_num_batches=args.sample_batches, 866 sample_num_batches=args.sample_batches,
876 sample_num_steps=args.sample_steps, 867 sample_num_steps=args.sample_steps,
877 sample_image_size=args.sample_image_size, 868 sample_image_size=args.sample_image_size,
869 placeholder_tokens=placeholder_tokens,
870 placeholder_token_ids=placeholder_token_ids,
871 use_emb_decay=args.use_emb_decay,
872 emb_decay_target=args.emb_decay_target,
873 emb_decay=args.emb_decay,
878 max_grad_norm=args.max_grad_norm, 874 max_grad_norm=args.max_grad_norm,
879 ) 875 )
880 876