diff options
author | Volpeon <git@volpeon.ink> | 2023-01-02 15:56:44 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-01-02 15:56:44 +0100 |
commit | ad9128d1131f2ae298cee56a2393486806f23c73 (patch) | |
tree | b2e8f87026183294ab9b698d79f5176a250a7263 /train_dreambooth.py | |
parent | Fix (diff) | |
download | textual-inversion-diff-ad9128d1131f2ae298cee56a2393486806f23c73.tar.gz textual-inversion-diff-ad9128d1131f2ae298cee56a2393486806f23c73.tar.bz2 textual-inversion-diff-ad9128d1131f2ae298cee56a2393486806f23c73.zip |
Update
Diffstat (limited to 'train_dreambooth.py')
-rw-r--r-- | train_dreambooth.py | 13 |
1 files changed, 6 insertions, 7 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py index 92f9b96..c0fe328 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
@@ -941,9 +941,6 @@ def main(): | |||
941 | seed=args.seed | 941 | seed=args.seed |
942 | ) | 942 | ) |
943 | 943 | ||
944 | if accelerator.is_main_process: | ||
945 | checkpointer.save_samples(0, args.sample_steps) | ||
946 | |||
947 | local_progress_bar = tqdm( | 944 | local_progress_bar = tqdm( |
948 | range(num_update_steps_per_epoch + num_val_steps_per_epoch), | 945 | range(num_update_steps_per_epoch + num_val_steps_per_epoch), |
949 | disable=not accelerator.is_local_main_process, | 946 | disable=not accelerator.is_local_main_process, |
@@ -960,6 +957,10 @@ def main(): | |||
960 | 957 | ||
961 | try: | 958 | try: |
962 | for epoch in range(num_epochs): | 959 | for epoch in range(num_epochs): |
960 | if accelerator.is_main_process: | ||
961 | if epoch % args.sample_frequency == 0: | ||
962 | checkpointer.save_samples(global_step, args.sample_steps) | ||
963 | |||
963 | local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") | 964 | local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") |
964 | local_progress_bar.reset() | 965 | local_progress_bar.reset() |
965 | 966 | ||
@@ -1064,19 +1065,17 @@ def main(): | |||
1064 | f"Global step {global_step}: Validation accuracy reached new maximum: {max_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}") | 1065 | f"Global step {global_step}: Validation accuracy reached new maximum: {max_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}") |
1065 | max_acc_val = avg_acc_val.avg.item() | 1066 | max_acc_val = avg_acc_val.avg.item() |
1066 | 1067 | ||
1067 | if (epoch + 1) % args.sample_frequency == 0: | ||
1068 | checkpointer.save_samples(global_step, args.sample_steps) | ||
1069 | |||
1070 | # Create the pipeline using using the trained modules and save it. | 1068 | # Create the pipeline using using the trained modules and save it. |
1071 | if accelerator.is_main_process: | 1069 | if accelerator.is_main_process: |
1072 | print("Finished! Saving final checkpoint and resume state.") | 1070 | print("Finished! Saving final checkpoint and resume state.") |
1071 | checkpointer.save_samples(0, args.sample_steps) | ||
1073 | checkpointer.save_model() | 1072 | checkpointer.save_model() |
1074 | |||
1075 | accelerator.end_training() | 1073 | accelerator.end_training() |
1076 | 1074 | ||
1077 | except KeyboardInterrupt: | 1075 | except KeyboardInterrupt: |
1078 | if accelerator.is_main_process: | 1076 | if accelerator.is_main_process: |
1079 | print("Interrupted, saving checkpoint and resume state...") | 1077 | print("Interrupted, saving checkpoint and resume state...") |
1078 | checkpointer.save_samples(0, args.sample_steps) | ||
1080 | checkpointer.save_model() | 1079 | checkpointer.save_model() |
1081 | accelerator.end_training() | 1080 | accelerator.end_training() |
1082 | quit() | 1081 | quit() |