From 5b9a3de142e7a645573b4f4a8c1ce9c59746ab08 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sun, 15 Jan 2023 09:25:30 +0100 Subject: Added functional trainer --- training/functional.py | 75 +++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 74 insertions(+), 1 deletion(-) (limited to 'training') diff --git a/training/functional.py b/training/functional.py index c5b514a..1f2ca6d 100644 --- a/training/functional.py +++ b/training/functional.py @@ -1,6 +1,7 @@ import math from contextlib import _GeneratorContextManager, nullcontext -from typing import Callable, Any, Tuple, Union +from typing import Callable, Any, Tuple, Union, Optional +from functools import partial import torch import torch.nn.functional as F @@ -376,3 +377,75 @@ def train_loop( print("Interrupted") on_checkpoint(global_step + global_step_offset, "end") accelerator.end_training() + + +def train( + accelerator: Accelerator, + unet: UNet2DConditionModel, + text_encoder: CLIPTextModel, + vae: AutoencoderKL, + noise_scheduler: DDPMScheduler, + train_dataloader: DataLoader, + val_dataloader: DataLoader, + dtype: torch.dtype, + seed: int, + optimizer: torch.optim.Optimizer, + lr_scheduler: torch.optim.lr_scheduler._LRScheduler, + num_train_epochs: int = 100, + sample_frequency: int = 20, + checkpoint_frequency: int = 50, + global_step_offset: int = 0, + prior_loss_weight: float = 0, + on_prepare: Callable[[], dict[str, Any]] = const({}), + on_log: Callable[[], dict[str, Any]] = const({}), + on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext()), + on_before_optimize: Callable[[int], None] = const(), + on_after_optimize: Callable[[float], None] = const(), + on_eval: Callable[[], _GeneratorContextManager] = const(nullcontext()), + on_sample: Callable[[int], None] = const(), + on_checkpoint: Callable[[int, str], None] = const(), +): + unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( + unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler + ) + + vae.to(accelerator.device, dtype=dtype) + + for model in (unet, text_encoder, vae): + model.requires_grad_(False) + model.eval() + + on_prepare() + + loss_step_ = partial( + loss_step, + vae, + noise_scheduler, + unet, + text_encoder, + prior_loss_weight, + seed, + ) + + train_loop( + accelerator=accelerator, + optimizer=optimizer, + lr_scheduler=lr_scheduler, + model=text_encoder, + train_dataloader=train_dataloader, + val_dataloader=val_dataloader, + loss_step=loss_step_, + sample_frequency=sample_frequency, + checkpoint_frequency=checkpoint_frequency, + global_step_offset=global_step_offset, + num_epochs=num_train_epochs, + on_log=on_log, + on_train=on_train, + on_before_optimize=on_before_optimize, + on_after_optimize=on_after_optimize, + on_eval=on_eval, + on_sample=on_sample, + on_checkpoint=on_checkpoint, + ) + + accelerator.free_memory() -- cgit v1.2.3-70-g09d2