diff options
author | Volpeon <git@volpeon.ink> | 2023-05-16 07:12:14 +0200 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-05-16 07:12:14 +0200 |
commit | b31fcb741432076f7e2f3ec9423ad935a08c6671 (patch) | |
tree | 2ab052d3bd617a56c4ea388c200da52cff39ba37 /train_ti.py | |
parent | Fix for latest PEFT (diff) | |
download | textual-inversion-diff-b31fcb741432076f7e2f3ec9423ad935a08c6671.tar.gz textual-inversion-diff-b31fcb741432076f7e2f3ec9423ad935a08c6671.tar.bz2 textual-inversion-diff-b31fcb741432076f7e2f3ec9423ad935a08c6671.zip |
Support LoRA training for token embeddings
Diffstat (limited to 'train_ti.py')
-rw-r--r-- | train_ti.py | 10 |
1 files changed, 4 insertions, 6 deletions
diff --git a/train_ti.py b/train_ti.py index 6fd974e..f60e3e5 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -21,13 +21,14 @@ import transformers | |||
21 | import numpy as np | 21 | import numpy as np |
22 | from slugify import slugify | 22 | from slugify import slugify |
23 | 23 | ||
24 | from util.files import load_config, load_embeddings_from_dir | ||
25 | from data.csv import VlpnDataModule, keyword_filter | 24 | from data.csv import VlpnDataModule, keyword_filter |
25 | from models.clip.embeddings import patch_managed_embeddings | ||
26 | from training.functional import train, add_placeholder_tokens, get_models | 26 | from training.functional import train, add_placeholder_tokens, get_models |
27 | from training.strategy.ti import textual_inversion_strategy | 27 | from training.strategy.ti import textual_inversion_strategy |
28 | from training.optimization import get_scheduler | 28 | from training.optimization import get_scheduler |
29 | from training.sampler import create_named_schedule_sampler | 29 | from training.sampler import create_named_schedule_sampler |
30 | from training.util import AverageMeter, save_args | 30 | from training.util import AverageMeter, save_args |
31 | from util.files import load_config, load_embeddings_from_dir | ||
31 | 32 | ||
32 | logger = get_logger(__name__) | 33 | logger = get_logger(__name__) |
33 | 34 | ||
@@ -702,11 +703,8 @@ def main(): | |||
702 | 703 | ||
703 | save_args(output_dir, args) | 704 | save_args(output_dir, args) |
704 | 705 | ||
705 | tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( | 706 | tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler = get_models(args.pretrained_model_name_or_path) |
706 | args.pretrained_model_name_or_path, | 707 | embeddings = patch_managed_embeddings(text_encoder, args.emb_alpha, args.emb_dropout) |
707 | args.emb_alpha, | ||
708 | args.emb_dropout | ||
709 | ) | ||
710 | schedule_sampler = create_named_schedule_sampler(args.schedule_sampler, noise_scheduler.config.num_train_timesteps) | 708 | schedule_sampler = create_named_schedule_sampler(args.schedule_sampler, noise_scheduler.config.num_train_timesteps) |
711 | 709 | ||
712 | tokenizer.set_use_vector_shuffle(args.vector_shuffle) | 710 | tokenizer.set_use_vector_shuffle(args.vector_shuffle) |