From b31fcb741432076f7e2f3ec9423ad935a08c6671 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Tue, 16 May 2023 07:12:14 +0200 Subject: Support LoRA training for token embeddings --- train_ti.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) (limited to 'train_ti.py') diff --git a/train_ti.py b/train_ti.py index 6fd974e..f60e3e5 100644 --- a/train_ti.py +++ b/train_ti.py @@ -21,13 +21,14 @@ import transformers import numpy as np from slugify import slugify -from util.files import load_config, load_embeddings_from_dir from data.csv import VlpnDataModule, keyword_filter +from models.clip.embeddings import patch_managed_embeddings from training.functional import train, add_placeholder_tokens, get_models from training.strategy.ti import textual_inversion_strategy from training.optimization import get_scheduler from training.sampler import create_named_schedule_sampler from training.util import AverageMeter, save_args +from util.files import load_config, load_embeddings_from_dir logger = get_logger(__name__) @@ -702,11 +703,8 @@ def main(): save_args(output_dir, args) - tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( - args.pretrained_model_name_or_path, - args.emb_alpha, - args.emb_dropout - ) + tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler = get_models(args.pretrained_model_name_or_path) + embeddings = patch_managed_embeddings(text_encoder, args.emb_alpha, args.emb_dropout) schedule_sampler = create_named_schedule_sampler(args.schedule_sampler, noise_scheduler.config.num_train_timesteps) tokenizer.set_use_vector_shuffle(args.vector_shuffle) -- cgit v1.2.3-70-g09d2