diff options
Diffstat (limited to 'train_ti.py')
| -rw-r--r-- | train_ti.py | 28 |
1 files changed, 22 insertions, 6 deletions
diff --git a/train_ti.py b/train_ti.py index d931db6..6c57f4b 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -18,7 +18,6 @@ import transformers | |||
| 18 | 18 | ||
| 19 | from util.files import load_config, load_embeddings_from_dir | 19 | from util.files import load_config, load_embeddings_from_dir |
| 20 | from data.csv import VlpnDataModule, keyword_filter | 20 | from data.csv import VlpnDataModule, keyword_filter |
| 21 | from models.convnext.discriminator import ConvNeXtDiscriminator | ||
| 22 | from training.functional import train, add_placeholder_tokens, get_models | 21 | from training.functional import train, add_placeholder_tokens, get_models |
| 23 | from training.strategy.ti import textual_inversion_strategy | 22 | from training.strategy.ti import textual_inversion_strategy |
| 24 | from training.optimization import get_scheduler | 23 | from training.optimization import get_scheduler |
| @@ -354,7 +353,7 @@ def parse_args(): | |||
| 354 | parser.add_argument( | 353 | parser.add_argument( |
| 355 | "--optimizer", | 354 | "--optimizer", |
| 356 | type=str, | 355 | type=str, |
| 357 | default="dadan", | 356 | default="adan", |
| 358 | choices=["adam", "adam8bit", "adan", "lion", "dadam", "dadan", "adafactor"], | 357 | choices=["adam", "adam8bit", "adan", "lion", "dadam", "dadan", "adafactor"], |
| 359 | help='Optimizer to use' | 358 | help='Optimizer to use' |
| 360 | ) | 359 | ) |
| @@ -379,7 +378,7 @@ def parse_args(): | |||
| 379 | parser.add_argument( | 378 | parser.add_argument( |
| 380 | "--adam_weight_decay", | 379 | "--adam_weight_decay", |
| 381 | type=float, | 380 | type=float, |
| 382 | default=0, | 381 | default=2e-2, |
| 383 | help="Weight decay to use." | 382 | help="Weight decay to use." |
| 384 | ) | 383 | ) |
| 385 | parser.add_argument( | 384 | parser.add_argument( |
| @@ -483,7 +482,19 @@ def parse_args(): | |||
| 483 | help="The weight of prior preservation loss." | 482 | help="The weight of prior preservation loss." |
| 484 | ) | 483 | ) |
| 485 | parser.add_argument( | 484 | parser.add_argument( |
| 486 | "--emb_dropout", | 485 | "--lora_r", |
| 486 | type=int, | ||
| 487 | default=8, | ||
| 488 | help="Lora rank, only used if use_lora is True" | ||
| 489 | ) | ||
| 490 | parser.add_argument( | ||
| 491 | "--lora_alpha", | ||
| 492 | type=int, | ||
| 493 | default=32, | ||
| 494 | help="Lora alpha, only used if use_lora is True" | ||
| 495 | ) | ||
| 496 | parser.add_argument( | ||
| 497 | "--lora_dropout", | ||
| 487 | type=float, | 498 | type=float, |
| 488 | default=0, | 499 | default=0, |
| 489 | help="Embedding dropout probability.", | 500 | help="Embedding dropout probability.", |
| @@ -655,7 +666,11 @@ def main(): | |||
| 655 | save_args(output_dir, args) | 666 | save_args(output_dir, args) |
| 656 | 667 | ||
| 657 | tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( | 668 | tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( |
| 658 | args.pretrained_model_name_or_path, args.emb_dropout) | 669 | args.pretrained_model_name_or_path, |
| 670 | args.lora_r, | ||
| 671 | args.lora_alpha, | ||
| 672 | args.lora_dropout | ||
| 673 | ) | ||
| 659 | 674 | ||
| 660 | tokenizer.set_use_vector_shuffle(args.vector_shuffle) | 675 | tokenizer.set_use_vector_shuffle(args.vector_shuffle) |
| 661 | tokenizer.set_dropout(args.vector_dropout) | 676 | tokenizer.set_dropout(args.vector_dropout) |
| @@ -747,6 +762,7 @@ def main(): | |||
| 747 | timm.optim.Adan, | 762 | timm.optim.Adan, |
| 748 | weight_decay=args.adam_weight_decay, | 763 | weight_decay=args.adam_weight_decay, |
| 749 | eps=args.adam_epsilon, | 764 | eps=args.adam_epsilon, |
| 765 | no_prox=True, | ||
| 750 | ) | 766 | ) |
| 751 | elif args.optimizer == 'lion': | 767 | elif args.optimizer == 'lion': |
| 752 | try: | 768 | try: |
| @@ -914,7 +930,7 @@ def main(): | |||
| 914 | print("") | 930 | print("") |
| 915 | 931 | ||
| 916 | optimizer = create_optimizer( | 932 | optimizer = create_optimizer( |
| 917 | text_encoder.text_model.embeddings.token_override_embedding.parameters(), | 933 | text_encoder.text_model.embeddings.token_embedding.parameters(), |
| 918 | lr=learning_rate, | 934 | lr=learning_rate, |
| 919 | ) | 935 | ) |
| 920 | 936 | ||
