From 3353ffb64c280a938a0f2513d13b716c1fca8c02 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 7 Jan 2023 17:10:06 +0100 Subject: Cleanup --- train_dreambooth.py | 265 ++++++++++++++++++++++++++-------------------------- 1 file changed, 131 insertions(+), 134 deletions(-) (limited to 'train_dreambooth.py') diff --git a/train_dreambooth.py b/train_dreambooth.py index d265bcc..589af59 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py @@ -5,6 +5,7 @@ import datetime import logging from pathlib import Path from functools import partial +from contextlib import contextmanager, nullcontext import torch import torch.utils.checkpoint @@ -23,7 +24,7 @@ from slugify import slugify from util import load_config, load_embeddings_from_dir from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion from data.csv import VlpnDataModule, VlpnDataItem -from training.common import run_model +from training.common import run_model, generate_class_images from training.optimization import get_one_cycle_schedule from training.lr import LRFinder from training.util import AverageMeter, CheckpointerBase, save_args @@ -216,7 +217,6 @@ def parse_args(): parser.add_argument( "--scale_lr", action="store_true", - default=True, help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", ) parser.add_argument( @@ -273,7 +273,6 @@ def parse_args(): parser.add_argument( "--use_ema", action="store_true", - default=True, help="Whether to use EMA model." ) parser.add_argument( @@ -294,7 +293,6 @@ def parse_args(): parser.add_argument( "--use_8bit_adam", action="store_true", - default=True, help="Whether or not to use 8-bit Adam from bitsandbytes." ) parser.add_argument( @@ -486,17 +484,20 @@ class Checkpointer(CheckpointerBase): def save_model(self): print("Saving model...") - unet = self.ema_unet.averaged_model if self.ema_unet is not None else self.accelerator.unwrap_model(self.unet) + unet = self.accelerator.unwrap_model(self.unet) text_encoder = self.accelerator.unwrap_model(self.text_encoder) - pipeline = VlpnStableDiffusion( - text_encoder=text_encoder, - vae=self.vae, - unet=unet, - tokenizer=self.tokenizer, - scheduler=self.scheduler, - ) - pipeline.save_pretrained(self.output_dir.joinpath("model")) + ema_context = self.ema_unet.apply_temporary(unet.parameters()) if self.ema_unet is not None else nullcontext() + + with ema_context: + pipeline = VlpnStableDiffusion( + text_encoder=text_encoder, + vae=self.vae, + unet=unet, + tokenizer=self.tokenizer, + scheduler=self.scheduler, + ) + pipeline.save_pretrained(self.output_dir.joinpath("model")) del unet del text_encoder @@ -507,28 +508,31 @@ class Checkpointer(CheckpointerBase): @torch.no_grad() def save_samples(self, step, num_inference_steps, guidance_scale=7.5, eta=0.0): - unet = self.ema_unet.averaged_model if self.ema_unet is not None else self.accelerator.unwrap_model(self.unet) + unet = self.accelerator.unwrap_model(self.unet) text_encoder = self.accelerator.unwrap_model(self.text_encoder) - orig_unet_dtype = unet.dtype - orig_text_encoder_dtype = text_encoder.dtype + ema_context = self.ema_unet.apply_temporary(unet.parameters()) if self.ema_unet is not None else nullcontext() - unet.to(dtype=self.weight_dtype) - text_encoder.to(dtype=self.weight_dtype) + with ema_context: + orig_unet_dtype = unet.dtype + orig_text_encoder_dtype = text_encoder.dtype - pipeline = VlpnStableDiffusion( - text_encoder=text_encoder, - vae=self.vae, - unet=unet, - tokenizer=self.tokenizer, - scheduler=self.scheduler, - ).to(self.accelerator.device) - pipeline.set_progress_bar_config(dynamic_ncols=True) + unet.to(dtype=self.weight_dtype) + text_encoder.to(dtype=self.weight_dtype) + + pipeline = VlpnStableDiffusion( + text_encoder=text_encoder, + vae=self.vae, + unet=unet, + tokenizer=self.tokenizer, + scheduler=self.scheduler, + ).to(self.accelerator.device) + pipeline.set_progress_bar_config(dynamic_ncols=True) - super().save_samples(pipeline, step, num_inference_steps, guidance_scale, eta) + super().save_samples(pipeline, step, num_inference_steps, guidance_scale, eta) - unet.to(dtype=orig_unet_dtype) - text_encoder.to(dtype=orig_text_encoder_dtype) + unet.to(dtype=orig_unet_dtype) + text_encoder.to(dtype=orig_text_encoder_dtype) del unet del text_encoder @@ -580,6 +584,7 @@ def main(): noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder='scheduler') checkpoint_scheduler = DPMSolverMultistepScheduler.from_pretrained( args.pretrained_model_name_or_path, subfolder='scheduler') + ema_unet = None vae.enable_slicing() vae.set_use_memory_efficient_attention_xformers(True) @@ -589,14 +594,12 @@ def main(): unet.enable_gradient_checkpointing() text_encoder.gradient_checkpointing_enable() - ema_unet = None if args.use_ema: ema_unet = EMAModel( - unet, + unet.parameters(), inv_gamma=args.ema_inv_gamma, power=args.ema_power, max_value=args.ema_max_decay, - device=accelerator.device ) embeddings = patch_managed_embeddings(text_encoder) @@ -748,52 +751,27 @@ def main(): datamodule.prepare_data() datamodule.setup() - if args.num_class_images != 0: - missing_data = [item for item in datamodule.data_train if not item.class_image_path.exists()] - - if len(missing_data) != 0: - batched_data = [ - missing_data[i:i+args.sample_batch_size] - for i in range(0, len(missing_data), args.sample_batch_size) - ] - - pipeline = VlpnStableDiffusion( - text_encoder=text_encoder, - vae=vae, - unet=unet, - tokenizer=tokenizer, - scheduler=checkpoint_scheduler, - ).to(accelerator.device) - pipeline.set_progress_bar_config(dynamic_ncols=True) - - with torch.inference_mode(): - for batch in batched_data: - image_name = [item.class_image_path for item in batch] - prompt = [item.cprompt for item in batch] - nprompt = [item.nprompt for item in batch] - - images = pipeline( - prompt=prompt, - negative_prompt=nprompt, - height=args.sample_image_size, - width=args.sample_image_size, - num_inference_steps=args.sample_steps - ).images - - for i, image in enumerate(images): - image.save(image_name[i]) - - del pipeline + train_dataloaders = datamodule.train_dataloaders + val_dataloader = datamodule.val_dataloader - if torch.cuda.is_available(): - torch.cuda.empty_cache() - - train_dataloader = datamodule.train_dataloader() - val_dataloader = datamodule.val_dataloader() + if args.num_class_images != 0: + generate_class_images( + accelerator, + text_encoder, + vae, + unet, + tokenizer, + checkpoint_scheduler, + datamodule.data_train, + args.sample_batch_size, + args.sample_image_size, + args.sample_steps + ) # Scheduler and math around the number of training steps. overrode_max_train_steps = False - num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + num_update_steps_per_dataloader = sum(len(dataloader) for dataloader in train_dataloaders) + num_update_steps_per_epoch = math.ceil(num_update_steps_per_dataloader / args.gradient_accumulation_steps) if args.max_train_steps is None: args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch overrode_max_train_steps = True @@ -838,8 +816,12 @@ def main(): # Keep text_encoder and vae in eval mode as we don't train these vae.eval() + if args.use_ema: + ema_unet.to(accelerator.device) + # We need to recalculate our total training steps as the size of the training dataloader may have changed. - num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + num_update_steps_per_dataloader = sum(len(dataloader) for dataloader in train_dataloaders) + num_update_steps_per_epoch = math.ceil(num_update_steps_per_dataloader / args.gradient_accumulation_steps) if overrode_max_train_steps: args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch @@ -847,11 +829,25 @@ def main(): num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) val_steps = num_val_steps_per_epoch * num_epochs + @contextmanager def on_train(): - tokenizer.train() + try: + tokenizer.train() + yield + finally: + pass + @contextmanager def on_eval(): - tokenizer.eval() + try: + tokenizer.eval() + + ema_context = ema_unet.apply_temporary(unet.parameters()) if args.use_ema else nullcontext() + + with ema_context: + yield + finally: + pass loop = partial( run_model, @@ -881,7 +877,7 @@ def main(): accelerator, text_encoder, optimizer, - train_dataloader, + train_dataloaders[0], val_dataloader, loop, on_train=tokenizer.train, @@ -962,88 +958,89 @@ def main(): text_encoder.train() elif epoch == args.train_text_encoder_epochs: text_encoder.requires_grad_(False) - on_train() - for step, batch in enumerate(train_dataloader): - with accelerator.accumulate(unet): - loss, acc, bsz = loop(step, batch) + with on_train(): + for train_dataloader in train_dataloaders: + for step, batch in enumerate(train_dataloader): + with accelerator.accumulate(unet): + loss, acc, bsz = loop(step, batch) - accelerator.backward(loss) + accelerator.backward(loss) - if accelerator.sync_gradients: - params_to_clip = ( - itertools.chain(unet.parameters(), text_encoder.parameters()) - if args.train_text_encoder and epoch < args.train_text_encoder_epochs - else unet.parameters() - ) - accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + if accelerator.sync_gradients: + params_to_clip = ( + itertools.chain(unet.parameters(), text_encoder.parameters()) + if args.train_text_encoder and epoch < args.train_text_encoder_epochs + else unet.parameters() + ) + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) - optimizer.step() - if not accelerator.optimizer_step_was_skipped: - lr_scheduler.step() - if args.use_ema: - ema_unet.step(unet) - optimizer.zero_grad(set_to_none=True) + optimizer.step() + if not accelerator.optimizer_step_was_skipped: + lr_scheduler.step() + if args.use_ema: + ema_unet.step(unet.parameters()) + optimizer.zero_grad(set_to_none=True) - avg_loss.update(loss.detach_(), bsz) - avg_acc.update(acc.detach_(), bsz) + avg_loss.update(loss.detach_(), bsz) + avg_acc.update(acc.detach_(), bsz) - # Checks if the accelerator has performed an optimization step behind the scenes - if accelerator.sync_gradients: - local_progress_bar.update(1) - global_progress_bar.update(1) + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + local_progress_bar.update(1) + global_progress_bar.update(1) - global_step += 1 + global_step += 1 - logs = { - "train/loss": avg_loss.avg.item(), - "train/acc": avg_acc.avg.item(), - "train/cur_loss": loss.item(), - "train/cur_acc": acc.item(), - "lr": lr_scheduler.get_last_lr()[0] - } - if args.use_ema: - logs["ema_decay"] = 1 - ema_unet.decay + logs = { + "train/loss": avg_loss.avg.item(), + "train/acc": avg_acc.avg.item(), + "train/cur_loss": loss.item(), + "train/cur_acc": acc.item(), + "lr": lr_scheduler.get_last_lr()[0] + } + if args.use_ema: + logs["ema_decay"] = 1 - ema_unet.decay - accelerator.log(logs, step=global_step) + accelerator.log(logs, step=global_step) - local_progress_bar.set_postfix(**logs) + local_progress_bar.set_postfix(**logs) - if global_step >= args.max_train_steps: - break + if global_step >= args.max_train_steps: + break accelerator.wait_for_everyone() unet.eval() text_encoder.eval() - on_eval() cur_loss_val = AverageMeter() cur_acc_val = AverageMeter() with torch.inference_mode(): - for step, batch in enumerate(val_dataloader): - loss, acc, bsz = loop(step, batch, True) + with on_eval(): + for step, batch in enumerate(val_dataloader): + loss, acc, bsz = loop(step, batch, True) - loss = loss.detach_() - acc = acc.detach_() + loss = loss.detach_() + acc = acc.detach_() - cur_loss_val.update(loss, bsz) - cur_acc_val.update(acc, bsz) + cur_loss_val.update(loss, bsz) + cur_acc_val.update(acc, bsz) - avg_loss_val.update(loss, bsz) - avg_acc_val.update(acc, bsz) + avg_loss_val.update(loss, bsz) + avg_acc_val.update(acc, bsz) - local_progress_bar.update(1) - global_progress_bar.update(1) + local_progress_bar.update(1) + global_progress_bar.update(1) - logs = { - "val/loss": avg_loss_val.avg.item(), - "val/acc": avg_acc_val.avg.item(), - "val/cur_loss": loss.item(), - "val/cur_acc": acc.item(), - } - local_progress_bar.set_postfix(**logs) + logs = { + "val/loss": avg_loss_val.avg.item(), + "val/acc": avg_acc_val.avg.item(), + "val/cur_loss": loss.item(), + "val/cur_acc": acc.item(), + } + local_progress_bar.set_postfix(**logs) logs["val/cur_loss"] = cur_loss_val.avg.item() logs["val/cur_acc"] = cur_acc_val.avg.item() -- cgit v1.2.3-54-g00ecf