From ad9128d1131f2ae298cee56a2393486806f23c73 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Mon, 2 Jan 2023 15:56:44 +0100 Subject: Update --- train_dreambooth.py | 13 ++++++------- train_ti.py | 12 ++++++------ 2 files changed, 12 insertions(+), 13 deletions(-) diff --git a/train_dreambooth.py b/train_dreambooth.py index 92f9b96..c0fe328 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py @@ -941,9 +941,6 @@ def main(): seed=args.seed ) - if accelerator.is_main_process: - checkpointer.save_samples(0, args.sample_steps) - local_progress_bar = tqdm( range(num_update_steps_per_epoch + num_val_steps_per_epoch), disable=not accelerator.is_local_main_process, @@ -960,6 +957,10 @@ def main(): try: for epoch in range(num_epochs): + if accelerator.is_main_process: + if epoch % args.sample_frequency == 0: + checkpointer.save_samples(global_step, args.sample_steps) + local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") local_progress_bar.reset() @@ -1064,19 +1065,17 @@ def main(): f"Global step {global_step}: Validation accuracy reached new maximum: {max_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}") max_acc_val = avg_acc_val.avg.item() - if (epoch + 1) % args.sample_frequency == 0: - checkpointer.save_samples(global_step, args.sample_steps) - # Create the pipeline using using the trained modules and save it. if accelerator.is_main_process: print("Finished! Saving final checkpoint and resume state.") + checkpointer.save_samples(0, args.sample_steps) checkpointer.save_model() - accelerator.end_training() except KeyboardInterrupt: if accelerator.is_main_process: print("Interrupted, saving checkpoint and resume state...") + checkpointer.save_samples(0, args.sample_steps) checkpointer.save_model() accelerator.end_training() quit() diff --git a/train_ti.py b/train_ti.py index 775b918..f1dbed1 100644 --- a/train_ti.py +++ b/train_ti.py @@ -889,9 +889,6 @@ def main(): seed=args.seed ) - if accelerator.is_main_process: - checkpointer.save_samples(global_step_offset, args.sample_steps) - local_progress_bar = tqdm( range(num_update_steps_per_epoch + num_val_steps_per_epoch), disable=not accelerator.is_local_main_process, @@ -908,6 +905,10 @@ def main(): try: for epoch in range(num_epochs): + if accelerator.is_main_process: + if epoch % args.sample_frequency == 0: + checkpointer.save_samples(global_step + global_step_offset, args.sample_steps) + local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") local_progress_bar.reset() @@ -1001,13 +1002,11 @@ def main(): "global_step": global_step + global_step_offset }) - if (epoch + 1) % args.sample_frequency == 0: - checkpointer.save_samples(global_step + global_step_offset, args.sample_steps) - # Create the pipeline using using the trained modules and save it. if accelerator.is_main_process: print("Finished! Saving final checkpoint and resume state.") checkpointer.checkpoint(global_step + global_step_offset, "end") + checkpointer.save_samples(global_step_offset, args.sample_steps) save_args(basepath, args, { "global_step": global_step + global_step_offset }) @@ -1017,6 +1016,7 @@ def main(): if accelerator.is_main_process: print("Interrupted, saving checkpoint and resume state...") checkpointer.checkpoint(global_step + global_step_offset, "end") + checkpointer.save_samples(global_step_offset, args.sample_steps) save_args(basepath, args, { "global_step": global_step + global_step_offset }) -- cgit v1.2.3-70-g09d2