import argparse import datetime import logging import itertools from pathlib import Path from functools import partial import torch import torch.utils.checkpoint from accelerate import Accelerator from accelerate.logging import get_logger from accelerate.utils import LoggerType, set_seed from peft import LoraConfig, LoraModel from slugify import slugify from util.files import load_config, load_embeddings_from_dir from data.csv import VlpnDataModule, keyword_filter from training.functional import train, get_models from training.lr import plot_metrics from training.strategy.lora import lora_strategy from training.optimization import get_scheduler from training.util import save_args # https://github.com/huggingface/peft/blob/main/examples/lora_dreambooth/train_dreambooth.py UNET_TARGET_MODULES = ["to_q", "to_v", "query", "value"] TEXT_ENCODER_TARGET_MODULES = ["q_proj", "v_proj"] logger = get_logger(__name__) torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.benchmark = True def parse_args(): parser = argparse.ArgumentParser( description="Simple example of a training script." ) parser.add_argument( "--pretrained_model_name_or_path", type=str, default=None, help="Path to pretrained model or model identifier from huggingface.co/models.", ) parser.add_argument( "--tokenizer_name", type=str, default=None, help="Pretrained tokenizer name or path if not the same as model_name", ) parser.add_argument( "--train_data_file", type=str, default=None, help="A folder containing the training data." ) parser.add_argument( "--train_data_template", type=str, default="template", ) parser.add_argument( "--train_set_pad", type=int, default=None, help="The number to fill train dataset items up to." ) parser.add_argument( "--valid_set_pad", type=int, default=None, help="The number to fill validation dataset items up to." ) parser.add_argument( "--project", type=str, default=None, help="The name of the current project.", ) parser.add_argument( "--exclude_collections", type=str, nargs='*', help="Exclude all items with a listed collection.", ) parser.add_argument( "--num_buckets", type=int, default=2, help="Number of aspect ratio buckets in either direction.", ) parser.add_argument( "--progressive_buckets", action="store_true", help="Include images in smaller buckets as well.", ) parser.add_argument( "--bucket_step_size", type=int, default=64, help="Step size between buckets.", ) parser.add_argument( "--bucket_max_pixels", type=int, default=None, help="Maximum pixels per bucket.", ) parser.add_argument( "--tag_dropout", type=float, default=0, help="Tag dropout probability.", ) parser.add_argument( "--no_tag_shuffle", action="store_true", help="Shuffle tags.", ) parser.add_argument( "--guidance_scale", type=float, default=0, ) parser.add_argument( "--num_class_images", type=int, default=0, help="How many class images to generate." ) parser.add_argument( "--class_image_dir", type=str, default="cls", help="The directory where class images will be saved.", ) parser.add_argument( "--output_dir", type=str, default="output/lora", help="The output directory where the model predictions and checkpoints will be written.", ) parser.add_argument( "--embeddings_dir", type=str, default=None, help="The embeddings directory where Textual Inversion embeddings are stored.", ) parser.add_argument( "--collection", type=str, nargs='*', help="A collection to filter the dataset.", ) parser.add_argument( "--seed", type=int, default=None, help="A seed for reproducible training." ) parser.add_argument( "--resolution", type=int, default=768, help=( "The resolution for input images, all the images in the train/validation dataset will be resized to this" " resolution" ), ) parser.add_argument( "--offset_noise_strength", type=float, default=0, help="Perlin offset noise strength.", ) parser.add_argument( "--num_train_epochs", type=int, default=100 ) parser.add_argument( "--max_train_steps", type=int, default=None, help="Total number of training steps to perform. If provided, overrides num_train_epochs.", ) parser.add_argument( "--gradient_accumulation_steps", type=int, default=1, help="Number of updates steps to accumulate before performing a backward/update pass.", ) parser.add_argument( "--lora_r", type=int, default=8, help="Lora rank, only used if use_lora is True" ) parser.add_argument( "--lora_alpha", type=int, default=32, help="Lora alpha, only used if use_lora is True" ) parser.add_argument( "--lora_dropout", type=float, default=0.0, help="Lora dropout, only used if use_lora is True" ) parser.add_argument( "--lora_bias", type=str, default="none", help="Bias type for Lora. Can be 'none', 'all' or 'lora_only', only used if use_lora is True", ) parser.add_argument( "--lora_text_encoder_r", type=int, default=8, help="Lora rank for text encoder, only used if `use_lora` and `train_text_encoder` are True", ) parser.add_argument( "--lora_text_encoder_alpha", type=int, default=32, help="Lora alpha for text encoder, only used if `use_lora` and `train_text_encoder` are True", ) parser.add_argument( "--lora_text_encoder_dropout", type=float, default=0.0, help="Lora dropout for text encoder, only used if `use_lora` and `train_text_encoder` are True", ) parser.add_argument( "--lora_text_encoder_bias", type=str, default="none", help="Bias type for Lora. Can be 'none', 'all' or 'lora_only', only used if use_lora and `train_text_encoder` are True", ) parser.add_argument( "--find_lr", action="store_true", help="Automatically find a learning rate (no training).", ) parser.add_argument( "--learning_rate", type=float, default=2e-6, help="Initial learning rate (after the potential warmup period) to use.", ) parser.add_argument( "--scale_lr", action="store_true", help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", ) parser.add_argument( "--lr_scheduler", type=str, default="one_cycle", help=( 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' ' "constant", "constant_with_warmup", "one_cycle"]' ), ) parser.add_argument( "--lr_warmup_epochs", type=int, default=10, help="Number of steps for the warmup in the lr scheduler." ) parser.add_argument( "--lr_cycles", type=int, default=None, help="Number of restart cycles in the lr scheduler (if supported)." ) parser.add_argument( "--lr_warmup_func", type=str, default="cos", help='Choose between ["linear", "cos"]' ) parser.add_argument( "--lr_warmup_exp", type=int, default=1, help='If lr_warmup_func is "cos", exponent to modify the function' ) parser.add_argument( "--lr_annealing_func", type=str, default="cos", help='Choose between ["linear", "half_cos", "cos"]' ) parser.add_argument( "--lr_annealing_exp", type=int, default=3, help='If lr_annealing_func is "half_cos" or "cos", exponent to modify the function' ) parser.add_argument( "--lr_min_lr", type=float, default=0.04, help="Minimum learning rate in the lr scheduler." ) parser.add_argument( "--optimizer", type=str, default="dadan", help='Optimizer to use ["adam", "adam8bit", "dadam", "dadan"]' ) parser.add_argument( "--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer." ) parser.add_argument( "--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer." ) parser.add_argument( "--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use." ) parser.add_argument( "--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer" ) parser.add_argument( "--adam_amsgrad", type=bool, default=False, help="Amsgrad value for the Adam optimizer" ) parser.add_argument( "--mixed_precision", type=str, default="no", choices=["no", "fp16", "bf16"], help=( "Whether to use mixed precision. Choose" "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10." "and an Nvidia Ampere GPU." ), ) parser.add_argument( "--lora_rank", type=int, default=256, help="LoRA rank.", ) parser.add_argument( "--sample_frequency", type=int, default=1, help="How often to save a checkpoint and sample image", ) parser.add_argument( "--sample_image_size", type=int, default=768, help="Size of sample images", ) parser.add_argument( "--sample_batches", type=int, default=1, help="Number of sample batches to generate per checkpoint", ) parser.add_argument( "--sample_batch_size", type=int, default=1, help="Number of samples to generate per batch", ) parser.add_argument( "--valid_set_size", type=int, default=None, help="Number of images in the validation dataset." ) parser.add_argument( "--valid_set_repeat", type=int, default=1, help="Times the images in the validation dataset are repeated." ) parser.add_argument( "--train_batch_size", type=int, default=1, help="Batch size (per device) for the training dataloader." ) parser.add_argument( "--sample_steps", type=int, default=10, help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", ) parser.add_argument( "--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss." ) parser.add_argument( "--max_grad_norm", default=1.0, type=float, help="Max gradient norm." ) parser.add_argument( "--noise_timesteps", type=int, default=1000, ) parser.add_argument( "--config", type=str, default=None, help="Path to a JSON configuration file containing arguments for invoking this script." ) args = parser.parse_args() if args.config is not None: args = load_config(args.config) args = parser.parse_args(namespace=argparse.Namespace(**args)) if args.train_data_file is None: raise ValueError("You must specify --train_data_file") if args.pretrained_model_name_or_path is None: raise ValueError("You must specify --pretrained_model_name_or_path") if args.project is None: raise ValueError("You must specify --project") if isinstance(args.collection, str): args.collection = [args.collection] if isinstance(args.exclude_collections, str): args.exclude_collections = [args.exclude_collections] if args.output_dir is None: raise ValueError("You must specify --output_dir") return args def main(): args = parse_args() now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") output_dir = Path(args.output_dir) / slugify(args.project) / now output_dir.mkdir(parents=True, exist_ok=True) accelerator = Accelerator( log_with=LoggerType.TENSORBOARD, project_dir=f"{output_dir}", mixed_precision=args.mixed_precision ) weight_dtype = torch.float32 if args.mixed_precision == "fp16": weight_dtype = torch.float16 elif args.mixed_precision == "bf16": weight_dtype = torch.bfloat16 logging.basicConfig(filename=output_dir / "log.txt", level=logging.DEBUG) if args.seed is None: args.seed = torch.random.seed() >> 32 set_seed(args.seed) save_args(output_dir, args) tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( args.pretrained_model_name_or_path) unet_config = LoraConfig( r=args.lora_r, lora_alpha=args.lora_alpha, target_modules=UNET_TARGET_MODULES, lora_dropout=args.lora_dropout, bias=args.lora_bias, ) unet = LoraModel(unet_config, unet) text_encoder_config = LoraConfig( r=args.lora_text_encoder_r, lora_alpha=args.lora_text_encoder_alpha, target_modules=TEXT_ENCODER_TARGET_MODULES, lora_dropout=args.lora_text_encoder_dropout, bias=args.lora_text_encoder_bias, ) text_encoder = LoraModel(text_encoder_config, text_encoder) vae.enable_slicing() vae.set_use_memory_efficient_attention_xformers(True) unet.enable_xformers_memory_efficient_attention() if args.gradient_checkpointing: unet.enable_gradient_checkpointing() if args.embeddings_dir is not None: embeddings_dir = Path(args.embeddings_dir) if not embeddings_dir.exists() or not embeddings_dir.is_dir(): raise ValueError("--embeddings_dir must point to an existing directory") embeddings.persist() added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") if args.scale_lr: args.learning_rate = ( args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes ) if args.find_lr: args.learning_rate = 1e-6 args.lr_scheduler = "exponential_growth" if args.optimizer == 'adam8bit': try: import bitsandbytes as bnb except ImportError: raise ImportError("To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.") create_optimizer = partial( bnb.optim.AdamW8bit, betas=(args.adam_beta1, args.adam_beta2), weight_decay=args.adam_weight_decay, eps=args.adam_epsilon, amsgrad=args.adam_amsgrad, ) elif args.optimizer == 'adam': create_optimizer = partial( torch.optim.AdamW, betas=(args.adam_beta1, args.adam_beta2), weight_decay=args.adam_weight_decay, eps=args.adam_epsilon, amsgrad=args.adam_amsgrad, ) elif args.optimizer == 'dadam': try: import dadaptation except ImportError: raise ImportError("To use DAdaptAdam, please install the dadaptation library: `pip install dadaptation`.") create_optimizer = partial( dadaptation.DAdaptAdam, betas=(args.adam_beta1, args.adam_beta2), weight_decay=args.adam_weight_decay, eps=args.adam_epsilon, decouple=True, ) args.learning_rate = 1.0 elif args.optimizer == 'dadan': try: import dadaptation except ImportError: raise ImportError("To use DAdaptAdan, please install the dadaptation library: `pip install dadaptation`.") create_optimizer = partial( dadaptation.DAdaptAdan, weight_decay=args.adam_weight_decay, eps=args.adam_epsilon, ) args.learning_rate = 1.0 else: raise ValueError(f"Unknown --optimizer \"{args.optimizer}\"") trainer = partial( train, accelerator=accelerator, unet=unet, text_encoder=text_encoder, vae=vae, noise_scheduler=noise_scheduler, dtype=weight_dtype, guidance_scale=args.guidance_scale, prior_loss_weight=args.prior_loss_weight if args.num_class_images != 0 else 0, no_val=args.valid_set_size == 0, ) checkpoint_output_dir = output_dir / "model" sample_output_dir = output_dir/"samples" datamodule = VlpnDataModule( data_file=args.train_data_file, batch_size=args.train_batch_size, tokenizer=tokenizer, class_subdir=args.class_image_dir, with_guidance=args.guidance_scale != 0, num_class_images=args.num_class_images, size=args.resolution, num_buckets=args.num_buckets, progressive_buckets=args.progressive_buckets, bucket_step_size=args.bucket_step_size, bucket_max_pixels=args.bucket_max_pixels, dropout=args.tag_dropout, shuffle=not args.no_tag_shuffle, template_key=args.train_data_template, placeholder_tokens=args.placeholder_tokens, valid_set_size=args.valid_set_size, train_set_pad=args.train_set_pad, valid_set_pad=args.valid_set_pad, seed=args.seed, filter=partial(keyword_filter, None, args.collection, args.exclude_collections), dtype=weight_dtype ) datamodule.setup() optimizer = create_optimizer( itertools.chain( unet.parameters(), text_encoder.parameters(), ), lr=args.learning_rate, ) lr_scheduler = get_scheduler( args.lr_scheduler, optimizer=optimizer, num_training_steps_per_epoch=len(datamodule.train_dataloader), gradient_accumulation_steps=args.gradient_accumulation_steps, min_lr=args.lr_min_lr, warmup_func=args.lr_warmup_func, annealing_func=args.lr_annealing_func, warmup_exp=args.lr_warmup_exp, annealing_exp=args.lr_annealing_exp, cycles=args.lr_cycles, end_lr=1e2, train_epochs=args.num_train_epochs, warmup_epochs=args.lr_warmup_epochs, ) metrics = trainer( strategy=lora_strategy, project="lora", train_dataloader=datamodule.train_dataloader, val_dataloader=datamodule.val_dataloader, seed=args.seed, optimizer=optimizer, lr_scheduler=lr_scheduler, num_train_epochs=args.num_train_epochs, gradient_accumulation_steps=args.gradient_accumulation_steps, sample_frequency=args.sample_frequency, offset_noise_strength=args.offset_noise_strength, # -- tokenizer=tokenizer, sample_scheduler=sample_scheduler, sample_output_dir=sample_output_dir, checkpoint_output_dir=checkpoint_output_dir, max_grad_norm=args.max_grad_norm, sample_batch_size=args.sample_batch_size, sample_num_batches=args.sample_batches, sample_num_steps=args.sample_steps, sample_image_size=args.sample_image_size, ) plot_metrics(metrics, output_dir/"lr.png") if __name__ == "__main__": main()