diff options
Diffstat (limited to 'train_ti.py')
-rw-r--r-- | train_ti.py | 24 |
1 files changed, 19 insertions, 5 deletions
diff --git a/train_ti.py b/train_ti.py index a9a2333..4366c9e 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -353,7 +353,7 @@ def parse_args(): | |||
353 | parser.add_argument( | 353 | parser.add_argument( |
354 | "--adam_weight_decay", | 354 | "--adam_weight_decay", |
355 | type=float, | 355 | type=float, |
356 | default=1e-2, | 356 | default=0, |
357 | help="Weight decay to use." | 357 | help="Weight decay to use." |
358 | ) | 358 | ) |
359 | parser.add_argument( | 359 | parser.add_argument( |
@@ -451,10 +451,21 @@ def parse_args(): | |||
451 | help="The weight of prior preservation loss." | 451 | help="The weight of prior preservation loss." |
452 | ) | 452 | ) |
453 | parser.add_argument( | 453 | parser.add_argument( |
454 | "--emb_alpha", | 454 | "--use_emb_decay", |
455 | default=1.0, | 455 | action="store_true", |
456 | help="Whether to use embedding decay." | ||
457 | ) | ||
458 | parser.add_argument( | ||
459 | "--emb_decay_target", | ||
460 | default=0.4, | ||
461 | type=float, | ||
462 | help="Embedding decay target." | ||
463 | ) | ||
464 | parser.add_argument( | ||
465 | "--emb_decay", | ||
466 | default=1e+2, | ||
456 | type=float, | 467 | type=float, |
457 | help="Embedding alpha." | 468 | help="Embedding decay factor." |
458 | ) | 469 | ) |
459 | parser.add_argument( | 470 | parser.add_argument( |
460 | "--noise_timesteps", | 471 | "--noise_timesteps", |
@@ -600,7 +611,7 @@ def main(): | |||
600 | save_args(output_dir, args) | 611 | save_args(output_dir, args) |
601 | 612 | ||
602 | tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( | 613 | tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( |
603 | args.pretrained_model_name_or_path, args.emb_alpha) | 614 | args.pretrained_model_name_or_path) |
604 | 615 | ||
605 | tokenizer.set_use_vector_shuffle(args.vector_shuffle) | 616 | tokenizer.set_use_vector_shuffle(args.vector_shuffle) |
606 | tokenizer.set_dropout(args.vector_dropout) | 617 | tokenizer.set_dropout(args.vector_dropout) |
@@ -744,6 +755,9 @@ def main(): | |||
744 | tokenizer=tokenizer, | 755 | tokenizer=tokenizer, |
745 | sample_scheduler=sample_scheduler, | 756 | sample_scheduler=sample_scheduler, |
746 | checkpoint_output_dir=checkpoint_output_dir, | 757 | checkpoint_output_dir=checkpoint_output_dir, |
758 | use_emb_decay=args.use_emb_decay, | ||
759 | emb_decay_target=args.emb_decay_target, | ||
760 | emb_decay=args.emb_decay, | ||
747 | use_ema=args.use_ema, | 761 | use_ema=args.use_ema, |
748 | ema_inv_gamma=args.ema_inv_gamma, | 762 | ema_inv_gamma=args.ema_inv_gamma, |
749 | ema_power=args.ema_power, | 763 | ema_power=args.ema_power, |