diff options
-rw-r--r-- | train_dreambooth.py | 3 | ||||
-rw-r--r-- | train_ti.py | 3 |
2 files changed, 2 insertions, 4 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py index c0fe328..05f6cb5 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
@@ -1068,14 +1068,13 @@ def main(): | |||
1068 | # Create the pipeline using using the trained modules and save it. | 1068 | # Create the pipeline using using the trained modules and save it. |
1069 | if accelerator.is_main_process: | 1069 | if accelerator.is_main_process: |
1070 | print("Finished! Saving final checkpoint and resume state.") | 1070 | print("Finished! Saving final checkpoint and resume state.") |
1071 | checkpointer.save_samples(0, args.sample_steps) | 1071 | checkpointer.save_samples(global_step, args.sample_steps) |
1072 | checkpointer.save_model() | 1072 | checkpointer.save_model() |
1073 | accelerator.end_training() | 1073 | accelerator.end_training() |
1074 | 1074 | ||
1075 | except KeyboardInterrupt: | 1075 | except KeyboardInterrupt: |
1076 | if accelerator.is_main_process: | 1076 | if accelerator.is_main_process: |
1077 | print("Interrupted, saving checkpoint and resume state...") | 1077 | print("Interrupted, saving checkpoint and resume state...") |
1078 | checkpointer.save_samples(0, args.sample_steps) | ||
1079 | checkpointer.save_model() | 1078 | checkpointer.save_model() |
1080 | accelerator.end_training() | 1079 | accelerator.end_training() |
1081 | quit() | 1080 | quit() |
diff --git a/train_ti.py b/train_ti.py index f1dbed1..5cbef06 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -1006,7 +1006,7 @@ def main(): | |||
1006 | if accelerator.is_main_process: | 1006 | if accelerator.is_main_process: |
1007 | print("Finished! Saving final checkpoint and resume state.") | 1007 | print("Finished! Saving final checkpoint and resume state.") |
1008 | checkpointer.checkpoint(global_step + global_step_offset, "end") | 1008 | checkpointer.checkpoint(global_step + global_step_offset, "end") |
1009 | checkpointer.save_samples(global_step_offset, args.sample_steps) | 1009 | checkpointer.save_samples(global_step + global_step_offset, args.sample_steps) |
1010 | save_args(basepath, args, { | 1010 | save_args(basepath, args, { |
1011 | "global_step": global_step + global_step_offset | 1011 | "global_step": global_step + global_step_offset |
1012 | }) | 1012 | }) |
@@ -1016,7 +1016,6 @@ def main(): | |||
1016 | if accelerator.is_main_process: | 1016 | if accelerator.is_main_process: |
1017 | print("Interrupted, saving checkpoint and resume state...") | 1017 | print("Interrupted, saving checkpoint and resume state...") |
1018 | checkpointer.checkpoint(global_step + global_step_offset, "end") | 1018 | checkpointer.checkpoint(global_step + global_step_offset, "end") |
1019 | checkpointer.save_samples(global_step_offset, args.sample_steps) | ||
1020 | save_args(basepath, args, { | 1019 | save_args(basepath, args, { |
1021 | "global_step": global_step + global_step_offset | 1020 | "global_step": global_step + global_step_offset |
1022 | }) | 1021 | }) |