summaryrefslogtreecommitdiffstats
path: root/training
diff options
context:
space:
mode:
Diffstat (limited to 'training')
-rw-r--r--training/lr.py9
1 files changed, 6 insertions, 3 deletions
diff --git a/training/lr.py b/training/lr.py
index acc01a2..37588b6 100644
--- a/training/lr.py
+++ b/training/lr.py
@@ -58,7 +58,11 @@ class LRFinder():
58 losses = [] 58 losses = []
59 accs = [] 59 accs = []
60 60
61 lr_scheduler = get_exponential_schedule(self.optimizer, end_lr, num_epochs) 61 lr_scheduler = get_exponential_schedule(
62 self.optimizer,
63 end_lr,
64 num_epochs * min(num_train_batches, len(self.train_dataloader))
65 )
62 66
63 steps = min(num_train_batches, len(self.train_dataloader)) 67 steps = min(num_train_batches, len(self.train_dataloader))
64 steps += min(num_val_batches, len(self.val_dataloader)) 68 steps += min(num_val_batches, len(self.val_dataloader))
@@ -90,6 +94,7 @@ class LRFinder():
90 self.accelerator.backward(loss) 94 self.accelerator.backward(loss)
91 95
92 self.optimizer.step() 96 self.optimizer.step()
97 lr_scheduler.step()
93 self.optimizer.zero_grad(set_to_none=True) 98 self.optimizer.zero_grad(set_to_none=True)
94 99
95 if self.accelerator.sync_gradients: 100 if self.accelerator.sync_gradients:
@@ -109,8 +114,6 @@ class LRFinder():
109 114
110 progress_bar.update(1) 115 progress_bar.update(1)
111 116
112 lr_scheduler.step()
113
114 loss = avg_loss.avg.item() 117 loss = avg_loss.avg.item()
115 acc = avg_acc.avg.item() 118 acc = avg_acc.avg.item()
116 119