diff options
Diffstat (limited to 'training')
| -rw-r--r-- | training/lr.py | 9 |
1 files changed, 6 insertions, 3 deletions
diff --git a/training/lr.py b/training/lr.py index acc01a2..37588b6 100644 --- a/training/lr.py +++ b/training/lr.py | |||
| @@ -58,7 +58,11 @@ class LRFinder(): | |||
| 58 | losses = [] | 58 | losses = [] |
| 59 | accs = [] | 59 | accs = [] |
| 60 | 60 | ||
| 61 | lr_scheduler = get_exponential_schedule(self.optimizer, end_lr, num_epochs) | 61 | lr_scheduler = get_exponential_schedule( |
| 62 | self.optimizer, | ||
| 63 | end_lr, | ||
| 64 | num_epochs * min(num_train_batches, len(self.train_dataloader)) | ||
| 65 | ) | ||
| 62 | 66 | ||
| 63 | steps = min(num_train_batches, len(self.train_dataloader)) | 67 | steps = min(num_train_batches, len(self.train_dataloader)) |
| 64 | steps += min(num_val_batches, len(self.val_dataloader)) | 68 | steps += min(num_val_batches, len(self.val_dataloader)) |
| @@ -90,6 +94,7 @@ class LRFinder(): | |||
| 90 | self.accelerator.backward(loss) | 94 | self.accelerator.backward(loss) |
| 91 | 95 | ||
| 92 | self.optimizer.step() | 96 | self.optimizer.step() |
| 97 | lr_scheduler.step() | ||
| 93 | self.optimizer.zero_grad(set_to_none=True) | 98 | self.optimizer.zero_grad(set_to_none=True) |
| 94 | 99 | ||
| 95 | if self.accelerator.sync_gradients: | 100 | if self.accelerator.sync_gradients: |
| @@ -109,8 +114,6 @@ class LRFinder(): | |||
| 109 | 114 | ||
| 110 | progress_bar.update(1) | 115 | progress_bar.update(1) |
| 111 | 116 | ||
| 112 | lr_scheduler.step() | ||
| 113 | |||
| 114 | loss = avg_loss.avg.item() | 117 | loss = avg_loss.avg.item() |
| 115 | acc = avg_acc.avg.item() | 118 | acc = avg_acc.avg.item() |
| 116 | 119 | ||
