diff options
Diffstat (limited to 'training')
-rw-r--r-- | training/common.py | 37 | ||||
-rw-r--r-- | training/util.py | 16 |
2 files changed, 26 insertions, 27 deletions
diff --git a/training/common.py b/training/common.py index 8083137..5d1e3f9 100644 --- a/training/common.py +++ b/training/common.py | |||
@@ -316,30 +316,29 @@ def train_loop( | |||
316 | cur_loss_val = AverageMeter() | 316 | cur_loss_val = AverageMeter() |
317 | cur_acc_val = AverageMeter() | 317 | cur_acc_val = AverageMeter() |
318 | 318 | ||
319 | with torch.inference_mode(): | 319 | with torch.inference_mode(), on_eval(): |
320 | with on_eval(): | 320 | for step, batch in enumerate(val_dataloader): |
321 | for step, batch in enumerate(val_dataloader): | 321 | loss, acc, bsz = loss_step(step, batch, True) |
322 | loss, acc, bsz = loss_step(step, batch, True) | ||
323 | 322 | ||
324 | loss = loss.detach_() | 323 | loss = loss.detach_() |
325 | acc = acc.detach_() | 324 | acc = acc.detach_() |
326 | 325 | ||
327 | cur_loss_val.update(loss, bsz) | 326 | cur_loss_val.update(loss, bsz) |
328 | cur_acc_val.update(acc, bsz) | 327 | cur_acc_val.update(acc, bsz) |
329 | 328 | ||
330 | avg_loss_val.update(loss, bsz) | 329 | avg_loss_val.update(loss, bsz) |
331 | avg_acc_val.update(acc, bsz) | 330 | avg_acc_val.update(acc, bsz) |
332 | 331 | ||
333 | local_progress_bar.update(1) | 332 | local_progress_bar.update(1) |
334 | global_progress_bar.update(1) | 333 | global_progress_bar.update(1) |
335 | 334 | ||
336 | logs = { | 335 | logs = { |
337 | "val/loss": avg_loss_val.avg.item(), | 336 | "val/loss": avg_loss_val.avg.item(), |
338 | "val/acc": avg_acc_val.avg.item(), | 337 | "val/acc": avg_acc_val.avg.item(), |
339 | "val/cur_loss": loss.item(), | 338 | "val/cur_loss": loss.item(), |
340 | "val/cur_acc": acc.item(), | 339 | "val/cur_acc": acc.item(), |
341 | } | 340 | } |
342 | local_progress_bar.set_postfix(**logs) | 341 | local_progress_bar.set_postfix(**logs) |
343 | 342 | ||
344 | logs["val/cur_loss"] = cur_loss_val.avg.item() | 343 | logs["val/cur_loss"] = cur_loss_val.avg.item() |
345 | logs["val/cur_acc"] = cur_acc_val.avg.item() | 344 | logs["val/cur_acc"] = cur_acc_val.avg.item() |
diff --git a/training/util.py b/training/util.py index 1008021..781cf04 100644 --- a/training/util.py +++ b/training/util.py | |||
@@ -134,11 +134,11 @@ class EMAModel: | |||
134 | def __init__( | 134 | def __init__( |
135 | self, | 135 | self, |
136 | parameters: Iterable[torch.nn.Parameter], | 136 | parameters: Iterable[torch.nn.Parameter], |
137 | update_after_step=0, | 137 | update_after_step: int = 0, |
138 | inv_gamma=1.0, | 138 | inv_gamma: float = 1.0, |
139 | power=2 / 3, | 139 | power: float = 2 / 3, |
140 | min_value=0.0, | 140 | min_value: float = 0.0, |
141 | max_value=0.9999, | 141 | max_value: float = 0.9999, |
142 | ): | 142 | ): |
143 | """ | 143 | """ |
144 | @crowsonkb's notes on EMA Warmup: | 144 | @crowsonkb's notes on EMA Warmup: |
@@ -165,7 +165,7 @@ class EMAModel: | |||
165 | self.decay = 0.0 | 165 | self.decay = 0.0 |
166 | self.optimization_step = 0 | 166 | self.optimization_step = 0 |
167 | 167 | ||
168 | def get_decay(self, optimization_step): | 168 | def get_decay(self, optimization_step: int): |
169 | """ | 169 | """ |
170 | Compute the decay factor for the exponential moving average. | 170 | Compute the decay factor for the exponential moving average. |
171 | """ | 171 | """ |
@@ -276,5 +276,5 @@ class EMAModel: | |||
276 | self.copy_to(parameters) | 276 | self.copy_to(parameters) |
277 | yield | 277 | yield |
278 | finally: | 278 | finally: |
279 | for s_param, param in zip(original_params, parameters): | 279 | for o_param, param in zip(original_params, parameters): |
280 | param.data.copy_(s_param.data) | 280 | param.data.copy_(o_param.data) |