diff options
author | Volpeon <git@volpeon.ink> | 2023-01-02 15:56:44 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-01-02 15:56:44 +0100 |
commit | ad9128d1131f2ae298cee56a2393486806f23c73 (patch) | |
tree | b2e8f87026183294ab9b698d79f5176a250a7263 /train_ti.py | |
parent | Fix (diff) | |
download | textual-inversion-diff-ad9128d1131f2ae298cee56a2393486806f23c73.tar.gz textual-inversion-diff-ad9128d1131f2ae298cee56a2393486806f23c73.tar.bz2 textual-inversion-diff-ad9128d1131f2ae298cee56a2393486806f23c73.zip |
Update
Diffstat (limited to 'train_ti.py')
-rw-r--r-- | train_ti.py | 12 |
1 files changed, 6 insertions, 6 deletions
diff --git a/train_ti.py b/train_ti.py index 775b918..f1dbed1 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -889,9 +889,6 @@ def main(): | |||
889 | seed=args.seed | 889 | seed=args.seed |
890 | ) | 890 | ) |
891 | 891 | ||
892 | if accelerator.is_main_process: | ||
893 | checkpointer.save_samples(global_step_offset, args.sample_steps) | ||
894 | |||
895 | local_progress_bar = tqdm( | 892 | local_progress_bar = tqdm( |
896 | range(num_update_steps_per_epoch + num_val_steps_per_epoch), | 893 | range(num_update_steps_per_epoch + num_val_steps_per_epoch), |
897 | disable=not accelerator.is_local_main_process, | 894 | disable=not accelerator.is_local_main_process, |
@@ -908,6 +905,10 @@ def main(): | |||
908 | 905 | ||
909 | try: | 906 | try: |
910 | for epoch in range(num_epochs): | 907 | for epoch in range(num_epochs): |
908 | if accelerator.is_main_process: | ||
909 | if epoch % args.sample_frequency == 0: | ||
910 | checkpointer.save_samples(global_step + global_step_offset, args.sample_steps) | ||
911 | |||
911 | local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") | 912 | local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") |
912 | local_progress_bar.reset() | 913 | local_progress_bar.reset() |
913 | 914 | ||
@@ -1001,13 +1002,11 @@ def main(): | |||
1001 | "global_step": global_step + global_step_offset | 1002 | "global_step": global_step + global_step_offset |
1002 | }) | 1003 | }) |
1003 | 1004 | ||
1004 | if (epoch + 1) % args.sample_frequency == 0: | ||
1005 | checkpointer.save_samples(global_step + global_step_offset, args.sample_steps) | ||
1006 | |||
1007 | # Create the pipeline using using the trained modules and save it. | 1005 | # Create the pipeline using using the trained modules and save it. |
1008 | if accelerator.is_main_process: | 1006 | if accelerator.is_main_process: |
1009 | print("Finished! Saving final checkpoint and resume state.") | 1007 | print("Finished! Saving final checkpoint and resume state.") |
1010 | checkpointer.checkpoint(global_step + global_step_offset, "end") | 1008 | checkpointer.checkpoint(global_step + global_step_offset, "end") |
1009 | checkpointer.save_samples(global_step_offset, args.sample_steps) | ||
1011 | save_args(basepath, args, { | 1010 | save_args(basepath, args, { |
1012 | "global_step": global_step + global_step_offset | 1011 | "global_step": global_step + global_step_offset |
1013 | }) | 1012 | }) |
@@ -1017,6 +1016,7 @@ def main(): | |||
1017 | if accelerator.is_main_process: | 1016 | if accelerator.is_main_process: |
1018 | print("Interrupted, saving checkpoint and resume state...") | 1017 | print("Interrupted, saving checkpoint and resume state...") |
1019 | checkpointer.checkpoint(global_step + global_step_offset, "end") | 1018 | checkpointer.checkpoint(global_step + global_step_offset, "end") |
1019 | checkpointer.save_samples(global_step_offset, args.sample_steps) | ||
1020 | save_args(basepath, args, { | 1020 | save_args(basepath, args, { |
1021 | "global_step": global_step + global_step_offset | 1021 | "global_step": global_step + global_step_offset |
1022 | }) | 1022 | }) |