From 3575d041f1507811b577fd2c653171fb51c0a386 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Fri, 20 Jan 2023 14:26:17 +0100 Subject: Restored LR finder --- training/optimization.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) (limited to 'training/optimization.py') diff --git a/training/optimization.py b/training/optimization.py index 6dee4bc..6c9a35d 100644 --- a/training/optimization.py +++ b/training/optimization.py @@ -87,6 +87,15 @@ def get_one_cycle_schedule( return LambdaLR(optimizer, lr_lambda, last_epoch) +def get_exponential_growing_schedule(optimizer, end_lr: float, num_training_steps: int, last_epoch: int = -1): + def lr_lambda(base_lr: float, current_step: int): + return (end_lr / base_lr) ** (current_step / num_training_steps) + + lr_lambdas = [partial(lr_lambda, group["lr"]) for group in optimizer.param_groups] + + return LambdaLR(optimizer, lr_lambdas, last_epoch) + + def get_scheduler( id: str, optimizer: torch.optim.Optimizer, @@ -97,6 +106,7 @@ def get_scheduler( annealing_func: Literal["cos", "half_cos", "linear"] = "cos", warmup_exp: int = 1, annealing_exp: int = 1, + end_lr: float = 1e3, cycles: int = 1, train_epochs: int = 100, warmup_epochs: int = 10, @@ -117,6 +127,15 @@ def get_scheduler( annealing_exp=annealing_exp, min_lr=min_lr, ) + elif id == "exponential_growth": + if cycles is None: + cycles = math.ceil(math.sqrt(((num_training_steps - num_warmup_steps) / num_training_steps_per_epoch))) + + lr_scheduler = get_exponential_growing_schedule( + optimizer=optimizer, + end_lr=end_lr, + num_training_steps=num_training_steps, + ) elif id == "cosine_with_restarts": if cycles is None: cycles = math.ceil(math.sqrt(((num_training_steps - num_warmup_steps) / num_training_steps_per_epoch))) -- cgit v1.2.3-54-g00ecf