From a2c240c8c55dfe930657f66372975d6f26feb168 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 14 Jan 2023 10:02:30 +0100 Subject: TI: Prepare UNet with Accelerate as well --- training/common.py | 37 ++++++++++++++++++------------------- training/util.py | 16 ++++++++-------- 2 files changed, 26 insertions(+), 27 deletions(-) (limited to 'training') diff --git a/training/common.py b/training/common.py index 8083137..5d1e3f9 100644 --- a/training/common.py +++ b/training/common.py @@ -316,30 +316,29 @@ def train_loop( cur_loss_val = AverageMeter() cur_acc_val = AverageMeter() - with torch.inference_mode(): - with on_eval(): - for step, batch in enumerate(val_dataloader): - loss, acc, bsz = loss_step(step, batch, True) + with torch.inference_mode(), on_eval(): + for step, batch in enumerate(val_dataloader): + loss, acc, bsz = loss_step(step, batch, True) - loss = loss.detach_() - acc = acc.detach_() + loss = loss.detach_() + acc = acc.detach_() - cur_loss_val.update(loss, bsz) - cur_acc_val.update(acc, bsz) + cur_loss_val.update(loss, bsz) + cur_acc_val.update(acc, bsz) - avg_loss_val.update(loss, bsz) - avg_acc_val.update(acc, bsz) + avg_loss_val.update(loss, bsz) + avg_acc_val.update(acc, bsz) - local_progress_bar.update(1) - global_progress_bar.update(1) + local_progress_bar.update(1) + global_progress_bar.update(1) - logs = { - "val/loss": avg_loss_val.avg.item(), - "val/acc": avg_acc_val.avg.item(), - "val/cur_loss": loss.item(), - "val/cur_acc": acc.item(), - } - local_progress_bar.set_postfix(**logs) + logs = { + "val/loss": avg_loss_val.avg.item(), + "val/acc": avg_acc_val.avg.item(), + "val/cur_loss": loss.item(), + "val/cur_acc": acc.item(), + } + local_progress_bar.set_postfix(**logs) logs["val/cur_loss"] = cur_loss_val.avg.item() logs["val/cur_acc"] = cur_acc_val.avg.item() diff --git a/training/util.py b/training/util.py index 1008021..781cf04 100644 --- a/training/util.py +++ b/training/util.py @@ -134,11 +134,11 @@ class EMAModel: def __init__( self, parameters: Iterable[torch.nn.Parameter], - update_after_step=0, - inv_gamma=1.0, - power=2 / 3, - min_value=0.0, - max_value=0.9999, + update_after_step: int = 0, + inv_gamma: float = 1.0, + power: float = 2 / 3, + min_value: float = 0.0, + max_value: float = 0.9999, ): """ @crowsonkb's notes on EMA Warmup: @@ -165,7 +165,7 @@ class EMAModel: self.decay = 0.0 self.optimization_step = 0 - def get_decay(self, optimization_step): + def get_decay(self, optimization_step: int): """ Compute the decay factor for the exponential moving average. """ @@ -276,5 +276,5 @@ class EMAModel: self.copy_to(parameters) yield finally: - for s_param, param in zip(original_params, parameters): - param.data.copy_(s_param.data) + for o_param, param in zip(original_params, parameters): + param.data.copy_(o_param.data) -- cgit v1.2.3-54-g00ecf