diff options
Diffstat (limited to 'train_ti.py')
-rw-r--r-- | train_ti.py | 37 |
1 files changed, 11 insertions, 26 deletions
diff --git a/train_ti.py b/train_ti.py index 8dde1ba..0ad7574 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -353,7 +353,7 @@ def parse_args(): | |||
353 | parser.add_argument( | 353 | parser.add_argument( |
354 | "--adam_weight_decay", | 354 | "--adam_weight_decay", |
355 | type=float, | 355 | type=float, |
356 | default=0, | 356 | default=1e-2, |
357 | help="Weight decay to use." | 357 | help="Weight decay to use." |
358 | ) | 358 | ) |
359 | parser.add_argument( | 359 | parser.add_argument( |
@@ -451,21 +451,10 @@ def parse_args(): | |||
451 | help="The weight of prior preservation loss." | 451 | help="The weight of prior preservation loss." |
452 | ) | 452 | ) |
453 | parser.add_argument( | 453 | parser.add_argument( |
454 | "--use_emb_decay", | 454 | "--emb_alpha", |
455 | action="store_true", | 455 | default=1.0, |
456 | help="Whether to use embedding decay." | ||
457 | ) | ||
458 | parser.add_argument( | ||
459 | "--emb_decay_target", | ||
460 | default=0.4, | ||
461 | type=float, | ||
462 | help="Embedding decay target." | ||
463 | ) | ||
464 | parser.add_argument( | ||
465 | "--emb_decay", | ||
466 | default=1e2, | ||
467 | type=float, | 456 | type=float, |
468 | help="Embedding decay factor." | 457 | help="Embedding alpha." |
469 | ) | 458 | ) |
470 | parser.add_argument( | 459 | parser.add_argument( |
471 | "--noise_timesteps", | 460 | "--noise_timesteps", |
@@ -567,16 +556,16 @@ def parse_args(): | |||
567 | raise ValueError("You must specify --output_dir") | 556 | raise ValueError("You must specify --output_dir") |
568 | 557 | ||
569 | if args.adam_beta1 is None: | 558 | if args.adam_beta1 is None: |
570 | if args.optimizer in ('adam', 'adam8bit'): | 559 | if args.optimizer == 'lion': |
571 | args.adam_beta1 = 0.9 | ||
572 | elif args.optimizer == 'lion': | ||
573 | args.adam_beta1 = 0.95 | 560 | args.adam_beta1 = 0.95 |
561 | else: | ||
562 | args.adam_beta1 = 0.9 | ||
574 | 563 | ||
575 | if args.adam_beta2 is None: | 564 | if args.adam_beta2 is None: |
576 | if args.optimizer in ('adam', 'adam8bit'): | 565 | if args.optimizer == 'lion': |
577 | args.adam_beta2 = 0.999 | ||
578 | elif args.optimizer == 'lion': | ||
579 | args.adam_beta2 = 0.98 | 566 | args.adam_beta2 = 0.98 |
567 | else: | ||
568 | args.adam_beta2 = 0.999 | ||
580 | 569 | ||
581 | return args | 570 | return args |
582 | 571 | ||
@@ -611,7 +600,7 @@ def main(): | |||
611 | save_args(output_dir, args) | 600 | save_args(output_dir, args) |
612 | 601 | ||
613 | tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( | 602 | tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( |
614 | args.pretrained_model_name_or_path) | 603 | args.pretrained_model_name_or_path, args.emb_alpha) |
615 | 604 | ||
616 | tokenizer.set_use_vector_shuffle(args.vector_shuffle) | 605 | tokenizer.set_use_vector_shuffle(args.vector_shuffle) |
617 | tokenizer.set_dropout(args.vector_dropout) | 606 | tokenizer.set_dropout(args.vector_dropout) |
@@ -755,10 +744,6 @@ def main(): | |||
755 | tokenizer=tokenizer, | 744 | tokenizer=tokenizer, |
756 | sample_scheduler=sample_scheduler, | 745 | sample_scheduler=sample_scheduler, |
757 | checkpoint_output_dir=checkpoint_output_dir, | 746 | checkpoint_output_dir=checkpoint_output_dir, |
758 | gradient_checkpointing=args.gradient_checkpointing, | ||
759 | use_emb_decay=args.use_emb_decay, | ||
760 | emb_decay_target=args.emb_decay_target, | ||
761 | emb_decay=args.emb_decay, | ||
762 | use_ema=args.use_ema, | 747 | use_ema=args.use_ema, |
763 | ema_inv_gamma=args.ema_inv_gamma, | 748 | ema_inv_gamma=args.ema_inv_gamma, |
764 | ema_power=args.ema_power, | 749 | ema_power=args.ema_power, |