import argparse import itertools import math import datetime import logging from pathlib import Path from functools import partial from contextlib import contextmanager, nullcontext import torch import torch.utils.checkpoint from accelerate import Accelerator from accelerate.logging import get_logger from accelerate.utils import LoggerType, set_seed from diffusers import AutoencoderKL, DDPMScheduler, DPMSolverMultistepScheduler, UNet2DConditionModel from diffusers.optimization import get_scheduler, get_cosine_with_hard_restarts_schedule_with_warmup import matplotlib.pyplot as plt from diffusers.training_utils import EMAModel from tqdm.auto import tqdm from transformers import CLIPTextModel from slugify import slugify from util import load_config, load_embeddings_from_dir from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion from data.csv import VlpnDataModule, VlpnDataItem from training.common import run_model, generate_class_images from training.optimization import get_one_cycle_schedule from training.lr import LRFinder from training.util import AverageMeter, CheckpointerBase, save_args from models.clip.embeddings import patch_managed_embeddings from models.clip.prompt import PromptProcessor from models.clip.tokenizer import MultiCLIPTokenizer logger = get_logger(__name__) torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.benchmark = True def parse_args(): parser = argparse.ArgumentParser( description="Simple example of a training script." ) parser.add_argument( "--pretrained_model_name_or_path", type=str, default=None, help="Path to pretrained model or model identifier from huggingface.co/models.", ) parser.add_argument( "--tokenizer_name", type=str, default=None, help="Pretrained tokenizer name or path if not the same as model_name", ) parser.add_argument( "--train_data_file", type=str, default=None, help="A folder containing the training data." ) parser.add_argument( "--train_data_template", type=str, default="template", ) parser.add_argument( "--project", type=str, default=None, help="The name of the current project.", ) parser.add_argument( "--placeholder_token", type=str, nargs='*', default=[], help="A token to use as a placeholder for the concept.", ) parser.add_argument( "--initializer_token", type=str, nargs='*', default=[], help="A token to use as initializer word." ) parser.add_argument( "--exclude_collections", type=str, nargs='*', help="Exclude all items with a listed collection.", ) parser.add_argument( "--train_text_encoder", action="store_true", default=True, help="Whether to train the whole text encoder." ) parser.add_argument( "--train_text_encoder_epochs", default=999999, help="Number of epochs the text encoder will be trained." ) parser.add_argument( "--num_buckets", type=int, default=4, help="Number of aspect ratio buckets in either direction.", ) parser.add_argument( "--progressive_buckets", action="store_true", help="Include images in smaller buckets as well.", ) parser.add_argument( "--bucket_step_size", type=int, default=64, help="Step size between buckets.", ) parser.add_argument( "--bucket_max_pixels", type=int, default=None, help="Maximum pixels per bucket.", ) parser.add_argument( "--tag_dropout", type=float, default=0.1, help="Tag dropout probability.", ) parser.add_argument( "--tag_shuffle", type="store_true", default=True, help="Shuffle tags.", ) parser.add_argument( "--vector_dropout", type=int, default=0, help="Vector dropout probability.", ) parser.add_argument( "--vector_shuffle", type=str, default="auto", help='Vector shuffling algorithm. Choose between ["all", "trailing", "leading", "between", "auto", "off"]', ) parser.add_argument( "--num_class_images", type=int, default=1, help="How many class images to generate." ) parser.add_argument( "--class_image_dir", type=str, default="cls", help="The directory where class images will be saved.", ) parser.add_argument( "--output_dir", type=str, default="output/dreambooth", help="The output directory where the model predictions and checkpoints will be written.", ) parser.add_argument( "--embeddings_dir", type=str, default=None, help="The embeddings directory where Textual Inversion embeddings are stored.", ) parser.add_argument( "--collection", type=str, nargs='*', help="A collection to filter the dataset.", ) parser.add_argument( "--seed", type=int, default=None, help="A seed for reproducible training." ) parser.add_argument( "--resolution", type=int, default=768, help=( "The resolution for input images, all the images in the train/validation dataset will be resized to this" " resolution" ), ) parser.add_argument( "--dataloader_num_workers", type=int, default=0, help=( "The number of subprocesses to use for data loading. 0 means that the data will be loaded in the main" " process." ), ) parser.add_argument( "--num_train_epochs", type=int, default=100 ) parser.add_argument( "--max_train_steps", type=int, default=None, help="Total number of training steps to perform. If provided, overrides num_train_epochs.", ) parser.add_argument( "--gradient_accumulation_steps", type=int, default=1, help="Number of updates steps to accumulate before performing a backward/update pass.", ) parser.add_argument( "--gradient_checkpointing", action="store_true", help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", ) parser.add_argument( "--find_lr", action="store_true", help="Automatically find a learning rate (no training).", ) parser.add_argument( "--learning_rate", type=float, default=2e-6, help="Initial learning rate (after the potential warmup period) to use.", ) parser.add_argument( "--scale_lr", action="store_true", help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", ) parser.add_argument( "--lr_scheduler", type=str, default="one_cycle", help=( 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' ' "constant", "constant_with_warmup", "one_cycle"]' ), ) parser.add_argument( "--lr_warmup_epochs", type=int, default=10, help="Number of steps for the warmup in the lr scheduler." ) parser.add_argument( "--lr_cycles", type=int, default=None, help="Number of restart cycles in the lr scheduler (if supported)." ) parser.add_argument( "--lr_warmup_func", type=str, default="cos", help='Choose between ["linear", "cos"]' ) parser.add_argument( "--lr_warmup_exp", type=int, default=1, help='If lr_warmup_func is "cos", exponent to modify the function' ) parser.add_argument( "--lr_annealing_func", type=str, default="cos", help='Choose between ["linear", "half_cos", "cos"]' ) parser.add_argument( "--lr_annealing_exp", type=int, default=3, help='If lr_annealing_func is "half_cos" or "cos", exponent to modify the function' ) parser.add_argument( "--lr_min_lr", type=float, default=None, help="Minimum learning rate in the lr scheduler." ) parser.add_argument( "--use_ema", action="store_true", help="Whether to use EMA model." ) parser.add_argument( "--ema_inv_gamma", type=float, default=1.0 ) parser.add_argument( "--ema_power", type=float, default=6/7 ) parser.add_argument( "--ema_max_decay", type=float, default=0.9999 ) parser.add_argument( "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." ) parser.add_argument( "--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer." ) parser.add_argument( "--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer." ) parser.add_argument( "--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use." ) parser.add_argument( "--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer" ) parser.add_argument( "--adam_amsgrad", type=bool, default=False, help="Amsgrad value for the Adam optimizer" ) parser.add_argument( "--mixed_precision", type=str, default="no", choices=["no", "fp16", "bf16"], help=( "Whether to use mixed precision. Choose" "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10." "and an Nvidia Ampere GPU." ), ) parser.add_argument( "--sample_frequency", type=int, default=1, help="How often to save a checkpoint and sample image", ) parser.add_argument( "--sample_image_size", type=int, default=768, help="Size of sample images", ) parser.add_argument( "--sample_batches", type=int, default=1, help="Number of sample batches to generate per checkpoint", ) parser.add_argument( "--sample_batch_size", type=int, default=1, help="Number of samples to generate per batch", ) parser.add_argument( "--valid_set_size", type=int, default=None, help="Number of images in the validation dataset." ) parser.add_argument( "--valid_set_repeat", type=int, default=1, help="Times the images in the validation dataset are repeated." ) parser.add_argument( "--train_batch_size", type=int, default=1, help="Batch size (per device) for the training dataloader." ) parser.add_argument( "--sample_steps", type=int, default=20, help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", ) parser.add_argument( "--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss." ) parser.add_argument( "--max_grad_norm", default=1.0, type=float, help="Max gradient norm." ) parser.add_argument( "--noise_timesteps", type=int, default=1000, ) parser.add_argument( "--config", type=str, default=None, help="Path to a JSON configuration file containing arguments for invoking this script." ) args = parser.parse_args() if args.config is not None: args = load_config(args.config) args = parser.parse_args(namespace=argparse.Namespace(**args)) if args.train_data_file is None: raise ValueError("You must specify --train_data_file") if args.pretrained_model_name_or_path is None: raise ValueError("You must specify --pretrained_model_name_or_path") if args.project is None: raise ValueError("You must specify --project") if isinstance(args.initializer_token, str): args.initializer_token = [args.initializer_token] if isinstance(args.placeholder_token, str): args.placeholder_token = [args.placeholder_token] if len(args.placeholder_token) == 0: args.placeholder_token = [f"<*{i}>" for i in range(len(args.initializer_token))] if len(args.placeholder_token) != len(args.initializer_token): raise ValueError("Number of items in --placeholder_token and --initializer_token must match") if isinstance(args.collection, str): args.collection = [args.collection] if isinstance(args.exclude_collections, str): args.exclude_collections = [args.exclude_collections] if args.output_dir is None: raise ValueError("You must specify --output_dir") return args class Checkpointer(CheckpointerBase): def __init__( self, weight_dtype, datamodule, accelerator, vae, unet, ema_unet, tokenizer, text_encoder, scheduler, output_dir: Path, placeholder_token, placeholder_token_id, sample_image_size, sample_batches, sample_batch_size, seed, ): super().__init__( datamodule=datamodule, output_dir=output_dir, placeholder_token=placeholder_token, placeholder_token_id=placeholder_token_id, sample_image_size=sample_image_size, seed=seed or torch.random.seed(), sample_batches=sample_batches, sample_batch_size=sample_batch_size ) self.weight_dtype = weight_dtype self.accelerator = accelerator self.vae = vae self.unet = unet self.ema_unet = ema_unet self.tokenizer = tokenizer self.text_encoder = text_encoder self.scheduler = scheduler @torch.no_grad() def save_model(self): print("Saving model...") unet = self.accelerator.unwrap_model(self.unet) text_encoder = self.accelerator.unwrap_model(self.text_encoder) ema_context = self.ema_unet.apply_temporary(unet.parameters()) if self.ema_unet is not None else nullcontext() with ema_context: pipeline = VlpnStableDiffusion( text_encoder=text_encoder, vae=self.vae, unet=unet, tokenizer=self.tokenizer, scheduler=self.scheduler, ) pipeline.save_pretrained(self.output_dir.joinpath("model")) del unet del text_encoder del pipeline if torch.cuda.is_available(): torch.cuda.empty_cache() @torch.no_grad() def save_samples(self, step, num_inference_steps, guidance_scale=7.5, eta=0.0): unet = self.accelerator.unwrap_model(self.unet) text_encoder = self.accelerator.unwrap_model(self.text_encoder) ema_context = self.ema_unet.apply_temporary(unet.parameters()) if self.ema_unet is not None else nullcontext() with ema_context: orig_unet_dtype = unet.dtype orig_text_encoder_dtype = text_encoder.dtype unet.to(dtype=self.weight_dtype) text_encoder.to(dtype=self.weight_dtype) pipeline = VlpnStableDiffusion( text_encoder=text_encoder, vae=self.vae, unet=unet, tokenizer=self.tokenizer, scheduler=self.scheduler, ).to(self.accelerator.device) pipeline.set_progress_bar_config(dynamic_ncols=True) super().save_samples(pipeline, step, num_inference_steps, guidance_scale, eta) unet.to(dtype=orig_unet_dtype) text_encoder.to(dtype=orig_text_encoder_dtype) del unet del text_encoder del pipeline if torch.cuda.is_available(): torch.cuda.empty_cache() def main(): args = parse_args() if args.train_text_encoder and args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1: raise ValueError( "Gradient accumulation is not supported when training the text encoder in distributed training. " "Please set gradient_accumulation_steps to 1. This feature will be supported in the future." ) now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") basepath = Path(args.output_dir).joinpath(slugify(args.project), now) basepath.mkdir(parents=True, exist_ok=True) accelerator = Accelerator( log_with=LoggerType.TENSORBOARD, logging_dir=f"{basepath}", gradient_accumulation_steps=args.gradient_accumulation_steps, mixed_precision=args.mixed_precision ) logging.basicConfig(filename=basepath.joinpath("log.txt"), level=logging.DEBUG) args.seed = args.seed or (torch.random.seed() >> 32) set_seed(args.seed) save_args(basepath, args) # Load the tokenizer and add the placeholder token as a additional special token if args.tokenizer_name: tokenizer = MultiCLIPTokenizer.from_pretrained(args.tokenizer_name) elif args.pretrained_model_name_or_path: tokenizer = MultiCLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder='tokenizer') tokenizer.set_use_vector_shuffle(args.vector_shuffle) tokenizer.set_dropout(args.vector_dropout) # Load models and create wrapper for stable diffusion text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='text_encoder') vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder='vae') unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='unet') noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder='scheduler') checkpoint_scheduler = DPMSolverMultistepScheduler.from_pretrained( args.pretrained_model_name_or_path, subfolder='scheduler') ema_unet = None vae.enable_slicing() vae.set_use_memory_efficient_attention_xformers(True) unet.set_use_memory_efficient_attention_xformers(True) if args.gradient_checkpointing: unet.enable_gradient_checkpointing() text_encoder.gradient_checkpointing_enable() if args.use_ema: ema_unet = EMAModel( unet.parameters(), inv_gamma=args.ema_inv_gamma, power=args.ema_power, max_value=args.ema_max_decay, ) embeddings = patch_managed_embeddings(text_encoder) if args.embeddings_dir is not None: embeddings_dir = Path(args.embeddings_dir) if not embeddings_dir.exists() or not embeddings_dir.is_dir(): raise ValueError("--embeddings_dir must point to an existing directory") added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") if len(args.placeholder_token) != 0: # Convert the initializer_token, placeholder_token to ids initializer_token_ids = [ tokenizer.encode(token, add_special_tokens=False) for token in args.initializer_token ] new_ids = tokenizer.add_multi_tokens(args.placeholder_token, args.num_vectors) embeddings.resize(len(tokenizer)) for (new_id, init_ids) in zip(new_ids, initializer_token_ids): embeddings.add_embed(new_id, init_ids) init_ratios = [f"{len(init_ids)} / {len(new_id)}" for new_id, init_ids in zip(new_ids, initializer_token_ids)] print(f"Added {len(new_ids)} new tokens: {list(zip(args.placeholder_token, new_ids, init_ratios))}") else: placeholder_token_id = [] vae.requires_grad_(False) if args.train_text_encoder: print(f"Training entire text encoder.") embeddings.persist() text_encoder.text_model.embeddings.temp_token_embedding.requires_grad_(False) else: print(f"Training added text embeddings") text_encoder.text_model.encoder.requires_grad_(False) text_encoder.text_model.final_layer_norm.requires_grad_(False) text_encoder.text_model.embeddings.position_embedding.requires_grad_(False) text_encoder.text_model.embeddings.token_embedding.requires_grad_(False) prompt_processor = PromptProcessor(tokenizer, text_encoder) if args.scale_lr: args.learning_rate = ( args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes ) if args.find_lr: args.learning_rate = 1e-6 # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs if args.use_8bit_adam: try: import bitsandbytes as bnb except ImportError: raise ImportError("To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.") optimizer_class = bnb.optim.AdamW8bit else: optimizer_class = torch.optim.AdamW if args.train_text_encoder: text_encoder_params_to_optimize = text_encoder.parameters() else: text_encoder_params_to_optimize = text_encoder.text_model.embeddings.temp_token_embedding.parameters() # Initialize the optimizer optimizer = optimizer_class( [ { 'params': unet.parameters(), }, { 'params': text_encoder_params_to_optimize, } ], lr=args.learning_rate, betas=(args.adam_beta1, args.adam_beta2), weight_decay=args.adam_weight_decay, eps=args.adam_epsilon, amsgrad=args.adam_amsgrad, ) weight_dtype = torch.float32 if args.mixed_precision == "fp16": weight_dtype = torch.float16 elif args.mixed_precision == "bf16": weight_dtype = torch.bfloat16 def keyword_filter(item: VlpnDataItem): cond3 = args.collection is None or args.collection in item.collection cond4 = args.exclude_collections is None or not any( collection in item.collection for collection in args.exclude_collections ) return cond3 and cond4 def collate_fn(examples): prompt_ids = [example["prompt_ids"] for example in examples] nprompt_ids = [example["nprompt_ids"] for example in examples] input_ids = [example["instance_prompt_ids"] for example in examples] pixel_values = [example["instance_images"] for example in examples] # concat class and instance examples for prior preservation if args.num_class_images != 0 and "class_prompt_ids" in examples[0]: input_ids += [example["class_prompt_ids"] for example in examples] pixel_values += [example["class_images"] for example in examples] pixel_values = torch.stack(pixel_values) pixel_values = pixel_values.to(dtype=weight_dtype, memory_format=torch.contiguous_format) prompts = prompt_processor.unify_input_ids(prompt_ids) nprompts = prompt_processor.unify_input_ids(nprompt_ids) inputs = prompt_processor.unify_input_ids(input_ids) batch = { "prompt_ids": prompts.input_ids, "nprompt_ids": nprompts.input_ids, "input_ids": inputs.input_ids, "pixel_values": pixel_values, "attention_mask": inputs.attention_mask, } return batch datamodule = VlpnDataModule( data_file=args.train_data_file, batch_size=args.train_batch_size, prompt_processor=prompt_processor, class_subdir=args.class_image_dir, num_class_images=args.num_class_images, size=args.resolution, num_buckets=args.num_buckets, progressive_buckets=args.progressive_buckets, bucket_step_size=args.bucket_step_size, bucket_max_pixels=args.bucket_max_pixels, dropout=args.tag_dropout, shuffle=args.tag_shuffle, template_key=args.train_data_template, valid_set_size=args.valid_set_size, valid_set_repeat=args.valid_set_repeat, num_workers=args.dataloader_num_workers, seed=args.seed, filter=keyword_filter, collate_fn=collate_fn ) datamodule.prepare_data() datamodule.setup() train_dataloader = datamodule.train_dataloader val_dataloader = datamodule.val_dataloader if args.num_class_images != 0: generate_class_images( accelerator, text_encoder, vae, unet, tokenizer, checkpoint_scheduler, datamodule.data_train, args.sample_batch_size, args.sample_image_size, args.sample_steps ) # Scheduler and math around the number of training steps. overrode_max_train_steps = False num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) if args.max_train_steps is None: args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch overrode_max_train_steps = True num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) warmup_steps = args.lr_warmup_epochs * num_update_steps_per_epoch * args.gradient_accumulation_steps if args.lr_scheduler == "one_cycle": lr_min_lr = 0.04 if args.lr_min_lr is None else args.lr_min_lr / args.learning_rate lr_scheduler = get_one_cycle_schedule( optimizer=optimizer, num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, warmup=args.lr_warmup_func, annealing=args.lr_annealing_func, warmup_exp=args.lr_warmup_exp, annealing_exp=args.lr_annealing_exp, min_lr=lr_min_lr, ) elif args.lr_scheduler == "cosine_with_restarts": lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( optimizer=optimizer, num_warmup_steps=warmup_steps, num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, num_cycles=args.lr_cycles or math.ceil(math.sqrt( ((args.max_train_steps - warmup_steps) / num_update_steps_per_epoch))), ) else: lr_scheduler = get_scheduler( args.lr_scheduler, optimizer=optimizer, num_warmup_steps=warmup_steps, num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, ) unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler ) # Move text_encoder and vae to device vae.to(accelerator.device, dtype=weight_dtype) # Keep text_encoder and vae in eval mode as we don't train these vae.eval() if args.use_ema: ema_unet.to(accelerator.device) # We need to recalculate our total training steps as the size of the training dataloader may have changed. num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) if overrode_max_train_steps: args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch num_val_steps_per_epoch = len(val_dataloader) num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) val_steps = num_val_steps_per_epoch * num_epochs @contextmanager def on_train(): try: tokenizer.train() yield finally: pass @contextmanager def on_eval(): try: tokenizer.eval() ema_context = ema_unet.apply_temporary(unet.parameters()) if args.use_ema else nullcontext() with ema_context: yield finally: pass loop = partial( run_model, vae, noise_scheduler, unet, prompt_processor, args.num_class_images, args.prior_loss_weight, args.seed, ) # We need to initialize the trackers we use, and also store our configuration. # The trackers initializes automatically on the main process. if accelerator.is_main_process: config = vars(args).copy() config["initializer_token"] = " ".join(config["initializer_token"]) config["placeholder_token"] = " ".join(config["placeholder_token"]) if config["collection"] is not None: config["collection"] = " ".join(config["collection"]) if config["exclude_collections"] is not None: config["exclude_collections"] = " ".join(config["exclude_collections"]) accelerator.init_trackers("dreambooth", config=config) if args.find_lr: lr_finder = LRFinder( accelerator, text_encoder, optimizer, train_dataloader, val_dataloader, loop, on_train=tokenizer.train, on_eval=tokenizer.eval, ) lr_finder.run(end_lr=1e2) plt.savefig(basepath.joinpath("lr.png")) plt.close() quit() # Train! total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps logger.info("***** Running training *****") logger.info(f" Num Epochs = {num_epochs}") logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") logger.info(f" Total optimization steps = {args.max_train_steps}") # Only show the progress bar once on each machine. global_step = 0 avg_loss = AverageMeter() avg_acc = AverageMeter() avg_loss_val = AverageMeter() avg_acc_val = AverageMeter() max_acc_val = 0.0 checkpointer = Checkpointer( weight_dtype=weight_dtype, datamodule=datamodule, accelerator=accelerator, vae=vae, unet=unet, ema_unet=ema_unet, tokenizer=tokenizer, text_encoder=text_encoder, scheduler=checkpoint_scheduler, output_dir=basepath, placeholder_token=args.placeholder_token, placeholder_token_id=placeholder_token_id, sample_image_size=args.sample_image_size, sample_batch_size=args.sample_batch_size, sample_batches=args.sample_batches, seed=args.seed ) local_progress_bar = tqdm( range(num_update_steps_per_epoch + num_val_steps_per_epoch), disable=not accelerator.is_local_main_process, dynamic_ncols=True ) local_progress_bar.set_description(f"Epoch 1 / {num_epochs}") global_progress_bar = tqdm( range(args.max_train_steps + val_steps), disable=not accelerator.is_local_main_process, dynamic_ncols=True ) global_progress_bar.set_description("Total progress") try: for epoch in range(num_epochs): if accelerator.is_main_process: if epoch % args.sample_frequency == 0: checkpointer.save_samples(global_step, args.sample_steps) local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") local_progress_bar.reset() unet.train() if epoch < args.train_text_encoder_epochs: text_encoder.train() elif epoch == args.train_text_encoder_epochs: text_encoder.requires_grad_(False) with on_train(): for step, batch in enumerate(train_dataloader): with accelerator.accumulate(unet): loss, acc, bsz = loop(step, batch) accelerator.backward(loss) if accelerator.sync_gradients: params_to_clip = ( itertools.chain(unet.parameters(), text_encoder.parameters()) if args.train_text_encoder and epoch < args.train_text_encoder_epochs else unet.parameters() ) accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) optimizer.step() if not accelerator.optimizer_step_was_skipped: lr_scheduler.step() if args.use_ema: ema_unet.step(unet.parameters()) optimizer.zero_grad(set_to_none=True) avg_loss.update(loss.detach_(), bsz) avg_acc.update(acc.detach_(), bsz) # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: local_progress_bar.update(1) global_progress_bar.update(1) global_step += 1 logs = { "train/loss": avg_loss.avg.item(), "train/acc": avg_acc.avg.item(), "train/cur_loss": loss.item(), "train/cur_acc": acc.item(), "lr": lr_scheduler.get_last_lr()[0] } if args.use_ema: logs["ema_decay"] = 1 - ema_unet.decay accelerator.log(logs, step=global_step) local_progress_bar.set_postfix(**logs) if global_step >= args.max_train_steps: break accelerator.wait_for_everyone() unet.eval() text_encoder.eval() cur_loss_val = AverageMeter() cur_acc_val = AverageMeter() with torch.inference_mode(): with on_eval(): for step, batch in enumerate(val_dataloader): loss, acc, bsz = loop(step, batch, True) loss = loss.detach_() acc = acc.detach_() cur_loss_val.update(loss, bsz) cur_acc_val.update(acc, bsz) avg_loss_val.update(loss, bsz) avg_acc_val.update(acc, bsz) local_progress_bar.update(1) global_progress_bar.update(1) logs = { "val/loss": avg_loss_val.avg.item(), "val/acc": avg_acc_val.avg.item(), "val/cur_loss": loss.item(), "val/cur_acc": acc.item(), } local_progress_bar.set_postfix(**logs) logs["val/cur_loss"] = cur_loss_val.avg.item() logs["val/cur_acc"] = cur_acc_val.avg.item() accelerator.log(logs, step=global_step) local_progress_bar.clear() global_progress_bar.clear() if accelerator.is_main_process: if avg_acc_val.avg.item() > max_acc_val: accelerator.print( f"Global step {global_step}: Validation accuracy reached new maximum: {max_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}") max_acc_val = avg_acc_val.avg.item() # Create the pipeline using using the trained modules and save it. if accelerator.is_main_process: print("Finished! Saving final checkpoint and resume state.") checkpointer.save_samples(global_step, args.sample_steps) checkpointer.save_model() accelerator.end_training() except KeyboardInterrupt: if accelerator.is_main_process: print("Interrupted, saving checkpoint and resume state...") checkpointer.save_model() accelerator.end_training() quit() if __name__ == "__main__": main()