From b57ca669a150d9313447612fb8c37668f4f2a80d Mon Sep 17 00:00:00 2001 From: Volpeon Date: Mon, 9 Jan 2023 10:19:37 +0100 Subject: Add --valid_set_repeat --- train_ti.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) (limited to 'train_ti.py') diff --git a/train_ti.py b/train_ti.py index 03f52c4..7784d04 100644 --- a/train_ti.py +++ b/train_ti.py @@ -380,6 +380,12 @@ def parse_args(): default=None, help="Number of images in the validation dataset." ) + parser.add_argument( + "--valid_set_repeat", + type=int, + default=None, + help="Times the images in the validation dataset are repeated." + ) parser.add_argument( "--train_batch_size", type=int, @@ -398,6 +404,12 @@ def parse_args(): default=1.0, help="The weight of prior preservation loss." ) + parser.add_argument( + "--max_grad_norm", + default=3.0, + type=float, + help="Max gradient norm." + ) parser.add_argument( "--noise_timesteps", type=int, @@ -465,6 +477,9 @@ def parse_args(): if isinstance(args.exclude_collections, str): args.exclude_collections = [args.exclude_collections] + if args.valid_set_repeat is None: + args.valid_set_repeat = args.train_batch_size + if args.output_dir is None: raise ValueError("You must specify --output_dir") @@ -735,6 +750,7 @@ def main(): dropout=args.tag_dropout, template_key=args.train_data_template, valid_set_size=args.valid_set_size, + valid_set_repeat=args.valid_set_repeat, num_workers=args.dataloader_num_workers, seed=args.seed, filter=keyword_filter, @@ -961,6 +977,12 @@ def main(): accelerator.backward(loss) + if accelerator.sync_gradients: + accelerator.clip_grad_norm_( + text_encoder.text_model.embeddings.temp_token_embedding.parameters(), + args.max_grad_norm + ) + optimizer.step() if not accelerator.optimizer_step_was_skipped: lr_scheduler.step() -- cgit v1.2.3-54-g00ecf