diff options
-rw-r--r-- | train_ti.py | 21 | ||||
-rw-r--r-- | training/lr.py | 7 | ||||
-rw-r--r-- | training/optimization.py | 43 |
3 files changed, 42 insertions, 29 deletions
diff --git a/train_ti.py b/train_ti.py index d7696e5..b1f6a49 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -903,12 +903,21 @@ def main(): | |||
903 | 903 | ||
904 | text_encoder.eval() | 904 | text_encoder.eval() |
905 | 905 | ||
906 | cur_loss_val = AverageMeter() | ||
907 | cur_acc_val = AverageMeter() | ||
908 | |||
906 | with torch.inference_mode(): | 909 | with torch.inference_mode(): |
907 | for step, batch in enumerate(val_dataloader): | 910 | for step, batch in enumerate(val_dataloader): |
908 | loss, acc, bsz = loop(batch) | 911 | loss, acc, bsz = loop(batch) |
909 | 912 | ||
910 | avg_loss_val.update(loss.detach_(), bsz) | 913 | loss = loss.detach_() |
911 | avg_acc_val.update(acc.detach_(), bsz) | 914 | acc = acc.detach_() |
915 | |||
916 | cur_loss_val.update(loss, bsz) | ||
917 | cur_acc_val.update(acc, bsz) | ||
918 | |||
919 | avg_loss_val.update(loss, bsz) | ||
920 | avg_acc_val.update(acc, bsz) | ||
912 | 921 | ||
913 | local_progress_bar.update(1) | 922 | local_progress_bar.update(1) |
914 | global_progress_bar.update(1) | 923 | global_progress_bar.update(1) |
@@ -921,10 +930,10 @@ def main(): | |||
921 | } | 930 | } |
922 | local_progress_bar.set_postfix(**logs) | 931 | local_progress_bar.set_postfix(**logs) |
923 | 932 | ||
924 | accelerator.log({ | 933 | logs["val/cur_loss"] = cur_loss_val.avg.item() |
925 | "val/loss": avg_loss_val.avg.item(), | 934 | logs["val/cur_acc"] = cur_acc_val.avg.item() |
926 | "val/acc": avg_acc_val.avg.item(), | 935 | |
927 | }, step=global_step) | 936 | accelerator.log(logs, step=global_step) |
928 | 937 | ||
929 | local_progress_bar.clear() | 938 | local_progress_bar.clear() |
930 | global_progress_bar.clear() | 939 | global_progress_bar.clear() |
diff --git a/training/lr.py b/training/lr.py index c0e9b3f..0c5ce9e 100644 --- a/training/lr.py +++ b/training/lr.py | |||
@@ -90,6 +90,7 @@ class LRFinder(): | |||
90 | else: | 90 | else: |
91 | if smooth_f > 0: | 91 | if smooth_f > 0: |
92 | loss = smooth_f * loss + (1 - smooth_f) * losses[-1] | 92 | loss = smooth_f * loss + (1 - smooth_f) * losses[-1] |
93 | acc = smooth_f * acc + (1 - smooth_f) * accs[-1] | ||
93 | if loss < best_loss: | 94 | if loss < best_loss: |
94 | best_loss = loss | 95 | best_loss = loss |
95 | if acc > best_acc: | 96 | if acc > best_acc: |
@@ -132,9 +133,9 @@ class LRFinder(): | |||
132 | ax_loss.set_xlabel("Learning rate") | 133 | ax_loss.set_xlabel("Learning rate") |
133 | ax_loss.set_ylabel("Loss") | 134 | ax_loss.set_ylabel("Loss") |
134 | 135 | ||
135 | # ax_acc = ax_loss.twinx() | 136 | ax_acc = ax_loss.twinx() |
136 | # ax_acc.plot(lrs, accs, color='blue') | 137 | ax_acc.plot(lrs, accs, color='blue') |
137 | # ax_acc.set_ylabel("Accuracy") | 138 | ax_acc.set_ylabel("Accuracy") |
138 | 139 | ||
139 | print("LR suggestion: steepest gradient") | 140 | print("LR suggestion: steepest gradient") |
140 | min_grad_idx = None | 141 | min_grad_idx = None |
diff --git a/training/optimization.py b/training/optimization.py index a0c8673..dfee2b5 100644 --- a/training/optimization.py +++ b/training/optimization.py | |||
@@ -1,4 +1,7 @@ | |||
1 | import math | 1 | import math |
2 | from typing import Literal | ||
3 | |||
4 | import torch | ||
2 | from torch.optim.lr_scheduler import LambdaLR | 5 | from torch.optim.lr_scheduler import LambdaLR |
3 | 6 | ||
4 | from diffusers.utils import logging | 7 | from diffusers.utils import logging |
@@ -6,41 +9,41 @@ from diffusers.utils import logging | |||
6 | logger = logging.get_logger(__name__) | 9 | logger = logging.get_logger(__name__) |
7 | 10 | ||
8 | 11 | ||
9 | def get_one_cycle_schedule(optimizer, num_training_steps, annealing="cos", min_lr=0.04, mid_point=0.3, last_epoch=-1): | 12 | def get_one_cycle_schedule( |
10 | """ | 13 | optimizer: torch.optim.Optimizer, |
11 | Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after | 14 | num_training_steps: int, |
12 | a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer. | 15 | warmup: Literal["cos", "linear"] = "cos", |
13 | Args: | 16 | annealing: Literal["cos", "half_cos", "linear"] = "cos", |
14 | optimizer ([`~torch.optim.Optimizer`]): | 17 | min_lr: int = 0.04, |
15 | The optimizer for which to schedule the learning rate. | 18 | mid_point: int = 0.3, |
16 | num_training_steps (`int`): | 19 | last_epoch: int = -1 |
17 | The total number of training steps. | 20 | ): |
18 | last_epoch (`int`, *optional*, defaults to -1): | ||
19 | The index of the last epoch when resuming training. | ||
20 | Return: | ||
21 | `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. | ||
22 | """ | ||
23 | |||
24 | def lr_lambda(current_step: int): | 21 | def lr_lambda(current_step: int): |
25 | thresh_up = int(num_training_steps * min(mid_point, 0.5)) | 22 | thresh_up = int(num_training_steps * min(mid_point, 0.5)) |
26 | 23 | ||
27 | if current_step < thresh_up: | 24 | if current_step < thresh_up: |
28 | return min_lr + float(current_step) / float(max(1, thresh_up)) * (1 - min_lr) | 25 | progress = float(current_step) / float(max(1, thresh_up)) |
26 | |||
27 | if warmup == "linear": | ||
28 | return min_lr + progress * (1 - min_lr) | ||
29 | |||
30 | return min_lr + 0.5 * (1.0 + math.cos(math.pi * (1 + progress))) | ||
29 | 31 | ||
30 | if annealing == "linear": | 32 | if annealing == "linear": |
31 | thresh_down = thresh_up * 2 | 33 | thresh_down = thresh_up * 2 |
32 | 34 | ||
33 | if current_step < thresh_down: | 35 | if current_step < thresh_down: |
34 | return min_lr + float(thresh_down - current_step) / float(max(1, thresh_down - thresh_up)) * (1 - min_lr) | 36 | progress = float(thresh_down - current_step) / float(max(1, thresh_down - thresh_up)) |
37 | return min_lr + progress * (1 - min_lr) | ||
35 | 38 | ||
36 | progress = float(num_training_steps - current_step) / float(max(1, num_training_steps - thresh_down)) | 39 | progress = float(num_training_steps - current_step) / float(max(1, num_training_steps - thresh_down)) |
37 | return max(0.0, progress) * min_lr | 40 | return progress * min_lr |
38 | 41 | ||
39 | progress = float(current_step - thresh_up) / float(max(1, num_training_steps - thresh_up)) | 42 | progress = float(current_step - thresh_up) / float(max(1, num_training_steps - thresh_up)) |
40 | 43 | ||
41 | if annealing == "half_cos": | 44 | if annealing == "half_cos": |
42 | return max(0.0, 1.0 + math.cos(math.pi * (0.5 + 0.5 * progress))) | 45 | return 1.0 + math.cos(math.pi * (0.5 + 0.5 * progress)) |
43 | 46 | ||
44 | return max(0.0, 0.5 * (1.0 + math.cos(math.pi * progress))) | 47 | return 0.5 * (1.0 + math.cos(math.pi * progress)) |
45 | 48 | ||
46 | return LambdaLR(optimizer, lr_lambda, last_epoch) | 49 | return LambdaLR(optimizer, lr_lambda, last_epoch) |