From 94b676d91382267e7429bd68362019868affd9d1 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Mon, 13 Feb 2023 17:19:18 +0100 Subject: Update --- train_ti.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) (limited to 'train_ti.py') diff --git a/train_ti.py b/train_ti.py index c79dfa2..171d085 100644 --- a/train_ti.py +++ b/train_ti.py @@ -143,7 +143,7 @@ def parse_args(): parser.add_argument( "--num_buckets", type=int, - default=4, + default=0, help="Number of aspect ratio buckets in either direction.", ) parser.add_argument( @@ -485,6 +485,9 @@ def parse_args(): if len(args.placeholder_tokens) != len(args.train_data_template): raise ValueError("--placeholder_tokens and --train_data_template must have the same number of items") + else: + if isinstance(args.train_data_template, list): + raise ValueError("--train_data_template can't be a list in simultaneous mode") if isinstance(args.collection, str): args.collection = [args.collection] @@ -503,7 +506,7 @@ def main(): global_step_offset = args.global_step now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") - output_dir = Path(args.output_dir).joinpath(slugify(args.project), now) + output_dir = Path(args.output_dir)/slugify(args.project)/now output_dir.mkdir(parents=True, exist_ok=True) accelerator = Accelerator( @@ -519,7 +522,7 @@ def main(): elif args.mixed_precision == "bf16": weight_dtype = torch.bfloat16 - logging.basicConfig(filename=output_dir.joinpath("log.txt"), level=logging.DEBUG) + logging.basicConfig(filename=output_dir/"log.txt", level=logging.DEBUG) if args.seed is None: args.seed = torch.random.seed() >> 32 @@ -570,7 +573,7 @@ def main(): else: optimizer_class = torch.optim.AdamW - checkpoint_output_dir = output_dir.joinpath("checkpoints") + checkpoint_output_dir = output_dir/"checkpoints" trainer = partial( train, @@ -611,11 +614,11 @@ def main(): return if len(placeholder_tokens) == 1: - sample_output_dir = output_dir.joinpath(f"samples_{placeholder_tokens[0]}") - metrics_output_file = output_dir.joinpath(f"{placeholder_tokens[0]}.png") + sample_output_dir = output_dir/f"samples_{placeholder_tokens[0]}" + metrics_output_file = output_dir/f"{placeholder_tokens[0]}.png" else: - sample_output_dir = output_dir.joinpath("samples") - metrics_output_file = output_dir.joinpath(f"lr.png") + sample_output_dir = output_dir/"samples" + metrics_output_file = output_dir/f"lr.png" placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( tokenizer=tokenizer, -- cgit v1.2.3-54-g00ecf