From a7dc66ae0974886a6c6a4c50def1b733bc04525a Mon Sep 17 00:00:00 2001 From: Volpeon Date: Tue, 11 Apr 2023 17:02:22 +0200 Subject: Update --- train_ti.py | 23 ++++++++++++++++------- 1 file changed, 16 insertions(+), 7 deletions(-) (limited to 'train_ti.py') diff --git a/train_ti.py b/train_ti.py index 009495b..d7878cd 100644 --- a/train_ti.py +++ b/train_ti.py @@ -870,20 +870,26 @@ def main(): args.num_train_steps / len(datamodule.train_dataset) ) * args.gradient_accumulation_steps sample_frequency = math.ceil(num_train_epochs * (sample_frequency / args.num_train_steps)) + num_training_steps_per_epoch = math.ceil(len(datamodule.train_dataset) / args.gradient_accumulation_steps) + num_train_steps = num_training_steps_per_epoch * num_train_epochs if args.sample_num is not None: sample_frequency = math.ceil(num_train_epochs / args.sample_num) training_iter = 0 + project = placeholder_tokens[0] if len(placeholder_tokens) == 1 else "ti" + + if accelerator.is_main_process: + accelerator.init_trackers(project) + while True: - training_iter += 1 - if training_iter > args.auto_cycles: + if training_iter >= args.auto_cycles: response = input("Run another cycle? [y/n] ") if response.lower().strip() == "n": break print("") - print(f"------------ TI cycle {training_iter} ------------") + print(f"------------ TI cycle {training_iter + 1} ------------") print("") optimizer = create_optimizer( @@ -908,17 +914,16 @@ def main(): mid_point=args.lr_mid_point, ) - project = f"{placeholder_tokens[0]}_{training_iter}" if len(placeholder_tokens) == 1 else f"{training_iter}" - sample_output_dir = output_dir / project / "samples" - checkpoint_output_dir = output_dir / project / "checkpoints" + sample_output_dir = output_dir / project / f"{training_iter + 1}" / "samples" + checkpoint_output_dir = output_dir / project / f"{training_iter + 1}" / "checkpoints" trainer( - project=project, train_dataloader=datamodule.train_dataloader, val_dataloader=datamodule.val_dataloader, optimizer=optimizer, lr_scheduler=lr_scheduler, num_train_epochs=num_train_epochs, + global_step_offset=training_iter * num_train_steps, # -- group_labels=["emb"], checkpoint_output_dir=checkpoint_output_dir, @@ -928,6 +933,10 @@ def main(): placeholder_token_ids=placeholder_token_ids, ) + training_iter += 1 + + accelerator.end_training() + if not args.sequential: run(0, args.placeholder_tokens, args.initializer_tokens, args.num_vectors, args.train_data_template) else: -- cgit v1.2.3-54-g00ecf