From 7b149930bb53b93db74106ad20a30abf4b114f9b Mon Sep 17 00:00:00 2001 From: Volpeon Date: Fri, 13 Jan 2023 13:49:35 +0100 Subject: Removed PromptProcessor, modularized training loop --- train_ti.py | 268 ++++++++++++------------------------------------------------ 1 file changed, 53 insertions(+), 215 deletions(-) (limited to 'train_ti.py') diff --git a/train_ti.py b/train_ti.py index e18ee38..8c86586 100644 --- a/train_ti.py +++ b/train_ti.py @@ -21,11 +21,10 @@ from slugify import slugify from util import load_config, load_embeddings_from_dir from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion from data.csv import VlpnDataModule, VlpnDataItem -from training.common import loss_step, generate_class_images, get_scheduler +from training.common import loss_step, train_loop, generate_class_images, get_scheduler from training.lr import LRFinder from training.util import AverageMeter, CheckpointerBase, EMAModel, save_args from models.clip.embeddings import patch_managed_embeddings -from models.clip.prompt import PromptProcessor from models.clip.tokenizer import MultiCLIPTokenizer logger = get_logger(__name__) @@ -197,12 +196,6 @@ def parse_args(): type=int, default=100 ) - parser.add_argument( - "--max_train_steps", - type=int, - default=None, - help="Total number of training steps to perform. If provided, overrides num_train_epochs.", - ) parser.add_argument( "--gradient_accumulation_steps", type=int, @@ -409,7 +402,7 @@ def parse_args(): ) parser.add_argument( "--decay_target", - default=0.4, + default=None, type=float, help="Embedding decay target." ) @@ -668,8 +661,6 @@ def main(): text_encoder.text_model.embeddings.position_embedding.requires_grad_(False) text_encoder.text_model.embeddings.token_embedding.requires_grad_(False) - prompt_processor = PromptProcessor(tokenizer, text_encoder) - if args.scale_lr: args.learning_rate = ( args.learning_rate * args.gradient_accumulation_steps * @@ -722,7 +713,7 @@ def main(): datamodule = VlpnDataModule( data_file=args.train_data_file, batch_size=args.train_batch_size, - prompt_processor=prompt_processor, + tokenizer=tokenizer, class_subdir=args.class_image_dir, num_class_images=args.num_class_images, size=args.resolution, @@ -759,13 +750,7 @@ def main(): args.sample_steps ) - # Scheduler and math around the number of training steps. - overrode_max_train_steps = False num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) - if args.max_train_steps is None: - args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch - overrode_max_train_steps = True - num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) if args.find_lr: lr_scheduler = None @@ -781,7 +766,7 @@ def main(): annealing_exp=args.lr_annealing_exp, cycles=args.lr_cycles, warmup_epochs=args.lr_warmup_epochs, - max_train_steps=args.max_train_steps, + num_train_epochs=args.num_train_epochs, num_update_steps_per_epoch=num_update_steps_per_epoch, gradient_accumulation_steps=args.gradient_accumulation_steps ) @@ -805,15 +790,6 @@ def main(): else: unet.eval() - # We need to recalculate our total training steps as the size of the training dataloader may have changed. - num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) - if overrode_max_train_steps: - args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch - - num_val_steps_per_epoch = len(val_dataloader) - num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) - val_steps = num_val_steps_per_epoch * num_epochs - @contextmanager def on_train(): try: @@ -842,19 +818,44 @@ def main(): min(1.0, max(0.0, args.decay_factor * ((lr - args.decay_start) / (args.learning_rate - args.decay_start)))) ) + if args.use_ema: + ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters()) + + def on_log(): + if args.use_ema: + return {"ema_decay": ema_embeddings.decay} + return {} + loop = partial( loss_step, vae, noise_scheduler, unet, - prompt_processor, + text_encoder, args.num_class_images, args.prior_loss_weight, args.seed, ) - # We need to initialize the trackers we use, and also store our configuration. - # The trackers initializes automatically on the main process. + checkpointer = Checkpointer( + weight_dtype=weight_dtype, + datamodule=datamodule, + accelerator=accelerator, + vae=vae, + unet=unet, + tokenizer=tokenizer, + text_encoder=text_encoder, + ema_embeddings=ema_embeddings, + scheduler=checkpoint_scheduler, + placeholder_token=args.placeholder_token, + new_ids=new_ids, + output_dir=basepath, + sample_image_size=args.sample_image_size, + sample_batch_size=args.sample_batch_size, + sample_batches=args.sample_batches, + seed=args.seed + ) + if accelerator.is_main_process: config = vars(args).copy() config["initializer_token"] = " ".join(config["initializer_token"]) @@ -882,190 +883,27 @@ def main(): plt.savefig(basepath.joinpath("lr.png"), dpi=300) plt.close() - - quit() - - # Train! - total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps - - logger.info("***** Running training *****") - logger.info(f" Num Epochs = {num_epochs}") - logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") - logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") - logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") - logger.info(f" Total optimization steps = {args.max_train_steps}") - # Only show the progress bar once on each machine. - - global_step = 0 - - avg_loss = AverageMeter() - avg_acc = AverageMeter() - - avg_loss_val = AverageMeter() - avg_acc_val = AverageMeter() - - max_acc_val = 0.0 - - checkpointer = Checkpointer( - weight_dtype=weight_dtype, - datamodule=datamodule, - accelerator=accelerator, - vae=vae, - unet=unet, - tokenizer=tokenizer, - text_encoder=text_encoder, - ema_embeddings=ema_embeddings, - scheduler=checkpoint_scheduler, - placeholder_token=args.placeholder_token, - new_ids=new_ids, - output_dir=basepath, - sample_image_size=args.sample_image_size, - sample_batch_size=args.sample_batch_size, - sample_batches=args.sample_batches, - seed=args.seed - ) - - local_progress_bar = tqdm( - range(num_update_steps_per_epoch + num_val_steps_per_epoch), - disable=not accelerator.is_local_main_process, - dynamic_ncols=True - ) - local_progress_bar.set_description(f"Epoch 1 / {num_epochs}") - - global_progress_bar = tqdm( - range(args.max_train_steps + val_steps), - disable=not accelerator.is_local_main_process, - dynamic_ncols=True - ) - global_progress_bar.set_description("Total progress") - - try: - for epoch in range(num_epochs): - if accelerator.is_main_process: - if epoch % args.sample_frequency == 0: - checkpointer.save_samples(global_step + global_step_offset, args.sample_steps) - - local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") - local_progress_bar.reset() - - text_encoder.train() - - with on_train(): - for step, batch in enumerate(train_dataloader): - with accelerator.accumulate(text_encoder): - loss, acc, bsz = loop(step, batch) - - accelerator.backward(loss) - - optimizer.step() - lr_scheduler.step() - optimizer.zero_grad(set_to_none=True) - - avg_loss.update(loss.detach_(), bsz) - avg_acc.update(acc.detach_(), bsz) - - # Checks if the accelerator has performed an optimization step behind the scenes - if accelerator.sync_gradients: - on_after_optimize(lr_scheduler.get_last_lr()[0]) - - if args.use_ema: - ema_embeddings.step( - text_encoder.text_model.embeddings.temp_token_embedding.parameters()) - - local_progress_bar.update(1) - global_progress_bar.update(1) - - global_step += 1 - - logs = { - "train/loss": avg_loss.avg.item(), - "train/acc": avg_acc.avg.item(), - "train/cur_loss": loss.item(), - "train/cur_acc": acc.item(), - "lr": lr_scheduler.get_last_lr()[0], - } - if args.use_ema: - logs["ema_decay"] = ema_embeddings.decay - - accelerator.log(logs, step=global_step) - - local_progress_bar.set_postfix(**logs) - - if global_step >= args.max_train_steps: - break - - accelerator.wait_for_everyone() - - text_encoder.eval() - - cur_loss_val = AverageMeter() - cur_acc_val = AverageMeter() - - with torch.inference_mode(): - with on_eval(): - for step, batch in enumerate(val_dataloader): - loss, acc, bsz = loop(step, batch, True) - - loss = loss.detach_() - acc = acc.detach_() - - cur_loss_val.update(loss, bsz) - cur_acc_val.update(acc, bsz) - - avg_loss_val.update(loss, bsz) - avg_acc_val.update(acc, bsz) - - local_progress_bar.update(1) - global_progress_bar.update(1) - - logs = { - "val/loss": avg_loss_val.avg.item(), - "val/acc": avg_acc_val.avg.item(), - "val/cur_loss": loss.item(), - "val/cur_acc": acc.item(), - } - local_progress_bar.set_postfix(**logs) - - logs["val/cur_loss"] = cur_loss_val.avg.item() - logs["val/cur_acc"] = cur_acc_val.avg.item() - - accelerator.log(logs, step=global_step) - - local_progress_bar.clear() - global_progress_bar.clear() - - if accelerator.is_main_process: - if avg_acc_val.avg.item() > max_acc_val: - accelerator.print( - f"Global step {global_step}: Validation accuracy reached new maximum: {max_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}") - checkpointer.checkpoint(global_step + global_step_offset, "milestone") - max_acc_val = avg_acc_val.avg.item() - - if (epoch + 1) % args.checkpoint_frequency == 0: - checkpointer.checkpoint(global_step + global_step_offset, "training") - save_args(basepath, args, { - "global_step": global_step + global_step_offset - }) - - # Create the pipeline using using the trained modules and save it. - if accelerator.is_main_process: - print("Finished! Saving final checkpoint and resume state.") - checkpointer.checkpoint(global_step + global_step_offset, "end") - checkpointer.save_samples(global_step + global_step_offset, args.sample_steps) - save_args(basepath, args, { - "global_step": global_step + global_step_offset - }) - accelerator.end_training() - - except KeyboardInterrupt: - if accelerator.is_main_process: - print("Interrupted, saving checkpoint and resume state...") - checkpointer.checkpoint(global_step + global_step_offset, "end") - save_args(basepath, args, { - "global_step": global_step + global_step_offset - }) - accelerator.end_training() - quit() + else: + train_loop( + accelerator=accelerator, + optimizer=optimizer, + lr_scheduler=lr_scheduler, + model=text_encoder, + checkpointer=checkpointer, + train_dataloader=train_dataloader, + val_dataloader=val_dataloader, + loss_step=loop, + sample_frequency=args.sample_frequency, + sample_steps=args.sample_steps, + checkpoint_frequency=args.checkpoint_frequency, + global_step_offset=global_step_offset, + gradient_accumulation_steps=args.gradient_accumulation_steps, + num_epochs=args.num_train_epochs, + on_log=on_log, + on_train=on_train, + on_after_optimize=on_after_optimize, + on_eval=on_eval + ) if __name__ == "__main__": -- cgit v1.2.3-54-g00ecf