From 950f1f6bcbb1a767170cea590b828d8e3cdae882 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Fri, 23 Jun 2023 06:48:38 +0200 Subject: Update --- train_dreambooth.py | 32 ++++++++++++++------------------ 1 file changed, 14 insertions(+), 18 deletions(-) (limited to 'train_dreambooth.py') diff --git a/train_dreambooth.py b/train_dreambooth.py index d284346..c8f03ea 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py @@ -144,12 +144,6 @@ def parse_args(): default=[], help="Tokens to create an alias for.", ) - parser.add_argument( - "--inverted_initializer_tokens", - type=str, - nargs="*", - help="A token to use as initializer word.", - ) parser.add_argument( "--num_vectors", type=int, nargs="*", help="Number of vectors per embedding." ) @@ -498,6 +492,15 @@ def parse_args(): default=0, help="Embedding dropout probability.", ) + parser.add_argument( + "--use_emb_decay", action="store_true", help="Whether to use embedding decay." + ) + parser.add_argument( + "--emb_decay_target", default=0.4, type=float, help="Embedding decay target." + ) + parser.add_argument( + "--emb_decay", default=1e2, type=float, help="Embedding decay factor." + ) parser.add_argument( "--max_grad_norm", default=1.0, type=float, help="Max gradient norm." ) @@ -554,18 +557,6 @@ def parse_args(): "--placeholder_tokens and --initializer_tokens must have the same number of items" ) - if isinstance(args.inverted_initializer_tokens, str): - args.inverted_initializer_tokens = [args.inverted_initializer_tokens] * len( - args.placeholder_tokens - ) - - if ( - isinstance(args.inverted_initializer_tokens, list) - and len(args.inverted_initializer_tokens) != 0 - ): - args.placeholder_tokens += [f"inv_{t}" for t in args.placeholder_tokens] - args.initializer_tokens += args.inverted_initializer_tokens - if isinstance(args.num_vectors, int): args.num_vectors = [args.num_vectors] * len(args.placeholder_tokens) @@ -875,6 +866,11 @@ def main(): sample_num_batches=args.sample_batches, sample_num_steps=args.sample_steps, sample_image_size=args.sample_image_size, + placeholder_tokens=placeholder_tokens, + placeholder_token_ids=placeholder_token_ids, + use_emb_decay=args.use_emb_decay, + emb_decay_target=args.emb_decay_target, + emb_decay=args.emb_decay, max_grad_norm=args.max_grad_norm, ) -- cgit v1.2.3-54-g00ecf