diff options
| -rw-r--r-- | training/optimization.py | 8 | 
1 files changed, 4 insertions, 4 deletions
| diff --git a/training/optimization.py b/training/optimization.py index 012beed..0fd7ec8 100644 --- a/training/optimization.py +++ b/training/optimization.py | |||
| @@ -6,7 +6,7 @@ from diffusers.utils import logging | |||
| 6 | logger = logging.get_logger(__name__) | 6 | logger = logging.get_logger(__name__) | 
| 7 | 7 | ||
| 8 | 8 | ||
| 9 | def get_one_cycle_schedule(optimizer, num_training_steps, annealing="cos", min_lr=0.05, mid_point=0.42, last_epoch=-1): | 9 | def get_one_cycle_schedule(optimizer, num_training_steps, annealing="cos", min_lr=0.05, mid_point=0.43, last_epoch=-1): | 
| 10 | """ | 10 | """ | 
| 11 | Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after | 11 | Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after | 
| 12 | a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer. | 12 | a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer. | 
| @@ -33,10 +33,10 @@ def get_one_cycle_schedule(optimizer, num_training_steps, annealing="cos", min_l | |||
| 33 | if current_step < thresh_down: | 33 | if current_step < thresh_down: | 
| 34 | return min_lr + float(thresh_down - current_step) / float(max(1, thresh_down - thresh_up)) * (1 - min_lr) | 34 | return min_lr + float(thresh_down - current_step) / float(max(1, thresh_down - thresh_up)) * (1 - min_lr) | 
| 35 | 35 | ||
| 36 | return max(0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - thresh_down))) * min_lr | 36 | progress = float(num_training_steps - current_step) / float(max(1, num_training_steps - thresh_down)) | 
| 37 | return max(0.0, progress) * min_lr | ||
| 37 | else: | 38 | else: | 
| 38 | progress = float(current_step - thresh_up) / float(max(1, num_training_steps - thresh_up)) | 39 | progress = float(current_step - thresh_up) / float(max(1, num_training_steps - thresh_up)) | 
| 39 | 40 | return max(0.0, 1.0 + math.cos(math.pi * (0.5 + 0.5 * progress))) | |
| 40 | return max(0.0, 0.5 * (1.0 + math.cos(math.pi * progress))) | ||
| 41 | 41 | ||
| 42 | return LambdaLR(optimizer, lr_lambda, last_epoch) | 42 | return LambdaLR(optimizer, lr_lambda, last_epoch) | 
