diff options
author | Volpeon <git@volpeon.ink> | 2022-12-29 09:00:19 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2022-12-29 09:00:19 +0100 |
commit | 4d3d318a4168ef79847737cef2c0ad8a4dafd3e7 (patch) | |
tree | 967e2c1ee6e2c29b9b6ffaff3e8978f4a43a529d /training | |
parent | Updated 1-cycle scheduler (diff) | |
download | textual-inversion-diff-4d3d318a4168ef79847737cef2c0ad8a4dafd3e7.tar.gz textual-inversion-diff-4d3d318a4168ef79847737cef2c0ad8a4dafd3e7.tar.bz2 textual-inversion-diff-4d3d318a4168ef79847737cef2c0ad8a4dafd3e7.zip |
Training improvements
Diffstat (limited to 'training')
-rw-r--r-- | training/lr.py | 7 | ||||
-rw-r--r-- | training/optimization.py | 43 |
2 files changed, 27 insertions, 23 deletions
diff --git a/training/lr.py b/training/lr.py index c0e9b3f..0c5ce9e 100644 --- a/training/lr.py +++ b/training/lr.py | |||
@@ -90,6 +90,7 @@ class LRFinder(): | |||
90 | else: | 90 | else: |
91 | if smooth_f > 0: | 91 | if smooth_f > 0: |
92 | loss = smooth_f * loss + (1 - smooth_f) * losses[-1] | 92 | loss = smooth_f * loss + (1 - smooth_f) * losses[-1] |
93 | acc = smooth_f * acc + (1 - smooth_f) * accs[-1] | ||
93 | if loss < best_loss: | 94 | if loss < best_loss: |
94 | best_loss = loss | 95 | best_loss = loss |
95 | if acc > best_acc: | 96 | if acc > best_acc: |
@@ -132,9 +133,9 @@ class LRFinder(): | |||
132 | ax_loss.set_xlabel("Learning rate") | 133 | ax_loss.set_xlabel("Learning rate") |
133 | ax_loss.set_ylabel("Loss") | 134 | ax_loss.set_ylabel("Loss") |
134 | 135 | ||
135 | # ax_acc = ax_loss.twinx() | 136 | ax_acc = ax_loss.twinx() |
136 | # ax_acc.plot(lrs, accs, color='blue') | 137 | ax_acc.plot(lrs, accs, color='blue') |
137 | # ax_acc.set_ylabel("Accuracy") | 138 | ax_acc.set_ylabel("Accuracy") |
138 | 139 | ||
139 | print("LR suggestion: steepest gradient") | 140 | print("LR suggestion: steepest gradient") |
140 | min_grad_idx = None | 141 | min_grad_idx = None |
diff --git a/training/optimization.py b/training/optimization.py index a0c8673..dfee2b5 100644 --- a/training/optimization.py +++ b/training/optimization.py | |||
@@ -1,4 +1,7 @@ | |||
1 | import math | 1 | import math |
2 | from typing import Literal | ||
3 | |||
4 | import torch | ||
2 | from torch.optim.lr_scheduler import LambdaLR | 5 | from torch.optim.lr_scheduler import LambdaLR |
3 | 6 | ||
4 | from diffusers.utils import logging | 7 | from diffusers.utils import logging |
@@ -6,41 +9,41 @@ from diffusers.utils import logging | |||
6 | logger = logging.get_logger(__name__) | 9 | logger = logging.get_logger(__name__) |
7 | 10 | ||
8 | 11 | ||
9 | def get_one_cycle_schedule(optimizer, num_training_steps, annealing="cos", min_lr=0.04, mid_point=0.3, last_epoch=-1): | 12 | def get_one_cycle_schedule( |
10 | """ | 13 | optimizer: torch.optim.Optimizer, |
11 | Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after | 14 | num_training_steps: int, |
12 | a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer. | 15 | warmup: Literal["cos", "linear"] = "cos", |
13 | Args: | 16 | annealing: Literal["cos", "half_cos", "linear"] = "cos", |
14 | optimizer ([`~torch.optim.Optimizer`]): | 17 | min_lr: int = 0.04, |
15 | The optimizer for which to schedule the learning rate. | 18 | mid_point: int = 0.3, |
16 | num_training_steps (`int`): | 19 | last_epoch: int = -1 |
17 | The total number of training steps. | 20 | ): |
18 | last_epoch (`int`, *optional*, defaults to -1): | ||
19 | The index of the last epoch when resuming training. | ||
20 | Return: | ||
21 | `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. | ||
22 | """ | ||
23 | |||
24 | def lr_lambda(current_step: int): | 21 | def lr_lambda(current_step: int): |
25 | thresh_up = int(num_training_steps * min(mid_point, 0.5)) | 22 | thresh_up = int(num_training_steps * min(mid_point, 0.5)) |
26 | 23 | ||
27 | if current_step < thresh_up: | 24 | if current_step < thresh_up: |
28 | return min_lr + float(current_step) / float(max(1, thresh_up)) * (1 - min_lr) | 25 | progress = float(current_step) / float(max(1, thresh_up)) |
26 | |||
27 | if warmup == "linear": | ||
28 | return min_lr + progress * (1 - min_lr) | ||
29 | |||
30 | return min_lr + 0.5 * (1.0 + math.cos(math.pi * (1 + progress))) | ||
29 | 31 | ||
30 | if annealing == "linear": | 32 | if annealing == "linear": |
31 | thresh_down = thresh_up * 2 | 33 | thresh_down = thresh_up * 2 |
32 | 34 | ||
33 | if current_step < thresh_down: | 35 | if current_step < thresh_down: |
34 | return min_lr + float(thresh_down - current_step) / float(max(1, thresh_down - thresh_up)) * (1 - min_lr) | 36 | progress = float(thresh_down - current_step) / float(max(1, thresh_down - thresh_up)) |
37 | return min_lr + progress * (1 - min_lr) | ||
35 | 38 | ||
36 | progress = float(num_training_steps - current_step) / float(max(1, num_training_steps - thresh_down)) | 39 | progress = float(num_training_steps - current_step) / float(max(1, num_training_steps - thresh_down)) |
37 | return max(0.0, progress) * min_lr | 40 | return progress * min_lr |
38 | 41 | ||
39 | progress = float(current_step - thresh_up) / float(max(1, num_training_steps - thresh_up)) | 42 | progress = float(current_step - thresh_up) / float(max(1, num_training_steps - thresh_up)) |
40 | 43 | ||
41 | if annealing == "half_cos": | 44 | if annealing == "half_cos": |
42 | return max(0.0, 1.0 + math.cos(math.pi * (0.5 + 0.5 * progress))) | 45 | return 1.0 + math.cos(math.pi * (0.5 + 0.5 * progress)) |
43 | 46 | ||
44 | return max(0.0, 0.5 * (1.0 + math.cos(math.pi * progress))) | 47 | return 0.5 * (1.0 + math.cos(math.pi * progress)) |
45 | 48 | ||
46 | return LambdaLR(optimizer, lr_lambda, last_epoch) | 49 | return LambdaLR(optimizer, lr_lambda, last_epoch) |