From 7ccd4614a56cfd6ecacba85605f338593f1059f0 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Tue, 7 Feb 2023 20:44:43 +0100 Subject: Add Lora --- train_lora.py | 566 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 566 insertions(+) create mode 100644 train_lora.py (limited to 'train_lora.py') diff --git a/train_lora.py b/train_lora.py new file mode 100644 index 0000000..2cb85cc --- /dev/null +++ b/train_lora.py @@ -0,0 +1,566 @@ +import argparse +import datetime +import logging +import itertools +from pathlib import Path +from functools import partial + +import torch +import torch.utils.checkpoint + +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import LoggerType, set_seed +from slugify import slugify +from diffusers.loaders import AttnProcsLayers +from diffusers.models.cross_attention import LoRACrossAttnProcessor + +from util import load_config, load_embeddings_from_dir +from data.csv import VlpnDataModule, keyword_filter +from training.functional import train, get_models +from training.lr import plot_metrics +from training.strategy.lora import lora_strategy +from training.optimization import get_scheduler +from training.util import save_args + +logger = get_logger(__name__) + + +torch.backends.cuda.matmul.allow_tf32 = True +torch.backends.cudnn.benchmark = True + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Simple example of a training script." + ) + parser.add_argument( + "--pretrained_model_name_or_path", + type=str, + default=None, + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--tokenizer_name", + type=str, + default=None, + help="Pretrained tokenizer name or path if not the same as model_name", + ) + parser.add_argument( + "--train_data_file", + type=str, + default=None, + help="A folder containing the training data." + ) + parser.add_argument( + "--train_data_template", + type=str, + default="template", + ) + parser.add_argument( + "--train_set_pad", + type=int, + default=None, + help="The number to fill train dataset items up to." + ) + parser.add_argument( + "--valid_set_pad", + type=int, + default=None, + help="The number to fill validation dataset items up to." + ) + parser.add_argument( + "--project", + type=str, + default=None, + help="The name of the current project.", + ) + parser.add_argument( + "--exclude_collections", + type=str, + nargs='*', + help="Exclude all items with a listed collection.", + ) + parser.add_argument( + "--num_buckets", + type=int, + default=4, + help="Number of aspect ratio buckets in either direction.", + ) + parser.add_argument( + "--progressive_buckets", + action="store_true", + help="Include images in smaller buckets as well.", + ) + parser.add_argument( + "--bucket_step_size", + type=int, + default=64, + help="Step size between buckets.", + ) + parser.add_argument( + "--bucket_max_pixels", + type=int, + default=None, + help="Maximum pixels per bucket.", + ) + parser.add_argument( + "--tag_dropout", + type=float, + default=0.1, + help="Tag dropout probability.", + ) + parser.add_argument( + "--no_tag_shuffle", + action="store_true", + help="Shuffle tags.", + ) + parser.add_argument( + "--num_class_images", + type=int, + default=0, + help="How many class images to generate." + ) + parser.add_argument( + "--class_image_dir", + type=str, + default="cls", + help="The directory where class images will be saved.", + ) + parser.add_argument( + "--output_dir", + type=str, + default="output/lora", + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument( + "--embeddings_dir", + type=str, + default=None, + help="The embeddings directory where Textual Inversion embeddings are stored.", + ) + parser.add_argument( + "--collection", + type=str, + nargs='*', + help="A collection to filter the dataset.", + ) + parser.add_argument( + "--seed", + type=int, + default=None, + help="A seed for reproducible training." + ) + parser.add_argument( + "--resolution", + type=int, + default=768, + help=( + "The resolution for input images, all the images in the train/validation dataset will be resized to this" + " resolution" + ), + ) + parser.add_argument( + "--num_train_epochs", + type=int, + default=100 + ) + parser.add_argument( + "--max_train_steps", + type=int, + default=None, + help="Total number of training steps to perform. If provided, overrides num_train_epochs.", + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + parser.add_argument( + "--find_lr", + action="store_true", + help="Automatically find a learning rate (no training).", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=2e-6, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument( + "--scale_lr", + action="store_true", + help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", + ) + parser.add_argument( + "--lr_scheduler", + type=str, + default="one_cycle", + help=( + 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup", "one_cycle"]' + ), + ) + parser.add_argument( + "--lr_warmup_epochs", + type=int, + default=10, + help="Number of steps for the warmup in the lr scheduler." + ) + parser.add_argument( + "--lr_cycles", + type=int, + default=None, + help="Number of restart cycles in the lr scheduler (if supported)." + ) + parser.add_argument( + "--lr_warmup_func", + type=str, + default="cos", + help='Choose between ["linear", "cos"]' + ) + parser.add_argument( + "--lr_warmup_exp", + type=int, + default=1, + help='If lr_warmup_func is "cos", exponent to modify the function' + ) + parser.add_argument( + "--lr_annealing_func", + type=str, + default="cos", + help='Choose between ["linear", "half_cos", "cos"]' + ) + parser.add_argument( + "--lr_annealing_exp", + type=int, + default=3, + help='If lr_annealing_func is "half_cos" or "cos", exponent to modify the function' + ) + parser.add_argument( + "--lr_min_lr", + type=float, + default=0.04, + help="Minimum learning rate in the lr scheduler." + ) + parser.add_argument( + "--use_8bit_adam", + action="store_true", + help="Whether or not to use 8-bit Adam from bitsandbytes." + ) + parser.add_argument( + "--adam_beta1", + type=float, + default=0.9, + help="The beta1 parameter for the Adam optimizer." + ) + parser.add_argument( + "--adam_beta2", + type=float, + default=0.999, + help="The beta2 parameter for the Adam optimizer." + ) + parser.add_argument( + "--adam_weight_decay", + type=float, + default=1e-2, + help="Weight decay to use." + ) + parser.add_argument( + "--adam_epsilon", + type=float, + default=1e-08, + help="Epsilon value for the Adam optimizer" + ) + parser.add_argument( + "--adam_amsgrad", + type=bool, + default=False, + help="Amsgrad value for the Adam optimizer" + ) + parser.add_argument( + "--mixed_precision", + type=str, + default="no", + choices=["no", "fp16", "bf16"], + help=( + "Whether to use mixed precision. Choose" + "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10." + "and an Nvidia Ampere GPU." + ), + ) + parser.add_argument( + "--sample_frequency", + type=int, + default=1, + help="How often to save a checkpoint and sample image", + ) + parser.add_argument( + "--sample_image_size", + type=int, + default=768, + help="Size of sample images", + ) + parser.add_argument( + "--sample_batches", + type=int, + default=1, + help="Number of sample batches to generate per checkpoint", + ) + parser.add_argument( + "--sample_batch_size", + type=int, + default=1, + help="Number of samples to generate per batch", + ) + parser.add_argument( + "--valid_set_size", + type=int, + default=None, + help="Number of images in the validation dataset." + ) + parser.add_argument( + "--valid_set_repeat", + type=int, + default=1, + help="Times the images in the validation dataset are repeated." + ) + parser.add_argument( + "--train_batch_size", + type=int, + default=1, + help="Batch size (per device) for the training dataloader." + ) + parser.add_argument( + "--sample_steps", + type=int, + default=20, + help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", + ) + parser.add_argument( + "--prior_loss_weight", + type=float, + default=1.0, + help="The weight of prior preservation loss." + ) + parser.add_argument( + "--max_grad_norm", + default=1.0, + type=float, + help="Max gradient norm." + ) + parser.add_argument( + "--noise_timesteps", + type=int, + default=1000, + ) + parser.add_argument( + "--config", + type=str, + default=None, + help="Path to a JSON configuration file containing arguments for invoking this script." + ) + + args = parser.parse_args() + if args.config is not None: + args = load_config(args.config) + args = parser.parse_args(namespace=argparse.Namespace(**args)) + + if args.train_data_file is None: + raise ValueError("You must specify --train_data_file") + + if args.pretrained_model_name_or_path is None: + raise ValueError("You must specify --pretrained_model_name_or_path") + + if args.project is None: + raise ValueError("You must specify --project") + + if isinstance(args.collection, str): + args.collection = [args.collection] + + if isinstance(args.exclude_collections, str): + args.exclude_collections = [args.exclude_collections] + + if args.output_dir is None: + raise ValueError("You must specify --output_dir") + + return args + + +def main(): + args = parse_args() + + now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") + output_dir = Path(args.output_dir).joinpath(slugify(args.project), now) + output_dir.mkdir(parents=True, exist_ok=True) + + accelerator = Accelerator( + log_with=LoggerType.TENSORBOARD, + logging_dir=f"{output_dir}", + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision + ) + + logging.basicConfig(filename=output_dir.joinpath("log.txt"), level=logging.DEBUG) + + if args.seed is None: + args.seed = torch.random.seed() >> 32 + + set_seed(args.seed) + + save_args(output_dir, args) + + tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( + args.pretrained_model_name_or_path) + + vae.enable_slicing() + vae.set_use_memory_efficient_attention_xformers(True) + unet.enable_xformers_memory_efficient_attention() + + lora_attn_procs = {} + for name in unet.attn_processors.keys(): + cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim + if name.startswith("mid_block"): + hidden_size = unet.config.block_out_channels[-1] + elif name.startswith("up_blocks"): + block_id = int(name[len("up_blocks.")]) + hidden_size = list(reversed(unet.config.block_out_channels))[block_id] + elif name.startswith("down_blocks"): + block_id = int(name[len("down_blocks.")]) + hidden_size = unet.config.block_out_channels[block_id] + + lora_attn_procs[name] = LoRACrossAttnProcessor( + hidden_size=hidden_size, cross_attention_dim=cross_attention_dim + ) + + unet.set_attn_processor(lora_attn_procs) + lora_layers = AttnProcsLayers(unet.attn_processors) + + if args.embeddings_dir is not None: + embeddings_dir = Path(args.embeddings_dir) + if not embeddings_dir.exists() or not embeddings_dir.is_dir(): + raise ValueError("--embeddings_dir must point to an existing directory") + + embeddings.persist() + + added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) + print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") + + if args.scale_lr: + args.learning_rate = ( + args.learning_rate * args.gradient_accumulation_steps * + args.train_batch_size * accelerator.num_processes + ) + + if args.find_lr: + args.learning_rate = 1e-6 + args.lr_scheduler = "exponential_growth" + + if args.use_8bit_adam: + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError("To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.") + + optimizer_class = bnb.optim.AdamW8bit + else: + optimizer_class = torch.optim.AdamW + + weight_dtype = torch.float32 + if args.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif args.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + trainer = partial( + train, + accelerator=accelerator, + unet=unet, + text_encoder=text_encoder, + vae=vae, + lora_layers=lora_layers, + noise_scheduler=noise_scheduler, + dtype=weight_dtype, + with_prior_preservation=args.num_class_images != 0, + prior_loss_weight=args.prior_loss_weight, + ) + + checkpoint_output_dir = output_dir.joinpath("model") + sample_output_dir = output_dir.joinpath(f"samples") + + datamodule = VlpnDataModule( + data_file=args.train_data_file, + batch_size=args.train_batch_size, + tokenizer=tokenizer, + class_subdir=args.class_image_dir, + num_class_images=args.num_class_images, + size=args.resolution, + num_buckets=args.num_buckets, + progressive_buckets=args.progressive_buckets, + bucket_step_size=args.bucket_step_size, + bucket_max_pixels=args.bucket_max_pixels, + dropout=args.tag_dropout, + shuffle=not args.no_tag_shuffle, + template_key=args.train_data_template, + valid_set_size=args.valid_set_size, + train_set_pad=args.train_set_pad, + valid_set_pad=args.valid_set_pad, + seed=args.seed, + filter=partial(keyword_filter, None, args.collection, args.exclude_collections), + dtype=weight_dtype + ) + datamodule.setup() + + optimizer = optimizer_class( + lora_layers.parameters(), + lr=args.learning_rate, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + amsgrad=args.adam_amsgrad, + ) + + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_training_steps_per_epoch=len(datamodule.train_dataloader), + gradient_accumulation_steps=args.gradient_accumulation_steps, + min_lr=args.lr_min_lr, + warmup_func=args.lr_warmup_func, + annealing_func=args.lr_annealing_func, + warmup_exp=args.lr_warmup_exp, + annealing_exp=args.lr_annealing_exp, + cycles=args.lr_cycles, + end_lr=1e2, + train_epochs=args.num_train_epochs, + warmup_epochs=args.lr_warmup_epochs, + ) + + metrics = trainer( + strategy=lora_strategy, + project="lora", + train_dataloader=datamodule.train_dataloader, + val_dataloader=datamodule.val_dataloader, + seed=args.seed, + optimizer=optimizer, + lr_scheduler=lr_scheduler, + num_train_epochs=args.num_train_epochs, + sample_frequency=args.sample_frequency, + # -- + tokenizer=tokenizer, + sample_scheduler=sample_scheduler, + sample_output_dir=sample_output_dir, + checkpoint_output_dir=checkpoint_output_dir, + max_grad_norm=args.max_grad_norm, + sample_batch_size=args.sample_batch_size, + sample_num_batches=args.sample_batches, + sample_num_steps=args.sample_steps, + sample_image_size=args.sample_image_size, + ) + + plot_metrics(metrics, output_dir.joinpath("lr.png")) + + +if __name__ == "__main__": + main() -- cgit v1.2.3-54-g00ecf