From 6a49074dce78615bce54777fb2be3bfd0dd8f780 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Fri, 14 Oct 2022 20:03:01 +0200 Subject: Removed aesthetic gradients; training improvements --- dreambooth.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) (limited to 'dreambooth.py') diff --git a/dreambooth.py b/dreambooth.py index 072142e..1ba8dc0 100644 --- a/dreambooth.py +++ b/dreambooth.py @@ -70,7 +70,7 @@ def parse_args(): "--num_class_images", type=int, default=400, - help="How many class images to generate per training image." + help="How many class images to generate." ) parser.add_argument( "--repeats", @@ -112,7 +112,7 @@ def parse_args(): parser.add_argument( "--max_train_steps", type=int, - default=3000, + default=2000, help="Total number of training steps to perform. If provided, overrides num_train_epochs.", ) parser.add_argument( @@ -341,7 +341,7 @@ class Checkpointer: self.sample_batch_size = sample_batch_size @torch.no_grad() - def checkpoint(self): + def save_model(self): print("Saving model...") unwrapped = self.accelerator.unwrap_model( @@ -839,14 +839,14 @@ def main(): # Create the pipeline using using the trained modules and save it. if accelerator.is_main_process: print("Finished! Saving final checkpoint and resume state.") - checkpointer.checkpoint() + checkpointer.save_model() accelerator.end_training() except KeyboardInterrupt: if accelerator.is_main_process: print("Interrupted, saving checkpoint and resume state...") - checkpointer.checkpoint() + checkpointer.save_model() accelerator.end_training() quit() -- cgit v1.2.3-54-g00ecf