diff options
-rw-r--r-- | train_ti.py | 24 | ||||
-rw-r--r-- | training/common.py | 42 |
2 files changed, 33 insertions, 33 deletions
diff --git a/train_ti.py b/train_ti.py index 8c86586..3f4e739 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -750,8 +750,6 @@ def main(): | |||
750 | args.sample_steps | 750 | args.sample_steps |
751 | ) | 751 | ) |
752 | 752 | ||
753 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) | ||
754 | |||
755 | if args.find_lr: | 753 | if args.find_lr: |
756 | lr_scheduler = None | 754 | lr_scheduler = None |
757 | else: | 755 | else: |
@@ -765,9 +763,9 @@ def main(): | |||
765 | warmup_exp=args.lr_warmup_exp, | 763 | warmup_exp=args.lr_warmup_exp, |
766 | annealing_exp=args.lr_annealing_exp, | 764 | annealing_exp=args.lr_annealing_exp, |
767 | cycles=args.lr_cycles, | 765 | cycles=args.lr_cycles, |
766 | train_epochs=args.num_train_epochs, | ||
768 | warmup_epochs=args.lr_warmup_epochs, | 767 | warmup_epochs=args.lr_warmup_epochs, |
769 | num_train_epochs=args.num_train_epochs, | 768 | num_training_steps_per_epoch=len(train_dataloader), |
770 | num_update_steps_per_epoch=num_update_steps_per_epoch, | ||
771 | gradient_accumulation_steps=args.gradient_accumulation_steps | 769 | gradient_accumulation_steps=args.gradient_accumulation_steps |
772 | ) | 770 | ) |
773 | 771 | ||
@@ -826,13 +824,13 @@ def main(): | |||
826 | return {"ema_decay": ema_embeddings.decay} | 824 | return {"ema_decay": ema_embeddings.decay} |
827 | return {} | 825 | return {} |
828 | 826 | ||
829 | loop = partial( | 827 | loss_step_ = partial( |
830 | loss_step, | 828 | loss_step, |
831 | vae, | 829 | vae, |
832 | noise_scheduler, | 830 | noise_scheduler, |
833 | unet, | 831 | unet, |
834 | text_encoder, | 832 | text_encoder, |
835 | args.num_class_images, | 833 | args.num_class_images != 0, |
836 | args.prior_loss_weight, | 834 | args.prior_loss_weight, |
837 | args.seed, | 835 | args.seed, |
838 | ) | 836 | ) |
@@ -869,12 +867,12 @@ def main(): | |||
869 | 867 | ||
870 | if args.find_lr: | 868 | if args.find_lr: |
871 | lr_finder = LRFinder( | 869 | lr_finder = LRFinder( |
872 | accelerator, | 870 | accelerator=accelerator, |
873 | text_encoder, | 871 | optimizer=optimizer, |
874 | optimizer, | 872 | model=text_encoder, |
875 | train_dataloader, | 873 | train_dataloader=train_dataloader, |
876 | val_dataloader, | 874 | val_dataloader=val_dataloader, |
877 | loop, | 875 | loss_step=loss_step_, |
878 | on_train=on_train, | 876 | on_train=on_train, |
879 | on_eval=on_eval, | 877 | on_eval=on_eval, |
880 | on_after_optimize=on_after_optimize, | 878 | on_after_optimize=on_after_optimize, |
@@ -892,7 +890,7 @@ def main(): | |||
892 | checkpointer=checkpointer, | 890 | checkpointer=checkpointer, |
893 | train_dataloader=train_dataloader, | 891 | train_dataloader=train_dataloader, |
894 | val_dataloader=val_dataloader, | 892 | val_dataloader=val_dataloader, |
895 | loss_step=loop, | 893 | loss_step=loss_step_, |
896 | sample_frequency=args.sample_frequency, | 894 | sample_frequency=args.sample_frequency, |
897 | sample_steps=args.sample_steps, | 895 | sample_steps=args.sample_steps, |
898 | checkpoint_frequency=args.checkpoint_frequency, | 896 | checkpoint_frequency=args.checkpoint_frequency, |
diff --git a/training/common.py b/training/common.py index 842ac07..180396e 100644 --- a/training/common.py +++ b/training/common.py | |||
@@ -36,21 +36,24 @@ def get_scheduler( | |||
36 | warmup_exp: int, | 36 | warmup_exp: int, |
37 | annealing_exp: int, | 37 | annealing_exp: int, |
38 | cycles: int, | 38 | cycles: int, |
39 | train_epochs: int, | ||
39 | warmup_epochs: int, | 40 | warmup_epochs: int, |
40 | optimizer: torch.optim.Optimizer, | 41 | optimizer: torch.optim.Optimizer, |
41 | num_train_epochs: int, | 42 | num_training_steps_per_epoch: int, |
42 | num_update_steps_per_epoch: int, | ||
43 | gradient_accumulation_steps: int, | 43 | gradient_accumulation_steps: int, |
44 | ): | 44 | ): |
45 | num_train_steps = num_train_epochs * num_update_steps_per_epoch | 45 | num_training_steps_per_epoch = math.ceil( |
46 | warmup_steps = warmup_epochs * num_update_steps_per_epoch * gradient_accumulation_steps | 46 | num_training_steps_per_epoch / gradient_accumulation_steps |
47 | ) * gradient_accumulation_steps | ||
48 | num_training_steps = train_epochs * num_training_steps_per_epoch | ||
49 | num_warmup_steps = warmup_epochs * num_training_steps_per_epoch | ||
47 | 50 | ||
48 | if id == "one_cycle": | 51 | if id == "one_cycle": |
49 | min_lr = 0.04 if min_lr is None else min_lr / lr | 52 | min_lr = 0.04 if min_lr is None else min_lr / lr |
50 | 53 | ||
51 | lr_scheduler = get_one_cycle_schedule( | 54 | lr_scheduler = get_one_cycle_schedule( |
52 | optimizer=optimizer, | 55 | optimizer=optimizer, |
53 | num_training_steps=num_train_steps * gradient_accumulation_steps, | 56 | num_training_steps=num_training_steps, |
54 | warmup=warmup_func, | 57 | warmup=warmup_func, |
55 | annealing=annealing_func, | 58 | annealing=annealing_func, |
56 | warmup_exp=warmup_exp, | 59 | warmup_exp=warmup_exp, |
@@ -58,21 +61,21 @@ def get_scheduler( | |||
58 | min_lr=min_lr, | 61 | min_lr=min_lr, |
59 | ) | 62 | ) |
60 | elif id == "cosine_with_restarts": | 63 | elif id == "cosine_with_restarts": |
61 | cycles = cycles if cycles is not None else math.ceil( | 64 | if cycles is None: |
62 | math.sqrt(((num_train_steps - warmup_steps) / num_update_steps_per_epoch))) | 65 | cycles = math.ceil(math.sqrt(((num_training_steps - num_warmup_steps) / num_training_steps_per_epoch))) |
63 | 66 | ||
64 | lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( | 67 | lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( |
65 | optimizer=optimizer, | 68 | optimizer=optimizer, |
66 | num_warmup_steps=warmup_steps, | 69 | num_warmup_steps=num_warmup_steps, |
67 | num_training_steps=num_train_steps * gradient_accumulation_steps, | 70 | num_training_steps=num_training_steps, |
68 | num_cycles=cycles, | 71 | num_cycles=cycles, |
69 | ) | 72 | ) |
70 | else: | 73 | else: |
71 | lr_scheduler = get_scheduler_( | 74 | lr_scheduler = get_scheduler_( |
72 | id, | 75 | id, |
73 | optimizer=optimizer, | 76 | optimizer=optimizer, |
74 | num_warmup_steps=warmup_steps, | 77 | num_warmup_steps=num_warmup_steps, |
75 | num_training_steps=num_train_steps * gradient_accumulation_steps, | 78 | num_training_steps=num_training_steps, |
76 | ) | 79 | ) |
77 | 80 | ||
78 | return lr_scheduler | 81 | return lr_scheduler |
@@ -135,7 +138,7 @@ def loss_step( | |||
135 | noise_scheduler: DDPMScheduler, | 138 | noise_scheduler: DDPMScheduler, |
136 | unet: UNet2DConditionModel, | 139 | unet: UNet2DConditionModel, |
137 | text_encoder: CLIPTextModel, | 140 | text_encoder: CLIPTextModel, |
138 | num_class_images: int, | 141 | with_prior: bool, |
139 | prior_loss_weight: float, | 142 | prior_loss_weight: float, |
140 | seed: int, | 143 | seed: int, |
141 | step: int, | 144 | step: int, |
@@ -184,7 +187,7 @@ def loss_step( | |||
184 | else: | 187 | else: |
185 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") | 188 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") |
186 | 189 | ||
187 | if num_class_images != 0: | 190 | if with_prior: |
188 | # Chunk the noise and model_pred into two parts and compute the loss on each part separately. | 191 | # Chunk the noise and model_pred into two parts and compute the loss on each part separately. |
189 | model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) | 192 | model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) |
190 | target, target_prior = torch.chunk(target, 2, dim=0) | 193 | target, target_prior = torch.chunk(target, 2, dim=0) |
@@ -226,11 +229,10 @@ def train_loop( | |||
226 | on_after_optimize: Callable[[float], None] = noop, | 229 | on_after_optimize: Callable[[float], None] = noop, |
227 | on_eval: Callable[[], _GeneratorContextManager] = nullcontext | 230 | on_eval: Callable[[], _GeneratorContextManager] = nullcontext |
228 | ): | 231 | ): |
229 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps) | 232 | num_training_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps) |
230 | num_train_steps = num_epochs * num_update_steps_per_epoch | ||
231 | |||
232 | num_val_steps_per_epoch = len(val_dataloader) | 233 | num_val_steps_per_epoch = len(val_dataloader) |
233 | num_epochs = math.ceil(num_train_steps / num_update_steps_per_epoch) | 234 | |
235 | num_training_steps = num_training_steps_per_epoch * num_epochs | ||
234 | num_val_steps = num_val_steps_per_epoch * num_epochs | 236 | num_val_steps = num_val_steps_per_epoch * num_epochs |
235 | 237 | ||
236 | global_step = 0 | 238 | global_step = 0 |
@@ -244,14 +246,14 @@ def train_loop( | |||
244 | max_acc_val = 0.0 | 246 | max_acc_val = 0.0 |
245 | 247 | ||
246 | local_progress_bar = tqdm( | 248 | local_progress_bar = tqdm( |
247 | range(num_update_steps_per_epoch + num_val_steps_per_epoch), | 249 | range(num_training_steps_per_epoch + num_val_steps_per_epoch), |
248 | disable=not accelerator.is_local_main_process, | 250 | disable=not accelerator.is_local_main_process, |
249 | dynamic_ncols=True | 251 | dynamic_ncols=True |
250 | ) | 252 | ) |
251 | local_progress_bar.set_description(f"Epoch 1 / {num_epochs}") | 253 | local_progress_bar.set_description(f"Epoch 1 / {num_epochs}") |
252 | 254 | ||
253 | global_progress_bar = tqdm( | 255 | global_progress_bar = tqdm( |
254 | range(num_train_steps + num_val_steps), | 256 | range(num_training_steps + num_val_steps), |
255 | disable=not accelerator.is_local_main_process, | 257 | disable=not accelerator.is_local_main_process, |
256 | dynamic_ncols=True | 258 | dynamic_ncols=True |
257 | ) | 259 | ) |
@@ -309,7 +311,7 @@ def train_loop( | |||
309 | 311 | ||
310 | local_progress_bar.set_postfix(**logs) | 312 | local_progress_bar.set_postfix(**logs) |
311 | 313 | ||
312 | if global_step >= num_train_steps: | 314 | if global_step >= num_training_steps: |
313 | break | 315 | break |
314 | 316 | ||
315 | accelerator.wait_for_everyone() | 317 | accelerator.wait_for_everyone() |