diff options
Diffstat (limited to 'train_ti.py')
-rw-r--r-- | train_ti.py | 53 |
1 files changed, 33 insertions, 20 deletions
diff --git a/train_ti.py b/train_ti.py index c452269..880320f 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -68,9 +68,9 @@ def parse_args(): | |||
68 | ) | 68 | ) |
69 | parser.add_argument( | 69 | parser.add_argument( |
70 | "--auto_cycles", | 70 | "--auto_cycles", |
71 | type=int, | 71 | type=str, |
72 | default=1, | 72 | default="o", |
73 | help="How many cycles to run automatically." | 73 | help="Cycles to run automatically." |
74 | ) | 74 | ) |
75 | parser.add_argument( | 75 | parser.add_argument( |
76 | "--cycle_decay", | 76 | "--cycle_decay", |
@@ -79,11 +79,6 @@ def parse_args(): | |||
79 | help="Learning rate decay per cycle." | 79 | help="Learning rate decay per cycle." |
80 | ) | 80 | ) |
81 | parser.add_argument( | 81 | parser.add_argument( |
82 | "--cycle_constant", | ||
83 | action="store_true", | ||
84 | help="Use constant LR on cycles > 1." | ||
85 | ) | ||
86 | parser.add_argument( | ||
87 | "--placeholder_tokens", | 82 | "--placeholder_tokens", |
88 | type=str, | 83 | type=str, |
89 | nargs='*', | 84 | nargs='*', |
@@ -921,27 +916,45 @@ def main(): | |||
921 | 916 | ||
922 | sample_output_dir = output_dir / project / "samples" | 917 | sample_output_dir = output_dir / project / "samples" |
923 | 918 | ||
919 | auto_cycles = list(args.auto_cycles) | ||
920 | lr_scheduler = args.lr_scheduler | ||
921 | lr_warmup_epochs = args.lr_warmup_epochs | ||
922 | lr_cycles = args.lr_cycles | ||
923 | |||
924 | while True: | 924 | while True: |
925 | if training_iter >= args.auto_cycles: | 925 | if len(auto_cycles) != 0: |
926 | response = input("Run another cycle? [y/n] ") | 926 | response = auto_cycles.pop(0) |
927 | if response.lower().strip() == "n": | 927 | else: |
928 | break | 928 | response = input("Choose action: [o] one_cycle, [w] warmup, [c] constant, [d] decay, [s] stop \n--> ") |
929 | |||
930 | if response.lower().strip() == "o": | ||
931 | lr_scheduler = "one_cycle" | ||
932 | lr_warmup_epochs = args.lr_warmup_epochs | ||
933 | lr_cycles = args.lr_cycles | ||
934 | if response.lower().strip() == "w": | ||
935 | lr_scheduler = "constant" | ||
936 | lr_warmup_epochs = num_train_epochs | ||
937 | if response.lower().strip() == "c": | ||
938 | lr_scheduler = "constant" | ||
939 | lr_warmup_epochs = 0 | ||
940 | if response.lower().strip() == "d": | ||
941 | lr_scheduler = "cosine" | ||
942 | lr_warmup_epochs = 0 | ||
943 | lr_cycles = 1 | ||
944 | elif response.lower().strip() == "s": | ||
945 | break | ||
929 | 946 | ||
930 | print("") | 947 | print("") |
931 | print(f"------------ TI cycle {training_iter + 1} ------------") | 948 | print(f"------------ TI cycle {training_iter + 1} ------------") |
932 | print("") | 949 | print("") |
933 | 950 | ||
934 | if args.cycle_constant and training_iter == 1: | ||
935 | args.lr_scheduler = "constant" | ||
936 | args.lr_warmup_epochs = 0 | ||
937 | |||
938 | optimizer = create_optimizer( | 951 | optimizer = create_optimizer( |
939 | text_encoder.text_model.embeddings.token_embedding.parameters(), | 952 | text_encoder.text_model.embeddings.token_embedding.parameters(), |
940 | lr=learning_rate, | 953 | lr=learning_rate, |
941 | ) | 954 | ) |
942 | 955 | ||
943 | lr_scheduler = get_scheduler( | 956 | lr_scheduler = get_scheduler( |
944 | args.lr_scheduler, | 957 | lr_scheduler, |
945 | optimizer=optimizer, | 958 | optimizer=optimizer, |
946 | num_training_steps_per_epoch=len(datamodule.train_dataloader), | 959 | num_training_steps_per_epoch=len(datamodule.train_dataloader), |
947 | gradient_accumulation_steps=args.gradient_accumulation_steps, | 960 | gradient_accumulation_steps=args.gradient_accumulation_steps, |
@@ -950,10 +963,10 @@ def main(): | |||
950 | annealing_func=args.lr_annealing_func, | 963 | annealing_func=args.lr_annealing_func, |
951 | warmup_exp=args.lr_warmup_exp, | 964 | warmup_exp=args.lr_warmup_exp, |
952 | annealing_exp=args.lr_annealing_exp, | 965 | annealing_exp=args.lr_annealing_exp, |
953 | cycles=args.lr_cycles, | 966 | cycles=lr_cycles, |
954 | end_lr=1e3, | 967 | end_lr=1e3, |
955 | train_epochs=num_train_epochs, | 968 | train_epochs=num_train_epochs, |
956 | warmup_epochs=args.lr_warmup_epochs, | 969 | warmup_epochs=lr_warmup_epochs, |
957 | mid_point=args.lr_mid_point, | 970 | mid_point=args.lr_mid_point, |
958 | ) | 971 | ) |
959 | 972 | ||
@@ -966,7 +979,7 @@ def main(): | |||
966 | lr_scheduler=lr_scheduler, | 979 | lr_scheduler=lr_scheduler, |
967 | num_train_epochs=num_train_epochs, | 980 | num_train_epochs=num_train_epochs, |
968 | global_step_offset=training_iter * num_train_steps, | 981 | global_step_offset=training_iter * num_train_steps, |
969 | initial_samples=training_iter == 0, | 982 | cycle=training_iter, |
970 | # -- | 983 | # -- |
971 | group_labels=["emb"], | 984 | group_labels=["emb"], |
972 | checkpoint_output_dir=checkpoint_output_dir, | 985 | checkpoint_output_dir=checkpoint_output_dir, |