From b1d7b2962e6454f8e72bd64efe08dc80a1f2d3aa Mon Sep 17 00:00:00 2001 From: Volpeon Date: Mon, 2 Jan 2023 17:12:06 +0100 Subject: Fix --- train_dreambooth.py | 4 ---- train_ti.py | 6 ++---- 2 files changed, 2 insertions(+), 8 deletions(-) diff --git a/train_dreambooth.py b/train_dreambooth.py index cd0bf67..05f6cb5 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py @@ -955,10 +955,6 @@ def main(): ) global_progress_bar.set_description("Total progress") - save_args(basepath, args, { - "global_step": global_step + global_step_offset - }) - try: for epoch in range(num_epochs): if accelerator.is_main_process: diff --git a/train_ti.py b/train_ti.py index 6c74854..97dde1e 100644 --- a/train_ti.py +++ b/train_ti.py @@ -519,6 +519,8 @@ def main(): args.seed = args.seed or (torch.random.seed() >> 32) set_seed(args.seed) + save_args(basepath, args) + # Load the tokenizer and add the placeholder token as a additional special token if args.tokenizer_name: tokenizer = MultiCLIPTokenizer.from_pretrained(args.tokenizer_name) @@ -903,10 +905,6 @@ def main(): ) global_progress_bar.set_description("Total progress") - save_args(basepath, args, { - "global_step": global_step + global_step_offset - }) - try: for epoch in range(num_epochs): if accelerator.is_main_process: -- cgit v1.2.3-70-g09d2