diff options
| author | Volpeon <git@volpeon.ink> | 2023-04-16 19:03:25 +0200 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2023-04-16 19:03:25 +0200 |
| commit | 71f4a40bb48be4f2759ba2d83faff39691cb2955 (patch) | |
| tree | 29c704ca549a4c4323403b6cbb0e62f54040ae22 /train_ti.py | |
| parent | Added option to use constant LR on cycles > 1 (diff) | |
| download | textual-inversion-diff-71f4a40bb48be4f2759ba2d83faff39691cb2955.tar.gz textual-inversion-diff-71f4a40bb48be4f2759ba2d83faff39691cb2955.tar.bz2 textual-inversion-diff-71f4a40bb48be4f2759ba2d83faff39691cb2955.zip | |
Improved automation caps
Diffstat (limited to 'train_ti.py')
| -rw-r--r-- | train_ti.py | 53 |
1 files changed, 33 insertions, 20 deletions
diff --git a/train_ti.py b/train_ti.py index c452269..880320f 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -68,9 +68,9 @@ def parse_args(): | |||
| 68 | ) | 68 | ) |
| 69 | parser.add_argument( | 69 | parser.add_argument( |
| 70 | "--auto_cycles", | 70 | "--auto_cycles", |
| 71 | type=int, | 71 | type=str, |
| 72 | default=1, | 72 | default="o", |
| 73 | help="How many cycles to run automatically." | 73 | help="Cycles to run automatically." |
| 74 | ) | 74 | ) |
| 75 | parser.add_argument( | 75 | parser.add_argument( |
| 76 | "--cycle_decay", | 76 | "--cycle_decay", |
| @@ -79,11 +79,6 @@ def parse_args(): | |||
| 79 | help="Learning rate decay per cycle." | 79 | help="Learning rate decay per cycle." |
| 80 | ) | 80 | ) |
| 81 | parser.add_argument( | 81 | parser.add_argument( |
| 82 | "--cycle_constant", | ||
| 83 | action="store_true", | ||
| 84 | help="Use constant LR on cycles > 1." | ||
| 85 | ) | ||
| 86 | parser.add_argument( | ||
| 87 | "--placeholder_tokens", | 82 | "--placeholder_tokens", |
| 88 | type=str, | 83 | type=str, |
| 89 | nargs='*', | 84 | nargs='*', |
| @@ -921,27 +916,45 @@ def main(): | |||
| 921 | 916 | ||
| 922 | sample_output_dir = output_dir / project / "samples" | 917 | sample_output_dir = output_dir / project / "samples" |
| 923 | 918 | ||
| 919 | auto_cycles = list(args.auto_cycles) | ||
| 920 | lr_scheduler = args.lr_scheduler | ||
| 921 | lr_warmup_epochs = args.lr_warmup_epochs | ||
| 922 | lr_cycles = args.lr_cycles | ||
| 923 | |||
| 924 | while True: | 924 | while True: |
| 925 | if training_iter >= args.auto_cycles: | 925 | if len(auto_cycles) != 0: |
| 926 | response = input("Run another cycle? [y/n] ") | 926 | response = auto_cycles.pop(0) |
| 927 | if response.lower().strip() == "n": | 927 | else: |
| 928 | break | 928 | response = input("Choose action: [o] one_cycle, [w] warmup, [c] constant, [d] decay, [s] stop \n--> ") |
| 929 | |||
| 930 | if response.lower().strip() == "o": | ||
| 931 | lr_scheduler = "one_cycle" | ||
| 932 | lr_warmup_epochs = args.lr_warmup_epochs | ||
| 933 | lr_cycles = args.lr_cycles | ||
| 934 | if response.lower().strip() == "w": | ||
| 935 | lr_scheduler = "constant" | ||
| 936 | lr_warmup_epochs = num_train_epochs | ||
| 937 | if response.lower().strip() == "c": | ||
| 938 | lr_scheduler = "constant" | ||
| 939 | lr_warmup_epochs = 0 | ||
| 940 | if response.lower().strip() == "d": | ||
| 941 | lr_scheduler = "cosine" | ||
| 942 | lr_warmup_epochs = 0 | ||
| 943 | lr_cycles = 1 | ||
| 944 | elif response.lower().strip() == "s": | ||
| 945 | break | ||
| 929 | 946 | ||
| 930 | print("") | 947 | print("") |
| 931 | print(f"------------ TI cycle {training_iter + 1} ------------") | 948 | print(f"------------ TI cycle {training_iter + 1} ------------") |
| 932 | print("") | 949 | print("") |
| 933 | 950 | ||
| 934 | if args.cycle_constant and training_iter == 1: | ||
| 935 | args.lr_scheduler = "constant" | ||
| 936 | args.lr_warmup_epochs = 0 | ||
| 937 | |||
| 938 | optimizer = create_optimizer( | 951 | optimizer = create_optimizer( |
| 939 | text_encoder.text_model.embeddings.token_embedding.parameters(), | 952 | text_encoder.text_model.embeddings.token_embedding.parameters(), |
| 940 | lr=learning_rate, | 953 | lr=learning_rate, |
| 941 | ) | 954 | ) |
| 942 | 955 | ||
| 943 | lr_scheduler = get_scheduler( | 956 | lr_scheduler = get_scheduler( |
| 944 | args.lr_scheduler, | 957 | lr_scheduler, |
| 945 | optimizer=optimizer, | 958 | optimizer=optimizer, |
| 946 | num_training_steps_per_epoch=len(datamodule.train_dataloader), | 959 | num_training_steps_per_epoch=len(datamodule.train_dataloader), |
| 947 | gradient_accumulation_steps=args.gradient_accumulation_steps, | 960 | gradient_accumulation_steps=args.gradient_accumulation_steps, |
| @@ -950,10 +963,10 @@ def main(): | |||
| 950 | annealing_func=args.lr_annealing_func, | 963 | annealing_func=args.lr_annealing_func, |
| 951 | warmup_exp=args.lr_warmup_exp, | 964 | warmup_exp=args.lr_warmup_exp, |
| 952 | annealing_exp=args.lr_annealing_exp, | 965 | annealing_exp=args.lr_annealing_exp, |
| 953 | cycles=args.lr_cycles, | 966 | cycles=lr_cycles, |
| 954 | end_lr=1e3, | 967 | end_lr=1e3, |
| 955 | train_epochs=num_train_epochs, | 968 | train_epochs=num_train_epochs, |
| 956 | warmup_epochs=args.lr_warmup_epochs, | 969 | warmup_epochs=lr_warmup_epochs, |
| 957 | mid_point=args.lr_mid_point, | 970 | mid_point=args.lr_mid_point, |
| 958 | ) | 971 | ) |
| 959 | 972 | ||
| @@ -966,7 +979,7 @@ def main(): | |||
| 966 | lr_scheduler=lr_scheduler, | 979 | lr_scheduler=lr_scheduler, |
| 967 | num_train_epochs=num_train_epochs, | 980 | num_train_epochs=num_train_epochs, |
| 968 | global_step_offset=training_iter * num_train_steps, | 981 | global_step_offset=training_iter * num_train_steps, |
| 969 | initial_samples=training_iter == 0, | 982 | cycle=training_iter, |
| 970 | # -- | 983 | # -- |
| 971 | group_labels=["emb"], | 984 | group_labels=["emb"], |
| 972 | checkpoint_output_dir=checkpoint_output_dir, | 985 | checkpoint_output_dir=checkpoint_output_dir, |
