From f963d4cba5c4c6575d77be80621a40b615603ca3 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Thu, 12 Jan 2023 13:50:22 +0100 Subject: Update --- training/lr.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) (limited to 'training') diff --git a/training/lr.py b/training/lr.py index 01f7f5e..84e30a0 100644 --- a/training/lr.py +++ b/training/lr.py @@ -26,7 +26,8 @@ class LRFinder(): val_dataloader, loss_fn: Union[Callable[[int, Any], Tuple[Any, Any, int]], Callable[[int, Any, bool], Tuple[Any, Any, int]]], on_train: Callable[[], _GeneratorContextManager] = nullcontext, - on_clip: Callable[[float], None] = noop, + on_before_optimize: Callable[[], None] = noop, + on_after_optimize: Callable[[float], None] = noop, on_eval: Callable[[], _GeneratorContextManager] = nullcontext ): self.accelerator = accelerator @@ -36,7 +37,8 @@ class LRFinder(): self.val_dataloader = val_dataloader self.loss_fn = loss_fn self.on_train = on_train - self.on_clip = on_clip + self.on_before_optimize = on_before_optimize + self.on_after_optimize = on_after_optimize self.on_eval = on_eval # self.model_state = copy.deepcopy(model.state_dict()) @@ -94,14 +96,15 @@ class LRFinder(): self.accelerator.backward(loss) - if self.accelerator.sync_gradients: - self.on_clip(lr_scheduler.get_last_lr()[0]) + self.on_before_optimize() self.optimizer.step() lr_scheduler.step() self.optimizer.zero_grad(set_to_none=True) if self.accelerator.sync_gradients: + self.on_after_optimize(lr_scheduler.get_last_lr()[0]) + progress_bar.update(1) self.model.eval() -- cgit v1.2.3-70-g09d2