From 30b557c8e1f03b4748ac3efca599ff51d66561cb Mon Sep 17 00:00:00 2001 From: Volpeon Date: Tue, 4 Apr 2023 07:30:43 +0200 Subject: TI: Bring back old embedding decay --- train_ti.py | 24 +++++++++++++++++++----- 1 file changed, 19 insertions(+), 5 deletions(-) (limited to 'train_ti.py') diff --git a/train_ti.py b/train_ti.py index a9a2333..4366c9e 100644 --- a/train_ti.py +++ b/train_ti.py @@ -353,7 +353,7 @@ def parse_args(): parser.add_argument( "--adam_weight_decay", type=float, - default=1e-2, + default=0, help="Weight decay to use." ) parser.add_argument( @@ -451,10 +451,21 @@ def parse_args(): help="The weight of prior preservation loss." ) parser.add_argument( - "--emb_alpha", - default=1.0, + "--use_emb_decay", + action="store_true", + help="Whether to use embedding decay." + ) + parser.add_argument( + "--emb_decay_target", + default=0.4, + type=float, + help="Embedding decay target." + ) + parser.add_argument( + "--emb_decay", + default=1e+2, type=float, - help="Embedding alpha." + help="Embedding decay factor." ) parser.add_argument( "--noise_timesteps", @@ -600,7 +611,7 @@ def main(): save_args(output_dir, args) tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( - args.pretrained_model_name_or_path, args.emb_alpha) + args.pretrained_model_name_or_path) tokenizer.set_use_vector_shuffle(args.vector_shuffle) tokenizer.set_dropout(args.vector_dropout) @@ -744,6 +755,9 @@ def main(): tokenizer=tokenizer, sample_scheduler=sample_scheduler, checkpoint_output_dir=checkpoint_output_dir, + use_emb_decay=args.use_emb_decay, + emb_decay_target=args.emb_decay_target, + emb_decay=args.emb_decay, use_ema=args.use_ema, ema_inv_gamma=args.ema_inv_gamma, ema_power=args.ema_power, -- cgit v1.2.3-54-g00ecf