diff options
Diffstat (limited to 'training')
-rw-r--r-- | training/common.py | 42 |
1 files changed, 22 insertions, 20 deletions
diff --git a/training/common.py b/training/common.py index 842ac07..180396e 100644 --- a/training/common.py +++ b/training/common.py | |||
@@ -36,21 +36,24 @@ def get_scheduler( | |||
36 | warmup_exp: int, | 36 | warmup_exp: int, |
37 | annealing_exp: int, | 37 | annealing_exp: int, |
38 | cycles: int, | 38 | cycles: int, |
39 | train_epochs: int, | ||
39 | warmup_epochs: int, | 40 | warmup_epochs: int, |
40 | optimizer: torch.optim.Optimizer, | 41 | optimizer: torch.optim.Optimizer, |
41 | num_train_epochs: int, | 42 | num_training_steps_per_epoch: int, |
42 | num_update_steps_per_epoch: int, | ||
43 | gradient_accumulation_steps: int, | 43 | gradient_accumulation_steps: int, |
44 | ): | 44 | ): |
45 | num_train_steps = num_train_epochs * num_update_steps_per_epoch | 45 | num_training_steps_per_epoch = math.ceil( |
46 | warmup_steps = warmup_epochs * num_update_steps_per_epoch * gradient_accumulation_steps | 46 | num_training_steps_per_epoch / gradient_accumulation_steps |
47 | ) * gradient_accumulation_steps | ||
48 | num_training_steps = train_epochs * num_training_steps_per_epoch | ||
49 | num_warmup_steps = warmup_epochs * num_training_steps_per_epoch | ||
47 | 50 | ||
48 | if id == "one_cycle": | 51 | if id == "one_cycle": |
49 | min_lr = 0.04 if min_lr is None else min_lr / lr | 52 | min_lr = 0.04 if min_lr is None else min_lr / lr |
50 | 53 | ||
51 | lr_scheduler = get_one_cycle_schedule( | 54 | lr_scheduler = get_one_cycle_schedule( |
52 | optimizer=optimizer, | 55 | optimizer=optimizer, |
53 | num_training_steps=num_train_steps * gradient_accumulation_steps, | 56 | num_training_steps=num_training_steps, |
54 | warmup=warmup_func, | 57 | warmup=warmup_func, |
55 | annealing=annealing_func, | 58 | annealing=annealing_func, |
56 | warmup_exp=warmup_exp, | 59 | warmup_exp=warmup_exp, |
@@ -58,21 +61,21 @@ def get_scheduler( | |||
58 | min_lr=min_lr, | 61 | min_lr=min_lr, |
59 | ) | 62 | ) |
60 | elif id == "cosine_with_restarts": | 63 | elif id == "cosine_with_restarts": |
61 | cycles = cycles if cycles is not None else math.ceil( | 64 | if cycles is None: |
62 | math.sqrt(((num_train_steps - warmup_steps) / num_update_steps_per_epoch))) | 65 | cycles = math.ceil(math.sqrt(((num_training_steps - num_warmup_steps) / num_training_steps_per_epoch))) |
63 | 66 | ||
64 | lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( | 67 | lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( |
65 | optimizer=optimizer, | 68 | optimizer=optimizer, |
66 | num_warmup_steps=warmup_steps, | 69 | num_warmup_steps=num_warmup_steps, |
67 | num_training_steps=num_train_steps * gradient_accumulation_steps, | 70 | num_training_steps=num_training_steps, |
68 | num_cycles=cycles, | 71 | num_cycles=cycles, |
69 | ) | 72 | ) |
70 | else: | 73 | else: |
71 | lr_scheduler = get_scheduler_( | 74 | lr_scheduler = get_scheduler_( |
72 | id, | 75 | id, |
73 | optimizer=optimizer, | 76 | optimizer=optimizer, |
74 | num_warmup_steps=warmup_steps, | 77 | num_warmup_steps=num_warmup_steps, |
75 | num_training_steps=num_train_steps * gradient_accumulation_steps, | 78 | num_training_steps=num_training_steps, |
76 | ) | 79 | ) |
77 | 80 | ||
78 | return lr_scheduler | 81 | return lr_scheduler |
@@ -135,7 +138,7 @@ def loss_step( | |||
135 | noise_scheduler: DDPMScheduler, | 138 | noise_scheduler: DDPMScheduler, |
136 | unet: UNet2DConditionModel, | 139 | unet: UNet2DConditionModel, |
137 | text_encoder: CLIPTextModel, | 140 | text_encoder: CLIPTextModel, |
138 | num_class_images: int, | 141 | with_prior: bool, |
139 | prior_loss_weight: float, | 142 | prior_loss_weight: float, |
140 | seed: int, | 143 | seed: int, |
141 | step: int, | 144 | step: int, |
@@ -184,7 +187,7 @@ def loss_step( | |||
184 | else: | 187 | else: |
185 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") | 188 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") |
186 | 189 | ||
187 | if num_class_images != 0: | 190 | if with_prior: |
188 | # Chunk the noise and model_pred into two parts and compute the loss on each part separately. | 191 | # Chunk the noise and model_pred into two parts and compute the loss on each part separately. |
189 | model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) | 192 | model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) |
190 | target, target_prior = torch.chunk(target, 2, dim=0) | 193 | target, target_prior = torch.chunk(target, 2, dim=0) |
@@ -226,11 +229,10 @@ def train_loop( | |||
226 | on_after_optimize: Callable[[float], None] = noop, | 229 | on_after_optimize: Callable[[float], None] = noop, |
227 | on_eval: Callable[[], _GeneratorContextManager] = nullcontext | 230 | on_eval: Callable[[], _GeneratorContextManager] = nullcontext |
228 | ): | 231 | ): |
229 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps) | 232 | num_training_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps) |
230 | num_train_steps = num_epochs * num_update_steps_per_epoch | ||
231 | |||
232 | num_val_steps_per_epoch = len(val_dataloader) | 233 | num_val_steps_per_epoch = len(val_dataloader) |
233 | num_epochs = math.ceil(num_train_steps / num_update_steps_per_epoch) | 234 | |
235 | num_training_steps = num_training_steps_per_epoch * num_epochs | ||
234 | num_val_steps = num_val_steps_per_epoch * num_epochs | 236 | num_val_steps = num_val_steps_per_epoch * num_epochs |
235 | 237 | ||
236 | global_step = 0 | 238 | global_step = 0 |
@@ -244,14 +246,14 @@ def train_loop( | |||
244 | max_acc_val = 0.0 | 246 | max_acc_val = 0.0 |
245 | 247 | ||
246 | local_progress_bar = tqdm( | 248 | local_progress_bar = tqdm( |
247 | range(num_update_steps_per_epoch + num_val_steps_per_epoch), | 249 | range(num_training_steps_per_epoch + num_val_steps_per_epoch), |
248 | disable=not accelerator.is_local_main_process, | 250 | disable=not accelerator.is_local_main_process, |
249 | dynamic_ncols=True | 251 | dynamic_ncols=True |
250 | ) | 252 | ) |
251 | local_progress_bar.set_description(f"Epoch 1 / {num_epochs}") | 253 | local_progress_bar.set_description(f"Epoch 1 / {num_epochs}") |
252 | 254 | ||
253 | global_progress_bar = tqdm( | 255 | global_progress_bar = tqdm( |
254 | range(num_train_steps + num_val_steps), | 256 | range(num_training_steps + num_val_steps), |
255 | disable=not accelerator.is_local_main_process, | 257 | disable=not accelerator.is_local_main_process, |
256 | dynamic_ncols=True | 258 | dynamic_ncols=True |
257 | ) | 259 | ) |
@@ -309,7 +311,7 @@ def train_loop( | |||
309 | 311 | ||
310 | local_progress_bar.set_postfix(**logs) | 312 | local_progress_bar.set_postfix(**logs) |
311 | 313 | ||
312 | if global_step >= num_train_steps: | 314 | if global_step >= num_training_steps: |
313 | break | 315 | break |
314 | 316 | ||
315 | accelerator.wait_for_everyone() | 317 | accelerator.wait_for_everyone() |