summaryrefslogtreecommitdiffstats
path: root/training
diff options
context:
space:
mode:
Diffstat (limited to 'training')
-rw-r--r--training/common.py37
-rw-r--r--training/util.py16
2 files changed, 26 insertions, 27 deletions
diff --git a/training/common.py b/training/common.py
index 8083137..5d1e3f9 100644
--- a/training/common.py
+++ b/training/common.py
@@ -316,30 +316,29 @@ def train_loop(
316 cur_loss_val = AverageMeter() 316 cur_loss_val = AverageMeter()
317 cur_acc_val = AverageMeter() 317 cur_acc_val = AverageMeter()
318 318
319 with torch.inference_mode(): 319 with torch.inference_mode(), on_eval():
320 with on_eval(): 320 for step, batch in enumerate(val_dataloader):
321 for step, batch in enumerate(val_dataloader): 321 loss, acc, bsz = loss_step(step, batch, True)
322 loss, acc, bsz = loss_step(step, batch, True)
323 322
324 loss = loss.detach_() 323 loss = loss.detach_()
325 acc = acc.detach_() 324 acc = acc.detach_()
326 325
327 cur_loss_val.update(loss, bsz) 326 cur_loss_val.update(loss, bsz)
328 cur_acc_val.update(acc, bsz) 327 cur_acc_val.update(acc, bsz)
329 328
330 avg_loss_val.update(loss, bsz) 329 avg_loss_val.update(loss, bsz)
331 avg_acc_val.update(acc, bsz) 330 avg_acc_val.update(acc, bsz)
332 331
333 local_progress_bar.update(1) 332 local_progress_bar.update(1)
334 global_progress_bar.update(1) 333 global_progress_bar.update(1)
335 334
336 logs = { 335 logs = {
337 "val/loss": avg_loss_val.avg.item(), 336 "val/loss": avg_loss_val.avg.item(),
338 "val/acc": avg_acc_val.avg.item(), 337 "val/acc": avg_acc_val.avg.item(),
339 "val/cur_loss": loss.item(), 338 "val/cur_loss": loss.item(),
340 "val/cur_acc": acc.item(), 339 "val/cur_acc": acc.item(),
341 } 340 }
342 local_progress_bar.set_postfix(**logs) 341 local_progress_bar.set_postfix(**logs)
343 342
344 logs["val/cur_loss"] = cur_loss_val.avg.item() 343 logs["val/cur_loss"] = cur_loss_val.avg.item()
345 logs["val/cur_acc"] = cur_acc_val.avg.item() 344 logs["val/cur_acc"] = cur_acc_val.avg.item()
diff --git a/training/util.py b/training/util.py
index 1008021..781cf04 100644
--- a/training/util.py
+++ b/training/util.py
@@ -134,11 +134,11 @@ class EMAModel:
134 def __init__( 134 def __init__(
135 self, 135 self,
136 parameters: Iterable[torch.nn.Parameter], 136 parameters: Iterable[torch.nn.Parameter],
137 update_after_step=0, 137 update_after_step: int = 0,
138 inv_gamma=1.0, 138 inv_gamma: float = 1.0,
139 power=2 / 3, 139 power: float = 2 / 3,
140 min_value=0.0, 140 min_value: float = 0.0,
141 max_value=0.9999, 141 max_value: float = 0.9999,
142 ): 142 ):
143 """ 143 """
144 @crowsonkb's notes on EMA Warmup: 144 @crowsonkb's notes on EMA Warmup:
@@ -165,7 +165,7 @@ class EMAModel:
165 self.decay = 0.0 165 self.decay = 0.0
166 self.optimization_step = 0 166 self.optimization_step = 0
167 167
168 def get_decay(self, optimization_step): 168 def get_decay(self, optimization_step: int):
169 """ 169 """
170 Compute the decay factor for the exponential moving average. 170 Compute the decay factor for the exponential moving average.
171 """ 171 """
@@ -276,5 +276,5 @@ class EMAModel:
276 self.copy_to(parameters) 276 self.copy_to(parameters)
277 yield 277 yield
278 finally: 278 finally:
279 for s_param, param in zip(original_params, parameters): 279 for o_param, param in zip(original_params, parameters):
280 param.data.copy_(s_param.data) 280 param.data.copy_(o_param.data)