summaryrefslogtreecommitdiffstats
path: root/train_ti.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-04-15 13:11:11 +0200
committerVolpeon <git@volpeon.ink>2023-04-15 13:11:11 +0200
commit99b4dba56e3e1e434820d1221d561e90f1a6d30a (patch)
tree717a4099e9ebfedec702060fed5ed12aaceb0094 /train_ti.py
parentAdded cycle LR decay (diff)
downloadtextual-inversion-diff-99b4dba56e3e1e434820d1221d561e90f1a6d30a.tar.gz
textual-inversion-diff-99b4dba56e3e1e434820d1221d561e90f1a6d30a.tar.bz2
textual-inversion-diff-99b4dba56e3e1e434820d1221d561e90f1a6d30a.zip
TI via LoRA
Diffstat (limited to 'train_ti.py')
-rw-r--r--train_ti.py28
1 files changed, 22 insertions, 6 deletions
diff --git a/train_ti.py b/train_ti.py
index d931db6..6c57f4b 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -18,7 +18,6 @@ import transformers
18 18
19from util.files import load_config, load_embeddings_from_dir 19from util.files import load_config, load_embeddings_from_dir
20from data.csv import VlpnDataModule, keyword_filter 20from data.csv import VlpnDataModule, keyword_filter
21from models.convnext.discriminator import ConvNeXtDiscriminator
22from training.functional import train, add_placeholder_tokens, get_models 21from training.functional import train, add_placeholder_tokens, get_models
23from training.strategy.ti import textual_inversion_strategy 22from training.strategy.ti import textual_inversion_strategy
24from training.optimization import get_scheduler 23from training.optimization import get_scheduler
@@ -354,7 +353,7 @@ def parse_args():
354 parser.add_argument( 353 parser.add_argument(
355 "--optimizer", 354 "--optimizer",
356 type=str, 355 type=str,
357 default="dadan", 356 default="adan",
358 choices=["adam", "adam8bit", "adan", "lion", "dadam", "dadan", "adafactor"], 357 choices=["adam", "adam8bit", "adan", "lion", "dadam", "dadan", "adafactor"],
359 help='Optimizer to use' 358 help='Optimizer to use'
360 ) 359 )
@@ -379,7 +378,7 @@ def parse_args():
379 parser.add_argument( 378 parser.add_argument(
380 "--adam_weight_decay", 379 "--adam_weight_decay",
381 type=float, 380 type=float,
382 default=0, 381 default=2e-2,
383 help="Weight decay to use." 382 help="Weight decay to use."
384 ) 383 )
385 parser.add_argument( 384 parser.add_argument(
@@ -483,7 +482,19 @@ def parse_args():
483 help="The weight of prior preservation loss." 482 help="The weight of prior preservation loss."
484 ) 483 )
485 parser.add_argument( 484 parser.add_argument(
486 "--emb_dropout", 485 "--lora_r",
486 type=int,
487 default=8,
488 help="Lora rank, only used if use_lora is True"
489 )
490 parser.add_argument(
491 "--lora_alpha",
492 type=int,
493 default=32,
494 help="Lora alpha, only used if use_lora is True"
495 )
496 parser.add_argument(
497 "--lora_dropout",
487 type=float, 498 type=float,
488 default=0, 499 default=0,
489 help="Embedding dropout probability.", 500 help="Embedding dropout probability.",
@@ -655,7 +666,11 @@ def main():
655 save_args(output_dir, args) 666 save_args(output_dir, args)
656 667
657 tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( 668 tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models(
658 args.pretrained_model_name_or_path, args.emb_dropout) 669 args.pretrained_model_name_or_path,
670 args.lora_r,
671 args.lora_alpha,
672 args.lora_dropout
673 )
659 674
660 tokenizer.set_use_vector_shuffle(args.vector_shuffle) 675 tokenizer.set_use_vector_shuffle(args.vector_shuffle)
661 tokenizer.set_dropout(args.vector_dropout) 676 tokenizer.set_dropout(args.vector_dropout)
@@ -747,6 +762,7 @@ def main():
747 timm.optim.Adan, 762 timm.optim.Adan,
748 weight_decay=args.adam_weight_decay, 763 weight_decay=args.adam_weight_decay,
749 eps=args.adam_epsilon, 764 eps=args.adam_epsilon,
765 no_prox=True,
750 ) 766 )
751 elif args.optimizer == 'lion': 767 elif args.optimizer == 'lion':
752 try: 768 try:
@@ -914,7 +930,7 @@ def main():
914 print("") 930 print("")
915 931
916 optimizer = create_optimizer( 932 optimizer = create_optimizer(
917 text_encoder.text_model.embeddings.token_override_embedding.parameters(), 933 text_encoder.text_model.embeddings.token_embedding.parameters(),
918 lr=learning_rate, 934 lr=learning_rate,
919 ) 935 )
920 936