diff options
Diffstat (limited to 'training')
-rw-r--r-- | training/common.py | 55 |
1 files changed, 55 insertions, 0 deletions
diff --git a/training/common.py b/training/common.py index 0b2ae44..90cf910 100644 --- a/training/common.py +++ b/training/common.py | |||
@@ -1,10 +1,65 @@ | |||
1 | import math | ||
2 | |||
1 | import torch | 3 | import torch |
2 | import torch.nn.functional as F | 4 | import torch.nn.functional as F |
3 | 5 | ||
4 | from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel | 6 | from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel |
7 | from diffusers.optimization import get_scheduler as get_scheduler_, get_cosine_with_hard_restarts_schedule_with_warmup | ||
5 | 8 | ||
6 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 9 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion |
7 | 10 | ||
11 | from training.optimization import get_one_cycle_schedule | ||
12 | |||
13 | |||
14 | def get_scheduler( | ||
15 | id: str, | ||
16 | min_lr: float, | ||
17 | lr: float, | ||
18 | warmup_func: str, | ||
19 | annealing_func: str, | ||
20 | warmup_exp: int, | ||
21 | annealing_exp: int, | ||
22 | cycles: int, | ||
23 | warmup_epochs: int, | ||
24 | optimizer: torch.optim.Optimizer, | ||
25 | max_train_steps: int, | ||
26 | num_update_steps_per_epoch: int, | ||
27 | gradient_accumulation_steps: int, | ||
28 | ): | ||
29 | warmup_steps = warmup_epochs * num_update_steps_per_epoch * gradient_accumulation_steps | ||
30 | |||
31 | if id == "one_cycle": | ||
32 | min_lr = 0.04 if min_lr is None else min_lr / lr | ||
33 | |||
34 | lr_scheduler = get_one_cycle_schedule( | ||
35 | optimizer=optimizer, | ||
36 | num_training_steps=max_train_steps * gradient_accumulation_steps, | ||
37 | warmup=warmup_func, | ||
38 | annealing=annealing_func, | ||
39 | warmup_exp=warmup_exp, | ||
40 | annealing_exp=annealing_exp, | ||
41 | min_lr=min_lr, | ||
42 | ) | ||
43 | elif id == "cosine_with_restarts": | ||
44 | cycles = cycles if cycles is not None else math.ceil( | ||
45 | math.sqrt(((max_train_steps - warmup_steps) / num_update_steps_per_epoch))) | ||
46 | |||
47 | lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( | ||
48 | optimizer=optimizer, | ||
49 | num_warmup_steps=warmup_steps, | ||
50 | num_training_steps=max_train_steps * gradient_accumulation_steps, | ||
51 | num_cycles=cycles, | ||
52 | ) | ||
53 | else: | ||
54 | lr_scheduler = get_scheduler_( | ||
55 | id, | ||
56 | optimizer=optimizer, | ||
57 | num_warmup_steps=warmup_steps, | ||
58 | num_training_steps=max_train_steps * gradient_accumulation_steps, | ||
59 | ) | ||
60 | |||
61 | return lr_scheduler | ||
62 | |||
8 | 63 | ||
9 | def generate_class_images( | 64 | def generate_class_images( |
10 | accelerator, | 65 | accelerator, |