From 127ec21e5bd4e7df21e36c561d070f8b9a0e19f5 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Fri, 13 Jan 2023 18:59:26 +0100 Subject: More modularization --- training/lr.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) (limited to 'training/lr.py') diff --git a/training/lr.py b/training/lr.py index 84e30a0..7584ba2 100644 --- a/training/lr.py +++ b/training/lr.py @@ -16,6 +16,10 @@ def noop(*args, **kwards): pass +def noop_ctx(*args, **kwards): + return nullcontext() + + class LRFinder(): def __init__( self, @@ -25,10 +29,10 @@ class LRFinder(): train_dataloader, val_dataloader, loss_fn: Union[Callable[[int, Any], Tuple[Any, Any, int]], Callable[[int, Any, bool], Tuple[Any, Any, int]]], - on_train: Callable[[], _GeneratorContextManager] = nullcontext, - on_before_optimize: Callable[[], None] = noop, + on_train: Callable[[int], _GeneratorContextManager] = noop_ctx, + on_before_optimize: Callable[[int], None] = noop, on_after_optimize: Callable[[float], None] = noop, - on_eval: Callable[[], _GeneratorContextManager] = nullcontext + on_eval: Callable[[], _GeneratorContextManager] = noop_ctx ): self.accelerator = accelerator self.model = model @@ -86,7 +90,7 @@ class LRFinder(): self.model.train() - with self.on_train(): + with self.on_train(epoch): for step, batch in enumerate(self.train_dataloader): if step >= num_train_batches: break @@ -96,7 +100,7 @@ class LRFinder(): self.accelerator.backward(loss) - self.on_before_optimize() + self.on_before_optimize(epoch) self.optimizer.step() lr_scheduler.step() -- cgit v1.2.3-54-g00ecf