From bed44095ab99440467c2f302899b970c92baebf8 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Wed, 4 Jan 2023 10:32:58 +0100 Subject: Better eval generator --- training/lr.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) (limited to 'training') diff --git a/training/lr.py b/training/lr.py index a3144ba..c8dc040 100644 --- a/training/lr.py +++ b/training/lr.py @@ -24,7 +24,7 @@ class LRFinder(): optimizer, train_dataloader, val_dataloader, - loss_fn: Union[Callable[[Any], Tuple[Any, Any, int]], Callable[[Any, bool], Tuple[Any, Any, int]]], + loss_fn: Union[Callable[[int, Any], Tuple[Any, Any, int]], Callable[[int, Any, bool], Tuple[Any, Any, int]]], on_train: Callable[[], None] = noop, on_eval: Callable[[], None] = noop ): @@ -89,7 +89,7 @@ class LRFinder(): break with self.accelerator.accumulate(self.model): - loss, acc, bsz = self.loss_fn(batch) + loss, acc, bsz = self.loss_fn(step, batch) self.accelerator.backward(loss) @@ -108,7 +108,7 @@ class LRFinder(): if step >= num_val_batches: break - loss, acc, bsz = self.loss_fn(batch, True) + loss, acc, bsz = self.loss_fn(step, batch, True) avg_loss.update(loss.detach_(), bsz) avg_acc.update(acc.detach_(), bsz) -- cgit v1.2.3-70-g09d2