summaryrefslogtreecommitdiffstats
path: root/training/common.py
diff options
context:
space:
mode:
Diffstat (limited to 'training/common.py')
-rw-r--r--training/common.py55
1 files changed, 55 insertions, 0 deletions
diff --git a/training/common.py b/training/common.py
index 0b2ae44..90cf910 100644
--- a/training/common.py
+++ b/training/common.py
@@ -1,10 +1,65 @@
1import math
2
1import torch 3import torch
2import torch.nn.functional as F 4import torch.nn.functional as F
3 5
4from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel 6from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel
7from diffusers.optimization import get_scheduler as get_scheduler_, get_cosine_with_hard_restarts_schedule_with_warmup
5 8
6from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion 9from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
7 10
11from training.optimization import get_one_cycle_schedule
12
13
14def get_scheduler(
15 id: str,
16 min_lr: float,
17 lr: float,
18 warmup_func: str,
19 annealing_func: str,
20 warmup_exp: int,
21 annealing_exp: int,
22 cycles: int,
23 warmup_epochs: int,
24 optimizer: torch.optim.Optimizer,
25 max_train_steps: int,
26 num_update_steps_per_epoch: int,
27 gradient_accumulation_steps: int,
28):
29 warmup_steps = warmup_epochs * num_update_steps_per_epoch * gradient_accumulation_steps
30
31 if id == "one_cycle":
32 min_lr = 0.04 if min_lr is None else min_lr / lr
33
34 lr_scheduler = get_one_cycle_schedule(
35 optimizer=optimizer,
36 num_training_steps=max_train_steps * gradient_accumulation_steps,
37 warmup=warmup_func,
38 annealing=annealing_func,
39 warmup_exp=warmup_exp,
40 annealing_exp=annealing_exp,
41 min_lr=min_lr,
42 )
43 elif id == "cosine_with_restarts":
44 cycles = cycles if cycles is not None else math.ceil(
45 math.sqrt(((max_train_steps - warmup_steps) / num_update_steps_per_epoch)))
46
47 lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup(
48 optimizer=optimizer,
49 num_warmup_steps=warmup_steps,
50 num_training_steps=max_train_steps * gradient_accumulation_steps,
51 num_cycles=cycles,
52 )
53 else:
54 lr_scheduler = get_scheduler_(
55 id,
56 optimizer=optimizer,
57 num_warmup_steps=warmup_steps,
58 num_training_steps=max_train_steps * gradient_accumulation_steps,
59 )
60
61 return lr_scheduler
62
8 63
9def generate_class_images( 64def generate_class_images(
10 accelerator, 65 accelerator,