diff options
Diffstat (limited to 'training')
| -rw-r--r-- | training/lr.py | 38 |
1 files changed, 25 insertions, 13 deletions
diff --git a/training/lr.py b/training/lr.py index 5343f24..8e558e1 100644 --- a/training/lr.py +++ b/training/lr.py | |||
| @@ -1,3 +1,6 @@ | |||
| 1 | import math | ||
| 2 | import copy | ||
| 3 | |||
| 1 | import matplotlib.pyplot as plt | 4 | import matplotlib.pyplot as plt |
| 2 | import numpy as np | 5 | import numpy as np |
| 3 | import torch | 6 | import torch |
| @@ -16,15 +19,22 @@ class LRFinder(): | |||
| 16 | self.val_dataloader = val_dataloader | 19 | self.val_dataloader = val_dataloader |
| 17 | self.loss_fn = loss_fn | 20 | self.loss_fn = loss_fn |
| 18 | 21 | ||
| 19 | def run(self, num_epochs=100, num_train_steps=1, num_val_steps=1, smooth_f=0.05, diverge_th=5): | 22 | self.model_state = copy.deepcopy(model.state_dict()) |
| 23 | self.optimizer_state = copy.deepcopy(optimizer.state_dict()) | ||
| 24 | |||
| 25 | def run(self, min_lr, num_epochs=100, num_train_batches=1, num_val_batches=math.inf, smooth_f=0.05, diverge_th=5): | ||
| 20 | best_loss = None | 26 | best_loss = None |
| 21 | lrs = [] | 27 | lrs = [] |
| 22 | losses = [] | 28 | losses = [] |
| 23 | 29 | ||
| 24 | lr_scheduler = get_exponential_schedule(self.optimizer, num_epochs) | 30 | lr_scheduler = get_exponential_schedule(self.optimizer, min_lr, num_epochs) |
| 31 | |||
| 32 | steps = min(num_train_batches, len(self.train_dataloader)) | ||
| 33 | steps += min(num_val_batches, len(self.val_dataloader)) | ||
| 34 | steps *= num_epochs | ||
| 25 | 35 | ||
| 26 | progress_bar = tqdm( | 36 | progress_bar = tqdm( |
| 27 | range(num_epochs * (num_train_steps + num_val_steps)), | 37 | range(steps), |
| 28 | disable=not self.accelerator.is_local_main_process, | 38 | disable=not self.accelerator.is_local_main_process, |
| 29 | dynamic_ncols=True | 39 | dynamic_ncols=True |
| 30 | ) | 40 | ) |
| @@ -38,6 +48,9 @@ class LRFinder(): | |||
| 38 | self.model.train() | 48 | self.model.train() |
| 39 | 49 | ||
| 40 | for step, batch in enumerate(self.train_dataloader): | 50 | for step, batch in enumerate(self.train_dataloader): |
| 51 | if step >= num_train_batches: | ||
| 52 | break | ||
| 53 | |||
| 41 | with self.accelerator.accumulate(self.model): | 54 | with self.accelerator.accumulate(self.model): |
| 42 | loss, acc, bsz = self.loss_fn(batch) | 55 | loss, acc, bsz = self.loss_fn(batch) |
| 43 | 56 | ||
| @@ -49,21 +62,17 @@ class LRFinder(): | |||
| 49 | if self.accelerator.sync_gradients: | 62 | if self.accelerator.sync_gradients: |
| 50 | progress_bar.update(1) | 63 | progress_bar.update(1) |
| 51 | 64 | ||
| 52 | if step >= num_train_steps: | ||
| 53 | break | ||
| 54 | |||
| 55 | self.model.eval() | 65 | self.model.eval() |
| 56 | 66 | ||
| 57 | with torch.inference_mode(): | 67 | with torch.inference_mode(): |
| 58 | for step, batch in enumerate(self.val_dataloader): | 68 | for step, batch in enumerate(self.val_dataloader): |
| 69 | if step >= num_val_batches: | ||
| 70 | break | ||
| 71 | |||
| 59 | loss, acc, bsz = self.loss_fn(batch) | 72 | loss, acc, bsz = self.loss_fn(batch) |
| 60 | avg_loss.update(loss.detach_(), bsz) | 73 | avg_loss.update(loss.detach_(), bsz) |
| 61 | 74 | ||
| 62 | if self.accelerator.sync_gradients: | 75 | progress_bar.update(1) |
| 63 | progress_bar.update(1) | ||
| 64 | |||
| 65 | if step >= num_val_steps: | ||
| 66 | break | ||
| 67 | 76 | ||
| 68 | lr_scheduler.step() | 77 | lr_scheduler.step() |
| 69 | 78 | ||
| @@ -87,6 +96,9 @@ class LRFinder(): | |||
| 87 | "lr": lr, | 96 | "lr": lr, |
| 88 | }) | 97 | }) |
| 89 | 98 | ||
| 99 | self.model.load_state_dict(self.model_state) | ||
| 100 | self.optimizer.load_state_dict(self.optimizer_state) | ||
| 101 | |||
| 90 | if loss > diverge_th * best_loss: | 102 | if loss > diverge_th * best_loss: |
| 91 | print("Stopping early, the loss has diverged") | 103 | print("Stopping early, the loss has diverged") |
| 92 | break | 104 | break |
| @@ -120,8 +132,8 @@ class LRFinder(): | |||
| 120 | ax.set_ylabel("Loss") | 132 | ax.set_ylabel("Loss") |
| 121 | 133 | ||
| 122 | 134 | ||
| 123 | def get_exponential_schedule(optimizer, num_epochs, last_epoch=-1): | 135 | def get_exponential_schedule(optimizer, min_lr, num_epochs, last_epoch=-1): |
| 124 | def lr_lambda(current_epoch: int): | 136 | def lr_lambda(current_epoch: int): |
| 125 | return (current_epoch / num_epochs) ** 5 | 137 | return min_lr + ((current_epoch / num_epochs) ** 10) * (1 - min_lr) |
| 126 | 138 | ||
| 127 | return LambdaLR(optimizer, lr_lambda, last_epoch) | 139 | return LambdaLR(optimizer, lr_lambda, last_epoch) |
