diff options
Diffstat (limited to 'training')
| -rw-r--r-- | training/optimization.py | 42 |
1 files changed, 42 insertions, 0 deletions
diff --git a/training/optimization.py b/training/optimization.py new file mode 100644 index 0000000..012beed --- /dev/null +++ b/training/optimization.py | |||
| @@ -0,0 +1,42 @@ | |||
| 1 | import math | ||
| 2 | from torch.optim.lr_scheduler import LambdaLR | ||
| 3 | |||
| 4 | from diffusers.utils import logging | ||
| 5 | |||
| 6 | logger = logging.get_logger(__name__) | ||
| 7 | |||
| 8 | |||
| 9 | def get_one_cycle_schedule(optimizer, num_training_steps, annealing="cos", min_lr=0.05, mid_point=0.42, last_epoch=-1): | ||
| 10 | """ | ||
| 11 | Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after | ||
| 12 | a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer. | ||
| 13 | Args: | ||
| 14 | optimizer ([`~torch.optim.Optimizer`]): | ||
| 15 | The optimizer for which to schedule the learning rate. | ||
| 16 | num_training_steps (`int`): | ||
| 17 | The total number of training steps. | ||
| 18 | last_epoch (`int`, *optional*, defaults to -1): | ||
| 19 | The index of the last epoch when resuming training. | ||
| 20 | Return: | ||
| 21 | `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. | ||
| 22 | """ | ||
| 23 | |||
| 24 | def lr_lambda(current_step: int): | ||
| 25 | thresh_up = int(num_training_steps * min(mid_point, 0.5)) | ||
| 26 | |||
| 27 | if current_step < thresh_up: | ||
| 28 | return min_lr + float(current_step) / float(max(1, thresh_up)) * (1 - min_lr) | ||
| 29 | |||
| 30 | if annealing == "linear": | ||
| 31 | thresh_down = thresh_up * 2 | ||
| 32 | |||
| 33 | if current_step < thresh_down: | ||
| 34 | return min_lr + float(thresh_down - current_step) / float(max(1, thresh_down - thresh_up)) * (1 - min_lr) | ||
| 35 | |||
| 36 | return max(0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - thresh_down))) * min_lr | ||
| 37 | else: | ||
| 38 | progress = float(current_step - thresh_up) / float(max(1, num_training_steps - thresh_up)) | ||
| 39 | |||
| 40 | return max(0.0, 0.5 * (1.0 + math.cos(math.pi * progress))) | ||
| 41 | |||
| 42 | return LambdaLR(optimizer, lr_lambda, last_epoch) | ||
