summaryrefslogtreecommitdiffstats
path: root/train_ti.py
diff options
context:
space:
mode:
Diffstat (limited to 'train_ti.py')
-rw-r--r--train_ti.py3
1 files changed, 1 insertions, 2 deletions
diff --git a/train_ti.py b/train_ti.py
index f1dbed1..5cbef06 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -1006,7 +1006,7 @@ def main():
1006 if accelerator.is_main_process: 1006 if accelerator.is_main_process:
1007 print("Finished! Saving final checkpoint and resume state.") 1007 print("Finished! Saving final checkpoint and resume state.")
1008 checkpointer.checkpoint(global_step + global_step_offset, "end") 1008 checkpointer.checkpoint(global_step + global_step_offset, "end")
1009 checkpointer.save_samples(global_step_offset, args.sample_steps) 1009 checkpointer.save_samples(global_step + global_step_offset, args.sample_steps)
1010 save_args(basepath, args, { 1010 save_args(basepath, args, {
1011 "global_step": global_step + global_step_offset 1011 "global_step": global_step + global_step_offset
1012 }) 1012 })
@@ -1016,7 +1016,6 @@ def main():
1016 if accelerator.is_main_process: 1016 if accelerator.is_main_process:
1017 print("Interrupted, saving checkpoint and resume state...") 1017 print("Interrupted, saving checkpoint and resume state...")
1018 checkpointer.checkpoint(global_step + global_step_offset, "end") 1018 checkpointer.checkpoint(global_step + global_step_offset, "end")
1019 checkpointer.save_samples(global_step_offset, args.sample_steps)
1020 save_args(basepath, args, { 1019 save_args(basepath, args, {
1021 "global_step": global_step + global_step_offset 1020 "global_step": global_step + global_step_offset
1022 }) 1021 })