From 99b4dba56e3e1e434820d1221d561e90f1a6d30a Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 15 Apr 2023 13:11:11 +0200 Subject: TI via LoRA --- train_ti.py | 28 ++++++++++++++++++++++------ 1 file changed, 22 insertions(+), 6 deletions(-) (limited to 'train_ti.py') diff --git a/train_ti.py b/train_ti.py index d931db6..6c57f4b 100644 --- a/train_ti.py +++ b/train_ti.py @@ -18,7 +18,6 @@ import transformers from util.files import load_config, load_embeddings_from_dir from data.csv import VlpnDataModule, keyword_filter -from models.convnext.discriminator import ConvNeXtDiscriminator from training.functional import train, add_placeholder_tokens, get_models from training.strategy.ti import textual_inversion_strategy from training.optimization import get_scheduler @@ -354,7 +353,7 @@ def parse_args(): parser.add_argument( "--optimizer", type=str, - default="dadan", + default="adan", choices=["adam", "adam8bit", "adan", "lion", "dadam", "dadan", "adafactor"], help='Optimizer to use' ) @@ -379,7 +378,7 @@ def parse_args(): parser.add_argument( "--adam_weight_decay", type=float, - default=0, + default=2e-2, help="Weight decay to use." ) parser.add_argument( @@ -483,7 +482,19 @@ def parse_args(): help="The weight of prior preservation loss." ) parser.add_argument( - "--emb_dropout", + "--lora_r", + type=int, + default=8, + help="Lora rank, only used if use_lora is True" + ) + parser.add_argument( + "--lora_alpha", + type=int, + default=32, + help="Lora alpha, only used if use_lora is True" + ) + parser.add_argument( + "--lora_dropout", type=float, default=0, help="Embedding dropout probability.", @@ -655,7 +666,11 @@ def main(): save_args(output_dir, args) tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( - args.pretrained_model_name_or_path, args.emb_dropout) + args.pretrained_model_name_or_path, + args.lora_r, + args.lora_alpha, + args.lora_dropout + ) tokenizer.set_use_vector_shuffle(args.vector_shuffle) tokenizer.set_dropout(args.vector_dropout) @@ -747,6 +762,7 @@ def main(): timm.optim.Adan, weight_decay=args.adam_weight_decay, eps=args.adam_epsilon, + no_prox=True, ) elif args.optimizer == 'lion': try: @@ -914,7 +930,7 @@ def main(): print("") optimizer = create_optimizer( - text_encoder.text_model.embeddings.token_override_embedding.parameters(), + text_encoder.text_model.embeddings.token_embedding.parameters(), lr=learning_rate, ) -- cgit v1.2.3-54-g00ecf