diff options
Diffstat (limited to 'training/util.py')
-rw-r--r-- | training/util.py | 22 |
1 files changed, 14 insertions, 8 deletions
diff --git a/training/util.py b/training/util.py index 8bd8a83..61f1533 100644 --- a/training/util.py +++ b/training/util.py | |||
@@ -16,19 +16,25 @@ def save_args(basepath: Path, args, extra={}): | |||
16 | 16 | ||
17 | 17 | ||
18 | class AverageMeter: | 18 | class AverageMeter: |
19 | avg: Any | 19 | def __init__(self, inv_gamma=1.0, power=2 / 3): |
20 | 20 | self.inv_gamma = inv_gamma | |
21 | def __init__(self, name=None): | 21 | self.power = power |
22 | self.name = name | ||
23 | self.reset() | 22 | self.reset() |
24 | 23 | ||
25 | def reset(self): | 24 | def reset(self): |
26 | self.sum = self.count = self.avg = 0 | 25 | self.step = 0 |
26 | self.avg = 0 | ||
27 | |||
28 | def get_decay(self): | ||
29 | if self.step <= 0: | ||
30 | return 1 | ||
31 | |||
32 | return (self.step / self.inv_gamma) ** -self.power | ||
27 | 33 | ||
28 | def update(self, val, n=1): | 34 | def update(self, val, n=1): |
29 | self.sum += val * n | 35 | for _ in range(n): |
30 | self.count += n | 36 | self.step += n |
31 | self.avg = self.sum / self.count | 37 | self.avg += (val - self.avg) * self.get_decay() |
32 | 38 | ||
33 | 39 | ||
34 | class EMAModel(EMAModel_): | 40 | class EMAModel(EMAModel_): |