summaryrefslogtreecommitdiffstats
path: root/train_ti.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-01-09 10:19:37 +0100
committerVolpeon <git@volpeon.ink>2023-01-09 10:19:37 +0100
commitb57ca669a150d9313447612fb8c37668f4f2a80d (patch)
treeb0ebfedc33c26847838850416b96fd2623cf6ba5 /train_ti.py
parentNo cache after all (diff)
downloadtextual-inversion-diff-b57ca669a150d9313447612fb8c37668f4f2a80d.tar.gz
textual-inversion-diff-b57ca669a150d9313447612fb8c37668f4f2a80d.tar.bz2
textual-inversion-diff-b57ca669a150d9313447612fb8c37668f4f2a80d.zip
Add --valid_set_repeat
Diffstat (limited to 'train_ti.py')
-rw-r--r--train_ti.py22
1 files changed, 22 insertions, 0 deletions
diff --git a/train_ti.py b/train_ti.py
index 03f52c4..7784d04 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -381,6 +381,12 @@ def parse_args():
381 help="Number of images in the validation dataset." 381 help="Number of images in the validation dataset."
382 ) 382 )
383 parser.add_argument( 383 parser.add_argument(
384 "--valid_set_repeat",
385 type=int,
386 default=None,
387 help="Times the images in the validation dataset are repeated."
388 )
389 parser.add_argument(
384 "--train_batch_size", 390 "--train_batch_size",
385 type=int, 391 type=int,
386 default=1, 392 default=1,
@@ -399,6 +405,12 @@ def parse_args():
399 help="The weight of prior preservation loss." 405 help="The weight of prior preservation loss."
400 ) 406 )
401 parser.add_argument( 407 parser.add_argument(
408 "--max_grad_norm",
409 default=3.0,
410 type=float,
411 help="Max gradient norm."
412 )
413 parser.add_argument(
402 "--noise_timesteps", 414 "--noise_timesteps",
403 type=int, 415 type=int,
404 default=1000, 416 default=1000,
@@ -465,6 +477,9 @@ def parse_args():
465 if isinstance(args.exclude_collections, str): 477 if isinstance(args.exclude_collections, str):
466 args.exclude_collections = [args.exclude_collections] 478 args.exclude_collections = [args.exclude_collections]
467 479
480 if args.valid_set_repeat is None:
481 args.valid_set_repeat = args.train_batch_size
482
468 if args.output_dir is None: 483 if args.output_dir is None:
469 raise ValueError("You must specify --output_dir") 484 raise ValueError("You must specify --output_dir")
470 485
@@ -735,6 +750,7 @@ def main():
735 dropout=args.tag_dropout, 750 dropout=args.tag_dropout,
736 template_key=args.train_data_template, 751 template_key=args.train_data_template,
737 valid_set_size=args.valid_set_size, 752 valid_set_size=args.valid_set_size,
753 valid_set_repeat=args.valid_set_repeat,
738 num_workers=args.dataloader_num_workers, 754 num_workers=args.dataloader_num_workers,
739 seed=args.seed, 755 seed=args.seed,
740 filter=keyword_filter, 756 filter=keyword_filter,
@@ -961,6 +977,12 @@ def main():
961 977
962 accelerator.backward(loss) 978 accelerator.backward(loss)
963 979
980 if accelerator.sync_gradients:
981 accelerator.clip_grad_norm_(
982 text_encoder.text_model.embeddings.temp_token_embedding.parameters(),
983 args.max_grad_norm
984 )
985
964 optimizer.step() 986 optimizer.step()
965 if not accelerator.optimizer_step_was_skipped: 987 if not accelerator.optimizer_step_was_skipped:
966 lr_scheduler.step() 988 lr_scheduler.step()