diff options
author | Volpeon <git@volpeon.ink> | 2023-04-11 17:02:22 +0200 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-04-11 17:02:22 +0200 |
commit | a7dc66ae0974886a6c6a4c50def1b733bc04525a (patch) | |
tree | bbea49b82f8f87b0ce6141114875d6253c75d8ab /train_ti.py | |
parent | Randomize dataset across cycles (diff) | |
download | textual-inversion-diff-a7dc66ae0974886a6c6a4c50def1b733bc04525a.tar.gz textual-inversion-diff-a7dc66ae0974886a6c6a4c50def1b733bc04525a.tar.bz2 textual-inversion-diff-a7dc66ae0974886a6c6a4c50def1b733bc04525a.zip |
Update
Diffstat (limited to 'train_ti.py')
-rw-r--r-- | train_ti.py | 23 |
1 files changed, 16 insertions, 7 deletions
diff --git a/train_ti.py b/train_ti.py index 009495b..d7878cd 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -870,20 +870,26 @@ def main(): | |||
870 | args.num_train_steps / len(datamodule.train_dataset) | 870 | args.num_train_steps / len(datamodule.train_dataset) |
871 | ) * args.gradient_accumulation_steps | 871 | ) * args.gradient_accumulation_steps |
872 | sample_frequency = math.ceil(num_train_epochs * (sample_frequency / args.num_train_steps)) | 872 | sample_frequency = math.ceil(num_train_epochs * (sample_frequency / args.num_train_steps)) |
873 | num_training_steps_per_epoch = math.ceil(len(datamodule.train_dataset) / args.gradient_accumulation_steps) | ||
874 | num_train_steps = num_training_steps_per_epoch * num_train_epochs | ||
873 | if args.sample_num is not None: | 875 | if args.sample_num is not None: |
874 | sample_frequency = math.ceil(num_train_epochs / args.sample_num) | 876 | sample_frequency = math.ceil(num_train_epochs / args.sample_num) |
875 | 877 | ||
876 | training_iter = 0 | 878 | training_iter = 0 |
877 | 879 | ||
880 | project = placeholder_tokens[0] if len(placeholder_tokens) == 1 else "ti" | ||
881 | |||
882 | if accelerator.is_main_process: | ||
883 | accelerator.init_trackers(project) | ||
884 | |||
878 | while True: | 885 | while True: |
879 | training_iter += 1 | 886 | if training_iter >= args.auto_cycles: |
880 | if training_iter > args.auto_cycles: | ||
881 | response = input("Run another cycle? [y/n] ") | 887 | response = input("Run another cycle? [y/n] ") |
882 | if response.lower().strip() == "n": | 888 | if response.lower().strip() == "n": |
883 | break | 889 | break |
884 | 890 | ||
885 | print("") | 891 | print("") |
886 | print(f"------------ TI cycle {training_iter} ------------") | 892 | print(f"------------ TI cycle {training_iter + 1} ------------") |
887 | print("") | 893 | print("") |
888 | 894 | ||
889 | optimizer = create_optimizer( | 895 | optimizer = create_optimizer( |
@@ -908,17 +914,16 @@ def main(): | |||
908 | mid_point=args.lr_mid_point, | 914 | mid_point=args.lr_mid_point, |
909 | ) | 915 | ) |
910 | 916 | ||
911 | project = f"{placeholder_tokens[0]}_{training_iter}" if len(placeholder_tokens) == 1 else f"{training_iter}" | 917 | sample_output_dir = output_dir / project / f"{training_iter + 1}" / "samples" |
912 | sample_output_dir = output_dir / project / "samples" | 918 | checkpoint_output_dir = output_dir / project / f"{training_iter + 1}" / "checkpoints" |
913 | checkpoint_output_dir = output_dir / project / "checkpoints" | ||
914 | 919 | ||
915 | trainer( | 920 | trainer( |
916 | project=project, | ||
917 | train_dataloader=datamodule.train_dataloader, | 921 | train_dataloader=datamodule.train_dataloader, |
918 | val_dataloader=datamodule.val_dataloader, | 922 | val_dataloader=datamodule.val_dataloader, |
919 | optimizer=optimizer, | 923 | optimizer=optimizer, |
920 | lr_scheduler=lr_scheduler, | 924 | lr_scheduler=lr_scheduler, |
921 | num_train_epochs=num_train_epochs, | 925 | num_train_epochs=num_train_epochs, |
926 | global_step_offset=training_iter * num_train_steps, | ||
922 | # -- | 927 | # -- |
923 | group_labels=["emb"], | 928 | group_labels=["emb"], |
924 | checkpoint_output_dir=checkpoint_output_dir, | 929 | checkpoint_output_dir=checkpoint_output_dir, |
@@ -928,6 +933,10 @@ def main(): | |||
928 | placeholder_token_ids=placeholder_token_ids, | 933 | placeholder_token_ids=placeholder_token_ids, |
929 | ) | 934 | ) |
930 | 935 | ||
936 | training_iter += 1 | ||
937 | |||
938 | accelerator.end_training() | ||
939 | |||
931 | if not args.sequential: | 940 | if not args.sequential: |
932 | run(0, args.placeholder_tokens, args.initializer_tokens, args.num_vectors, args.train_data_template) | 941 | run(0, args.placeholder_tokens, args.initializer_tokens, args.num_vectors, args.train_data_template) |
933 | else: | 942 | else: |