From a5e45e2c0dab95589e5fbaa4fe87d18484fbbe68 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Wed, 4 Jan 2023 22:06:05 +0100 Subject: Update --- training/lr.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) (limited to 'training') diff --git a/training/lr.py b/training/lr.py index c8dc040..3cdf994 100644 --- a/training/lr.py +++ b/training/lr.py @@ -12,7 +12,7 @@ from tqdm.auto import tqdm from training.util import AverageMeter -def noop(): +def noop(*args, **kwards): pass @@ -26,6 +26,7 @@ class LRFinder(): val_dataloader, loss_fn: Union[Callable[[int, Any], Tuple[Any, Any, int]], Callable[[int, Any, bool], Tuple[Any, Any, int]]], on_train: Callable[[], None] = noop, + on_clip: Callable[[], None] = noop, on_eval: Callable[[], None] = noop ): self.accelerator = accelerator @@ -35,6 +36,7 @@ class LRFinder(): self.val_dataloader = val_dataloader self.loss_fn = loss_fn self.on_train = on_train + self.on_clip = on_clip self.on_eval = on_eval # self.model_state = copy.deepcopy(model.state_dict()) @@ -93,6 +95,9 @@ class LRFinder(): self.accelerator.backward(loss) + if self.accelerator.sync_gradients: + self.on_clip() + self.optimizer.step() lr_scheduler.step() self.optimizer.zero_grad(set_to_none=True) -- cgit v1.2.3-70-g09d2