From 127ec21e5bd4e7df21e36c561d070f8b9a0e19f5 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Fri, 13 Jan 2023 18:59:26 +0100 Subject: More modularization --- train_dreambooth.py | 272 +++++++++++++--------------------------------------- 1 file changed, 65 insertions(+), 207 deletions(-) (limited to 'train_dreambooth.py') diff --git a/train_dreambooth.py b/train_dreambooth.py index fbbe6c2..c892ebf 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py @@ -1,6 +1,5 @@ import argparse import itertools -import math import datetime import logging from pathlib import Path @@ -16,16 +15,15 @@ from accelerate.utils import LoggerType, set_seed from diffusers import AutoencoderKL, DDPMScheduler, DPMSolverMultistepScheduler, UNet2DConditionModel import matplotlib.pyplot as plt from diffusers.training_utils import EMAModel -from tqdm.auto import tqdm from transformers import CLIPTextModel from slugify import slugify from util import load_config, load_embeddings_from_dir from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion from data.csv import VlpnDataModule, VlpnDataItem -from training.common import loss_step, generate_class_images, get_scheduler +from training.common import loss_step, train_loop, generate_class_images, get_scheduler from training.lr import LRFinder -from training.util import AverageMeter, CheckpointerBase, save_args +from training.util import CheckpointerBase, save_args from models.clip.embeddings import patch_managed_embeddings from models.clip.tokenizer import MultiCLIPTokenizer @@ -292,7 +290,7 @@ def parse_args(): parser.add_argument( "--lr_min_lr", type=float, - default=None, + default=0.04, help="Minimum learning rate in the lr scheduler." ) parser.add_argument( @@ -787,14 +785,6 @@ def main(): args.sample_steps ) - # Scheduler and math around the number of training steps. - overrode_max_train_steps = False - num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) - if args.max_train_steps is None: - args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch - overrode_max_train_steps = True - num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) - if args.find_lr: lr_scheduler = None else: @@ -802,15 +792,14 @@ def main(): args.lr_scheduler, optimizer=optimizer, min_lr=args.lr_min_lr, - lr=args.learning_rate, warmup_func=args.lr_warmup_func, annealing_func=args.lr_annealing_func, warmup_exp=args.lr_warmup_exp, annealing_exp=args.lr_annealing_exp, cycles=args.lr_cycles, + train_epochs=args.num_train_epochs, warmup_epochs=args.lr_warmup_epochs, - max_train_steps=args.max_train_steps, - num_update_steps_per_epoch=num_update_steps_per_epoch, + num_training_steps_per_epoch=len(train_dataloader), gradient_accumulation_steps=args.gradient_accumulation_steps ) @@ -827,19 +816,16 @@ def main(): if args.use_ema: ema_unet.to(accelerator.device) - # We need to recalculate our total training steps as the size of the training dataloader may have changed. - num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) - if overrode_max_train_steps: - args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch - - num_val_steps_per_epoch = len(val_dataloader) - num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) - val_steps = num_val_steps_per_epoch * num_epochs - @contextmanager - def on_train(): + def on_train(epoch: int): try: tokenizer.train() + + if epoch < args.train_text_encoder_epochs: + text_encoder.train() + elif epoch == args.train_text_encoder_epochs: + text_encoder.requires_grad_(False) + yield finally: pass @@ -848,6 +834,7 @@ def main(): def on_eval(): try: tokenizer.eval() + text_encoder.eval() ema_context = ema_unet.apply_temporary(unet.parameters()) if args.use_ema else nullcontext() @@ -856,7 +843,7 @@ def main(): finally: pass - def on_before_optimize(): + def on_before_optimize(epoch: int): if accelerator.sync_gradients: params_to_clip = [unet.parameters()] if args.train_text_encoder and epoch < args.train_text_encoder_epochs: @@ -866,9 +853,17 @@ def main(): @torch.no_grad() def on_after_optimize(lr: float): if not args.train_text_encoder: - text_encoder.text_model.embeddings.normalize(min(1.0, 100 * lr)) + text_encoder.text_model.embeddings.normalize( + args.decay_target, + min(1.0, max(0.0, args.decay_factor * ((lr - args.decay_start) / (args.learning_rate - args.decay_start)))) + ) + + def on_log(): + if args.use_ema: + return {"ema_decay": ema_unet.decay} + return {} - loop = partial( + loss_step_ = partial( loss_step, vae, noise_scheduler, @@ -879,8 +874,25 @@ def main(): args.seed, ) - # We need to initialize the trackers we use, and also store our configuration. - # The trackers initializes automatically on the main process. + checkpointer = Checkpointer( + weight_dtype=weight_dtype, + datamodule=datamodule, + accelerator=accelerator, + vae=vae, + unet=unet, + ema_unet=ema_unet, + tokenizer=tokenizer, + text_encoder=text_encoder, + scheduler=checkpoint_scheduler, + output_dir=basepath, + placeholder_token=args.placeholder_token, + placeholder_token_id=placeholder_token_id, + sample_image_size=args.sample_image_size, + sample_batch_size=args.sample_batch_size, + sample_batches=args.sample_batches, + seed=args.seed + ) + if accelerator.is_main_process: config = vars(args).copy() config["initializer_token"] = " ".join(config["initializer_token"]) @@ -898,9 +910,9 @@ def main(): optimizer, train_dataloader, val_dataloader, - loop, - on_train=tokenizer.train, - on_eval=tokenizer.eval, + loss_step_, + on_train=on_train, + on_eval=on_eval, on_before_optimize=on_before_optimize, on_after_optimize=on_after_optimize, ) @@ -909,182 +921,28 @@ def main(): plt.savefig(basepath.joinpath("lr.png"), dpi=300) plt.close() - quit() - - # Train! - total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps - - logger.info("***** Running training *****") - logger.info(f" Num Epochs = {num_epochs}") - logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") - logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") - logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") - logger.info(f" Total optimization steps = {args.max_train_steps}") - # Only show the progress bar once on each machine. - - global_step = 0 - - avg_loss = AverageMeter() - avg_acc = AverageMeter() + return - avg_loss_val = AverageMeter() - avg_acc_val = AverageMeter() - - max_acc_val = 0.0 - - checkpointer = Checkpointer( - weight_dtype=weight_dtype, - datamodule=datamodule, + train_loop( accelerator=accelerator, - vae=vae, - unet=unet, - ema_unet=ema_unet, - tokenizer=tokenizer, - text_encoder=text_encoder, - scheduler=checkpoint_scheduler, - output_dir=basepath, - placeholder_token=args.placeholder_token, - placeholder_token_id=placeholder_token_id, - sample_image_size=args.sample_image_size, - sample_batch_size=args.sample_batch_size, - sample_batches=args.sample_batches, - seed=args.seed - ) - - local_progress_bar = tqdm( - range(num_update_steps_per_epoch + num_val_steps_per_epoch), - disable=not accelerator.is_local_main_process, - dynamic_ncols=True - ) - local_progress_bar.set_description(f"Epoch 1 / {num_epochs}") - - global_progress_bar = tqdm( - range(args.max_train_steps + val_steps), - disable=not accelerator.is_local_main_process, - dynamic_ncols=True + optimizer=optimizer, + lr_scheduler=lr_scheduler, + model=unet, + checkpointer=checkpointer, + train_dataloader=train_dataloader, + val_dataloader=val_dataloader, + loss_step=loss_step_, + sample_frequency=args.sample_frequency, + sample_steps=args.sample_steps, + checkpoint_frequency=args.checkpoint_frequency, + global_step_offset=0, + gradient_accumulation_steps=args.gradient_accumulation_steps, + num_epochs=args.num_train_epochs, + on_log=on_log, + on_train=on_train, + on_after_optimize=on_after_optimize, + on_eval=on_eval ) - global_progress_bar.set_description("Total progress") - - try: - for epoch in range(num_epochs): - if accelerator.is_main_process: - if epoch % args.sample_frequency == 0: - checkpointer.save_samples(global_step, args.sample_steps) - - local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") - local_progress_bar.reset() - - unet.train() - if epoch < args.train_text_encoder_epochs: - text_encoder.train() - elif epoch == args.train_text_encoder_epochs: - text_encoder.requires_grad_(False) - - with on_train(): - for step, batch in enumerate(train_dataloader): - with accelerator.accumulate(unet): - loss, acc, bsz = loop(step, batch) - - accelerator.backward(loss) - - on_before_optimize() - - optimizer.step() - if not accelerator.optimizer_step_was_skipped: - lr_scheduler.step() - if args.use_ema: - ema_unet.step(unet.parameters()) - optimizer.zero_grad(set_to_none=True) - - avg_loss.update(loss.detach_(), bsz) - avg_acc.update(acc.detach_(), bsz) - - # Checks if the accelerator has performed an optimization step behind the scenes - if accelerator.sync_gradients: - on_after_optimize(lr_scheduler.get_last_lr()[0]) - - local_progress_bar.update(1) - global_progress_bar.update(1) - - global_step += 1 - - logs = { - "train/loss": avg_loss.avg.item(), - "train/acc": avg_acc.avg.item(), - "train/cur_loss": loss.item(), - "train/cur_acc": acc.item(), - "lr": lr_scheduler.get_last_lr()[0] - } - if args.use_ema: - logs["ema_decay"] = 1 - ema_unet.decay - - accelerator.log(logs, step=global_step) - - local_progress_bar.set_postfix(**logs) - - if global_step >= args.max_train_steps: - break - - accelerator.wait_for_everyone() - - unet.eval() - text_encoder.eval() - - cur_loss_val = AverageMeter() - cur_acc_val = AverageMeter() - - with torch.inference_mode(): - with on_eval(): - for step, batch in enumerate(val_dataloader): - loss, acc, bsz = loop(step, batch, True) - - loss = loss.detach_() - acc = acc.detach_() - - cur_loss_val.update(loss, bsz) - cur_acc_val.update(acc, bsz) - - avg_loss_val.update(loss, bsz) - avg_acc_val.update(acc, bsz) - - local_progress_bar.update(1) - global_progress_bar.update(1) - - logs = { - "val/loss": avg_loss_val.avg.item(), - "val/acc": avg_acc_val.avg.item(), - "val/cur_loss": loss.item(), - "val/cur_acc": acc.item(), - } - local_progress_bar.set_postfix(**logs) - - logs["val/cur_loss"] = cur_loss_val.avg.item() - logs["val/cur_acc"] = cur_acc_val.avg.item() - - accelerator.log(logs, step=global_step) - - local_progress_bar.clear() - global_progress_bar.clear() - - if accelerator.is_main_process: - if avg_acc_val.avg.item() > max_acc_val: - accelerator.print( - f"Global step {global_step}: Validation accuracy reached new maximum: {max_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}") - max_acc_val = avg_acc_val.avg.item() - - # Create the pipeline using using the trained modules and save it. - if accelerator.is_main_process: - print("Finished! Saving final checkpoint and resume state.") - checkpointer.save_samples(global_step, args.sample_steps) - checkpointer.save_model() - accelerator.end_training() - - except KeyboardInterrupt: - if accelerator.is_main_process: - print("Interrupted, saving checkpoint and resume state...") - checkpointer.save_model() - accelerator.end_training() - quit() if __name__ == "__main__": -- cgit v1.2.3-54-g00ecf