From 27b18776ba6d38d6bda5e5bafee3e7c4ca8c9712 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 24 Jun 2023 16:26:22 +0200 Subject: Fixes --- train_ti.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) (limited to 'train_ti.py') diff --git a/train_ti.py b/train_ti.py index 8c63493..89f4113 100644 --- a/train_ti.py +++ b/train_ti.py @@ -159,6 +159,12 @@ def parse_args(): nargs="*", help="A collection to filter the dataset.", ) + parser.add_argument( + "--validation_prompts", + type=str, + nargs="*", + help="Prompts for additional validation images", + ) parser.add_argument( "--seed", type=int, default=None, help="A seed for reproducible training." ) @@ -456,7 +462,7 @@ def parse_args(): parser.add_argument( "--sample_steps", type=int, - default=10, + default=15, help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", ) parser.add_argument( @@ -852,11 +858,6 @@ def main(): sample_image_size=args.sample_image_size, ) - optimizer = create_optimizer( - text_encoder.text_model.embeddings.token_embedding.parameters(), - lr=args.learning_rate, - ) - data_generator = torch.Generator(device="cpu").manual_seed(args.seed) data_npgenerator = np.random.default_rng(args.seed) @@ -957,6 +958,11 @@ def main(): avg_loss_val = AverageMeter() avg_acc_val = AverageMeter() + optimizer = create_optimizer( + text_encoder.text_model.embeddings.token_embedding.parameters(), + lr=args.learning_rate, + ) + while True: if len(auto_cycles) != 0: response = auto_cycles.pop(0) -- cgit v1.2.3-54-g00ecf