From 67aaba2159bcda4c0b8538b1580a40f01e8f0964 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Mon, 2 Jan 2023 17:34:11 +0100 Subject: Update --- training/lr.py | 33 +++++++++++++++++++++++++++++++-- 1 file changed, 31 insertions(+), 2 deletions(-) (limited to 'training') diff --git a/training/lr.py b/training/lr.py index 3abd2f2..fe166ed 100644 --- a/training/lr.py +++ b/training/lr.py @@ -1,5 +1,6 @@ import math import copy +from typing import Callable import matplotlib.pyplot as plt import numpy as np @@ -10,19 +11,45 @@ from tqdm.auto import tqdm from training.util import AverageMeter +def noop(): + pass + + class LRFinder(): - def __init__(self, accelerator, model, optimizer, train_dataloader, val_dataloader, loss_fn): + def __init__( + self, + accelerator, + model, + optimizer, + train_dataloader, + val_dataloader, + loss_fn, + on_train: Callable[[], None] = noop, + on_eval: Callable[[], None] = noop + ): self.accelerator = accelerator self.model = model self.optimizer = optimizer self.train_dataloader = train_dataloader self.val_dataloader = val_dataloader self.loss_fn = loss_fn + self.on_train = on_train + self.on_eval = on_eval # self.model_state = copy.deepcopy(model.state_dict()) # self.optimizer_state = copy.deepcopy(optimizer.state_dict()) - def run(self, min_lr, skip_start=10, skip_end=5, num_epochs=100, num_train_batches=1, num_val_batches=math.inf, smooth_f=0.05, diverge_th=5): + def run( + self, + min_lr, + skip_start: int = 10, + skip_end: int = 5, + num_epochs: int = 100, + num_train_batches: int = 1, + num_val_batches: int = math.inf, + smooth_f: float = 0.05, + diverge_th: int = 5 + ): best_loss = None best_acc = None @@ -50,6 +77,7 @@ class LRFinder(): avg_acc = AverageMeter() self.model.train() + self.on_train() for step, batch in enumerate(self.train_dataloader): if step >= num_train_batches: @@ -67,6 +95,7 @@ class LRFinder(): progress_bar.update(1) self.model.eval() + self.on_eval() with torch.inference_mode(): for step, batch in enumerate(self.val_dataloader): -- cgit v1.2.3-70-g09d2