diff options
Diffstat (limited to 'train_ti.py')
-rw-r--r-- | train_ti.py | 24 |
1 files changed, 2 insertions, 22 deletions
diff --git a/train_ti.py b/train_ti.py index 5482326..0ce0056 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -353,7 +353,7 @@ def parse_args(): | |||
353 | parser.add_argument( | 353 | parser.add_argument( |
354 | "--adam_weight_decay", | 354 | "--adam_weight_decay", |
355 | type=float, | 355 | type=float, |
356 | default=0, | 356 | default=1e-2, |
357 | help="Weight decay to use." | 357 | help="Weight decay to use." |
358 | ) | 358 | ) |
359 | parser.add_argument( | 359 | parser.add_argument( |
@@ -451,23 +451,6 @@ def parse_args(): | |||
451 | help="The weight of prior preservation loss." | 451 | help="The weight of prior preservation loss." |
452 | ) | 452 | ) |
453 | parser.add_argument( | 453 | parser.add_argument( |
454 | "--use_emb_decay", | ||
455 | action="store_true", | ||
456 | help="Whether to use embedding decay." | ||
457 | ) | ||
458 | parser.add_argument( | ||
459 | "--emb_decay_target", | ||
460 | default=0.4, | ||
461 | type=float, | ||
462 | help="Embedding decay target." | ||
463 | ) | ||
464 | parser.add_argument( | ||
465 | "--emb_decay", | ||
466 | default=1e2, | ||
467 | type=float, | ||
468 | help="Embedding decay factor." | ||
469 | ) | ||
470 | parser.add_argument( | ||
471 | "--noise_timesteps", | 454 | "--noise_timesteps", |
472 | type=int, | 455 | type=int, |
473 | default=1000, | 456 | default=1000, |
@@ -732,9 +715,6 @@ def main(): | |||
732 | sample_scheduler=sample_scheduler, | 715 | sample_scheduler=sample_scheduler, |
733 | checkpoint_output_dir=checkpoint_output_dir, | 716 | checkpoint_output_dir=checkpoint_output_dir, |
734 | gradient_checkpointing=args.gradient_checkpointing, | 717 | gradient_checkpointing=args.gradient_checkpointing, |
735 | use_emb_decay=args.use_emb_decay, | ||
736 | emb_decay_target=args.emb_decay_target, | ||
737 | emb_decay=args.emb_decay, | ||
738 | use_ema=args.use_ema, | 718 | use_ema=args.use_ema, |
739 | ema_inv_gamma=args.ema_inv_gamma, | 719 | ema_inv_gamma=args.ema_inv_gamma, |
740 | ema_power=args.ema_power, | 720 | ema_power=args.ema_power, |
@@ -800,7 +780,7 @@ def main(): | |||
800 | sample_frequency = math.ceil(num_train_epochs * (sample_frequency / args.num_train_steps)) | 780 | sample_frequency = math.ceil(num_train_epochs * (sample_frequency / args.num_train_steps)) |
801 | 781 | ||
802 | optimizer = create_optimizer( | 782 | optimizer = create_optimizer( |
803 | text_encoder.text_model.embeddings.temp_token_embedding.parameters(), | 783 | text_encoder.text_model.embeddings.overlay.parameters(), |
804 | lr=args.learning_rate, | 784 | lr=args.learning_rate, |
805 | ) | 785 | ) |
806 | 786 | ||