summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--train_dreambooth.py3
-rw-r--r--train_ti.py3
2 files changed, 2 insertions, 4 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py
index c0fe328..05f6cb5 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -1068,14 +1068,13 @@ def main():
1068 # Create the pipeline using using the trained modules and save it. 1068 # Create the pipeline using using the trained modules and save it.
1069 if accelerator.is_main_process: 1069 if accelerator.is_main_process:
1070 print("Finished! Saving final checkpoint and resume state.") 1070 print("Finished! Saving final checkpoint and resume state.")
1071 checkpointer.save_samples(0, args.sample_steps) 1071 checkpointer.save_samples(global_step, args.sample_steps)
1072 checkpointer.save_model() 1072 checkpointer.save_model()
1073 accelerator.end_training() 1073 accelerator.end_training()
1074 1074
1075 except KeyboardInterrupt: 1075 except KeyboardInterrupt:
1076 if accelerator.is_main_process: 1076 if accelerator.is_main_process:
1077 print("Interrupted, saving checkpoint and resume state...") 1077 print("Interrupted, saving checkpoint and resume state...")
1078 checkpointer.save_samples(0, args.sample_steps)
1079 checkpointer.save_model() 1078 checkpointer.save_model()
1080 accelerator.end_training() 1079 accelerator.end_training()
1081 quit() 1080 quit()
diff --git a/train_ti.py b/train_ti.py
index f1dbed1..5cbef06 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -1006,7 +1006,7 @@ def main():
1006 if accelerator.is_main_process: 1006 if accelerator.is_main_process:
1007 print("Finished! Saving final checkpoint and resume state.") 1007 print("Finished! Saving final checkpoint and resume state.")
1008 checkpointer.checkpoint(global_step + global_step_offset, "end") 1008 checkpointer.checkpoint(global_step + global_step_offset, "end")
1009 checkpointer.save_samples(global_step_offset, args.sample_steps) 1009 checkpointer.save_samples(global_step + global_step_offset, args.sample_steps)
1010 save_args(basepath, args, { 1010 save_args(basepath, args, {
1011 "global_step": global_step + global_step_offset 1011 "global_step": global_step + global_step_offset
1012 }) 1012 })
@@ -1016,7 +1016,6 @@ def main():
1016 if accelerator.is_main_process: 1016 if accelerator.is_main_process:
1017 print("Interrupted, saving checkpoint and resume state...") 1017 print("Interrupted, saving checkpoint and resume state...")
1018 checkpointer.checkpoint(global_step + global_step_offset, "end") 1018 checkpointer.checkpoint(global_step + global_step_offset, "end")
1019 checkpointer.save_samples(global_step_offset, args.sample_steps)
1020 save_args(basepath, args, { 1019 save_args(basepath, args, {
1021 "global_step": global_step + global_step_offset 1020 "global_step": global_step + global_step_offset
1022 }) 1021 })