diff options
| author | Volpeon <git@volpeon.ink> | 2023-04-11 17:02:22 +0200 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2023-04-11 17:02:22 +0200 |
| commit | a7dc66ae0974886a6c6a4c50def1b733bc04525a (patch) | |
| tree | bbea49b82f8f87b0ce6141114875d6253c75d8ab /train_ti.py | |
| parent | Randomize dataset across cycles (diff) | |
| download | textual-inversion-diff-a7dc66ae0974886a6c6a4c50def1b733bc04525a.tar.gz textual-inversion-diff-a7dc66ae0974886a6c6a4c50def1b733bc04525a.tar.bz2 textual-inversion-diff-a7dc66ae0974886a6c6a4c50def1b733bc04525a.zip | |
Update
Diffstat (limited to 'train_ti.py')
| -rw-r--r-- | train_ti.py | 23 |
1 files changed, 16 insertions, 7 deletions
diff --git a/train_ti.py b/train_ti.py index 009495b..d7878cd 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -870,20 +870,26 @@ def main(): | |||
| 870 | args.num_train_steps / len(datamodule.train_dataset) | 870 | args.num_train_steps / len(datamodule.train_dataset) |
| 871 | ) * args.gradient_accumulation_steps | 871 | ) * args.gradient_accumulation_steps |
| 872 | sample_frequency = math.ceil(num_train_epochs * (sample_frequency / args.num_train_steps)) | 872 | sample_frequency = math.ceil(num_train_epochs * (sample_frequency / args.num_train_steps)) |
| 873 | num_training_steps_per_epoch = math.ceil(len(datamodule.train_dataset) / args.gradient_accumulation_steps) | ||
| 874 | num_train_steps = num_training_steps_per_epoch * num_train_epochs | ||
| 873 | if args.sample_num is not None: | 875 | if args.sample_num is not None: |
| 874 | sample_frequency = math.ceil(num_train_epochs / args.sample_num) | 876 | sample_frequency = math.ceil(num_train_epochs / args.sample_num) |
| 875 | 877 | ||
| 876 | training_iter = 0 | 878 | training_iter = 0 |
| 877 | 879 | ||
| 880 | project = placeholder_tokens[0] if len(placeholder_tokens) == 1 else "ti" | ||
| 881 | |||
| 882 | if accelerator.is_main_process: | ||
| 883 | accelerator.init_trackers(project) | ||
| 884 | |||
| 878 | while True: | 885 | while True: |
| 879 | training_iter += 1 | 886 | if training_iter >= args.auto_cycles: |
| 880 | if training_iter > args.auto_cycles: | ||
| 881 | response = input("Run another cycle? [y/n] ") | 887 | response = input("Run another cycle? [y/n] ") |
| 882 | if response.lower().strip() == "n": | 888 | if response.lower().strip() == "n": |
| 883 | break | 889 | break |
| 884 | 890 | ||
| 885 | print("") | 891 | print("") |
| 886 | print(f"------------ TI cycle {training_iter} ------------") | 892 | print(f"------------ TI cycle {training_iter + 1} ------------") |
| 887 | print("") | 893 | print("") |
| 888 | 894 | ||
| 889 | optimizer = create_optimizer( | 895 | optimizer = create_optimizer( |
| @@ -908,17 +914,16 @@ def main(): | |||
| 908 | mid_point=args.lr_mid_point, | 914 | mid_point=args.lr_mid_point, |
| 909 | ) | 915 | ) |
| 910 | 916 | ||
| 911 | project = f"{placeholder_tokens[0]}_{training_iter}" if len(placeholder_tokens) == 1 else f"{training_iter}" | 917 | sample_output_dir = output_dir / project / f"{training_iter + 1}" / "samples" |
| 912 | sample_output_dir = output_dir / project / "samples" | 918 | checkpoint_output_dir = output_dir / project / f"{training_iter + 1}" / "checkpoints" |
| 913 | checkpoint_output_dir = output_dir / project / "checkpoints" | ||
| 914 | 919 | ||
| 915 | trainer( | 920 | trainer( |
| 916 | project=project, | ||
| 917 | train_dataloader=datamodule.train_dataloader, | 921 | train_dataloader=datamodule.train_dataloader, |
| 918 | val_dataloader=datamodule.val_dataloader, | 922 | val_dataloader=datamodule.val_dataloader, |
| 919 | optimizer=optimizer, | 923 | optimizer=optimizer, |
| 920 | lr_scheduler=lr_scheduler, | 924 | lr_scheduler=lr_scheduler, |
| 921 | num_train_epochs=num_train_epochs, | 925 | num_train_epochs=num_train_epochs, |
| 926 | global_step_offset=training_iter * num_train_steps, | ||
| 922 | # -- | 927 | # -- |
| 923 | group_labels=["emb"], | 928 | group_labels=["emb"], |
| 924 | checkpoint_output_dir=checkpoint_output_dir, | 929 | checkpoint_output_dir=checkpoint_output_dir, |
| @@ -928,6 +933,10 @@ def main(): | |||
| 928 | placeholder_token_ids=placeholder_token_ids, | 933 | placeholder_token_ids=placeholder_token_ids, |
| 929 | ) | 934 | ) |
| 930 | 935 | ||
| 936 | training_iter += 1 | ||
| 937 | |||
| 938 | accelerator.end_training() | ||
| 939 | |||
| 931 | if not args.sequential: | 940 | if not args.sequential: |
| 932 | run(0, args.placeholder_tokens, args.initializer_tokens, args.num_vectors, args.train_data_template) | 941 | run(0, args.placeholder_tokens, args.initializer_tokens, args.num_vectors, args.train_data_template) |
| 933 | else: | 942 | else: |
