From 555912a86b012382a78f1b2717c2e0fde5994a04 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Tue, 17 Jan 2023 11:50:16 +0100 Subject: Make embedding decay work like Adam decay --- train_ti.py | 16 ++++------------ 1 file changed, 4 insertions(+), 12 deletions(-) (limited to 'train_ti.py') diff --git a/train_ti.py b/train_ti.py index 0891c49..fc34d27 100644 --- a/train_ti.py +++ b/train_ti.py @@ -159,7 +159,7 @@ def parse_args(): parser.add_argument( "--tag_dropout", type=float, - default=0, + default=0.1, help="Tag dropout probability.", ) parser.add_argument( @@ -406,17 +406,11 @@ def parse_args(): help="Embedding decay target." ) parser.add_argument( - "--emb_decay_factor", - default=1, + "--emb_decay", + default=1e-1, type=float, help="Embedding decay factor." ) - parser.add_argument( - "--emb_decay_start", - default=0, - type=float, - help="Embedding decay start offset." - ) parser.add_argument( "--noise_timesteps", type=int, @@ -587,12 +581,10 @@ def main(): tokenizer=tokenizer, sample_scheduler=sample_scheduler, checkpoint_output_dir=checkpoint_output_dir, - learning_rate=args.learning_rate, gradient_checkpointing=args.gradient_checkpointing, use_emb_decay=args.use_emb_decay, emb_decay_target=args.emb_decay_target, - emb_decay_factor=args.emb_decay_factor, - emb_decay_start=args.emb_decay_start, + emb_decay=args.emb_decay, use_ema=args.use_ema, ema_inv_gamma=args.ema_inv_gamma, ema_power=args.ema_power, -- cgit v1.2.3-54-g00ecf