From 59bf501198d7ff6c0c03c45e92adef14069d5ac6 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sun, 15 Jan 2023 12:33:52 +0100 Subject: Update --- training/lr.py | 29 +++++++++++++++-------------- 1 file changed, 15 insertions(+), 14 deletions(-) (limited to 'training/lr.py') diff --git a/training/lr.py b/training/lr.py index 7584ba2..902c4eb 100644 --- a/training/lr.py +++ b/training/lr.py @@ -9,6 +9,7 @@ import torch from torch.optim.lr_scheduler import LambdaLR from tqdm.auto import tqdm +from training.functional import TrainingCallbacks from training.util import AverageMeter @@ -24,26 +25,19 @@ class LRFinder(): def __init__( self, accelerator, - model, optimizer, train_dataloader, val_dataloader, loss_fn: Union[Callable[[int, Any], Tuple[Any, Any, int]], Callable[[int, Any, bool], Tuple[Any, Any, int]]], - on_train: Callable[[int], _GeneratorContextManager] = noop_ctx, - on_before_optimize: Callable[[int], None] = noop, - on_after_optimize: Callable[[float], None] = noop, - on_eval: Callable[[], _GeneratorContextManager] = noop_ctx + callbacks: TrainingCallbacks = TrainingCallbacks() ): self.accelerator = accelerator - self.model = model + self.model = callbacks.on_model() self.optimizer = optimizer self.train_dataloader = train_dataloader self.val_dataloader = val_dataloader self.loss_fn = loss_fn - self.on_train = on_train - self.on_before_optimize = on_before_optimize - self.on_after_optimize = on_after_optimize - self.on_eval = on_eval + self.callbacks = callbacks # self.model_state = copy.deepcopy(model.state_dict()) # self.optimizer_state = copy.deepcopy(optimizer.state_dict()) @@ -82,6 +76,13 @@ class LRFinder(): ) progress_bar.set_description("Epoch X / Y") + self.callbacks.on_prepare() + + on_train = self.callbacks.on_train + on_before_optimize = self.callbacks.on_before_optimize + on_after_optimize = self.callbacks.on_after_optimize + on_eval = self.callbacks.on_eval + for epoch in range(num_epochs): progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") @@ -90,7 +91,7 @@ class LRFinder(): self.model.train() - with self.on_train(epoch): + with on_train(epoch): for step, batch in enumerate(self.train_dataloader): if step >= num_train_batches: break @@ -100,21 +101,21 @@ class LRFinder(): self.accelerator.backward(loss) - self.on_before_optimize(epoch) + on_before_optimize(epoch) self.optimizer.step() lr_scheduler.step() self.optimizer.zero_grad(set_to_none=True) if self.accelerator.sync_gradients: - self.on_after_optimize(lr_scheduler.get_last_lr()[0]) + on_after_optimize(lr_scheduler.get_last_lr()[0]) progress_bar.update(1) self.model.eval() with torch.inference_mode(): - with self.on_eval(): + with on_eval(): for step, batch in enumerate(self.val_dataloader): if step >= num_val_batches: break -- cgit v1.2.3-54-g00ecf