From 24f7a14defdbc050907ad282ebdc7ea6f6591363 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Thu, 13 Apr 2023 09:13:35 +0200 Subject: Added cycle LR decay --- train_ti.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) (limited to 'train_ti.py') diff --git a/train_ti.py b/train_ti.py index 94ddbb6..d931db6 100644 --- a/train_ti.py +++ b/train_ti.py @@ -901,6 +901,8 @@ def main(): if accelerator.is_main_process: accelerator.init_trackers(project) + sample_output_dir = output_dir / project / "samples" + while True: if training_iter >= args.auto_cycles: response = input("Run another cycle? [y/n] ") @@ -933,8 +935,7 @@ def main(): mid_point=args.lr_mid_point, ) - sample_output_dir = output_dir / project / f"{training_iter + 1}" / "samples" - checkpoint_output_dir = output_dir / project / f"{training_iter + 1}" / "checkpoints" + checkpoint_output_dir = output_dir / project / f"checkpoints_{training_iter + 1}" trainer( train_dataloader=datamodule.train_dataloader, @@ -943,6 +944,7 @@ def main(): lr_scheduler=lr_scheduler, num_train_epochs=num_train_epochs, global_step_offset=training_iter * num_train_steps, + initial_samples=training_iter == 0, # -- group_labels=["emb"], checkpoint_output_dir=checkpoint_output_dir, @@ -953,7 +955,7 @@ def main(): ) training_iter += 1 - if args.learning_rate is not None: + if learning_rate is not None: learning_rate *= args.cycle_decay accelerator.end_training() -- cgit v1.2.3-54-g00ecf