diff options
-rw-r--r-- | training/lr.py | 13 | ||||
-rw-r--r-- | training/optimization.py | 14 |
2 files changed, 16 insertions, 11 deletions
diff --git a/training/lr.py b/training/lr.py index 0c5ce9e..ef01906 100644 --- a/training/lr.py +++ b/training/lr.py | |||
@@ -43,6 +43,9 @@ class LRFinder(): | |||
43 | ) | 43 | ) |
44 | progress_bar.set_description("Epoch X / Y") | 44 | progress_bar.set_description("Epoch X / Y") |
45 | 45 | ||
46 | train_workload = [batch for i, batch in enumerate(self.train_dataloader) if i < num_train_batches] | ||
47 | val_workload = [batch for i, batch in enumerate(self.val_dataloader) if i < num_val_batches] | ||
48 | |||
46 | for epoch in range(num_epochs): | 49 | for epoch in range(num_epochs): |
47 | progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") | 50 | progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") |
48 | 51 | ||
@@ -51,10 +54,7 @@ class LRFinder(): | |||
51 | 54 | ||
52 | self.model.train() | 55 | self.model.train() |
53 | 56 | ||
54 | for step, batch in enumerate(self.train_dataloader): | 57 | for batch in train_workload: |
55 | if step >= num_train_batches: | ||
56 | break | ||
57 | |||
58 | with self.accelerator.accumulate(self.model): | 58 | with self.accelerator.accumulate(self.model): |
59 | loss, acc, bsz = self.loss_fn(batch) | 59 | loss, acc, bsz = self.loss_fn(batch) |
60 | 60 | ||
@@ -69,10 +69,7 @@ class LRFinder(): | |||
69 | self.model.eval() | 69 | self.model.eval() |
70 | 70 | ||
71 | with torch.inference_mode(): | 71 | with torch.inference_mode(): |
72 | for step, batch in enumerate(self.val_dataloader): | 72 | for batch in val_workload: |
73 | if step >= num_val_batches: | ||
74 | break | ||
75 | |||
76 | loss, acc, bsz = self.loss_fn(batch) | 73 | loss, acc, bsz = self.loss_fn(batch) |
77 | avg_loss.update(loss.detach_(), bsz) | 74 | avg_loss.update(loss.detach_(), bsz) |
78 | avg_acc.update(acc.detach_(), bsz) | 75 | avg_acc.update(acc.detach_(), bsz) |
diff --git a/training/optimization.py b/training/optimization.py index dfee2b5..3340544 100644 --- a/training/optimization.py +++ b/training/optimization.py | |||
@@ -14,6 +14,8 @@ def get_one_cycle_schedule( | |||
14 | num_training_steps: int, | 14 | num_training_steps: int, |
15 | warmup: Literal["cos", "linear"] = "cos", | 15 | warmup: Literal["cos", "linear"] = "cos", |
16 | annealing: Literal["cos", "half_cos", "linear"] = "cos", | 16 | annealing: Literal["cos", "half_cos", "linear"] = "cos", |
17 | warmup_exp: int = 1, | ||
18 | annealing_exp: int = 2, | ||
17 | min_lr: int = 0.04, | 19 | min_lr: int = 0.04, |
18 | mid_point: int = 0.3, | 20 | mid_point: int = 0.3, |
19 | last_epoch: int = -1 | 21 | last_epoch: int = -1 |
@@ -27,7 +29,9 @@ def get_one_cycle_schedule( | |||
27 | if warmup == "linear": | 29 | if warmup == "linear": |
28 | return min_lr + progress * (1 - min_lr) | 30 | return min_lr + progress * (1 - min_lr) |
29 | 31 | ||
30 | return min_lr + 0.5 * (1.0 + math.cos(math.pi * (1 + progress))) | 32 | lr = 0.5 * (1.0 + math.cos(math.pi * (1 + progress))) |
33 | lr = lr ** warmup_exp | ||
34 | return min_lr + lr * (1 - min_lr) | ||
31 | 35 | ||
32 | if annealing == "linear": | 36 | if annealing == "linear": |
33 | thresh_down = thresh_up * 2 | 37 | thresh_down = thresh_up * 2 |
@@ -42,8 +46,12 @@ def get_one_cycle_schedule( | |||
42 | progress = float(current_step - thresh_up) / float(max(1, num_training_steps - thresh_up)) | 46 | progress = float(current_step - thresh_up) / float(max(1, num_training_steps - thresh_up)) |
43 | 47 | ||
44 | if annealing == "half_cos": | 48 | if annealing == "half_cos": |
45 | return 1.0 + math.cos(math.pi * (0.5 + 0.5 * progress)) | 49 | lr = 1.0 + math.cos(math.pi * (0.5 + 0.5 * progress)) |
50 | lr = lr ** annealing_exp | ||
51 | return lr | ||
46 | 52 | ||
47 | return 0.5 * (1.0 + math.cos(math.pi * progress)) | 53 | lr = 0.5 * (1.0 + math.cos(math.pi * progress)) |
54 | lr = lr ** annealing_exp | ||
55 | return lr | ||
48 | 56 | ||
49 | return LambdaLR(optimizer, lr_lambda, last_epoch) | 57 | return LambdaLR(optimizer, lr_lambda, last_epoch) |