summaryrefslogtreecommitdiffstats
path: root/train_ti.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-01-09 10:57:05 +0100
committerVolpeon <git@volpeon.ink>2023-01-09 10:57:05 +0100
commit7ce728b7ea9cfe6b6dc7d05826c1bf64eec5aacb (patch)
tree3ad7621f9ebf3468ff46aee4b6a736fa52f7aaf4 /train_ti.py
parentAdd --valid_set_repeat (diff)
downloadtextual-inversion-diff-7ce728b7ea9cfe6b6dc7d05826c1bf64eec5aacb.tar.gz
textual-inversion-diff-7ce728b7ea9cfe6b6dc7d05826c1bf64eec5aacb.tar.bz2
textual-inversion-diff-7ce728b7ea9cfe6b6dc7d05826c1bf64eec5aacb.zip
Enable buckets for validation, fixed vaildation repeat arg
Diffstat (limited to 'train_ti.py')
-rw-r--r--train_ti.py5
1 files changed, 1 insertions, 4 deletions
diff --git a/train_ti.py b/train_ti.py
index 7784d04..df8d443 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -383,7 +383,7 @@ def parse_args():
383 parser.add_argument( 383 parser.add_argument(
384 "--valid_set_repeat", 384 "--valid_set_repeat",
385 type=int, 385 type=int,
386 default=None, 386 default=1,
387 help="Times the images in the validation dataset are repeated." 387 help="Times the images in the validation dataset are repeated."
388 ) 388 )
389 parser.add_argument( 389 parser.add_argument(
@@ -477,9 +477,6 @@ def parse_args():
477 if isinstance(args.exclude_collections, str): 477 if isinstance(args.exclude_collections, str):
478 args.exclude_collections = [args.exclude_collections] 478 args.exclude_collections = [args.exclude_collections]
479 479
480 if args.valid_set_repeat is None:
481 args.valid_set_repeat = args.train_batch_size
482
483 if args.output_dir is None: 480 if args.output_dir is None:
484 raise ValueError("You must specify --output_dir") 481 raise ValueError("You must specify --output_dir")
485 482