diff options
Diffstat (limited to 'dreambooth_plus.py')
-rw-r--r-- | dreambooth_plus.py | 59 |
1 files changed, 50 insertions, 9 deletions
diff --git a/dreambooth_plus.py b/dreambooth_plus.py index 7996bc2..b5ec2fc 100644 --- a/dreambooth_plus.py +++ b/dreambooth_plus.py | |||
@@ -58,6 +58,12 @@ def parse_args(): | |||
58 | parser.add_argument( | 58 | parser.add_argument( |
59 | "--placeholder_token", | 59 | "--placeholder_token", |
60 | type=str, | 60 | type=str, |
61 | default="<*>", | ||
62 | help="A token to use as a placeholder for the concept.", | ||
63 | ) | ||
64 | parser.add_argument( | ||
65 | "--class_identifier", | ||
66 | type=str, | ||
61 | default=None, | 67 | default=None, |
62 | help="A token to use as a placeholder for the concept.", | 68 | help="A token to use as a placeholder for the concept.", |
63 | ) | 69 | ) |
@@ -71,7 +77,7 @@ def parse_args(): | |||
71 | "--num_class_images", | 77 | "--num_class_images", |
72 | type=int, | 78 | type=int, |
73 | default=400, | 79 | default=400, |
74 | help="How many class images to generate per training image." | 80 | help="How many class images to generate." |
75 | ) | 81 | ) |
76 | parser.add_argument( | 82 | parser.add_argument( |
77 | "--repeats", | 83 | "--repeats", |
@@ -112,7 +118,7 @@ def parse_args(): | |||
112 | parser.add_argument( | 118 | parser.add_argument( |
113 | "--max_train_steps", | 119 | "--max_train_steps", |
114 | type=int, | 120 | type=int, |
115 | default=1600, | 121 | default=2300, |
116 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", | 122 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", |
117 | ) | 123 | ) |
118 | parser.add_argument( | 124 | parser.add_argument( |
@@ -135,7 +141,7 @@ def parse_args(): | |||
135 | parser.add_argument( | 141 | parser.add_argument( |
136 | "--learning_rate_text", | 142 | "--learning_rate_text", |
137 | type=float, | 143 | type=float, |
138 | default=5e-4, | 144 | default=5e-6, |
139 | help="Initial learning rate (after the potential warmup period) to use.", | 145 | help="Initial learning rate (after the potential warmup period) to use.", |
140 | ) | 146 | ) |
141 | parser.add_argument( | 147 | parser.add_argument( |
@@ -222,6 +228,12 @@ def parse_args(): | |||
222 | ), | 228 | ), |
223 | ) | 229 | ) |
224 | parser.add_argument( | 230 | parser.add_argument( |
231 | "--checkpoint_frequency", | ||
232 | type=int, | ||
233 | default=500, | ||
234 | help="How often to save a checkpoint and sample image", | ||
235 | ) | ||
236 | parser.add_argument( | ||
225 | "--sample_frequency", | 237 | "--sample_frequency", |
226 | type=int, | 238 | type=int, |
227 | default=100, | 239 | default=100, |
@@ -352,7 +364,26 @@ class Checkpointer: | |||
352 | self.sample_batch_size = sample_batch_size | 364 | self.sample_batch_size = sample_batch_size |
353 | 365 | ||
354 | @torch.no_grad() | 366 | @torch.no_grad() |
355 | def checkpoint(self): | 367 | def checkpoint(self, step, postfix): |
368 | print("Saving checkpoint for step %d..." % step) | ||
369 | |||
370 | checkpoints_path = self.output_dir.joinpath("checkpoints") | ||
371 | checkpoints_path.mkdir(parents=True, exist_ok=True) | ||
372 | |||
373 | unwrapped = self.accelerator.unwrap_model(self.text_encoder) | ||
374 | |||
375 | # Save a checkpoint | ||
376 | learned_embeds = unwrapped.get_input_embeddings().weight[self.placeholder_token_id] | ||
377 | learned_embeds_dict = {self.placeholder_token: learned_embeds.detach().cpu()} | ||
378 | |||
379 | filename = f"%s_%d_%s.bin" % (slugify(self.placeholder_token), step, postfix) | ||
380 | torch.save(learned_embeds_dict, checkpoints_path.joinpath(filename)) | ||
381 | |||
382 | del unwrapped | ||
383 | del learned_embeds | ||
384 | |||
385 | @torch.no_grad() | ||
386 | def save_model(self): | ||
356 | print("Saving model...") | 387 | print("Saving model...") |
357 | 388 | ||
358 | unwrapped_unet = self.accelerator.unwrap_model( | 389 | unwrapped_unet = self.accelerator.unwrap_model( |
@@ -612,7 +643,7 @@ def main(): | |||
612 | batch_size=args.train_batch_size, | 643 | batch_size=args.train_batch_size, |
613 | tokenizer=tokenizer, | 644 | tokenizer=tokenizer, |
614 | instance_identifier=args.placeholder_token, | 645 | instance_identifier=args.placeholder_token, |
615 | class_identifier=args.initializer_token, | 646 | class_identifier=args.class_identifier, |
616 | class_subdir="cls", | 647 | class_subdir="cls", |
617 | num_class_images=args.num_class_images, | 648 | num_class_images=args.num_class_images, |
618 | size=args.resolution, | 649 | size=args.resolution, |
@@ -648,7 +679,7 @@ def main(): | |||
648 | with torch.inference_mode(): | 679 | with torch.inference_mode(): |
649 | for batch in batched_data: | 680 | for batch in batched_data: |
650 | image_name = [p.class_image_path for p in batch] | 681 | image_name = [p.class_image_path for p in batch] |
651 | prompt = [p.prompt.format(args.initializer_token) for p in batch] | 682 | prompt = [p.prompt.format(args.class_identifier) for p in batch] |
652 | nprompt = [p.nprompt for p in batch] | 683 | nprompt = [p.nprompt for p in batch] |
653 | 684 | ||
654 | images = pipeline( | 685 | images = pipeline( |
@@ -842,6 +873,12 @@ def main(): | |||
842 | if global_step % args.sample_frequency == 0: | 873 | if global_step % args.sample_frequency == 0: |
843 | sample_checkpoint = True | 874 | sample_checkpoint = True |
844 | 875 | ||
876 | if global_step % args.checkpoint_frequency == 0 and global_step > 0 and accelerator.is_main_process: | ||
877 | local_progress_bar.clear() | ||
878 | global_progress_bar.clear() | ||
879 | |||
880 | checkpointer.checkpoint(global_step + global_step_offset, "training") | ||
881 | |||
845 | logs = { | 882 | logs = { |
846 | "train/loss": loss, | 883 | "train/loss": loss, |
847 | "lr/unet": lr_scheduler.get_last_lr()[0], | 884 | "lr/unet": lr_scheduler.get_last_lr()[0], |
@@ -903,6 +940,9 @@ def main(): | |||
903 | global_progress_bar.clear() | 940 | global_progress_bar.clear() |
904 | 941 | ||
905 | if min_val_loss > val_loss: | 942 | if min_val_loss > val_loss: |
943 | accelerator.print( | ||
944 | f"Global step {global_step}: Validation loss reached new minimum: {min_val_loss:.2e} -> {val_loss:.2e}") | ||
945 | checkpointer.checkpoint(global_step + global_step_offset, "milestone") | ||
906 | min_val_loss = val_loss | 946 | min_val_loss = val_loss |
907 | 947 | ||
908 | if sample_checkpoint and accelerator.is_main_process: | 948 | if sample_checkpoint and accelerator.is_main_process: |
@@ -913,14 +953,15 @@ def main(): | |||
913 | # Create the pipeline using using the trained modules and save it. | 953 | # Create the pipeline using using the trained modules and save it. |
914 | if accelerator.is_main_process: | 954 | if accelerator.is_main_process: |
915 | print("Finished! Saving final checkpoint and resume state.") | 955 | print("Finished! Saving final checkpoint and resume state.") |
916 | checkpointer.checkpoint() | 956 | checkpointer.checkpoint(global_step + global_step_offset, "end") |
917 | 957 | checkpointer.save_model() | |
918 | accelerator.end_training() | 958 | accelerator.end_training() |
919 | 959 | ||
920 | except KeyboardInterrupt: | 960 | except KeyboardInterrupt: |
921 | if accelerator.is_main_process: | 961 | if accelerator.is_main_process: |
922 | print("Interrupted, saving checkpoint and resume state...") | 962 | print("Interrupted, saving checkpoint and resume state...") |
923 | checkpointer.checkpoint() | 963 | checkpointer.checkpoint(global_step + global_step_offset, "end") |
964 | checkpointer.save_model() | ||
924 | accelerator.end_training() | 965 | accelerator.end_training() |
925 | quit() | 966 | quit() |
926 | 967 | ||