diff options
Diffstat (limited to 'training/util.py')
| -rw-r--r-- | training/util.py | 22 |
1 files changed, 14 insertions, 8 deletions
diff --git a/training/util.py b/training/util.py index 8bd8a83..61f1533 100644 --- a/training/util.py +++ b/training/util.py | |||
| @@ -16,19 +16,25 @@ def save_args(basepath: Path, args, extra={}): | |||
| 16 | 16 | ||
| 17 | 17 | ||
| 18 | class AverageMeter: | 18 | class AverageMeter: |
| 19 | avg: Any | 19 | def __init__(self, inv_gamma=1.0, power=2 / 3): |
| 20 | 20 | self.inv_gamma = inv_gamma | |
| 21 | def __init__(self, name=None): | 21 | self.power = power |
| 22 | self.name = name | ||
| 23 | self.reset() | 22 | self.reset() |
| 24 | 23 | ||
| 25 | def reset(self): | 24 | def reset(self): |
| 26 | self.sum = self.count = self.avg = 0 | 25 | self.step = 0 |
| 26 | self.avg = 0 | ||
| 27 | |||
| 28 | def get_decay(self): | ||
| 29 | if self.step <= 0: | ||
| 30 | return 1 | ||
| 31 | |||
| 32 | return (self.step / self.inv_gamma) ** -self.power | ||
| 27 | 33 | ||
| 28 | def update(self, val, n=1): | 34 | def update(self, val, n=1): |
| 29 | self.sum += val * n | 35 | for _ in range(n): |
| 30 | self.count += n | 36 | self.step += n |
| 31 | self.avg = self.sum / self.count | 37 | self.avg += (val - self.avg) * self.get_decay() |
| 32 | 38 | ||
| 33 | 39 | ||
| 34 | class EMAModel(EMAModel_): | 40 | class EMAModel(EMAModel_): |
