diff options
Diffstat (limited to 'training')
| -rw-r--r-- | training/common.py | 55 | 
1 files changed, 55 insertions, 0 deletions
diff --git a/training/common.py b/training/common.py index 0b2ae44..90cf910 100644 --- a/training/common.py +++ b/training/common.py  | |||
| @@ -1,10 +1,65 @@ | |||
| 1 | import math | ||
| 2 | |||
| 1 | import torch | 3 | import torch | 
| 2 | import torch.nn.functional as F | 4 | import torch.nn.functional as F | 
| 3 | 5 | ||
| 4 | from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel | 6 | from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel | 
| 7 | from diffusers.optimization import get_scheduler as get_scheduler_, get_cosine_with_hard_restarts_schedule_with_warmup | ||
| 5 | 8 | ||
| 6 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 9 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 
| 7 | 10 | ||
| 11 | from training.optimization import get_one_cycle_schedule | ||
| 12 | |||
| 13 | |||
| 14 | def get_scheduler( | ||
| 15 | id: str, | ||
| 16 | min_lr: float, | ||
| 17 | lr: float, | ||
| 18 | warmup_func: str, | ||
| 19 | annealing_func: str, | ||
| 20 | warmup_exp: int, | ||
| 21 | annealing_exp: int, | ||
| 22 | cycles: int, | ||
| 23 | warmup_epochs: int, | ||
| 24 | optimizer: torch.optim.Optimizer, | ||
| 25 | max_train_steps: int, | ||
| 26 | num_update_steps_per_epoch: int, | ||
| 27 | gradient_accumulation_steps: int, | ||
| 28 | ): | ||
| 29 | warmup_steps = warmup_epochs * num_update_steps_per_epoch * gradient_accumulation_steps | ||
| 30 | |||
| 31 | if id == "one_cycle": | ||
| 32 | min_lr = 0.04 if min_lr is None else min_lr / lr | ||
| 33 | |||
| 34 | lr_scheduler = get_one_cycle_schedule( | ||
| 35 | optimizer=optimizer, | ||
| 36 | num_training_steps=max_train_steps * gradient_accumulation_steps, | ||
| 37 | warmup=warmup_func, | ||
| 38 | annealing=annealing_func, | ||
| 39 | warmup_exp=warmup_exp, | ||
| 40 | annealing_exp=annealing_exp, | ||
| 41 | min_lr=min_lr, | ||
| 42 | ) | ||
| 43 | elif id == "cosine_with_restarts": | ||
| 44 | cycles = cycles if cycles is not None else math.ceil( | ||
| 45 | math.sqrt(((max_train_steps - warmup_steps) / num_update_steps_per_epoch))) | ||
| 46 | |||
| 47 | lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( | ||
| 48 | optimizer=optimizer, | ||
| 49 | num_warmup_steps=warmup_steps, | ||
| 50 | num_training_steps=max_train_steps * gradient_accumulation_steps, | ||
| 51 | num_cycles=cycles, | ||
| 52 | ) | ||
| 53 | else: | ||
| 54 | lr_scheduler = get_scheduler_( | ||
| 55 | id, | ||
| 56 | optimizer=optimizer, | ||
| 57 | num_warmup_steps=warmup_steps, | ||
| 58 | num_training_steps=max_train_steps * gradient_accumulation_steps, | ||
| 59 | ) | ||
| 60 | |||
| 61 | return lr_scheduler | ||
| 62 | |||
| 8 | 63 | ||
| 9 | def generate_class_images( | 64 | def generate_class_images( | 
| 10 | accelerator, | 65 | accelerator, | 
