From b1d7b2962e6454f8e72bd64efe08dc80a1f2d3aa Mon Sep 17 00:00:00 2001 From: Volpeon Date: Mon, 2 Jan 2023 17:12:06 +0100 Subject: Fix --- train_ti.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) (limited to 'train_ti.py') diff --git a/train_ti.py b/train_ti.py index 6c74854..97dde1e 100644 --- a/train_ti.py +++ b/train_ti.py @@ -519,6 +519,8 @@ def main(): args.seed = args.seed or (torch.random.seed() >> 32) set_seed(args.seed) + save_args(basepath, args) + # Load the tokenizer and add the placeholder token as a additional special token if args.tokenizer_name: tokenizer = MultiCLIPTokenizer.from_pretrained(args.tokenizer_name) @@ -903,10 +905,6 @@ def main(): ) global_progress_bar.set_description("Total progress") - save_args(basepath, args, { - "global_step": global_step + global_step_offset - }) - try: for epoch in range(num_epochs): if accelerator.is_main_process: -- cgit v1.2.3-54-g00ecf