diff options
Diffstat (limited to 'training')
| -rw-r--r-- | training/lr.py | 33 |
1 files changed, 31 insertions, 2 deletions
diff --git a/training/lr.py b/training/lr.py index 3abd2f2..fe166ed 100644 --- a/training/lr.py +++ b/training/lr.py | |||
| @@ -1,5 +1,6 @@ | |||
| 1 | import math | 1 | import math |
| 2 | import copy | 2 | import copy |
| 3 | from typing import Callable | ||
| 3 | 4 | ||
| 4 | import matplotlib.pyplot as plt | 5 | import matplotlib.pyplot as plt |
| 5 | import numpy as np | 6 | import numpy as np |
| @@ -10,19 +11,45 @@ from tqdm.auto import tqdm | |||
| 10 | from training.util import AverageMeter | 11 | from training.util import AverageMeter |
| 11 | 12 | ||
| 12 | 13 | ||
| 14 | def noop(): | ||
| 15 | pass | ||
| 16 | |||
| 17 | |||
| 13 | class LRFinder(): | 18 | class LRFinder(): |
| 14 | def __init__(self, accelerator, model, optimizer, train_dataloader, val_dataloader, loss_fn): | 19 | def __init__( |
| 20 | self, | ||
| 21 | accelerator, | ||
| 22 | model, | ||
| 23 | optimizer, | ||
| 24 | train_dataloader, | ||
| 25 | val_dataloader, | ||
| 26 | loss_fn, | ||
| 27 | on_train: Callable[[], None] = noop, | ||
| 28 | on_eval: Callable[[], None] = noop | ||
| 29 | ): | ||
| 15 | self.accelerator = accelerator | 30 | self.accelerator = accelerator |
| 16 | self.model = model | 31 | self.model = model |
| 17 | self.optimizer = optimizer | 32 | self.optimizer = optimizer |
| 18 | self.train_dataloader = train_dataloader | 33 | self.train_dataloader = train_dataloader |
| 19 | self.val_dataloader = val_dataloader | 34 | self.val_dataloader = val_dataloader |
| 20 | self.loss_fn = loss_fn | 35 | self.loss_fn = loss_fn |
| 36 | self.on_train = on_train | ||
| 37 | self.on_eval = on_eval | ||
| 21 | 38 | ||
| 22 | # self.model_state = copy.deepcopy(model.state_dict()) | 39 | # self.model_state = copy.deepcopy(model.state_dict()) |
| 23 | # self.optimizer_state = copy.deepcopy(optimizer.state_dict()) | 40 | # self.optimizer_state = copy.deepcopy(optimizer.state_dict()) |
| 24 | 41 | ||
| 25 | def run(self, min_lr, skip_start=10, skip_end=5, num_epochs=100, num_train_batches=1, num_val_batches=math.inf, smooth_f=0.05, diverge_th=5): | 42 | def run( |
| 43 | self, | ||
| 44 | min_lr, | ||
| 45 | skip_start: int = 10, | ||
| 46 | skip_end: int = 5, | ||
| 47 | num_epochs: int = 100, | ||
| 48 | num_train_batches: int = 1, | ||
| 49 | num_val_batches: int = math.inf, | ||
| 50 | smooth_f: float = 0.05, | ||
| 51 | diverge_th: int = 5 | ||
| 52 | ): | ||
| 26 | best_loss = None | 53 | best_loss = None |
| 27 | best_acc = None | 54 | best_acc = None |
| 28 | 55 | ||
| @@ -50,6 +77,7 @@ class LRFinder(): | |||
| 50 | avg_acc = AverageMeter() | 77 | avg_acc = AverageMeter() |
| 51 | 78 | ||
| 52 | self.model.train() | 79 | self.model.train() |
| 80 | self.on_train() | ||
| 53 | 81 | ||
| 54 | for step, batch in enumerate(self.train_dataloader): | 82 | for step, batch in enumerate(self.train_dataloader): |
| 55 | if step >= num_train_batches: | 83 | if step >= num_train_batches: |
| @@ -67,6 +95,7 @@ class LRFinder(): | |||
| 67 | progress_bar.update(1) | 95 | progress_bar.update(1) |
| 68 | 96 | ||
| 69 | self.model.eval() | 97 | self.model.eval() |
| 98 | self.on_eval() | ||
| 70 | 99 | ||
| 71 | with torch.inference_mode(): | 100 | with torch.inference_mode(): |
| 72 | for step, batch in enumerate(self.val_dataloader): | 101 | for step, batch in enumerate(self.val_dataloader): |
