summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--training/lr.py14
-rw-r--r--training/optimization.py10
2 files changed, 15 insertions, 9 deletions
diff --git a/training/lr.py b/training/lr.py
index c1fa3a0..c0e9b3f 100644
--- a/training/lr.py
+++ b/training/lr.py
@@ -19,8 +19,8 @@ class LRFinder():
19 self.val_dataloader = val_dataloader 19 self.val_dataloader = val_dataloader
20 self.loss_fn = loss_fn 20 self.loss_fn = loss_fn
21 21
22 self.model_state = copy.deepcopy(model.state_dict()) 22 # self.model_state = copy.deepcopy(model.state_dict())
23 self.optimizer_state = copy.deepcopy(optimizer.state_dict()) 23 # self.optimizer_state = copy.deepcopy(optimizer.state_dict())
24 24
25 def run(self, min_lr, skip_start=10, skip_end=5, num_epochs=100, num_train_batches=1, num_val_batches=math.inf, smooth_f=0.05, diverge_th=5): 25 def run(self, min_lr, skip_start=10, skip_end=5, num_epochs=100, num_train_batches=1, num_val_batches=math.inf, smooth_f=0.05, diverge_th=5):
26 best_loss = None 26 best_loss = None
@@ -109,8 +109,8 @@ class LRFinder():
109 "lr": lr, 109 "lr": lr,
110 }) 110 })
111 111
112 self.model.load_state_dict(self.model_state) 112 # self.model.load_state_dict(self.model_state)
113 self.optimizer.load_state_dict(self.optimizer_state) 113 # self.optimizer.load_state_dict(self.optimizer_state)
114 114
115 if loss > diverge_th * best_loss: 115 if loss > diverge_th * best_loss:
116 print("Stopping early, the loss has diverged") 116 print("Stopping early, the loss has diverged")
@@ -127,12 +127,14 @@ class LRFinder():
127 127
128 fig, ax_loss = plt.subplots() 128 fig, ax_loss = plt.subplots()
129 129
130 ax_loss.plot(lrs, losses, color='red', label='Loss') 130 ax_loss.plot(lrs, losses, color='red')
131 ax_loss.set_xscale("log") 131 ax_loss.set_xscale("log")
132 ax_loss.set_xlabel("Learning rate") 132 ax_loss.set_xlabel("Learning rate")
133 ax_loss.set_ylabel("Loss")
133 134
134 # ax_acc = ax_loss.twinx() 135 # ax_acc = ax_loss.twinx()
135 # ax_acc.plot(lrs, accs, color='blue', label='Accuracy') 136 # ax_acc.plot(lrs, accs, color='blue')
137 # ax_acc.set_ylabel("Accuracy")
136 138
137 print("LR suggestion: steepest gradient") 139 print("LR suggestion: steepest gradient")
138 min_grad_idx = None 140 min_grad_idx = None
diff --git a/training/optimization.py b/training/optimization.py
index 3809f3b..a0c8673 100644
--- a/training/optimization.py
+++ b/training/optimization.py
@@ -6,7 +6,7 @@ from diffusers.utils import logging
6logger = logging.get_logger(__name__) 6logger = logging.get_logger(__name__)
7 7
8 8
9def get_one_cycle_schedule(optimizer, num_training_steps, annealing="cos", min_lr=0.01, mid_point=0.4, last_epoch=-1): 9def get_one_cycle_schedule(optimizer, num_training_steps, annealing="cos", min_lr=0.04, mid_point=0.3, last_epoch=-1):
10 """ 10 """
11 Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after 11 Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after
12 a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer. 12 a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer.
@@ -35,8 +35,12 @@ def get_one_cycle_schedule(optimizer, num_training_steps, annealing="cos", min_l
35 35
36 progress = float(num_training_steps - current_step) / float(max(1, num_training_steps - thresh_down)) 36 progress = float(num_training_steps - current_step) / float(max(1, num_training_steps - thresh_down))
37 return max(0.0, progress) * min_lr 37 return max(0.0, progress) * min_lr
38 else: 38
39 progress = float(current_step - thresh_up) / float(max(1, num_training_steps - thresh_up)) 39 progress = float(current_step - thresh_up) / float(max(1, num_training_steps - thresh_up))
40
41 if annealing == "half_cos":
40 return max(0.0, 1.0 + math.cos(math.pi * (0.5 + 0.5 * progress))) 42 return max(0.0, 1.0 + math.cos(math.pi * (0.5 + 0.5 * progress)))
41 43
44 return max(0.0, 0.5 * (1.0 + math.cos(math.pi * progress)))
45
42 return LambdaLR(optimizer, lr_lambda, last_epoch) 46 return LambdaLR(optimizer, lr_lambda, last_epoch)