summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--train_dreambooth.py13
-rw-r--r--train_ti.py12
2 files changed, 12 insertions, 13 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py
index 92f9b96..c0fe328 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -941,9 +941,6 @@ def main():
941 seed=args.seed 941 seed=args.seed
942 ) 942 )
943 943
944 if accelerator.is_main_process:
945 checkpointer.save_samples(0, args.sample_steps)
946
947 local_progress_bar = tqdm( 944 local_progress_bar = tqdm(
948 range(num_update_steps_per_epoch + num_val_steps_per_epoch), 945 range(num_update_steps_per_epoch + num_val_steps_per_epoch),
949 disable=not accelerator.is_local_main_process, 946 disable=not accelerator.is_local_main_process,
@@ -960,6 +957,10 @@ def main():
960 957
961 try: 958 try:
962 for epoch in range(num_epochs): 959 for epoch in range(num_epochs):
960 if accelerator.is_main_process:
961 if epoch % args.sample_frequency == 0:
962 checkpointer.save_samples(global_step, args.sample_steps)
963
963 local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") 964 local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}")
964 local_progress_bar.reset() 965 local_progress_bar.reset()
965 966
@@ -1064,19 +1065,17 @@ def main():
1064 f"Global step {global_step}: Validation accuracy reached new maximum: {max_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}") 1065 f"Global step {global_step}: Validation accuracy reached new maximum: {max_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}")
1065 max_acc_val = avg_acc_val.avg.item() 1066 max_acc_val = avg_acc_val.avg.item()
1066 1067
1067 if (epoch + 1) % args.sample_frequency == 0:
1068 checkpointer.save_samples(global_step, args.sample_steps)
1069
1070 # Create the pipeline using using the trained modules and save it. 1068 # Create the pipeline using using the trained modules and save it.
1071 if accelerator.is_main_process: 1069 if accelerator.is_main_process:
1072 print("Finished! Saving final checkpoint and resume state.") 1070 print("Finished! Saving final checkpoint and resume state.")
1071 checkpointer.save_samples(0, args.sample_steps)
1073 checkpointer.save_model() 1072 checkpointer.save_model()
1074
1075 accelerator.end_training() 1073 accelerator.end_training()
1076 1074
1077 except KeyboardInterrupt: 1075 except KeyboardInterrupt:
1078 if accelerator.is_main_process: 1076 if accelerator.is_main_process:
1079 print("Interrupted, saving checkpoint and resume state...") 1077 print("Interrupted, saving checkpoint and resume state...")
1078 checkpointer.save_samples(0, args.sample_steps)
1080 checkpointer.save_model() 1079 checkpointer.save_model()
1081 accelerator.end_training() 1080 accelerator.end_training()
1082 quit() 1081 quit()
diff --git a/train_ti.py b/train_ti.py
index 775b918..f1dbed1 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -889,9 +889,6 @@ def main():
889 seed=args.seed 889 seed=args.seed
890 ) 890 )
891 891
892 if accelerator.is_main_process:
893 checkpointer.save_samples(global_step_offset, args.sample_steps)
894
895 local_progress_bar = tqdm( 892 local_progress_bar = tqdm(
896 range(num_update_steps_per_epoch + num_val_steps_per_epoch), 893 range(num_update_steps_per_epoch + num_val_steps_per_epoch),
897 disable=not accelerator.is_local_main_process, 894 disable=not accelerator.is_local_main_process,
@@ -908,6 +905,10 @@ def main():
908 905
909 try: 906 try:
910 for epoch in range(num_epochs): 907 for epoch in range(num_epochs):
908 if accelerator.is_main_process:
909 if epoch % args.sample_frequency == 0:
910 checkpointer.save_samples(global_step + global_step_offset, args.sample_steps)
911
911 local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") 912 local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}")
912 local_progress_bar.reset() 913 local_progress_bar.reset()
913 914
@@ -1001,13 +1002,11 @@ def main():
1001 "global_step": global_step + global_step_offset 1002 "global_step": global_step + global_step_offset
1002 }) 1003 })
1003 1004
1004 if (epoch + 1) % args.sample_frequency == 0:
1005 checkpointer.save_samples(global_step + global_step_offset, args.sample_steps)
1006
1007 # Create the pipeline using using the trained modules and save it. 1005 # Create the pipeline using using the trained modules and save it.
1008 if accelerator.is_main_process: 1006 if accelerator.is_main_process:
1009 print("Finished! Saving final checkpoint and resume state.") 1007 print("Finished! Saving final checkpoint and resume state.")
1010 checkpointer.checkpoint(global_step + global_step_offset, "end") 1008 checkpointer.checkpoint(global_step + global_step_offset, "end")
1009 checkpointer.save_samples(global_step_offset, args.sample_steps)
1011 save_args(basepath, args, { 1010 save_args(basepath, args, {
1012 "global_step": global_step + global_step_offset 1011 "global_step": global_step + global_step_offset
1013 }) 1012 })
@@ -1017,6 +1016,7 @@ def main():
1017 if accelerator.is_main_process: 1016 if accelerator.is_main_process:
1018 print("Interrupted, saving checkpoint and resume state...") 1017 print("Interrupted, saving checkpoint and resume state...")
1019 checkpointer.checkpoint(global_step + global_step_offset, "end") 1018 checkpointer.checkpoint(global_step + global_step_offset, "end")
1019 checkpointer.save_samples(global_step_offset, args.sample_steps)
1020 save_args(basepath, args, { 1020 save_args(basepath, args, {
1021 "global_step": global_step + global_step_offset 1021 "global_step": global_step + global_step_offset
1022 }) 1022 })