diff options
| -rw-r--r-- | train_ti.py | 21 | ||||
| -rw-r--r-- | training/lr.py | 7 | ||||
| -rw-r--r-- | training/optimization.py | 43 | 
3 files changed, 42 insertions, 29 deletions
| diff --git a/train_ti.py b/train_ti.py index d7696e5..b1f6a49 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -903,12 +903,21 @@ def main(): | |||
| 903 | 903 | ||
| 904 | text_encoder.eval() | 904 | text_encoder.eval() | 
| 905 | 905 | ||
| 906 | cur_loss_val = AverageMeter() | ||
| 907 | cur_acc_val = AverageMeter() | ||
| 908 | |||
| 906 | with torch.inference_mode(): | 909 | with torch.inference_mode(): | 
| 907 | for step, batch in enumerate(val_dataloader): | 910 | for step, batch in enumerate(val_dataloader): | 
| 908 | loss, acc, bsz = loop(batch) | 911 | loss, acc, bsz = loop(batch) | 
| 909 | 912 | ||
| 910 | avg_loss_val.update(loss.detach_(), bsz) | 913 | loss = loss.detach_() | 
| 911 | avg_acc_val.update(acc.detach_(), bsz) | 914 | acc = acc.detach_() | 
| 915 | |||
| 916 | cur_loss_val.update(loss, bsz) | ||
| 917 | cur_acc_val.update(acc, bsz) | ||
| 918 | |||
| 919 | avg_loss_val.update(loss, bsz) | ||
| 920 | avg_acc_val.update(acc, bsz) | ||
| 912 | 921 | ||
| 913 | local_progress_bar.update(1) | 922 | local_progress_bar.update(1) | 
| 914 | global_progress_bar.update(1) | 923 | global_progress_bar.update(1) | 
| @@ -921,10 +930,10 @@ def main(): | |||
| 921 | } | 930 | } | 
| 922 | local_progress_bar.set_postfix(**logs) | 931 | local_progress_bar.set_postfix(**logs) | 
| 923 | 932 | ||
| 924 | accelerator.log({ | 933 | logs["val/cur_loss"] = cur_loss_val.avg.item() | 
| 925 | "val/loss": avg_loss_val.avg.item(), | 934 | logs["val/cur_acc"] = cur_acc_val.avg.item() | 
| 926 | "val/acc": avg_acc_val.avg.item(), | 935 | |
| 927 | }, step=global_step) | 936 | accelerator.log(logs, step=global_step) | 
| 928 | 937 | ||
| 929 | local_progress_bar.clear() | 938 | local_progress_bar.clear() | 
| 930 | global_progress_bar.clear() | 939 | global_progress_bar.clear() | 
| diff --git a/training/lr.py b/training/lr.py index c0e9b3f..0c5ce9e 100644 --- a/training/lr.py +++ b/training/lr.py | |||
| @@ -90,6 +90,7 @@ class LRFinder(): | |||
| 90 | else: | 90 | else: | 
| 91 | if smooth_f > 0: | 91 | if smooth_f > 0: | 
| 92 | loss = smooth_f * loss + (1 - smooth_f) * losses[-1] | 92 | loss = smooth_f * loss + (1 - smooth_f) * losses[-1] | 
| 93 | acc = smooth_f * acc + (1 - smooth_f) * accs[-1] | ||
| 93 | if loss < best_loss: | 94 | if loss < best_loss: | 
| 94 | best_loss = loss | 95 | best_loss = loss | 
| 95 | if acc > best_acc: | 96 | if acc > best_acc: | 
| @@ -132,9 +133,9 @@ class LRFinder(): | |||
| 132 | ax_loss.set_xlabel("Learning rate") | 133 | ax_loss.set_xlabel("Learning rate") | 
| 133 | ax_loss.set_ylabel("Loss") | 134 | ax_loss.set_ylabel("Loss") | 
| 134 | 135 | ||
| 135 | # ax_acc = ax_loss.twinx() | 136 | ax_acc = ax_loss.twinx() | 
| 136 | # ax_acc.plot(lrs, accs, color='blue') | 137 | ax_acc.plot(lrs, accs, color='blue') | 
| 137 | # ax_acc.set_ylabel("Accuracy") | 138 | ax_acc.set_ylabel("Accuracy") | 
| 138 | 139 | ||
| 139 | print("LR suggestion: steepest gradient") | 140 | print("LR suggestion: steepest gradient") | 
| 140 | min_grad_idx = None | 141 | min_grad_idx = None | 
| diff --git a/training/optimization.py b/training/optimization.py index a0c8673..dfee2b5 100644 --- a/training/optimization.py +++ b/training/optimization.py | |||
| @@ -1,4 +1,7 @@ | |||
| 1 | import math | 1 | import math | 
| 2 | from typing import Literal | ||
| 3 | |||
| 4 | import torch | ||
| 2 | from torch.optim.lr_scheduler import LambdaLR | 5 | from torch.optim.lr_scheduler import LambdaLR | 
| 3 | 6 | ||
| 4 | from diffusers.utils import logging | 7 | from diffusers.utils import logging | 
| @@ -6,41 +9,41 @@ from diffusers.utils import logging | |||
| 6 | logger = logging.get_logger(__name__) | 9 | logger = logging.get_logger(__name__) | 
| 7 | 10 | ||
| 8 | 11 | ||
| 9 | def get_one_cycle_schedule(optimizer, num_training_steps, annealing="cos", min_lr=0.04, mid_point=0.3, last_epoch=-1): | 12 | def get_one_cycle_schedule( | 
| 10 | """ | 13 | optimizer: torch.optim.Optimizer, | 
| 11 | Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after | 14 | num_training_steps: int, | 
| 12 | a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer. | 15 | warmup: Literal["cos", "linear"] = "cos", | 
| 13 | Args: | 16 | annealing: Literal["cos", "half_cos", "linear"] = "cos", | 
| 14 | optimizer ([`~torch.optim.Optimizer`]): | 17 | min_lr: int = 0.04, | 
| 15 | The optimizer for which to schedule the learning rate. | 18 | mid_point: int = 0.3, | 
| 16 | num_training_steps (`int`): | 19 | last_epoch: int = -1 | 
| 17 | The total number of training steps. | 20 | ): | 
| 18 | last_epoch (`int`, *optional*, defaults to -1): | ||
| 19 | The index of the last epoch when resuming training. | ||
| 20 | Return: | ||
| 21 | `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. | ||
| 22 | """ | ||
| 23 | |||
| 24 | def lr_lambda(current_step: int): | 21 | def lr_lambda(current_step: int): | 
| 25 | thresh_up = int(num_training_steps * min(mid_point, 0.5)) | 22 | thresh_up = int(num_training_steps * min(mid_point, 0.5)) | 
| 26 | 23 | ||
| 27 | if current_step < thresh_up: | 24 | if current_step < thresh_up: | 
| 28 | return min_lr + float(current_step) / float(max(1, thresh_up)) * (1 - min_lr) | 25 | progress = float(current_step) / float(max(1, thresh_up)) | 
| 26 | |||
| 27 | if warmup == "linear": | ||
| 28 | return min_lr + progress * (1 - min_lr) | ||
| 29 | |||
| 30 | return min_lr + 0.5 * (1.0 + math.cos(math.pi * (1 + progress))) | ||
| 29 | 31 | ||
| 30 | if annealing == "linear": | 32 | if annealing == "linear": | 
| 31 | thresh_down = thresh_up * 2 | 33 | thresh_down = thresh_up * 2 | 
| 32 | 34 | ||
| 33 | if current_step < thresh_down: | 35 | if current_step < thresh_down: | 
| 34 | return min_lr + float(thresh_down - current_step) / float(max(1, thresh_down - thresh_up)) * (1 - min_lr) | 36 | progress = float(thresh_down - current_step) / float(max(1, thresh_down - thresh_up)) | 
| 37 | return min_lr + progress * (1 - min_lr) | ||
| 35 | 38 | ||
| 36 | progress = float(num_training_steps - current_step) / float(max(1, num_training_steps - thresh_down)) | 39 | progress = float(num_training_steps - current_step) / float(max(1, num_training_steps - thresh_down)) | 
| 37 | return max(0.0, progress) * min_lr | 40 | return progress * min_lr | 
| 38 | 41 | ||
| 39 | progress = float(current_step - thresh_up) / float(max(1, num_training_steps - thresh_up)) | 42 | progress = float(current_step - thresh_up) / float(max(1, num_training_steps - thresh_up)) | 
| 40 | 43 | ||
| 41 | if annealing == "half_cos": | 44 | if annealing == "half_cos": | 
| 42 | return max(0.0, 1.0 + math.cos(math.pi * (0.5 + 0.5 * progress))) | 45 | return 1.0 + math.cos(math.pi * (0.5 + 0.5 * progress)) | 
| 43 | 46 | ||
| 44 | return max(0.0, 0.5 * (1.0 + math.cos(math.pi * progress))) | 47 | return 0.5 * (1.0 + math.cos(math.pi * progress)) | 
| 45 | 48 | ||
| 46 | return LambdaLR(optimizer, lr_lambda, last_epoch) | 49 | return LambdaLR(optimizer, lr_lambda, last_epoch) | 
