From f87d9fdf541b0282249ddde1dc0302317350f998 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Thu, 29 Dec 2022 15:28:02 +0100 Subject: Update --- training/lr.py | 13 +++++-------- training/optimization.py | 14 +++++++++++--- 2 files changed, 16 insertions(+), 11 deletions(-) (limited to 'training') diff --git a/training/lr.py b/training/lr.py index 0c5ce9e..ef01906 100644 --- a/training/lr.py +++ b/training/lr.py @@ -43,6 +43,9 @@ class LRFinder(): ) progress_bar.set_description("Epoch X / Y") + train_workload = [batch for i, batch in enumerate(self.train_dataloader) if i < num_train_batches] + val_workload = [batch for i, batch in enumerate(self.val_dataloader) if i < num_val_batches] + for epoch in range(num_epochs): progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") @@ -51,10 +54,7 @@ class LRFinder(): self.model.train() - for step, batch in enumerate(self.train_dataloader): - if step >= num_train_batches: - break - + for batch in train_workload: with self.accelerator.accumulate(self.model): loss, acc, bsz = self.loss_fn(batch) @@ -69,10 +69,7 @@ class LRFinder(): self.model.eval() with torch.inference_mode(): - for step, batch in enumerate(self.val_dataloader): - if step >= num_val_batches: - break - + for batch in val_workload: loss, acc, bsz = self.loss_fn(batch) avg_loss.update(loss.detach_(), bsz) avg_acc.update(acc.detach_(), bsz) diff --git a/training/optimization.py b/training/optimization.py index dfee2b5..3340544 100644 --- a/training/optimization.py +++ b/training/optimization.py @@ -14,6 +14,8 @@ def get_one_cycle_schedule( num_training_steps: int, warmup: Literal["cos", "linear"] = "cos", annealing: Literal["cos", "half_cos", "linear"] = "cos", + warmup_exp: int = 1, + annealing_exp: int = 2, min_lr: int = 0.04, mid_point: int = 0.3, last_epoch: int = -1 @@ -27,7 +29,9 @@ def get_one_cycle_schedule( if warmup == "linear": return min_lr + progress * (1 - min_lr) - return min_lr + 0.5 * (1.0 + math.cos(math.pi * (1 + progress))) + lr = 0.5 * (1.0 + math.cos(math.pi * (1 + progress))) + lr = lr ** warmup_exp + return min_lr + lr * (1 - min_lr) if annealing == "linear": thresh_down = thresh_up * 2 @@ -42,8 +46,12 @@ def get_one_cycle_schedule( progress = float(current_step - thresh_up) / float(max(1, num_training_steps - thresh_up)) if annealing == "half_cos": - return 1.0 + math.cos(math.pi * (0.5 + 0.5 * progress)) + lr = 1.0 + math.cos(math.pi * (0.5 + 0.5 * progress)) + lr = lr ** annealing_exp + return lr - return 0.5 * (1.0 + math.cos(math.pi * progress)) + lr = 0.5 * (1.0 + math.cos(math.pi * progress)) + lr = lr ** annealing_exp + return lr return LambdaLR(optimizer, lr_lambda, last_epoch) -- cgit v1.2.3-70-g09d2