import argparse import datetime import logging from functools import partial from pathlib import Path from contextlib import contextmanager, nullcontext import torch import torch.utils.checkpoint from accelerate import Accelerator from accelerate.logging import get_logger from accelerate.utils import LoggerType, set_seed from diffusers import AutoencoderKL, UNet2DConditionModel import matplotlib.pyplot as plt from transformers import CLIPTextModel from slugify import slugify from util import load_config, load_embeddings_from_dir from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion from data.csv import VlpnDataModule, VlpnDataItem from training.common import loss_step, train_loop, generate_class_images, add_placeholder_tokens, get_models from training.optimization import get_scheduler from training.lr import LRFinder from training.util import CheckpointerBase, EMAModel, save_args from models.clip.tokenizer import MultiCLIPTokenizer logger = get_logger(__name__) torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.benchmark = True def parse_args(): parser = argparse.ArgumentParser( description="Simple example of a training script." ) parser.add_argument( "--pretrained_model_name_or_path", type=str, default=None, help="Path to pretrained model or model identifier from huggingface.co/models.", ) parser.add_argument( "--tokenizer_name", type=str, default=None, help="Pretrained tokenizer name or path if not the same as model_name", ) parser.add_argument( "--train_data_file", type=str, default=None, help="A CSV file containing the training data." ) parser.add_argument( "--train_data_template", type=str, default="template", ) parser.add_argument( "--project", type=str, default=None, help="The name of the current project.", ) parser.add_argument( "--placeholder_tokens", type=str, nargs='*', help="A token to use as a placeholder for the concept.", ) parser.add_argument( "--initializer_tokens", type=str, nargs='*', help="A token to use as initializer word." ) parser.add_argument( "--num_vectors", type=int, nargs='*', help="Number of vectors per embedding." ) parser.add_argument( "--num_class_images", type=int, default=1, help="How many class images to generate." ) parser.add_argument( "--class_image_dir", type=str, default="cls", help="The directory where class images will be saved.", ) parser.add_argument( "--exclude_collections", type=str, nargs='*', help="Exclude all items with a listed collection.", ) parser.add_argument( "--output_dir", type=str, default="output/text-inversion", help="The output directory where the model predictions and checkpoints will be written.", ) parser.add_argument( "--embeddings_dir", type=str, default=None, help="The embeddings directory where Textual Inversion embeddings are stored.", ) parser.add_argument( "--collection", type=str, nargs='*', help="A collection to filter the dataset.", ) parser.add_argument( "--seed", type=int, default=None, help="A seed for reproducible training." ) parser.add_argument( "--resolution", type=int, default=768, help=( "The resolution for input images, all the images in the train/validation dataset will be resized to this" " resolution" ), ) parser.add_argument( "--num_buckets", type=int, default=0, help="Number of aspect ratio buckets in either direction.", ) parser.add_argument( "--progressive_buckets", action="store_true", help="Include images in smaller buckets as well.", ) parser.add_argument( "--bucket_step_size", type=int, default=64, help="Step size between buckets.", ) parser.add_argument( "--bucket_max_pixels", type=int, default=None, help="Maximum pixels per bucket.", ) parser.add_argument( "--tag_dropout", type=float, default=0, help="Tag dropout probability.", ) parser.add_argument( "--no_tag_shuffle", action="store_true", help="Shuffle tags.", ) parser.add_argument( "--vector_dropout", type=int, default=0, help="Vector dropout probability.", ) parser.add_argument( "--vector_shuffle", type=str, default="auto", help='Vector shuffling algorithm. Choose between ["all", "trailing", "leading", "between", "auto", "off"]', ) parser.add_argument( "--num_train_epochs", type=int, default=100 ) parser.add_argument( "--gradient_accumulation_steps", type=int, default=1, help="Number of updates steps to accumulate before performing a backward/update pass.", ) parser.add_argument( "--gradient_checkpointing", action="store_true", help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", ) parser.add_argument( "--find_lr", action="store_true", help="Automatically find a learning rate (no training).", ) parser.add_argument( "--learning_rate", type=float, default=1e-4, help="Initial learning rate (after the potential warmup period) to use.", ) parser.add_argument( "--scale_lr", action="store_true", help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", ) parser.add_argument( "--lr_scheduler", type=str, default="one_cycle", help=( 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' ' "constant", "constant_with_warmup", "one_cycle"]' ), ) parser.add_argument( "--lr_warmup_epochs", type=int, default=10, help="Number of steps for the warmup in the lr scheduler." ) parser.add_argument( "--lr_cycles", type=int, default=None, help="Number of restart cycles in the lr scheduler." ) parser.add_argument( "--lr_warmup_func", type=str, default="cos", help='Choose between ["linear", "cos"]' ) parser.add_argument( "--lr_warmup_exp", type=int, default=1, help='If lr_warmup_func is "cos", exponent to modify the function' ) parser.add_argument( "--lr_annealing_func", type=str, default="cos", help='Choose between ["linear", "half_cos", "cos"]' ) parser.add_argument( "--lr_annealing_exp", type=int, default=1, help='If lr_annealing_func is "half_cos" or "cos", exponent to modify the function' ) parser.add_argument( "--lr_min_lr", type=float, default=0.04, help="Minimum learning rate in the lr scheduler." ) parser.add_argument( "--use_ema", action="store_true", help="Whether to use EMA model." ) parser.add_argument( "--ema_inv_gamma", type=float, default=1.0 ) parser.add_argument( "--ema_power", type=float, default=4/5 ) parser.add_argument( "--ema_max_decay", type=float, default=0.9999 ) parser.add_argument( "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." ) parser.add_argument( "--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer." ) parser.add_argument( "--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer." ) parser.add_argument( "--adam_weight_decay", type=float, default=0, help="Weight decay to use." ) parser.add_argument( "--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer" ) parser.add_argument( "--adam_amsgrad", type=bool, default=False, help="Amsgrad value for the Adam optimizer" ) parser.add_argument( "--mixed_precision", type=str, default="no", choices=["no", "fp16", "bf16"], help=( "Whether to use mixed precision. Choose" "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10." "and an Nvidia Ampere GPU." ), ) parser.add_argument( "--checkpoint_frequency", type=int, default=5, help="How often to save a checkpoint and sample image (in epochs)", ) parser.add_argument( "--sample_frequency", type=int, default=1, help="How often to save a checkpoint and sample image (in epochs)", ) parser.add_argument( "--sample_image_size", type=int, default=768, help="Size of sample images", ) parser.add_argument( "--sample_batches", type=int, default=1, help="Number of sample batches to generate per checkpoint", ) parser.add_argument( "--sample_batch_size", type=int, default=1, help="Number of samples to generate per batch", ) parser.add_argument( "--valid_set_size", type=int, default=None, help="Number of images in the validation dataset." ) parser.add_argument( "--valid_set_repeat", type=int, default=1, help="Times the images in the validation dataset are repeated." ) parser.add_argument( "--train_batch_size", type=int, default=1, help="Batch size (per device) for the training dataloader." ) parser.add_argument( "--sample_steps", type=int, default=20, help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", ) parser.add_argument( "--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss." ) parser.add_argument( "--emb_decay_target", default=0.4, type=float, help="Embedding decay target." ) parser.add_argument( "--emb_decay_factor", default=1, type=float, help="Embedding decay factor." ) parser.add_argument( "--emb_decay_start", default=1e-4, type=float, help="Embedding decay start offset." ) parser.add_argument( "--noise_timesteps", type=int, default=1000, ) parser.add_argument( "--resume_from", type=str, default=None, help="Path to a directory to resume training from (ie, logs/token_name/2022-09-22T23-36-27)" ) parser.add_argument( "--global_step", type=int, default=0, ) parser.add_argument( "--config", type=str, default=None, help="Path to a JSON configuration file containing arguments for invoking this script." ) args = parser.parse_args() if args.config is not None: args = load_config(args.config) args = parser.parse_args(namespace=argparse.Namespace(**args)) if args.train_data_file is None: raise ValueError("You must specify --train_data_file") if args.pretrained_model_name_or_path is None: raise ValueError("You must specify --pretrained_model_name_or_path") if args.project is None: raise ValueError("You must specify --project") if isinstance(args.placeholder_tokens, str): args.placeholder_tokens = [args.placeholder_tokens] if len(args.placeholder_tokens) == 0: args.placeholder_tokens = [f"<*{i}>" for i in range(args.initializer_tokens)] if isinstance(args.initializer_tokens, str): args.initializer_tokens = [args.initializer_tokens] * len(args.placeholder_tokens) if len(args.initializer_tokens) == 0: raise ValueError("You must specify --initializer_tokens") if len(args.placeholder_tokens) != len(args.initializer_tokens): raise ValueError("--placeholder_tokens and --initializer_tokens must have the same number of items") if args.num_vectors is None: args.num_vectors = 1 if isinstance(args.num_vectors, int): args.num_vectors = [args.num_vectors] * len(args.initializer_tokens) if len(args.placeholder_tokens) != len(args.num_vectors): raise ValueError("--placeholder_tokens and --num_vectors must have the same number of items") if isinstance(args.collection, str): args.collection = [args.collection] if isinstance(args.exclude_collections, str): args.exclude_collections = [args.exclude_collections] if args.output_dir is None: raise ValueError("You must specify --output_dir") return args class Checkpointer(CheckpointerBase): def __init__( self, weight_dtype: torch.dtype, accelerator: Accelerator, vae: AutoencoderKL, unet: UNet2DConditionModel, tokenizer: MultiCLIPTokenizer, text_encoder: CLIPTextModel, ema_embeddings: EMAModel, scheduler, placeholder_tokens, placeholder_token_ids, *args, **kwargs ): super().__init__(*args, **kwargs) self.weight_dtype = weight_dtype self.accelerator = accelerator self.vae = vae self.unet = unet self.tokenizer = tokenizer self.text_encoder = text_encoder self.ema_embeddings = ema_embeddings self.scheduler = scheduler self.placeholder_tokens = placeholder_tokens self.placeholder_token_ids = placeholder_token_ids @torch.no_grad() def checkpoint(self, step, postfix): print("Saving checkpoint for step %d..." % step) checkpoints_path = self.output_dir.joinpath("checkpoints") checkpoints_path.mkdir(parents=True, exist_ok=True) text_encoder = self.accelerator.unwrap_model(self.text_encoder) ema_context = self.ema_embeddings.apply_temporary( text_encoder.text_model.embeddings.temp_token_embedding.parameters()) if self.ema_embeddings is not None else nullcontext() with ema_context: for (token, ids) in zip(self.placeholder_tokens, self.placeholder_token_ids): text_encoder.text_model.embeddings.save_embed( ids, checkpoints_path.joinpath(f"{slugify(token)}_{step}_{postfix}.bin") ) del text_encoder @torch.no_grad() def save_samples(self, step, num_inference_steps, guidance_scale=7.5, eta=0.0): text_encoder = self.accelerator.unwrap_model(self.text_encoder) ema_context = self.ema_embeddings.apply_temporary( text_encoder.text_model.embeddings.temp_token_embedding.parameters()) if self.ema_embeddings is not None else nullcontext() with ema_context: orig_dtype = text_encoder.dtype text_encoder.to(dtype=self.weight_dtype) pipeline = VlpnStableDiffusion( text_encoder=text_encoder, vae=self.vae, unet=self.unet, tokenizer=self.tokenizer, scheduler=self.scheduler, ).to(self.accelerator.device) pipeline.set_progress_bar_config(dynamic_ncols=True) super().save_samples(pipeline, step, num_inference_steps, guidance_scale, eta) text_encoder.to(dtype=orig_dtype) del text_encoder del pipeline if torch.cuda.is_available(): torch.cuda.empty_cache() def main(): args = parse_args() global_step_offset = args.global_step now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") output_dir = Path(args.output_dir).joinpath(slugify(args.project), now) output_dir.mkdir(parents=True, exist_ok=True) accelerator = Accelerator( log_with=LoggerType.TENSORBOARD, logging_dir=f"{output_dir}", gradient_accumulation_steps=args.gradient_accumulation_steps, mixed_precision=args.mixed_precision ) logging.basicConfig(filename=output_dir.joinpath("log.txt"), level=logging.DEBUG) if args.seed is None: args.seed = torch.random.seed() >> 32 set_seed(args.seed) save_args(output_dir, args) tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( args.pretrained_model_name_or_path) tokenizer.set_use_vector_shuffle(args.vector_shuffle) tokenizer.set_dropout(args.vector_dropout) vae.enable_slicing() vae.set_use_memory_efficient_attention_xformers(True) unet.set_use_memory_efficient_attention_xformers(True) if args.gradient_checkpointing: unet.enable_gradient_checkpointing() text_encoder.gradient_checkpointing_enable() if args.embeddings_dir is not None: embeddings_dir = Path(args.embeddings_dir) if not embeddings_dir.exists() or not embeddings_dir.is_dir(): raise ValueError("--embeddings_dir must point to an existing directory") added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( tokenizer=tokenizer, embeddings=embeddings, placeholder_tokens=args.placeholder_tokens, initializer_tokens=args.initializer_tokens, num_vectors=args.num_vectors ) if len(placeholder_token_ids) != 0: initializer_token_id_lens = [len(id) for id in initializer_token_ids] placeholder_token_stats = list(zip(args.placeholder_tokens, placeholder_token_ids, initializer_token_id_lens)) print(f"Added {len(placeholder_token_ids)} new tokens: {placeholder_token_stats}") if args.use_ema: ema_embeddings = EMAModel( text_encoder.text_model.embeddings.temp_token_embedding.parameters(), inv_gamma=args.ema_inv_gamma, power=args.ema_power, max_value=args.ema_max_decay, ) else: ema_embeddings = None vae.requires_grad_(False) unet.requires_grad_(False) text_encoder.text_model.encoder.requires_grad_(False) text_encoder.text_model.final_layer_norm.requires_grad_(False) text_encoder.text_model.embeddings.position_embedding.requires_grad_(False) text_encoder.text_model.embeddings.token_embedding.requires_grad_(False) if args.scale_lr: args.learning_rate = ( args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes ) if args.find_lr: args.learning_rate = 1e-5 if args.use_8bit_adam: try: import bitsandbytes as bnb except ImportError: raise ImportError("To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.") optimizer_class = bnb.optim.AdamW8bit else: optimizer_class = torch.optim.AdamW optimizer = optimizer_class( text_encoder.text_model.embeddings.temp_token_embedding.parameters(), lr=args.learning_rate, betas=(args.adam_beta1, args.adam_beta2), weight_decay=args.adam_weight_decay, eps=args.adam_epsilon, amsgrad=args.adam_amsgrad, ) weight_dtype = torch.float32 if args.mixed_precision == "fp16": weight_dtype = torch.float16 elif args.mixed_precision == "bf16": weight_dtype = torch.bfloat16 def keyword_filter(item: VlpnDataItem): cond1 = any( keyword in part for keyword in args.placeholder_tokens for part in item.prompt ) cond3 = args.collection is None or args.collection in item.collection cond4 = args.exclude_collections is None or not any( collection in item.collection for collection in args.exclude_collections ) return cond1 and cond3 and cond4 datamodule = VlpnDataModule( data_file=args.train_data_file, batch_size=args.train_batch_size, tokenizer=tokenizer, class_subdir=args.class_image_dir, num_class_images=args.num_class_images, size=args.resolution, num_buckets=args.num_buckets, progressive_buckets=args.progressive_buckets, bucket_step_size=args.bucket_step_size, bucket_max_pixels=args.bucket_max_pixels, dropout=args.tag_dropout, shuffle=not args.no_tag_shuffle, template_key=args.train_data_template, valid_set_size=args.valid_set_size, valid_set_repeat=args.valid_set_repeat, seed=args.seed, filter=keyword_filter, dtype=weight_dtype ) datamodule.setup() train_dataloader = datamodule.train_dataloader val_dataloader = datamodule.val_dataloader if args.num_class_images != 0: generate_class_images( accelerator, text_encoder, vae, unet, tokenizer, sample_scheduler, datamodule.data_train, args.sample_batch_size, args.sample_image_size, args.sample_steps ) if args.find_lr: lr_scheduler = None else: lr_scheduler = get_scheduler( args.lr_scheduler, optimizer=optimizer, num_training_steps_per_epoch=len(train_dataloader), gradient_accumulation_steps=args.gradient_accumulation_steps, min_lr=args.lr_min_lr, warmup_func=args.lr_warmup_func, annealing_func=args.lr_annealing_func, warmup_exp=args.lr_warmup_exp, annealing_exp=args.lr_annealing_exp, cycles=args.lr_cycles, train_epochs=args.num_train_epochs, warmup_epochs=args.lr_warmup_epochs, ) text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler ) vae.to(accelerator.device, dtype=weight_dtype) unet.to(accelerator.device, dtype=weight_dtype) if args.use_ema: ema_embeddings.to(accelerator.device) if args.gradient_checkpointing: unet.train() else: unet.eval() @contextmanager def on_train(epoch: int): try: tokenizer.train() yield finally: pass @contextmanager def on_eval(): try: tokenizer.eval() ema_context = ema_embeddings.apply_temporary( text_encoder.text_model.embeddings.temp_token_embedding.parameters()) if args.use_ema else nullcontext() with ema_context: yield finally: pass @torch.no_grad() def on_after_optimize(lr: float): text_encoder.text_model.embeddings.normalize( args.emb_decay_target, min(1.0, max(0.0, args.emb_decay_factor * ((lr - args.emb_decay_start) / (args.learning_rate - args.emb_decay_start)))) ) if args.use_ema: ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters()) def on_log(): if args.use_ema: return {"ema_decay": ema_embeddings.decay} return {} loss_step_ = partial( loss_step, vae, noise_scheduler, unet, text_encoder, args.prior_loss_weight, args.seed, ) checkpointer = Checkpointer( weight_dtype=weight_dtype, train_dataloader=train_dataloader, val_dataloader=val_dataloader, accelerator=accelerator, vae=vae, unet=unet, tokenizer=tokenizer, text_encoder=text_encoder, ema_embeddings=ema_embeddings, scheduler=sample_scheduler, placeholder_tokens=args.placeholder_tokens, placeholder_token_ids=placeholder_token_ids, output_dir=output_dir, sample_steps=args.sample_steps, sample_image_size=args.sample_image_size, sample_batch_size=args.sample_batch_size, sample_batches=args.sample_batches, seed=args.seed ) if accelerator.is_main_process: accelerator.init_trackers("textual_inversion") if args.find_lr: lr_finder = LRFinder( accelerator=accelerator, optimizer=optimizer, model=text_encoder, train_dataloader=train_dataloader, val_dataloader=val_dataloader, loss_step=loss_step_, on_train=on_train, on_eval=on_eval, on_after_optimize=on_after_optimize, ) lr_finder.run(num_epochs=100, end_lr=1e3) plt.savefig(output_dir.joinpath("lr.png"), dpi=300) plt.close() else: train_loop( accelerator=accelerator, optimizer=optimizer, lr_scheduler=lr_scheduler, model=text_encoder, checkpointer=checkpointer, train_dataloader=train_dataloader, val_dataloader=val_dataloader, loss_step=loss_step_, sample_frequency=args.sample_frequency, checkpoint_frequency=args.checkpoint_frequency, global_step_offset=global_step_offset, num_epochs=args.num_train_epochs, on_log=on_log, on_train=on_train, on_after_optimize=on_after_optimize, on_eval=on_eval ) if __name__ == "__main__": main()