diff options
author | Volpeon <git@volpeon.ink> | 2023-04-13 09:13:35 +0200 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-04-13 09:13:35 +0200 |
commit | 24f7a14defdbc050907ad282ebdc7ea6f6591363 (patch) | |
tree | 996eaa43214f7c1ce7634605a8c08c9204898944 /train_ti.py | |
parent | Update (diff) | |
download | textual-inversion-diff-24f7a14defdbc050907ad282ebdc7ea6f6591363.tar.gz textual-inversion-diff-24f7a14defdbc050907ad282ebdc7ea6f6591363.tar.bz2 textual-inversion-diff-24f7a14defdbc050907ad282ebdc7ea6f6591363.zip |
Added cycle LR decay
Diffstat (limited to 'train_ti.py')
-rw-r--r-- | train_ti.py | 8 |
1 files changed, 5 insertions, 3 deletions
diff --git a/train_ti.py b/train_ti.py index 94ddbb6..d931db6 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -901,6 +901,8 @@ def main(): | |||
901 | if accelerator.is_main_process: | 901 | if accelerator.is_main_process: |
902 | accelerator.init_trackers(project) | 902 | accelerator.init_trackers(project) |
903 | 903 | ||
904 | sample_output_dir = output_dir / project / "samples" | ||
905 | |||
904 | while True: | 906 | while True: |
905 | if training_iter >= args.auto_cycles: | 907 | if training_iter >= args.auto_cycles: |
906 | response = input("Run another cycle? [y/n] ") | 908 | response = input("Run another cycle? [y/n] ") |
@@ -933,8 +935,7 @@ def main(): | |||
933 | mid_point=args.lr_mid_point, | 935 | mid_point=args.lr_mid_point, |
934 | ) | 936 | ) |
935 | 937 | ||
936 | sample_output_dir = output_dir / project / f"{training_iter + 1}" / "samples" | 938 | checkpoint_output_dir = output_dir / project / f"checkpoints_{training_iter + 1}" |
937 | checkpoint_output_dir = output_dir / project / f"{training_iter + 1}" / "checkpoints" | ||
938 | 939 | ||
939 | trainer( | 940 | trainer( |
940 | train_dataloader=datamodule.train_dataloader, | 941 | train_dataloader=datamodule.train_dataloader, |
@@ -943,6 +944,7 @@ def main(): | |||
943 | lr_scheduler=lr_scheduler, | 944 | lr_scheduler=lr_scheduler, |
944 | num_train_epochs=num_train_epochs, | 945 | num_train_epochs=num_train_epochs, |
945 | global_step_offset=training_iter * num_train_steps, | 946 | global_step_offset=training_iter * num_train_steps, |
947 | initial_samples=training_iter == 0, | ||
946 | # -- | 948 | # -- |
947 | group_labels=["emb"], | 949 | group_labels=["emb"], |
948 | checkpoint_output_dir=checkpoint_output_dir, | 950 | checkpoint_output_dir=checkpoint_output_dir, |
@@ -953,7 +955,7 @@ def main(): | |||
953 | ) | 955 | ) |
954 | 956 | ||
955 | training_iter += 1 | 957 | training_iter += 1 |
956 | if args.learning_rate is not None: | 958 | if learning_rate is not None: |
957 | learning_rate *= args.cycle_decay | 959 | learning_rate *= args.cycle_decay |
958 | 960 | ||
959 | accelerator.end_training() | 961 | accelerator.end_training() |