import argparse import math import datetime import logging from functools import partial from pathlib import Path from contextlib import contextmanager, nullcontext import torch import torch.utils.checkpoint from accelerate import Accelerator from accelerate.logging import get_logger from accelerate.utils import LoggerType, set_seed from diffusers import AutoencoderKL, DDPMScheduler, DPMSolverMultistepScheduler, UNet2DConditionModel from diffusers.optimization import get_scheduler, get_cosine_with_hard_restarts_schedule_with_warmup import matplotlib.pyplot as plt from tqdm.auto import tqdm from transformers import CLIPTextModel from slugify import slugify from util import load_config, load_embeddings_from_dir from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion from data.csv import VlpnDataModule, VlpnDataItem from training.common import run_model, generate_class_images from training.optimization import get_one_cycle_schedule from training.lr import LRFinder from training.util import AverageMeter, CheckpointerBase, EMAModel, save_args from models.clip.embeddings import patch_managed_embeddings from models.clip.prompt import PromptProcessor from models.clip.tokenizer import MultiCLIPTokenizer logger = get_logger(__name__) torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.benchmark = True def parse_args(): parser = argparse.ArgumentParser( description="Simple example of a training script." ) parser.add_argument( "--pretrained_model_name_or_path", type=str, default=None, help="Path to pretrained model or model identifier from huggingface.co/models.", ) parser.add_argument( "--tokenizer_name", type=str, default=None, help="Pretrained tokenizer name or path if not the same as model_name", ) parser.add_argument( "--train_data_file", type=str, default=None, help="A CSV file containing the training data." ) parser.add_argument( "--train_data_template", type=str, default="template", ) parser.add_argument( "--project", type=str, default=None, help="The name of the current project.", ) parser.add_argument( "--placeholder_token", type=str, nargs='*', help="A token to use as a placeholder for the concept.", ) parser.add_argument( "--initializer_token", type=str, nargs='*', help="A token to use as initializer word." ) parser.add_argument( "--num_vectors", type=int, nargs='*', help="Number of vectors per embedding." ) parser.add_argument( "--num_class_images", type=int, default=1, help="How many class images to generate." ) parser.add_argument( "--class_image_dir", type=str, default="cls", help="The directory where class images will be saved.", ) parser.add_argument( "--exclude_collections", type=str, nargs='*', help="Exclude all items with a listed collection.", ) parser.add_argument( "--output_dir", type=str, default="output/text-inversion", help="The output directory where the model predictions and checkpoints will be written.", ) parser.add_argument( "--embeddings_dir", type=str, default=None, help="The embeddings directory where Textual Inversion embeddings are stored.", ) parser.add_argument( "--collection", type=str, nargs='*', help="A collection to filter the dataset.", ) parser.add_argument( "--seed", type=int, default=None, help="A seed for reproducible training." ) parser.add_argument( "--resolution", type=int, default=768, help=( "The resolution for input images, all the images in the train/validation dataset will be resized to this" " resolution" ), ) parser.add_argument( "--num_buckets", type=int, default=0, help="Number of aspect ratio buckets in either direction.", ) parser.add_argument( "--progressive_buckets", action="store_true", help="Include images in smaller buckets as well.", ) parser.add_argument( "--bucket_step_size", type=int, default=64, help="Step size between buckets.", ) parser.add_argument( "--bucket_max_pixels", type=int, default=None, help="Maximum pixels per bucket.", ) parser.add_argument( "--tag_dropout", type=float, default=0.1, help="Tag dropout probability.", ) parser.add_argument( "--vector_dropout", type=int, default=0, help="Vector dropout probability.", ) parser.add_argument( "--vector_shuffle", type=str, default="auto", help='Vector shuffling algorithm. Choose between ["all", "trailing", "leading", "between", "auto", "off"]', ) parser.add_argument( "--dataloader_num_workers", type=int, default=0, help=( "The number of subprocesses to use for data loading. 0 means that the data will be loaded in the main" " process." ), ) parser.add_argument( "--num_train_epochs", type=int, default=100 ) parser.add_argument( "--max_train_steps", type=int, default=None, help="Total number of training steps to perform. If provided, overrides num_train_epochs.", ) parser.add_argument( "--gradient_accumulation_steps", type=int, default=1, help="Number of updates steps to accumulate before performing a backward/update pass.", ) parser.add_argument( "--gradient_checkpointing", action="store_true", help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", ) parser.add_argument( "--find_lr", action="store_true", help="Automatically find a learning rate (no training).", ) parser.add_argument( "--learning_rate", type=float, default=1e-4, help="Initial learning rate (after the potential warmup period) to use.", ) parser.add_argument( "--scale_lr", action="store_true", help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", ) parser.add_argument( "--lr_scheduler", type=str, default="one_cycle", help=( 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' ' "constant", "constant_with_warmup", "one_cycle"]' ), ) parser.add_argument( "--lr_warmup_epochs", type=int, default=10, help="Number of steps for the warmup in the lr scheduler." ) parser.add_argument( "--lr_cycles", type=int, default=None, help="Number of restart cycles in the lr scheduler." ) parser.add_argument( "--lr_warmup_func", type=str, default="cos", help='Choose between ["linear", "cos"]' ) parser.add_argument( "--lr_warmup_exp", type=int, default=1, help='If lr_warmup_func is "cos", exponent to modify the function' ) parser.add_argument( "--lr_annealing_func", type=str, default="cos", help='Choose between ["linear", "half_cos", "cos"]' ) parser.add_argument( "--lr_annealing_exp", type=int, default=1, help='If lr_annealing_func is "half_cos" or "cos", exponent to modify the function' ) parser.add_argument( "--lr_min_lr", type=float, default=None, help="Minimum learning rate in the lr scheduler." ) parser.add_argument( "--use_ema", action="store_true", help="Whether to use EMA model." ) parser.add_argument( "--ema_inv_gamma", type=float, default=1.0 ) parser.add_argument( "--ema_power", type=float, default=4/5 ) parser.add_argument( "--ema_max_decay", type=float, default=0.9999 ) parser.add_argument( "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." ) parser.add_argument( "--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer." ) parser.add_argument( "--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer." ) parser.add_argument( "--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use." ) parser.add_argument( "--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer" ) parser.add_argument( "--adam_amsgrad", type=bool, default=False, help="Amsgrad value for the Adam optimizer" ) parser.add_argument( "--mixed_precision", type=str, default="no", choices=["no", "fp16", "bf16"], help=( "Whether to use mixed precision. Choose" "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10." "and an Nvidia Ampere GPU." ), ) parser.add_argument( "--checkpoint_frequency", type=int, default=5, help="How often to save a checkpoint and sample image (in epochs)", ) parser.add_argument( "--sample_frequency", type=int, default=1, help="How often to save a checkpoint and sample image (in epochs)", ) parser.add_argument( "--sample_image_size", type=int, default=768, help="Size of sample images", ) parser.add_argument( "--sample_batches", type=int, default=1, help="Number of sample batches to generate per checkpoint", ) parser.add_argument( "--sample_batch_size", type=int, default=1, help="Number of samples to generate per batch", ) parser.add_argument( "--valid_set_size", type=int, default=None, help="Number of images in the validation dataset." ) parser.add_argument( "--train_batch_size", type=int, default=1, help="Batch size (per device) for the training dataloader." ) parser.add_argument( "--sample_steps", type=int, default=15, help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", ) parser.add_argument( "--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss." ) parser.add_argument( "--noise_timesteps", type=int, default=1000, ) parser.add_argument( "--resume_from", type=str, default=None, help="Path to a directory to resume training from (ie, logs/token_name/2022-09-22T23-36-27)" ) parser.add_argument( "--global_step", type=int, default=0, ) parser.add_argument( "--config", type=str, default=None, help="Path to a JSON configuration file containing arguments for invoking this script." ) args = parser.parse_args() if args.config is not None: args = load_config(args.config) args = parser.parse_args(namespace=argparse.Namespace(**args)) if args.train_data_file is None: raise ValueError("You must specify --train_data_file") if args.pretrained_model_name_or_path is None: raise ValueError("You must specify --pretrained_model_name_or_path") if args.project is None: raise ValueError("You must specify --project") if isinstance(args.initializer_token, str): args.initializer_token = [args.initializer_token] if len(args.initializer_token) == 0: raise ValueError("You must specify --initializer_token") if isinstance(args.placeholder_token, str): args.placeholder_token = [args.placeholder_token] if len(args.placeholder_token) == 0: args.placeholder_token = [f"<*{i}>" for i in range(args.initializer_token)] if args.num_vectors is None: args.num_vectors = 1 if isinstance(args.num_vectors, int): args.num_vectors = [args.num_vectors] * len(args.initializer_token) if len(args.placeholder_token) != len(args.initializer_token): raise ValueError("--placeholder_token and --initializer_token must have the same number of items") if len(args.placeholder_token) != len(args.num_vectors): raise ValueError("--placeholder_token and --num_vectors must have the same number of items") if isinstance(args.collection, str): args.collection = [args.collection] if isinstance(args.exclude_collections, str): args.exclude_collections = [args.exclude_collections] if args.output_dir is None: raise ValueError("You must specify --output_dir") return args class Checkpointer(CheckpointerBase): def __init__( self, weight_dtype, accelerator: Accelerator, vae: AutoencoderKL, unet: UNet2DConditionModel, tokenizer: MultiCLIPTokenizer, text_encoder: CLIPTextModel, ema_embeddings: EMAModel, scheduler, placeholder_token, new_ids, *args, **kwargs ): super().__init__(*args, **kwargs) self.weight_dtype = weight_dtype self.accelerator = accelerator self.vae = vae self.unet = unet self.tokenizer = tokenizer self.text_encoder = text_encoder self.ema_embeddings = ema_embeddings self.scheduler = scheduler self.placeholder_token = placeholder_token self.new_ids = new_ids @torch.no_grad() def checkpoint(self, step, postfix): print("Saving checkpoint for step %d..." % step) checkpoints_path = self.output_dir.joinpath("checkpoints") checkpoints_path.mkdir(parents=True, exist_ok=True) text_encoder = self.accelerator.unwrap_model(self.text_encoder) ema_context = self.ema_embeddings.apply_temporary( text_encoder.text_model.embeddings.temp_token_embedding.parameters()) if self.ema_embeddings is not None else nullcontext() with ema_context: for (token, ids) in zip(self.placeholder_token, self.new_ids): text_encoder.text_model.embeddings.save_embed( ids, checkpoints_path.joinpath(f"{slugify(token)}_{step}_{postfix}.bin") ) del text_encoder @torch.no_grad() def save_samples(self, step, num_inference_steps, guidance_scale=7.5, eta=0.0): text_encoder = self.accelerator.unwrap_model(self.text_encoder) ema_context = self.ema_embeddings.apply_temporary( text_encoder.text_model.embeddings.temp_token_embedding.parameters()) if self.ema_embeddings is not None else nullcontext() with ema_context: orig_dtype = text_encoder.dtype text_encoder.to(dtype=self.weight_dtype) pipeline = VlpnStableDiffusion( text_encoder=text_encoder, vae=self.vae, unet=self.unet, tokenizer=self.tokenizer, scheduler=self.scheduler, ).to(self.accelerator.device) pipeline.set_progress_bar_config(dynamic_ncols=True) super().save_samples(pipeline, step, num_inference_steps, guidance_scale, eta) text_encoder.to(dtype=orig_dtype) del text_encoder del pipeline if torch.cuda.is_available(): torch.cuda.empty_cache() def main(): args = parse_args() global_step_offset = args.global_step now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") basepath = Path(args.output_dir).joinpath(slugify(args.project), now) basepath.mkdir(parents=True, exist_ok=True) accelerator = Accelerator( log_with=LoggerType.TENSORBOARD, logging_dir=f"{basepath}", gradient_accumulation_steps=args.gradient_accumulation_steps, mixed_precision=args.mixed_precision ) logging.basicConfig(filename=basepath.joinpath("log.txt"), level=logging.DEBUG) args.seed = args.seed or (torch.random.seed() >> 32) set_seed(args.seed) save_args(basepath, args) # Load the tokenizer and add the placeholder token as a additional special token if args.tokenizer_name: tokenizer = MultiCLIPTokenizer.from_pretrained(args.tokenizer_name) elif args.pretrained_model_name_or_path: tokenizer = MultiCLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder='tokenizer') tokenizer.set_use_vector_shuffle(args.vector_shuffle) tokenizer.set_dropout(args.vector_dropout) # Load models and create wrapper for stable diffusion text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='text_encoder') vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder='vae') unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='unet') noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder='scheduler') checkpoint_scheduler = DPMSolverMultistepScheduler.from_pretrained( args.pretrained_model_name_or_path, subfolder='scheduler') vae.enable_slicing() vae.set_use_memory_efficient_attention_xformers(True) unet.set_use_memory_efficient_attention_xformers(True) if args.gradient_checkpointing: unet.enable_gradient_checkpointing() text_encoder.gradient_checkpointing_enable() embeddings = patch_managed_embeddings(text_encoder) ema_embeddings = None if args.embeddings_dir is not None: embeddings_dir = Path(args.embeddings_dir) if not embeddings_dir.exists() or not embeddings_dir.is_dir(): raise ValueError("--embeddings_dir must point to an existing directory") added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") # Convert the initializer_token, placeholder_token to ids initializer_token_ids = [ tokenizer.encode(token, add_special_tokens=False) for token in args.initializer_token ] new_ids = tokenizer.add_multi_tokens(args.placeholder_token, args.num_vectors) embeddings.resize(len(tokenizer)) for (new_id, init_ids) in zip(new_ids, initializer_token_ids): embeddings.add_embed(new_id, init_ids) init_ratios = [f"{len(init_ids)} / {len(new_id)}" for new_id, init_ids in zip(new_ids, initializer_token_ids)] print(f"Added {len(new_ids)} new tokens: {list(zip(args.placeholder_token, new_ids, init_ratios))}") if args.use_ema: ema_embeddings = EMAModel( text_encoder.text_model.embeddings.temp_token_embedding.parameters(), inv_gamma=args.ema_inv_gamma, power=args.ema_power, max_value=args.ema_max_decay, ) vae.requires_grad_(False) unet.requires_grad_(False) text_encoder.text_model.encoder.requires_grad_(False) text_encoder.text_model.final_layer_norm.requires_grad_(False) text_encoder.text_model.embeddings.position_embedding.requires_grad_(False) text_encoder.text_model.embeddings.token_embedding.requires_grad_(False) prompt_processor = PromptProcessor(tokenizer, text_encoder) if args.scale_lr: args.learning_rate = ( args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes ) if args.find_lr: args.learning_rate = 1e-5 # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs if args.use_8bit_adam: try: import bitsandbytes as bnb except ImportError: raise ImportError("To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.") optimizer_class = bnb.optim.AdamW8bit else: optimizer_class = torch.optim.AdamW # Initialize the optimizer optimizer = optimizer_class( text_encoder.text_model.embeddings.temp_token_embedding.parameters(), # only optimize the embeddings lr=args.learning_rate, betas=(args.adam_beta1, args.adam_beta2), weight_decay=args.adam_weight_decay, eps=args.adam_epsilon, amsgrad=args.adam_amsgrad, ) weight_dtype = torch.float32 if args.mixed_precision == "fp16": weight_dtype = torch.float16 elif args.mixed_precision == "bf16": weight_dtype = torch.bfloat16 def keyword_filter(item: VlpnDataItem): cond1 = any( keyword in part for keyword in args.placeholder_token for part in item.prompt ) cond3 = args.collection is None or args.collection in item.collection cond4 = args.exclude_collections is None or not any( collection in item.collection for collection in args.exclude_collections ) return cond1 and cond3 and cond4 def collate_fn(examples): prompt_ids = [example["prompt_ids"] for example in examples] nprompt_ids = [example["nprompt_ids"] for example in examples] input_ids = [example["instance_prompt_ids"] for example in examples] pixel_values = [example["instance_images"] for example in examples] # concat class and instance examples for prior preservation if args.num_class_images != 0 and "class_prompt_ids" in examples[0]: input_ids += [example["class_prompt_ids"] for example in examples] pixel_values += [example["class_images"] for example in examples] pixel_values = torch.stack(pixel_values) pixel_values = pixel_values.to(dtype=weight_dtype, memory_format=torch.contiguous_format) prompts = prompt_processor.unify_input_ids(prompt_ids) nprompts = prompt_processor.unify_input_ids(nprompt_ids) inputs = prompt_processor.unify_input_ids(input_ids) batch = { "prompt_ids": prompts.input_ids, "nprompt_ids": nprompts.input_ids, "input_ids": inputs.input_ids, "pixel_values": pixel_values, "attention_mask": inputs.attention_mask, } return batch datamodule = VlpnDataModule( data_file=args.train_data_file, batch_size=args.train_batch_size, prompt_processor=prompt_processor, class_subdir=args.class_image_dir, num_class_images=args.num_class_images, size=args.resolution, num_buckets=args.num_buckets, progressive_buckets=args.progressive_buckets, bucket_step_size=args.bucket_step_size, bucket_max_pixels=args.bucket_max_pixels, dropout=args.tag_dropout, template_key=args.train_data_template, valid_set_size=args.valid_set_size, num_workers=args.dataloader_num_workers, seed=args.seed, filter=keyword_filter, collate_fn=collate_fn ) datamodule.setup() train_dataloader = datamodule.train_dataloader val_dataloader = datamodule.val_dataloader if args.num_class_images != 0: generate_class_images( accelerator, text_encoder, vae, unet, tokenizer, checkpoint_scheduler, datamodule.data_train, args.sample_batch_size, args.sample_image_size, args.sample_steps ) # Scheduler and math around the number of training steps. overrode_max_train_steps = False num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) if args.max_train_steps is None: args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch overrode_max_train_steps = True num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) warmup_steps = args.lr_warmup_epochs * num_update_steps_per_epoch * args.gradient_accumulation_steps if args.find_lr: lr_scheduler = None elif args.lr_scheduler == "one_cycle": lr_min_lr = 0.04 if args.lr_min_lr is None else args.lr_min_lr / args.learning_rate lr_scheduler = get_one_cycle_schedule( optimizer=optimizer, num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, warmup=args.lr_warmup_func, annealing=args.lr_annealing_func, warmup_exp=args.lr_warmup_exp, annealing_exp=args.lr_annealing_exp, min_lr=lr_min_lr, ) elif args.lr_scheduler == "cosine_with_restarts": lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( optimizer=optimizer, num_warmup_steps=warmup_steps, num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, num_cycles=args.lr_cycles or math.ceil(math.sqrt( ((args.max_train_steps - warmup_steps) / num_update_steps_per_epoch))), ) else: lr_scheduler = get_scheduler( args.lr_scheduler, optimizer=optimizer, num_warmup_steps=warmup_steps, num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, ) text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler ) # Move vae and unet to device vae.to(accelerator.device, dtype=weight_dtype) unet.to(accelerator.device, dtype=weight_dtype) if args.use_ema: ema_embeddings.to(accelerator.device) # Keep vae and unet in eval mode as we don't train these vae.eval() if args.gradient_checkpointing: unet.train() else: unet.eval() # We need to recalculate our total training steps as the size of the training dataloader may have changed. num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) if overrode_max_train_steps: args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch num_val_steps_per_epoch = len(val_dataloader) num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) val_steps = num_val_steps_per_epoch * num_epochs @contextmanager def on_train(): try: tokenizer.train() yield finally: pass @contextmanager def on_eval(): try: tokenizer.eval() ema_context = ema_embeddings.apply_temporary( text_encoder.text_model.embeddings.temp_token_embedding.parameters()) if args.use_ema else nullcontext() with ema_context: yield finally: pass loop = partial( run_model, vae, noise_scheduler, unet, prompt_processor, args.num_class_images, args.prior_loss_weight, args.seed, ) # We need to initialize the trackers we use, and also store our configuration. # The trackers initializes automatically on the main process. if accelerator.is_main_process: config = vars(args).copy() config["initializer_token"] = " ".join(config["initializer_token"]) config["placeholder_token"] = " ".join(config["placeholder_token"]) config["num_vectors"] = " ".join([str(n) for n in config["num_vectors"]]) if config["collection"] is not None: config["collection"] = " ".join(config["collection"]) if config["exclude_collections"] is not None: config["exclude_collections"] = " ".join(config["exclude_collections"]) accelerator.init_trackers("textual_inversion", config=config) if args.find_lr: lr_finder = LRFinder( accelerator, text_encoder, optimizer, train_dataloader, val_dataloader, loop, on_train=on_train, on_eval=on_eval, ) lr_finder.run(num_epochs=200, end_lr=1e3) plt.savefig(basepath.joinpath("lr.png"), dpi=300) plt.close() quit() # Train! total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps logger.info("***** Running training *****") logger.info(f" Num Epochs = {num_epochs}") logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") logger.info(f" Total optimization steps = {args.max_train_steps}") # Only show the progress bar once on each machine. global_step = 0 avg_loss = AverageMeter() avg_acc = AverageMeter() avg_loss_val = AverageMeter() avg_acc_val = AverageMeter() max_acc_val = 0.0 checkpointer = Checkpointer( weight_dtype=weight_dtype, datamodule=datamodule, accelerator=accelerator, vae=vae, unet=unet, tokenizer=tokenizer, text_encoder=text_encoder, ema_embeddings=ema_embeddings, scheduler=checkpoint_scheduler, placeholder_token=args.placeholder_token, new_ids=new_ids, output_dir=basepath, sample_image_size=args.sample_image_size, sample_batch_size=args.sample_batch_size, sample_batches=args.sample_batches, seed=args.seed ) local_progress_bar = tqdm( range(num_update_steps_per_epoch + num_val_steps_per_epoch), disable=not accelerator.is_local_main_process, dynamic_ncols=True ) local_progress_bar.set_description(f"Epoch 1 / {num_epochs}") global_progress_bar = tqdm( range(args.max_train_steps + val_steps), disable=not accelerator.is_local_main_process, dynamic_ncols=True ) global_progress_bar.set_description("Total progress") try: for epoch in range(num_epochs): if accelerator.is_main_process: if epoch % args.sample_frequency == 0: checkpointer.save_samples(global_step + global_step_offset, args.sample_steps) local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") local_progress_bar.reset() text_encoder.train() with on_train(): for step, batch in enumerate(train_dataloader): with accelerator.accumulate(text_encoder): loss, acc, bsz = loop(step, batch) accelerator.backward(loss) optimizer.step() if not accelerator.optimizer_step_was_skipped: lr_scheduler.step() optimizer.zero_grad(set_to_none=True) avg_loss.update(loss.detach_(), bsz) avg_acc.update(acc.detach_(), bsz) # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: if args.use_ema: ema_embeddings.step( text_encoder.text_model.embeddings.temp_token_embedding.parameters()) local_progress_bar.update(1) global_progress_bar.update(1) global_step += 1 logs = { "train/loss": avg_loss.avg.item(), "train/acc": avg_acc.avg.item(), "train/cur_loss": loss.item(), "train/cur_acc": acc.item(), "lr": lr_scheduler.get_last_lr()[0], } if args.use_ema: logs["ema_decay"] = ema_embeddings.decay accelerator.log(logs, step=global_step) local_progress_bar.set_postfix(**logs) if global_step >= args.max_train_steps: break accelerator.wait_for_everyone() text_encoder.eval() cur_loss_val = AverageMeter() cur_acc_val = AverageMeter() with torch.inference_mode(): with on_eval(): for step, batch in enumerate(val_dataloader): loss, acc, bsz = loop(step, batch, True) loss = loss.detach_() acc = acc.detach_() cur_loss_val.update(loss, bsz) cur_acc_val.update(acc, bsz) avg_loss_val.update(loss, bsz) avg_acc_val.update(acc, bsz) local_progress_bar.update(1) global_progress_bar.update(1) logs = { "val/loss": avg_loss_val.avg.item(), "val/acc": avg_acc_val.avg.item(), "val/cur_loss": loss.item(), "val/cur_acc": acc.item(), } local_progress_bar.set_postfix(**logs) logs["val/cur_loss"] = cur_loss_val.avg.item() logs["val/cur_acc"] = cur_acc_val.avg.item() accelerator.log(logs, step=global_step) local_progress_bar.clear() global_progress_bar.clear() if accelerator.is_main_process: if avg_acc_val.avg.item() > max_acc_val: accelerator.print( f"Global step {global_step}: Validation accuracy reached new maximum: {max_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}") checkpointer.checkpoint(global_step + global_step_offset, "milestone") max_acc_val = avg_acc_val.avg.item() if (epoch + 1) % args.checkpoint_frequency == 0: checkpointer.checkpoint(global_step + global_step_offset, "training") save_args(basepath, args, { "global_step": global_step + global_step_offset }) # Create the pipeline using using the trained modules and save it. if accelerator.is_main_process: print("Finished! Saving final checkpoint and resume state.") checkpointer.checkpoint(global_step + global_step_offset, "end") checkpointer.save_samples(global_step + global_step_offset, args.sample_steps) save_args(basepath, args, { "global_step": global_step + global_step_offset }) accelerator.end_training() except KeyboardInterrupt: if accelerator.is_main_process: print("Interrupted, saving checkpoint and resume state...") checkpointer.checkpoint(global_step + global_step_offset, "end") save_args(basepath, args, { "global_step": global_step + global_step_offset }) accelerator.end_training() quit() if __name__ == "__main__": main()