From 33e7d2ed37e32657ca94d92815043026c4cea7c0 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Tue, 10 Jan 2023 09:22:02 +0100 Subject: Added arg to disable tag shuffling --- train_dreambooth.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) (limited to 'train_dreambooth.py') diff --git a/train_dreambooth.py b/train_dreambooth.py index 1a1f516..48a513c 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py @@ -132,6 +132,12 @@ def parse_args(): default=0.1, help="Tag dropout probability.", ) + parser.add_argument( + "--tag_shuffle", + type="store_true", + default=True, + help="Shuffle tags.", + ) parser.add_argument( "--vector_dropout", type=int, @@ -398,7 +404,7 @@ def parse_args(): parser.add_argument( "--sample_steps", type=int, - default=15, + default=20, help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", ) parser.add_argument( @@ -768,6 +774,7 @@ def main(): bucket_step_size=args.bucket_step_size, bucket_max_pixels=args.bucket_max_pixels, dropout=args.tag_dropout, + shuffle=args.tag_shuffle, template_key=args.train_data_template, valid_set_size=args.valid_set_size, valid_set_repeat=args.valid_set_repeat, -- cgit v1.2.3-54-g00ecf