summaryrefslogtreecommitdiffstats
path: root/training
diff options
context:
space:
mode:
Diffstat (limited to 'training')
-rw-r--r--training/common.py42
1 files changed, 22 insertions, 20 deletions
diff --git a/training/common.py b/training/common.py
index 842ac07..180396e 100644
--- a/training/common.py
+++ b/training/common.py
@@ -36,21 +36,24 @@ def get_scheduler(
36 warmup_exp: int, 36 warmup_exp: int,
37 annealing_exp: int, 37 annealing_exp: int,
38 cycles: int, 38 cycles: int,
39 train_epochs: int,
39 warmup_epochs: int, 40 warmup_epochs: int,
40 optimizer: torch.optim.Optimizer, 41 optimizer: torch.optim.Optimizer,
41 num_train_epochs: int, 42 num_training_steps_per_epoch: int,
42 num_update_steps_per_epoch: int,
43 gradient_accumulation_steps: int, 43 gradient_accumulation_steps: int,
44): 44):
45 num_train_steps = num_train_epochs * num_update_steps_per_epoch 45 num_training_steps_per_epoch = math.ceil(
46 warmup_steps = warmup_epochs * num_update_steps_per_epoch * gradient_accumulation_steps 46 num_training_steps_per_epoch / gradient_accumulation_steps
47 ) * gradient_accumulation_steps
48 num_training_steps = train_epochs * num_training_steps_per_epoch
49 num_warmup_steps = warmup_epochs * num_training_steps_per_epoch
47 50
48 if id == "one_cycle": 51 if id == "one_cycle":
49 min_lr = 0.04 if min_lr is None else min_lr / lr 52 min_lr = 0.04 if min_lr is None else min_lr / lr
50 53
51 lr_scheduler = get_one_cycle_schedule( 54 lr_scheduler = get_one_cycle_schedule(
52 optimizer=optimizer, 55 optimizer=optimizer,
53 num_training_steps=num_train_steps * gradient_accumulation_steps, 56 num_training_steps=num_training_steps,
54 warmup=warmup_func, 57 warmup=warmup_func,
55 annealing=annealing_func, 58 annealing=annealing_func,
56 warmup_exp=warmup_exp, 59 warmup_exp=warmup_exp,
@@ -58,21 +61,21 @@ def get_scheduler(
58 min_lr=min_lr, 61 min_lr=min_lr,
59 ) 62 )
60 elif id == "cosine_with_restarts": 63 elif id == "cosine_with_restarts":
61 cycles = cycles if cycles is not None else math.ceil( 64 if cycles is None:
62 math.sqrt(((num_train_steps - warmup_steps) / num_update_steps_per_epoch))) 65 cycles = math.ceil(math.sqrt(((num_training_steps - num_warmup_steps) / num_training_steps_per_epoch)))
63 66
64 lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( 67 lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup(
65 optimizer=optimizer, 68 optimizer=optimizer,
66 num_warmup_steps=warmup_steps, 69 num_warmup_steps=num_warmup_steps,
67 num_training_steps=num_train_steps * gradient_accumulation_steps, 70 num_training_steps=num_training_steps,
68 num_cycles=cycles, 71 num_cycles=cycles,
69 ) 72 )
70 else: 73 else:
71 lr_scheduler = get_scheduler_( 74 lr_scheduler = get_scheduler_(
72 id, 75 id,
73 optimizer=optimizer, 76 optimizer=optimizer,
74 num_warmup_steps=warmup_steps, 77 num_warmup_steps=num_warmup_steps,
75 num_training_steps=num_train_steps * gradient_accumulation_steps, 78 num_training_steps=num_training_steps,
76 ) 79 )
77 80
78 return lr_scheduler 81 return lr_scheduler
@@ -135,7 +138,7 @@ def loss_step(
135 noise_scheduler: DDPMScheduler, 138 noise_scheduler: DDPMScheduler,
136 unet: UNet2DConditionModel, 139 unet: UNet2DConditionModel,
137 text_encoder: CLIPTextModel, 140 text_encoder: CLIPTextModel,
138 num_class_images: int, 141 with_prior: bool,
139 prior_loss_weight: float, 142 prior_loss_weight: float,
140 seed: int, 143 seed: int,
141 step: int, 144 step: int,
@@ -184,7 +187,7 @@ def loss_step(
184 else: 187 else:
185 raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") 188 raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
186 189
187 if num_class_images != 0: 190 if with_prior:
188 # Chunk the noise and model_pred into two parts and compute the loss on each part separately. 191 # Chunk the noise and model_pred into two parts and compute the loss on each part separately.
189 model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) 192 model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
190 target, target_prior = torch.chunk(target, 2, dim=0) 193 target, target_prior = torch.chunk(target, 2, dim=0)
@@ -226,11 +229,10 @@ def train_loop(
226 on_after_optimize: Callable[[float], None] = noop, 229 on_after_optimize: Callable[[float], None] = noop,
227 on_eval: Callable[[], _GeneratorContextManager] = nullcontext 230 on_eval: Callable[[], _GeneratorContextManager] = nullcontext
228): 231):
229 num_update_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps) 232 num_training_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps)
230 num_train_steps = num_epochs * num_update_steps_per_epoch
231
232 num_val_steps_per_epoch = len(val_dataloader) 233 num_val_steps_per_epoch = len(val_dataloader)
233 num_epochs = math.ceil(num_train_steps / num_update_steps_per_epoch) 234
235 num_training_steps = num_training_steps_per_epoch * num_epochs
234 num_val_steps = num_val_steps_per_epoch * num_epochs 236 num_val_steps = num_val_steps_per_epoch * num_epochs
235 237
236 global_step = 0 238 global_step = 0
@@ -244,14 +246,14 @@ def train_loop(
244 max_acc_val = 0.0 246 max_acc_val = 0.0
245 247
246 local_progress_bar = tqdm( 248 local_progress_bar = tqdm(
247 range(num_update_steps_per_epoch + num_val_steps_per_epoch), 249 range(num_training_steps_per_epoch + num_val_steps_per_epoch),
248 disable=not accelerator.is_local_main_process, 250 disable=not accelerator.is_local_main_process,
249 dynamic_ncols=True 251 dynamic_ncols=True
250 ) 252 )
251 local_progress_bar.set_description(f"Epoch 1 / {num_epochs}") 253 local_progress_bar.set_description(f"Epoch 1 / {num_epochs}")
252 254
253 global_progress_bar = tqdm( 255 global_progress_bar = tqdm(
254 range(num_train_steps + num_val_steps), 256 range(num_training_steps + num_val_steps),
255 disable=not accelerator.is_local_main_process, 257 disable=not accelerator.is_local_main_process,
256 dynamic_ncols=True 258 dynamic_ncols=True
257 ) 259 )
@@ -309,7 +311,7 @@ def train_loop(
309 311
310 local_progress_bar.set_postfix(**logs) 312 local_progress_bar.set_postfix(**logs)
311 313
312 if global_step >= num_train_steps: 314 if global_step >= num_training_steps:
313 break 315 break
314 316
315 accelerator.wait_for_everyone() 317 accelerator.wait_for_everyone()