From 672a59abeaa60dc5ef78a33bd9b58e391b922016 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Fri, 6 Jan 2023 11:14:24 +0100 Subject: Use context manager for EMA, on_train/eval hooks --- train_ti.py | 120 ++++++++++++++++++++++++++++++------------------------- training/lr.py | 51 ++++++++++++----------- training/util.py | 2 +- 3 files changed, 92 insertions(+), 81 deletions(-) diff --git a/train_ti.py b/train_ti.py index aa2bf02..f622299 100644 --- a/train_ti.py +++ b/train_ti.py @@ -2,10 +2,9 @@ import argparse import math import datetime import logging -import copy -from pathlib import Path from functools import partial -from contextlib import nullcontext +from pathlib import Path +from contextlib import contextmanager, nullcontext import torch import torch.utils.checkpoint @@ -849,11 +848,24 @@ def main(): num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) val_steps = num_val_steps_per_epoch * num_epochs + @contextmanager def on_train(): - tokenizer.train() + try: + tokenizer.train() + yield + finally: + tokenizer.eval() + @contextmanager def on_eval(): - tokenizer.eval() + try: + ema_context = ema_embeddings.apply_temporary( + text_encoder.text_model.embeddings.temp_token_embedding.parameters()) if args.use_ema is not None and eval else nullcontext() + + with ema_context: + yield + finally: + pass loop = partial( run_model, @@ -961,80 +973,80 @@ def main(): local_progress_bar.reset() text_encoder.train() - on_train() - for step, batch in enumerate(train_dataloader): - with accelerator.accumulate(text_encoder): - loss, acc, bsz = loop(step, batch) + with on_train(): + for step, batch in enumerate(train_dataloader): + with accelerator.accumulate(text_encoder): + loss, acc, bsz = loop(step, batch) - accelerator.backward(loss) + accelerator.backward(loss) - optimizer.step() - if not accelerator.optimizer_step_was_skipped: - lr_scheduler.step() - optimizer.zero_grad(set_to_none=True) + optimizer.step() + if not accelerator.optimizer_step_was_skipped: + lr_scheduler.step() + optimizer.zero_grad(set_to_none=True) - avg_loss.update(loss.detach_(), bsz) - avg_acc.update(acc.detach_(), bsz) + avg_loss.update(loss.detach_(), bsz) + avg_acc.update(acc.detach_(), bsz) - # Checks if the accelerator has performed an optimization step behind the scenes - if accelerator.sync_gradients: - if args.use_ema: - ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters()) + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + if args.use_ema: + ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters()) - local_progress_bar.update(1) - global_progress_bar.update(1) + local_progress_bar.update(1) + global_progress_bar.update(1) - global_step += 1 + global_step += 1 - logs = { - "train/loss": avg_loss.avg.item(), - "train/acc": avg_acc.avg.item(), - "train/cur_loss": loss.item(), - "train/cur_acc": acc.item(), - "lr": lr_scheduler.get_last_lr()[0], - } - if args.use_ema: - logs["ema_decay"] = ema_embeddings.decay + logs = { + "train/loss": avg_loss.avg.item(), + "train/acc": avg_acc.avg.item(), + "train/cur_loss": loss.item(), + "train/cur_acc": acc.item(), + "lr": lr_scheduler.get_last_lr()[0], + } + if args.use_ema: + logs["ema_decay"] = ema_embeddings.decay - accelerator.log(logs, step=global_step) + accelerator.log(logs, step=global_step) - local_progress_bar.set_postfix(**logs) + local_progress_bar.set_postfix(**logs) - if global_step >= args.max_train_steps: - break + if global_step >= args.max_train_steps: + break accelerator.wait_for_everyone() text_encoder.eval() - on_eval() cur_loss_val = AverageMeter() cur_acc_val = AverageMeter() with torch.inference_mode(): - for step, batch in enumerate(val_dataloader): - loss, acc, bsz = loop(step, batch, True) + with on_eval(): + for step, batch in enumerate(val_dataloader): + loss, acc, bsz = loop(step, batch, True) - loss = loss.detach_() - acc = acc.detach_() + loss = loss.detach_() + acc = acc.detach_() - cur_loss_val.update(loss, bsz) - cur_acc_val.update(acc, bsz) + cur_loss_val.update(loss, bsz) + cur_acc_val.update(acc, bsz) - avg_loss_val.update(loss, bsz) - avg_acc_val.update(acc, bsz) + avg_loss_val.update(loss, bsz) + avg_acc_val.update(acc, bsz) - local_progress_bar.update(1) - global_progress_bar.update(1) + local_progress_bar.update(1) + global_progress_bar.update(1) - logs = { - "val/loss": avg_loss_val.avg.item(), - "val/acc": avg_acc_val.avg.item(), - "val/cur_loss": loss.item(), - "val/cur_acc": acc.item(), - } - local_progress_bar.set_postfix(**logs) + logs = { + "val/loss": avg_loss_val.avg.item(), + "val/acc": avg_acc_val.avg.item(), + "val/cur_loss": loss.item(), + "val/cur_acc": acc.item(), + } + local_progress_bar.set_postfix(**logs) logs["val/cur_loss"] = cur_loss_val.avg.item() logs["val/cur_acc"] = cur_acc_val.avg.item() diff --git a/training/lr.py b/training/lr.py index c765150..68e0f72 100644 --- a/training/lr.py +++ b/training/lr.py @@ -1,5 +1,5 @@ import math -import copy +from contextlib import _GeneratorContextManager, nullcontext from typing import Callable, Any, Tuple, Union from functools import partial @@ -25,9 +25,9 @@ class LRFinder(): train_dataloader, val_dataloader, loss_fn: Union[Callable[[int, Any], Tuple[Any, Any, int]], Callable[[int, Any, bool], Tuple[Any, Any, int]]], - on_train: Callable[[], None] = noop, + on_train: Callable[[], _GeneratorContextManager] = nullcontext, on_clip: Callable[[], None] = noop, - on_eval: Callable[[], None] = noop + on_eval: Callable[[], _GeneratorContextManager] = nullcontext ): self.accelerator = accelerator self.model = model @@ -51,7 +51,6 @@ class LRFinder(): num_train_batches: int = 1, num_val_batches: int = math.inf, smooth_f: float = 0.05, - diverge_th: int = 5, ): best_loss = None best_acc = None @@ -84,40 +83,40 @@ class LRFinder(): avg_acc = AverageMeter() self.model.train() - self.on_train() - for step, batch in enumerate(self.train_dataloader): - if step >= num_train_batches: - break + with self.on_train(): + for step, batch in enumerate(self.train_dataloader): + if step >= num_train_batches: + break - with self.accelerator.accumulate(self.model): - loss, acc, bsz = self.loss_fn(step, batch) + with self.accelerator.accumulate(self.model): + loss, acc, bsz = self.loss_fn(step, batch) - self.accelerator.backward(loss) + self.accelerator.backward(loss) - if self.accelerator.sync_gradients: - self.on_clip() + if self.accelerator.sync_gradients: + self.on_clip() - self.optimizer.step() - lr_scheduler.step() - self.optimizer.zero_grad(set_to_none=True) + self.optimizer.step() + lr_scheduler.step() + self.optimizer.zero_grad(set_to_none=True) - if self.accelerator.sync_gradients: - progress_bar.update(1) + if self.accelerator.sync_gradients: + progress_bar.update(1) self.model.eval() - self.on_eval() with torch.inference_mode(): - for step, batch in enumerate(self.val_dataloader): - if step >= num_val_batches: - break + with self.on_eval(): + for step, batch in enumerate(self.val_dataloader): + if step >= num_val_batches: + break - loss, acc, bsz = self.loss_fn(step, batch, True) - avg_loss.update(loss.detach_(), bsz) - avg_acc.update(acc.detach_(), bsz) + loss, acc, bsz = self.loss_fn(step, batch, True) + avg_loss.update(loss.detach_(), bsz) + avg_acc.update(acc.detach_(), bsz) - progress_bar.update(1) + progress_bar.update(1) loss = avg_loss.avg.item() acc = avg_acc.avg.item() diff --git a/training/util.py b/training/util.py index 6f1e85a..bed7111 100644 --- a/training/util.py +++ b/training/util.py @@ -262,7 +262,7 @@ class EMAModel: raise ValueError("collected_params and shadow_params must have the same length") @contextmanager - def apply_temporary(self, parameters): + def apply_temporary(self, parameters: Iterable[torch.nn.Parameter]): try: parameters = list(parameters) original_params = [p.clone() for p in parameters] -- cgit v1.2.3-70-g09d2