From 7b149930bb53b93db74106ad20a30abf4b114f9b Mon Sep 17 00:00:00 2001 From: Volpeon Date: Fri, 13 Jan 2023 13:49:35 +0100 Subject: Removed PromptProcessor, modularized training loop --- training/common.py | 205 ++++++++++++++++++++++++++++++++++++++++++++++++++--- training/util.py | 13 +++- 2 files changed, 208 insertions(+), 10 deletions(-) (limited to 'training') diff --git a/training/common.py b/training/common.py index 90cf910..842ac07 100644 --- a/training/common.py +++ b/training/common.py @@ -1,14 +1,30 @@ import math +from contextlib import _GeneratorContextManager, nullcontext +from typing import Callable, Any, Tuple, Union import torch import torch.nn.functional as F +from torch.utils.data import DataLoader +from accelerate import Accelerator +from transformers import CLIPTokenizer, CLIPTextModel from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel from diffusers.optimization import get_scheduler as get_scheduler_, get_cosine_with_hard_restarts_schedule_with_warmup -from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion +from tqdm.auto import tqdm +from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion +from models.clip.util import get_extended_embeddings from training.optimization import get_one_cycle_schedule +from training.util import AverageMeter, CheckpointerBase + + +def noop(*args, **kwards): + pass + + +def noop_on_log(): + return {} def get_scheduler( @@ -22,10 +38,11 @@ def get_scheduler( cycles: int, warmup_epochs: int, optimizer: torch.optim.Optimizer, - max_train_steps: int, + num_train_epochs: int, num_update_steps_per_epoch: int, gradient_accumulation_steps: int, ): + num_train_steps = num_train_epochs * num_update_steps_per_epoch warmup_steps = warmup_epochs * num_update_steps_per_epoch * gradient_accumulation_steps if id == "one_cycle": @@ -33,7 +50,7 @@ def get_scheduler( lr_scheduler = get_one_cycle_schedule( optimizer=optimizer, - num_training_steps=max_train_steps * gradient_accumulation_steps, + num_training_steps=num_train_steps * gradient_accumulation_steps, warmup=warmup_func, annealing=annealing_func, warmup_exp=warmup_exp, @@ -42,12 +59,12 @@ def get_scheduler( ) elif id == "cosine_with_restarts": cycles = cycles if cycles is not None else math.ceil( - math.sqrt(((max_train_steps - warmup_steps) / num_update_steps_per_epoch))) + math.sqrt(((num_train_steps - warmup_steps) / num_update_steps_per_epoch))) lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( optimizer=optimizer, num_warmup_steps=warmup_steps, - num_training_steps=max_train_steps * gradient_accumulation_steps, + num_training_steps=num_train_steps * gradient_accumulation_steps, num_cycles=cycles, ) else: @@ -55,7 +72,7 @@ def get_scheduler( id, optimizer=optimizer, num_warmup_steps=warmup_steps, - num_training_steps=max_train_steps * gradient_accumulation_steps, + num_training_steps=num_train_steps * gradient_accumulation_steps, ) return lr_scheduler @@ -117,12 +134,12 @@ def loss_step( vae: AutoencoderKL, noise_scheduler: DDPMScheduler, unet: UNet2DConditionModel, - prompt_processor, + text_encoder: CLIPTextModel, num_class_images: int, prior_loss_weight: float, seed: int, step: int, - batch, + batch: dict[str, Any], eval: bool = False ): # Convert images to latent space @@ -149,7 +166,8 @@ def loss_step( noisy_latents = noisy_latents.to(dtype=unet.dtype) # Get the text embedding for conditioning - encoder_hidden_states = prompt_processor.get_embeddings( + encoder_hidden_states = get_extended_embeddings( + text_encoder, batch["input_ids"], batch["attention_mask"] ) @@ -185,3 +203,172 @@ def loss_step( acc = (model_pred == target).float().mean() return loss, acc, bsz + + +def train_loop( + accelerator: Accelerator, + optimizer: torch.optim.Optimizer, + lr_scheduler: torch.optim.lr_scheduler._LRScheduler, + model: torch.nn.Module, + checkpointer: CheckpointerBase, + train_dataloader: DataLoader, + val_dataloader: DataLoader, + loss_step: Union[Callable[[int, Any], Tuple[Any, Any, int]], Callable[[int, Any, bool], Tuple[Any, Any, int]]], + sample_frequency: int = 10, + sample_steps: int = 20, + checkpoint_frequency: int = 50, + global_step_offset: int = 0, + gradient_accumulation_steps: int = 1, + num_epochs: int = 100, + on_log: Callable[[], dict[str, Any]] = noop_on_log, + on_train: Callable[[], _GeneratorContextManager] = nullcontext, + on_before_optimize: Callable[[], None] = noop, + on_after_optimize: Callable[[float], None] = noop, + on_eval: Callable[[], _GeneratorContextManager] = nullcontext +): + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps) + num_train_steps = num_epochs * num_update_steps_per_epoch + + num_val_steps_per_epoch = len(val_dataloader) + num_epochs = math.ceil(num_train_steps / num_update_steps_per_epoch) + num_val_steps = num_val_steps_per_epoch * num_epochs + + global_step = 0 + + avg_loss = AverageMeter() + avg_acc = AverageMeter() + + avg_loss_val = AverageMeter() + avg_acc_val = AverageMeter() + + max_acc_val = 0.0 + + local_progress_bar = tqdm( + range(num_update_steps_per_epoch + num_val_steps_per_epoch), + disable=not accelerator.is_local_main_process, + dynamic_ncols=True + ) + local_progress_bar.set_description(f"Epoch 1 / {num_epochs}") + + global_progress_bar = tqdm( + range(num_train_steps + num_val_steps), + disable=not accelerator.is_local_main_process, + dynamic_ncols=True + ) + global_progress_bar.set_description("Total progress") + + try: + for epoch in range(num_epochs): + if accelerator.is_main_process: + if epoch % sample_frequency == 0: + checkpointer.save_samples(global_step + global_step_offset, sample_steps) + + if epoch % checkpoint_frequency == 0 and epoch != 0: + checkpointer.checkpoint(global_step + global_step_offset, "training") + + local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") + local_progress_bar.reset() + + model.train() + + with on_train(): + for step, batch in enumerate(train_dataloader): + with accelerator.accumulate(model): + loss, acc, bsz = loss_step(step, batch) + + accelerator.backward(loss) + + on_before_optimize() + + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad(set_to_none=True) + + avg_loss.update(loss.detach_(), bsz) + avg_acc.update(acc.detach_(), bsz) + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + on_after_optimize(lr_scheduler.get_last_lr()[0]) + + local_progress_bar.update(1) + global_progress_bar.update(1) + + global_step += 1 + + logs = { + "train/loss": avg_loss.avg.item(), + "train/acc": avg_acc.avg.item(), + "train/cur_loss": loss.item(), + "train/cur_acc": acc.item(), + "lr": lr_scheduler.get_last_lr()[0], + } + logs.update(on_log()) + + accelerator.log(logs, step=global_step) + + local_progress_bar.set_postfix(**logs) + + if global_step >= num_train_steps: + break + + accelerator.wait_for_everyone() + + model.eval() + + cur_loss_val = AverageMeter() + cur_acc_val = AverageMeter() + + with torch.inference_mode(): + with on_eval(): + for step, batch in enumerate(val_dataloader): + loss, acc, bsz = loss_step(step, batch, True) + + loss = loss.detach_() + acc = acc.detach_() + + cur_loss_val.update(loss, bsz) + cur_acc_val.update(acc, bsz) + + avg_loss_val.update(loss, bsz) + avg_acc_val.update(acc, bsz) + + local_progress_bar.update(1) + global_progress_bar.update(1) + + logs = { + "val/loss": avg_loss_val.avg.item(), + "val/acc": avg_acc_val.avg.item(), + "val/cur_loss": loss.item(), + "val/cur_acc": acc.item(), + } + local_progress_bar.set_postfix(**logs) + + logs["val/cur_loss"] = cur_loss_val.avg.item() + logs["val/cur_acc"] = cur_acc_val.avg.item() + + accelerator.log(logs, step=global_step) + + local_progress_bar.clear() + global_progress_bar.clear() + + if accelerator.is_main_process: + if avg_acc_val.avg.item() > max_acc_val: + accelerator.print( + f"Global step {global_step}: Validation accuracy reached new maximum: {max_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}") + checkpointer.checkpoint(global_step + global_step_offset, "milestone") + max_acc_val = avg_acc_val.avg.item() + + # Create the pipeline using using the trained modules and save it. + if accelerator.is_main_process: + print("Finished!") + checkpointer.checkpoint(global_step + global_step_offset, "end") + checkpointer.save_samples(global_step + global_step_offset, sample_steps) + accelerator.end_training() + + except KeyboardInterrupt: + if accelerator.is_main_process: + print("Interrupted") + checkpointer.checkpoint(global_step + global_step_offset, "end") + accelerator.end_training() + quit() diff --git a/training/util.py b/training/util.py index 60d64f0..0ec2032 100644 --- a/training/util.py +++ b/training/util.py @@ -55,8 +55,19 @@ class CheckpointerBase: self.sample_batches = sample_batches self.sample_batch_size = sample_batch_size + @torch.no_grad() + def checkpoint(self, step: int, postfix: str): + pass + @torch.inference_mode() - def save_samples(self, pipeline, step, num_inference_steps, guidance_scale=7.5, eta=0.0): + def save_samples( + self, + pipeline, + step: int, + num_inference_steps: int, + guidance_scale: float = 7.5, + eta: float = 0.0 + ): samples_path = Path(self.output_dir).joinpath("samples") train_data = self.datamodule.train_dataloader -- cgit v1.2.3-54-g00ecf