summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--training/lr.py13
-rw-r--r--training/optimization.py14
2 files changed, 16 insertions, 11 deletions
diff --git a/training/lr.py b/training/lr.py
index 0c5ce9e..ef01906 100644
--- a/training/lr.py
+++ b/training/lr.py
@@ -43,6 +43,9 @@ class LRFinder():
43 ) 43 )
44 progress_bar.set_description("Epoch X / Y") 44 progress_bar.set_description("Epoch X / Y")
45 45
46 train_workload = [batch for i, batch in enumerate(self.train_dataloader) if i < num_train_batches]
47 val_workload = [batch for i, batch in enumerate(self.val_dataloader) if i < num_val_batches]
48
46 for epoch in range(num_epochs): 49 for epoch in range(num_epochs):
47 progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") 50 progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}")
48 51
@@ -51,10 +54,7 @@ class LRFinder():
51 54
52 self.model.train() 55 self.model.train()
53 56
54 for step, batch in enumerate(self.train_dataloader): 57 for batch in train_workload:
55 if step >= num_train_batches:
56 break
57
58 with self.accelerator.accumulate(self.model): 58 with self.accelerator.accumulate(self.model):
59 loss, acc, bsz = self.loss_fn(batch) 59 loss, acc, bsz = self.loss_fn(batch)
60 60
@@ -69,10 +69,7 @@ class LRFinder():
69 self.model.eval() 69 self.model.eval()
70 70
71 with torch.inference_mode(): 71 with torch.inference_mode():
72 for step, batch in enumerate(self.val_dataloader): 72 for batch in val_workload:
73 if step >= num_val_batches:
74 break
75
76 loss, acc, bsz = self.loss_fn(batch) 73 loss, acc, bsz = self.loss_fn(batch)
77 avg_loss.update(loss.detach_(), bsz) 74 avg_loss.update(loss.detach_(), bsz)
78 avg_acc.update(acc.detach_(), bsz) 75 avg_acc.update(acc.detach_(), bsz)
diff --git a/training/optimization.py b/training/optimization.py
index dfee2b5..3340544 100644
--- a/training/optimization.py
+++ b/training/optimization.py
@@ -14,6 +14,8 @@ def get_one_cycle_schedule(
14 num_training_steps: int, 14 num_training_steps: int,
15 warmup: Literal["cos", "linear"] = "cos", 15 warmup: Literal["cos", "linear"] = "cos",
16 annealing: Literal["cos", "half_cos", "linear"] = "cos", 16 annealing: Literal["cos", "half_cos", "linear"] = "cos",
17 warmup_exp: int = 1,
18 annealing_exp: int = 2,
17 min_lr: int = 0.04, 19 min_lr: int = 0.04,
18 mid_point: int = 0.3, 20 mid_point: int = 0.3,
19 last_epoch: int = -1 21 last_epoch: int = -1
@@ -27,7 +29,9 @@ def get_one_cycle_schedule(
27 if warmup == "linear": 29 if warmup == "linear":
28 return min_lr + progress * (1 - min_lr) 30 return min_lr + progress * (1 - min_lr)
29 31
30 return min_lr + 0.5 * (1.0 + math.cos(math.pi * (1 + progress))) 32 lr = 0.5 * (1.0 + math.cos(math.pi * (1 + progress)))
33 lr = lr ** warmup_exp
34 return min_lr + lr * (1 - min_lr)
31 35
32 if annealing == "linear": 36 if annealing == "linear":
33 thresh_down = thresh_up * 2 37 thresh_down = thresh_up * 2
@@ -42,8 +46,12 @@ def get_one_cycle_schedule(
42 progress = float(current_step - thresh_up) / float(max(1, num_training_steps - thresh_up)) 46 progress = float(current_step - thresh_up) / float(max(1, num_training_steps - thresh_up))
43 47
44 if annealing == "half_cos": 48 if annealing == "half_cos":
45 return 1.0 + math.cos(math.pi * (0.5 + 0.5 * progress)) 49 lr = 1.0 + math.cos(math.pi * (0.5 + 0.5 * progress))
50 lr = lr ** annealing_exp
51 return lr
46 52
47 return 0.5 * (1.0 + math.cos(math.pi * progress)) 53 lr = 0.5 * (1.0 + math.cos(math.pi * progress))
54 lr = lr ** annealing_exp
55 return lr
48 56
49 return LambdaLR(optimizer, lr_lambda, last_epoch) 57 return LambdaLR(optimizer, lr_lambda, last_epoch)