summaryrefslogtreecommitdiffstats
path: root/train_ti.py
diff options
context:
space:
mode:
Diffstat (limited to 'train_ti.py')
-rw-r--r--train_ti.py13
1 files changed, 10 insertions, 3 deletions
diff --git a/train_ti.py b/train_ti.py
index 082e9b7..94ddbb6 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -72,6 +72,12 @@ def parse_args():
72 help="How many cycles to run automatically." 72 help="How many cycles to run automatically."
73 ) 73 )
74 parser.add_argument( 74 parser.add_argument(
75 "--cycle_decay",
76 type=float,
77 default=1.0,
78 help="Learning rate decay per cycle."
79 )
80 parser.add_argument(
75 "--placeholder_tokens", 81 "--placeholder_tokens",
76 type=str, 82 type=str,
77 nargs='*', 83 nargs='*',
@@ -672,7 +678,6 @@ def main():
672 convnext.to(accelerator.device, dtype=weight_dtype) 678 convnext.to(accelerator.device, dtype=weight_dtype)
673 convnext.requires_grad_(False) 679 convnext.requires_grad_(False)
674 convnext.eval() 680 convnext.eval()
675 disc = ConvNeXtDiscriminator(convnext, input_size=384)
676 681
677 if len(args.alias_tokens) != 0: 682 if len(args.alias_tokens) != 0:
678 alias_placeholder_tokens = args.alias_tokens[::2] 683 alias_placeholder_tokens = args.alias_tokens[::2]
@@ -815,7 +820,6 @@ def main():
815 milestone_checkpoints=not args.no_milestone_checkpoints, 820 milestone_checkpoints=not args.no_milestone_checkpoints,
816 global_step_offset=global_step_offset, 821 global_step_offset=global_step_offset,
817 offset_noise_strength=args.offset_noise_strength, 822 offset_noise_strength=args.offset_noise_strength,
818 disc=disc,
819 # -- 823 # --
820 use_emb_decay=args.use_emb_decay, 824 use_emb_decay=args.use_emb_decay,
821 emb_decay_target=args.emb_decay_target, 825 emb_decay_target=args.emb_decay_target,
@@ -890,6 +894,7 @@ def main():
890 sample_frequency = math.ceil(num_train_epochs / args.sample_num) 894 sample_frequency = math.ceil(num_train_epochs / args.sample_num)
891 895
892 training_iter = 0 896 training_iter = 0
897 learning_rate = args.learning_rate
893 898
894 project = placeholder_tokens[0] if len(placeholder_tokens) == 1 else "ti" 899 project = placeholder_tokens[0] if len(placeholder_tokens) == 1 else "ti"
895 900
@@ -908,7 +913,7 @@ def main():
908 913
909 optimizer = create_optimizer( 914 optimizer = create_optimizer(
910 text_encoder.text_model.embeddings.token_override_embedding.parameters(), 915 text_encoder.text_model.embeddings.token_override_embedding.parameters(),
911 lr=args.learning_rate, 916 lr=learning_rate,
912 ) 917 )
913 918
914 lr_scheduler = get_scheduler( 919 lr_scheduler = get_scheduler(
@@ -948,6 +953,8 @@ def main():
948 ) 953 )
949 954
950 training_iter += 1 955 training_iter += 1
956 if args.learning_rate is not None:
957 learning_rate *= args.cycle_decay
951 958
952 accelerator.end_training() 959 accelerator.end_training()
953 960