diff options
| -rw-r--r-- | train_dreambooth.py | 13 | ||||
| -rw-r--r-- | train_ti.py | 12 |
2 files changed, 12 insertions, 13 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py index 92f9b96..c0fe328 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
| @@ -941,9 +941,6 @@ def main(): | |||
| 941 | seed=args.seed | 941 | seed=args.seed |
| 942 | ) | 942 | ) |
| 943 | 943 | ||
| 944 | if accelerator.is_main_process: | ||
| 945 | checkpointer.save_samples(0, args.sample_steps) | ||
| 946 | |||
| 947 | local_progress_bar = tqdm( | 944 | local_progress_bar = tqdm( |
| 948 | range(num_update_steps_per_epoch + num_val_steps_per_epoch), | 945 | range(num_update_steps_per_epoch + num_val_steps_per_epoch), |
| 949 | disable=not accelerator.is_local_main_process, | 946 | disable=not accelerator.is_local_main_process, |
| @@ -960,6 +957,10 @@ def main(): | |||
| 960 | 957 | ||
| 961 | try: | 958 | try: |
| 962 | for epoch in range(num_epochs): | 959 | for epoch in range(num_epochs): |
| 960 | if accelerator.is_main_process: | ||
| 961 | if epoch % args.sample_frequency == 0: | ||
| 962 | checkpointer.save_samples(global_step, args.sample_steps) | ||
| 963 | |||
| 963 | local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") | 964 | local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") |
| 964 | local_progress_bar.reset() | 965 | local_progress_bar.reset() |
| 965 | 966 | ||
| @@ -1064,19 +1065,17 @@ def main(): | |||
| 1064 | f"Global step {global_step}: Validation accuracy reached new maximum: {max_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}") | 1065 | f"Global step {global_step}: Validation accuracy reached new maximum: {max_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}") |
| 1065 | max_acc_val = avg_acc_val.avg.item() | 1066 | max_acc_val = avg_acc_val.avg.item() |
| 1066 | 1067 | ||
| 1067 | if (epoch + 1) % args.sample_frequency == 0: | ||
| 1068 | checkpointer.save_samples(global_step, args.sample_steps) | ||
| 1069 | |||
| 1070 | # Create the pipeline using using the trained modules and save it. | 1068 | # Create the pipeline using using the trained modules and save it. |
| 1071 | if accelerator.is_main_process: | 1069 | if accelerator.is_main_process: |
| 1072 | print("Finished! Saving final checkpoint and resume state.") | 1070 | print("Finished! Saving final checkpoint and resume state.") |
| 1071 | checkpointer.save_samples(0, args.sample_steps) | ||
| 1073 | checkpointer.save_model() | 1072 | checkpointer.save_model() |
| 1074 | |||
| 1075 | accelerator.end_training() | 1073 | accelerator.end_training() |
| 1076 | 1074 | ||
| 1077 | except KeyboardInterrupt: | 1075 | except KeyboardInterrupt: |
| 1078 | if accelerator.is_main_process: | 1076 | if accelerator.is_main_process: |
| 1079 | print("Interrupted, saving checkpoint and resume state...") | 1077 | print("Interrupted, saving checkpoint and resume state...") |
| 1078 | checkpointer.save_samples(0, args.sample_steps) | ||
| 1080 | checkpointer.save_model() | 1079 | checkpointer.save_model() |
| 1081 | accelerator.end_training() | 1080 | accelerator.end_training() |
| 1082 | quit() | 1081 | quit() |
diff --git a/train_ti.py b/train_ti.py index 775b918..f1dbed1 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -889,9 +889,6 @@ def main(): | |||
| 889 | seed=args.seed | 889 | seed=args.seed |
| 890 | ) | 890 | ) |
| 891 | 891 | ||
| 892 | if accelerator.is_main_process: | ||
| 893 | checkpointer.save_samples(global_step_offset, args.sample_steps) | ||
| 894 | |||
| 895 | local_progress_bar = tqdm( | 892 | local_progress_bar = tqdm( |
| 896 | range(num_update_steps_per_epoch + num_val_steps_per_epoch), | 893 | range(num_update_steps_per_epoch + num_val_steps_per_epoch), |
| 897 | disable=not accelerator.is_local_main_process, | 894 | disable=not accelerator.is_local_main_process, |
| @@ -908,6 +905,10 @@ def main(): | |||
| 908 | 905 | ||
| 909 | try: | 906 | try: |
| 910 | for epoch in range(num_epochs): | 907 | for epoch in range(num_epochs): |
| 908 | if accelerator.is_main_process: | ||
| 909 | if epoch % args.sample_frequency == 0: | ||
| 910 | checkpointer.save_samples(global_step + global_step_offset, args.sample_steps) | ||
| 911 | |||
| 911 | local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") | 912 | local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") |
| 912 | local_progress_bar.reset() | 913 | local_progress_bar.reset() |
| 913 | 914 | ||
| @@ -1001,13 +1002,11 @@ def main(): | |||
| 1001 | "global_step": global_step + global_step_offset | 1002 | "global_step": global_step + global_step_offset |
| 1002 | }) | 1003 | }) |
| 1003 | 1004 | ||
| 1004 | if (epoch + 1) % args.sample_frequency == 0: | ||
| 1005 | checkpointer.save_samples(global_step + global_step_offset, args.sample_steps) | ||
| 1006 | |||
| 1007 | # Create the pipeline using using the trained modules and save it. | 1005 | # Create the pipeline using using the trained modules and save it. |
| 1008 | if accelerator.is_main_process: | 1006 | if accelerator.is_main_process: |
| 1009 | print("Finished! Saving final checkpoint and resume state.") | 1007 | print("Finished! Saving final checkpoint and resume state.") |
| 1010 | checkpointer.checkpoint(global_step + global_step_offset, "end") | 1008 | checkpointer.checkpoint(global_step + global_step_offset, "end") |
| 1009 | checkpointer.save_samples(global_step_offset, args.sample_steps) | ||
| 1011 | save_args(basepath, args, { | 1010 | save_args(basepath, args, { |
| 1012 | "global_step": global_step + global_step_offset | 1011 | "global_step": global_step + global_step_offset |
| 1013 | }) | 1012 | }) |
| @@ -1017,6 +1016,7 @@ def main(): | |||
| 1017 | if accelerator.is_main_process: | 1016 | if accelerator.is_main_process: |
| 1018 | print("Interrupted, saving checkpoint and resume state...") | 1017 | print("Interrupted, saving checkpoint and resume state...") |
| 1019 | checkpointer.checkpoint(global_step + global_step_offset, "end") | 1018 | checkpointer.checkpoint(global_step + global_step_offset, "end") |
| 1019 | checkpointer.save_samples(global_step_offset, args.sample_steps) | ||
| 1020 | save_args(basepath, args, { | 1020 | save_args(basepath, args, { |
| 1021 | "global_step": global_step + global_step_offset | 1021 | "global_step": global_step + global_step_offset |
| 1022 | }) | 1022 | }) |
