diff options
Diffstat (limited to 'train_ti.py')
| -rw-r--r-- | train_ti.py | 13 |
1 files changed, 10 insertions, 3 deletions
diff --git a/train_ti.py b/train_ti.py index 082e9b7..94ddbb6 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -72,6 +72,12 @@ def parse_args(): | |||
| 72 | help="How many cycles to run automatically." | 72 | help="How many cycles to run automatically." |
| 73 | ) | 73 | ) |
| 74 | parser.add_argument( | 74 | parser.add_argument( |
| 75 | "--cycle_decay", | ||
| 76 | type=float, | ||
| 77 | default=1.0, | ||
| 78 | help="Learning rate decay per cycle." | ||
| 79 | ) | ||
| 80 | parser.add_argument( | ||
| 75 | "--placeholder_tokens", | 81 | "--placeholder_tokens", |
| 76 | type=str, | 82 | type=str, |
| 77 | nargs='*', | 83 | nargs='*', |
| @@ -672,7 +678,6 @@ def main(): | |||
| 672 | convnext.to(accelerator.device, dtype=weight_dtype) | 678 | convnext.to(accelerator.device, dtype=weight_dtype) |
| 673 | convnext.requires_grad_(False) | 679 | convnext.requires_grad_(False) |
| 674 | convnext.eval() | 680 | convnext.eval() |
| 675 | disc = ConvNeXtDiscriminator(convnext, input_size=384) | ||
| 676 | 681 | ||
| 677 | if len(args.alias_tokens) != 0: | 682 | if len(args.alias_tokens) != 0: |
| 678 | alias_placeholder_tokens = args.alias_tokens[::2] | 683 | alias_placeholder_tokens = args.alias_tokens[::2] |
| @@ -815,7 +820,6 @@ def main(): | |||
| 815 | milestone_checkpoints=not args.no_milestone_checkpoints, | 820 | milestone_checkpoints=not args.no_milestone_checkpoints, |
| 816 | global_step_offset=global_step_offset, | 821 | global_step_offset=global_step_offset, |
| 817 | offset_noise_strength=args.offset_noise_strength, | 822 | offset_noise_strength=args.offset_noise_strength, |
| 818 | disc=disc, | ||
| 819 | # -- | 823 | # -- |
| 820 | use_emb_decay=args.use_emb_decay, | 824 | use_emb_decay=args.use_emb_decay, |
| 821 | emb_decay_target=args.emb_decay_target, | 825 | emb_decay_target=args.emb_decay_target, |
| @@ -890,6 +894,7 @@ def main(): | |||
| 890 | sample_frequency = math.ceil(num_train_epochs / args.sample_num) | 894 | sample_frequency = math.ceil(num_train_epochs / args.sample_num) |
| 891 | 895 | ||
| 892 | training_iter = 0 | 896 | training_iter = 0 |
| 897 | learning_rate = args.learning_rate | ||
| 893 | 898 | ||
| 894 | project = placeholder_tokens[0] if len(placeholder_tokens) == 1 else "ti" | 899 | project = placeholder_tokens[0] if len(placeholder_tokens) == 1 else "ti" |
| 895 | 900 | ||
| @@ -908,7 +913,7 @@ def main(): | |||
| 908 | 913 | ||
| 909 | optimizer = create_optimizer( | 914 | optimizer = create_optimizer( |
| 910 | text_encoder.text_model.embeddings.token_override_embedding.parameters(), | 915 | text_encoder.text_model.embeddings.token_override_embedding.parameters(), |
| 911 | lr=args.learning_rate, | 916 | lr=learning_rate, |
| 912 | ) | 917 | ) |
| 913 | 918 | ||
| 914 | lr_scheduler = get_scheduler( | 919 | lr_scheduler = get_scheduler( |
| @@ -948,6 +953,8 @@ def main(): | |||
| 948 | ) | 953 | ) |
| 949 | 954 | ||
| 950 | training_iter += 1 | 955 | training_iter += 1 |
| 956 | if args.learning_rate is not None: | ||
| 957 | learning_rate *= args.cycle_decay | ||
| 951 | 958 | ||
| 952 | accelerator.end_training() | 959 | accelerator.end_training() |
| 953 | 960 | ||
