From 37baa3aa254af721728aa33befdc383858cb8ea2 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sun, 15 Jan 2023 10:38:49 +0100 Subject: Removed unused code, put training callbacks in dataclass --- training/functional.py | 63 +++++++++++++++++++++++--------------------------- 1 file changed, 29 insertions(+), 34 deletions(-) (limited to 'training/functional.py') diff --git a/training/functional.py b/training/functional.py index e54c9c8..4ca7470 100644 --- a/training/functional.py +++ b/training/functional.py @@ -1,3 +1,4 @@ +from dataclasses import dataclass import math from contextlib import _GeneratorContextManager, nullcontext from typing import Callable, Any, Tuple, Union, Optional @@ -14,6 +15,7 @@ from transformers import CLIPTextModel from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel, DPMSolverMultistepScheduler from tqdm.auto import tqdm +from PIL import Image from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion from models.clip.embeddings import ManagedCLIPTextEmbeddings, patch_managed_embeddings @@ -28,6 +30,18 @@ def const(result=None): return fn +@dataclass +class TrainingCallbacks(): + on_prepare: Callable[[float], None] = const() + on_log: Callable[[], dict[str, Any]] = const({}) + on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext()) + on_before_optimize: Callable[[int], None] = const() + on_after_optimize: Callable[[float], None] = const() + on_eval: Callable[[], _GeneratorContextManager] = const(nullcontext()) + on_sample: Callable[[int], None] = const() + on_checkpoint: Callable[[int, str], None] = const() + + def make_grid(images, rows, cols): w, h = images[0].size grid = Image.new('RGB', size=(cols*w, rows*h)) @@ -341,13 +355,7 @@ def train_loop( checkpoint_frequency: int = 50, global_step_offset: int = 0, num_epochs: int = 100, - on_log: Callable[[], dict[str, Any]] = const({}), - on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext()), - on_before_optimize: Callable[[int], None] = const(), - on_after_optimize: Callable[[float], None] = const(), - on_eval: Callable[[], _GeneratorContextManager] = const(nullcontext()), - on_sample: Callable[[int], None] = const(), - on_checkpoint: Callable[[int, str], None] = const(), + callbacks: TrainingCallbacks = TrainingCallbacks(), ): num_training_steps_per_epoch = math.ceil(len(train_dataloader) / accelerator.gradient_accumulation_steps) num_val_steps_per_epoch = len(val_dataloader) @@ -383,24 +391,24 @@ def train_loop( for epoch in range(num_epochs): if accelerator.is_main_process: if epoch % sample_frequency == 0: - on_sample(global_step + global_step_offset) + callbacks.on_sample(global_step + global_step_offset) if epoch % checkpoint_frequency == 0 and epoch != 0: - on_checkpoint(global_step + global_step_offset, "training") + callbacks.on_checkpoint(global_step + global_step_offset, "training") local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") local_progress_bar.reset() model.train() - with on_train(epoch): + with callbacks.on_train(epoch): for step, batch in enumerate(train_dataloader): with accelerator.accumulate(model): loss, acc, bsz = loss_step(step, batch) accelerator.backward(loss) - on_before_optimize(epoch) + callbacks.on_before_optimize(epoch) optimizer.step() lr_scheduler.step() @@ -411,7 +419,7 @@ def train_loop( # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: - on_after_optimize(lr_scheduler.get_last_lr()[0]) + callbacks.on_after_optimize(lr_scheduler.get_last_lr()[0]) local_progress_bar.update(1) global_progress_bar.update(1) @@ -425,7 +433,7 @@ def train_loop( "train/cur_acc": acc.item(), "lr": lr_scheduler.get_last_lr()[0], } - logs.update(on_log()) + logs.update(callbacks.on_log()) accelerator.log(logs, step=global_step) @@ -441,7 +449,7 @@ def train_loop( cur_loss_val = AverageMeter() cur_acc_val = AverageMeter() - with torch.inference_mode(), on_eval(): + with torch.inference_mode(), callbacks.on_eval(): for step, batch in enumerate(val_dataloader): loss, acc, bsz = loss_step(step, batch, True) @@ -477,20 +485,20 @@ def train_loop( if avg_acc_val.avg.item() > max_acc_val: accelerator.print( f"Global step {global_step}: Validation accuracy reached new maximum: {max_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}") - on_checkpoint(global_step + global_step_offset, "milestone") + callbacks.on_checkpoint(global_step + global_step_offset, "milestone") max_acc_val = avg_acc_val.avg.item() # Create the pipeline using using the trained modules and save it. if accelerator.is_main_process: print("Finished!") - on_checkpoint(global_step + global_step_offset, "end") - on_sample(global_step + global_step_offset) + callbacks.on_checkpoint(global_step + global_step_offset, "end") + callbacks.on_sample(global_step + global_step_offset) accelerator.end_training() except KeyboardInterrupt: if accelerator.is_main_process: print("Interrupted") - on_checkpoint(global_step + global_step_offset, "end") + callbacks.on_checkpoint(global_step + global_step_offset, "end") accelerator.end_training() @@ -511,14 +519,7 @@ def train( checkpoint_frequency: int = 50, global_step_offset: int = 0, prior_loss_weight: float = 0, - on_prepare: Callable[[], dict[str, Any]] = const({}), - on_log: Callable[[], dict[str, Any]] = const({}), - on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext()), - on_before_optimize: Callable[[int], None] = const(), - on_after_optimize: Callable[[float], None] = const(), - on_eval: Callable[[], _GeneratorContextManager] = const(nullcontext()), - on_sample: Callable[[int], None] = const(), - on_checkpoint: Callable[[int, str], None] = const(), + callbacks: TrainingCallbacks = TrainingCallbacks(), ): unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler @@ -530,7 +531,7 @@ def train( model.requires_grad_(False) model.eval() - on_prepare() + callbacks.on_prepare() loss_step_ = partial( loss_step, @@ -557,13 +558,7 @@ def train( checkpoint_frequency=checkpoint_frequency, global_step_offset=global_step_offset, num_epochs=num_train_epochs, - on_log=on_log, - on_train=on_train, - on_before_optimize=on_before_optimize, - on_after_optimize=on_after_optimize, - on_eval=on_eval, - on_sample=on_sample, - on_checkpoint=on_checkpoint, + callbacks=callbacks, ) accelerator.free_memory() -- cgit v1.2.3-70-g09d2