diff options
Diffstat (limited to 'training')
-rw-r--r-- | training/lr.py | 9 |
1 files changed, 6 insertions, 3 deletions
diff --git a/training/lr.py b/training/lr.py index acc01a2..37588b6 100644 --- a/training/lr.py +++ b/training/lr.py | |||
@@ -58,7 +58,11 @@ class LRFinder(): | |||
58 | losses = [] | 58 | losses = [] |
59 | accs = [] | 59 | accs = [] |
60 | 60 | ||
61 | lr_scheduler = get_exponential_schedule(self.optimizer, end_lr, num_epochs) | 61 | lr_scheduler = get_exponential_schedule( |
62 | self.optimizer, | ||
63 | end_lr, | ||
64 | num_epochs * min(num_train_batches, len(self.train_dataloader)) | ||
65 | ) | ||
62 | 66 | ||
63 | steps = min(num_train_batches, len(self.train_dataloader)) | 67 | steps = min(num_train_batches, len(self.train_dataloader)) |
64 | steps += min(num_val_batches, len(self.val_dataloader)) | 68 | steps += min(num_val_batches, len(self.val_dataloader)) |
@@ -90,6 +94,7 @@ class LRFinder(): | |||
90 | self.accelerator.backward(loss) | 94 | self.accelerator.backward(loss) |
91 | 95 | ||
92 | self.optimizer.step() | 96 | self.optimizer.step() |
97 | lr_scheduler.step() | ||
93 | self.optimizer.zero_grad(set_to_none=True) | 98 | self.optimizer.zero_grad(set_to_none=True) |
94 | 99 | ||
95 | if self.accelerator.sync_gradients: | 100 | if self.accelerator.sync_gradients: |
@@ -109,8 +114,6 @@ class LRFinder(): | |||
109 | 114 | ||
110 | progress_bar.update(1) | 115 | progress_bar.update(1) |
111 | 116 | ||
112 | lr_scheduler.step() | ||
113 | |||
114 | loss = avg_loss.avg.item() | 117 | loss = avg_loss.avg.item() |
115 | acc = avg_acc.avg.item() | 118 | acc = avg_acc.avg.item() |
116 | 119 | ||