import argparse import itertools import datetime import logging from pathlib import Path from functools import partial from contextlib import contextmanager, nullcontext import torch import torch.utils.checkpoint from accelerate import Accelerator from accelerate.logging import get_logger from accelerate.utils import LoggerType, set_seed from diffusers import AutoencoderKL, UNet2DConditionModel import matplotlib.pyplot as plt from transformers import CLIPTextModel from slugify import slugify from util import load_config, load_embeddings_from_dir from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion from data.csv import VlpnDataModule, VlpnDataItem from training.optimization import get_scheduler from training.lr import LRFinder from training.util import CheckpointerBase, EMAModel, save_args, generate_class_images, add_placeholder_tokens, get_models from models.clip.tokenizer import MultiCLIPTokenizer logger = get_logger(__name__) torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.benchmark = True def parse_args(): parser = argparse.ArgumentParser( description="Simple example of a training script." ) parser.add_argument( "--pretrained_model_name_or_path", type=str, default=None, help="Path to pretrained model or model identifier from huggingface.co/models.", ) parser.add_argument( "--tokenizer_name", type=str, default=None, help="Pretrained tokenizer name or path if not the same as model_name", ) parser.add_argument( "--train_data_file", type=str, default=None, help="A folder containing the training data." ) parser.add_argument( "--train_data_template", type=str, default="template", ) parser.add_argument( "--project", type=str, default=None, help="The name of the current project.", ) parser.add_argument( "--placeholder_tokens", type=str, nargs='*', default=[], help="A token to use as a placeholder for the concept.", ) parser.add_argument( "--initializer_token", type=str, nargs='*', default=[], help="A token to use as initializer word." ) parser.add_argument( "--num_vectors", type=int, nargs='*', help="Number of vectors per embedding." ) parser.add_argument( "--exclude_collections", type=str, nargs='*', help="Exclude all items with a listed collection.", ) parser.add_argument( "--train_text_encoder", action="store_true", default=True, help="Whether to train the whole text encoder." ) parser.add_argument( "--train_text_encoder_epochs", default=999999, help="Number of epochs the text encoder will be trained." ) parser.add_argument( "--num_buckets", type=int, default=4, help="Number of aspect ratio buckets in either direction.", ) parser.add_argument( "--progressive_buckets", action="store_true", help="Include images in smaller buckets as well.", ) parser.add_argument( "--bucket_step_size", type=int, default=64, help="Step size between buckets.", ) parser.add_argument( "--bucket_max_pixels", type=int, default=None, help="Maximum pixels per bucket.", ) parser.add_argument( "--tag_dropout", type=float, default=0.1, help="Tag dropout probability.", ) parser.add_argument( "--no_tag_shuffle", action="store_true", help="Shuffle tags.", ) parser.add_argument( "--vector_dropout", type=int, default=0, help="Vector dropout probability.", ) parser.add_argument( "--vector_shuffle", type=str, default="auto", help='Vector shuffling algorithm. Choose between ["all", "trailing", "leading", "between", "auto", "off"]', ) parser.add_argument( "--num_class_images", type=int, default=1, help="How many class images to generate." ) parser.add_argument( "--class_image_dir", type=str, default="cls", help="The directory where class images will be saved.", ) parser.add_argument( "--output_dir", type=str, default="output/dreambooth", help="The output directory where the model predictions and checkpoints will be written.", ) parser.add_argument( "--embeddings_dir", type=str, default=None, help="The embeddings directory where Textual Inversion embeddings are stored.", ) parser.add_argument( "--collection", type=str, nargs='*', help="A collection to filter the dataset.", ) parser.add_argument( "--seed", type=int, default=None, help="A seed for reproducible training." ) parser.add_argument( "--resolution", type=int, default=768, help=( "The resolution for input images, all the images in the train/validation dataset will be resized to this" " resolution" ), ) parser.add_argument( "--num_train_epochs", type=int, default=100 ) parser.add_argument( "--max_train_steps", type=int, default=None, help="Total number of training steps to perform. If provided, overrides num_train_epochs.", ) parser.add_argument( "--gradient_accumulation_steps", type=int, default=1, help="Number of updates steps to accumulate before performing a backward/update pass.", ) parser.add_argument( "--gradient_checkpointing", action="store_true", help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", ) parser.add_argument( "--find_lr", action="store_true", help="Automatically find a learning rate (no training).", ) parser.add_argument( "--learning_rate", type=float, default=2e-6, help="Initial learning rate (after the potential warmup period) to use.", ) parser.add_argument( "--scale_lr", action="store_true", help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", ) parser.add_argument( "--lr_scheduler", type=str, default="one_cycle", help=( 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' ' "constant", "constant_with_warmup", "one_cycle"]' ), ) parser.add_argument( "--lr_warmup_epochs", type=int, default=10, help="Number of steps for the warmup in the lr scheduler." ) parser.add_argument( "--lr_cycles", type=int, default=None, help="Number of restart cycles in the lr scheduler (if supported)." ) parser.add_argument( "--lr_warmup_func", type=str, default="cos", help='Choose between ["linear", "cos"]' ) parser.add_argument( "--lr_warmup_exp", type=int, default=1, help='If lr_warmup_func is "cos", exponent to modify the function' ) parser.add_argument( "--lr_annealing_func", type=str, default="cos", help='Choose between ["linear", "half_cos", "cos"]' ) parser.add_argument( "--lr_annealing_exp", type=int, default=3, help='If lr_annealing_func is "half_cos" or "cos", exponent to modify the function' ) parser.add_argument( "--lr_min_lr", type=float, default=0.04, help="Minimum learning rate in the lr scheduler." ) parser.add_argument( "--use_ema", action="store_true", help="Whether to use EMA model." ) parser.add_argument( "--ema_inv_gamma", type=float, default=1.0 ) parser.add_argument( "--ema_power", type=float, default=6/7 ) parser.add_argument( "--ema_max_decay", type=float, default=0.9999 ) parser.add_argument( "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." ) parser.add_argument( "--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer." ) parser.add_argument( "--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer." ) parser.add_argument( "--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use." ) parser.add_argument( "--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer" ) parser.add_argument( "--adam_amsgrad", type=bool, default=False, help="Amsgrad value for the Adam optimizer" ) parser.add_argument( "--mixed_precision", type=str, default="no", choices=["no", "fp16", "bf16"], help=( "Whether to use mixed precision. Choose" "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10." "and an Nvidia Ampere GPU." ), ) parser.add_argument( "--sample_frequency", type=int, default=1, help="How often to save a checkpoint and sample image", ) parser.add_argument( "--sample_image_size", type=int, default=768, help="Size of sample images", ) parser.add_argument( "--sample_batches", type=int, default=1, help="Number of sample batches to generate per checkpoint", ) parser.add_argument( "--sample_batch_size", type=int, default=1, help="Number of samples to generate per batch", ) parser.add_argument( "--valid_set_size", type=int, default=None, help="Number of images in the validation dataset." ) parser.add_argument( "--valid_set_repeat", type=int, default=1, help="Times the images in the validation dataset are repeated." ) parser.add_argument( "--train_batch_size", type=int, default=1, help="Batch size (per device) for the training dataloader." ) parser.add_argument( "--sample_steps", type=int, default=20, help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", ) parser.add_argument( "--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss." ) parser.add_argument( "--max_grad_norm", default=1.0, type=float, help="Max gradient norm." ) parser.add_argument( "--noise_timesteps", type=int, default=1000, ) parser.add_argument( "--config", type=str, default=None, help="Path to a JSON configuration file containing arguments for invoking this script." ) args = parser.parse_args() if args.config is not None: args = load_config(args.config) args = parser.parse_args(namespace=argparse.Namespace(**args)) if args.train_data_file is None: raise ValueError("You must specify --train_data_file") if args.pretrained_model_name_or_path is None: raise ValueError("You must specify --pretrained_model_name_or_path") if args.project is None: raise ValueError("You must specify --project") if isinstance(args.placeholder_tokens, str): args.placeholder_tokens = [args.placeholder_tokens] if len(args.placeholder_tokens) == 0: args.placeholder_tokens = [f"<*{i}>" for i in range(args.initializer_token)] if isinstance(args.initializer_token, str): args.initializer_token = [args.initializer_token] * len(args.placeholder_tokens) if len(args.initializer_token) == 0: raise ValueError("You must specify --initializer_token") if len(args.placeholder_tokens) != len(args.initializer_token): raise ValueError("--placeholder_tokens and --initializer_token must have the same number of items") if args.num_vectors is None: args.num_vectors = 1 if isinstance(args.num_vectors, int): args.num_vectors = [args.num_vectors] * len(args.initializer_token) if len(args.placeholder_tokens) != len(args.num_vectors): raise ValueError("--placeholder_tokens and --num_vectors must have the same number of items") if isinstance(args.collection, str): args.collection = [args.collection] if isinstance(args.exclude_collections, str): args.exclude_collections = [args.exclude_collections] if args.output_dir is None: raise ValueError("You must specify --output_dir") return args class Checkpointer(CheckpointerBase): def __init__( self, weight_dtype: torch.dtype, accelerator: Accelerator, vae: AutoencoderKL, unet: UNet2DConditionModel, ema_unet: EMAModel, tokenizer: MultiCLIPTokenizer, text_encoder: CLIPTextModel, scheduler, *args, **kwargs ): super().__init__(*args, **kwargs) self.weight_dtype = weight_dtype self.accelerator = accelerator self.vae = vae self.unet = unet self.ema_unet = ema_unet self.tokenizer = tokenizer self.text_encoder = text_encoder self.scheduler = scheduler @torch.no_grad() def save_model(self): print("Saving model...") unet = self.accelerator.unwrap_model(self.unet) text_encoder = self.accelerator.unwrap_model(self.text_encoder) ema_context = self.ema_unet.apply_temporary(unet.parameters()) if self.ema_unet is not None else nullcontext() with ema_context: pipeline = VlpnStableDiffusion( text_encoder=text_encoder, vae=self.vae, unet=unet, tokenizer=self.tokenizer, scheduler=self.scheduler, ) pipeline.save_pretrained(self.output_dir.joinpath("model")) del unet del text_encoder del pipeline if torch.cuda.is_available(): torch.cuda.empty_cache() @torch.no_grad() def save_samples(self, step): unet = self.accelerator.unwrap_model(self.unet) text_encoder = self.accelerator.unwrap_model(self.text_encoder) ema_context = self.ema_unet.apply_temporary(unet.parameters()) if self.ema_unet is not None else nullcontext() with ema_context: orig_unet_dtype = unet.dtype orig_text_encoder_dtype = text_encoder.dtype unet.to(dtype=self.weight_dtype) text_encoder.to(dtype=self.weight_dtype) pipeline = VlpnStableDiffusion( text_encoder=text_encoder, vae=self.vae, unet=unet, tokenizer=self.tokenizer, scheduler=self.scheduler, ).to(self.accelerator.device) pipeline.set_progress_bar_config(dynamic_ncols=True) super().save_samples(pipeline, step) unet.to(dtype=orig_unet_dtype) text_encoder.to(dtype=orig_text_encoder_dtype) del unet del text_encoder del pipeline if torch.cuda.is_available(): torch.cuda.empty_cache() def main(): args = parse_args() if args.train_text_encoder and args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1: raise ValueError( "Gradient accumulation is not supported when training the text encoder in distributed training. " "Please set gradient_accumulation_steps to 1. This feature will be supported in the future." ) now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") output_dir = Path(args.output_dir).joinpath(slugify(args.project), now) output_dir.mkdir(parents=True, exist_ok=True) accelerator = Accelerator( log_with=LoggerType.TENSORBOARD, logging_dir=f"{output_dir}", gradient_accumulation_steps=args.gradient_accumulation_steps, mixed_precision=args.mixed_precision ) logging.basicConfig(filename=output_dir.joinpath("log.txt"), level=logging.DEBUG) if args.seed is None: args.seed = torch.random.seed() >> 32 set_seed(args.seed) save_args(output_dir, args) tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( args.pretrained_model_name_or_path) tokenizer.set_use_vector_shuffle(args.vector_shuffle) tokenizer.set_dropout(args.vector_dropout) vae.enable_slicing() vae.set_use_memory_efficient_attention_xformers(True) unet.set_use_memory_efficient_attention_xformers(True) if args.gradient_checkpointing: unet.enable_gradient_checkpointing() text_encoder.gradient_checkpointing_enable() if args.embeddings_dir is not None: embeddings_dir = Path(args.embeddings_dir) if not embeddings_dir.exists() or not embeddings_dir.is_dir(): raise ValueError("--embeddings_dir must point to an existing directory") added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( tokenizer=tokenizer, embeddings=embeddings, placeholder_tokens=args.placeholder_tokens, initializer_tokens=args.initializer_tokens, num_vectors=args.num_vectors ) if len(placeholder_token_ids) != 0: initializer_token_id_lens = [len(id) for id in initializer_token_ids] placeholder_token_stats = list(zip(args.placeholder_tokens, placeholder_token_ids, initializer_token_id_lens)) print(f"Added {len(placeholder_token_ids)} new tokens: {placeholder_token_stats}") if args.use_ema: ema_unet = EMAModel( unet.parameters(), inv_gamma=args.ema_inv_gamma, power=args.ema_power, max_value=args.ema_max_decay, ) else: ema_unet = None vae.requires_grad_(False) if args.train_text_encoder: print(f"Training entire text encoder.") embeddings.persist() text_encoder.text_model.embeddings.temp_token_embedding.requires_grad_(False) else: print(f"Training added text embeddings") text_encoder.text_model.encoder.requires_grad_(False) text_encoder.text_model.final_layer_norm.requires_grad_(False) text_encoder.text_model.embeddings.position_embedding.requires_grad_(False) text_encoder.text_model.embeddings.token_embedding.requires_grad_(False) if args.scale_lr: args.learning_rate = ( args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes ) if args.find_lr: args.learning_rate = 1e-6 # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs if args.use_8bit_adam: try: import bitsandbytes as bnb except ImportError: raise ImportError("To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.") optimizer_class = bnb.optim.AdamW8bit else: optimizer_class = torch.optim.AdamW if args.train_text_encoder: text_encoder_params_to_optimize = text_encoder.parameters() else: text_encoder_params_to_optimize = text_encoder.text_model.embeddings.temp_token_embedding.parameters() # Initialize the optimizer optimizer = optimizer_class( [ { 'params': unet.parameters(), }, { 'params': text_encoder_params_to_optimize, } ], lr=args.learning_rate, betas=(args.adam_beta1, args.adam_beta2), weight_decay=args.adam_weight_decay, eps=args.adam_epsilon, amsgrad=args.adam_amsgrad, ) weight_dtype = torch.float32 if args.mixed_precision == "fp16": weight_dtype = torch.float16 elif args.mixed_precision == "bf16": weight_dtype = torch.bfloat16 def keyword_filter(item: VlpnDataItem): cond3 = args.collection is None or args.collection in item.collection cond4 = args.exclude_collections is None or not any( collection in item.collection for collection in args.exclude_collections ) return cond3 and cond4 datamodule = VlpnDataModule( data_file=args.train_data_file, batch_size=args.train_batch_size, tokenizer=tokenizer, class_subdir=args.class_image_dir, num_class_images=args.num_class_images, size=args.resolution, num_buckets=args.num_buckets, progressive_buckets=args.progressive_buckets, bucket_step_size=args.bucket_step_size, bucket_max_pixels=args.bucket_max_pixels, dropout=args.tag_dropout, shuffle=not args.no_tag_shuffle, template_key=args.train_data_template, valid_set_size=args.valid_set_size, valid_set_repeat=args.valid_set_repeat, seed=args.seed, filter=keyword_filter, dtype=weight_dtype ) datamodule.setup() train_dataloader = datamodule.train_dataloader val_dataloader = datamodule.val_dataloader if args.num_class_images != 0: generate_class_images( accelerator, text_encoder, vae, unet, tokenizer, sample_scheduler, datamodule.data_train, args.sample_batch_size, args.sample_image_size, args.sample_steps ) if args.find_lr: lr_scheduler = None else: lr_scheduler = get_scheduler( args.lr_scheduler, optimizer=optimizer, min_lr=args.lr_min_lr, warmup_func=args.lr_warmup_func, annealing_func=args.lr_annealing_func, warmup_exp=args.lr_warmup_exp, annealing_exp=args.lr_annealing_exp, cycles=args.lr_cycles, train_epochs=args.num_train_epochs, warmup_epochs=args.lr_warmup_epochs, num_training_steps_per_epoch=len(train_dataloader), gradient_accumulation_steps=args.gradient_accumulation_steps ) unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler ) vae.to(accelerator.device, dtype=weight_dtype) if args.use_ema: ema_unet.to(accelerator.device) @contextmanager def on_train(epoch: int): try: tokenizer.train() if epoch < args.train_text_encoder_epochs: text_encoder.train() elif epoch == args.train_text_encoder_epochs: text_encoder.requires_grad_(False) yield finally: pass @contextmanager def on_eval(): try: tokenizer.eval() text_encoder.eval() ema_context = ema_unet.apply_temporary(unet.parameters()) if args.use_ema else nullcontext() with ema_context: yield finally: pass def on_before_optimize(epoch: int): if accelerator.sync_gradients: params_to_clip = [unet.parameters()] if args.train_text_encoder and epoch < args.train_text_encoder_epochs: params_to_clip.append(text_encoder.parameters()) accelerator.clip_grad_norm_(itertools.chain(*params_to_clip), args.max_grad_norm) @torch.no_grad() def on_after_optimize(lr: float): if not args.train_text_encoder: text_encoder.text_model.embeddings.normalize( args.decay_target, min(1.0, max(0.0, args.decay_factor * ((lr - args.decay_start) / (args.learning_rate - args.decay_start)))) ) def on_log(): if args.use_ema: return {"ema_decay": ema_unet.decay} return {} loss_step_ = partial( loss_step, vae, noise_scheduler, unet, text_encoder, args.prior_loss_weight, args.seed, ) checkpointer = Checkpointer( weight_dtype=weight_dtype, train_dataloader=train_dataloader, val_dataloader=val_dataloader, accelerator=accelerator, vae=vae, unet=unet, ema_unet=ema_unet, tokenizer=tokenizer, text_encoder=text_encoder, scheduler=sample_scheduler, placeholder_tokens=args.placeholder_tokens, placeholder_token_ids=placeholder_token_ids, output_dir=output_dir, sample_steps=args.sample_steps, sample_image_size=args.sample_image_size, sample_batch_size=args.sample_batch_size, sample_batches=args.sample_batches, seed=args.seed ) if accelerator.is_main_process: accelerator.init_trackers("dreambooth", config=config) if args.find_lr: lr_finder = LRFinder( accelerator=accelerator, optimizer=optimizer, model=unet, train_dataloader=train_dataloader, val_dataloader=val_dataloader, loss_step=loss_step_, on_train=on_train, on_eval=on_eval, on_before_optimize=on_before_optimize, on_after_optimize=on_after_optimize, ) lr_finder.run(num_epochs=100, end_lr=1e2) plt.savefig(output_dir.joinpath("lr.png"), dpi=300) plt.close() else: train_loop( accelerator=accelerator, optimizer=optimizer, lr_scheduler=lr_scheduler, model=unet, checkpointer=checkpointer, train_dataloader=train_dataloader, val_dataloader=val_dataloader, loss_step=loss_step_, sample_frequency=args.sample_frequency, checkpoint_frequency=args.checkpoint_frequency, global_step_offset=0, num_epochs=args.num_train_epochs, on_log=on_log, on_train=on_train, on_after_optimize=on_after_optimize, on_eval=on_eval ) if __name__ == "__main__": main()