From 5b9a3de142e7a645573b4f4a8c1ce9c59746ab08 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sun, 15 Jan 2023 09:25:30 +0100 Subject: Added functional trainer --- train_ti.py | 49 ++++++++++++++++----------------- trainer_old/base.py | 14 +++------- training/functional.py | 75 +++++++++++++++++++++++++++++++++++++++++++++++++- 3 files changed, 101 insertions(+), 37 deletions(-) diff --git a/train_ti.py b/train_ti.py index 78c1b5c..97e4e72 100644 --- a/train_ti.py +++ b/train_ti.py @@ -17,7 +17,7 @@ from slugify import slugify from util import load_config, load_embeddings_from_dir from data.csv import VlpnDataModule, VlpnDataItem from trainer_old.base import Checkpointer -from training.functional import loss_step, train_loop, generate_class_images, add_placeholder_tokens, get_models +from training.functional import train, loss_step, train_loop, generate_class_images, add_placeholder_tokens, get_models from training.optimization import get_scheduler from training.lr import LRFinder from training.util import EMAModel, save_args @@ -703,17 +703,27 @@ def main(): warmup_epochs=args.lr_warmup_epochs, ) - unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( - unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler - ) - - vae.to(accelerator.device, dtype=weight_dtype) - if args.use_ema: ema_embeddings.to(accelerator.device) - if args.gradient_checkpointing: - unet.train() + trainer = partial( + train, + accelerator=accelerator, + vae=vae, + unet=unet, + text_encoder=text_encoder, + noise_scheduler=noise_scheduler, + train_dataloader=train_dataloader, + val_dataloader=val_dataloader, + dtype=weight_dtype, + seed=args.seed, + ) + + def on_prepare(): + text_encoder.text_model.embeddings.temp_token_embedding.requires_grad_(True) + + if args.gradient_checkpointing: + unet.train() @contextmanager def on_train(epoch: int): @@ -752,16 +762,6 @@ def main(): return {"ema_decay": ema_embeddings.decay} return {} - loss_step_ = partial( - loss_step, - vae, - noise_scheduler, - unet, - text_encoder, - args.prior_loss_weight, - args.seed, - ) - checkpointer = TextualInversionCheckpointer( dtype=weight_dtype, train_dataloader=train_dataloader, @@ -803,18 +803,15 @@ def main(): plt.savefig(output_dir.joinpath("lr.png"), dpi=300) plt.close() else: - train_loop( - accelerator=accelerator, + trainer( optimizer=optimizer, lr_scheduler=lr_scheduler, - model=text_encoder, - train_dataloader=train_dataloader, - val_dataloader=val_dataloader, - loss_step=loss_step_, + num_train_epochs=args.num_train_epochs, sample_frequency=args.sample_frequency, checkpoint_frequency=args.checkpoint_frequency, global_step_offset=global_step_offset, - num_epochs=args.num_train_epochs, + prior_loss_weight=args.prior_loss_weight, + on_prepare=on_prepare, on_log=on_log, on_train=on_train, on_after_optimize=on_after_optimize, diff --git a/trainer_old/base.py b/trainer_old/base.py index 1f85e71..5903d96 100644 --- a/trainer_old/base.py +++ b/trainer_old/base.py @@ -174,19 +174,13 @@ class TrainingStrategy(): @contextmanager def on_train(self, epoch: int): - try: - self.tokenizer.train() - yield - finally: - pass + self.tokenizer.train() + yield @contextmanager def on_eval(self): - try: - self.tokenizer.eval() - yield - finally: - pass + self.tokenizer.eval() + yield def on_before_optimize(self, epoch: int): ... diff --git a/training/functional.py b/training/functional.py index c5b514a..1f2ca6d 100644 --- a/training/functional.py +++ b/training/functional.py @@ -1,6 +1,7 @@ import math from contextlib import _GeneratorContextManager, nullcontext -from typing import Callable, Any, Tuple, Union +from typing import Callable, Any, Tuple, Union, Optional +from functools import partial import torch import torch.nn.functional as F @@ -376,3 +377,75 @@ def train_loop( print("Interrupted") on_checkpoint(global_step + global_step_offset, "end") accelerator.end_training() + + +def train( + accelerator: Accelerator, + unet: UNet2DConditionModel, + text_encoder: CLIPTextModel, + vae: AutoencoderKL, + noise_scheduler: DDPMScheduler, + train_dataloader: DataLoader, + val_dataloader: DataLoader, + dtype: torch.dtype, + seed: int, + optimizer: torch.optim.Optimizer, + lr_scheduler: torch.optim.lr_scheduler._LRScheduler, + num_train_epochs: int = 100, + sample_frequency: int = 20, + checkpoint_frequency: int = 50, + global_step_offset: int = 0, + prior_loss_weight: float = 0, + on_prepare: Callable[[], dict[str, Any]] = const({}), + on_log: Callable[[], dict[str, Any]] = const({}), + on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext()), + on_before_optimize: Callable[[int], None] = const(), + on_after_optimize: Callable[[float], None] = const(), + on_eval: Callable[[], _GeneratorContextManager] = const(nullcontext()), + on_sample: Callable[[int], None] = const(), + on_checkpoint: Callable[[int, str], None] = const(), +): + unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( + unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler + ) + + vae.to(accelerator.device, dtype=dtype) + + for model in (unet, text_encoder, vae): + model.requires_grad_(False) + model.eval() + + on_prepare() + + loss_step_ = partial( + loss_step, + vae, + noise_scheduler, + unet, + text_encoder, + prior_loss_weight, + seed, + ) + + train_loop( + accelerator=accelerator, + optimizer=optimizer, + lr_scheduler=lr_scheduler, + model=text_encoder, + train_dataloader=train_dataloader, + val_dataloader=val_dataloader, + loss_step=loss_step_, + sample_frequency=sample_frequency, + checkpoint_frequency=checkpoint_frequency, + global_step_offset=global_step_offset, + num_epochs=num_train_epochs, + on_log=on_log, + on_train=on_train, + on_before_optimize=on_before_optimize, + on_after_optimize=on_after_optimize, + on_eval=on_eval, + on_sample=on_sample, + on_checkpoint=on_checkpoint, + ) + + accelerator.free_memory() -- cgit v1.2.3-70-g09d2