diff options
Diffstat (limited to 'train_ti.py')
-rw-r--r-- | train_ti.py | 13 |
1 files changed, 10 insertions, 3 deletions
diff --git a/train_ti.py b/train_ti.py index 082e9b7..94ddbb6 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -72,6 +72,12 @@ def parse_args(): | |||
72 | help="How many cycles to run automatically." | 72 | help="How many cycles to run automatically." |
73 | ) | 73 | ) |
74 | parser.add_argument( | 74 | parser.add_argument( |
75 | "--cycle_decay", | ||
76 | type=float, | ||
77 | default=1.0, | ||
78 | help="Learning rate decay per cycle." | ||
79 | ) | ||
80 | parser.add_argument( | ||
75 | "--placeholder_tokens", | 81 | "--placeholder_tokens", |
76 | type=str, | 82 | type=str, |
77 | nargs='*', | 83 | nargs='*', |
@@ -672,7 +678,6 @@ def main(): | |||
672 | convnext.to(accelerator.device, dtype=weight_dtype) | 678 | convnext.to(accelerator.device, dtype=weight_dtype) |
673 | convnext.requires_grad_(False) | 679 | convnext.requires_grad_(False) |
674 | convnext.eval() | 680 | convnext.eval() |
675 | disc = ConvNeXtDiscriminator(convnext, input_size=384) | ||
676 | 681 | ||
677 | if len(args.alias_tokens) != 0: | 682 | if len(args.alias_tokens) != 0: |
678 | alias_placeholder_tokens = args.alias_tokens[::2] | 683 | alias_placeholder_tokens = args.alias_tokens[::2] |
@@ -815,7 +820,6 @@ def main(): | |||
815 | milestone_checkpoints=not args.no_milestone_checkpoints, | 820 | milestone_checkpoints=not args.no_milestone_checkpoints, |
816 | global_step_offset=global_step_offset, | 821 | global_step_offset=global_step_offset, |
817 | offset_noise_strength=args.offset_noise_strength, | 822 | offset_noise_strength=args.offset_noise_strength, |
818 | disc=disc, | ||
819 | # -- | 823 | # -- |
820 | use_emb_decay=args.use_emb_decay, | 824 | use_emb_decay=args.use_emb_decay, |
821 | emb_decay_target=args.emb_decay_target, | 825 | emb_decay_target=args.emb_decay_target, |
@@ -890,6 +894,7 @@ def main(): | |||
890 | sample_frequency = math.ceil(num_train_epochs / args.sample_num) | 894 | sample_frequency = math.ceil(num_train_epochs / args.sample_num) |
891 | 895 | ||
892 | training_iter = 0 | 896 | training_iter = 0 |
897 | learning_rate = args.learning_rate | ||
893 | 898 | ||
894 | project = placeholder_tokens[0] if len(placeholder_tokens) == 1 else "ti" | 899 | project = placeholder_tokens[0] if len(placeholder_tokens) == 1 else "ti" |
895 | 900 | ||
@@ -908,7 +913,7 @@ def main(): | |||
908 | 913 | ||
909 | optimizer = create_optimizer( | 914 | optimizer = create_optimizer( |
910 | text_encoder.text_model.embeddings.token_override_embedding.parameters(), | 915 | text_encoder.text_model.embeddings.token_override_embedding.parameters(), |
911 | lr=args.learning_rate, | 916 | lr=learning_rate, |
912 | ) | 917 | ) |
913 | 918 | ||
914 | lr_scheduler = get_scheduler( | 919 | lr_scheduler = get_scheduler( |
@@ -948,6 +953,8 @@ def main(): | |||
948 | ) | 953 | ) |
949 | 954 | ||
950 | training_iter += 1 | 955 | training_iter += 1 |
956 | if args.learning_rate is not None: | ||
957 | learning_rate *= args.cycle_decay | ||
951 | 958 | ||
952 | accelerator.end_training() | 959 | accelerator.end_training() |
953 | 960 | ||