summaryrefslogtreecommitdiffstats
path: root/dreambooth.py
diff options
context:
space:
mode:
Diffstat (limited to 'dreambooth.py')
-rw-r--r--dreambooth.py10
1 files changed, 5 insertions, 5 deletions
diff --git a/dreambooth.py b/dreambooth.py
index 072142e..1ba8dc0 100644
--- a/dreambooth.py
+++ b/dreambooth.py
@@ -70,7 +70,7 @@ def parse_args():
70 "--num_class_images", 70 "--num_class_images",
71 type=int, 71 type=int,
72 default=400, 72 default=400,
73 help="How many class images to generate per training image." 73 help="How many class images to generate."
74 ) 74 )
75 parser.add_argument( 75 parser.add_argument(
76 "--repeats", 76 "--repeats",
@@ -112,7 +112,7 @@ def parse_args():
112 parser.add_argument( 112 parser.add_argument(
113 "--max_train_steps", 113 "--max_train_steps",
114 type=int, 114 type=int,
115 default=3000, 115 default=2000,
116 help="Total number of training steps to perform. If provided, overrides num_train_epochs.", 116 help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
117 ) 117 )
118 parser.add_argument( 118 parser.add_argument(
@@ -341,7 +341,7 @@ class Checkpointer:
341 self.sample_batch_size = sample_batch_size 341 self.sample_batch_size = sample_batch_size
342 342
343 @torch.no_grad() 343 @torch.no_grad()
344 def checkpoint(self): 344 def save_model(self):
345 print("Saving model...") 345 print("Saving model...")
346 346
347 unwrapped = self.accelerator.unwrap_model( 347 unwrapped = self.accelerator.unwrap_model(
@@ -839,14 +839,14 @@ def main():
839 # Create the pipeline using using the trained modules and save it. 839 # Create the pipeline using using the trained modules and save it.
840 if accelerator.is_main_process: 840 if accelerator.is_main_process:
841 print("Finished! Saving final checkpoint and resume state.") 841 print("Finished! Saving final checkpoint and resume state.")
842 checkpointer.checkpoint() 842 checkpointer.save_model()
843 843
844 accelerator.end_training() 844 accelerator.end_training()
845 845
846 except KeyboardInterrupt: 846 except KeyboardInterrupt:
847 if accelerator.is_main_process: 847 if accelerator.is_main_process:
848 print("Interrupted, saving checkpoint and resume state...") 848 print("Interrupted, saving checkpoint and resume state...")
849 checkpointer.checkpoint() 849 checkpointer.save_model()
850 accelerator.end_training() 850 accelerator.end_training()
851 quit() 851 quit()
852 852