diff options
Diffstat (limited to 'train_ti.py')
-rw-r--r-- | train_ti.py | 16 |
1 files changed, 4 insertions, 12 deletions
diff --git a/train_ti.py b/train_ti.py index 0891c49..fc34d27 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -159,7 +159,7 @@ def parse_args(): | |||
159 | parser.add_argument( | 159 | parser.add_argument( |
160 | "--tag_dropout", | 160 | "--tag_dropout", |
161 | type=float, | 161 | type=float, |
162 | default=0, | 162 | default=0.1, |
163 | help="Tag dropout probability.", | 163 | help="Tag dropout probability.", |
164 | ) | 164 | ) |
165 | parser.add_argument( | 165 | parser.add_argument( |
@@ -406,18 +406,12 @@ def parse_args(): | |||
406 | help="Embedding decay target." | 406 | help="Embedding decay target." |
407 | ) | 407 | ) |
408 | parser.add_argument( | 408 | parser.add_argument( |
409 | "--emb_decay_factor", | 409 | "--emb_decay", |
410 | default=1, | 410 | default=1e-1, |
411 | type=float, | 411 | type=float, |
412 | help="Embedding decay factor." | 412 | help="Embedding decay factor." |
413 | ) | 413 | ) |
414 | parser.add_argument( | 414 | parser.add_argument( |
415 | "--emb_decay_start", | ||
416 | default=0, | ||
417 | type=float, | ||
418 | help="Embedding decay start offset." | ||
419 | ) | ||
420 | parser.add_argument( | ||
421 | "--noise_timesteps", | 415 | "--noise_timesteps", |
422 | type=int, | 416 | type=int, |
423 | default=1000, | 417 | default=1000, |
@@ -587,12 +581,10 @@ def main(): | |||
587 | tokenizer=tokenizer, | 581 | tokenizer=tokenizer, |
588 | sample_scheduler=sample_scheduler, | 582 | sample_scheduler=sample_scheduler, |
589 | checkpoint_output_dir=checkpoint_output_dir, | 583 | checkpoint_output_dir=checkpoint_output_dir, |
590 | learning_rate=args.learning_rate, | ||
591 | gradient_checkpointing=args.gradient_checkpointing, | 584 | gradient_checkpointing=args.gradient_checkpointing, |
592 | use_emb_decay=args.use_emb_decay, | 585 | use_emb_decay=args.use_emb_decay, |
593 | emb_decay_target=args.emb_decay_target, | 586 | emb_decay_target=args.emb_decay_target, |
594 | emb_decay_factor=args.emb_decay_factor, | 587 | emb_decay=args.emb_decay, |
595 | emb_decay_start=args.emb_decay_start, | ||
596 | use_ema=args.use_ema, | 588 | use_ema=args.use_ema, |
597 | ema_inv_gamma=args.ema_inv_gamma, | 589 | ema_inv_gamma=args.ema_inv_gamma, |
598 | ema_power=args.ema_power, | 590 | ema_power=args.ema_power, |