diff options
Diffstat (limited to 'train_ti.py')
-rw-r--r-- | train_ti.py | 18 |
1 files changed, 12 insertions, 6 deletions
diff --git a/train_ti.py b/train_ti.py index 8c63493..89f4113 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -160,6 +160,12 @@ def parse_args(): | |||
160 | help="A collection to filter the dataset.", | 160 | help="A collection to filter the dataset.", |
161 | ) | 161 | ) |
162 | parser.add_argument( | 162 | parser.add_argument( |
163 | "--validation_prompts", | ||
164 | type=str, | ||
165 | nargs="*", | ||
166 | help="Prompts for additional validation images", | ||
167 | ) | ||
168 | parser.add_argument( | ||
163 | "--seed", type=int, default=None, help="A seed for reproducible training." | 169 | "--seed", type=int, default=None, help="A seed for reproducible training." |
164 | ) | 170 | ) |
165 | parser.add_argument( | 171 | parser.add_argument( |
@@ -456,7 +462,7 @@ def parse_args(): | |||
456 | parser.add_argument( | 462 | parser.add_argument( |
457 | "--sample_steps", | 463 | "--sample_steps", |
458 | type=int, | 464 | type=int, |
459 | default=10, | 465 | default=15, |
460 | help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", | 466 | help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", |
461 | ) | 467 | ) |
462 | parser.add_argument( | 468 | parser.add_argument( |
@@ -852,11 +858,6 @@ def main(): | |||
852 | sample_image_size=args.sample_image_size, | 858 | sample_image_size=args.sample_image_size, |
853 | ) | 859 | ) |
854 | 860 | ||
855 | optimizer = create_optimizer( | ||
856 | text_encoder.text_model.embeddings.token_embedding.parameters(), | ||
857 | lr=args.learning_rate, | ||
858 | ) | ||
859 | |||
860 | data_generator = torch.Generator(device="cpu").manual_seed(args.seed) | 861 | data_generator = torch.Generator(device="cpu").manual_seed(args.seed) |
861 | data_npgenerator = np.random.default_rng(args.seed) | 862 | data_npgenerator = np.random.default_rng(args.seed) |
862 | 863 | ||
@@ -957,6 +958,11 @@ def main(): | |||
957 | avg_loss_val = AverageMeter() | 958 | avg_loss_val = AverageMeter() |
958 | avg_acc_val = AverageMeter() | 959 | avg_acc_val = AverageMeter() |
959 | 960 | ||
961 | optimizer = create_optimizer( | ||
962 | text_encoder.text_model.embeddings.token_embedding.parameters(), | ||
963 | lr=args.learning_rate, | ||
964 | ) | ||
965 | |||
960 | while True: | 966 | while True: |
961 | if len(auto_cycles) != 0: | 967 | if len(auto_cycles) != 0: |
962 | response = auto_cycles.pop(0) | 968 | response = auto_cycles.pop(0) |