summaryrefslogtreecommitdiffstats
path: root/training/util.py
diff options
context:
space:
mode:
Diffstat (limited to 'training/util.py')
-rw-r--r--training/util.py22
1 files changed, 14 insertions, 8 deletions
diff --git a/training/util.py b/training/util.py
index 8bd8a83..61f1533 100644
--- a/training/util.py
+++ b/training/util.py
@@ -16,19 +16,25 @@ def save_args(basepath: Path, args, extra={}):
16 16
17 17
18class AverageMeter: 18class AverageMeter:
19 avg: Any 19 def __init__(self, inv_gamma=1.0, power=2 / 3):
20 20 self.inv_gamma = inv_gamma
21 def __init__(self, name=None): 21 self.power = power
22 self.name = name
23 self.reset() 22 self.reset()
24 23
25 def reset(self): 24 def reset(self):
26 self.sum = self.count = self.avg = 0 25 self.step = 0
26 self.avg = 0
27
28 def get_decay(self):
29 if self.step <= 0:
30 return 1
31
32 return (self.step / self.inv_gamma) ** -self.power
27 33
28 def update(self, val, n=1): 34 def update(self, val, n=1):
29 self.sum += val * n 35 for _ in range(n):
30 self.count += n 36 self.step += n
31 self.avg = self.sum / self.count 37 self.avg += (val - self.avg) * self.get_decay()
32 38
33 39
34class EMAModel(EMAModel_): 40class EMAModel(EMAModel_):