summaryrefslogtreecommitdiffstats
path: root/training/util.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-04-28 16:22:06 +0200
committerVolpeon <git@volpeon.ink>2023-04-28 16:22:06 +0200
commitec762fd3afaa6df0715fa1cbe9e6f088b9276ea0 (patch)
tree04106c3b44ebafd13cfa5f2f9e2c8bf30ab57bc2 /training/util.py
parentFix (diff)
downloadtextual-inversion-diff-ec762fd3afaa6df0715fa1cbe9e6f088b9276ea0.tar.gz
textual-inversion-diff-ec762fd3afaa6df0715fa1cbe9e6f088b9276ea0.tar.bz2
textual-inversion-diff-ec762fd3afaa6df0715fa1cbe9e6f088b9276ea0.zip
Fixed loss/acc logging
Diffstat (limited to 'training/util.py')
-rw-r--r--training/util.py7
1 files changed, 6 insertions, 1 deletions
diff --git a/training/util.py b/training/util.py
index 61f1533..0b6bea9 100644
--- a/training/util.py
+++ b/training/util.py
@@ -1,5 +1,6 @@
1from pathlib import Path 1from pathlib import Path
2import json 2import json
3import math
3from typing import Iterable, Any 4from typing import Iterable, Any
4from contextlib import contextmanager 5from contextlib import contextmanager
5 6
@@ -23,7 +24,9 @@ class AverageMeter:
23 24
24 def reset(self): 25 def reset(self):
25 self.step = 0 26 self.step = 0
26 self.avg = 0 27 self.min = math.inf
28 self.max = 0.0
29 self.avg = 0.0
27 30
28 def get_decay(self): 31 def get_decay(self):
29 if self.step <= 0: 32 if self.step <= 0:
@@ -35,6 +38,8 @@ class AverageMeter:
35 for _ in range(n): 38 for _ in range(n):
36 self.step += n 39 self.step += n
37 self.avg += (val - self.avg) * self.get_decay() 40 self.avg += (val - self.avg) * self.get_decay()
41 self.min = min(self.min, self.avg)
42 self.max = max(self.max, self.avg)
38 43
39 44
40class EMAModel(EMAModel_): 45class EMAModel(EMAModel_):