summaryrefslogtreecommitdiffstats
path: root/train_ti.py
diff options
context:
space:
mode:
Diffstat (limited to 'train_ti.py')
-rw-r--r--train_ti.py18
1 files changed, 12 insertions, 6 deletions
diff --git a/train_ti.py b/train_ti.py
index 8c63493..89f4113 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -160,6 +160,12 @@ def parse_args():
160 help="A collection to filter the dataset.", 160 help="A collection to filter the dataset.",
161 ) 161 )
162 parser.add_argument( 162 parser.add_argument(
163 "--validation_prompts",
164 type=str,
165 nargs="*",
166 help="Prompts for additional validation images",
167 )
168 parser.add_argument(
163 "--seed", type=int, default=None, help="A seed for reproducible training." 169 "--seed", type=int, default=None, help="A seed for reproducible training."
164 ) 170 )
165 parser.add_argument( 171 parser.add_argument(
@@ -456,7 +462,7 @@ def parse_args():
456 parser.add_argument( 462 parser.add_argument(
457 "--sample_steps", 463 "--sample_steps",
458 type=int, 464 type=int,
459 default=10, 465 default=15,
460 help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", 466 help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.",
461 ) 467 )
462 parser.add_argument( 468 parser.add_argument(
@@ -852,11 +858,6 @@ def main():
852 sample_image_size=args.sample_image_size, 858 sample_image_size=args.sample_image_size,
853 ) 859 )
854 860
855 optimizer = create_optimizer(
856 text_encoder.text_model.embeddings.token_embedding.parameters(),
857 lr=args.learning_rate,
858 )
859
860 data_generator = torch.Generator(device="cpu").manual_seed(args.seed) 861 data_generator = torch.Generator(device="cpu").manual_seed(args.seed)
861 data_npgenerator = np.random.default_rng(args.seed) 862 data_npgenerator = np.random.default_rng(args.seed)
862 863
@@ -957,6 +958,11 @@ def main():
957 avg_loss_val = AverageMeter() 958 avg_loss_val = AverageMeter()
958 avg_acc_val = AverageMeter() 959 avg_acc_val = AverageMeter()
959 960
961 optimizer = create_optimizer(
962 text_encoder.text_model.embeddings.token_embedding.parameters(),
963 lr=args.learning_rate,
964 )
965
960 while True: 966 while True:
961 if len(auto_cycles) != 0: 967 if len(auto_cycles) != 0:
962 response = auto_cycles.pop(0) 968 response = auto_cycles.pop(0)