diff options
| author | Volpeon <git@volpeon.ink> | 2023-05-16 07:12:14 +0200 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2023-05-16 07:12:14 +0200 |
| commit | b31fcb741432076f7e2f3ec9423ad935a08c6671 (patch) | |
| tree | 2ab052d3bd617a56c4ea388c200da52cff39ba37 /train_ti.py | |
| parent | Fix for latest PEFT (diff) | |
| download | textual-inversion-diff-b31fcb741432076f7e2f3ec9423ad935a08c6671.tar.gz textual-inversion-diff-b31fcb741432076f7e2f3ec9423ad935a08c6671.tar.bz2 textual-inversion-diff-b31fcb741432076f7e2f3ec9423ad935a08c6671.zip | |
Support LoRA training for token embeddings
Diffstat (limited to 'train_ti.py')
| -rw-r--r-- | train_ti.py | 10 |
1 files changed, 4 insertions, 6 deletions
diff --git a/train_ti.py b/train_ti.py index 6fd974e..f60e3e5 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -21,13 +21,14 @@ import transformers | |||
| 21 | import numpy as np | 21 | import numpy as np |
| 22 | from slugify import slugify | 22 | from slugify import slugify |
| 23 | 23 | ||
| 24 | from util.files import load_config, load_embeddings_from_dir | ||
| 25 | from data.csv import VlpnDataModule, keyword_filter | 24 | from data.csv import VlpnDataModule, keyword_filter |
| 25 | from models.clip.embeddings import patch_managed_embeddings | ||
| 26 | from training.functional import train, add_placeholder_tokens, get_models | 26 | from training.functional import train, add_placeholder_tokens, get_models |
| 27 | from training.strategy.ti import textual_inversion_strategy | 27 | from training.strategy.ti import textual_inversion_strategy |
| 28 | from training.optimization import get_scheduler | 28 | from training.optimization import get_scheduler |
| 29 | from training.sampler import create_named_schedule_sampler | 29 | from training.sampler import create_named_schedule_sampler |
| 30 | from training.util import AverageMeter, save_args | 30 | from training.util import AverageMeter, save_args |
| 31 | from util.files import load_config, load_embeddings_from_dir | ||
| 31 | 32 | ||
| 32 | logger = get_logger(__name__) | 33 | logger = get_logger(__name__) |
| 33 | 34 | ||
| @@ -702,11 +703,8 @@ def main(): | |||
| 702 | 703 | ||
| 703 | save_args(output_dir, args) | 704 | save_args(output_dir, args) |
| 704 | 705 | ||
| 705 | tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( | 706 | tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler = get_models(args.pretrained_model_name_or_path) |
| 706 | args.pretrained_model_name_or_path, | 707 | embeddings = patch_managed_embeddings(text_encoder, args.emb_alpha, args.emb_dropout) |
| 707 | args.emb_alpha, | ||
| 708 | args.emb_dropout | ||
| 709 | ) | ||
| 710 | schedule_sampler = create_named_schedule_sampler(args.schedule_sampler, noise_scheduler.config.num_train_timesteps) | 708 | schedule_sampler = create_named_schedule_sampler(args.schedule_sampler, noise_scheduler.config.num_train_timesteps) |
| 711 | 709 | ||
| 712 | tokenizer.set_use_vector_shuffle(args.vector_shuffle) | 710 | tokenizer.set_use_vector_shuffle(args.vector_shuffle) |
