From b31fcb741432076f7e2f3ec9423ad935a08c6671 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Tue, 16 May 2023 07:12:14 +0200 Subject: Support LoRA training for token embeddings --- models/clip/embeddings.py | 3 +++ train_lora.py | 32 ++++++++++++++++++++++++-------- train_ti.py | 10 ++++------ training/functional.py | 12 +++--------- training/strategy/lora.py | 4 ++-- 5 files changed, 36 insertions(+), 25 deletions(-) diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index 2b23bd3..7c7f2ac 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py @@ -86,6 +86,9 @@ def patch_managed_embeddings( alpha: int = 8, dropout: float = 0.0 ) -> ManagedCLIPTextEmbeddings: + if isinstance(text_encoder.text_model.embeddings, ManagedCLIPTextEmbeddings): + return text_encoder.text_model.embeddings + text_embeddings = ManagedCLIPTextEmbeddings( text_encoder.config, text_encoder.text_model.embeddings, diff --git a/train_lora.py b/train_lora.py index dea58cf..167b17a 100644 --- a/train_lora.py +++ b/train_lora.py @@ -22,17 +22,19 @@ import transformers import numpy as np from slugify import slugify -from util.files import load_config, load_embeddings_from_dir from data.csv import VlpnDataModule, keyword_filter +from models.clip.embeddings import patch_managed_embeddings from training.functional import train, add_placeholder_tokens, get_models from training.strategy.lora import lora_strategy from training.optimization import get_scheduler from training.sampler import create_named_schedule_sampler from training.util import AverageMeter, save_args +from util.files import load_config, load_embeddings_from_dir # https://github.com/huggingface/peft/blob/main/examples/lora_dreambooth/train_dreambooth.py UNET_TARGET_MODULES = ["to_q", "to_v", "query", "value"] TEXT_ENCODER_TARGET_MODULES = ["q_proj", "v_proj"] +TEXT_ENCODER_TARGET_MODULES_WITH_EMBEDDING = TEXT_ENCODER_TARGET_MODULES + ["token_embedding"] logger = get_logger(__name__) @@ -44,9 +46,9 @@ torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.benchmark = True torch._dynamo.config.log_level = logging.WARNING +torch._dynamo.config.suppress_errors = True hidet.torch.dynamo_config.use_tensor_core(True) -hidet.torch.dynamo_config.use_attention(True) hidet.torch.dynamo_config.search_space(0) @@ -321,6 +323,11 @@ def parse_args(): default="none", help="Bias type for Lora. Can be 'none', 'all' or 'lora_only', only used if use_lora and `train_text_encoder` are True", ) + parser.add_argument( + "--lora_text_encoder_emb", + action="store_true", + help="Include token embeddings in training. Prevents usage of TI techniques.", + ) parser.add_argument( "--train_text_encoder_cycles", default=999999, @@ -717,12 +724,13 @@ def main(): save_args(output_dir, args) - tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( - args.pretrained_model_name_or_path, - args.emb_alpha, - args.emb_dropout - ) + tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler = get_models(args.pretrained_model_name_or_path) schedule_sampler = create_named_schedule_sampler(args.schedule_sampler, noise_scheduler.config.num_train_timesteps) + + def ensure_embeddings(): + if args.lora_text_encoder_emb: + raise ValueError("Can't use TI options when training token embeddings with LoRA") + return patch_managed_embeddings(text_encoder, args.emb_alpha, args.emb_dropout) unet_config = LoraConfig( r=args.lora_r, @@ -736,7 +744,7 @@ def main(): text_encoder_config = LoraConfig( r=args.lora_text_encoder_r, lora_alpha=args.lora_text_encoder_alpha, - target_modules=TEXT_ENCODER_TARGET_MODULES, + target_modules=TEXT_ENCODER_TARGET_MODULES_WITH_EMBEDDING if args.lora_text_encoder_emb else TEXT_ENCODER_TARGET_MODULES, lora_dropout=args.lora_text_encoder_dropout, bias=args.lora_text_encoder_bias, ) @@ -765,6 +773,8 @@ def main(): unet.enable_gradient_checkpointing() if len(args.alias_tokens) != 0: + embeddings = ensure_embeddings() + alias_placeholder_tokens = args.alias_tokens[::2] alias_initializer_tokens = args.alias_tokens[1::2] @@ -781,6 +791,8 @@ def main(): placeholder_token_ids = [] if args.embeddings_dir is not None: + embeddings = ensure_embeddings() + embeddings_dir = Path(args.embeddings_dir) if not embeddings_dir.exists() or not embeddings_dir.is_dir(): raise ValueError("--embeddings_dir must point to an existing directory") @@ -798,6 +810,8 @@ def main(): embeddings.persist() if len(args.placeholder_tokens) != 0 and not args.train_dir_embeddings: + embeddings = ensure_embeddings() + placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( tokenizer=tokenizer, embeddings=embeddings, @@ -997,6 +1011,8 @@ def main(): # -------------------------------------------------------------------------------- if args.run_pti and len(placeholder_tokens) != 0: + embeddings = ensure_embeddings() + filter_tokens = [token for token in args.filter_tokens if token in placeholder_tokens] pti_datamodule = create_datamodule( diff --git a/train_ti.py b/train_ti.py index 6fd974e..f60e3e5 100644 --- a/train_ti.py +++ b/train_ti.py @@ -21,13 +21,14 @@ import transformers import numpy as np from slugify import slugify -from util.files import load_config, load_embeddings_from_dir from data.csv import VlpnDataModule, keyword_filter +from models.clip.embeddings import patch_managed_embeddings from training.functional import train, add_placeholder_tokens, get_models from training.strategy.ti import textual_inversion_strategy from training.optimization import get_scheduler from training.sampler import create_named_schedule_sampler from training.util import AverageMeter, save_args +from util.files import load_config, load_embeddings_from_dir logger = get_logger(__name__) @@ -702,11 +703,8 @@ def main(): save_args(output_dir, args) - tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( - args.pretrained_model_name_or_path, - args.emb_alpha, - args.emb_dropout - ) + tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler = get_models(args.pretrained_model_name_or_path) + embeddings = patch_managed_embeddings(text_encoder, args.emb_alpha, args.emb_dropout) schedule_sampler = create_named_schedule_sampler(args.schedule_sampler, noise_scheduler.config.num_train_timesteps) tokenizer.set_use_vector_shuffle(args.vector_shuffle) diff --git a/training/functional.py b/training/functional.py index 49c21c7..56c2995 100644 --- a/training/functional.py +++ b/training/functional.py @@ -20,7 +20,7 @@ from tqdm.auto import tqdm from data.csv import VlpnDataset from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion -from models.clip.embeddings import ManagedCLIPTextEmbeddings, patch_managed_embeddings +from models.clip.embeddings import ManagedCLIPTextEmbeddings from models.clip.util import get_extended_embeddings from models.clip.tokenizer import MultiCLIPTokenizer from models.convnext.discriminator import ConvNeXtDiscriminator @@ -68,11 +68,7 @@ class TrainingStrategy(): prepare: TrainingStrategyPrepareCallable -def get_models( - pretrained_model_name_or_path: str, - emb_alpha: int = 8, - emb_dropout: float = 0.0 -): +def get_models(pretrained_model_name_or_path: str): tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer') text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder') vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae') @@ -81,9 +77,7 @@ def get_models( sample_scheduler = UniPCMultistepScheduler.from_pretrained( pretrained_model_name_or_path, subfolder='scheduler') - embeddings = patch_managed_embeddings(text_encoder, emb_alpha, emb_dropout) - - return tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings + return tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler def save_samples( diff --git a/training/strategy/lora.py b/training/strategy/lora.py index 0c0f633..f942b76 100644 --- a/training/strategy/lora.py +++ b/training/strategy/lora.py @@ -92,7 +92,7 @@ def lora_strategy_callbacks( max_grad_norm ) - if use_emb_decay: + if len(placeholder_tokens) != 0 and use_emb_decay: params = [ p for p in text_encoder.text_model.embeddings.parameters() @@ -102,7 +102,7 @@ def lora_strategy_callbacks( @torch.no_grad() def on_after_optimize(w, lrs: dict[str, float]): - if use_emb_decay and w is not None and "emb" in lrs: + if w is not None and "emb" in lrs: lr = lrs["emb"] lambda_ = emb_decay * lr -- cgit v1.2.3-70-g09d2