diff options
-rw-r--r-- | training/lr.py | 14 | ||||
-rw-r--r-- | training/optimization.py | 10 |
2 files changed, 15 insertions, 9 deletions
diff --git a/training/lr.py b/training/lr.py index c1fa3a0..c0e9b3f 100644 --- a/training/lr.py +++ b/training/lr.py | |||
@@ -19,8 +19,8 @@ class LRFinder(): | |||
19 | self.val_dataloader = val_dataloader | 19 | self.val_dataloader = val_dataloader |
20 | self.loss_fn = loss_fn | 20 | self.loss_fn = loss_fn |
21 | 21 | ||
22 | self.model_state = copy.deepcopy(model.state_dict()) | 22 | # self.model_state = copy.deepcopy(model.state_dict()) |
23 | self.optimizer_state = copy.deepcopy(optimizer.state_dict()) | 23 | # self.optimizer_state = copy.deepcopy(optimizer.state_dict()) |
24 | 24 | ||
25 | def run(self, min_lr, skip_start=10, skip_end=5, num_epochs=100, num_train_batches=1, num_val_batches=math.inf, smooth_f=0.05, diverge_th=5): | 25 | def run(self, min_lr, skip_start=10, skip_end=5, num_epochs=100, num_train_batches=1, num_val_batches=math.inf, smooth_f=0.05, diverge_th=5): |
26 | best_loss = None | 26 | best_loss = None |
@@ -109,8 +109,8 @@ class LRFinder(): | |||
109 | "lr": lr, | 109 | "lr": lr, |
110 | }) | 110 | }) |
111 | 111 | ||
112 | self.model.load_state_dict(self.model_state) | 112 | # self.model.load_state_dict(self.model_state) |
113 | self.optimizer.load_state_dict(self.optimizer_state) | 113 | # self.optimizer.load_state_dict(self.optimizer_state) |
114 | 114 | ||
115 | if loss > diverge_th * best_loss: | 115 | if loss > diverge_th * best_loss: |
116 | print("Stopping early, the loss has diverged") | 116 | print("Stopping early, the loss has diverged") |
@@ -127,12 +127,14 @@ class LRFinder(): | |||
127 | 127 | ||
128 | fig, ax_loss = plt.subplots() | 128 | fig, ax_loss = plt.subplots() |
129 | 129 | ||
130 | ax_loss.plot(lrs, losses, color='red', label='Loss') | 130 | ax_loss.plot(lrs, losses, color='red') |
131 | ax_loss.set_xscale("log") | 131 | ax_loss.set_xscale("log") |
132 | ax_loss.set_xlabel("Learning rate") | 132 | ax_loss.set_xlabel("Learning rate") |
133 | ax_loss.set_ylabel("Loss") | ||
133 | 134 | ||
134 | # ax_acc = ax_loss.twinx() | 135 | # ax_acc = ax_loss.twinx() |
135 | # ax_acc.plot(lrs, accs, color='blue', label='Accuracy') | 136 | # ax_acc.plot(lrs, accs, color='blue') |
137 | # ax_acc.set_ylabel("Accuracy") | ||
136 | 138 | ||
137 | print("LR suggestion: steepest gradient") | 139 | print("LR suggestion: steepest gradient") |
138 | min_grad_idx = None | 140 | min_grad_idx = None |
diff --git a/training/optimization.py b/training/optimization.py index 3809f3b..a0c8673 100644 --- a/training/optimization.py +++ b/training/optimization.py | |||
@@ -6,7 +6,7 @@ from diffusers.utils import logging | |||
6 | logger = logging.get_logger(__name__) | 6 | logger = logging.get_logger(__name__) |
7 | 7 | ||
8 | 8 | ||
9 | def get_one_cycle_schedule(optimizer, num_training_steps, annealing="cos", min_lr=0.01, mid_point=0.4, last_epoch=-1): | 9 | def get_one_cycle_schedule(optimizer, num_training_steps, annealing="cos", min_lr=0.04, mid_point=0.3, last_epoch=-1): |
10 | """ | 10 | """ |
11 | Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after | 11 | Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after |
12 | a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer. | 12 | a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer. |
@@ -35,8 +35,12 @@ def get_one_cycle_schedule(optimizer, num_training_steps, annealing="cos", min_l | |||
35 | 35 | ||
36 | progress = float(num_training_steps - current_step) / float(max(1, num_training_steps - thresh_down)) | 36 | progress = float(num_training_steps - current_step) / float(max(1, num_training_steps - thresh_down)) |
37 | return max(0.0, progress) * min_lr | 37 | return max(0.0, progress) * min_lr |
38 | else: | 38 | |
39 | progress = float(current_step - thresh_up) / float(max(1, num_training_steps - thresh_up)) | 39 | progress = float(current_step - thresh_up) / float(max(1, num_training_steps - thresh_up)) |
40 | |||
41 | if annealing == "half_cos": | ||
40 | return max(0.0, 1.0 + math.cos(math.pi * (0.5 + 0.5 * progress))) | 42 | return max(0.0, 1.0 + math.cos(math.pi * (0.5 + 0.5 * progress))) |
41 | 43 | ||
44 | return max(0.0, 0.5 * (1.0 + math.cos(math.pi * progress))) | ||
45 | |||
42 | return LambdaLR(optimizer, lr_lambda, last_epoch) | 46 | return LambdaLR(optimizer, lr_lambda, last_epoch) |