diff options
Diffstat (limited to 'training')
-rw-r--r-- | training/lr.py | 6 |
1 files changed, 3 insertions, 3 deletions
diff --git a/training/lr.py b/training/lr.py index a3144ba..c8dc040 100644 --- a/training/lr.py +++ b/training/lr.py | |||
@@ -24,7 +24,7 @@ class LRFinder(): | |||
24 | optimizer, | 24 | optimizer, |
25 | train_dataloader, | 25 | train_dataloader, |
26 | val_dataloader, | 26 | val_dataloader, |
27 | loss_fn: Union[Callable[[Any], Tuple[Any, Any, int]], Callable[[Any, bool], Tuple[Any, Any, int]]], | 27 | loss_fn: Union[Callable[[int, Any], Tuple[Any, Any, int]], Callable[[int, Any, bool], Tuple[Any, Any, int]]], |
28 | on_train: Callable[[], None] = noop, | 28 | on_train: Callable[[], None] = noop, |
29 | on_eval: Callable[[], None] = noop | 29 | on_eval: Callable[[], None] = noop |
30 | ): | 30 | ): |
@@ -89,7 +89,7 @@ class LRFinder(): | |||
89 | break | 89 | break |
90 | 90 | ||
91 | with self.accelerator.accumulate(self.model): | 91 | with self.accelerator.accumulate(self.model): |
92 | loss, acc, bsz = self.loss_fn(batch) | 92 | loss, acc, bsz = self.loss_fn(step, batch) |
93 | 93 | ||
94 | self.accelerator.backward(loss) | 94 | self.accelerator.backward(loss) |
95 | 95 | ||
@@ -108,7 +108,7 @@ class LRFinder(): | |||
108 | if step >= num_val_batches: | 108 | if step >= num_val_batches: |
109 | break | 109 | break |
110 | 110 | ||
111 | loss, acc, bsz = self.loss_fn(batch, True) | 111 | loss, acc, bsz = self.loss_fn(step, batch, True) |
112 | avg_loss.update(loss.detach_(), bsz) | 112 | avg_loss.update(loss.detach_(), bsz) |
113 | avg_acc.update(acc.detach_(), bsz) | 113 | avg_acc.update(acc.detach_(), bsz) |
114 | 114 | ||