diff options
author | Volpeon <git@volpeon.ink> | 2022-10-14 20:03:01 +0200 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2022-10-14 20:03:01 +0200 |
commit | 6a49074dce78615bce54777fb2be3bfd0dd8f780 (patch) | |
tree | 0f7dde5ea6b6343fb6e0a527e5ebb2940d418dce /dreambooth.py | |
parent | Added support for Aesthetic Gradients (diff) | |
download | textual-inversion-diff-6a49074dce78615bce54777fb2be3bfd0dd8f780.tar.gz textual-inversion-diff-6a49074dce78615bce54777fb2be3bfd0dd8f780.tar.bz2 textual-inversion-diff-6a49074dce78615bce54777fb2be3bfd0dd8f780.zip |
Removed aesthetic gradients; training improvements
Diffstat (limited to 'dreambooth.py')
-rw-r--r-- | dreambooth.py | 10 |
1 files changed, 5 insertions, 5 deletions
diff --git a/dreambooth.py b/dreambooth.py index 072142e..1ba8dc0 100644 --- a/dreambooth.py +++ b/dreambooth.py | |||
@@ -70,7 +70,7 @@ def parse_args(): | |||
70 | "--num_class_images", | 70 | "--num_class_images", |
71 | type=int, | 71 | type=int, |
72 | default=400, | 72 | default=400, |
73 | help="How many class images to generate per training image." | 73 | help="How many class images to generate." |
74 | ) | 74 | ) |
75 | parser.add_argument( | 75 | parser.add_argument( |
76 | "--repeats", | 76 | "--repeats", |
@@ -112,7 +112,7 @@ def parse_args(): | |||
112 | parser.add_argument( | 112 | parser.add_argument( |
113 | "--max_train_steps", | 113 | "--max_train_steps", |
114 | type=int, | 114 | type=int, |
115 | default=3000, | 115 | default=2000, |
116 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", | 116 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", |
117 | ) | 117 | ) |
118 | parser.add_argument( | 118 | parser.add_argument( |
@@ -341,7 +341,7 @@ class Checkpointer: | |||
341 | self.sample_batch_size = sample_batch_size | 341 | self.sample_batch_size = sample_batch_size |
342 | 342 | ||
343 | @torch.no_grad() | 343 | @torch.no_grad() |
344 | def checkpoint(self): | 344 | def save_model(self): |
345 | print("Saving model...") | 345 | print("Saving model...") |
346 | 346 | ||
347 | unwrapped = self.accelerator.unwrap_model( | 347 | unwrapped = self.accelerator.unwrap_model( |
@@ -839,14 +839,14 @@ def main(): | |||
839 | # Create the pipeline using using the trained modules and save it. | 839 | # Create the pipeline using using the trained modules and save it. |
840 | if accelerator.is_main_process: | 840 | if accelerator.is_main_process: |
841 | print("Finished! Saving final checkpoint and resume state.") | 841 | print("Finished! Saving final checkpoint and resume state.") |
842 | checkpointer.checkpoint() | 842 | checkpointer.save_model() |
843 | 843 | ||
844 | accelerator.end_training() | 844 | accelerator.end_training() |
845 | 845 | ||
846 | except KeyboardInterrupt: | 846 | except KeyboardInterrupt: |
847 | if accelerator.is_main_process: | 847 | if accelerator.is_main_process: |
848 | print("Interrupted, saving checkpoint and resume state...") | 848 | print("Interrupted, saving checkpoint and resume state...") |
849 | checkpointer.checkpoint() | 849 | checkpointer.save_model() |
850 | accelerator.end_training() | 850 | accelerator.end_training() |
851 | quit() | 851 | quit() |
852 | 852 | ||