summaryrefslogtreecommitdiffstats
path: root/train_ti.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-05-16 07:12:14 +0200
committerVolpeon <git@volpeon.ink>2023-05-16 07:12:14 +0200
commitb31fcb741432076f7e2f3ec9423ad935a08c6671 (patch)
tree2ab052d3bd617a56c4ea388c200da52cff39ba37 /train_ti.py
parentFix for latest PEFT (diff)
downloadtextual-inversion-diff-b31fcb741432076f7e2f3ec9423ad935a08c6671.tar.gz
textual-inversion-diff-b31fcb741432076f7e2f3ec9423ad935a08c6671.tar.bz2
textual-inversion-diff-b31fcb741432076f7e2f3ec9423ad935a08c6671.zip
Support LoRA training for token embeddings
Diffstat (limited to 'train_ti.py')
-rw-r--r--train_ti.py10
1 files changed, 4 insertions, 6 deletions
diff --git a/train_ti.py b/train_ti.py
index 6fd974e..f60e3e5 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -21,13 +21,14 @@ import transformers
21import numpy as np 21import numpy as np
22from slugify import slugify 22from slugify import slugify
23 23
24from util.files import load_config, load_embeddings_from_dir
25from data.csv import VlpnDataModule, keyword_filter 24from data.csv import VlpnDataModule, keyword_filter
25from models.clip.embeddings import patch_managed_embeddings
26from training.functional import train, add_placeholder_tokens, get_models 26from training.functional import train, add_placeholder_tokens, get_models
27from training.strategy.ti import textual_inversion_strategy 27from training.strategy.ti import textual_inversion_strategy
28from training.optimization import get_scheduler 28from training.optimization import get_scheduler
29from training.sampler import create_named_schedule_sampler 29from training.sampler import create_named_schedule_sampler
30from training.util import AverageMeter, save_args 30from training.util import AverageMeter, save_args
31from util.files import load_config, load_embeddings_from_dir
31 32
32logger = get_logger(__name__) 33logger = get_logger(__name__)
33 34
@@ -702,11 +703,8 @@ def main():
702 703
703 save_args(output_dir, args) 704 save_args(output_dir, args)
704 705
705 tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( 706 tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler = get_models(args.pretrained_model_name_or_path)
706 args.pretrained_model_name_or_path, 707 embeddings = patch_managed_embeddings(text_encoder, args.emb_alpha, args.emb_dropout)
707 args.emb_alpha,
708 args.emb_dropout
709 )
710 schedule_sampler = create_named_schedule_sampler(args.schedule_sampler, noise_scheduler.config.num_train_timesteps) 708 schedule_sampler = create_named_schedule_sampler(args.schedule_sampler, noise_scheduler.config.num_train_timesteps)
711 709
712 tokenizer.set_use_vector_shuffle(args.vector_shuffle) 710 tokenizer.set_use_vector_shuffle(args.vector_shuffle)