diff options
Diffstat (limited to 'training')
-rw-r--r-- | training/lr.py | 11 |
1 files changed, 7 insertions, 4 deletions
diff --git a/training/lr.py b/training/lr.py index 01f7f5e..84e30a0 100644 --- a/training/lr.py +++ b/training/lr.py | |||
@@ -26,7 +26,8 @@ class LRFinder(): | |||
26 | val_dataloader, | 26 | val_dataloader, |
27 | loss_fn: Union[Callable[[int, Any], Tuple[Any, Any, int]], Callable[[int, Any, bool], Tuple[Any, Any, int]]], | 27 | loss_fn: Union[Callable[[int, Any], Tuple[Any, Any, int]], Callable[[int, Any, bool], Tuple[Any, Any, int]]], |
28 | on_train: Callable[[], _GeneratorContextManager] = nullcontext, | 28 | on_train: Callable[[], _GeneratorContextManager] = nullcontext, |
29 | on_clip: Callable[[float], None] = noop, | 29 | on_before_optimize: Callable[[], None] = noop, |
30 | on_after_optimize: Callable[[float], None] = noop, | ||
30 | on_eval: Callable[[], _GeneratorContextManager] = nullcontext | 31 | on_eval: Callable[[], _GeneratorContextManager] = nullcontext |
31 | ): | 32 | ): |
32 | self.accelerator = accelerator | 33 | self.accelerator = accelerator |
@@ -36,7 +37,8 @@ class LRFinder(): | |||
36 | self.val_dataloader = val_dataloader | 37 | self.val_dataloader = val_dataloader |
37 | self.loss_fn = loss_fn | 38 | self.loss_fn = loss_fn |
38 | self.on_train = on_train | 39 | self.on_train = on_train |
39 | self.on_clip = on_clip | 40 | self.on_before_optimize = on_before_optimize |
41 | self.on_after_optimize = on_after_optimize | ||
40 | self.on_eval = on_eval | 42 | self.on_eval = on_eval |
41 | 43 | ||
42 | # self.model_state = copy.deepcopy(model.state_dict()) | 44 | # self.model_state = copy.deepcopy(model.state_dict()) |
@@ -94,14 +96,15 @@ class LRFinder(): | |||
94 | 96 | ||
95 | self.accelerator.backward(loss) | 97 | self.accelerator.backward(loss) |
96 | 98 | ||
97 | if self.accelerator.sync_gradients: | 99 | self.on_before_optimize() |
98 | self.on_clip(lr_scheduler.get_last_lr()[0]) | ||
99 | 100 | ||
100 | self.optimizer.step() | 101 | self.optimizer.step() |
101 | lr_scheduler.step() | 102 | lr_scheduler.step() |
102 | self.optimizer.zero_grad(set_to_none=True) | 103 | self.optimizer.zero_grad(set_to_none=True) |
103 | 104 | ||
104 | if self.accelerator.sync_gradients: | 105 | if self.accelerator.sync_gradients: |
106 | self.on_after_optimize(lr_scheduler.get_last_lr()[0]) | ||
107 | |||
105 | progress_bar.update(1) | 108 | progress_bar.update(1) |
106 | 109 | ||
107 | self.model.eval() | 110 | self.model.eval() |