From a72b6260c117cabe4fcb2996cce4f870986df99b Mon Sep 17 00:00:00 2001 From: Volpeon Date: Tue, 3 Jan 2023 12:40:16 +0100 Subject: Added vector dropout --- training/lr.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) (limited to 'training') diff --git a/training/lr.py b/training/lr.py index acc01a2..37588b6 100644 --- a/training/lr.py +++ b/training/lr.py @@ -58,7 +58,11 @@ class LRFinder(): losses = [] accs = [] - lr_scheduler = get_exponential_schedule(self.optimizer, end_lr, num_epochs) + lr_scheduler = get_exponential_schedule( + self.optimizer, + end_lr, + num_epochs * min(num_train_batches, len(self.train_dataloader)) + ) steps = min(num_train_batches, len(self.train_dataloader)) steps += min(num_val_batches, len(self.val_dataloader)) @@ -90,6 +94,7 @@ class LRFinder(): self.accelerator.backward(loss) self.optimizer.step() + lr_scheduler.step() self.optimizer.zero_grad(set_to_none=True) if self.accelerator.sync_gradients: @@ -109,8 +114,6 @@ class LRFinder(): progress_bar.update(1) - lr_scheduler.step() - loss = avg_loss.avg.item() acc = avg_acc.avg.item() -- cgit v1.2.3-70-g09d2