summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--train_ti.py21
-rw-r--r--training/lr.py7
-rw-r--r--training/optimization.py43
3 files changed, 42 insertions, 29 deletions
diff --git a/train_ti.py b/train_ti.py
index d7696e5..b1f6a49 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -903,12 +903,21 @@ def main():
903 903
904 text_encoder.eval() 904 text_encoder.eval()
905 905
906 cur_loss_val = AverageMeter()
907 cur_acc_val = AverageMeter()
908
906 with torch.inference_mode(): 909 with torch.inference_mode():
907 for step, batch in enumerate(val_dataloader): 910 for step, batch in enumerate(val_dataloader):
908 loss, acc, bsz = loop(batch) 911 loss, acc, bsz = loop(batch)
909 912
910 avg_loss_val.update(loss.detach_(), bsz) 913 loss = loss.detach_()
911 avg_acc_val.update(acc.detach_(), bsz) 914 acc = acc.detach_()
915
916 cur_loss_val.update(loss, bsz)
917 cur_acc_val.update(acc, bsz)
918
919 avg_loss_val.update(loss, bsz)
920 avg_acc_val.update(acc, bsz)
912 921
913 local_progress_bar.update(1) 922 local_progress_bar.update(1)
914 global_progress_bar.update(1) 923 global_progress_bar.update(1)
@@ -921,10 +930,10 @@ def main():
921 } 930 }
922 local_progress_bar.set_postfix(**logs) 931 local_progress_bar.set_postfix(**logs)
923 932
924 accelerator.log({ 933 logs["val/cur_loss"] = cur_loss_val.avg.item()
925 "val/loss": avg_loss_val.avg.item(), 934 logs["val/cur_acc"] = cur_acc_val.avg.item()
926 "val/acc": avg_acc_val.avg.item(), 935
927 }, step=global_step) 936 accelerator.log(logs, step=global_step)
928 937
929 local_progress_bar.clear() 938 local_progress_bar.clear()
930 global_progress_bar.clear() 939 global_progress_bar.clear()
diff --git a/training/lr.py b/training/lr.py
index c0e9b3f..0c5ce9e 100644
--- a/training/lr.py
+++ b/training/lr.py
@@ -90,6 +90,7 @@ class LRFinder():
90 else: 90 else:
91 if smooth_f > 0: 91 if smooth_f > 0:
92 loss = smooth_f * loss + (1 - smooth_f) * losses[-1] 92 loss = smooth_f * loss + (1 - smooth_f) * losses[-1]
93 acc = smooth_f * acc + (1 - smooth_f) * accs[-1]
93 if loss < best_loss: 94 if loss < best_loss:
94 best_loss = loss 95 best_loss = loss
95 if acc > best_acc: 96 if acc > best_acc:
@@ -132,9 +133,9 @@ class LRFinder():
132 ax_loss.set_xlabel("Learning rate") 133 ax_loss.set_xlabel("Learning rate")
133 ax_loss.set_ylabel("Loss") 134 ax_loss.set_ylabel("Loss")
134 135
135 # ax_acc = ax_loss.twinx() 136 ax_acc = ax_loss.twinx()
136 # ax_acc.plot(lrs, accs, color='blue') 137 ax_acc.plot(lrs, accs, color='blue')
137 # ax_acc.set_ylabel("Accuracy") 138 ax_acc.set_ylabel("Accuracy")
138 139
139 print("LR suggestion: steepest gradient") 140 print("LR suggestion: steepest gradient")
140 min_grad_idx = None 141 min_grad_idx = None
diff --git a/training/optimization.py b/training/optimization.py
index a0c8673..dfee2b5 100644
--- a/training/optimization.py
+++ b/training/optimization.py
@@ -1,4 +1,7 @@
1import math 1import math
2from typing import Literal
3
4import torch
2from torch.optim.lr_scheduler import LambdaLR 5from torch.optim.lr_scheduler import LambdaLR
3 6
4from diffusers.utils import logging 7from diffusers.utils import logging
@@ -6,41 +9,41 @@ from diffusers.utils import logging
6logger = logging.get_logger(__name__) 9logger = logging.get_logger(__name__)
7 10
8 11
9def get_one_cycle_schedule(optimizer, num_training_steps, annealing="cos", min_lr=0.04, mid_point=0.3, last_epoch=-1): 12def get_one_cycle_schedule(
10 """ 13 optimizer: torch.optim.Optimizer,
11 Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after 14 num_training_steps: int,
12 a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer. 15 warmup: Literal["cos", "linear"] = "cos",
13 Args: 16 annealing: Literal["cos", "half_cos", "linear"] = "cos",
14 optimizer ([`~torch.optim.Optimizer`]): 17 min_lr: int = 0.04,
15 The optimizer for which to schedule the learning rate. 18 mid_point: int = 0.3,
16 num_training_steps (`int`): 19 last_epoch: int = -1
17 The total number of training steps. 20):
18 last_epoch (`int`, *optional*, defaults to -1):
19 The index of the last epoch when resuming training.
20 Return:
21 `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
22 """
23
24 def lr_lambda(current_step: int): 21 def lr_lambda(current_step: int):
25 thresh_up = int(num_training_steps * min(mid_point, 0.5)) 22 thresh_up = int(num_training_steps * min(mid_point, 0.5))
26 23
27 if current_step < thresh_up: 24 if current_step < thresh_up:
28 return min_lr + float(current_step) / float(max(1, thresh_up)) * (1 - min_lr) 25 progress = float(current_step) / float(max(1, thresh_up))
26
27 if warmup == "linear":
28 return min_lr + progress * (1 - min_lr)
29
30 return min_lr + 0.5 * (1.0 + math.cos(math.pi * (1 + progress)))
29 31
30 if annealing == "linear": 32 if annealing == "linear":
31 thresh_down = thresh_up * 2 33 thresh_down = thresh_up * 2
32 34
33 if current_step < thresh_down: 35 if current_step < thresh_down:
34 return min_lr + float(thresh_down - current_step) / float(max(1, thresh_down - thresh_up)) * (1 - min_lr) 36 progress = float(thresh_down - current_step) / float(max(1, thresh_down - thresh_up))
37 return min_lr + progress * (1 - min_lr)
35 38
36 progress = float(num_training_steps - current_step) / float(max(1, num_training_steps - thresh_down)) 39 progress = float(num_training_steps - current_step) / float(max(1, num_training_steps - thresh_down))
37 return max(0.0, progress) * min_lr 40 return progress * min_lr
38 41
39 progress = float(current_step - thresh_up) / float(max(1, num_training_steps - thresh_up)) 42 progress = float(current_step - thresh_up) / float(max(1, num_training_steps - thresh_up))
40 43
41 if annealing == "half_cos": 44 if annealing == "half_cos":
42 return max(0.0, 1.0 + math.cos(math.pi * (0.5 + 0.5 * progress))) 45 return 1.0 + math.cos(math.pi * (0.5 + 0.5 * progress))
43 46
44 return max(0.0, 0.5 * (1.0 + math.cos(math.pi * progress))) 47 return 0.5 * (1.0 + math.cos(math.pi * progress))
45 48
46 return LambdaLR(optimizer, lr_lambda, last_epoch) 49 return LambdaLR(optimizer, lr_lambda, last_epoch)