diff options
| author | Volpeon <git@volpeon.ink> | 2023-01-02 11:27:38 +0100 | 
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2023-01-02 11:27:38 +0100 | 
| commit | c24a24c77a825e71de2e67c1515c84d2b77701fa (patch) | |
| tree | 2ed4616191c8f61c3f502375bd6698e0d29d041e | |
| parent | Update (diff) | |
| download | textual-inversion-diff-c24a24c77a825e71de2e67c1515c84d2b77701fa.tar.gz textual-inversion-diff-c24a24c77a825e71de2e67c1515c84d2b77701fa.tar.bz2 textual-inversion-diff-c24a24c77a825e71de2e67c1515c84d2b77701fa.zip | |
Improved one cycle scheduler
| -rw-r--r-- | train_ti.py | 2 | ||||
| -rw-r--r-- | training/optimization.py | 92 | 
2 files changed, 63 insertions, 31 deletions
| diff --git a/train_ti.py b/train_ti.py index 870bd40..775b918 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -250,7 +250,7 @@ def parse_args(): | |||
| 250 | parser.add_argument( | 250 | parser.add_argument( | 
| 251 | "--lr_annealing_exp", | 251 | "--lr_annealing_exp", | 
| 252 | type=int, | 252 | type=int, | 
| 253 | default=2, | 253 | default=1, | 
| 254 | help='If lr_annealing_func is "half_cos" or "cos", exponent to modify the function' | 254 | help='If lr_annealing_func is "half_cos" or "cos", exponent to modify the function' | 
| 255 | ) | 255 | ) | 
| 256 | parser.add_argument( | 256 | parser.add_argument( | 
| diff --git a/training/optimization.py b/training/optimization.py index 725599b..cb7f088 100644 --- a/training/optimization.py +++ b/training/optimization.py | |||
| @@ -1,5 +1,6 @@ | |||
| 1 | import math | 1 | import math | 
| 2 | from typing import Literal | 2 | from typing import NamedTuple, Literal, Callable | 
| 3 | from functools import partial | ||
| 3 | 4 | ||
| 4 | import torch | 5 | import torch | 
| 5 | from torch.optim.lr_scheduler import LambdaLR | 6 | from torch.optim.lr_scheduler import LambdaLR | 
| @@ -9,6 +10,40 @@ from diffusers.utils import logging | |||
| 9 | logger = logging.get_logger(__name__) | 10 | logger = logging.get_logger(__name__) | 
| 10 | 11 | ||
| 11 | 12 | ||
| 13 | class OneCyclePhase(NamedTuple): | ||
| 14 | step_min: int | ||
| 15 | step_max: int | ||
| 16 | min: float | ||
| 17 | max: float | ||
| 18 | func: Callable[[float], float] | ||
| 19 | |||
| 20 | |||
| 21 | def warmup_linear(progress: float): | ||
| 22 | return progress | ||
| 23 | |||
| 24 | |||
| 25 | def warmup_cos(exp: int, progress: float): | ||
| 26 | lr = 0.5 * (1.0 + math.cos(math.pi * (1 + progress))) | ||
| 27 | lr = lr ** (exp - (exp - 1) * progress) | ||
| 28 | return lr | ||
| 29 | |||
| 30 | |||
| 31 | def anneal_linear(progress: float): | ||
| 32 | return 1 - progress | ||
| 33 | |||
| 34 | |||
| 35 | def anneal_half_cos(exp: int, progress: float): | ||
| 36 | lr = 1.0 + math.cos(math.pi * (0.5 + 0.5 * progress)) | ||
| 37 | lr = lr ** (exp - (exp - 1) * progress) | ||
| 38 | return lr | ||
| 39 | |||
| 40 | |||
| 41 | def anneal_cos(exp: int, progress: float): | ||
| 42 | lr = 0.5 * (1.0 + math.cos(math.pi * progress)) | ||
| 43 | lr = lr ** (exp - (exp - 1) * progress) | ||
| 44 | return lr | ||
| 45 | |||
| 46 | |||
| 12 | def get_one_cycle_schedule( | 47 | def get_one_cycle_schedule( | 
| 13 | optimizer: torch.optim.Optimizer, | 48 | optimizer: torch.optim.Optimizer, | 
| 14 | num_training_steps: int, | 49 | num_training_steps: int, | 
| @@ -20,38 +55,35 @@ def get_one_cycle_schedule( | |||
| 20 | mid_point: int = 0.3, | 55 | mid_point: int = 0.3, | 
| 21 | last_epoch: int = -1 | 56 | last_epoch: int = -1 | 
| 22 | ): | 57 | ): | 
| 23 | def lr_lambda(current_step: int): | 58 | if warmup == "linear": | 
| 24 | thresh_up = int(num_training_steps * min(mid_point, 0.5)) | 59 | warmup_func = warmup_linear | 
| 25 | 60 | else: | |
| 26 | if current_step < thresh_up: | 61 | warmup_func = partial(warmup_cos, warmup_exp) | 
| 27 | progress = float(current_step) / float(max(1, thresh_up)) | ||
| 28 | |||
| 29 | if warmup == "linear": | ||
| 30 | return min_lr + progress * (1 - min_lr) | ||
| 31 | 62 | ||
| 32 | lr = 0.5 * (1.0 + math.cos(math.pi * (1 + progress))) | 63 | if annealing == "linear": | 
| 33 | lr = lr ** (warmup_exp - (warmup_exp - 1) * progress) | 64 | anneal_func = anneal_linear | 
| 34 | return min_lr + lr * (1 - min_lr) | 65 | elif annealing == "half_cos": | 
| 66 | anneal_func = partial(anneal_half_cos, annealing_exp) | ||
| 67 | else: | ||
| 68 | anneal_func = partial(anneal_cos, annealing_exp) | ||
| 35 | 69 | ||
| 36 | if annealing == "linear": | 70 | thresh_up = int(num_training_steps * min(mid_point, 0.5)) | 
| 37 | thresh_down = thresh_up * 2 | ||
| 38 | 71 | ||
| 39 | if current_step < thresh_down: | 72 | if annealing == "linear": | 
| 40 | progress = float(thresh_down - current_step) / float(max(1, thresh_down - thresh_up)) | 73 | thresh_down = thresh_up * 2 | 
| 41 | return min_lr + progress * (1 - min_lr) | 74 | phases = [ | 
| 75 | OneCyclePhase(0, thresh_up, min_lr, 1, warmup_func), | ||
| 76 | OneCyclePhase(thresh_up, thresh_down, min_lr, 1, anneal_func), | ||
| 77 | OneCyclePhase(thresh_down, num_training_steps, 0, min_lr, anneal_func), | ||
| 78 | ] | ||
| 79 | else: | ||
| 80 | phases = [ | ||
| 81 | OneCyclePhase(0, thresh_up, 0, 1, warmup_func), | ||
| 82 | OneCyclePhase(thresh_up, num_training_steps, 0, 1, anneal_func), | ||
| 83 | ] | ||
| 42 | 84 | ||
| 43 | progress = float(num_training_steps - current_step) / float(max(1, num_training_steps - thresh_down)) | 85 | def lr_lambda(current_step: int): | 
| 44 | return progress * min_lr | 86 | phase = [p for p in phases if current_step >= p.step_min][-1] | 
| 45 | 87 | return phase.min + phase.func((current_step - phase.step_min) / (phase.step_max - phase.step_min)) * (phase.max - phase.min) | |
| 46 | progress = float(current_step - thresh_up) / float(max(1, num_training_steps - thresh_up)) | ||
| 47 | |||
| 48 | if annealing == "half_cos": | ||
| 49 | lr = 1.0 + math.cos(math.pi * (0.5 + 0.5 * progress)) | ||
| 50 | lr = lr ** (annealing_exp - (annealing_exp - 1) * progress) | ||
| 51 | return lr | ||
| 52 | |||
| 53 | lr = 0.5 * (1.0 + math.cos(math.pi * progress)) | ||
| 54 | lr = lr ** (annealing_exp - (annealing_exp - 1) * progress) | ||
| 55 | return lr | ||
| 56 | 88 | ||
| 57 | return LambdaLR(optimizer, lr_lambda, last_epoch) | 89 | return LambdaLR(optimizer, lr_lambda, last_epoch) | 
