From 810e9b3efeb99e76170486bdbb0f33a67e265dee Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sun, 9 Apr 2023 09:13:24 +0200 Subject: Made Lora script interactive --- train_lora.py | 111 +++++++++++++++++++++++++++++++++++++--------------------- 1 file changed, 71 insertions(+), 40 deletions(-) diff --git a/train_lora.py b/train_lora.py index d8a4880..f1e7ec7 100644 --- a/train_lora.py +++ b/train_lora.py @@ -158,6 +158,12 @@ def parse_args(): default=0, help="Tag dropout probability.", ) + parser.add_argument( + "--pti_tag_dropout", + type=float, + default=0, + help="Tag dropout probability.", + ) parser.add_argument( "--no_tag_shuffle", action="store_true", @@ -891,7 +897,6 @@ def main(): progressive_buckets=args.progressive_buckets, bucket_step_size=args.bucket_step_size, bucket_max_pixels=args.bucket_max_pixels, - dropout=args.tag_dropout, shuffle=not args.no_tag_shuffle, template_key=args.train_data_template, valid_set_size=args.valid_set_size, @@ -919,12 +924,9 @@ def main(): # -------------------------------------------------------------------------------- if len(args.placeholder_tokens) != 0: - pti_output_dir = output_dir / "pti" - pti_checkpoint_output_dir = pti_output_dir / "model" - pti_sample_output_dir = pti_output_dir / "samples" - pti_datamodule = create_datamodule( batch_size=args.pti_batch_size, + dropout=args.pti_tag_dropout, filter=partial(keyword_filter, args.filter_tokens, args.collection, args.exclude_collections), ) pti_datamodule.setup() @@ -955,22 +957,38 @@ def main(): train_epochs=num_pti_epochs, ) - trainer( - strategy=lora_strategy, - pti_mode=True, - project="pti", - train_dataloader=pti_datamodule.train_dataloader, - val_dataloader=pti_datamodule.val_dataloader, - optimizer=pti_optimizer, - lr_scheduler=pti_lr_scheduler, - num_train_epochs=num_pti_epochs, - gradient_accumulation_steps=args.pti_gradient_accumulation_steps, - # -- - group_labels=["emb"], - sample_output_dir=pti_sample_output_dir, - checkpoint_output_dir=pti_checkpoint_output_dir, - sample_frequency=pti_sample_frequency, - ) + continue_training = True + training_iter = 1 + + while continue_training: + print("") + print(f"============ PTI cycle {training_iter} ============") + print("") + + pti_output_dir = output_dir / f"pti_{training_iter}" + pti_checkpoint_output_dir = pti_output_dir / "model" + pti_sample_output_dir = pti_output_dir / "samples" + + trainer( + strategy=lora_strategy, + pti_mode=True, + project="pti", + train_dataloader=pti_datamodule.train_dataloader, + val_dataloader=pti_datamodule.val_dataloader, + optimizer=pti_optimizer, + lr_scheduler=pti_lr_scheduler, + num_train_epochs=num_pti_epochs, + gradient_accumulation_steps=args.pti_gradient_accumulation_steps, + # -- + group_labels=["emb"], + sample_output_dir=pti_sample_output_dir, + checkpoint_output_dir=pti_checkpoint_output_dir, + sample_frequency=pti_sample_frequency, + ) + + response = input("Run another cycle? [y/n] ") + continue_training = response.lower().strip() != "n" + training_iter += 1 if not args.train_emb: embeddings.persist() @@ -978,12 +996,9 @@ def main(): # LORA # -------------------------------------------------------------------------------- - lora_output_dir = output_dir / "lora" - lora_checkpoint_output_dir = lora_output_dir / "model" - lora_sample_output_dir = lora_output_dir / "samples" - lora_datamodule = create_datamodule( batch_size=args.train_batch_size, + dropout=args.tag_dropout, filter=partial(keyword_filter, None, args.collection, args.exclude_collections), ) lora_datamodule.setup() @@ -1037,21 +1052,37 @@ def main(): train_epochs=num_train_epochs, ) - trainer( - strategy=lora_strategy, - project="lora", - train_dataloader=lora_datamodule.train_dataloader, - val_dataloader=lora_datamodule.val_dataloader, - optimizer=lora_optimizer, - lr_scheduler=lora_lr_scheduler, - num_train_epochs=num_train_epochs, - gradient_accumulation_steps=args.gradient_accumulation_steps, - # -- - group_labels=group_labels, - sample_output_dir=lora_sample_output_dir, - checkpoint_output_dir=lora_checkpoint_output_dir, - sample_frequency=lora_sample_frequency, - ) + continue_training = True + training_iter = 1 + + while continue_training: + print("") + print(f"============ LoRA cycle {training_iter} ============") + print("") + + lora_output_dir = output_dir / f"lora_{training_iter}" + lora_checkpoint_output_dir = lora_output_dir / "model" + lora_sample_output_dir = lora_output_dir / "samples" + + trainer( + strategy=lora_strategy, + project=f"lora_{training_iter}", + train_dataloader=lora_datamodule.train_dataloader, + val_dataloader=lora_datamodule.val_dataloader, + optimizer=lora_optimizer, + lr_scheduler=lora_lr_scheduler, + num_train_epochs=num_train_epochs, + gradient_accumulation_steps=args.gradient_accumulation_steps, + # -- + group_labels=group_labels, + sample_output_dir=lora_sample_output_dir, + checkpoint_output_dir=lora_checkpoint_output_dir, + sample_frequency=lora_sample_frequency, + ) + + response = input("Run another cycle? [y/n] ") + continue_training = response.lower().strip() != "n" + training_iter += 1 if __name__ == "__main__": -- cgit v1.2.3-70-g09d2