diff options
author | Volpeon <git@volpeon.ink> | 2023-04-15 13:11:11 +0200 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-04-15 13:11:11 +0200 |
commit | 99b4dba56e3e1e434820d1221d561e90f1a6d30a (patch) | |
tree | 717a4099e9ebfedec702060fed5ed12aaceb0094 /train_ti.py | |
parent | Added cycle LR decay (diff) | |
download | textual-inversion-diff-99b4dba56e3e1e434820d1221d561e90f1a6d30a.tar.gz textual-inversion-diff-99b4dba56e3e1e434820d1221d561e90f1a6d30a.tar.bz2 textual-inversion-diff-99b4dba56e3e1e434820d1221d561e90f1a6d30a.zip |
TI via LoRA
Diffstat (limited to 'train_ti.py')
-rw-r--r-- | train_ti.py | 28 |
1 files changed, 22 insertions, 6 deletions
diff --git a/train_ti.py b/train_ti.py index d931db6..6c57f4b 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -18,7 +18,6 @@ import transformers | |||
18 | 18 | ||
19 | from util.files import load_config, load_embeddings_from_dir | 19 | from util.files import load_config, load_embeddings_from_dir |
20 | from data.csv import VlpnDataModule, keyword_filter | 20 | from data.csv import VlpnDataModule, keyword_filter |
21 | from models.convnext.discriminator import ConvNeXtDiscriminator | ||
22 | from training.functional import train, add_placeholder_tokens, get_models | 21 | from training.functional import train, add_placeholder_tokens, get_models |
23 | from training.strategy.ti import textual_inversion_strategy | 22 | from training.strategy.ti import textual_inversion_strategy |
24 | from training.optimization import get_scheduler | 23 | from training.optimization import get_scheduler |
@@ -354,7 +353,7 @@ def parse_args(): | |||
354 | parser.add_argument( | 353 | parser.add_argument( |
355 | "--optimizer", | 354 | "--optimizer", |
356 | type=str, | 355 | type=str, |
357 | default="dadan", | 356 | default="adan", |
358 | choices=["adam", "adam8bit", "adan", "lion", "dadam", "dadan", "adafactor"], | 357 | choices=["adam", "adam8bit", "adan", "lion", "dadam", "dadan", "adafactor"], |
359 | help='Optimizer to use' | 358 | help='Optimizer to use' |
360 | ) | 359 | ) |
@@ -379,7 +378,7 @@ def parse_args(): | |||
379 | parser.add_argument( | 378 | parser.add_argument( |
380 | "--adam_weight_decay", | 379 | "--adam_weight_decay", |
381 | type=float, | 380 | type=float, |
382 | default=0, | 381 | default=2e-2, |
383 | help="Weight decay to use." | 382 | help="Weight decay to use." |
384 | ) | 383 | ) |
385 | parser.add_argument( | 384 | parser.add_argument( |
@@ -483,7 +482,19 @@ def parse_args(): | |||
483 | help="The weight of prior preservation loss." | 482 | help="The weight of prior preservation loss." |
484 | ) | 483 | ) |
485 | parser.add_argument( | 484 | parser.add_argument( |
486 | "--emb_dropout", | 485 | "--lora_r", |
486 | type=int, | ||
487 | default=8, | ||
488 | help="Lora rank, only used if use_lora is True" | ||
489 | ) | ||
490 | parser.add_argument( | ||
491 | "--lora_alpha", | ||
492 | type=int, | ||
493 | default=32, | ||
494 | help="Lora alpha, only used if use_lora is True" | ||
495 | ) | ||
496 | parser.add_argument( | ||
497 | "--lora_dropout", | ||
487 | type=float, | 498 | type=float, |
488 | default=0, | 499 | default=0, |
489 | help="Embedding dropout probability.", | 500 | help="Embedding dropout probability.", |
@@ -655,7 +666,11 @@ def main(): | |||
655 | save_args(output_dir, args) | 666 | save_args(output_dir, args) |
656 | 667 | ||
657 | tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( | 668 | tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( |
658 | args.pretrained_model_name_or_path, args.emb_dropout) | 669 | args.pretrained_model_name_or_path, |
670 | args.lora_r, | ||
671 | args.lora_alpha, | ||
672 | args.lora_dropout | ||
673 | ) | ||
659 | 674 | ||
660 | tokenizer.set_use_vector_shuffle(args.vector_shuffle) | 675 | tokenizer.set_use_vector_shuffle(args.vector_shuffle) |
661 | tokenizer.set_dropout(args.vector_dropout) | 676 | tokenizer.set_dropout(args.vector_dropout) |
@@ -747,6 +762,7 @@ def main(): | |||
747 | timm.optim.Adan, | 762 | timm.optim.Adan, |
748 | weight_decay=args.adam_weight_decay, | 763 | weight_decay=args.adam_weight_decay, |
749 | eps=args.adam_epsilon, | 764 | eps=args.adam_epsilon, |
765 | no_prox=True, | ||
750 | ) | 766 | ) |
751 | elif args.optimizer == 'lion': | 767 | elif args.optimizer == 'lion': |
752 | try: | 768 | try: |
@@ -914,7 +930,7 @@ def main(): | |||
914 | print("") | 930 | print("") |
915 | 931 | ||
916 | optimizer = create_optimizer( | 932 | optimizer = create_optimizer( |
917 | text_encoder.text_model.embeddings.token_override_embedding.parameters(), | 933 | text_encoder.text_model.embeddings.token_embedding.parameters(), |
918 | lr=learning_rate, | 934 | lr=learning_rate, |
919 | ) | 935 | ) |
920 | 936 | ||