diff options
Diffstat (limited to 'training')
-rw-r--r-- | training/lr.py | 4 |
1 files changed, 2 insertions, 2 deletions
diff --git a/training/lr.py b/training/lr.py index dfb1743..01f7f5e 100644 --- a/training/lr.py +++ b/training/lr.py | |||
@@ -26,7 +26,7 @@ class LRFinder(): | |||
26 | val_dataloader, | 26 | val_dataloader, |
27 | loss_fn: Union[Callable[[int, Any], Tuple[Any, Any, int]], Callable[[int, Any, bool], Tuple[Any, Any, int]]], | 27 | loss_fn: Union[Callable[[int, Any], Tuple[Any, Any, int]], Callable[[int, Any, bool], Tuple[Any, Any, int]]], |
28 | on_train: Callable[[], _GeneratorContextManager] = nullcontext, | 28 | on_train: Callable[[], _GeneratorContextManager] = nullcontext, |
29 | on_clip: Callable[[], None] = noop, | 29 | on_clip: Callable[[float], None] = noop, |
30 | on_eval: Callable[[], _GeneratorContextManager] = nullcontext | 30 | on_eval: Callable[[], _GeneratorContextManager] = nullcontext |
31 | ): | 31 | ): |
32 | self.accelerator = accelerator | 32 | self.accelerator = accelerator |
@@ -95,7 +95,7 @@ class LRFinder(): | |||
95 | self.accelerator.backward(loss) | 95 | self.accelerator.backward(loss) |
96 | 96 | ||
97 | if self.accelerator.sync_gradients: | 97 | if self.accelerator.sync_gradients: |
98 | self.on_clip() | 98 | self.on_clip(lr_scheduler.get_last_lr()[0]) |
99 | 99 | ||
100 | self.optimizer.step() | 100 | self.optimizer.step() |
101 | lr_scheduler.step() | 101 | lr_scheduler.step() |