summaryrefslogtreecommitdiffstats
path: root/train_ti.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-04-16 19:03:25 +0200
committerVolpeon <git@volpeon.ink>2023-04-16 19:03:25 +0200
commit71f4a40bb48be4f2759ba2d83faff39691cb2955 (patch)
tree29c704ca549a4c4323403b6cbb0e62f54040ae22 /train_ti.py
parentAdded option to use constant LR on cycles > 1 (diff)
downloadtextual-inversion-diff-71f4a40bb48be4f2759ba2d83faff39691cb2955.tar.gz
textual-inversion-diff-71f4a40bb48be4f2759ba2d83faff39691cb2955.tar.bz2
textual-inversion-diff-71f4a40bb48be4f2759ba2d83faff39691cb2955.zip
Improved automation caps
Diffstat (limited to 'train_ti.py')
-rw-r--r--train_ti.py53
1 files changed, 33 insertions, 20 deletions
diff --git a/train_ti.py b/train_ti.py
index c452269..880320f 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -68,9 +68,9 @@ def parse_args():
68 ) 68 )
69 parser.add_argument( 69 parser.add_argument(
70 "--auto_cycles", 70 "--auto_cycles",
71 type=int, 71 type=str,
72 default=1, 72 default="o",
73 help="How many cycles to run automatically." 73 help="Cycles to run automatically."
74 ) 74 )
75 parser.add_argument( 75 parser.add_argument(
76 "--cycle_decay", 76 "--cycle_decay",
@@ -79,11 +79,6 @@ def parse_args():
79 help="Learning rate decay per cycle." 79 help="Learning rate decay per cycle."
80 ) 80 )
81 parser.add_argument( 81 parser.add_argument(
82 "--cycle_constant",
83 action="store_true",
84 help="Use constant LR on cycles > 1."
85 )
86 parser.add_argument(
87 "--placeholder_tokens", 82 "--placeholder_tokens",
88 type=str, 83 type=str,
89 nargs='*', 84 nargs='*',
@@ -921,27 +916,45 @@ def main():
921 916
922 sample_output_dir = output_dir / project / "samples" 917 sample_output_dir = output_dir / project / "samples"
923 918
919 auto_cycles = list(args.auto_cycles)
920 lr_scheduler = args.lr_scheduler
921 lr_warmup_epochs = args.lr_warmup_epochs
922 lr_cycles = args.lr_cycles
923
924 while True: 924 while True:
925 if training_iter >= args.auto_cycles: 925 if len(auto_cycles) != 0:
926 response = input("Run another cycle? [y/n] ") 926 response = auto_cycles.pop(0)
927 if response.lower().strip() == "n": 927 else:
928 break 928 response = input("Choose action: [o] one_cycle, [w] warmup, [c] constant, [d] decay, [s] stop \n--> ")
929
930 if response.lower().strip() == "o":
931 lr_scheduler = "one_cycle"
932 lr_warmup_epochs = args.lr_warmup_epochs
933 lr_cycles = args.lr_cycles
934 if response.lower().strip() == "w":
935 lr_scheduler = "constant"
936 lr_warmup_epochs = num_train_epochs
937 if response.lower().strip() == "c":
938 lr_scheduler = "constant"
939 lr_warmup_epochs = 0
940 if response.lower().strip() == "d":
941 lr_scheduler = "cosine"
942 lr_warmup_epochs = 0
943 lr_cycles = 1
944 elif response.lower().strip() == "s":
945 break
929 946
930 print("") 947 print("")
931 print(f"------------ TI cycle {training_iter + 1} ------------") 948 print(f"------------ TI cycle {training_iter + 1} ------------")
932 print("") 949 print("")
933 950
934 if args.cycle_constant and training_iter == 1:
935 args.lr_scheduler = "constant"
936 args.lr_warmup_epochs = 0
937
938 optimizer = create_optimizer( 951 optimizer = create_optimizer(
939 text_encoder.text_model.embeddings.token_embedding.parameters(), 952 text_encoder.text_model.embeddings.token_embedding.parameters(),
940 lr=learning_rate, 953 lr=learning_rate,
941 ) 954 )
942 955
943 lr_scheduler = get_scheduler( 956 lr_scheduler = get_scheduler(
944 args.lr_scheduler, 957 lr_scheduler,
945 optimizer=optimizer, 958 optimizer=optimizer,
946 num_training_steps_per_epoch=len(datamodule.train_dataloader), 959 num_training_steps_per_epoch=len(datamodule.train_dataloader),
947 gradient_accumulation_steps=args.gradient_accumulation_steps, 960 gradient_accumulation_steps=args.gradient_accumulation_steps,
@@ -950,10 +963,10 @@ def main():
950 annealing_func=args.lr_annealing_func, 963 annealing_func=args.lr_annealing_func,
951 warmup_exp=args.lr_warmup_exp, 964 warmup_exp=args.lr_warmup_exp,
952 annealing_exp=args.lr_annealing_exp, 965 annealing_exp=args.lr_annealing_exp,
953 cycles=args.lr_cycles, 966 cycles=lr_cycles,
954 end_lr=1e3, 967 end_lr=1e3,
955 train_epochs=num_train_epochs, 968 train_epochs=num_train_epochs,
956 warmup_epochs=args.lr_warmup_epochs, 969 warmup_epochs=lr_warmup_epochs,
957 mid_point=args.lr_mid_point, 970 mid_point=args.lr_mid_point,
958 ) 971 )
959 972
@@ -966,7 +979,7 @@ def main():
966 lr_scheduler=lr_scheduler, 979 lr_scheduler=lr_scheduler,
967 num_train_epochs=num_train_epochs, 980 num_train_epochs=num_train_epochs,
968 global_step_offset=training_iter * num_train_steps, 981 global_step_offset=training_iter * num_train_steps,
969 initial_samples=training_iter == 0, 982 cycle=training_iter,
970 # -- 983 # --
971 group_labels=["emb"], 984 group_labels=["emb"],
972 checkpoint_output_dir=checkpoint_output_dir, 985 checkpoint_output_dir=checkpoint_output_dir,