summaryrefslogtreecommitdiffstats
path: root/train_ti.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-04-11 17:02:22 +0200
committerVolpeon <git@volpeon.ink>2023-04-11 17:02:22 +0200
commita7dc66ae0974886a6c6a4c50def1b733bc04525a (patch)
treebbea49b82f8f87b0ce6141114875d6253c75d8ab /train_ti.py
parentRandomize dataset across cycles (diff)
downloadtextual-inversion-diff-a7dc66ae0974886a6c6a4c50def1b733bc04525a.tar.gz
textual-inversion-diff-a7dc66ae0974886a6c6a4c50def1b733bc04525a.tar.bz2
textual-inversion-diff-a7dc66ae0974886a6c6a4c50def1b733bc04525a.zip
Update
Diffstat (limited to 'train_ti.py')
-rw-r--r--train_ti.py23
1 files changed, 16 insertions, 7 deletions
diff --git a/train_ti.py b/train_ti.py
index 009495b..d7878cd 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -870,20 +870,26 @@ def main():
870 args.num_train_steps / len(datamodule.train_dataset) 870 args.num_train_steps / len(datamodule.train_dataset)
871 ) * args.gradient_accumulation_steps 871 ) * args.gradient_accumulation_steps
872 sample_frequency = math.ceil(num_train_epochs * (sample_frequency / args.num_train_steps)) 872 sample_frequency = math.ceil(num_train_epochs * (sample_frequency / args.num_train_steps))
873 num_training_steps_per_epoch = math.ceil(len(datamodule.train_dataset) / args.gradient_accumulation_steps)
874 num_train_steps = num_training_steps_per_epoch * num_train_epochs
873 if args.sample_num is not None: 875 if args.sample_num is not None:
874 sample_frequency = math.ceil(num_train_epochs / args.sample_num) 876 sample_frequency = math.ceil(num_train_epochs / args.sample_num)
875 877
876 training_iter = 0 878 training_iter = 0
877 879
880 project = placeholder_tokens[0] if len(placeholder_tokens) == 1 else "ti"
881
882 if accelerator.is_main_process:
883 accelerator.init_trackers(project)
884
878 while True: 885 while True:
879 training_iter += 1 886 if training_iter >= args.auto_cycles:
880 if training_iter > args.auto_cycles:
881 response = input("Run another cycle? [y/n] ") 887 response = input("Run another cycle? [y/n] ")
882 if response.lower().strip() == "n": 888 if response.lower().strip() == "n":
883 break 889 break
884 890
885 print("") 891 print("")
886 print(f"------------ TI cycle {training_iter} ------------") 892 print(f"------------ TI cycle {training_iter + 1} ------------")
887 print("") 893 print("")
888 894
889 optimizer = create_optimizer( 895 optimizer = create_optimizer(
@@ -908,17 +914,16 @@ def main():
908 mid_point=args.lr_mid_point, 914 mid_point=args.lr_mid_point,
909 ) 915 )
910 916
911 project = f"{placeholder_tokens[0]}_{training_iter}" if len(placeholder_tokens) == 1 else f"{training_iter}" 917 sample_output_dir = output_dir / project / f"{training_iter + 1}" / "samples"
912 sample_output_dir = output_dir / project / "samples" 918 checkpoint_output_dir = output_dir / project / f"{training_iter + 1}" / "checkpoints"
913 checkpoint_output_dir = output_dir / project / "checkpoints"
914 919
915 trainer( 920 trainer(
916 project=project,
917 train_dataloader=datamodule.train_dataloader, 921 train_dataloader=datamodule.train_dataloader,
918 val_dataloader=datamodule.val_dataloader, 922 val_dataloader=datamodule.val_dataloader,
919 optimizer=optimizer, 923 optimizer=optimizer,
920 lr_scheduler=lr_scheduler, 924 lr_scheduler=lr_scheduler,
921 num_train_epochs=num_train_epochs, 925 num_train_epochs=num_train_epochs,
926 global_step_offset=training_iter * num_train_steps,
922 # -- 927 # --
923 group_labels=["emb"], 928 group_labels=["emb"],
924 checkpoint_output_dir=checkpoint_output_dir, 929 checkpoint_output_dir=checkpoint_output_dir,
@@ -928,6 +933,10 @@ def main():
928 placeholder_token_ids=placeholder_token_ids, 933 placeholder_token_ids=placeholder_token_ids,
929 ) 934 )
930 935
936 training_iter += 1
937
938 accelerator.end_training()
939
931 if not args.sequential: 940 if not args.sequential:
932 run(0, args.placeholder_tokens, args.initializer_tokens, args.num_vectors, args.train_data_template) 941 run(0, args.placeholder_tokens, args.initializer_tokens, args.num_vectors, args.train_data_template)
933 else: 942 else: