From d952d467d31786f4a85cc4cb009934cd4ebbba71 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Fri, 7 Apr 2023 09:09:46 +0200 Subject: Update --- train_lora.py | 303 ++++++++++++++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 261 insertions(+), 42 deletions(-) (limited to 'train_lora.py') diff --git a/train_lora.py b/train_lora.py index 1ca56d9..39bf455 100644 --- a/train_lora.py +++ b/train_lora.py @@ -1,6 +1,7 @@ import argparse import datetime import logging +import itertools from pathlib import Path from functools import partial import math @@ -17,9 +18,10 @@ import transformers from util.files import load_config, load_embeddings_from_dir from data.csv import VlpnDataModule, keyword_filter -from training.functional import train, get_models +from training.functional import train, add_placeholder_tokens, get_models from training.lr import plot_metrics from training.strategy.lora import lora_strategy +from training.strategy.ti import textual_inversion_strategy from training.optimization import get_scheduler from training.util import save_args @@ -80,6 +82,43 @@ def parse_args(): default=None, help="The name of the current project.", ) + parser.add_argument( + "--placeholder_tokens", + type=str, + nargs='*', + help="A token to use as a placeholder for the concept.", + ) + parser.add_argument( + "--initializer_tokens", + type=str, + nargs='*', + help="A token to use as initializer word." + ) + parser.add_argument( + "--initializer_noise", + type=float, + default=0, + help="Noise to apply to the initializer word" + ) + parser.add_argument( + "--alias_tokens", + type=str, + nargs='*', + default=[], + help="Tokens to create an alias for." + ) + parser.add_argument( + "--inverted_initializer_tokens", + type=str, + nargs='*', + help="A token to use as initializer word." + ) + parser.add_argument( + "--num_vectors", + type=int, + nargs='*', + help="Number of vectors per embedding." + ) parser.add_argument( "--exclude_collections", type=str, @@ -186,6 +225,16 @@ def parse_args(): type=int, default=2000 ) + parser.add_argument( + "--num_pti_epochs", + type=int, + default=None + ) + parser.add_argument( + "--num_pti_steps", + type=int, + default=500 + ) parser.add_argument( "--gradient_accumulation_steps", type=int, @@ -257,6 +306,12 @@ def parse_args(): default=5e-5, help="Initial learning rate (after the potential warmup period) to use.", ) + parser.add_argument( + "--learning_rate_pti", + type=float, + default=1e-4, + help="Initial learning rate (after the potential warmup period) to use.", + ) parser.add_argument( "--scale_lr", action="store_true", @@ -432,6 +487,23 @@ def parse_args(): default=1.0, help="The weight of prior preservation loss." ) + parser.add_argument( + "--use_emb_decay", + action="store_true", + help="Whether to use embedding decay." + ) + parser.add_argument( + "--emb_decay_target", + default=0.4, + type=float, + help="Embedding decay target." + ) + parser.add_argument( + "--emb_decay", + default=1e+2, + type=float, + help="Embedding decay factor." + ) parser.add_argument( "--max_grad_norm", default=1.0, @@ -464,6 +536,40 @@ def parse_args(): if args.project is None: raise ValueError("You must specify --project") + if isinstance(args.placeholder_tokens, str): + args.placeholder_tokens = [args.placeholder_tokens] + + if isinstance(args.initializer_tokens, str): + args.initializer_tokens = [args.initializer_tokens] * len(args.placeholder_tokens) + + if len(args.initializer_tokens) == 0: + raise ValueError("You must specify --initializer_tokens") + + if len(args.placeholder_tokens) == 0: + args.placeholder_tokens = [f"<*{i}>" for i in range(len(args.initializer_tokens))] + + if len(args.placeholder_tokens) != len(args.initializer_tokens): + raise ValueError("--placeholder_tokens and --initializer_tokens must have the same number of items") + + if isinstance(args.inverted_initializer_tokens, str): + args.inverted_initializer_tokens = [args.inverted_initializer_tokens] * len(args.placeholder_tokens) + + if isinstance(args.inverted_initializer_tokens, list) and len(args.inverted_initializer_tokens) != 0: + args.placeholder_tokens += [f"inv_{t}" for t in args.placeholder_tokens] + args.initializer_tokens += args.inverted_initializer_tokens + + if isinstance(args.num_vectors, int): + args.num_vectors = [args.num_vectors] * len(args.placeholder_tokens) + + if isinstance(args.num_vectors, list) and len(args.placeholder_tokens) != len(args.num_vectors): + raise ValueError("--placeholder_tokens and --num_vectors must have the same number of items") + + if args.alias_tokens is None: + args.alias_tokens = [] + + if not isinstance(args.alias_tokens, list) or len(args.alias_tokens) % 2 != 0: + raise ValueError("--alias_tokens must be a list with an even number of items") + if isinstance(args.collection, str): args.collection = [args.collection] @@ -544,6 +650,19 @@ def main(): if args.gradient_checkpointing: unet.enable_gradient_checkpointing() + if len(args.alias_tokens) != 0: + alias_placeholder_tokens = args.alias_tokens[::2] + alias_initializer_tokens = args.alias_tokens[1::2] + + added_tokens, added_ids = add_placeholder_tokens( + tokenizer=tokenizer, + embeddings=embeddings, + placeholder_tokens=alias_placeholder_tokens, + initializer_tokens=alias_initializer_tokens + ) + embeddings.persist() + print(f"Added {len(added_tokens)} aliases: {list(zip(alias_placeholder_tokens, added_tokens, alias_initializer_tokens, added_ids))}") + if args.embeddings_dir is not None: embeddings_dir = Path(args.embeddings_dir) if not embeddings_dir.exists() or not embeddings_dir.is_dir(): @@ -552,6 +671,19 @@ def main(): added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") + placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( + tokenizer=tokenizer, + embeddings=embeddings, + placeholder_tokens=args.placeholder_tokens, + initializer_tokens=args.initializer_tokens, + num_vectors=args.num_vectors, + initializer_noise=args.initializer_noise, + ) + stats = list(zip( + args.placeholder_tokens, placeholder_token_ids, args.initializer_tokens, initializer_token_ids + )) + print(f"Training embeddings: {stats}") + if args.scale_lr: args.learning_rate_unet = ( args.learning_rate_unet * args.gradient_accumulation_steps * @@ -561,10 +693,15 @@ def main(): args.learning_rate_text * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes ) + args.learning_rate_pti = ( + args.learning_rate_pti * args.gradient_accumulation_steps * + args.train_batch_size * accelerator.num_processes + ) if args.find_lr: args.learning_rate_unet = 1e-6 args.learning_rate_text = 1e-6 + args.learning_rate_pti = 1e-6 args.lr_scheduler = "exponential_growth" if args.optimizer == 'adam8bit': @@ -663,18 +800,25 @@ def main(): accelerator=accelerator, unet=unet, text_encoder=text_encoder, + tokenizer=tokenizer, vae=vae, noise_scheduler=noise_scheduler, dtype=weight_dtype, + seed=args.seed, guidance_scale=args.guidance_scale, prior_loss_weight=args.prior_loss_weight if args.num_class_images != 0 else 0, no_val=args.valid_set_size == 0, + gradient_accumulation_steps=args.gradient_accumulation_steps, + offset_noise_strength=args.offset_noise_strength, + sample_scheduler=sample_scheduler, + sample_batch_size=args.sample_batch_size, + sample_num_batches=args.sample_batches, + sample_num_steps=args.sample_steps, + sample_image_size=args.sample_image_size, ) - checkpoint_output_dir = output_dir / "model" - sample_output_dir = output_dir/"samples" - - datamodule = VlpnDataModule( + create_datamodule = partial( + VlpnDataModule, data_file=args.train_data_file, batch_size=args.train_batch_size, tokenizer=tokenizer, @@ -693,71 +837,146 @@ def main(): train_set_pad=args.train_set_pad, valid_set_pad=args.valid_set_pad, seed=args.seed, + dtype=weight_dtype, + ) + + create_lr_scheduler = partial( + get_scheduler, + args.lr_scheduler, + gradient_accumulation_steps=args.gradient_accumulation_steps, + min_lr=args.lr_min_lr, + warmup_func=args.lr_warmup_func, + annealing_func=args.lr_annealing_func, + warmup_exp=args.lr_warmup_exp, + annealing_exp=args.lr_annealing_exp, + cycles=args.lr_cycles, + end_lr=1e2, + warmup_epochs=args.lr_warmup_epochs, + mid_point=args.lr_mid_point, + ) + + # PTI + # -------------------------------------------------------------------------------- + + pti_output_dir = output_dir / "pti" + pti_checkpoint_output_dir = pti_output_dir / "model" + pti_sample_output_dir = pti_output_dir / "samples" + + pti_datamodule = create_datamodule( + filter=partial(keyword_filter, args.placeholder_tokens, args.collection, args.exclude_collections), + ) + pti_datamodule.setup() + + num_pti_epochs = args.num_pti_epochs + pti_sample_frequency = args.sample_frequency + if num_pti_epochs is None: + num_pti_epochs = math.ceil( + args.num_pti_steps / len(pti_datamodule.train_dataset) + ) * args.gradient_accumulation_steps + pti_sample_frequency = math.ceil(num_pti_epochs * (pti_sample_frequency / args.num_train_steps)) + + pti_optimizer = create_optimizer( + [ + { + "params": text_encoder.text_model.embeddings.token_override_embedding.params.parameters(), + "lr": args.learning_rate_pti, + "weight_decay": 0, + }, + ] + ) + + pti_lr_scheduler = create_lr_scheduler( + optimizer=pti_optimizer, + num_training_steps_per_epoch=len(pti_datamodule.train_dataloader), + train_epochs=num_pti_epochs, + ) + + metrics = trainer( + strategy=textual_inversion_strategy, + project="ti", + train_dataloader=pti_datamodule.train_dataloader, + val_dataloader=pti_datamodule.val_dataloader, + optimizer=pti_optimizer, + lr_scheduler=pti_lr_scheduler, + num_train_epochs=num_pti_epochs, + # -- + sample_output_dir=pti_sample_output_dir, + checkpoint_output_dir=pti_checkpoint_output_dir, + sample_frequency=pti_sample_frequency, + placeholder_tokens=args.placeholder_tokens, + placeholder_token_ids=placeholder_token_ids, + use_emb_decay=args.use_emb_decay, + emb_decay_target=args.emb_decay_target, + emb_decay=args.emb_decay, + ) + + plot_metrics(metrics, output_dir/"lr.png") + + # LORA + # -------------------------------------------------------------------------------- + + lora_output_dir = output_dir / "pti" + lora_checkpoint_output_dir = lora_output_dir / "model" + lora_sample_output_dir = lora_output_dir / "samples" + + lora_datamodule = create_datamodule( filter=partial(keyword_filter, None, args.collection, args.exclude_collections), - dtype=weight_dtype ) - datamodule.setup() + lora_datamodule.setup() num_train_epochs = args.num_train_epochs - sample_frequency = args.sample_frequency + lora_sample_frequency = args.sample_frequency if num_train_epochs is None: num_train_epochs = math.ceil( - args.num_train_steps / len(datamodule.train_dataset) + args.num_train_steps / len(lora_datamodule.train_dataset) ) * args.gradient_accumulation_steps - sample_frequency = math.ceil(num_train_epochs * (sample_frequency / args.num_train_steps)) + lora_sample_frequency = math.ceil(num_train_epochs * (lora_sample_frequency / args.num_train_steps)) - optimizer = create_optimizer( + lora_optimizer = create_optimizer( [ { "params": unet.parameters(), "lr": args.learning_rate_unet, }, { - "params": text_encoder.parameters(), + "params": itertools.chain( + text_encoder.text_model.encoder.parameters(), + text_encoder.text_model.final_layer_norm.parameters(), + ), + "lr": args.learning_rate_text, + }, + { + "params": text_encoder.text_model.embeddings.token_override_embedding.params.parameters(), "lr": args.learning_rate_text, + "weight_decay": 0, }, ] ) - lr_scheduler = get_scheduler( - args.lr_scheduler, - optimizer=optimizer, - num_training_steps_per_epoch=len(datamodule.train_dataloader), - gradient_accumulation_steps=args.gradient_accumulation_steps, - min_lr=args.lr_min_lr, - warmup_func=args.lr_warmup_func, - annealing_func=args.lr_annealing_func, - warmup_exp=args.lr_warmup_exp, - annealing_exp=args.lr_annealing_exp, - cycles=args.lr_cycles, - end_lr=1e2, + lora_lr_scheduler = create_lr_scheduler( + optimizer=lora_optimizer, + num_training_steps_per_epoch=len(lora_datamodule.train_dataloader), train_epochs=num_train_epochs, - warmup_epochs=args.lr_warmup_epochs, - mid_point=args.lr_mid_point, ) metrics = trainer( strategy=lora_strategy, project="lora", - train_dataloader=datamodule.train_dataloader, - val_dataloader=datamodule.val_dataloader, - seed=args.seed, - optimizer=optimizer, - lr_scheduler=lr_scheduler, + train_dataloader=lora_datamodule.train_dataloader, + val_dataloader=lora_datamodule.val_dataloader, + optimizer=lora_optimizer, + lr_scheduler=lora_lr_scheduler, num_train_epochs=num_train_epochs, - gradient_accumulation_steps=args.gradient_accumulation_steps, - sample_frequency=sample_frequency, - offset_noise_strength=args.offset_noise_strength, # -- - tokenizer=tokenizer, - sample_scheduler=sample_scheduler, - sample_output_dir=sample_output_dir, - checkpoint_output_dir=checkpoint_output_dir, + sample_output_dir=lora_sample_output_dir, + checkpoint_output_dir=lora_checkpoint_output_dir, + sample_frequency=lora_sample_frequency, + placeholder_tokens=args.placeholder_tokens, + placeholder_token_ids=placeholder_token_ids, + use_emb_decay=args.use_emb_decay, + emb_decay_target=args.emb_decay_target, + emb_decay=args.emb_decay, max_grad_norm=args.max_grad_norm, - sample_batch_size=args.sample_batch_size, - sample_num_batches=args.sample_batches, - sample_num_steps=args.sample_steps, - sample_image_size=args.sample_image_size, ) plot_metrics(metrics, output_dir/"lr.png") -- cgit v1.2.3-54-g00ecf