summaryrefslogtreecommitdiffstats
path: root/train_ti.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-04-13 09:13:35 +0200
committerVolpeon <git@volpeon.ink>2023-04-13 09:13:35 +0200
commit24f7a14defdbc050907ad282ebdc7ea6f6591363 (patch)
tree996eaa43214f7c1ce7634605a8c08c9204898944 /train_ti.py
parentUpdate (diff)
downloadtextual-inversion-diff-24f7a14defdbc050907ad282ebdc7ea6f6591363.tar.gz
textual-inversion-diff-24f7a14defdbc050907ad282ebdc7ea6f6591363.tar.bz2
textual-inversion-diff-24f7a14defdbc050907ad282ebdc7ea6f6591363.zip
Added cycle LR decay
Diffstat (limited to 'train_ti.py')
-rw-r--r--train_ti.py8
1 files changed, 5 insertions, 3 deletions
diff --git a/train_ti.py b/train_ti.py
index 94ddbb6..d931db6 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -901,6 +901,8 @@ def main():
901 if accelerator.is_main_process: 901 if accelerator.is_main_process:
902 accelerator.init_trackers(project) 902 accelerator.init_trackers(project)
903 903
904 sample_output_dir = output_dir / project / "samples"
905
904 while True: 906 while True:
905 if training_iter >= args.auto_cycles: 907 if training_iter >= args.auto_cycles:
906 response = input("Run another cycle? [y/n] ") 908 response = input("Run another cycle? [y/n] ")
@@ -933,8 +935,7 @@ def main():
933 mid_point=args.lr_mid_point, 935 mid_point=args.lr_mid_point,
934 ) 936 )
935 937
936 sample_output_dir = output_dir / project / f"{training_iter + 1}" / "samples" 938 checkpoint_output_dir = output_dir / project / f"checkpoints_{training_iter + 1}"
937 checkpoint_output_dir = output_dir / project / f"{training_iter + 1}" / "checkpoints"
938 939
939 trainer( 940 trainer(
940 train_dataloader=datamodule.train_dataloader, 941 train_dataloader=datamodule.train_dataloader,
@@ -943,6 +944,7 @@ def main():
943 lr_scheduler=lr_scheduler, 944 lr_scheduler=lr_scheduler,
944 num_train_epochs=num_train_epochs, 945 num_train_epochs=num_train_epochs,
945 global_step_offset=training_iter * num_train_steps, 946 global_step_offset=training_iter * num_train_steps,
947 initial_samples=training_iter == 0,
946 # -- 948 # --
947 group_labels=["emb"], 949 group_labels=["emb"],
948 checkpoint_output_dir=checkpoint_output_dir, 950 checkpoint_output_dir=checkpoint_output_dir,
@@ -953,7 +955,7 @@ def main():
953 ) 955 )
954 956
955 training_iter += 1 957 training_iter += 1
956 if args.learning_rate is not None: 958 if learning_rate is not None:
957 learning_rate *= args.cycle_decay 959 learning_rate *= args.cycle_decay
958 960
959 accelerator.end_training() 961 accelerator.end_training()