diff options
Diffstat (limited to 'training')
| -rw-r--r-- | training/common.py | 42 |
1 files changed, 22 insertions, 20 deletions
diff --git a/training/common.py b/training/common.py index 842ac07..180396e 100644 --- a/training/common.py +++ b/training/common.py | |||
| @@ -36,21 +36,24 @@ def get_scheduler( | |||
| 36 | warmup_exp: int, | 36 | warmup_exp: int, |
| 37 | annealing_exp: int, | 37 | annealing_exp: int, |
| 38 | cycles: int, | 38 | cycles: int, |
| 39 | train_epochs: int, | ||
| 39 | warmup_epochs: int, | 40 | warmup_epochs: int, |
| 40 | optimizer: torch.optim.Optimizer, | 41 | optimizer: torch.optim.Optimizer, |
| 41 | num_train_epochs: int, | 42 | num_training_steps_per_epoch: int, |
| 42 | num_update_steps_per_epoch: int, | ||
| 43 | gradient_accumulation_steps: int, | 43 | gradient_accumulation_steps: int, |
| 44 | ): | 44 | ): |
| 45 | num_train_steps = num_train_epochs * num_update_steps_per_epoch | 45 | num_training_steps_per_epoch = math.ceil( |
| 46 | warmup_steps = warmup_epochs * num_update_steps_per_epoch * gradient_accumulation_steps | 46 | num_training_steps_per_epoch / gradient_accumulation_steps |
| 47 | ) * gradient_accumulation_steps | ||
| 48 | num_training_steps = train_epochs * num_training_steps_per_epoch | ||
| 49 | num_warmup_steps = warmup_epochs * num_training_steps_per_epoch | ||
| 47 | 50 | ||
| 48 | if id == "one_cycle": | 51 | if id == "one_cycle": |
| 49 | min_lr = 0.04 if min_lr is None else min_lr / lr | 52 | min_lr = 0.04 if min_lr is None else min_lr / lr |
| 50 | 53 | ||
| 51 | lr_scheduler = get_one_cycle_schedule( | 54 | lr_scheduler = get_one_cycle_schedule( |
| 52 | optimizer=optimizer, | 55 | optimizer=optimizer, |
| 53 | num_training_steps=num_train_steps * gradient_accumulation_steps, | 56 | num_training_steps=num_training_steps, |
| 54 | warmup=warmup_func, | 57 | warmup=warmup_func, |
| 55 | annealing=annealing_func, | 58 | annealing=annealing_func, |
| 56 | warmup_exp=warmup_exp, | 59 | warmup_exp=warmup_exp, |
| @@ -58,21 +61,21 @@ def get_scheduler( | |||
| 58 | min_lr=min_lr, | 61 | min_lr=min_lr, |
| 59 | ) | 62 | ) |
| 60 | elif id == "cosine_with_restarts": | 63 | elif id == "cosine_with_restarts": |
| 61 | cycles = cycles if cycles is not None else math.ceil( | 64 | if cycles is None: |
| 62 | math.sqrt(((num_train_steps - warmup_steps) / num_update_steps_per_epoch))) | 65 | cycles = math.ceil(math.sqrt(((num_training_steps - num_warmup_steps) / num_training_steps_per_epoch))) |
| 63 | 66 | ||
| 64 | lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( | 67 | lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( |
| 65 | optimizer=optimizer, | 68 | optimizer=optimizer, |
| 66 | num_warmup_steps=warmup_steps, | 69 | num_warmup_steps=num_warmup_steps, |
| 67 | num_training_steps=num_train_steps * gradient_accumulation_steps, | 70 | num_training_steps=num_training_steps, |
| 68 | num_cycles=cycles, | 71 | num_cycles=cycles, |
| 69 | ) | 72 | ) |
| 70 | else: | 73 | else: |
| 71 | lr_scheduler = get_scheduler_( | 74 | lr_scheduler = get_scheduler_( |
| 72 | id, | 75 | id, |
| 73 | optimizer=optimizer, | 76 | optimizer=optimizer, |
| 74 | num_warmup_steps=warmup_steps, | 77 | num_warmup_steps=num_warmup_steps, |
| 75 | num_training_steps=num_train_steps * gradient_accumulation_steps, | 78 | num_training_steps=num_training_steps, |
| 76 | ) | 79 | ) |
| 77 | 80 | ||
| 78 | return lr_scheduler | 81 | return lr_scheduler |
| @@ -135,7 +138,7 @@ def loss_step( | |||
| 135 | noise_scheduler: DDPMScheduler, | 138 | noise_scheduler: DDPMScheduler, |
| 136 | unet: UNet2DConditionModel, | 139 | unet: UNet2DConditionModel, |
| 137 | text_encoder: CLIPTextModel, | 140 | text_encoder: CLIPTextModel, |
| 138 | num_class_images: int, | 141 | with_prior: bool, |
| 139 | prior_loss_weight: float, | 142 | prior_loss_weight: float, |
| 140 | seed: int, | 143 | seed: int, |
| 141 | step: int, | 144 | step: int, |
| @@ -184,7 +187,7 @@ def loss_step( | |||
| 184 | else: | 187 | else: |
| 185 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") | 188 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") |
| 186 | 189 | ||
| 187 | if num_class_images != 0: | 190 | if with_prior: |
| 188 | # Chunk the noise and model_pred into two parts and compute the loss on each part separately. | 191 | # Chunk the noise and model_pred into two parts and compute the loss on each part separately. |
| 189 | model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) | 192 | model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) |
| 190 | target, target_prior = torch.chunk(target, 2, dim=0) | 193 | target, target_prior = torch.chunk(target, 2, dim=0) |
| @@ -226,11 +229,10 @@ def train_loop( | |||
| 226 | on_after_optimize: Callable[[float], None] = noop, | 229 | on_after_optimize: Callable[[float], None] = noop, |
| 227 | on_eval: Callable[[], _GeneratorContextManager] = nullcontext | 230 | on_eval: Callable[[], _GeneratorContextManager] = nullcontext |
| 228 | ): | 231 | ): |
| 229 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps) | 232 | num_training_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps) |
| 230 | num_train_steps = num_epochs * num_update_steps_per_epoch | ||
| 231 | |||
| 232 | num_val_steps_per_epoch = len(val_dataloader) | 233 | num_val_steps_per_epoch = len(val_dataloader) |
| 233 | num_epochs = math.ceil(num_train_steps / num_update_steps_per_epoch) | 234 | |
| 235 | num_training_steps = num_training_steps_per_epoch * num_epochs | ||
| 234 | num_val_steps = num_val_steps_per_epoch * num_epochs | 236 | num_val_steps = num_val_steps_per_epoch * num_epochs |
| 235 | 237 | ||
| 236 | global_step = 0 | 238 | global_step = 0 |
| @@ -244,14 +246,14 @@ def train_loop( | |||
| 244 | max_acc_val = 0.0 | 246 | max_acc_val = 0.0 |
| 245 | 247 | ||
| 246 | local_progress_bar = tqdm( | 248 | local_progress_bar = tqdm( |
| 247 | range(num_update_steps_per_epoch + num_val_steps_per_epoch), | 249 | range(num_training_steps_per_epoch + num_val_steps_per_epoch), |
| 248 | disable=not accelerator.is_local_main_process, | 250 | disable=not accelerator.is_local_main_process, |
| 249 | dynamic_ncols=True | 251 | dynamic_ncols=True |
| 250 | ) | 252 | ) |
| 251 | local_progress_bar.set_description(f"Epoch 1 / {num_epochs}") | 253 | local_progress_bar.set_description(f"Epoch 1 / {num_epochs}") |
| 252 | 254 | ||
| 253 | global_progress_bar = tqdm( | 255 | global_progress_bar = tqdm( |
| 254 | range(num_train_steps + num_val_steps), | 256 | range(num_training_steps + num_val_steps), |
| 255 | disable=not accelerator.is_local_main_process, | 257 | disable=not accelerator.is_local_main_process, |
| 256 | dynamic_ncols=True | 258 | dynamic_ncols=True |
| 257 | ) | 259 | ) |
| @@ -309,7 +311,7 @@ def train_loop( | |||
| 309 | 311 | ||
| 310 | local_progress_bar.set_postfix(**logs) | 312 | local_progress_bar.set_postfix(**logs) |
| 311 | 313 | ||
| 312 | if global_step >= num_train_steps: | 314 | if global_step >= num_training_steps: |
| 313 | break | 315 | break |
| 314 | 316 | ||
| 315 | accelerator.wait_for_everyone() | 317 | accelerator.wait_for_everyone() |
