summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--train_ti.py24
-rw-r--r--training/common.py42
2 files changed, 33 insertions, 33 deletions
diff --git a/train_ti.py b/train_ti.py
index 8c86586..3f4e739 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -750,8 +750,6 @@ def main():
750 args.sample_steps 750 args.sample_steps
751 ) 751 )
752 752
753 num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
754
755 if args.find_lr: 753 if args.find_lr:
756 lr_scheduler = None 754 lr_scheduler = None
757 else: 755 else:
@@ -765,9 +763,9 @@ def main():
765 warmup_exp=args.lr_warmup_exp, 763 warmup_exp=args.lr_warmup_exp,
766 annealing_exp=args.lr_annealing_exp, 764 annealing_exp=args.lr_annealing_exp,
767 cycles=args.lr_cycles, 765 cycles=args.lr_cycles,
766 train_epochs=args.num_train_epochs,
768 warmup_epochs=args.lr_warmup_epochs, 767 warmup_epochs=args.lr_warmup_epochs,
769 num_train_epochs=args.num_train_epochs, 768 num_training_steps_per_epoch=len(train_dataloader),
770 num_update_steps_per_epoch=num_update_steps_per_epoch,
771 gradient_accumulation_steps=args.gradient_accumulation_steps 769 gradient_accumulation_steps=args.gradient_accumulation_steps
772 ) 770 )
773 771
@@ -826,13 +824,13 @@ def main():
826 return {"ema_decay": ema_embeddings.decay} 824 return {"ema_decay": ema_embeddings.decay}
827 return {} 825 return {}
828 826
829 loop = partial( 827 loss_step_ = partial(
830 loss_step, 828 loss_step,
831 vae, 829 vae,
832 noise_scheduler, 830 noise_scheduler,
833 unet, 831 unet,
834 text_encoder, 832 text_encoder,
835 args.num_class_images, 833 args.num_class_images != 0,
836 args.prior_loss_weight, 834 args.prior_loss_weight,
837 args.seed, 835 args.seed,
838 ) 836 )
@@ -869,12 +867,12 @@ def main():
869 867
870 if args.find_lr: 868 if args.find_lr:
871 lr_finder = LRFinder( 869 lr_finder = LRFinder(
872 accelerator, 870 accelerator=accelerator,
873 text_encoder, 871 optimizer=optimizer,
874 optimizer, 872 model=text_encoder,
875 train_dataloader, 873 train_dataloader=train_dataloader,
876 val_dataloader, 874 val_dataloader=val_dataloader,
877 loop, 875 loss_step=loss_step_,
878 on_train=on_train, 876 on_train=on_train,
879 on_eval=on_eval, 877 on_eval=on_eval,
880 on_after_optimize=on_after_optimize, 878 on_after_optimize=on_after_optimize,
@@ -892,7 +890,7 @@ def main():
892 checkpointer=checkpointer, 890 checkpointer=checkpointer,
893 train_dataloader=train_dataloader, 891 train_dataloader=train_dataloader,
894 val_dataloader=val_dataloader, 892 val_dataloader=val_dataloader,
895 loss_step=loop, 893 loss_step=loss_step_,
896 sample_frequency=args.sample_frequency, 894 sample_frequency=args.sample_frequency,
897 sample_steps=args.sample_steps, 895 sample_steps=args.sample_steps,
898 checkpoint_frequency=args.checkpoint_frequency, 896 checkpoint_frequency=args.checkpoint_frequency,
diff --git a/training/common.py b/training/common.py
index 842ac07..180396e 100644
--- a/training/common.py
+++ b/training/common.py
@@ -36,21 +36,24 @@ def get_scheduler(
36 warmup_exp: int, 36 warmup_exp: int,
37 annealing_exp: int, 37 annealing_exp: int,
38 cycles: int, 38 cycles: int,
39 train_epochs: int,
39 warmup_epochs: int, 40 warmup_epochs: int,
40 optimizer: torch.optim.Optimizer, 41 optimizer: torch.optim.Optimizer,
41 num_train_epochs: int, 42 num_training_steps_per_epoch: int,
42 num_update_steps_per_epoch: int,
43 gradient_accumulation_steps: int, 43 gradient_accumulation_steps: int,
44): 44):
45 num_train_steps = num_train_epochs * num_update_steps_per_epoch 45 num_training_steps_per_epoch = math.ceil(
46 warmup_steps = warmup_epochs * num_update_steps_per_epoch * gradient_accumulation_steps 46 num_training_steps_per_epoch / gradient_accumulation_steps
47 ) * gradient_accumulation_steps
48 num_training_steps = train_epochs * num_training_steps_per_epoch
49 num_warmup_steps = warmup_epochs * num_training_steps_per_epoch
47 50
48 if id == "one_cycle": 51 if id == "one_cycle":
49 min_lr = 0.04 if min_lr is None else min_lr / lr 52 min_lr = 0.04 if min_lr is None else min_lr / lr
50 53
51 lr_scheduler = get_one_cycle_schedule( 54 lr_scheduler = get_one_cycle_schedule(
52 optimizer=optimizer, 55 optimizer=optimizer,
53 num_training_steps=num_train_steps * gradient_accumulation_steps, 56 num_training_steps=num_training_steps,
54 warmup=warmup_func, 57 warmup=warmup_func,
55 annealing=annealing_func, 58 annealing=annealing_func,
56 warmup_exp=warmup_exp, 59 warmup_exp=warmup_exp,
@@ -58,21 +61,21 @@ def get_scheduler(
58 min_lr=min_lr, 61 min_lr=min_lr,
59 ) 62 )
60 elif id == "cosine_with_restarts": 63 elif id == "cosine_with_restarts":
61 cycles = cycles if cycles is not None else math.ceil( 64 if cycles is None:
62 math.sqrt(((num_train_steps - warmup_steps) / num_update_steps_per_epoch))) 65 cycles = math.ceil(math.sqrt(((num_training_steps - num_warmup_steps) / num_training_steps_per_epoch)))
63 66
64 lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( 67 lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup(
65 optimizer=optimizer, 68 optimizer=optimizer,
66 num_warmup_steps=warmup_steps, 69 num_warmup_steps=num_warmup_steps,
67 num_training_steps=num_train_steps * gradient_accumulation_steps, 70 num_training_steps=num_training_steps,
68 num_cycles=cycles, 71 num_cycles=cycles,
69 ) 72 )
70 else: 73 else:
71 lr_scheduler = get_scheduler_( 74 lr_scheduler = get_scheduler_(
72 id, 75 id,
73 optimizer=optimizer, 76 optimizer=optimizer,
74 num_warmup_steps=warmup_steps, 77 num_warmup_steps=num_warmup_steps,
75 num_training_steps=num_train_steps * gradient_accumulation_steps, 78 num_training_steps=num_training_steps,
76 ) 79 )
77 80
78 return lr_scheduler 81 return lr_scheduler
@@ -135,7 +138,7 @@ def loss_step(
135 noise_scheduler: DDPMScheduler, 138 noise_scheduler: DDPMScheduler,
136 unet: UNet2DConditionModel, 139 unet: UNet2DConditionModel,
137 text_encoder: CLIPTextModel, 140 text_encoder: CLIPTextModel,
138 num_class_images: int, 141 with_prior: bool,
139 prior_loss_weight: float, 142 prior_loss_weight: float,
140 seed: int, 143 seed: int,
141 step: int, 144 step: int,
@@ -184,7 +187,7 @@ def loss_step(
184 else: 187 else:
185 raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") 188 raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
186 189
187 if num_class_images != 0: 190 if with_prior:
188 # Chunk the noise and model_pred into two parts and compute the loss on each part separately. 191 # Chunk the noise and model_pred into two parts and compute the loss on each part separately.
189 model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) 192 model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
190 target, target_prior = torch.chunk(target, 2, dim=0) 193 target, target_prior = torch.chunk(target, 2, dim=0)
@@ -226,11 +229,10 @@ def train_loop(
226 on_after_optimize: Callable[[float], None] = noop, 229 on_after_optimize: Callable[[float], None] = noop,
227 on_eval: Callable[[], _GeneratorContextManager] = nullcontext 230 on_eval: Callable[[], _GeneratorContextManager] = nullcontext
228): 231):
229 num_update_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps) 232 num_training_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps)
230 num_train_steps = num_epochs * num_update_steps_per_epoch
231
232 num_val_steps_per_epoch = len(val_dataloader) 233 num_val_steps_per_epoch = len(val_dataloader)
233 num_epochs = math.ceil(num_train_steps / num_update_steps_per_epoch) 234
235 num_training_steps = num_training_steps_per_epoch * num_epochs
234 num_val_steps = num_val_steps_per_epoch * num_epochs 236 num_val_steps = num_val_steps_per_epoch * num_epochs
235 237
236 global_step = 0 238 global_step = 0
@@ -244,14 +246,14 @@ def train_loop(
244 max_acc_val = 0.0 246 max_acc_val = 0.0
245 247
246 local_progress_bar = tqdm( 248 local_progress_bar = tqdm(
247 range(num_update_steps_per_epoch + num_val_steps_per_epoch), 249 range(num_training_steps_per_epoch + num_val_steps_per_epoch),
248 disable=not accelerator.is_local_main_process, 250 disable=not accelerator.is_local_main_process,
249 dynamic_ncols=True 251 dynamic_ncols=True
250 ) 252 )
251 local_progress_bar.set_description(f"Epoch 1 / {num_epochs}") 253 local_progress_bar.set_description(f"Epoch 1 / {num_epochs}")
252 254
253 global_progress_bar = tqdm( 255 global_progress_bar = tqdm(
254 range(num_train_steps + num_val_steps), 256 range(num_training_steps + num_val_steps),
255 disable=not accelerator.is_local_main_process, 257 disable=not accelerator.is_local_main_process,
256 dynamic_ncols=True 258 dynamic_ncols=True
257 ) 259 )
@@ -309,7 +311,7 @@ def train_loop(
309 311
310 local_progress_bar.set_postfix(**logs) 312 local_progress_bar.set_postfix(**logs)
311 313
312 if global_step >= num_train_steps: 314 if global_step >= num_training_steps:
313 break 315 break
314 316
315 accelerator.wait_for_everyone() 317 accelerator.wait_for_everyone()