diff options
Diffstat (limited to 'training/lr.py')
| -rw-r--r-- | training/lr.py | 29 |
1 files changed, 15 insertions, 14 deletions
diff --git a/training/lr.py b/training/lr.py index 7584ba2..902c4eb 100644 --- a/training/lr.py +++ b/training/lr.py | |||
| @@ -9,6 +9,7 @@ import torch | |||
| 9 | from torch.optim.lr_scheduler import LambdaLR | 9 | from torch.optim.lr_scheduler import LambdaLR |
| 10 | from tqdm.auto import tqdm | 10 | from tqdm.auto import tqdm |
| 11 | 11 | ||
| 12 | from training.functional import TrainingCallbacks | ||
| 12 | from training.util import AverageMeter | 13 | from training.util import AverageMeter |
| 13 | 14 | ||
| 14 | 15 | ||
| @@ -24,26 +25,19 @@ class LRFinder(): | |||
| 24 | def __init__( | 25 | def __init__( |
| 25 | self, | 26 | self, |
| 26 | accelerator, | 27 | accelerator, |
| 27 | model, | ||
| 28 | optimizer, | 28 | optimizer, |
| 29 | train_dataloader, | 29 | train_dataloader, |
| 30 | val_dataloader, | 30 | val_dataloader, |
| 31 | loss_fn: Union[Callable[[int, Any], Tuple[Any, Any, int]], Callable[[int, Any, bool], Tuple[Any, Any, int]]], | 31 | loss_fn: Union[Callable[[int, Any], Tuple[Any, Any, int]], Callable[[int, Any, bool], Tuple[Any, Any, int]]], |
| 32 | on_train: Callable[[int], _GeneratorContextManager] = noop_ctx, | 32 | callbacks: TrainingCallbacks = TrainingCallbacks() |
| 33 | on_before_optimize: Callable[[int], None] = noop, | ||
| 34 | on_after_optimize: Callable[[float], None] = noop, | ||
| 35 | on_eval: Callable[[], _GeneratorContextManager] = noop_ctx | ||
| 36 | ): | 33 | ): |
| 37 | self.accelerator = accelerator | 34 | self.accelerator = accelerator |
| 38 | self.model = model | 35 | self.model = callbacks.on_model() |
| 39 | self.optimizer = optimizer | 36 | self.optimizer = optimizer |
| 40 | self.train_dataloader = train_dataloader | 37 | self.train_dataloader = train_dataloader |
| 41 | self.val_dataloader = val_dataloader | 38 | self.val_dataloader = val_dataloader |
| 42 | self.loss_fn = loss_fn | 39 | self.loss_fn = loss_fn |
| 43 | self.on_train = on_train | 40 | self.callbacks = callbacks |
| 44 | self.on_before_optimize = on_before_optimize | ||
| 45 | self.on_after_optimize = on_after_optimize | ||
| 46 | self.on_eval = on_eval | ||
| 47 | 41 | ||
| 48 | # self.model_state = copy.deepcopy(model.state_dict()) | 42 | # self.model_state = copy.deepcopy(model.state_dict()) |
| 49 | # self.optimizer_state = copy.deepcopy(optimizer.state_dict()) | 43 | # self.optimizer_state = copy.deepcopy(optimizer.state_dict()) |
| @@ -82,6 +76,13 @@ class LRFinder(): | |||
| 82 | ) | 76 | ) |
| 83 | progress_bar.set_description("Epoch X / Y") | 77 | progress_bar.set_description("Epoch X / Y") |
| 84 | 78 | ||
| 79 | self.callbacks.on_prepare() | ||
| 80 | |||
| 81 | on_train = self.callbacks.on_train | ||
| 82 | on_before_optimize = self.callbacks.on_before_optimize | ||
| 83 | on_after_optimize = self.callbacks.on_after_optimize | ||
| 84 | on_eval = self.callbacks.on_eval | ||
| 85 | |||
| 85 | for epoch in range(num_epochs): | 86 | for epoch in range(num_epochs): |
| 86 | progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") | 87 | progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") |
| 87 | 88 | ||
| @@ -90,7 +91,7 @@ class LRFinder(): | |||
| 90 | 91 | ||
| 91 | self.model.train() | 92 | self.model.train() |
| 92 | 93 | ||
| 93 | with self.on_train(epoch): | 94 | with on_train(epoch): |
| 94 | for step, batch in enumerate(self.train_dataloader): | 95 | for step, batch in enumerate(self.train_dataloader): |
| 95 | if step >= num_train_batches: | 96 | if step >= num_train_batches: |
| 96 | break | 97 | break |
| @@ -100,21 +101,21 @@ class LRFinder(): | |||
| 100 | 101 | ||
| 101 | self.accelerator.backward(loss) | 102 | self.accelerator.backward(loss) |
| 102 | 103 | ||
| 103 | self.on_before_optimize(epoch) | 104 | on_before_optimize(epoch) |
| 104 | 105 | ||
| 105 | self.optimizer.step() | 106 | self.optimizer.step() |
| 106 | lr_scheduler.step() | 107 | lr_scheduler.step() |
| 107 | self.optimizer.zero_grad(set_to_none=True) | 108 | self.optimizer.zero_grad(set_to_none=True) |
| 108 | 109 | ||
| 109 | if self.accelerator.sync_gradients: | 110 | if self.accelerator.sync_gradients: |
| 110 | self.on_after_optimize(lr_scheduler.get_last_lr()[0]) | 111 | on_after_optimize(lr_scheduler.get_last_lr()[0]) |
| 111 | 112 | ||
| 112 | progress_bar.update(1) | 113 | progress_bar.update(1) |
| 113 | 114 | ||
| 114 | self.model.eval() | 115 | self.model.eval() |
| 115 | 116 | ||
| 116 | with torch.inference_mode(): | 117 | with torch.inference_mode(): |
| 117 | with self.on_eval(): | 118 | with on_eval(): |
| 118 | for step, batch in enumerate(self.val_dataloader): | 119 | for step, batch in enumerate(self.val_dataloader): |
| 119 | if step >= num_val_batches: | 120 | if step >= num_val_batches: |
| 120 | break | 121 | break |
